sachin commited on
Commit
6d1b6c6
1 Parent(s): 180681d

Initial training code

Browse files
Files changed (3) hide show
  1. src/loss.py +2 -13
  2. src/metrics.py +12 -0
  3. src/trainer.py +91 -0
src/loss.py CHANGED
@@ -3,17 +3,6 @@ from torch import nn
3
  import torch.nn.functional as F
4
 
5
 
6
- def metrics(similarity: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
7
- y = torch.arange(len(similarity)).to(similarity.device)
8
- img2cap_match_idx = similarity.argmax(dim=1)
9
- cap2img_match_idx = similarity.argmax(dim=0)
10
-
11
- img_acc = (img2cap_match_idx == y).float().mean()
12
- cap_acc = (cap2img_match_idx == y).float().mean()
13
-
14
- return img_acc, cap_acc
15
-
16
-
17
  def get_similarity_matrix(
18
  image_features: torch.Tensor, text_features: torch.Tensor
19
  ) -> torch.Tensor:
@@ -34,7 +23,7 @@ class CLIPLoss(nn.Module):
34
  super().__init__()
35
  self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
36
 
37
- def forward(self, similarity_matrix: torch.Tensor):
38
  temperature = self.logit_temperature.sigmoid()
39
 
40
  caption_loss = contrastive_loss(similarity_matrix / temperature, dim=0)
@@ -77,7 +66,7 @@ class SigLIPLoss(nn.Module):
77
  super().__init__()
78
  self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
79
 
80
- def forward(self, similarity_matrix: torch.Tensor):
81
  temperature = self.logit_temperature.sigmoid()
82
  return contrastive_sigmoid_loss(similarity_matrix / temperature)
83
 
 
3
  import torch.nn.functional as F
4
 
5
 
 
 
 
 
 
 
 
 
 
 
 
6
  def get_similarity_matrix(
7
  image_features: torch.Tensor, text_features: torch.Tensor
8
  ) -> torch.Tensor:
 
23
  super().__init__()
24
  self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
25
 
26
+ def forward(self, similarity_matrix: torch.Tensor, *args):
27
  temperature = self.logit_temperature.sigmoid()
28
 
29
  caption_loss = contrastive_loss(similarity_matrix / temperature, dim=0)
 
66
  super().__init__()
67
  self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
68
 
69
+ def forward(self, similarity_matrix: torch.Tensor, *args):
70
  temperature = self.logit_temperature.sigmoid()
71
  return contrastive_sigmoid_loss(similarity_matrix / temperature)
72
 
src/metrics.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def metrics(similarity_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
5
+ y = torch.arange(len(similarity_matrix)).to(similarity_matrix.device)
6
+ img2cap_match_idx = similarity_matrix.argmax(dim=1)
7
+ cap2img_match_idx = similarity_matrix.argmax(dim=0)
8
+
9
+ img_acc = (img2cap_match_idx == y).float().mean()
10
+ cap_acc = (cap2img_match_idx == y).float().mean()
11
+
12
+ return img_acc, cap_acc
src/trainer.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from src import config
6
+ from src import loss as loss_utils
7
+ from src import metrics
8
+ from src import models
9
+
10
+
11
+ class LightningModule(pl.LightningModule):
12
+ def __init__(
13
+ self,
14
+ vision_encoder: models.TinyCLIPVisionEncoder,
15
+ text_encoder: models.TinyCLIPTextEncoder,
16
+ loss_fn: nn.Module,
17
+ hyper_parameters: config.TrainerConfig,
18
+ len_train_dl: int,
19
+ ) -> None:
20
+ super().__init__()
21
+ self.vision_encoder = vision_encoder
22
+ self.text_encoder = text_encoder
23
+ self.loss_fn = loss_fn
24
+ self.hyper_parameters = hyper_parameters
25
+ self.len_train_dl = len_train_dl
26
+
27
+ def common_step(self, batch: tuple[torch.Tensor, list[str]], step_kind: str) -> torch.Tensor:
28
+ text, images = batch
29
+ image_features = self.vision_encoder(images)
30
+ text_features = self.text_encoder(text)
31
+ similarity_matrix = loss_utils.get_similarity_matrix(image_features, text_features)
32
+
33
+ loss = self.loss_fn(similarity_matrix, image_features, text_features)
34
+
35
+ img_acc, cap_acc = metrics.metrics(similarity_matrix)
36
+
37
+ self.log(f"{step_kind}_loss", loss, on_step=False, on_epoch=True)
38
+ self.log(f"{step_kind}_img_acc", img_acc, on_step=False, on_epoch=True, prog_bar=True)
39
+ self.log(f"{step_kind}_cap_acc", cap_acc, on_step=False, on_epoch=True, prog_bar=True)
40
+ return loss
41
+
42
+ def training_step(self, batch: tuple[torch.Tensor, list[str]], *args: list) -> torch.Tensor:
43
+ loss = self.common_step(batch, step_kind="training")
44
+ return loss
45
+
46
+ def validation_step(self, batch: tuple[torch.Tensor, list[str]], *args: list):
47
+ _ = self.common_step(batch, step_kind="training")
48
+
49
+ def configure_optimizers(self):
50
+ # TODO: Add loss parameters here
51
+ vision_params = [
52
+ {
53
+ "params": self.vision_encoder.projection.parameters(),
54
+ "lr": self.hyper_parameters.learning_rate,
55
+ },
56
+ {
57
+ "params": self.vision_encoder.base.parameters(),
58
+ "lr": self.hyper_parameters.learning_rate / 2,
59
+ },
60
+ ]
61
+ caption_params = [
62
+ {
63
+ "params": self.text_encoder.projection.parameters(),
64
+ "lr": self.hyper_parameters.learning_rate,
65
+ },
66
+ ]
67
+ if not self.hyper_parameters.freeze_text_base:
68
+ caption_params += [
69
+ {
70
+ "params": self.text_encoder.base.encoder.parameters(),
71
+ "lr": self.hyper_parameters.learning_rate / 2,
72
+ },
73
+ ]
74
+
75
+ optimizer = torch.optim.Adam(vision_params + caption_params)
76
+
77
+ if self.hyper_parameters.lr_scheduler:
78
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
79
+ optimizer,
80
+ max_lr=self.hyper_parameters.learning_rate,
81
+ total_steps=self.trainer.estimated_stepping_batches,
82
+ )
83
+ return [optimizer], [scheduler]
84
+ else:
85
+ return optimizer
86
+
87
+ def on_epoch_end(self):
88
+ if self.current_epoch == 0:
89
+ for p in self.vision_encoder.base.parameters():
90
+ p.requires_grad = True
91
+ self.vision_encoder.base.train()