zpn commited on
Commit
2cea8db
1 Parent(s): 487b330

mup base shapes path

Browse files
Files changed (1) hide show
  1. modeling_nt_bert.py +4 -2
modeling_nt_bert.py CHANGED
@@ -78,11 +78,13 @@ class BertPreTrainedModel(PreTrainedModel):
78
 
79
  # since we used MuP, need to reset values since they're not saved with the model
80
  if os.path.exists("base_shapes.bsh") is False:
81
- hf_hub_download(
82
  "zpn/human_bp_bert", "base_shapes.bsh"
83
  )
 
 
84
 
85
- set_base_shapes(model, "base_shapes.bsh", rescale_params=False)
86
 
87
  return model
88
 
 
78
 
79
  # since we used MuP, need to reset values since they're not saved with the model
80
  if os.path.exists("base_shapes.bsh") is False:
81
+ path = hf_hub_download(
82
  "zpn/human_bp_bert", "base_shapes.bsh"
83
  )
84
+ else:
85
+ path = "base_shapes.bsh"
86
 
87
+ set_base_shapes(model, path, rescale_params=False)
88
 
89
  return model
90