# Requires transformers>=4.51.0 import torch import torch.nn.functional as F from torch import Tensor from transformers import AutoTokenizer, AutoModel def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if left_padding: return last_hidden_states[:, -1] else: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_states.shape[0] return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] def get_detailed_instruct(task_description: str, query: str) -> str: return f'Instruct: {task_description}\nQuery:{query}' # Each query must come with a one-sentence instruction that describes the task task = 'Given a web search query, retrieve relevant passages that answer the query' queries = [ get_detailed_instruct(task, 'What is the capital of China?'), get_detailed_instruct(task, 'Explain gravity') ] # No need to add instruction for retrieval documents documents = [ "The capital of China is Beijing.", "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun." ] input_texts = queries + documents tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-Embedding-4B', padding_side='left') model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-4B') # We recommend enabling flash_attention_2 for better acceleration and memory saving. # model = AutoModel.from_pretrained('Qwen/Qwen3-Embedding-4B', attn_implementation="flash_attention_2", torch_dtype=torch.float16).cuda() max_length = 8192 # Tokenize the input texts batch_dict = tokenizer( input_texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt", ) batch_dict.to(model.device) outputs = model(**batch_dict) embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']) # normalize embeddings embeddings = F.normalize(embeddings, p=2, dim=1) scores = (embeddings[:2] @ embeddings[2:].T) print(scores.tolist()) # [[0.7534257769584656, 0.1146894246339798], [0.03198453038930893, 0.6258305311203003]]