JosephH commited on
Commit
9cc1e73
1 Parent(s): 1bd4388

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -2
model.py CHANGED
@@ -43,7 +43,7 @@ np.random.seed(0)
43
 
44
  # find a better way to abstract the class
45
  class GPT2PPLV2:
46
- def __init__(self, device="cpu", model_id="gpt2"):
47
  self.device = device
48
  self.model_id = model_id
49
  self.model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
@@ -54,7 +54,7 @@ class GPT2PPLV2:
54
  self.threshold = 0.7
55
 
56
  self.t5_model = transformers.AutoModelForSeq2SeqLM.from_pretrained("t5-base").to(device)
57
- self.t5_tokenizer = T5Tokenizer.from_pretrained("t5-base", model_max_length=512)
58
 
59
  def apply_extracted_fills(self, masked_texts, extracted_fills):
60
  texts = []
 
43
 
44
  # find a better way to abstract the class
45
  class GPT2PPLV2:
46
+ def __init__(self, device="cpu", model_id="gpt2-medium"):
47
  self.device = device
48
  self.model_id = model_id
49
  self.model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
 
54
  self.threshold = 0.7
55
 
56
  self.t5_model = transformers.AutoModelForSeq2SeqLM.from_pretrained("t5-base").to(device)
57
+ self.t5_tokenizer = T5Tokenizer.from_pretrained("t5-larger", model_max_length=512)
58
 
59
  def apply_extracted_fills(self, masked_texts, extracted_fills):
60
  texts = []