|
|
|
import torch |
|
import numpy as np |
|
import os |
|
|
|
from os.path import join, isdir, isfile, expanduser |
|
from PIL import Image |
|
|
|
from torchvision import transforms |
|
from torchvision.transforms.transforms import Resize |
|
|
|
from torch.nn import functional as nnf |
|
from general_utils import get_from_repository |
|
|
|
from skimage.draw import polygon2mask |
|
|
|
|
|
|
|
def random_crop_slices(origin_size, target_size): |
|
"""Gets slices of a random crop. """ |
|
assert origin_size[0] >= target_size[0] and origin_size[1] >= target_size[1], f'actual size: {origin_size}, target size: {target_size}' |
|
|
|
offset_y = torch.randint(0, origin_size[0] - target_size[0] + 1, (1,)).item() |
|
offset_x = torch.randint(0, origin_size[1] - target_size[1] + 1, (1,)).item() |
|
|
|
return slice(offset_y, offset_y + target_size[0]), slice(offset_x, offset_x + target_size[1]) |
|
|
|
|
|
def find_crop(seg, image_size, iterations=1000, min_frac=None, best_of=None): |
|
|
|
|
|
best_crops = [] |
|
best_crop_not_ok = float('-inf'), None, None |
|
min_sum = 0 |
|
|
|
seg = seg.astype('bool') |
|
|
|
if min_frac is not None: |
|
|
|
min_sum = seg.shape[0] * seg.shape[1] * min_frac |
|
|
|
for iteration in range(iterations): |
|
sl_y, sl_x = random_crop_slices(seg.shape, image_size) |
|
seg_ = seg[sl_y, sl_x] |
|
sum_seg_ = seg_.sum() |
|
|
|
if sum_seg_ > min_sum: |
|
|
|
if best_of is None: |
|
return sl_y, sl_x, False |
|
else: |
|
best_crops += [(sum_seg_, sl_y, sl_x)] |
|
if len(best_crops) >= best_of: |
|
best_crops.sort(key=lambda x:x[0], reverse=True) |
|
sl_y, sl_x = best_crops[0][1:] |
|
|
|
return sl_y, sl_x, False |
|
|
|
else: |
|
if sum_seg_ > best_crop_not_ok[0]: |
|
best_crop_not_ok = sum_seg_, sl_y, sl_x |
|
|
|
else: |
|
|
|
return best_crop_not_ok[1:] + (best_crop_not_ok[0] <= min_sum,) |
|
|
|
|
|
class PhraseCut(object): |
|
|
|
def __init__(self, split, image_size=400, negative_prob=0, aug=None, aug_color=False, aug_crop=True, |
|
min_size=0, remove_classes=None, with_visual=False, only_visual=False, mask=None): |
|
super().__init__() |
|
|
|
self.negative_prob = negative_prob |
|
self.image_size = image_size |
|
self.with_visual = with_visual |
|
self.only_visual = only_visual |
|
self.phrase_form = '{}' |
|
self.mask = mask |
|
self.aug_crop = aug_crop |
|
|
|
if aug_color: |
|
self.aug_color = transforms.Compose([ |
|
transforms.ColorJitter(0.5, 0.5, 0.2, 0.05), |
|
]) |
|
else: |
|
self.aug_color = None |
|
|
|
get_from_repository('PhraseCut', ['PhraseCut.tar'], integrity_check=lambda local_dir: all([ |
|
isdir(join(local_dir, 'VGPhraseCut_v0')), |
|
isdir(join(local_dir, 'VGPhraseCut_v0', 'images')), |
|
isfile(join(local_dir, 'VGPhraseCut_v0', 'refer_train.json')), |
|
len(os.listdir(join(local_dir, 'VGPhraseCut_v0', 'images'))) in {108250, 108249} |
|
])) |
|
|
|
from third_party.PhraseCutDataset.utils.refvg_loader import RefVGLoader |
|
self.refvg_loader = RefVGLoader(split=split) |
|
|
|
|
|
invalid_img_ids = set([150417, 285665, 498246, 61564, 285743, 498269, 498010, 150516, 150344, 286093, 61530, |
|
150333, 286065, 285814, 498187, 285761, 498042]) |
|
|
|
mean = [0.485, 0.456, 0.406] |
|
std = [0.229, 0.224, 0.225] |
|
self.normalize = transforms.Normalize(mean, std) |
|
|
|
self.sample_ids = [(i, j) |
|
for i in self.refvg_loader.img_ids |
|
for j in range(len(self.refvg_loader.get_img_ref_data(i)['phrases'])) |
|
if i not in invalid_img_ids] |
|
|
|
|
|
|
|
|
|
from nltk.stem import WordNetLemmatizer |
|
wnl = WordNetLemmatizer() |
|
|
|
|
|
if remove_classes is None: |
|
pass |
|
else: |
|
from datasets.generate_lvis_oneshot import PASCAL_SYNSETS, traverse_lemmas, traverse_lemmas_hypo |
|
from nltk.corpus import wordnet |
|
|
|
print('remove pascal classes...') |
|
|
|
get_data = self.refvg_loader.get_img_ref_data |
|
keep_sids = None |
|
|
|
if remove_classes[0] == 'pas5i': |
|
subset_id = remove_classes[1] |
|
from datasets.generate_lvis_oneshot import PASCAL_5I_SYNSETS_ORDERED, PASCAL_5I_CLASS_IDS |
|
avoid = [PASCAL_5I_SYNSETS_ORDERED[i] for i in range(20) if i+1 not in PASCAL_5I_CLASS_IDS[subset_id]] |
|
|
|
|
|
elif remove_classes[0] == 'zs': |
|
stop = remove_classes[1] |
|
|
|
from datasets.pascal_zeroshot import PASCAL_VOC_CLASSES_ZS |
|
|
|
avoid = [c for class_set in PASCAL_VOC_CLASSES_ZS[:stop] for c in class_set] |
|
print(avoid) |
|
|
|
elif remove_classes[0] == 'aff': |
|
|
|
|
|
avoid = ['drink', 'drinks', 'drinking', 'sit', 'sits', 'sitting', |
|
'ride', 'rides', 'riding', |
|
'fly', 'flies', 'flying', 'drive', 'drives', 'driving', 'driven', |
|
'swim', 'swims', 'swimming', |
|
'wheels', 'wheel', 'legs', 'leg', 'ear', 'ears'] |
|
keep_sids = [(i, j) for i, j in self.sample_ids if |
|
all(x not in avoid for x in get_data(i)['phrases'][j].split(' '))] |
|
|
|
print('avoid classes:', avoid) |
|
|
|
|
|
if keep_sids is None: |
|
all_lemmas = [s for ps in avoid for s in traverse_lemmas_hypo(wordnet.synset(ps), max_depth=None)] |
|
all_lemmas = list(set(all_lemmas)) |
|
all_lemmas = [h.replace('_', ' ').lower() for h in all_lemmas] |
|
all_lemmas = set(all_lemmas) |
|
|
|
|
|
all_lemmas_s = set(l for l in all_lemmas if ' ' not in l) |
|
all_lemmas_m = set(l for l in all_lemmas if l not in all_lemmas_s) |
|
|
|
|
|
phrases = [get_data(i)['phrases'][j] for i, j in self.sample_ids] |
|
remove_sids = set((i,j) for (i,j), phrase in zip(self.sample_ids, phrases) |
|
if any(l in phrase for l in all_lemmas_m) or |
|
len(set(wnl.lemmatize(w) for w in phrase.split(' ')).intersection(all_lemmas_s)) > 0 |
|
) |
|
keep_sids = [(i, j) for i, j in self.sample_ids if (i,j) not in remove_sids] |
|
|
|
print(f'Reduced to {len(keep_sids) / len(self.sample_ids):.3f}') |
|
removed_ids = set(self.sample_ids) - set(keep_sids) |
|
|
|
print('Examples of removed', len(removed_ids)) |
|
for i, j in list(removed_ids)[:20]: |
|
print(i, get_data(i)['phrases'][j]) |
|
|
|
self.sample_ids = keep_sids |
|
|
|
from itertools import groupby |
|
samples_by_phrase = [(self.refvg_loader.get_img_ref_data(i)['phrases'][j], (i, j)) |
|
for i, j in self.sample_ids] |
|
samples_by_phrase = sorted(samples_by_phrase) |
|
samples_by_phrase = groupby(samples_by_phrase, key=lambda x: x[0]) |
|
|
|
self.samples_by_phrase = {prompt: [s[1] for s in prompt_sample_ids] for prompt, prompt_sample_ids in samples_by_phrase} |
|
|
|
self.all_phrases = list(set(self.samples_by_phrase.keys())) |
|
|
|
|
|
if self.only_visual: |
|
assert self.with_visual |
|
self.sample_ids = [(i, j) for i, j in self.sample_ids |
|
if len(self.samples_by_phrase[self.refvg_loader.get_img_ref_data(i)['phrases'][j]]) > 1] |
|
|
|
|
|
sizes = [self.refvg_loader.get_img_ref_data(i)['gt_boxes'][j] for i, j in self.sample_ids] |
|
image_sizes = [self.refvg_loader.get_img_ref_data(i)['width'] * self.refvg_loader.get_img_ref_data(i)['height'] for i, j in self.sample_ids] |
|
|
|
self.sizes = [sum([s[2] * s[3] for s in size]) / img_size for size, img_size in zip(sizes, image_sizes)] |
|
|
|
if min_size: |
|
print('filter by size') |
|
|
|
self.sample_ids = [self.sample_ids[i] for i in range(len(self.sample_ids)) if self.sizes[i] > min_size] |
|
|
|
self.base_path = join(expanduser('~/datasets/PhraseCut/VGPhraseCut_v0/images/')) |
|
|
|
def __len__(self): |
|
return len(self.sample_ids) |
|
|
|
|
|
def load_sample(self, sample_i, j): |
|
|
|
img_ref_data = self.refvg_loader.get_img_ref_data(sample_i) |
|
|
|
polys_phrase0 = img_ref_data['gt_Polygons'][j] |
|
phrase = img_ref_data['phrases'][j] |
|
phrase = self.phrase_form.format(phrase) |
|
|
|
masks = [] |
|
for polys in polys_phrase0: |
|
for poly in polys: |
|
poly = [p[::-1] for p in poly] |
|
masks += [polygon2mask((img_ref_data['height'], img_ref_data['width']), poly)] |
|
|
|
seg = np.stack(masks).max(0) |
|
img = np.array(Image.open(join(self.base_path, str(img_ref_data['image_id']) + '.jpg'))) |
|
|
|
min_shape = min(img.shape[:2]) |
|
|
|
if self.aug_crop: |
|
sly, slx, exceed = find_crop(seg, (min_shape, min_shape), iterations=50, min_frac=0.05) |
|
else: |
|
sly, slx = slice(0, None), slice(0, None) |
|
|
|
seg = seg[sly, slx] |
|
img = img[sly, slx] |
|
|
|
seg = seg.astype('uint8') |
|
seg = torch.from_numpy(seg).view(1, 1, *seg.shape) |
|
|
|
if img.ndim == 2: |
|
img = np.dstack([img] * 3) |
|
|
|
img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).float() |
|
|
|
seg = nnf.interpolate(seg, (self.image_size, self.image_size), mode='nearest')[0,0] |
|
img = nnf.interpolate(img, (self.image_size, self.image_size), mode='bilinear', align_corners=True)[0] |
|
|
|
|
|
img = img / 255.0 |
|
|
|
if self.aug_color is not None: |
|
img = self.aug_color(img) |
|
|
|
img = self.normalize(img) |
|
|
|
|
|
|
|
return img, seg, phrase |
|
|
|
def __getitem__(self, i): |
|
|
|
sample_i, j = self.sample_ids[i] |
|
|
|
img, seg, phrase = self.load_sample(sample_i, j) |
|
|
|
if self.negative_prob > 0: |
|
if torch.rand((1,)).item() < self.negative_prob: |
|
|
|
new_phrase = None |
|
while new_phrase is None or new_phrase == phrase: |
|
idx = torch.randint(0, len(self.all_phrases), (1,)).item() |
|
new_phrase = self.all_phrases[idx] |
|
phrase = new_phrase |
|
seg = torch.zeros_like(seg) |
|
|
|
if self.with_visual: |
|
|
|
if phrase in self.samples_by_phrase and len(self.samples_by_phrase[phrase]) > 1: |
|
idx = torch.randint(0, len(self.samples_by_phrase[phrase]), (1,)).item() |
|
other_sample = self.samples_by_phrase[phrase][idx] |
|
|
|
img_s, seg_s, _ = self.load_sample(*other_sample) |
|
|
|
from datasets.utils import blend_image_segmentation |
|
|
|
if self.mask in {'separate', 'text_and_separate'}: |
|
|
|
add_phrase = [phrase] if self.mask == 'text_and_separate' else [] |
|
vis_s = add_phrase + [img_s, seg_s, True] |
|
else: |
|
if self.mask.startswith('text_and_'): |
|
mask_mode = self.mask[9:] |
|
label_add = [phrase] |
|
else: |
|
mask_mode = self.mask |
|
label_add = [] |
|
|
|
masked_img_s = torch.from_numpy(blend_image_segmentation(img_s, seg_s, mode=mask_mode, image_size=self.image_size)[0]) |
|
vis_s = label_add + [masked_img_s, True] |
|
|
|
else: |
|
|
|
vis_s = torch.zeros_like(img) |
|
|
|
if self.mask in {'separate', 'text_and_separate'}: |
|
add_phrase = [phrase] if self.mask == 'text_and_separate' else [] |
|
vis_s = add_phrase + [vis_s, torch.zeros(*vis_s.shape[1:], dtype=torch.uint8), False] |
|
elif self.mask.startswith('text_and_'): |
|
vis_s = [phrase, vis_s, False] |
|
else: |
|
vis_s = [vis_s, False] |
|
else: |
|
assert self.mask == 'text' |
|
vis_s = [phrase] |
|
|
|
seg = seg.unsqueeze(0).float() |
|
|
|
data_x = (img,) + tuple(vis_s) |
|
|
|
return data_x, (seg, torch.zeros(0), i) |
|
|
|
|
|
class PhraseCutPlus(PhraseCut): |
|
|
|
def __init__(self, split, image_size=400, aug=None, aug_color=False, aug_crop=True, min_size=0, remove_classes=None, only_visual=False, mask=None): |
|
super().__init__(split, image_size=image_size, negative_prob=0.2, aug=aug, aug_color=aug_color, aug_crop=aug_crop, min_size=min_size, |
|
remove_classes=remove_classes, with_visual=True, only_visual=only_visual, mask=mask) |