File size: 807 Bytes
16d5d78
 
 
 
 
24d96ab
 
16d5d78
 
24d96ab
 
 
16d5d78
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from typing import Union

import torch
from transformers import AutoTokenizer

from src.config import TinyCLIPTextConfig


class Tokenizer:
    def __init__(self, text_config: TinyCLIPTextConfig) -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(text_config.text_model)
        self.max_len = text_config.max_len

    def __call__(self, x: Union[str, list[str]]) -> dict[str, torch.LongTensor]:
        return self.tokenizer(
            x, max_length=self.max_len, truncation=True, padding=True, return_tensors="pt"
        )  # type: ignore

    def decode(self, x: dict[str, torch.LongTensor]) -> list[str]:
        return [
            self.tokenizer.decode(sentence[:sentence_len])
            for sentence, sentence_len in zip(x["input_ids"], x["attention_mask"].sum(axis=-1))
        ]