sachin commited on
Commit
18cb46c
1 Parent(s): 69fda24

Succesfully uploaded model to HF hub in correct place

Browse files
Files changed (2) hide show
  1. src/config.py +5 -0
  2. src/trainer.py +29 -10
src/config.py CHANGED
@@ -8,12 +8,17 @@ MAX_DOWNLOAD_TIME = 0.2
8
  IMAGE_DOWNLOAD_PATH = pathlib.Path("./data/images")
9
  WANDB_LOG_PATH = pathlib.Path("/tmp/wandb_logs")
10
  MODEL_PATH = pathlib.Path("/tmp/models")
 
 
11
 
12
  IMAGE_DOWNLOAD_PATH.mkdir(parents=True, exist_ok=True)
13
  WANDB_LOG_PATH.mkdir(parents=True, exist_ok=True)
14
  MODEL_PATH.mkdir(parents=True, exist_ok=True)
 
 
15
 
16
  MODEL_NAME = "tiny_clip"
 
17
 
18
  WANDB_ENTITY = "sachinruk"
19
 
 
8
  IMAGE_DOWNLOAD_PATH = pathlib.Path("./data/images")
9
  WANDB_LOG_PATH = pathlib.Path("/tmp/wandb_logs")
10
  MODEL_PATH = pathlib.Path("/tmp/models")
11
+ VISION_MODEL_PATH = MODEL_PATH / "vision"
12
+ TEXT_MODEL_PATH = MODEL_PATH / "text"
13
 
14
  IMAGE_DOWNLOAD_PATH.mkdir(parents=True, exist_ok=True)
15
  WANDB_LOG_PATH.mkdir(parents=True, exist_ok=True)
16
  MODEL_PATH.mkdir(parents=True, exist_ok=True)
17
+ VISION_MODEL_PATH.mkdir(parents=True, exist_ok=True)
18
+ TEXT_MODEL_PATH.mkdir(parents=True, exist_ok=True)
19
 
20
  MODEL_NAME = "tiny_clip"
21
+ REPO_ID = "sachin/clip-model"
22
 
23
  WANDB_ENTITY = "sachinruk"
24
 
src/trainer.py CHANGED
@@ -1,5 +1,8 @@
1
  import os
2
 
 
 
 
3
  from src import config
4
  from src import data
5
  from src import loss
@@ -11,23 +14,39 @@ from src.lightning_module import LightningModule
11
 
12
 
13
  def _upload_model_to_hub(
14
- vision_encoder: models.TinyCLIPVisionEncoder, text_encoder: models.TinyCLIPTextEncoder
 
 
15
  ):
16
  vision_encoder.save_pretrained(
17
- str(config.MODEL_PATH),
18
- variant="vision_encoder",
19
  safe_serialization=True,
20
- push_to_hub=True,
21
- repo_id="debug-clip-model",
22
  )
23
  text_encoder.save_pretrained(
24
- str(config.MODEL_PATH),
25
- variant="text_encoder",
26
  safe_serialization=True,
27
- push_to_hub=True,
28
- repo_id="debug-clip-model",
29
  )
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def train(trainer_config: config.TrainerConfig):
33
  if "HF_TOKEN" not in os.environ:
@@ -51,7 +70,7 @@ def train(trainer_config: config.TrainerConfig):
51
  trainer = utils.get_trainer(trainer_config)
52
  trainer.fit(lightning_module, train_dl, valid_dl)
53
 
54
- _upload_model_to_hub(vision_encoder, text_encoder)
55
 
56
 
57
  if __name__ == "__main__":
 
1
  import os
2
 
3
+ from huggingface_hub import HfApi
4
+ from loguru import logger
5
+
6
  from src import config
7
  from src import data
8
  from src import loss
 
14
 
15
 
16
  def _upload_model_to_hub(
17
+ vision_encoder: models.TinyCLIPVisionEncoder,
18
+ text_encoder: models.TinyCLIPTextEncoder,
19
+ debug: bool = False,
20
  ):
21
  vision_encoder.save_pretrained(
22
+ str(config.VISION_MODEL_PATH),
 
23
  safe_serialization=True,
 
 
24
  )
25
  text_encoder.save_pretrained(
26
+ str(config.TEXT_MODEL_PATH),
 
27
  safe_serialization=True,
 
 
28
  )
29
 
30
+ api = HfApi()
31
+ if debug:
32
+ repo_components = config.REPO_ID.split("/", maxsplit=1)
33
+ repo_components[1] = f"debug-{repo_components[1]}"
34
+ repo_id = "/".join(repo_components)
35
+ else:
36
+ repo_id = config.REPO_ID
37
+ common_hf_api_params = {
38
+ "repo_id": repo_id,
39
+ "repo_type": "model",
40
+ }
41
+ if not api.repo_exists(**common_hf_api_params):
42
+ logger.info(f"Creating repo {repo_id} on Hugging Face Hub.")
43
+ api.create_repo(**common_hf_api_params) # type: ignore
44
+ logger.info(f"Uploading models in {str(config.MODEL_PATH)} to {repo_id}.")
45
+ api.upload_folder(
46
+ folder_path=config.MODEL_PATH,
47
+ **common_hf_api_params, # type: ignore
48
+ ) # type: ignore
49
+
50
 
51
  def train(trainer_config: config.TrainerConfig):
52
  if "HF_TOKEN" not in os.environ:
 
70
  trainer = utils.get_trainer(trainer_config)
71
  trainer.fit(lightning_module, train_dl, valid_dl)
72
 
73
+ _upload_model_to_hub(vision_encoder, text_encoder, trainer_config.debug)
74
 
75
 
76
  if __name__ == "__main__":