|
import logging |
|
import os |
|
from typing import NoReturn |
|
|
|
import pytorch_lightning as pl |
|
import torch |
|
import torch.nn as nn |
|
from pytorch_lightning.utilities import rank_zero_only |
|
|
|
|
|
class SaveCheckpointsCallback(pl.Callback): |
|
def __init__( |
|
self, |
|
model: nn.Module, |
|
checkpoints_dir: str, |
|
save_step_frequency: int, |
|
): |
|
r"""Callback to save checkpoints every #save_step_frequency steps. |
|
|
|
Args: |
|
model: nn.Module |
|
checkpoints_dir: str, directory to save checkpoints |
|
save_step_frequency: int |
|
""" |
|
self.model = model |
|
self.checkpoints_dir = checkpoints_dir |
|
self.save_step_frequency = save_step_frequency |
|
os.makedirs(self.checkpoints_dir, exist_ok=True) |
|
|
|
@rank_zero_only |
|
def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn: |
|
r"""Save checkpoint.""" |
|
global_step = trainer.global_step |
|
|
|
if global_step % self.save_step_frequency == 0: |
|
|
|
checkpoint_path = os.path.join( |
|
self.checkpoints_dir, "step={}.pth".format(global_step) |
|
) |
|
|
|
checkpoint = {'step': global_step, 'model': self.model.state_dict()} |
|
|
|
torch.save(checkpoint, checkpoint_path) |
|
logging.info("Save checkpoint to {}".format(checkpoint_path)) |
|
|