File size: 3,436 Bytes
0a3525d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from typing import Optional, Union

import lightning.pytorch as pl
import torch
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from torch import Tensor, nn
from torch.utils._foreach_utils import (
    _group_tensors_by_device_and_dtype,
    _has_foreach_support,
)


@torch.no_grad()
def grad_norm(
    parameters: Union[Tensor, list[Tensor]],
    norm_type: float = 2.0,
) -> float:
    """
    Returns the norm of the gradients of the given parameters.

    Args:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        norm_type (float): type of the used p-norm.

    Returns:
        Total norm of the parameter gradients (viewed as a single vector).
    """  # noqa: E501

    if isinstance(parameters, Tensor):
        parameters = [parameters]

    grads = [p.grad for p in parameters if p.grad is not None]
    if len(grads) == 0:
        return None

    first_device = grads[0].device
    grouped_grads: dict[
        tuple[torch.device, torch.dtype], list[list[Tensor]]
    ] = _group_tensors_by_device_and_dtype(
        [[g.detach() for g in grads]]
    )  # type: ignore[assignment]

    norms = []
    for (device, _), ([grads], _) in grouped_grads.items():
        if _has_foreach_support(grads, device=device):
            norms.extend(torch._foreach_norm(grads, norm_type))
        else:
            norms.extend([torch.norm(g, norm_type) for g in grads])

    return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)


class GradNormMonitor(Callback):
    """
    Callback that computes the gradient norm of the model parameters.
    """

    def __init__(
        self,
        norm_type: float = 2.0,
        logging_interval: str = "step",
        sub_module: Optional[Union[str, list[str]]] = None,
    ) -> None:
        """
        Args:
            norm_type (float): type of the used p-norm.
            logging_interval (str): "step" or "epoch".
        """
        super().__init__()

        self.norm_type = norm_type
        self.logging_interval = logging_interval
        self.sub_module = sub_module

    def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
        """
        Computes the gradient norm of the model parameters and logs it to the logger.

        Args:
            trainer (Trainer): The trainer object
            model (LightningModule): The current lightningModule
        """

        lightning_model = model

        if self.sub_module is None:
            return self.log_sub_module_grad_norm(lightning_model, model, "")

        sub_modules = self.sub_module
        if isinstance(sub_modules, str):
            sub_modules = [sub_modules]

        for sub_module in sub_modules:
            self.log_sub_module_grad_norm(
                lightning_model, getattr(model, sub_module), f"/{sub_module}"
            )

    def log_sub_module_grad_norm(
        self, lightning_model: LightningModule, model: nn.Module, path: str
    ) -> None:
        grad_norm_val = grad_norm(model.parameters(), self.norm_type)
        if grad_norm_val is None:
            return

        on_step = self.logging_interval == "step"
        lightning_model.log(
            f"train{path}/grad_norm",
            grad_norm_val,
            on_step=on_step,
            on_epoch=not on_step,
        )