jbloom's picture
Update README.md
5c639de verified
|
raw
history blame
2.32 kB
metadata
license: mit

Gemma 2b Residual Stream SAEs.

This is a "quick and dirty" SAE release to unblock researchers. These SAEs have not been extensively studied or characterized. However, I will try to update the readme here when I add SAEs here to reflect what I know about them.

These SAEs were trained with SAE Lens and the library version is stored in the cfg.json.

All training hyperparameters are specified in cfg.json.

They are loadable using SAE via a few methods. A method that currently works (but may be replaced shortly by a more convenient method) would be the following:

import torch
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader

torch.set_grad_enabled(False)
path = "path/to/folder_containing_cfgjson_and_safetensors_file"
model, sae, activation_store = LMSparseAutoencoderSessionloader.load_pretrained_sae(
    path, device = "cuda",
)

Resid Post 0

Stats:

  • 16384 Features (expansion factor 8)
  • CE Loss score of 99.1% (2.647 without SAE, 2.732 with the SAE)
  • Mean L0 54 (in practice L0 is log normal distributed and is heavily right tailed).
  • Dead Features: We think this SAE may have ~2.5k dead features.

Notes:

  • This SAE was trained with methods from the Anthropic April Update excepting activation normalization.
  • It is likely under-trained.

Resid Post 6

Stats:

  • 16384 Features (expansion factor 8) achieving a CE Loss score of
  • CE Loss score of 95.33% (2.647 without SAE, 3.103 with the SAE)
  • Mean L0 53 (in practice L0 is log normal distributed and is heavily right tailed).
  • Dead Features: We think this SAE may have up to 7k dead features.

Notes:

  • This SAE was trained with methods from the Anthropic April Update
    • Excepting activation normalization.
    • We increased the learning rate here by one order of magnitude in order to explore whether this resulted in faster training (in particular, a lower L0 more quickly)
      • We find in practice that the drop in L0 is accelerated but this results is significantly more dead features (likely causing worse reconstruction)
  • As above, it is likely under-trained.