File size: 10,325 Bytes
7f9376c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
"""Functions to help with searching codes using regex."""

import pickle
import re
from dataclasses import dataclass
from typing import Optional

import numpy as np
import torch
from tqdm import tqdm

import utils


def load_dataset_cache(cache_base_path):
    """Load cache files required for dataset from `cache_base_path`."""
    tokens_str = np.load(cache_base_path + "tokens_str.npy")
    tokens_text = np.load(cache_base_path + "tokens_text.npy")
    token_byte_pos = np.load(cache_base_path + "token_byte_pos.npy")
    return tokens_str, tokens_text, token_byte_pos


def load_code_search_cache(cache_base_path):
    """Load cache files required for code search from `cache_base_path`."""
    metrics = np.load(cache_base_path + "metrics.npy", allow_pickle=True).item()
    with open(cache_base_path + "cb_acts.pkl", "rb") as f:
        cb_acts = pickle.load(f)
    with open(cache_base_path + "act_count_ft_tkns.pkl", "rb") as f:
        act_count_ft_tkns = pickle.load(f)

    return cb_acts, act_count_ft_tkns, metrics


def search_re(re_pattern, tokens_text):
    """Get list of (example_id, token_pos) where re_pattern matches in tokens_text."""
    # TODO: ensure that parantheses are not escaped
    if re_pattern.find("(") == -1:
        re_pattern = f"({re_pattern})"
    return [
        (i, finditer.span(1)[0])
        for i, text in enumerate(tokens_text)
        for finditer in re.finditer(re_pattern, text)
        if finditer.span(1)[0] != finditer.span(1)[1]
    ]


def byte_id_to_token_pos_id(example_byte_id, token_byte_pos):
    """Get (example_id, token_pos_id) for given (example_id, byte_id)."""
    example_id, byte_id = example_byte_id
    index = np.searchsorted(token_byte_pos[example_id], byte_id, side="right")
    return (example_id, index)


def get_code_pr(token_pos_ids, codebook_acts, cb_act_counts=None):
    """Get codes, prec, recall for given token_pos_ids and codebook_acts."""
    codes = np.array(
        [
            codebook_acts[example_id][token_pos_id]
            for example_id, token_pos_id in token_pos_ids
        ]
    )
    codes, counts = np.unique(codes, return_counts=True)
    recall = counts / len(token_pos_ids)
    idx = recall > 0.01
    codes, counts, recall = codes[idx], counts[idx], recall[idx]
    if cb_act_counts is not None:
        code_acts = np.array([cb_act_counts[code] for code in codes])
        prec = counts / code_acts
        sort_idx = np.argsort(prec)[::-1]
    else:
        code_acts = np.zeros_like(codes)
        prec = np.zeros_like(codes)
        sort_idx = np.argsort(recall)[::-1]
    codes, prec, recall = codes[sort_idx], prec[sort_idx], recall[sort_idx]
    code_acts = code_acts[sort_idx]
    return codes, prec, recall, code_acts


def get_neuron_pr(

    token_pos_ids, recall, neuron_acts_by_ex, neuron_sorted_acts, topk=10

):
    """Get codes, prec, recall for given token_pos_ids and codebook_acts."""
    # check if neuron_acts_by_ex is a torch tensor
    if isinstance(neuron_acts_by_ex, torch.Tensor):
        re_neuron_acts = torch.stack(
            [
                neuron_acts_by_ex[example_id, token_pos_id]
                for example_id, token_pos_id in token_pos_ids
            ],
            dim=-1,
        )  # (layers, 2, dim_size, matches)
        re_neuron_acts = torch.sort(re_neuron_acts, dim=-1).values
    else:
        re_neuron_acts = np.stack(
            [
                neuron_acts_by_ex[example_id, token_pos_id]
                for example_id, token_pos_id in token_pos_ids
            ],
            axis=-1,
        )  # (layers, 2, dim_size, matches)
        re_neuron_acts.sort(axis=-1)
        re_neuron_acts = torch.from_numpy(re_neuron_acts)
    # re_neuron_acts = re_neuron_acts[:, :, :, -int(recall * re_neuron_acts.shape[-1]) :]
    print("Examples for recall", recall, ":", int(recall * re_neuron_acts.shape[-1]))
    act_thresh = re_neuron_acts[:, :, :, -int(recall * re_neuron_acts.shape[-1])]
    # binary search act_thresh in neuron_sorted_acts
    assert neuron_sorted_acts.shape[:-1] == act_thresh.shape
    prec_den = torch.searchsorted(neuron_sorted_acts, act_thresh.unsqueeze(-1))
    prec_den = prec_den.squeeze(-1)
    prec_den = neuron_sorted_acts.shape[-1] - prec_den
    prec = int(recall * re_neuron_acts.shape[-1]) / prec_den
    assert (
        prec.shape == re_neuron_acts.shape[:-1]
    ), f"{prec.shape} != {re_neuron_acts.shape[:-1]}"

    best_neuron_idx = np.unravel_index(prec.argmax(), prec.shape)
    best_prec = prec[best_neuron_idx]
    print("max prec:", best_prec)
    best_neuron_act_thresh = act_thresh[best_neuron_idx].item()
    best_neuron_acts = neuron_acts_by_ex[
        :, :, best_neuron_idx[0], best_neuron_idx[1], best_neuron_idx[2]
    ]
    best_neuron_acts = best_neuron_acts >= best_neuron_act_thresh
    best_neuron_acts = np.stack(np.where(best_neuron_acts), axis=-1)

    return best_prec, best_neuron_acts, best_neuron_idx


def convert_to_adv_name(name, cb_at, ccb=""):
    """Convert layer0_head0 to layer0_attn_preproj_ccb0."""
    if ccb:
        layer, head = name.split("_")
        return layer + f"_{cb_at}_ccb" + head[4:]
    else:
        return layer + "_" + cb_at


def convert_to_base_name(name, ccb=""):
    """Convert layer0_attn_preproj_ccb0 to layer0_head0."""
    split_name = name.split("_")
    layer, head = split_name[0], split_name[-1][3:]
    if "ccb" in name:
        return layer + "_head" + head
    else:
        return layer


def get_layer_head_from_base_name(name):
    """Convert layer0_head0 to 0, 0."""
    split_name = name.split("_")
    layer = int(split_name[0][5:])
    head = None
    if len(split_name) > 1:
        head = int(split_name[-1][4:])
    return layer, head


def get_layer_head_from_adv_name(name):
    """Convert layer0_attn_preproj_ccb0 to 0, 0."""
    base_name = convert_to_base_name(name)
    layer, head = get_layer_head_from_base_name(base_name)
    return layer, head


def get_codes_from_pattern(

    re_pattern,

    tokens_text,

    token_byte_pos,

    cb_acts,

    act_count_ft_tkns,

    ccb="",

    topk=5,

    prec_threshold=0.5,

):
    """Fetch codes from a given regex pattern."""
    byte_ids = search_re(re_pattern, tokens_text)
    token_pos_ids = [
        byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids
    ]
    token_pos_ids = np.unique(token_pos_ids, axis=0)
    re_token_matches = len(token_pos_ids)
    codebook_wise_codes = {}
    for cb_name, cb in tqdm(cb_acts.items()):
        base_cb_name = convert_to_base_name(cb_name, ccb=ccb)
        codes, prec, recall, code_acts = get_code_pr(
            token_pos_ids,
            cb,
            cb_act_counts=act_count_ft_tkns[base_cb_name],
        )
        idx = np.arange(min(topk, len(codes)))
        idx = idx[prec[:topk] > prec_threshold]
        codes, prec, recall = codes[idx], prec[idx], recall[idx]
        code_acts = code_acts[idx]
        codes_pr = list(zip(codes, prec, recall, code_acts))
        codebook_wise_codes[base_cb_name] = codes_pr
    return codebook_wise_codes, re_token_matches


def get_neurons_from_pattern(

    re_pattern,

    tokens_text,

    token_byte_pos,

    neuron_acts_by_ex,

    neuron_sorted_acts,

    recall_threshold,

):
    """Fetch the best neuron (with act thresh given by recall) from a given regex pattern."""
    byte_ids = search_re(re_pattern, tokens_text)
    token_pos_ids = [
        byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids
    ]
    token_pos_ids = np.unique(token_pos_ids, axis=0)
    re_token_matches = len(token_pos_ids)
    best_prec, best_neuron_acts, best_neuron_idx = get_neuron_pr(
        token_pos_ids,
        recall_threshold,
        neuron_acts_by_ex,
        neuron_sorted_acts,
    )
    return best_prec, best_neuron_acts, best_neuron_idx, re_token_matches


def compare_codes_with_neurons(

    best_codes_info,

    tokens_text,

    token_byte_pos,

    neuron_acts_by_ex,

    neuron_sorted_acts,

):
    """Compare codes with neurons."""
    assert isinstance(neuron_acts_by_ex, np.ndarray)
    (
        all_best_prec,
        all_best_neuron_acts,
        all_best_neuron_idxs,
        all_re_token_matches,
    ) = zip(
        *[
            get_neurons_from_pattern(
                code_info.re_pattern,
                tokens_text,
                token_byte_pos,
                neuron_acts_by_ex,
                neuron_sorted_acts,
                code_info.recall,
            )
            for code_info in tqdm(range(len(best_codes_info)))
        ],
        strict=True,
    )
    code_best_precs = np.array(
        [code_info.prec for code_info in range(len(best_codes_info))]
    )
    codes_better_than_neurons = code_best_precs > np.array(all_best_prec)
    return codes_better_than_neurons.mean()


def get_code_info_pr_from_str(code_txt, regex):
    """Extract code info fields from string."""
    code_txt = code_txt.strip()
    code_txt = code_txt.split(", ")
    code_txt = dict(txt.split(": ") for txt in code_txt)
    return utils.CodeInfo(**code_txt)


@dataclass
class ModelInfoForWebapp:
    """Model info for webapp."""

    model_name: str
    pretrained_path: str
    dataset_name: str
    num_codes: int
    cb_at: str
    ccb: str
    n_layers: int
    n_heads: Optional[int] = None
    seed: int = 42
    max_samples: int = 2000

    def __post_init__(self):
        """Convert to correct types."""
        self.num_codes = int(self.num_codes)
        self.n_layers = int(self.n_layers)
        if self.n_heads == "None":
            self.n_heads = None
        elif self.n_heads is not None:
            self.n_heads = int(self.n_heads)
        self.seed = int(self.seed)
        self.max_samples = int(self.max_samples)


def parse_model_info(path):
    """Parse model info from path."""
    with open(path + "info.txt", "r") as f:
        lines = f.readlines()
        lines = dict(line.strip().split(": ") for line in lines)
        return ModelInfoForWebapp(**lines)
        return ModelInfoForWebapp(**lines)