import datetime import pytorch_lightning as pl from pytorch_lightning import loggers from src import config def _get_wandb_logger(trainer_config: config.TrainerConfig): name = f"{config.MODEL_NAME}-{datetime.datetime.now()}" if trainer_config.debug: name = "debug-" + name return loggers.WandbLogger( entity=config.WANDB_ENTITY, save_dir=config.WANDB_LOG_PATH, project=config.MODEL_NAME, name=name, config=trainer_config._model_config.to_dict(), ) def get_trainer(trainer_config: config.TrainerConfig): return pl.Trainer( max_epochs=trainer_config.epochs if not trainer_config.debug else 1, logger=_get_wandb_logger(trainer_config), log_every_n_steps=trainer_config.log_every_n_steps, gradient_clip_val=1.0, limit_train_batches=5 if trainer_config.debug else 1.0, limit_val_batches=5 if trainer_config.debug else 1.0, accelerator="auto", )