tiny_clip / src /tokenizer.py
sachin's picture
Got data batch running locally
16d5d78
raw
history blame
738 Bytes
from typing import Union
import torch
from transformers import AutoTokenizer
class Tokenizer:
def __init__(self, model_name: str, max_len: int) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.max_len = max_len
def __call__(self, x: Union[str, list[str]]) -> dict[str, torch.LongTensor]:
return self.tokenizer(
x, max_length=self.max_len, truncation=True, padding=True, return_tensors="pt"
) # type: ignore
def decode(self, x: dict[str, torch.LongTensor]) -> list[str]:
return [
self.tokenizer.decode(sentence[:sentence_len])
for sentence, sentence_len in zip(x["input_ids"], x["attention_mask"].sum(axis=-1))
]