|
from os.path import expanduser |
|
import torch |
|
import json |
|
import torchvision |
|
from general_utils import get_from_repository |
|
from general_utils import log |
|
from torchvision import transforms |
|
|
|
PASCAL_VOC_CLASSES_ZS = [['cattle.n.01', 'motorcycle.n.01'], ['aeroplane.n.01', 'sofa.n.01'], |
|
['cat.n.01', 'television.n.03'], ['train.n.01', 'bottle.n.01'], |
|
['chair.n.01', 'pot_plant.n.01']] |
|
|
|
|
|
class PascalZeroShot(object): |
|
|
|
def __init__(self, split, n_unseen, image_size=224) -> None: |
|
super().__init__() |
|
|
|
import sys |
|
sys.path.append('third_party/JoEm') |
|
from third_party.JoEm.data_loader.dataset import VOCSegmentation |
|
from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC |
|
|
|
self.pascal_classes = VOC |
|
self.image_size = image_size |
|
|
|
self.transform = transforms.Compose([ |
|
transforms.Resize((image_size, image_size)), |
|
]) |
|
|
|
if split == 'train': |
|
self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen), |
|
split=split, transform=True, transform_args=dict(base_size=312, crop_size=312), |
|
ignore_bg=False, ignore_unseen=False, remv_unseen_img=True) |
|
elif split == 'val': |
|
self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen), |
|
split=split, transform=False, |
|
ignore_bg=False, ignore_unseen=False) |
|
|
|
self.unseen_idx = get_unseen_idx(n_unseen) |
|
|
|
def __len__(self): |
|
return len(self.voc) |
|
|
|
def __getitem__(self, i): |
|
|
|
sample = self.voc[i] |
|
label = sample['label'].long() |
|
all_labels = [l for l in torch.where(torch.bincount(label.flatten())>0)[0].numpy().tolist() if l != 255] |
|
class_indices = [l for l in all_labels] |
|
class_names = [self.pascal_classes[l] for l in all_labels] |
|
|
|
image = self.transform(sample['image']) |
|
|
|
label = transforms.Resize((self.image_size, self.image_size), |
|
interpolation=torchvision.transforms.InterpolationMode.NEAREST)(label.unsqueeze(0))[0] |
|
|
|
return (image,), (label, ) |
|
|
|
|
|
|