|
from transformers import BertTokenizer, BertModel |
|
import torch |
|
|
|
|
|
class TextEmbedder: |
|
def __init__(self): |
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
self.model = BertModel.from_pretrained('bert-base-uncased') |
|
|
|
def _mean_pooling(self, model_output, attention_mask): |
|
token_embeddings = model_output.last_hidden_state |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) |
|
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
return sum_embeddings / sum_mask |
|
|
|
def embed_text(self, examples): |
|
inputs = self.tokenizer( |
|
examples["content"], padding=True, truncation=True, return_tensors="pt" |
|
) |
|
with torch.no_grad(): |
|
model_output = self.model(**inputs) |
|
pooled_embeds = self._mean_pooling(model_output, inputs["attention_mask"]) |
|
return {"embedding": pooled_embeds.cpu().numpy()} |
|
|
|
def generate_embeddings(self, dataset): |
|
return dataset.map(self.embed_text, batched=True, batch_size=128) |
|
|
|
def embed_query(self, query_text): |
|
query_inputs = self.tokenizer( |
|
query_text, |
|
padding=True, |
|
truncation=True, |
|
return_tensors="pt" |
|
) |
|
|
|
with torch.no_grad(): |
|
query_model_output = self.model(**query_inputs) |
|
|
|
query_embedding = self._mean_pooling(query_model_output, query_inputs["attention_mask"]) |
|
|
|
return query_embedding |