File size: 4,668 Bytes
bcbc05a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import dataclasses
import json

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers


class Projection(nn.Module):
    def __init__(self, d_in: int, d_out: int, p: float = 0.5) -> None:
        super().__init__()
        self.linear1 = nn.Linear(d_in, d_out, bias=False)
        self.linear2 = nn.Linear(d_out, d_out, bias=False)
        self.layer_norm = nn.LayerNorm(d_out)
        self.drop = nn.Dropout(p)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        embed1 = self.linear1(x)
        embed2 = self.drop(self.linear2(F.gelu(embed1)))
        embeds = self.layer_norm(embed1 + embed2)
        return embeds


def projection_layers(d_in: int, d_out: int, num_layers: int) -> nn.Module:
    layers = []
    for _ in range(num_layers - 1):
        layers.extend([Projection(d_in, d_in), nn.GELU()])
    layers += [Projection(d_in, d_out)]
    return nn.Sequential(*layers)


def mean_pooling(
    text_representation: torch.FloatTensor, attention_mask: torch.LongTensor
) -> torch.FloatTensor:
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(text_representation.size()).float()
    return torch.sum(text_representation * input_mask_expanded, 1) / torch.clamp(
        input_mask_expanded.sum(1), min=1e-9
    )


class TextEncoder(nn.Module):
    def __init__(
        self,
        base: nn.Module,
        d_in: int,
        d_out: int,
        n_projection_layers: int,
        cls_token: bool = False,
    ):
        super().__init__()
        self.base = base
        self.cls_token = cls_token
        self.projection = projection_layers(d_in, d_out, n_projection_layers)
        self.base.eval()
        for p in self.base.parameters():
            p.requires_grad = False

    def forward(self, x):
        out = self.base(**x).last_hidden_state
        if self.cls_token:
            out = out[:, 0]  # get CLS token output
        else:
            out = mean_pooling(out, x["attention_mask"])

        projected_vec = self.projection(out)
        return F.normalize(projected_vec, dim=-1)


class VisionEncoder(nn.Module):
    def __init__(self, base: nn.Module, d_in: int, d_out: int, n_projection_layers: int):
        super().__init__()
        self.base = base
        self.projection = projection_layers(d_in, d_out, n_projection_layers)

        self.base.eval()
        for p in self.base.parameters():
            p.requires_grad = False

    def forward(self, x):
        projected_vec = self.projection(self.base(x))
        return F.normalize(projected_vec, dim=-1)


class Tokenizer:
    def __init__(self, tokenizer, max_len: int) -> None:
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __call__(self, x: str) -> transformers.AutoTokenizer:
        return self.tokenizer(
            x, max_length=self.max_len, truncation=True, padding=True, return_tensors="pt"
        )

    def decode(self, x: dict[str, torch.LongTensor]) -> list[str]:
        return [
            self.tokenizer.decode(sentence[:sentence_len])
            for sentence, sentence_len in zip(x["input_ids"], x["attention_mask"].sum(axis=-1))
        ]


@dataclasses.dataclass(frozen=True)
class CLIPConfig:
    cls_token: bool = True
    n_projection_layers: int = 3
    embed_dims: int = 512
    vision_model: str = "edgenext_small"
    text_model: str = "microsoft/xtremedistil-l6-h256-uncased"
    max_len: int = 128


def get_model():
    with open("./clip_config.json", "r") as f:
        config = CLIPConfig(**json.load(f))

    # load text model and tokenizer
    text_config = transformers.AutoConfig.from_pretrained("./text_model_config/")
    text_base = transformers.AutoModel.from_config(text_config)
    tokenizer = Tokenizer(
        transformers.AutoTokenizer.from_pretrained("./tokenizer/"), config.max_len
    )
    text_encoder = TextEncoder(
        text_base,
        text_base.config.hidden_size,
        config.embed_dims,
        config.n_projection_layers,
        config.cls_token,
    )
    text_encoder.load_state_dict(torch.load("./text.ckpt", map_location=torch.device("cpu")))

    # load vision model and image transform
    image_base = timm.create_model(config.vision_model, num_classes=0)
    timm_config = timm.data.resolve_data_config({}, model=image_base)
    transform = timm.data.transforms_factory.create_transform(**timm_config)
    vision_encoder = VisionEncoder(
        image_base, image_base.num_features, config.embed_dims, config.n_projection_layers
    )
    vision_encoder.load_state_dict(torch.load("./vision.ckpt", map_location=torch.device("cpu")))

    return text_encoder, tokenizer, vision_encoder, transform