"""Util functions for codebook features.""" import re import typing from dataclasses import dataclass from functools import partial from typing import Optional import numpy as np import torch import torch.nn.functional as F from termcolor import colored from tqdm import tqdm @dataclass class CodeInfo: """Dataclass for codebook info.""" code: int layer: int head: Optional[int] cb_at: Optional[str] = None # for patching interventions pos: Optional[int] = None code_pos: Optional[int] = -1 # for description & regex-based interpretation description: Optional[str] = None regex: Optional[str] = None prec: Optional[float] = None recall: Optional[float] = None num_acts: Optional[int] = None def __post_init__(self): """Convert to appropriate types.""" self.code = int(self.code) self.layer = int(self.layer) if self.head: self.head = int(self.head) if self.pos: self.pos = int(self.pos) if self.code_pos: self.code_pos = int(self.code_pos) if self.prec: self.prec = float(self.prec) assert 0 <= self.prec <= 1 if self.recall: self.recall = float(self.recall) assert 0 <= self.recall <= 1 if self.num_acts: self.num_acts = int(self.num_acts) def check_description_info(self): """Check if the regex info is present.""" assert self.num_acts is not None and self.description is not None if self.regex is not None: assert self.prec is not None and self.recall is not None def check_patch_info(self): """Check if the patch info is present.""" # TODO: pos can be none for patching assert self.pos is not None and self.code_pos is not None def __repr__(self): """Return the string representation.""" repr = f"CodeInfo(code={self.code}, layer={self.layer}, head={self.head}, cb_at={self.cb_at}" if self.pos is not None or self.code_pos is not None: repr += f", pos={self.pos}, code_pos={self.code_pos}" if self.description is not None: repr += f", description={self.description}" if self.regex is not None: repr += f", regex={self.regex}, prec={self.prec}, recall={self.recall}" if self.num_acts is not None: repr += f", num_acts={self.num_acts}" repr += ")" return repr def logits_to_pred(logits, tokenizer, k=5): """Convert logits to top-k predictions.""" sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) probs = sorted_logits.softmax(dim=-1) topk_preds = [tokenizer.convert_ids_to_tokens(e) for e in sorted_indices[:, -1, :k]] topk_preds = [ tokenizer.convert_tokens_to_string([e]) for batch in topk_preds for e in batch ] return [(topk_preds[i], probs[:, -1, i].item()) for i in range(len(topk_preds))] def patch_codebook_ids( corrupted_codebook_ids, hook, pos, cache, cache_pos=None, code_idx=None ): """Patch codebook ids with cached ids.""" if cache_pos is None: cache_pos = pos if code_idx is None: corrupted_codebook_ids[:, pos] = cache[hook.name][:, cache_pos] else: for code_id in range(32): if code_id in code_idx: corrupted_codebook_ids[:, pos, code_id] = cache[hook.name][ :, cache_pos, code_id ] else: corrupted_codebook_ids[:, pos, code_id] = -1 return corrupted_codebook_ids def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False): """Calculate the average logit difference between the answer and the other token.""" # Only the final logits are relevant for the answer final_logits = logits[:, -1, :] answer_logits = final_logits.gather(dim=-1, index=answer_tokens) answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1] if per_prompt: return answer_logit_diff else: return answer_logit_diff.mean() def normalize_patched_logit_diff( patched_logit_diff, base_average_logit_diff, corrupted_average_logit_diff, ): """Normalize the patched logit difference.""" # Subtract corrupted logit diff to measure the improvement, # divide by the total improvement from clean to corrupted to normalise # 0 means zero change, negative means actively made worse, # 1 means totally recovered clean performance, >1 means actively *improved* on clean performance return (patched_logit_diff - corrupted_average_logit_diff) / ( base_average_logit_diff - corrupted_average_logit_diff ) def features_to_tokens(cb_key, cb_acts, num_codes, code=None): """Return the set of token ids each codebook feature activates on.""" codebook_ids = cb_acts[cb_key] if code is None: features_tokens = [[] for _ in range(num_codes)] for i in tqdm(range(codebook_ids.shape[0])): for j in range(codebook_ids.shape[1]): for k in range(codebook_ids.shape[2]): features_tokens[codebook_ids[i, j, k]].append((i, j)) else: idx0, idx1, _ = np.where(codebook_ids == code) features_tokens = list(zip(idx0, idx1)) return features_tokens def color_str(s: str, color: str, html: bool): """Color the string for html or terminal.""" if html: return f"{s}" else: return colored(s, color) def color_tokens_red_automata(tokens, red_idx, html=False): """Separate states with a dash and color red the tokens in red_idx.""" ret_string = "" itr_over_red_idx = 0 tokens_enumerate = enumerate(tokens) if tokens[0] == "<|endoftext|>": next(tokens_enumerate) if red_idx[0] == 0: itr_over_red_idx += 1 for i, c in tokens_enumerate: if i % 2 == 1: ret_string += "-" if itr_over_red_idx < len(red_idx) and i == red_idx[itr_over_red_idx]: ret_string += color_str(c, "red", html) itr_over_red_idx += 1 else: ret_string += c return ret_string def color_tokens_red(tokens, red_idx, n=3, html=False): """Color red the tokens in red_idx.""" ret_string = "" last_colored_token_idx = -1 for i in red_idx: c_str = tokens[i] if i <= last_colored_token_idx + 2 * n + 1: ret_string += "".join(tokens[last_colored_token_idx + 1 : i]) else: ret_string += "".join( tokens[last_colored_token_idx + 1 : last_colored_token_idx + n + 1] ) ret_string += " ... " ret_string += "".join(tokens[i - n : i]) ret_string += color_str(c_str, "red", html) last_colored_token_idx = i ret_string += "".join( tokens[ last_colored_token_idx + 1 : min(last_colored_token_idx + n, len(tokens)) ] ) return ret_string def prepare_example_print( example_id, example_tokens, tokens_to_color_red, html, color_red_fn=color_tokens_red, ): """Format example to print.""" example_output = color_str(example_id, "green", html) example_output += ( ": " + color_red_fn(example_tokens, tokens_to_color_red, html=html) + ("
" if html else "\n") ) return example_output def tkn_print( ll, tokens, separate_states, n=3, max_examples=100, randomize=False, html=False, return_example_list=False, ): """Format and prints the tokens in ll.""" if randomize: raise NotImplementedError("Randomize not yet implemented.") indices = range(len(ll)) print_output = [] if return_example_list else "" curr_ex = ll[0][0] total_examples = 0 tokens_to_color_red = [] color_red_fn = ( color_tokens_red_automata if separate_states else partial(color_tokens_red, n=n) ) for idx in indices: if total_examples > max_examples: break i, j = ll[idx] if i != curr_ex and curr_ex >= 0: curr_ex_output = prepare_example_print( curr_ex, tokens[curr_ex], tokens_to_color_red, html, color_red_fn, ) total_examples += 1 if return_example_list: print_output.append((curr_ex_output, len(tokens_to_color_red))) else: print_output += curr_ex_output curr_ex = i tokens_to_color_red = [] tokens_to_color_red.append(j) curr_ex_output = prepare_example_print( curr_ex, tokens[curr_ex], tokens_to_color_red, html, color_red_fn, ) if return_example_list: print_output.append((curr_ex_output, len(tokens_to_color_red))) else: print_output += curr_ex_output asterisk_str = "********************************************" print_output += color_str(asterisk_str, "green", html) total_examples += 1 return print_output def print_ft_tkns( ft_tkns, tokens, separate_states=False, n=3, start=0, stop=1000, indices=None, max_examples=100, freq_filter=None, randomize=False, html=False, return_example_list=False, ): """Print the tokens for the codebook features.""" indices = list(range(start, stop)) if indices is None else indices num_tokens = len(tokens) * len(tokens[0]) codes, token_act_freqs, token_acts = [], [], [] for i in indices: tkns = ft_tkns[i] freq = (len(tkns), 100 * len(tkns) / num_tokens) if freq_filter is not None and freq[1] > freq_filter: continue codes.append(i) token_act_freqs.append(freq) if len(tkns) > 0: tkn_acts = tkn_print( tkns, tokens, separate_states, n=n, max_examples=max_examples, randomize=randomize, html=html, return_example_list=return_example_list, ) token_acts.append(tkn_acts) else: token_acts.append("") return codes, token_act_freqs, token_acts def patch_in_codes(run_cb_ids, hook, pos, code, code_pos=None): """Patch in the `code` at `run_cb_ids`.""" pos = slice(None) if pos is None else pos code_pos = slice(None) if code_pos is None else code_pos if code_pos == "append": assert pos == slice(None) run_cb_ids = F.pad(run_cb_ids, (0, 1), mode="constant", value=code) if isinstance(pos, typing.Iterable) or isinstance(pos, typing.Iterable): for p in pos: run_cb_ids[:, p, code_pos] = code else: run_cb_ids[:, pos, code_pos] = code return run_cb_ids def get_cb_layer_name(cb_at, layer_idx, head_idx=None): """Get the layer name used to store hooks/cache.""" if head_idx is None: return f"blocks.{layer_idx}.{cb_at}.codebook_layer.hook_codebook_ids" else: return f"blocks.{layer_idx}.{cb_at}.codebook_layer.codebook.{head_idx}.hook_codebook_ids" def get_cb_layer_names(layer, patch_types, n_heads): """Get the layer names used to store hooks/cache.""" layer_names = [] attn_added, mlp_added = False, False if "attn_out" in patch_types: attn_added = True for head in range(n_heads): layer_names.append( f"blocks.{layer}.attn.codebook_layer.codebook.{head}.hook_codebook_ids" ) if "mlp_out" in patch_types: mlp_added = True layer_names.append(f"blocks.{layer}.mlp.codebook_layer.hook_codebook_ids") for patch_type in patch_types: # match patch_type of the pattern attn_\d_head_\d attn_head = re.match(r"attn_(\d)_head_(\d)", patch_type) if (not attn_added) and attn_head and attn_head[1] == str(layer): layer_names.append( f"blocks.{layer}.attn.codebook_layer.codebook.{attn_head[2]}.hook_codebook_ids" ) mlp = re.match(r"mlp_(\d)", patch_type) if (not mlp_added) and mlp and mlp[1] == str(layer): layer_names.append(f"blocks.{layer}.mlp.codebook_layer.hook_codebook_ids") return layer_names def cb_layer_name_to_info(layer_name): """Get the layer info from the layer name.""" layer_name_split = layer_name.split(".") layer_idx = int(layer_name_split[1]) cb_at = layer_name_split[2] if cb_at == "mlp": head_idx = None else: head_idx = int(layer_name_split[5]) return cb_at, layer_idx, head_idx def get_hooks(code, cb_at, layer_idx, head_idx=None, pos=None): """Get the hooks for the codebook features.""" hook_fns = [ partial(patch_in_codes, pos=pos, code=code[i]) for i in range(len(code)) ] return [ (get_cb_layer_name(cb_at[i], layer_idx[i], head_idx[i]), hook_fns[i]) for i in range(len(code)) ] def run_with_codes( input, cb_model, code, cb_at, layer_idx, head_idx=None, pos=None, prepend_bos=True ): """Run the model with the codebook features patched in.""" hook_fns = [ partial(patch_in_codes, pos=pos, code=code[i]) for i in range(len(code)) ] cb_model.reset_codebook_metrics() cb_model.reset_hook_kwargs() fwd_hooks = [ (get_cb_layer_name(cb_at[i], layer_idx[i], head_idx[i]), hook_fns[i]) for i in range(len(cb_at)) ] with cb_model.hooks(fwd_hooks, [], True, False) as hooked_model: patched_logits, patched_cache = hooked_model.run_with_cache( input, prepend_bos=prepend_bos ) return patched_logits, patched_cache def in_hook_list(list_of_arg_tuples, layer, head=None): """Check if the component specified by `layer` and `head` is in the `list_of_arg_tuples`.""" # if head is not provided, then checks in MLP for arg_tuple in list_of_arg_tuples: if head is None: if arg_tuple.cb_at == "mlp" and arg_tuple.layer == layer: return True else: if ( arg_tuple.cb_at == "attn" and arg_tuple.layer == layer and arg_tuple.head == head ): return True return False # def generate_with_codes(input, code, cb_at, layer_idx, head_idx=None, pos=None, disable_other_comps=False): def generate_with_codes( input, cb_model, list_of_code_infos=(), disable_other_comps=False, automata=None, generate_kwargs=None, ): """Model's generation with the codebook features patched in.""" if generate_kwargs is None: generate_kwargs = {} hook_fns = [ partial(patch_in_codes, pos=tupl.pos, code=tupl.code) for tupl in list_of_code_infos ] fwd_hooks = [ (get_cb_layer_name(tupl.cb_at, tupl.layer, tupl.head), hook_fns[i]) for i, tupl in enumerate(list_of_code_infos) ] cb_model.reset_hook_kwargs() if disable_other_comps: for layer, cb in cb_model.all_codebooks.items(): for head_idx, head in enumerate(cb[0].codebook): if not in_hook_list(list_of_code_infos, layer, head_idx): head.set_hook_kwargs( disable_topk=1, disable_for_tkns=[-1], keep_k_codes=False ) if not in_hook_list(list_of_code_infos, layer): cb[1].set_hook_kwargs( disable_topk=1, disable_for_tkns=[-1], keep_k_codes=False ) with cb_model.hooks(fwd_hooks, [], True, False) as hooked_model: gen = hooked_model.generate(input, **generate_kwargs) return automata.seq_to_traj(gen)[0] if automata is not None else gen def kl_div(logits1, logits2, pos=-1, reduction="batchmean"): """Calculate the KL divergence between the logits at `pos`.""" logits1_last, logits2_last = logits1[:, pos, :], logits2[:, pos, :] # calculate kl divergence between clean and mod logits last return F.kl_div( F.log_softmax(logits1_last, dim=-1), F.log_softmax(logits2_last, dim=-1), log_target=True, reduction=reduction, ) def JSD(logits1, logits2, pos=-1, reduction="batchmean"): """Compute the Jensen-Shannon divergence between two distributions.""" if len(logits1.shape) == 3: logits1, logits2 = logits1[:, pos, :], logits2[:, pos, :] probs1 = F.softmax(logits1, dim=-1) probs2 = F.softmax(logits2, dim=-1) total_m = (0.5 * (probs1 + probs2)).log() loss = 0.0 loss += F.kl_div( total_m, F.log_softmax(logits1, dim=-1), log_target=True, reduction=reduction, ) loss += F.kl_div( total_m, F.log_softmax(logits2, dim=-1), log_target=True, reduction=reduction, ) return 0.5 * loss def residual_stream_patching_hook(resid_pre, hook, cache, position: int): """Patch in the codebook features at `position` from `cache`.""" clean_resid_pre = cache[hook.name] resid_pre[:, position, :] = clean_resid_pre[:, position, :] return resid_pre def find_code_changes(cache1, cache2, pos=None): """Find the codebook codes that are different between the two caches.""" for k in cache1.keys(): if "codebook" in k: c1 = cache1[k][0, pos] c2 = cache2[k][0, pos] if not torch.all(c1 == c2): print(cb_layer_name_to_info(k), c1.tolist(), c2.tolist()) print(cb_layer_name_to_info(k), c1.tolist(), c2.tolist()) def common_codes_in_cache(cache_codes, threshold=0.0): """Get the common code in the cache.""" codes, counts = torch.unique(cache_codes, return_counts=True, sorted=True) counts = counts.float() * 100 counts /= cache_codes.shape[1] counts, indices = torch.sort(counts, descending=True) codes = codes[indices] indices = counts > threshold codes, counts = codes[indices], counts[indices] return codes, counts def parse_code_info_string( info_str: str, cb_at="attn", pos=None, code_pos=-1 ) -> CodeInfo: """Parse the code info string. The format of the `info_str` is: `code: 0, layer: 0, head: 0, occ_freq: 0.0, train_act_freq: 0.0`. """ code, layer, head, occ_freq, train_act_freq = info_str.split(", ") code = int(code.split(": ")[1]) layer = int(layer.split(": ")[1]) head = int(head.split(": ")[1]) if head else None occ_freq = float(occ_freq.split(": ")[1]) train_act_freq = float(train_act_freq.split(": ")[1]) return CodeInfo(code, layer, head, pos=pos, code_pos=code_pos, cb_at=cb_at) def parse_concept_codes_string(info_str: str, pos=None, code_append=False): """Parse the concept codes string.""" code_info_strs = info_str.strip().split("\n") concept_codes = [] layer, head = None, None code_pos = "append" if code_append else -1 for code_info_str in code_info_strs: concept_codes.append( parse_code_info_string(code_info_str, pos=pos, code_pos=code_pos) ) if code_append: continue if layer == concept_codes[-1].layer and head == concept_codes[-1].head: code_pos -= 1 else: code_pos = -1 concept_codes[-1].code_pos = code_pos layer, head = concept_codes[-1].layer, concept_codes[-1].head return concept_codes