#!/usr/bin/env python # -*- coding: utf-8 -*- """ ================================================ @author: Jaron @time: 2024/08/21 17:51:45 @email: fjjth98@163.com @description: ================================================ """ from typing import Union from transformers import PretrainedConfig from transformers.models.auto import CONFIG_MAPPING class CCAMConfig(PretrainedConfig): def __init__( self, num_query: int = 1024, num_heads: int = 16, hidden_size: int = 1024, intermediate_size: int = 4096, num_key_value_heads: int = 16, dropout: float = 0.1, mlp_bias: bool = True, hidden_act: str = 'swiglu', output_size: int = None, # inferred from llm attention_bias: bool = True, layer_norm_eps: float = 1e-5, cross_hidden_size: int = None, # inferred from vision encoder attention_dropout: float = 0.1, _attn_implementation: str = 'flash_attention_2', **kwargs ): super().__init__(**kwargs) self.dropout = dropout self.mlp_bias = mlp_bias self.num_query = num_query self.num_heads = num_heads self.hidden_act = hidden_act self.hidden_size = hidden_size self.output_size = output_size self.layer_norm_eps = layer_norm_eps self.attention_bias = attention_bias self.intermediate_size = intermediate_size self.cross_hidden_size = cross_hidden_size self.attention_dropout = attention_dropout self.num_key_value_heads = num_key_value_heads self._attn_implementation = _attn_implementation class VideoCCAMConfig(PretrainedConfig): model_type = 'videoccam' _auto_class = 'AutoConfig' def __init__( self, vision_config: Union[dict, PretrainedConfig] = None, text_config: Union[dict, PretrainedConfig] = None, projector_config: dict = None, image_token_id: int = None, video_token_id: int = None, **kwargs ): super().__init__(**kwargs) if isinstance(vision_config, dict): self.vision_config = CONFIG_MAPPING[vision_config['model_type']](**vision_config) else: self.vision_config = vision_config if isinstance(text_config, dict): self.text_config = CONFIG_MAPPING[text_config['model_type']](**text_config) else: self.text_config = text_config if isinstance(projector_config, dict): self.projector_config = CCAMConfig(**projector_config) else: self.projector_config = projector_config self.image_token_id = image_token_id self.video_token_id = video_token_id