tiny_clip / src /tokenizer.py
sachin's picture
Refactoring to distangle modules
24d96ab
raw
history blame
807 Bytes
from typing import Union
import torch
from transformers import AutoTokenizer
from src.config import TinyCLIPTextConfig
class Tokenizer:
def __init__(self, text_config: TinyCLIPTextConfig) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(text_config.text_model)
self.max_len = text_config.max_len
def __call__(self, x: Union[str, list[str]]) -> dict[str, torch.LongTensor]:
return self.tokenizer(
x, max_length=self.max_len, truncation=True, padding=True, return_tensors="pt"
) # type: ignore
def decode(self, x: dict[str, torch.LongTensor]) -> list[str]:
return [
self.tokenizer.decode(sentence[:sentence_len])
for sentence, sentence_len in zip(x["input_ids"], x["attention_mask"].sum(axis=-1))
]