# %% import torch from transformers import ( BertTokenizer, BertForMaskedLM, AutoModelForMaskedLM, AutoTokenizer, BertModel, ) import numpy as np import random from itertools import islice from torch.utils.data import Dataset, DataLoader from torch.optim import AdamW, SGD from tqdm import tqdm import os def index_to_onehot(l, length): # l=[1, 5], len=6 -> [0,1,0,0,0,1] return [1 if i in l else 0 for i in range(length)] def get_punctuation_position(tokenized_text, tokenizer): # adjust comma_pos and period_pos count = 0 comma_pos = [] period_pos = [] punctuation_removed_text = [] comma_id = tokenizer.convert_tokens_to_ids("、") period_id = tokenizer.convert_tokens_to_ids("。") for i, c in enumerate(tokenized_text): if c == comma_id: comma_pos.append(i - count - 1) count += 1 elif c == period_id: period_pos.append(i - count - 1) count += 1 else: punctuation_removed_text.append(c) if len(punctuation_removed_text) < 512: punctuation_removed_text += [tokenizer.pad_token_id] * ( 512 - len(punctuation_removed_text) ) return ( torch.tensor(punctuation_removed_text), [ index_to_onehot(comma_pos, 512), index_to_onehot(period_pos, 512), ], ) # %% # get_punctuation_position("今日は、いい天気です。") # # %% # index_to_onehot([1, 2, 3, 4, 5], 7) # tokenizer = BertTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-char") # tokenized_text = tokenizer( # "今 日 は 、 い い 天 気 で す 。", # max_length=512, # padding="max_length", # truncation=True, # return_tensors="pt", # ) # inputs, label = get_punctuation_position(tokenized_text["input_ids"][0], tokenizer) # print(inputs) # ->tensor([ 2, 732, 48, 12, 19, 19, 411, 343, 17, 46, 3, 0, 0, 0, ...]) # print(label) # -> [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ...], # 点の位置(最初に[SOS]が入るため、1つずれる) # -> [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ...]] # 丸の位置 # %% class PunctuationPositionDataset(torch.utils.data.Dataset): def __init__(self, data, tokenizer): self.data = data self.tokenizer = tokenizer def __len__(self): return len(self.data) def __getitem__(self, idx): text = self.data[idx] text = " ".join(list(text)) inputs = self.tokenizer( text, max_length=512, padding="max_length", truncation=True, return_tensors="pt", ) # if idx % 100 == 0: # print(masked_text, label) input_ids, label = get_punctuation_position( inputs["input_ids"][0], self.tokenizer ) label = torch.tensor(label, dtype=torch.float32).transpose(0, 1) return (input_ids, inputs.attention_mask.squeeze(), label.squeeze(), text) # %% model_name = "tohoku-nlp/bert-base-japanese-char-v3" tokenizer = BertTokenizer.from_pretrained(model_name) base_model = BertModel.from_pretrained(model_name) # %% class punctuation_predictor(torch.nn.Module): def __init__(self, base_model): super().__init__() self.base_model = base_model self.dropout = torch.nn.Dropout(0.2) self.linear = torch.nn.Linear(768, 2) def forward(self, input_ids, attention_mask): last_hidden_state = self.base_model( input_ids=input_ids, attention_mask=attention_mask ).last_hidden_state # get last hidden state token by token and apply linear layer return self.linear(self.dropout(last_hidden_state)) model = punctuation_predictor(base_model) # %% # a = tokenizer("今 日 は い い 天 気 で す 。",max_length=512, # padding="max_length", # truncation=True, # return_tensors="pt",) # %% with open("data/train.txt", "r") as f: texts = f.readlines() dataset = PunctuationPositionDataset(texts, tokenizer) # %% data_loader = DataLoader( dataset, batch_size=16, shuffle=True, num_workers=8, ) # %% # set lr to 5e-5 to base model optimizer = AdamW( [ {"params": model.base_model.parameters(), "lr": 5e-5}, {"params": model.linear.parameters(), "lr": 1e-3}, ], ) criteria = torch.nn.BCEWithLogitsLoss() # %% model.train() model.to("cuda") for epoch in range(10): epoch_loss = 0.0 progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}") for batch in progress_bar: input_ids, attention_masks, labels, text = batch input_ids = input_ids.to("cuda") attention_masks = attention_masks.to("cuda") labels = labels.to("cuda") outputs = model(input_ids=input_ids, attention_mask=attention_masks) loss = criteria(outputs, labels) loss.backward() optimizer.step() optimizer.zero_grad() epoch_loss += loss.item() progress_bar.set_postfix({"loss": epoch_loss / len(data_loader)}) # %% torch.save(model.state_dict(), "weight/punctuation_position_model.pth")