Update model.py
Browse files
model.py
CHANGED
@@ -43,7 +43,7 @@ np.random.seed(0)
|
|
43 |
|
44 |
# find a better way to abstract the class
|
45 |
class GPT2PPLV2:
|
46 |
-
def __init__(self, device="cpu", model_id="gpt2"):
|
47 |
self.device = device
|
48 |
self.model_id = model_id
|
49 |
self.model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
|
@@ -54,7 +54,7 @@ class GPT2PPLV2:
|
|
54 |
self.threshold = 0.7
|
55 |
|
56 |
self.t5_model = transformers.AutoModelForSeq2SeqLM.from_pretrained("t5-base").to(device)
|
57 |
-
self.t5_tokenizer = T5Tokenizer.from_pretrained("t5-
|
58 |
|
59 |
def apply_extracted_fills(self, masked_texts, extracted_fills):
|
60 |
texts = []
|
|
|
43 |
|
44 |
# find a better way to abstract the class
|
45 |
class GPT2PPLV2:
|
46 |
+
def __init__(self, device="cpu", model_id="gpt2-medium"):
|
47 |
self.device = device
|
48 |
self.model_id = model_id
|
49 |
self.model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
|
|
|
54 |
self.threshold = 0.7
|
55 |
|
56 |
self.t5_model = transformers.AutoModelForSeq2SeqLM.from_pretrained("t5-base").to(device)
|
57 |
+
self.t5_tokenizer = T5Tokenizer.from_pretrained("t5-larger", model_max_length=512)
|
58 |
|
59 |
def apply_extracted_fills(self, masked_texts, extracted_fills):
|
60 |
texts = []
|