sachin commited on
Commit
c6fe3c5
1 Parent(s): 31e368b

succesful local run

Browse files
.vscode/settings.json CHANGED
@@ -2,7 +2,7 @@
2
  "files.insertFinalNewline": true,
3
  "jupyter.debugJustMyCode": false,
4
  "editor.formatOnSave": true,
5
- "editor.formatOnPaste": true,
6
  "files.autoSave": "onFocusChange",
7
  "editor.defaultFormatter": "ms-python.black-formatter",
8
  "black-formatter.path": ["/opt/homebrew/bin/black"],
@@ -12,12 +12,12 @@
12
  "isort.check": true,
13
  "python.analysis.typeCheckingMode": "basic",
14
  "python.defaultInterpreterPath": "/opt/homebrew/bin/python3",
15
- "[python]": {
16
- "editor.defaultFormatter": "ms-python.black-formatter",
17
- "editor.formatOnSave": true,
18
- "editor.codeActionsOnSave": {
19
- "source.organizeImports": "explicit"
20
- },
21
- },
22
- "isort.args":["--profile", "black"],
23
  }
 
2
  "files.insertFinalNewline": true,
3
  "jupyter.debugJustMyCode": false,
4
  "editor.formatOnSave": true,
5
+ // "editor.formatOnPaste": true,
6
  "files.autoSave": "onFocusChange",
7
  "editor.defaultFormatter": "ms-python.black-formatter",
8
  "black-formatter.path": ["/opt/homebrew/bin/black"],
 
12
  "isort.check": true,
13
  "python.analysis.typeCheckingMode": "basic",
14
  "python.defaultInterpreterPath": "/opt/homebrew/bin/python3",
15
+ // "[python]": {
16
+ // "editor.defaultFormatter": "ms-python.black-formatter",
17
+ // "editor.formatOnSave": true,
18
+ // "editor.codeActionsOnSave": {
19
+ // "source.organizeImports": "explicit"
20
+ // },
21
+ // },
22
+ // "isort.args":["--profile", "black"],
23
  }
src/config.py CHANGED
@@ -6,6 +6,14 @@ from transformers import PretrainedConfig
6
  MAX_DOWNLOAD_TIME = 0.2
7
 
8
  IMAGE_DOWNLOAD_PATH = pathlib.Path("./data/images")
 
 
 
 
 
 
 
 
9
 
10
 
11
  class DataConfig(pydantic.BaseModel):
@@ -97,6 +105,8 @@ class TrainerConfig(pydantic.BaseModel):
97
  lambda_2: float = 1.0
98
 
99
  val_check_interval: int = 1000
 
 
100
 
101
  run_openai_clip: bool = False
102
 
 
6
  MAX_DOWNLOAD_TIME = 0.2
7
 
8
  IMAGE_DOWNLOAD_PATH = pathlib.Path("./data/images")
9
+ WANDB_LOG_PATH = pathlib.Path("/tmp/wandb_logs")
10
+
11
+ IMAGE_DOWNLOAD_PATH.mkdir(parents=True, exist_ok=True)
12
+ WANDB_LOG_PATH.mkdir(parents=True, exist_ok=True)
13
+
14
+ MODEL_NAME = "tiny_clip"
15
+
16
+ WANDB_ENTITY = "sachinruk"
17
 
18
 
19
  class DataConfig(pydantic.BaseModel):
 
105
  lambda_2: float = 1.0
106
 
107
  val_check_interval: int = 1000
108
+ log_every_n_steps: int = 100
109
+ debug: bool = False
110
 
111
  run_openai_clip: bool = False
112
 
src/data.py CHANGED
@@ -37,7 +37,7 @@ class CollateFn:
37
  tokenized_text = self.tokenizer([item["caption"] for item in batch])
38
 
39
  return {
40
- "image": stacked_images,
41
  **tokenized_text,
42
  }
43
 
 
37
  tokenized_text = self.tokenizer([item["caption"] for item in batch])
38
 
39
  return {
40
+ "images": stacked_images,
41
  **tokenized_text,
42
  }
43
 
src/lightning_module.py CHANGED
@@ -24,10 +24,11 @@ class LightningModule(pl.LightningModule):
24
  self.hyper_parameters = hyper_parameters
25
  self.len_train_dl = len_train_dl
26
 
27
- def common_step(self, batch: tuple[torch.Tensor, list[str]], step_kind: str) -> torch.Tensor:
28
- text, images = batch
29
- image_features = self.vision_encoder(images)
30
- text_features = self.text_encoder(text)
 
31
  similarity_matrix = loss_utils.get_similarity_matrix(image_features, text_features)
32
 
33
  loss = self.loss_fn(similarity_matrix, image_features, text_features)
@@ -52,10 +53,6 @@ class LightningModule(pl.LightningModule):
52
  "params": self.vision_encoder.projection.parameters(),
53
  "lr": self.hyper_parameters.learning_rate,
54
  },
55
- {
56
- "params": self.vision_encoder.base.parameters(),
57
- "lr": self.hyper_parameters.learning_rate / 2,
58
- },
59
  ]
60
  caption_params = [
61
  {
 
24
  self.hyper_parameters = hyper_parameters
25
  self.len_train_dl = len_train_dl
26
 
27
+ def common_step(self, batch: dict[str, torch.Tensor], step_kind: str) -> torch.Tensor:
28
+ image_features = self.vision_encoder(batch["images"])
29
+ text_features = self.text_encoder(
30
+ {key: value for key, value in batch.items() if key != "images"}
31
+ )
32
  similarity_matrix = loss_utils.get_similarity_matrix(image_features, text_features)
33
 
34
  loss = self.loss_fn(similarity_matrix, image_features, text_features)
 
53
  "params": self.vision_encoder.projection.parameters(),
54
  "lr": self.hyper_parameters.learning_rate,
55
  },
 
 
 
 
56
  ]
57
  caption_params = [
58
  {
src/models.py CHANGED
@@ -77,10 +77,8 @@ class TinyCLIPVisionEncoder(PreTrainedModel):
77
  num_features, config.embed_dims, config.projection_layers
78
  )
79
 
80
- def forward(self, images: list[Image.Image]):
81
- x: torch.Tensor = torch.stack([self.transform(image) for image in images]) # type: ignore
82
-
83
- projected_vec = self.projection(self.base(x))
84
  return F.normalize(projected_vec, dim=-1)
85
 
86
 
 
77
  num_features, config.embed_dims, config.projection_layers
78
  )
79
 
80
+ def forward(self, images: torch.Tensor):
81
+ projected_vec = self.projection(self.base(images))
 
 
82
  return F.normalize(projected_vec, dim=-1)
83
 
84
 
src/trainer.py CHANGED
@@ -1,25 +1,34 @@
1
- from src import data
2
  from src import config
3
- from src import vision_model
4
- from src import tokenizer as tk
5
- from src.lightning_module import LightningModule
6
  from src import loss
7
  from src import models
 
 
 
 
8
 
9
 
10
- def train(config: config.TrainerConfig):
11
- transform = vision_model.get_vision_transform(config._model_config.vision_config)
12
- tokenizer = tk.Tokenizer(config._model_config.text_config)
13
  train_dl, valid_dl = data.get_dataset(
14
- transform=transform, tokenizer=tokenizer, hyper_parameters=config # type: ignore
15
  )
16
- vision_encoder = models.TinyCLIPVisionEncoder(config=config._model_config.vision_config)
17
- text_encoder = models.TinyCLIPTextEncoder(config=config._model_config.text_config)
18
 
19
  lightning_module = LightningModule(
20
  vision_encoder=vision_encoder,
21
  text_encoder=text_encoder,
22
- loss_fn=loss.get_loss(config._model_config.loss_type),
23
- hyper_parameters=config,
24
  len_train_dl=len(train_dl),
25
  )
 
 
 
 
 
 
 
 
 
 
1
  from src import config
2
+ from src import data
 
 
3
  from src import loss
4
  from src import models
5
+ from src import tokenizer as tk
6
+ from src import vision_model
7
+ from src import utils
8
+ from src.lightning_module import LightningModule
9
 
10
 
11
+ def train(trainer_config: config.TrainerConfig):
12
+ transform = vision_model.get_vision_transform(trainer_config._model_config.vision_config)
13
+ tokenizer = tk.Tokenizer(trainer_config._model_config.text_config)
14
  train_dl, valid_dl = data.get_dataset(
15
+ transform=transform, tokenizer=tokenizer, hyper_parameters=trainer_config # type: ignore
16
  )
17
+ vision_encoder = models.TinyCLIPVisionEncoder(config=trainer_config._model_config.vision_config)
18
+ text_encoder = models.TinyCLIPTextEncoder(config=trainer_config._model_config.text_config)
19
 
20
  lightning_module = LightningModule(
21
  vision_encoder=vision_encoder,
22
  text_encoder=text_encoder,
23
+ loss_fn=loss.get_loss(trainer_config._model_config.loss_type),
24
+ hyper_parameters=trainer_config,
25
  len_train_dl=len(train_dl),
26
  )
27
+
28
+ trainer = utils.get_trainer(trainer_config)
29
+ trainer.fit(lightning_module, train_dl, valid_dl)
30
+
31
+
32
+ if __name__ == "__main__":
33
+ trainer_config = config.TrainerConfig(debug=True)
34
+ train(trainer_config)
src/utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+
3
+ import pytorch_lightning as pl
4
+ from pytorch_lightning import loggers
5
+
6
+ from src import config
7
+
8
+
9
+ def _get_wandb_logger(trainer_config: config.TrainerConfig):
10
+ name = f"{config.MODEL_NAME}-{datetime.datetime.now()}"
11
+ if trainer_config.debug:
12
+ name = "debug-" + name
13
+ return loggers.WandbLogger(
14
+ entity=config.WANDB_ENTITY,
15
+ save_dir=config.WANDB_LOG_PATH,
16
+ project=config.MODEL_NAME,
17
+ name=name,
18
+ config=trainer_config._model_config.to_dict(),
19
+ )
20
+
21
+
22
+ def get_trainer(trainer_config: config.TrainerConfig):
23
+ return pl.Trainer(
24
+ max_epochs=trainer_config.epochs if not trainer_config.debug else 1,
25
+ logger=_get_wandb_logger(trainer_config),
26
+ log_every_n_steps=trainer_config.log_every_n_steps,
27
+ gradient_clip_val=1.0,
28
+ limit_train_batches=5 if trainer_config.debug else 1.0,
29
+ limit_val_batches=5 if trainer_config.debug else 1.0,
30
+ accelerator="auto",
31
+ )