import torch import torchvision import torch.nn.functional as F def attn_cosine_sim(x, eps=1e-08): x = x[0] # TEMP: getting rid of redundant dimension, TBF norm1 = x.norm(dim=2, keepdim=True) factor = torch.clamp(norm1 @ norm1.permute(0, 2, 1), min=eps) sim_matrix = (x @ x.permute(0, 2, 1)) / factor return sim_matrix class VitExtractor: BLOCK_KEY = 'block' ATTN_KEY = 'attn' PATCH_IMD_KEY = 'patch_imd' QKV_KEY = 'qkv' KEY_LIST = [BLOCK_KEY, ATTN_KEY, PATCH_IMD_KEY, QKV_KEY] def __init__(self, model_name, device): # pdb.set_trace() self.model = torch.hub.load('facebookresearch/dino:main', model_name).to(device) self.model.eval() self.model_name = model_name self.hook_handlers = [] self.layers_dict = {} self.outputs_dict = {} for key in VitExtractor.KEY_LIST: self.layers_dict[key] = [] self.outputs_dict[key] = [] self._init_hooks_data() def _init_hooks_data(self): self.layers_dict[VitExtractor.BLOCK_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] self.layers_dict[VitExtractor.ATTN_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] self.layers_dict[VitExtractor.QKV_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] self.layers_dict[VitExtractor.PATCH_IMD_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] for key in VitExtractor.KEY_LIST: # self.layers_dict[key] = kwargs[key] if key in kwargs.keys() else [] self.outputs_dict[key] = [] def _register_hooks(self, **kwargs): for block_idx, block in enumerate(self.model.blocks): if block_idx in self.layers_dict[VitExtractor.BLOCK_KEY]: self.hook_handlers.append(block.register_forward_hook(self._get_block_hook())) if block_idx in self.layers_dict[VitExtractor.ATTN_KEY]: self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_attn_hook())) if block_idx in self.layers_dict[VitExtractor.QKV_KEY]: self.hook_handlers.append(block.attn.qkv.register_forward_hook(self._get_qkv_hook())) if block_idx in self.layers_dict[VitExtractor.PATCH_IMD_KEY]: self.hook_handlers.append(block.attn.register_forward_hook(self._get_patch_imd_hook())) def _clear_hooks(self): for handler in self.hook_handlers: handler.remove() self.hook_handlers = [] def _get_block_hook(self): def _get_block_output(model, input, output): self.outputs_dict[VitExtractor.BLOCK_KEY].append(output) return _get_block_output def _get_attn_hook(self): def _get_attn_output(model, inp, output): self.outputs_dict[VitExtractor.ATTN_KEY].append(output) return _get_attn_output def _get_qkv_hook(self): def _get_qkv_output(model, inp, output): self.outputs_dict[VitExtractor.QKV_KEY].append(output) return _get_qkv_output # TODO: CHECK ATTN OUTPUT TUPLE def _get_patch_imd_hook(self): def _get_attn_output(model, inp, output): self.outputs_dict[VitExtractor.PATCH_IMD_KEY].append(output[0]) return _get_attn_output def get_feature_from_input(self, input_img): # List([B, N, D]) self._register_hooks() self.model(input_img) feature = self.outputs_dict[VitExtractor.BLOCK_KEY] self._clear_hooks() self._init_hooks_data() return feature def get_qkv_feature_from_input(self, input_img): self._register_hooks() self.model(input_img) feature = self.outputs_dict[VitExtractor.QKV_KEY] self._clear_hooks() self._init_hooks_data() return feature def get_attn_feature_from_input(self, input_img): self._register_hooks() self.model(input_img) feature = self.outputs_dict[VitExtractor.ATTN_KEY] self._clear_hooks() self._init_hooks_data() return feature def get_patch_size(self): return 8 if "8" in self.model_name else 16 def get_width_patch_num(self, input_img_shape): b, c, h, w = input_img_shape patch_size = self.get_patch_size() return w // patch_size def get_height_patch_num(self, input_img_shape): b, c, h, w = input_img_shape patch_size = self.get_patch_size() return h // patch_size def get_patch_num(self, input_img_shape): patch_num = 1 + (self.get_height_patch_num(input_img_shape) * self.get_width_patch_num(input_img_shape)) return patch_num def get_head_num(self): if "dino" in self.model_name: return 6 if "s" in self.model_name else 12 return 6 if "small" in self.model_name else 12 def get_embedding_dim(self): if "dino" in self.model_name: return 384 if "s" in self.model_name else 768 return 384 if "small" in self.model_name else 768 def get_queries_from_qkv(self, qkv, input_img_shape): patch_num = self.get_patch_num(input_img_shape) head_num = self.get_head_num() embedding_dim = self.get_embedding_dim() q = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[0] return q def get_keys_from_qkv(self, qkv, input_img_shape): patch_num = self.get_patch_num(input_img_shape) head_num = self.get_head_num() embedding_dim = self.get_embedding_dim() k = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[1] return k def get_values_from_qkv(self, qkv, input_img_shape): patch_num = self.get_patch_num(input_img_shape) head_num = self.get_head_num() embedding_dim = self.get_embedding_dim() v = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[2] return v def get_keys_from_input(self, input_img, layer_num): qkv_features = self.get_qkv_feature_from_input(input_img)[layer_num] keys = self.get_keys_from_qkv(qkv_features, input_img.shape) return keys def get_keys_self_sim_from_input(self, input_img, layer_num): keys = self.get_keys_from_input(input_img, layer_num=layer_num) h, t, d = keys.shape concatenated_keys = keys.transpose(0, 1).reshape(t, h * d) ssim_map = attn_cosine_sim(concatenated_keys[None, None, ...]) return ssim_map class DinoStructureLoss: def __init__(self, ): self.extractor = VitExtractor(model_name="dino_vitb8", device="cuda") self.preprocess = torchvision.transforms.Compose([ torchvision.transforms.Resize(224), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) def calculate_global_ssim_loss(self, outputs, inputs): loss = 0.0 for a, b in zip(inputs, outputs): # avoid memory limitations with torch.no_grad(): target_keys_self_sim = self.extractor.get_keys_self_sim_from_input(a.unsqueeze(0), layer_num=11) keys_ssim = self.extractor.get_keys_self_sim_from_input(b.unsqueeze(0), layer_num=11) loss += F.mse_loss(keys_ssim, target_keys_self_sim) return loss