|
|
|
|
|
""" |
|
================================================ |
|
@author: Jaron |
|
@time: 2024/08/21 17:51:45 |
|
@email: [email protected] |
|
@description: |
|
================================================ |
|
""" |
|
from typing import Union |
|
|
|
from transformers import PretrainedConfig |
|
from transformers.models.auto import CONFIG_MAPPING |
|
|
|
|
|
class CCAMConfig(PretrainedConfig): |
|
|
|
def __init__( |
|
self, |
|
num_query: int = 1024, |
|
num_heads: int = 16, |
|
hidden_size: int = 1024, |
|
intermediate_size: int = 4096, |
|
num_key_value_heads: int = 16, |
|
dropout: float = 0.1, |
|
mlp_bias: bool = True, |
|
hidden_act: str = 'swiglu', |
|
output_size: int = None, |
|
attention_bias: bool = True, |
|
layer_norm_eps: float = 1e-5, |
|
cross_hidden_size: int = None, |
|
attention_dropout: float = 0.1, |
|
_attn_implementation: str = 'flash_attention_2', |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.dropout = dropout |
|
self.mlp_bias = mlp_bias |
|
self.num_query = num_query |
|
self.num_heads = num_heads |
|
self.hidden_act = hidden_act |
|
self.hidden_size = hidden_size |
|
self.output_size = output_size |
|
self.layer_norm_eps = layer_norm_eps |
|
self.attention_bias = attention_bias |
|
self.intermediate_size = intermediate_size |
|
self.cross_hidden_size = cross_hidden_size |
|
self.attention_dropout = attention_dropout |
|
self.num_key_value_heads = num_key_value_heads |
|
self._attn_implementation = _attn_implementation |
|
|
|
|
|
class VideoCCAMConfig(PretrainedConfig): |
|
model_type = 'videoccam' |
|
_auto_class = 'AutoConfig' |
|
|
|
def __init__( |
|
self, |
|
vision_config: Union[dict, PretrainedConfig] = None, |
|
text_config: Union[dict, PretrainedConfig] = None, |
|
projector_config: dict = None, |
|
image_token_id: int = None, |
|
video_token_id: int = None, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
if isinstance(vision_config, dict): |
|
self.vision_config = CONFIG_MAPPING[vision_config['model_type']](**vision_config) |
|
else: |
|
self.vision_config = vision_config |
|
if isinstance(text_config, dict): |
|
self.text_config = CONFIG_MAPPING[text_config['model_type']](**text_config) |
|
else: |
|
self.text_config = text_config |
|
if isinstance(projector_config, dict): |
|
self.projector_config = CCAMConfig(**projector_config) |
|
else: |
|
self.projector_config = projector_config |
|
self.image_token_id = image_token_id |
|
self.video_token_id = video_token_id |
|
|