--- language: - en library_name: transformers license: apache-2.0 pipeline_tag: image-text-to-text tags: - Sentence Similarity - Embedding - zero-shot-image-classification - video-text-to-text --- # UME-R1-7B ## Model Summary The model has undergone a cold-start SFT stage and an RL stage of training, and is capable of embedding text, images, multiple images, and videos. In particular, UME-R1 can generate either discriminative or generative embeddings as needed, and the generative embeddings possess the potential for test-time scaling. - **Repository:** [UME-R1](https://github.com/XMUDeepLIT/UME-R1) - **Paper:** [UME-R1](https://arxiv.org/abs/2511.00405) ## Train/Eval Data - Train data: https://huggingface.co/datasets/zhibinlan/UME-sft-train - Eval data: https://huggingface.co/datasets/TIGER-Lab/MMEB-V2 ## Model Performance UME-R1 significantly outperforms discriminative embeddings and can provide discriminative or generative representations as needed. Its oracle performance—selecting the best between discriminative and generative—far exceeds using either mode alone. MMEB-V2 In addition, UME-R1 can produce improved embedding representations through repeated sampling, indicating that generative embeddings also hold strong promise for inference-time scaling. pass@k ### Quick Start First clone our github ```bash git clone https://github.com/DeepLearnXMU/UME-R1 cd UME-R1 bash setup.sh ``` Below, we provide simple examples to show how to use UME-R1 with 🤗 Transformers. Example of obtaining generative embeddings: ```python from transformers import Qwen2VLForConditionalGeneration,AutoProcessor from qwen_vl_utils import process_vision_info import torch model = Qwen2VLForConditionalGeneration.from_pretrained( "zhibinlan/UME-R1-7B", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="cuda:0", ) processor = AutoProcessor.from_pretrained("zhibinlan/UME-R1-7B") prompt = '''Represent the above input text, images, videos, or any combination of the three as embeddings. First output the thinking process in tags and then summarize the entire input in a word or sentence. Finally, use the tag to represent the entire input.''' messages = [ { "role": "user", "content": [ { "type": "image", "image": "assets/example.jpg", }, {"type": "text", "text": "Represent the given image with the following question: What is in the image?\n\n" + prompt}, ], } ] # Preparation for inference text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to(model.device) # Inference: Generation of the output generated_output = model.generate(**inputs, max_new_tokens=8192, output_hidden_states=True, return_dict_in_generate=True, use_cache=True) # Post-process the output generated_ids = generated_output.sequences hidden_states = generated_output.hidden_states generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] def get_embedding_idx(generated_ids_trimmed, EMBEDDING_TOKEN_ID): embedding_idx = [] for i, out_ids in enumerate(generated_ids_trimmed): embed_exist = False for j in range(len(out_ids) - 1, -1, -1): if out_ids[j] == EMBEDDING_TOKEN_ID: embedding_idx.append(j + 1) embed_exist = True break if not embed_exist: embedding_idx.append(-1) return embedding_idx def normalize_reps(reps): reps = torch.nn.functional.normalize(reps, p=2, dim=-1) return reps # Get the last hidden state of the token embedding_idx = get_embedding_idx(generated_ids_trimmed, processor.tokenizer.get_vocab()[""]) embedding_reps = hidden_states[embedding_idx[0]][-1].squeeze(1) # Normalize the representations embedding_reps = normalize_reps(embedding_reps) output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False ) ```
Example of obtaining discriminative embeddings ```python from transformers import Qwen2VLForConditionalGeneration,AutoProcessor from qwen_vl_utils import process_vision_info import torch pretrained_path = "zhibinlan/UME-R1-7B" # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios. model = Qwen2VLForConditionalGeneration.from_pretrained( pretrained_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="cuda:0", ) # default processor processor = AutoProcessor.from_pretrained(pretrained_path) messages = [ { "role": "user", "content": [ { "type": "image", "image": "UME-R1/assets/example.jpg", }, {"type": "text", "text": "Represent the given image with the following question: What is in the image?\n\n"}, ], } ] # Preparation for inference text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to(model.device) def get_embedding_idx(generated_ids_trimmed, EMBEDDING_TOKEN_ID): embedding_idx = [] # Search from the last token forward for i, out_ids in enumerate(generated_ids_trimmed): embed_exist = False for j in range(len(out_ids) - 1, -1, -1): if out_ids[j] == EMBEDDING_TOKEN_ID: embedding_idx.append(j) embed_exist = True break if not embed_exist: embedding_idx.append(-1) return embedding_idx def normalize_reps(reps): # Normalize the representations reps = torch.nn.functional.normalize(reps, p=2, dim=-1) return reps output = model(**inputs, output_hidden_states=True, return_dict=True) hidden_states = output.hidden_states[-1][0] # print("output.hidden_states shape: ", hidden_states.shape) embedding_idx = get_embedding_idx(inputs['input_ids'], processor.tokenizer.get_vocab()[""]) # Get the last hidden state of the token embedding_reps = hidden_states[embedding_idx[0]] # Normalize the representations embedding_reps = normalize_reps(embedding_reps) ```
Multi image inference ```python # Messages containing multiple images and a text query messages = [ { "role": "user", "content": [ {"type": "image", "image": "file:///path/to/image1.jpg"}, {"type": "image", "image": "file:///path/to/image2.jpg"}, {"type": "text", "text": "Represent the given images."}, ], } ] ```
Video inference ```python # Messages containing a images list as a video and a text query messages = [ { "role": "user", "content": [ { "type": "video", "video": [ "file:///path/to/frame1.jpg", "file:///path/to/frame2.jpg", "file:///path/to/frame3.jpg", "file:///path/to/frame4.jpg", ], }, {"type": "text", "text": "Represent this video."}, ], } ] # Messages containing a local video path and a text query messages = [ { "role": "user", "content": [ { "type": "video", "video": "file:///path/to/video1.mp4", "max_pixels": 360 * 420, "fps": 1.0, }, {"type": "text", "text": "Represent this video."}, ], } ] # Messages containing a video url and a text query messages = [ { "role": "user", "content": [ { "type": "video", "video": "https://path/to/video.mp4", "min_pixels": 4 * 28 * 28, "max_pixels": 256 * 28 * 28, "total_pixels": 20480 * 28 * 28, }, {"type": "text", "text": "Represent this video."}, ], } ] image_inputs, video_inputs, video_kwargs = process_vision_info(messages, return_video_kwargs=True) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, fps=fps, padding=True, return_tensors="pt", **video_kwargs, ) ```
For more usage tips, please refer to our [Github page](https://github.com/DeepLearnXMU/UME-R1). ## Citation If you find our work useful, please consider citing it. ``` @article{lan2025ume, title={UME-R1: Exploring Reasoning-Driven Generative Multimodal Embeddings}, author={Lan, Zhibin and Niu, Liqiang and Meng, Fandong and Zhou, Jie and Su, Jinsong}, journal={arXiv preprint arXiv:2511.00405}, year={2025} } ```