import os import torchvision import pickle from typing import Any import lmdb import cv2 import imageio import numpy as np from PIL import Image import Imath import OpenEXR from pdb import set_trace as st from pathlib import Path from functools import partial import io import gzip import random import torch from torch import nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from torchvision import transforms from torch.utils.data.distributed import DistributedSampler from pathlib import Path from guided_diffusion import logger def load_dataset( file_path="", reso=64, reso_encoder=224, batch_size=1, # shuffle=True, num_workers=6, load_depth=False, preprocess=None, imgnet_normalize=True, dataset_size=-1, trainer_name='input_rec', use_lmdb=False, infi_sampler=True ): # st() # dataset_cls = { # 'input_rec': MultiViewDataset, # 'nv': NovelViewDataset, # }[trainer_name] # st() if use_lmdb: logger.log('using LMDB dataset') # dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later. if 'nv' in trainer_name: dataset_cls = LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. else: dataset_cls = LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. # dataset = dataset_cls(file_path) else: if 'nv' in trainer_name: dataset_cls = NovelViewDataset # 1.5-2iter/s else: dataset_cls = MultiViewDataset dataset = dataset_cls(file_path, reso, reso_encoder, test=False, preprocess=preprocess, load_depth=load_depth, imgnet_normalize=imgnet_normalize, dataset_size=dataset_size) logger.log('dataset_cls: {}, dataset size: {}'.format( trainer_name, len(dataset))) loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, drop_last=False, pin_memory=True, persistent_workers=num_workers > 0, shuffle=False) return loader def load_data( file_path="", reso=64, reso_encoder=224, batch_size=1, # shuffle=True, num_workers=6, load_depth=False, preprocess=None, imgnet_normalize=True, dataset_size=-1, trainer_name='input_rec', use_lmdb=False, infi_sampler=True ): # st() # dataset_cls = { # 'input_rec': MultiViewDataset, # 'nv': NovelViewDataset, # }[trainer_name] # st() if use_lmdb: logger.log('using LMDB dataset') # dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later. if 'nv' in trainer_name: dataset_cls = LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. else: dataset_cls = LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. # dataset = dataset_cls(file_path) else: if 'nv' in trainer_name: dataset_cls = NovelViewDataset # 1.5-2iter/s else: dataset_cls = MultiViewDataset dataset = dataset_cls(file_path, reso, reso_encoder, test=False, preprocess=preprocess, load_depth=load_depth, imgnet_normalize=imgnet_normalize, dataset_size=dataset_size) logger.log('dataset_cls: {}, dataset size: {}'.format( trainer_name, len(dataset))) # st() if infi_sampler: train_sampler = DistributedSampler(dataset=dataset, shuffle=True, drop_last=True) loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True, pin_memory=True, persistent_workers=num_workers > 0, sampler=train_sampler) while True: yield from loader else: # loader = DataLoader(dataset, # batch_size=batch_size, # num_workers=num_workers, # drop_last=False, # pin_memory=True, # persistent_workers=num_workers > 0, # shuffle=False) st() return dataset def load_eval_rays(file_path="", reso=64, reso_encoder=224, imgnet_normalize=True): dataset = MultiViewDataset(file_path, reso, reso_encoder, imgnet_normalize=imgnet_normalize) pose_list = dataset.single_pose_list ray_list = [] for pose_fname in pose_list: # c2w = dataset.get_c2w(pose_fname).reshape(1,4,4) #[1, 4, 4] # rays_o, rays_d = dataset.gen_rays(c2w) # ray_list.append( # [rays_o.unsqueeze(0), # rays_d.unsqueeze(0), # c2w.reshape(-1, 16)]) c2w = dataset.get_c2w(pose_fname).reshape(16) #[1, 4, 4] c = torch.cat([c2w, dataset.intrinsics], dim=0).reshape(25) # 25, no '1' dim needed. ray_list.append(c) return ray_list def load_eval_data(file_path="", reso=64, reso_encoder=224, batch_size=1, num_workers=1, load_depth=False, preprocess=None, imgnet_normalize=True, interval=1, **kwargs): dataset = MultiViewDataset(file_path, reso, reso_encoder, preprocess=preprocess, load_depth=load_depth, test=True, imgnet_normalize=imgnet_normalize, interval=interval) print('eval dataset size: {}'.format(len(dataset))) # train_sampler = DistributedSampler(dataset=dataset) loader = DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, drop_last=False, shuffle=False, ) # sampler=train_sampler) return loader def load_memory_data(file_path="", reso=64, reso_encoder=224, batch_size=1, num_workers=1, load_depth=True, preprocess=None, imgnet_normalize=True): # load a single-instance into the memory to speed up training IO dataset = MultiViewDataset(file_path, reso, reso_encoder, preprocess=preprocess, load_depth=True, test=False, overfitting=True, imgnet_normalize=imgnet_normalize, overfitting_bs=batch_size) logger.log('!!!!!!! memory dataset size: {} !!!!!!'.format(len(dataset))) # train_sampler = DistributedSampler(dataset=dataset) loader = DataLoader( dataset, batch_size=len(dataset), num_workers=num_workers, drop_last=False, shuffle=False, ) all_data: dict = next(iter(loader)) while True: start_idx = np.random.randint(0, len(dataset) - batch_size + 1) yield { k: v[start_idx:start_idx + batch_size] for k, v in all_data.items() } class MultiViewDataset(Dataset): def __init__(self, file_path, reso, reso_encoder, preprocess=None, classes=False, load_depth=False, test=False, scene_scale=1, overfitting=False, imgnet_normalize=True, dataset_size=-1, overfitting_bs=-1, interval=1): self.file_path = file_path self.overfitting = overfitting self.scene_scale = scene_scale self.reso = reso self.reso_encoder = reso_encoder self.classes = False self.load_depth = load_depth self.preprocess = preprocess assert not self.classes, "Not support class condition now." # self.ins_list = os.listdir(self.file_path) # if test: # TODO dataset_name = Path(self.file_path).stem.split('_')[0] self.dataset_name = dataset_name if test: # ins_list_file = Path(self.file_path).parent / f'{dataset_name}_test_list.txt' # ? in domain if dataset_name == 'chair': self.ins_list = sorted(os.listdir( self.file_path))[1:2] # more diversity else: self.ins_list = sorted(os.listdir(self.file_path))[ 0:1] # the first 1 instance for evaluation reference. else: # self.ins_list = sorted(Path(self.file_path).glob('[0-8]*')) # self.ins_list = Path(self.file_path).glob('*') # self.ins_list = list(Path(self.file_path).glob('*'))[:dataset_size] # ins_list_file = Path( # self.file_path).parent / f'{dataset_name}s_train_list.txt' # assert ins_list_file.exists(), 'add training list for ShapeNet' # with open(ins_list_file, 'r') as f: # self.ins_list = [name.strip() for name in f.readlines()] # if dataset_name == 'chair': ins_list_file = Path( self.file_path).parent / f'{dataset_name}_train_list.txt' # st() assert ins_list_file.exists(), 'add training list for ShapeNet' with open(ins_list_file, 'r') as f: self.ins_list = [name.strip() for name in f.readlines()][:dataset_size] # else: # self.ins_list = Path(self.file_path).glob('*') if overfitting: self.ins_list = self.ins_list[:1] self.rgb_list = [] self.pose_list = [] self.depth_list = [] self.data_ins_list = [] self.instance_data_length = -1 for ins in self.ins_list: cur_rgb_path = os.path.join(self.file_path, ins, 'rgb') cur_pose_path = os.path.join(self.file_path, ins, 'pose') cur_all_fname = sorted([ t.split('.')[0] for t in os.listdir(cur_rgb_path) if 'depth' not in t ][::interval]) if self.instance_data_length == -1: self.instance_data_length = len(cur_all_fname) else: assert len(cur_all_fname) == self.instance_data_length # ! check filtered data # for idx in range(len(cur_all_fname)): # fname = cur_all_fname[idx] # if not Path(os.path.join(cur_rgb_path, fname + '.png') ).exists(): # cur_all_fname.remove(fname) # del cur_all_fname[idx] if test: mid_index = len(cur_all_fname) // 3 * 2 cur_all_fname.insert(0, cur_all_fname[mid_index]) self.pose_list += ([ os.path.join(cur_pose_path, fname + '.txt') for fname in cur_all_fname ]) self.rgb_list += ([ os.path.join(cur_rgb_path, fname + '.png') for fname in cur_all_fname ]) self.depth_list += ([ os.path.join(cur_rgb_path, fname + '_depth0001.exr') for fname in cur_all_fname ]) self.data_ins_list += ([ins] * len(cur_all_fname)) # validate overfitting on images if overfitting: # bs=9 # self.pose_list = self.pose_list[::50//9+1] # self.rgb_list = self.rgb_list[::50//9+1] # self.depth_list = self.depth_list[::50//9+1] # bs=6 # self.pose_list = self.pose_list[::50//6+1] # self.rgb_list = self.rgb_list[::50//6+1] # self.depth_list = self.depth_list[::50//6+1] # bs=3 assert overfitting_bs != -1 # bs=1 # self.pose_list = self.pose_list[25:26] # self.rgb_list = self.rgb_list[25:26] # self.depth_list = self.depth_list[25:26] # uniform pose sampling self.pose_list = self.pose_list[::50//overfitting_bs+1] self.rgb_list = self.rgb_list[::50//overfitting_bs+1] self.depth_list = self.depth_list[::50//overfitting_bs+1] # sequentially sampling pose # self.pose_list = self.pose_list[25:25+overfitting_bs] # self.rgb_list = self.rgb_list[25:25+overfitting_bs] # self.depth_list = self.depth_list[25:25+overfitting_bs] # duplicate the same pose # self.pose_list = [self.pose_list[25]] * overfitting_bs # self.rgb_list = [self.rgb_list[25]] * overfitting_bs # self.depth_list = [self.depth_list[25]] * overfitting_bs # self.pose_list = [self.pose_list[28]] * overfitting_bs # self.rgb_list = [self.rgb_list[28]] * overfitting_bs # self.depth_list = [self.depth_list[28]] * overfitting_bs self.single_pose_list = [ os.path.join(cur_pose_path, fname + '.txt') for fname in cur_all_fname ] # st() # if imgnet_normalize: transformations = [ transforms.ToTensor(), # [0,1] range ] if imgnet_normalize: transformations.append( transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # type: ignore ) else: transformations.append( transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) # type: ignore self.normalize = transforms.Compose(transformations) # self.normalize_normalrange = transforms.Compose([ # transforms.ToTensor(),# [0,1] range # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), # ]) fx = fy = 525 cx = cy = 256 # rendering default K factor = self.reso / (cx * 2) # 128 / 512 self.fx = fx * factor self.fy = fy * factor self.cx = cx * factor self.cy = cy * factor # ! fix scale for triplane ray_sampler(), here we adopt [0,1] uv range, not [0, w] img space range. self.cx /= self.reso # 0.5 self.cy /= self.reso # 0.5 self.fx /= self.reso self.fy /= self.reso intrinsics = np.array([[self.fx, 0, self.cx], [0, self.fy, self.cy], [0, 0, 1]]).reshape(9) # self.intrinsics = torch.from_numpy(intrinsics).float() self.intrinsics = intrinsics def __len__(self): return len(self.rgb_list) def get_c2w(self, pose_fname): with open(pose_fname, 'r') as f: cam2world = f.readline().strip() cam2world = [float(t) for t in cam2world.split(' ')] c2w = torch.tensor(cam2world, dtype=torch.float32).reshape(4, 4) return c2w def gen_rays(self, c2w): # Generate rays self.h = self.reso self.w = self.reso yy, xx = torch.meshgrid( torch.arange(self.h, dtype=torch.float32) + 0.5, torch.arange(self.w, dtype=torch.float32) + 0.5, indexing='ij') xx = (xx - self.cx) / self.fx yy = (yy - self.cy) / self.fy zz = torch.ones_like(xx) dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention dirs /= torch.norm(dirs, dim=-1, keepdim=True) dirs = dirs.reshape(1, -1, 3, 1) del xx, yy, zz dirs = (c2w[:, None, :3, :3] @ dirs)[..., 0] origins = c2w[:, None, :3, 3].expand(-1, self.h * self.w, -1).contiguous() origins = origins.view(-1, 3) dirs = dirs.view(-1, 3) return origins, dirs def read_depth(self, idx): depth_path = self.depth_list[idx] # image_path = os.path.join(depth_fname, self.image_names[index]) exr = OpenEXR.InputFile(depth_path) header = exr.header() size = (header['dataWindow'].max.x - header['dataWindow'].min.x + 1, header['dataWindow'].max.y - header['dataWindow'].min.y + 1) FLOAT = Imath.PixelType(Imath.PixelType.FLOAT) depth_str = exr.channel('B', FLOAT) depth = np.frombuffer(depth_str, dtype=np.float32).reshape(size[1], size[0]) # H W depth = np.nan_to_num(depth, posinf=0, neginf=0) depth = depth.reshape(size) def resize_depth_mask(depth_to_resize, resolution): depth_resized = cv2.resize(depth_to_resize, (resolution, resolution), interpolation=cv2.INTER_LANCZOS4) # interpolation=cv2.INTER_AREA) return depth_resized > 0 # type: ignore fg_mask_reso = resize_depth_mask(depth, self.reso) fg_mask_sr = resize_depth_mask(depth, 128) # depth = cv2.resize(depth, (self.reso, self.reso), # interpolation=cv2.INTER_LANCZOS4) # interpolation=cv2.INTER_AREA) # depth_mask = depth > 0 # depth = np.expand_dims(depth, axis=0).reshape(size) # return torch.from_numpy(depth) return torch.from_numpy(depth), torch.from_numpy( fg_mask_reso), torch.from_numpy(fg_mask_sr) def load_bbox(self, mask): nonzero_value = torch.nonzero(mask) height, width = nonzero_value.max(dim=0)[0] top, left = nonzero_value.min(dim=0)[0] bbox = torch.tensor([top, left, height, width], dtype=torch.float32) return bbox def __getitem__(self, idx): rgb_fname = self.rgb_list[idx] pose_fname = self.pose_list[idx] raw_img = imageio.imread(rgb_fname) if self.preprocess is None: img_to_encoder = cv2.resize(raw_img, (self.reso_encoder, self.reso_encoder), interpolation=cv2.INTER_LANCZOS4) # interpolation=cv2.INTER_AREA) img_to_encoder = img_to_encoder[ ..., :3] #[3, reso_encoder, reso_encoder] img_to_encoder = self.normalize(img_to_encoder) else: img_to_encoder = self.preprocess(Image.open(rgb_fname)) # clip img = cv2.resize(raw_img, (self.reso, self.reso), interpolation=cv2.INTER_LANCZOS4) # interpolation=cv2.INTER_AREA) # img_sr = cv2.resize(raw_img, (512, 512), interpolation=cv2.INTER_AREA) # img_sr = cv2.resize(raw_img, (256, 256), interpolation=cv2.INTER_AREA) # just as refinement, since eg3d uses 64->128 final resolution # img_sr = cv2.resize(raw_img, (128, 128), interpolation=cv2.INTER_AREA) # just as refinement, since eg3d uses 64->128 final resolution img_sr = cv2.resize( raw_img, (128, 128), interpolation=cv2.INTER_LANCZOS4 ) # just as refinement, since eg3d uses 64->128 final resolution # img = torch.from_numpy(img)[..., :3].permute( # 2, 0, 1) / 255.0 #[3, reso, reso] img = torch.from_numpy(img)[..., :3].permute( 2, 0, 1 ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range img_sr = torch.from_numpy(img_sr)[..., :3].permute( 2, 0, 1 ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range # c2w = self.get_c2w(pose_fname).reshape(1, 4, 4) #[1, 4, 4] # rays_o, rays_d = self.gen_rays(c2w) # return img_to_encoder, img, rays_o, rays_d, c2w.reshape(-1) c2w = self.get_c2w(pose_fname).reshape(16) #[1, 4, 4] -> [1, 16] # c = np.concatenate([c2w, self.intrinsics], axis=0).reshape(25) # 25, no '1' dim needed. c = torch.cat([c2w, torch.from_numpy(self.intrinsics)], dim=0).reshape(25) # 25, no '1' dim needed. ret_dict = { # 'rgb_fname': rgb_fname, 'img_to_encoder': img_to_encoder, 'img': img, 'c': c, 'img_sr': img_sr, # 'ins_name': self.data_ins_list[idx] } if self.load_depth: depth, depth_mask, depth_mask_sr = self.read_depth(idx) bbox = self.load_bbox(depth_mask) ret_dict.update({ 'depth': depth, 'depth_mask': depth_mask, 'depth_mask_sr': depth_mask_sr, 'bbox': bbox }) # rays_o, rays_d = self.gen_rays(c2w) # return img_to_encoder, img, c return ret_dict class MultiViewDatasetforLMDB(MultiViewDataset): def __init__(self, file_path, reso, reso_encoder, preprocess=None, classes=False, load_depth=False, test=False, scene_scale=1, overfitting=False, imgnet_normalize=True, dataset_size=-1, overfitting_bs=-1): super().__init__(file_path, reso, reso_encoder, preprocess, classes, load_depth, test, scene_scale, overfitting, imgnet_normalize, dataset_size, overfitting_bs) def __len__(self): return super().__len__() # return 100 # for speed debug def __getitem__(self, idx): # ret_dict = super().__getitem__(idx) rgb_fname = self.rgb_list[idx] pose_fname = self.pose_list[idx] raw_img = imageio.imread(rgb_fname)[..., :3] c2w = self.get_c2w(pose_fname).reshape(16) #[1, 4, 4] -> [1, 16] # c = np.concatenate([c2w, self.intrinsics], axis=0).reshape(25) # 25, no '1' dim needed. c = torch.cat([c2w, torch.from_numpy(self.intrinsics)], dim=0).reshape(25) # 25, no '1' dim needed. depth, depth_mask, depth_mask_sr = self.read_depth(idx) bbox = self.load_bbox(depth_mask) ret_dict = { 'raw_img': raw_img, 'c': c, 'depth': depth, # 'depth_mask': depth_mask, # 64x64 here? 'bbox': bbox } return ret_dict def load_data_dryrun( file_path="", reso=64, reso_encoder=224, batch_size=1, # shuffle=True, num_workers=6, load_depth=False, preprocess=None, imgnet_normalize=True): # st() dataset = MultiViewDataset(file_path, reso, reso_encoder, test=False, preprocess=preprocess, load_depth=load_depth, imgnet_normalize=imgnet_normalize) print('dataset size: {}'.format(len(dataset))) # st() # train_sampler = DistributedSampler(dataset=dataset) loader = DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, # shuffle=shuffle, drop_last=False, ) # sampler=train_sampler) return loader class NovelViewDataset(MultiViewDataset): """novel view prediction version. """ def __init__(self, file_path, reso, reso_encoder, preprocess=None, classes=False, load_depth=False, test=False, scene_scale=1, overfitting=False, imgnet_normalize=True, dataset_size=-1, overfitting_bs=-1): super().__init__(file_path, reso, reso_encoder, preprocess, classes, load_depth, test, scene_scale, overfitting, imgnet_normalize, dataset_size, overfitting_bs) def __getitem__(self, idx): input_view = super().__getitem__( idx) # get previous input view results # get novel view of the same instance novel_view = super().__getitem__( (idx // self.instance_data_length) * self.instance_data_length + random.randint(0, self.instance_data_length - 1)) # assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance' input_view.update({f'nv_{k}': v for k, v in novel_view.items()}) return input_view def load_data_for_lmdb( file_path="", reso=64, reso_encoder=224, batch_size=1, # shuffle=True, num_workers=6, load_depth=False, preprocess=None, imgnet_normalize=True, dataset_size=-1, trainer_name='input_rec'): # st() # dataset_cls = { # 'input_rec': MultiViewDataset, # 'nv': NovelViewDataset, # }[trainer_name] # if 'nv' in trainer_name: # dataset_cls = NovelViewDataset # else: # dataset_cls = MultiViewDataset dataset_cls = MultiViewDatasetforLMDB dataset = dataset_cls(file_path, reso, reso_encoder, test=False, preprocess=preprocess, load_depth=load_depth, imgnet_normalize=imgnet_normalize, dataset_size=dataset_size) logger.log('dataset_cls: {}, dataset size: {}'.format( trainer_name, len(dataset))) # train_sampler = DistributedSampler(dataset=dataset, shuffle=True, drop_last=True) loader = DataLoader( dataset, shuffle=False, batch_size=batch_size, num_workers=num_workers, drop_last=False, prefetch_factor=2, # prefetch_factor=3, pin_memory=True, persistent_workers=True, ) # sampler=train_sampler) # while True: # yield from loader return loader, dataset.dataset_name, len(dataset) class LMDBDataset(Dataset): def __init__(self, lmdb_path): self.env = lmdb.open( lmdb_path, readonly=True, max_readers=32, lock=False, readahead=False, meminit=False, ) self.num_samples = self.env.stat()['entries'] # self.start_idx = self.env.stat()['start_idx'] # self.end_idx = self.env.stat()['end_idx'] def __len__(self): return self.num_samples def __getitem__(self, idx): with self.env.begin(write=False) as txn: key = str(idx).encode('utf-8') value = txn.get(key) sample = pickle.loads(value) return sample def resize_depth_mask(depth_to_resize, resolution): depth_resized = cv2.resize(depth_to_resize, (resolution, resolution), interpolation=cv2.INTER_LANCZOS4) # interpolation=cv2.INTER_AREA) return depth_resized, depth_resized > 0 # type: ignore class LMDBDataset_MV(LMDBDataset): def __init__(self, lmdb_path, reso, reso_encoder, imgnet_normalize=True, **kwargs): super().__init__(lmdb_path) self.reso_encoder = reso_encoder self.reso = reso transformations = [ transforms.ToTensor(), # [0,1] range ] if imgnet_normalize: transformations.append( transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # type: ignore ) else: transformations.append( transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) # type: ignore self.normalize = transforms.Compose(transformations) def _post_process_sample(self, raw_img, depth): # if raw_img.shape[-1] == 4: # ! set bg to white # alpha_mask = raw_img[..., -1:] > 0 # raw_img = alpha_mask * raw_img[..., :3] + (1-alpha_mask) * np.ones_like(raw_img[..., :3]) * 255 # raw_img = raw_img.astype(np.uint8) # img_to_encoder = cv2.resize(sample.pop('raw_img'), img_to_encoder = cv2.resize(raw_img, (self.reso_encoder, self.reso_encoder), interpolation=cv2.INTER_LANCZOS4) # interpolation=cv2.INTER_AREA) img_to_encoder = img_to_encoder[..., : 3] #[3, reso_encoder, reso_encoder] img_to_encoder = self.normalize(img_to_encoder) img = cv2.resize(raw_img, (self.reso, self.reso), interpolation=cv2.INTER_LANCZOS4) if img.shape[-1] == 4: alpha_mask = img[..., -1:] > 0 img = alpha_mask * img[..., :3] + (1-alpha_mask) * np.ones_like(img[..., :3]) * 255 img = torch.from_numpy(img)[..., :3].permute( 2, 0, 1 ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range img_sr = torch.from_numpy(raw_img)[..., :3].permute( 2, 0, 1 ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range # depth # fg_mask_reso = resize_depth_mask(sample['depth'], self.reso) depth_reso, fg_mask_reso = resize_depth_mask(depth, self.reso) return { # **sample, 'img_to_encoder': img_to_encoder, 'img': img, 'depth_mask': fg_mask_reso, 'img_sr': img_sr, 'depth': depth_reso, # ! no need to load img_sr for now } def __getitem__(self, idx): sample = super().__getitem__(idx) # do transformations online return self._post_process_sample(sample['raw_img'], sample['depth']) # return sample def load_bytes(inp_bytes, dtype, shape): return np.frombuffer(inp_bytes, dtype=dtype).reshape(shape).copy() # Function to decompress an image using gzip and open with imageio def decompress_and_open_image_gzip(compressed_data, is_img=False): # Decompress the image data using gzip decompressed_data = gzip.decompress(compressed_data) # Read the decompressed image using imageio if is_img: image = imageio.v3.imread(io.BytesIO(decompressed_data)).copy() return image return decompressed_data # Function to decompress an array using gzip def decompress_array(compressed_data, shape, dtype): # Decompress the array data using gzip decompressed_data = gzip.decompress(compressed_data) # Convert the decompressed data to a NumPy array # arr = np.frombuffer(decompressed_data, dtype=dtype).reshape(shape) return load_bytes(decompressed_data, dtype, shape) class LMDBDataset_MV_Compressed(LMDBDataset_MV): def __init__(self, lmdb_path, reso, reso_encoder, imgnet_normalize=True, **kwargs): super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize, **kwargs) with self.env.begin(write=False) as txn: self.length = int( txn.get('length'.encode('utf-8')).decode('utf-8')) - 40 self.load_image_fn = partial(decompress_and_open_image_gzip, is_img=True) def __len__(self): return self.length def _load_lmdb_data(self, idx): with self.env.begin(write=False) as txn: raw_img_key = f'{idx}-raw_img'.encode('utf-8') raw_img = self.load_image_fn(txn.get(raw_img_key)) depth_key = f'{idx}-depth'.encode('utf-8') depth = decompress_array(txn.get(depth_key), (512,512), np.float32) c_key = f'{idx}-c'.encode('utf-8') c = decompress_array(txn.get(c_key), (25, ), np.float32) bbox_key = f'{idx}-bbox'.encode('utf-8') bbox = decompress_array(txn.get(bbox_key), (4, ), np.float32) return raw_img, depth, c, bbox def __getitem__(self, idx): # sample = super(LMDBDataset).__getitem__(idx) # do gzip uncompress online raw_img, depth, c, bbox = self._load_lmdb_data(idx) return { **self._post_process_sample(raw_img, depth), 'c': c, 'bbox': bbox*(self.reso/64.0), # 'depth': depth, } class LMDBDataset_NV_Compressed(LMDBDataset_MV_Compressed): def __init__(self, lmdb_path, reso, reso_encoder, imgnet_normalize=True, **kwargs): super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize, **kwargs) self.instance_data_length = 50 # def __getitem__(self, idx): input_view = super().__getitem__( idx) # get previous input view results # get novel view of the same instance try: novel_view = super().__getitem__( (idx // self.instance_data_length) * self.instance_data_length + random.randint(0, self.instance_data_length - 1)) except Exception as e: raise NotImplementedError(idx) assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance' input_view.update({f'nv_{k}': v for k, v in novel_view.items()}) return input_view