File size: 4,071 Bytes
5e1c8df 35352c6 c7a14ad 35352c6 5e1c8df a8c8fe0 c6fe3c5 69fda24 18cb46c c6fe3c5 69fda24 18cb46c c6fe3c5 18cb46c c6fe3c5 5e1c8df 35352c6 5e1c8df 35352c6 c7a14ad 24d96ab c7a14ad 571c526 c7a14ad 35352c6 a8c8fe0 35352c6 bd0d978 35352c6 c6fe3c5 35352c6 24d96ab 5e1c8df 571c526 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import pathlib
import pydantic
from transformers import PretrainedConfig
MAX_DOWNLOAD_TIME = 0.2
IMAGE_DOWNLOAD_PATH = pathlib.Path("./data/images")
WANDB_LOG_PATH = pathlib.Path("/tmp/wandb_logs")
MODEL_PATH = pathlib.Path("/tmp/models")
VISION_MODEL_PATH = MODEL_PATH / "vision"
TEXT_MODEL_PATH = MODEL_PATH / "text"
IMAGE_DOWNLOAD_PATH.mkdir(parents=True, exist_ok=True)
WANDB_LOG_PATH.mkdir(parents=True, exist_ok=True)
MODEL_PATH.mkdir(parents=True, exist_ok=True)
VISION_MODEL_PATH.mkdir(parents=True, exist_ok=True)
TEXT_MODEL_PATH.mkdir(parents=True, exist_ok=True)
MODEL_NAME = "tiny_clip"
REPO_ID = "sachin/clip-model"
WANDB_ENTITY = "sachinruk"
class DataConfig(pydantic.BaseModel):
buffer_size: int = 1000
data_len: int = 100
train_len: int = 90
small_dataset: str = "laion/220k-gpt4vision-captions-from-livis"
large_dataset: str = "laion/laion400m"
dataset: str = small_dataset
class TinyCLIPTextConfig(PretrainedConfig):
model_type = "text"
def __init__(
self,
text_model: str = "microsoft/xtremedistil-l6-h256-uncased",
projection_layers: int = 3,
embed_dims: int = 512,
max_len: int = 128,
cls_type: bool = True,
**kwargs,
):
self.text_model = text_model
self.projection_layers = projection_layers
self.embed_dims = embed_dims
self.max_len = max_len
self.cls_type = cls_type
super().__init__(**kwargs)
class TinyCLIPVisionConfig(PretrainedConfig):
model_type = "vision"
def __init__(
self,
vision_model: str = "edgenext_small",
projection_layers: int = 3,
embed_dims: int = 512,
**kwargs,
):
self.vision_model = vision_model
self.projection_layers = projection_layers
self.embed_dims = embed_dims
super().__init__(**kwargs)
class TinyCLIPConfig(PretrainedConfig):
model_type = "clip"
def __init__(
self,
text_model: str = "microsoft/xtremedistil-l6-h256-uncased",
vision_model: str = "edgenext_small",
projection_layers: int = 3,
embed_dim: int = 512,
max_len: int = 128,
cls_type: bool = True,
freeze_vision_base: bool = False,
freeze_text_base: bool = True,
loss_type: str = "cyclip",
**kwargs,
):
self.text_config = TinyCLIPTextConfig(
text_model=text_model,
projection_layers=projection_layers,
embed_dims=embed_dim,
max_len=max_len,
cls_type=cls_type,
)
self.vision_config = TinyCLIPVisionConfig(
vision_model=vision_model, projection_layers=projection_layers, embed_dims=embed_dim
)
self.freeze_vision_base = freeze_vision_base
self.freeze_text_base = freeze_text_base
self.loss_type = loss_type
super().__init__(**kwargs)
@classmethod
def from_dict(cls, config_dict, **kwargs):
text_config_dict = config_dict.pop("text_config", {})
text_config = TinyCLIPTextConfig.from_dict(text_config_dict)
vision_config_dict = config_dict.pop("vision_config", {})
vision_config = TinyCLIPVisionConfig.from_dict(vision_config_dict)
return cls(text_config=text_config, vision_config=vision_config, **config_dict, **kwargs)
class TrainerConfig(pydantic.BaseModel):
epochs: int = 20
batch_size: int = 64
learning_rate: float = 5e-4
lr_scheduler: bool = True
accumulate_grad_batches: int = 1
temperature: float = 1.0
vision_freeze_layers: int = 2
lambda_1: float = 1.0
lambda_2: float = 1.0
val_check_interval: int = 1000
log_every_n_steps: int = 100
debug: bool = False
run_openai_clip: bool = False
_model_config: TinyCLIPConfig = TinyCLIPConfig()
_data_config: DataConfig = DataConfig()
def __init__(self, **data):
super().__init__(**data)
if "_model_config" in data:
self._model_config = TinyCLIPConfig.from_dict(data["_model_config"])
|