File size: 4,251 Bytes
25369b9
 
4c39b84
 
67cda2a
25369b9
 
 
 
88a44d0
25369b9
88a44d0
d115284
25369b9
 
 
 
 
 
d115284
25369b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9814918
25369b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67cda2a
d115284
 
d72c970
 
 
 
5e50a13
d72c970
 
 
 
 
 
 
 
 
0d6ff10
d72c970
 
 
 
 
 
 
 
d115284
 
 
25369b9
 
 
5e50a13
25369b9
 
 
 
 
 
 
 
 
 
 
 
d115284
25369b9
9814918
25369b9
 
d115284
 
 
 
67cda2a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
from tuned_lens.nn.lenses import TunedLens, LogitLens
from transformers import AutoModelForCausalLM, AutoTokenizer
from tuned_lens.plotting import plot_lens
import gradio as gr
from plotly import graph_objects as go

device = torch.device("cpu")
print(f"Using device {device} for inference")
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m-deduped")
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
tuned_lens = TunedLens.load("pythia-410m-deduped", map_location=device)
logit_lens = LogitLens(model)

lens_options_dict = {
    "Tuned Lens": tuned_lens,
    "Logit Lens": logit_lens,
}

statistic_options_dict = {
    "Entropy": "entropy",
    "Cross Entropy": "ce",
    "Forward KL": "forward_kl",
}


def make_plot(lens, text, statistic, token_cutoff):
    input_ids = tokenizer.encode(text, return_tensors="pt")

    if len(input_ids[0]) == 0:
        return go.Figure(layout=dict(title="Please enter some text."))

    if token_cutoff < 1:
        return go.Figure(layout=dict(title="Please provide valid token cut off."))

    fig = plot_lens(
        model,
        tokenizer,
        lens_options_dict[lens],
        layer_stride=2,
        input_ids=input_ids,
        start_pos=max(len(input_ids[0]) - token_cutoff, 0),
        statistic=statistic_options_dict[statistic],
    )

    return fig


preamble = """
# The Tuned Lens 🔎

A tuned lens allows us to peak at the iterative computations a transformer uses to compute the next token.

A lens into a transformer with n layers allows you to replace the last $m$ layers of the model with an [affine transformation](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) (we call these affine translators).

This essentially skips over these last few layers and lets you see the best prediction that can be made from the model's representations, i.e. the residual stream, at layer $n - m$. Since the representations may be rotated, shifted, or stretched from layer to layer it's useful to train the len's affine adapters specifically on each layer. This training is what differentiates this method from simpler approaches that decode the residual stream of the network directly using the unembeding layer i.e. the logit lens. We explain this process in [the paper](https://arxiv.org/abs/2303.08112).

## Usage
Since the tuned lens produces a distribution of predictions to visualize it's output we need to we need to provide a summary statistic to plot.  The default is simply [entropy](https://en.wikipedia.org/wiki/Entropy_(information_theory)), but you can also choose the [cross entropy](https://en.wikipedia.org/wiki/Cross_entropy) with the target token, or the [KL divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) between the model's predictions and the tuned lens' predictions. You can also hover over a token to see more of the distribution i.e. the top 10 most probable tokens and their probabilities.

## Examples
Here are some interesting examples you can try.

### Copy paste:
```
Copy: A!2j!#u&NGApS&MkkHe8Gm!#
Paste: A!2j!#u&NGApS&MkkHe8Gm!#
```

### Trivial in-context learning
```
inc 1 2
inc 4 5
inc 13 
```

#### Addition
```
add 1 1 2
add 3 4 7
add 13 2 
```
"""

with gr.Blocks() as demo:
    gr.Markdown(preamble)
    with gr.Column():
        text = gr.Textbox(
            value="it was the best of times, it was the worst of times",
            label="Input Text",
        )
        with gr.Row():
            lens_options = gr.Dropdown(
                list(lens_options_dict.keys()), value="Tuned Lens", label="Select Lens"
            )
            statistic = gr.Dropdown(
                list(statistic_options_dict.keys()),
                value="Entropy",
                label="Select Statistic",
            )
            token_cutoff = gr.Slider(
                maximum=20, minimum=2, value=10, step=1, label="Plot Last N Tokens"
            )
        examine_btn = gr.Button(value="Submit")
        plot = gr.Plot()
    examine_btn.click(make_plot, [lens_options, text, statistic, token_cutoff], plot)
    demo.load(make_plot, [lens_options, text, statistic, token_cutoff], plot)

if __name__ == "__main__":
    demo.launch()