#!/usr/bin/env python # -*- coding: utf-8 -*- """ ================================================ @author: Jaron @time: 2024/08/21 17:41:52 @email: fjjth98@163.com @description: Video-CCAM ================================================ """ from typing import Optional, Union import torch from PIL import Image from torch import nn from torch.nn import functional as F from transformers import (AutoImageProcessor, AutoModel, AutoModelForCausalLM, AutoTokenizer, Cache, DynamicCache, GenerationConfig, PreTrainedModel) from transformers.activations import ACT2FN from .configuration_videoccam import CCAMConfig, VideoCCAMConfig class CCAMMLP(nn.Module): def __init__(self, config): super().__init__() self.hidden_act = config.hidden_act self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.output_size = config.output_size if self.hidden_act == 'swiglu': self.fc1 = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.mlp_bias) self.act_fn = ACT2FN['silu'] else: self.fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.act_fn = ACT2FN[self.hidden_act] self.fc2 = nn.Linear(self.intermediate_size, self.output_size, bias=config.mlp_bias) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) if self.hidden_act == 'swiglu': gate, up = hidden_states.chunk(2, dim=-1) hidden_states = self.act_fn(gate) * up else: hidden_states = self.act_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states class CCAMCrossAttention(nn.Module): """Cross-attention layer of the CCAM projector. Flash Attention 2 is not supported since the mask may be neither full nor causal. Only support `attn_implementation` as `eager` and `sdpa`. """ def __init__(self, config): super().__init__() self.num_heads = config.num_heads self.hidden_size = config.hidden_size self.attention_bias = config.attention_bias self.attention_dropout = config.attention_dropout self.cross_hidden_size = config.cross_hidden_size self.num_key_value_heads = config.num_key_value_heads self.attn_implementation = config._attn_implementation self.head_dim = self.hidden_size // self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads assert self.head_dim * self.num_heads == self.hidden_size, f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads}).' self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.attention_bias) self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=self.attention_bias) self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=self.attention_bias) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.attention_bias) def forward( self, hidden_states: torch.Tensor, # (B, Q, C) cross_hidden_states: torch.Tensor, # (B, L, C') attention_mask: torch.Tensor = None # (Q, L), '-inf' means masked, 0 means not masked ) -> torch.Tensor: # (B, Q, C) B, Q, C = hidden_states.size() query_states = self.q_proj(hidden_states) # (B, Q, C) key_states = self.k_proj(cross_hidden_states) value_states = self.v_proj(cross_hidden_states) L = key_states.size(1) query_states = query_states.view(B, Q, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(B, L, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(B, L, self.num_key_value_heads, self.head_dim).transpose(1, 2) if self.num_key_value_groups > 1: key_states = key_states.repeat_interleave(repeats=self.num_key_value_groups, dim=1) value_states = value_states.repeat_interleave(repeats=self.num_key_value_groups, dim=1) if self.attn_implementation == 'eager': attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / self.head_dim ** 0.5 # (B, num_heads, Q, L) if attention_mask is not None: attn_weights = attn_weights + attention_mask.view(1, 1, Q, L) # upcast attention to fp32 attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) # (B, num_heads, Q, head_dim) else: # 'sdpa' # there are bugs in torch <=2.1.0, requiring qkv as contiguous(), be careful attn_output = F.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0 ) attn_output = attn_output.transpose(1, 2).reshape(B, Q, C) # (B, Q, C) attn_output = self.o_proj(attn_output) return attn_output class CCAMModel(PreTrainedModel): config_class = CCAMConfig _no_split_modules = ['CCAMCrossAttention'] _supports_flash_attn_2 = True # actually flash_attention_2 is not supported in the projector, manually convert it to sdpa _supports_sdpa = True def __init__(self, config: CCAMConfig): super().__init__(config) self.num_query = config.num_query self.hidden_size = config.hidden_size self.output_size = config.output_size self.cross_hidden_size = config.cross_hidden_size self.query = nn.Parameter(torch.empty(1, self.num_query, self.hidden_size).normal_(mean=.0, std=.02)) self.pre_ccam = nn.Sequential( nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps), nn.Dropout(config.dropout) ) self.ccam = CCAMCrossAttention(config) self.post_ccam = nn.Sequential( nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps), nn.Dropout(config.dropout), CCAMMLP(config) ) def get_ccam(self, vision_hidden_state: torch.Tensor) -> torch.Tensor: # (Q, T*L) """Compute CCAM Mask for vision hidden state Args: vision_hidden_state (torch.Tensor): (T, L, C) Returns: torch.Tensor: (Q, T*L) -inf means masked """ T, L, _ = vision_hidden_state.size() dtype, device = vision_hidden_state.dtype, vision_hidden_state.device base_mask = torch.zeros(T, T, dtype=dtype, device=device) t = torch.arange(T, device=device) base_mask.masked_fill_(t > t[:, None], float('-inf')) attention_mask = torch.zeros(self.num_query, T * L, dtype=dtype, device=device) attention_mask[:self.num_query // T * T] = torch.kron(base_mask, torch.ones(self.num_query // T, L, dtype=dtype, device=device)) return attention_mask def forward(self, vision_hidden_states: list[torch.Tensor]) -> torch.Tensor: # (B, Q, C) """Forward function, do not collect batch due to the support of zero3 Args: vision_hidden_states (list[torch.Tensor]): [(t0, L, C), (t1, L, C), ...] Returns: torch.Tensor: (B, Q, C) """ output = [] for hidden_states in vision_hidden_states: # reshape inputs and construct ccam masks attention_mask = self.get_ccam(hidden_states) # (Q, ti * L) # forward x = self.pre_ccam(self.query) # (1, Q, C) x = self.ccam( hidden_states=x, # (1, Q, C) cross_hidden_states=hidden_states.flatten(0, 1)[None], # (1, ti * L, C') attention_mask=attention_mask[None] # (1, Q, ti * L) ) + x x = self.post_ccam(x) output.append(x) output = torch.cat(output, dim=0) return output # Modified from transformers.models.llava_next.modeling_llava_next.py class VideoCCAM(PreTrainedModel): config_class = VideoCCAMConfig _auto_class = 'AutoModel' _supports_flash_attn_2 = True def __init__(self, config: VideoCCAMConfig): super().__init__(config) # the following only works for SiglipVisionModel self.vision_encoder = AutoModel.from_config(config.vision_config, torch_dtype=config.torch_dtype, attn_implementation=config._attn_implementation) self.vision_encoder.vision_model.post_layernorm = nn.Identity() self.projector = CCAMModel._from_config(config.projector_config, torch_dtype=config.torch_dtype, attn_implementation=config._attn_implementation) self.llm = AutoModelForCausalLM.from_config(config.text_config, torch_dtype=config.torch_dtype, attn_implementation=config._attn_implementation) self.post_init() # copied from transformers.models.llava_next.modeling_llava_next def _init_weights(self, module, std=.02): if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() @property def _supports_sdpa(self): """ Retrieve language_model's attribute to check whether the model supports SDPA or not. """ return self.llm._supports_sdpa @property def _no_split_modules(self): """ Retrieve language_model's attribute to check whether the model supports SDPA or not. """ return self.vision_encoder._no_split_modules + self.projector._no_split_modules + self.llm._no_split_modules @torch.inference_mode def generate( self, input_ids: list[list[int]] = None, # [(l_0,), (l_1,), ...] pixel_values: torch.FloatTensor = None, # (t_0+t_1+..., 3, H, W) vision_split_sizes: list[int] = None, # [t_0, t_1, ...] past_key_values: Union[tuple, Cache] = None, batch_generation: bool = False, generation_config: GenerationConfig = None, **kwargs ) -> tuple[torch.LongTensor, Optional[Cache]]: """Generation for multi-modal inputs Args: input_ids (list[list[int]]): input token indices, use list[int] for efficient embeddings concatenation. pixel_values (torch.FloatTensor): input image/video (processed) pixel values. vision_split_sizes (list[int]): for each vision token (,