| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from typing import Any, Callable, Optional, Union |
| |
|
| | from transformers import Qwen2_5_VLForConditionalGeneration, AutoModelForImageTextToText |
| | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( |
| | Qwen2_5_VisionTransformerPretrainedModel, |
| | Qwen2_5_VLModel, |
| | Qwen2RMSNorm, |
| | Qwen2_5_VLMLP, |
| | ALL_ATTENTION_FUNCTIONS |
| | ) |
| | from transformers.image_utils import ImageInput |
| | from transformers.tokenization_utils import TextInput, PreTokenizedInput |
| | from transformers.video_utils import VideoInput |
| | from transformers.feature_extraction_utils import BatchFeature |
| |
|
| | from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLConfig |
| | from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import Qwen2_5_VLProcessorKwargs |
| |
|
| | class ADCopilotConfig(Qwen2_5_VLConfig): |
| | model_type = "ad_copilot" |
| | def __init__(self, **kwargs): |
| | super().__init__(**kwargs) |
| | self.vision_config.compare_token_size = 100 |
| | self.architectures = ["ADCopilotVLForConditionalGeneration"] |
| | self.sequence_compare = True |
| | |
| | class ADCopilotProcessor(Qwen2_5_VLProcessor): |
| | config_class = ADCopilotConfig |
| | def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs): |
| | super().__init__(image_processor, tokenizer, video_processor, chat_template, **kwargs) |
| | self.compare_token_size = 100 if "compare_token_size" not in kwargs else kwargs["compare_token_size"] |
| |
|
| | def __call__( |
| | self, |
| | images: ImageInput = None, |
| | text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, |
| | videos: VideoInput = None, |
| | **kwargs, |
| | ) -> BatchFeature: |
| | """ |
| | Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` |
| | and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode |
| | the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to |
| | Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. |
| | |
| | Args: |
| | images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): |
| | The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch |
| | tensor. Both channels-first and channels-last formats are supported. |
| | text (`str`, `list[str]`, `list[list[str]]`): |
| | The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings |
| | (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set |
| | `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). |
| | videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`): |
| | The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch |
| | tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. |
| | return_tensors (`str` or [`~utils.TensorType`], *optional*): |
| | If set, will return tensors of a particular framework. Acceptable values are: |
| | - `'tf'`: Return TensorFlow `tf.constant` objects. |
| | - `'pt'`: Return PyTorch `torch.Tensor` objects. |
| | - `'np'`: Return NumPy `np.ndarray` objects. |
| | - `'jax'`: Return JAX `jnp.ndarray` objects. |
| | |
| | Returns: |
| | [`BatchFeature`]: A [`BatchFeature`] with the following fields: |
| | |
| | - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. |
| | - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when |
| | `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not |
| | `None`). |
| | - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. |
| | - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. |
| | - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. |
| | - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. |
| | - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. |
| | """ |
| | output_kwargs = self._merge_kwargs( |
| | Qwen2_5_VLProcessorKwargs, |
| | tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
| | **kwargs, |
| | ) |
| |
|
| | image_inputs = videos_inputs = {} |
| | if images is not None: |
| | image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) |
| | image_grid_thw = image_inputs["image_grid_thw"] |
| |
|
| | if videos is not None: |
| | fps = output_kwargs["videos_kwargs"].get("fps", 2.0) |
| | videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) |
| | video_grid_thw = videos_inputs["video_grid_thw"] |
| |
|
| | if isinstance(fps, (int, float)): |
| | second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw) |
| | elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): |
| | second_per_grid_ts = [self.video_processor.temporal_patch_size / tmp for tmp in fps] |
| | else: |
| | raise ValueError( |
| | f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." |
| | ) |
| | videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) |
| |
|
| | if not isinstance(text, list): |
| | text = [text] |
| |
|
| | text = text.copy() |
| | if images is not None: |
| | merge_length = self.image_processor.merge_size**2 |
| | index = 0 |
| | for i in range(len(text)): |
| | while self.image_token in text[i]: |
| | num_image_tokens = image_grid_thw[index].prod() // merge_length |
| | |
| | text[i] = text[i].replace(self.image_token, "<|placeholder|>" * (num_image_tokens + self.compare_token_size), 1) |
| | index += 1 |
| | text[i] = text[i].replace("<|placeholder|>", self.image_token) |
| |
|
| | if videos is not None: |
| | merge_length = self.video_processor.merge_size**2 |
| | index = 0 |
| | for i in range(len(text)): |
| | while self.video_token in text[i]: |
| | num_video_tokens = video_grid_thw[index].prod() // merge_length |
| | text[i] = text[i].replace(self.video_token, "<|placeholder|>" * num_video_tokens, 1) |
| | index += 1 |
| | text[i] = text[i].replace("<|placeholder|>", self.video_token) |
| |
|
| | return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) |
| | return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None) |
| | text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) |
| | self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) |
| |
|
| | if return_mm_token_type_ids: |
| | array_ids = np.array(text_inputs["input_ids"]) |
| | mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) |
| | mm_token_type_ids[array_ids == self.image_token_id] = 1 |
| | text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() |
| |
|
| | return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors) |
| |
|
| |
|
| | class OptimizedCrossAttention(nn.Module): |
| | """ |
| | 仿照 Qwen2_5_VLVisionAttention 结构的优化 Cross Attention |
| | """ |
| | def __init__(self, config, is_cross_attention=True): |
| | super().__init__() |
| | self.config = config |
| | self.dim = config.hidden_size |
| | self.num_heads = config.num_heads |
| | self.head_dim = self.dim // self.num_heads |
| | self.scaling = self.head_dim**-0.5 |
| | self.attention_dropout = 0.0 |
| | self.is_causal = False |
| | self.is_cross_attention = is_cross_attention |
| | |
| | if is_cross_attention: |
| | |
| | self.q_proj = nn.Linear(self.dim, self.dim, bias=True) |
| | self.kv = nn.Linear(self.dim, self.dim * 2, bias=True) |
| | else: |
| | |
| | self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) |
| | |
| | self.proj = nn.Linear(self.dim, self.dim, bias=True) |
| | |
| | def forward( |
| | self, |
| | query_states: torch.Tensor, |
| | key_value_states: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | cu_seqlens: Optional[torch.Tensor] = None, |
| | kv_cu_seqlens: Optional[torch.Tensor] = None, |
| | **kwargs, |
| | ) -> torch.Tensor: |
| | |
| | orig_2d = False |
| | if query_states.dim() == 2: |
| | query_states = query_states.unsqueeze(0) |
| | orig_2d = True |
| |
|
| | batch_size, seq_len_q, _ = query_states.shape |
| |
|
| | |
| | if self.is_cross_attention and key_value_states is not None: |
| | if key_value_states.dim() == 2: |
| | key_value_states = key_value_states.unsqueeze(0) |
| | q = self.q_proj(query_states) |
| | kv = self.kv(key_value_states) |
| | seq_len_kv = kv.shape[1] |
| | k, v = kv.reshape(batch_size, seq_len_kv, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4).unbind(0) |
| | q = q.reshape(batch_size, seq_len_q, self.num_heads, self.head_dim).transpose(1, 2) |
| | else: |
| | if key_value_states is None: |
| | key_value_states = query_states |
| | qkv = self.qkv(query_states) |
| | q, k, v = qkv.reshape(batch_size, seq_len_q, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4).unbind(0) |
| |
|
| | |
| | attn_impl = getattr(self.config, '_attn_implementation', 'sdpa') |
| | attn_impl = 'sdpa' |
| | attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[attn_impl] |
| |
|
| | |
| | if attn_impl == "flash_attention_2": |
| | |
| | |
| |
|
| | |
| | if cu_seqlens is None: |
| | |
| | cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len_q, step=seq_len_q, dtype=torch.int32, device=q.device) |
| | if kv_cu_seqlens is None: |
| | cu_seqlens_k = torch.arange(0, (batch_size + 1) * k.shape[2], step=k.shape[2], dtype=torch.int32, device=k.device) |
| | else: |
| | cu_seqlens_k = kv_cu_seqlens |
| |
|
| | |
| | |
| | |
| | |
| | q_ = q.transpose(1, 2).contiguous().view(-1, self.num_heads, self.head_dim) |
| | k_ = k.transpose(1, 2).contiguous().view(-1, self.num_heads, self.head_dim) |
| | v_ = v.transpose(1, 2).contiguous().view(-1, self.num_heads, self.head_dim) |
| | |
| | max_seqlen_q = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() |
| | max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() |
| |
|
| | attn_output, _ = attention_interface( |
| | self, |
| | q_, |
| | k_, |
| | v_, |
| | attention_mask=None, |
| | scaling=self.scaling, |
| | dropout=0.0 if not self.training else self.attention_dropout, |
| | cu_seq_lens_q=cu_seqlens, |
| | cu_seq_lens_k=cu_seqlens_k, |
| | max_length_q=max_seqlen_q, |
| | max_length_k=max_seqlen_k, |
| | is_causal=self.is_causal, |
| | **kwargs, |
| | ) |
| | |
| | |
| | |
| | attn_output = attn_output.view(batch_size, seq_len_q, self.num_heads, self.head_dim).contiguous() |
| | else: |
| | |
| | attn_output, _ = attention_interface( |
| | self, |
| | q, k, v, |
| | attention_mask=attention_mask, |
| | scaling=self.scaling, |
| | dropout=0.0 if not self.training else self.attention_dropout, |
| | is_causal=self.is_causal, |
| | **kwargs, |
| | ) |
| | |
| | attn_output = attn_output.transpose(1, 2).contiguous() |
| |
|
| | attn_output = attn_output.reshape(batch_size, seq_len_q, self.dim) |
| | attn_output = self.proj(attn_output) |
| | if orig_2d: |
| | attn_output = attn_output.squeeze(0) |
| | return attn_output.contiguous() |
| |
|
| |
|
| | class ADCopilotCompareVisualEncoder(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.config = config |
| | self.sequence_compare = getattr(config, "sequence_compare", True) |
| | self.hidden_size = config.hidden_size |
| | |
| | self.token_size = 100 if "compare_token_size" not in config else config.compare_token_size |
| | |
| | |
| | self.encoder_cross_attn1 = OptimizedCrossAttention(config, is_cross_attention=True) |
| | |
| | self.encoder_cross_attn2 = OptimizedCrossAttention(config, is_cross_attention=True) |
| |
|
| | self.encoder_norm1 = Qwen2RMSNorm(self.hidden_size, eps=1e-6) |
| | self.encoder_norm2 = Qwen2RMSNorm(self.hidden_size, eps=1e-6) |
| | self.encoder_norm3 = Qwen2RMSNorm(self.hidden_size, eps=1e-6) |
| | self.encoder_norm4 = Qwen2RMSNorm(self.hidden_size, eps=1e-6) |
| | self.encoder_mlp1 = Qwen2_5_VLMLP(config) |
| | self.encoder_mlp2 = Qwen2_5_VLMLP(config) |
| | |
| | |
| | |
| | self.query_embeddings = nn.Parameter( |
| | torch.empty(self.token_size, self.hidden_size) |
| | ) |
| | |
| | self.decoder_cross_attn = OptimizedCrossAttention(config, is_cross_attention=True) |
| | |
| | self.decoder_norm1 = Qwen2RMSNorm(self.hidden_size, eps=1e-6) |
| | self.decoder_norm2 = Qwen2RMSNorm(self.hidden_size, eps=1e-6) |
| | self.decoder_mlp = Qwen2_5_VLMLP(config) |
| |
|
| | self.compare_projector = nn.Linear(config.hidden_size, config.out_hidden_size) |
| |
|
| | def init_query_embeddings(self): |
| | nn.init.normal_(self.query_embeddings, mean=0.0, std=0.02) |
| |
|
| | def forward(self, images_hidden_states: list) -> torch.Tensor: |
| | """ |
| | Args: |
| | images_hidden_states: List of tensor, each tensor has shape [seq_len, hidden_size] |
| | |
| | Returns: |
| | Tensor of shape [total_images, token_size, hidden_size] |
| | """ |
| | if not images_hidden_states: |
| | return torch.empty(0, self.token_size, self.hidden_size) |
| | |
| | |
| | if torch.isnan(self.query_embeddings).any(): |
| | print("警告:query_embeddings 包含 NaN 值") |
| | |
| | |
| | |
| | seq_lengths = [state.size(0) for state in images_hidden_states] |
| | max_seq_len = max(seq_lengths) |
| | batch_size = len(images_hidden_states) |
| | device = images_hidden_states[0].device |
| | dtype = images_hidden_states[0].dtype |
| | |
| | |
| | padded_states = [] |
| | attention_masks = [] |
| | for state in images_hidden_states: |
| | pad_len = max_seq_len - state.size(0) |
| | if pad_len > 0: |
| | |
| | padded_state = F.pad(state, (0, 0, 0, pad_len), mode='constant', value=0) |
| | |
| | attention_mask = torch.ones(max_seq_len, dtype=torch.bool, device=device) |
| | attention_mask[state.size(0):] = False |
| | else: |
| | padded_state = state |
| | attention_mask = torch.ones(max_seq_len, dtype=torch.bool, device=device) |
| | padded_states.append(padded_state) |
| | attention_masks.append(attention_mask) |
| | |
| | |
| | batched_states = torch.stack(padded_states) |
| | |
| | attention_masks = torch.stack(attention_masks) |
| | |
| | |
| | |
| | previous_states = torch.roll(batched_states, shifts=1, dims=0) |
| | previous_masks = torch.roll(attention_masks, shifts=1, dims=0) |
| |
|
| | if previous_states.size(0) > 1 and self.sequence_compare: |
| | previous_states[0] = previous_states[1] |
| | previous_masks[0] = previous_masks[1] |
| | |
| | |
| | encoded_features = self._encoder_forward( |
| | batched_states, |
| | previous_states, |
| | attention_masks, |
| | previous_masks |
| | ) |
| | |
| | |
| | |
| | batch_queries = self.query_embeddings.unsqueeze(0).expand(batch_size, -1, -1) |
| | |
| | compare_visual_embeds = self._decoder_forward( |
| | batch_queries, |
| | encoded_features, |
| | torch.ones(batch_size, self.token_size, dtype=torch.bool, device=device), |
| | attention_masks |
| | ) |
| |
|
| | |
| | batch_size = compare_visual_embeds.size(0) |
| | token_size = compare_visual_embeds.size(1) |
| | |
| | |
| | flattened_embeds = compare_visual_embeds.view(-1, compare_visual_embeds.size(-1)) |
| | merged = self.compare_projector(flattened_embeds) |
| | merged_token_size = token_size |
| | |
| | compare_visual_embeds = merged.view(batch_size, merged_token_size, -1) |
| | |
| | return compare_visual_embeds |
| | |
| | def _encoder_forward(self, current_features, previous_features, current_mask=None, previous_mask=None): |
| | """ |
| | Encoder: 双向图像特征交互 |
| | Args: |
| | current_features: [batch_size, seq_len, hidden_size] |
| | previous_features: [batch_size, seq_len, hidden_size] |
| | current_mask: [batch_size, seq_len] |
| | previous_mask: [batch_size, seq_len] |
| | """ |
| | |
| | residual = previous_features |
| | |
| | |
| | previous_normed = self.encoder_norm1(previous_features) |
| | current_normed1 = self.encoder_norm1(current_features) |
| | |
| | |
| | cross_attn_output1 = self.encoder_cross_attn1( |
| | query_states=previous_normed, |
| | key_value_states=current_normed1, |
| | attention_mask=current_mask.unsqueeze(1).unsqueeze(2) if current_mask is not None else None |
| | ) |
| | |
| | |
| | previous_features = residual + cross_attn_output1 |
| | |
| | |
| | residual = previous_features |
| | mlp_input1 = self.encoder_norm2(previous_features) |
| | mlp_output1 = self.encoder_mlp1(mlp_input1) |
| | previous_features = residual + mlp_output1 |
| | |
| | |
| | residual = current_features |
| | |
| | |
| | current_normed2 = self.encoder_norm3(current_features) |
| | previous_normed2 = self.encoder_norm3(previous_features) |
| | |
| | |
| | cross_attn_output2 = self.encoder_cross_attn2( |
| | query_states=current_normed2, |
| | key_value_states=previous_normed2, |
| | attention_mask=previous_mask.unsqueeze(1).unsqueeze(2) if previous_mask is not None else None |
| | ) |
| | |
| | |
| | current_features = residual + cross_attn_output2 |
| | |
| | |
| | residual = current_features |
| | mlp_input2 = self.encoder_norm4(current_features) |
| | mlp_output2 = self.encoder_mlp2(mlp_input2) |
| | |
| | |
| | current_features = residual - mlp_output2 |
| | return current_features |
| | |
| | def _decoder_forward(self, queries, encoded_features, query_mask=None, encoded_mask=None): |
| | """ |
| | Decoder: Query 与编码特征交互 |
| | Args: |
| | queries: [batch_size, token_size, hidden_size] |
| | encoded_features: [batch_size, seq_len, hidden_size] |
| | query_mask: [batch_size, token_size] |
| | encoded_mask: [batch_size, seq_len] |
| | """ |
| | |
| | residual = queries |
| | queries_normed = self.decoder_norm1(queries) |
| | encoded_normed = self.decoder_norm1(encoded_features) |
| | |
| | cross_attn_output = self.decoder_cross_attn( |
| | query_states=queries_normed, |
| | key_value_states=encoded_normed, |
| | attention_mask=encoded_mask.unsqueeze(1).unsqueeze(2) if encoded_mask is not None else None |
| | ) |
| | |
| | queries = residual + cross_attn_output |
| | |
| | |
| | residual = queries |
| | mlp_input = self.decoder_norm2(queries) |
| | mlp_output = self.decoder_mlp(mlp_input) |
| | queries = residual + mlp_output |
| | |
| | return queries |
| |
|
| |
|
| | |
| | class ADCopilotVisionTransformerPretrainedModel(Qwen2_5_VisionTransformerPretrainedModel): |
| | def __init__(self, config, *inputs, **kwargs) -> None: |
| | super().__init__(config, *inputs, **kwargs) |
| | self.compare_visual_encoder = ADCopilotCompareVisualEncoder(config) |
| | |
| | def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: |
| | """ |
| | Args: |
| | hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): |
| | The final hidden states of the model. |
| | grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): |
| | The temporal, height and width of feature shape of each image in LLM. |
| | |
| | Returns: |
| | `torch.Tensor`: hidden_states, compare_visual_embeds. |
| | """ |
| | hidden_states = self.patch_embed(hidden_states) |
| | rotary_pos_emb = self.rot_pos_emb(grid_thw) |
| | window_index, cu_window_seqlens = self.get_window_index(grid_thw) |
| | cu_window_seqlens = torch.tensor( |
| | cu_window_seqlens, |
| | device=hidden_states.device, |
| | dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
| | ) |
| | cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) |
| |
|
| | seq_len, _ = hidden_states.size() |
| | hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) |
| | hidden_states = hidden_states[window_index, :, :] |
| | hidden_states = hidden_states.reshape(seq_len, -1) |
| | rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) |
| | rotary_pos_emb = rotary_pos_emb[window_index, :, :] |
| | rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) |
| | emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
| | position_embeddings = (emb.cos(), emb.sin()) |
| |
|
| | cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( |
| | dim=0, |
| | |
| | |
| | |
| | |
| | dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
| | ) |
| | cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) |
| |
|
| | for layer_num, blk in enumerate(self.blocks): |
| | if layer_num in self.fullatt_block_indexes: |
| | cu_seqlens_now = cu_seqlens |
| | else: |
| | cu_seqlens_now = cu_window_seqlens |
| |
|
| | hidden_states = blk( |
| | hidden_states, |
| | cu_seqlens=cu_seqlens_now, |
| | position_embeddings=position_embeddings, |
| | **kwargs, |
| | ) |
| |
|
| | split_sizes = grid_thw.prod(-1).tolist() |
| | splited_hidden_states_before_merger = torch.split(hidden_states, split_sizes) |
| | |
| | compare_visual_embeds = self.compare_visual_encoder(splited_hidden_states_before_merger) |
| |
|
| | |
| | hidden_states = self.merger(hidden_states) |
| | reverse_indices = torch.argsort(window_index) |
| | hidden_states = hidden_states[reverse_indices, :] |
| |
|
| | return hidden_states, compare_visual_embeds |
| |
|
| | class ADCopilotVLModel(Qwen2_5_VLModel): |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.visual = ADCopilotVisionTransformerPretrainedModel._from_config(config.vision_config) |
| | self.compare_token_size = config.vision_config.compare_token_size |
| | |
| | |
| | |
| | |
| | def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): |
| | """ |
| | Encodes images into continuous embeddings that can be forwarded to the language model. |
| | |
| | Args: |
| | pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): |
| | The tensors corresponding to the input images. |
| | image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): |
| | The temporal, height and width of feature shape of each image in LLM. |
| | """ |
| | pixel_values = pixel_values.type(self.visual.dtype) |
| | image_embeds, compare_visual_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) |
| | |
| | split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() |
| | image_embeds = torch.split(image_embeds, split_sizes) |
| |
|
| | |
| | enhanced_image_embeds = [] |
| | for i, embeds in enumerate(image_embeds): |
| | |
| | compare_embed = compare_visual_embeds[i].to(device=embeds.device, dtype=embeds.dtype) |
| | enhanced_embeds = torch.cat([embeds, compare_embed], dim=0) |
| | enhanced_image_embeds.append(enhanced_embeds) |
| | |
| | |
| | return enhanced_image_embeds |
| | |
| | def get_rope_index(self, input_ids: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, torch.Tensor]: |
| | return self.get_rope_index_with_compare_token(input_ids, image_grid_thw, video_grid_thw, second_per_grid_ts, attention_mask) |
| | |
| | def get_rope_index_with_compare_token( |
| | self, |
| | input_ids: Optional[torch.LongTensor] = None, |
| | image_grid_thw: Optional[torch.LongTensor] = None, |
| | video_grid_thw: Optional[torch.LongTensor] = None, |
| | second_per_grid_ts: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | spatial_merge_size = self.config.vision_config.spatial_merge_size |
| | image_token_id = self.config.image_token_id |
| | video_token_id = self.config.video_token_id |
| | vision_start_token_id = self.config.vision_start_token_id |
| | mrope_position_deltas = [] |
| | if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): |
| | total_input_ids = input_ids |
| | if attention_mask is None: |
| | attention_mask = torch.ones_like(total_input_ids) |
| | position_ids = torch.ones( |
| | 3, |
| | input_ids.shape[0], |
| | input_ids.shape[1], |
| | dtype=input_ids.dtype, |
| | device=input_ids.device, |
| | ) |
| | image_index, video_index = 0, 0 |
| | attention_mask = attention_mask.to(total_input_ids.device) |
| | for i, input_ids in enumerate(total_input_ids): |
| | input_ids = input_ids[attention_mask[i] == 1] |
| | image_nums, video_nums = 0, 0 |
| | vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) |
| | vision_tokens = input_ids[vision_start_indices + 1] |
| | image_nums = (vision_tokens == image_token_id).sum() |
| | video_nums = (vision_tokens == video_token_id).sum() |
| | input_tokens = input_ids.tolist() |
| | llm_pos_ids_list: list = [] |
| | st = 0 |
| | remain_images, remain_videos = image_nums, video_nums |
| | for vision_index in range(image_nums + video_nums): |
| | if image_token_id in input_tokens and remain_images > 0: |
| | ed_image = input_tokens.index(image_token_id, st) |
| | else: |
| | ed_image = len(input_tokens) + 1 |
| | if video_token_id in input_tokens and remain_videos > 0: |
| | ed_video = input_tokens.index(video_token_id, st) |
| | else: |
| | ed_video = len(input_tokens) + 1 |
| | if ed_image < ed_video: |
| | t, h, w = ( |
| | image_grid_thw[image_index][0], |
| | image_grid_thw[image_index][1], |
| | image_grid_thw[image_index][2], |
| | ) |
| | second_per_grid_t = 0 |
| | image_index += 1 |
| | remain_images -= 1 |
| | ed = ed_image |
| |
|
| | else: |
| | t, h, w = ( |
| | video_grid_thw[video_index][0], |
| | video_grid_thw[video_index][1], |
| | video_grid_thw[video_index][2], |
| | ) |
| | if second_per_grid_ts is not None: |
| | second_per_grid_t = second_per_grid_ts[video_index] |
| | else: |
| | second_per_grid_t = 1.0 |
| | video_index += 1 |
| | remain_videos -= 1 |
| | ed = ed_video |
| | llm_grid_t, llm_grid_h, llm_grid_w = ( |
| | t.item(), |
| | h.item() // spatial_merge_size, |
| | w.item() // spatial_merge_size, |
| | ) |
| | text_len = ed - st |
| |
|
| | st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 |
| | llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) |
| |
|
| | range_tensor = torch.arange(llm_grid_t).view(-1, 1) |
| | expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) |
| |
|
| | |
| | second_per_grid_t = torch.as_tensor( |
| | second_per_grid_t, dtype=range_tensor.dtype, device=range_tensor.device |
| | ) |
| |
|
| | time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second |
| |
|
| | time_tensor_long = time_tensor.long() |
| | t_index = time_tensor_long.flatten() |
| |
|
| | h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() |
| | w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() |
| | llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) |
| | st = ed + llm_grid_t * llm_grid_h * llm_grid_w |
| | if ed_image < ed_video: |
| | |
| | compare_t_index = t_index[-1].repeat(self.compare_token_size) |
| | |
| | |
| | compare_h_index = compare_t_index |
| | compare_w_index = compare_t_index |
| | llm_pos_ids_list.append(torch.stack([compare_t_index, compare_h_index, compare_w_index]) + text_len + st_idx) |
| | st = st + self.compare_token_size |
| |
|
| | if st < len(input_tokens): |
| | st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 |
| | text_len = len(input_tokens) - st |
| | llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) |
| |
|
| | llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) |
| | position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) |
| | mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) |
| | mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) |
| | return position_ids, mrope_position_deltas |
| | else: |
| | if attention_mask is not None: |
| | position_ids = attention_mask.long().cumsum(-1) - 1 |
| | position_ids.masked_fill_(attention_mask == 0, 1) |
| | position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) |
| | max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] |
| | mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] |
| | else: |
| | position_ids = ( |
| | torch.arange(input_ids.shape[1], device=input_ids.device) |
| | .view(1, 1, -1) |
| | .expand(3, input_ids.shape[0], -1) |
| | ) |
| | mrope_position_deltas = torch.zeros( |
| | [input_ids.shape[0], 1], |
| | device=input_ids.device, |
| | dtype=input_ids.dtype, |
| | ) |
| |
|
| | return position_ids, mrope_position_deltas |
| |
|
| | class ADCopilotVLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): |
| | config_class = ADCopilotConfig |
| | |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.model = ADCopilotVLModel(config) |