LN3Diff / datasets /shapenet.py
NIRVANALAN
release file
87c126b
import os
import torchvision
import pickle
from typing import Any
import lmdb
import cv2
import imageio
import numpy as np
from PIL import Image
import Imath
import OpenEXR
from pdb import set_trace as st
from pathlib import Path
from functools import partial
import io
import gzip
import random
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.utils.data.distributed import DistributedSampler
from pathlib import Path
from guided_diffusion import logger
def load_dataset(
file_path="",
reso=64,
reso_encoder=224,
batch_size=1,
# shuffle=True,
num_workers=6,
load_depth=False,
preprocess=None,
imgnet_normalize=True,
dataset_size=-1,
trainer_name='input_rec',
use_lmdb=False,
infi_sampler=True
):
# st()
# dataset_cls = {
# 'input_rec': MultiViewDataset,
# 'nv': NovelViewDataset,
# }[trainer_name]
# st()
if use_lmdb:
logger.log('using LMDB dataset')
# dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later.
if 'nv' in trainer_name:
dataset_cls = LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
else:
dataset_cls = LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
# dataset = dataset_cls(file_path)
else:
if 'nv' in trainer_name:
dataset_cls = NovelViewDataset # 1.5-2iter/s
else:
dataset_cls = MultiViewDataset
dataset = dataset_cls(file_path,
reso,
reso_encoder,
test=False,
preprocess=preprocess,
load_depth=load_depth,
imgnet_normalize=imgnet_normalize,
dataset_size=dataset_size)
logger.log('dataset_cls: {}, dataset size: {}'.format(
trainer_name, len(dataset)))
loader = DataLoader(dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
pin_memory=True,
persistent_workers=num_workers > 0,
shuffle=False)
return loader
def load_data(
file_path="",
reso=64,
reso_encoder=224,
batch_size=1,
# shuffle=True,
num_workers=6,
load_depth=False,
preprocess=None,
imgnet_normalize=True,
dataset_size=-1,
trainer_name='input_rec',
use_lmdb=False,
infi_sampler=True
):
# st()
# dataset_cls = {
# 'input_rec': MultiViewDataset,
# 'nv': NovelViewDataset,
# }[trainer_name]
# st()
if use_lmdb:
logger.log('using LMDB dataset')
# dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later.
if 'nv' in trainer_name:
dataset_cls = LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
else:
dataset_cls = LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
# dataset = dataset_cls(file_path)
else:
if 'nv' in trainer_name:
dataset_cls = NovelViewDataset # 1.5-2iter/s
else:
dataset_cls = MultiViewDataset
dataset = dataset_cls(file_path,
reso,
reso_encoder,
test=False,
preprocess=preprocess,
load_depth=load_depth,
imgnet_normalize=imgnet_normalize,
dataset_size=dataset_size)
logger.log('dataset_cls: {}, dataset size: {}'.format(
trainer_name, len(dataset)))
# st()
if infi_sampler:
train_sampler = DistributedSampler(dataset=dataset,
shuffle=True,
drop_last=True)
loader = DataLoader(dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=True,
pin_memory=True,
persistent_workers=num_workers > 0,
sampler=train_sampler)
while True:
yield from loader
else:
# loader = DataLoader(dataset,
# batch_size=batch_size,
# num_workers=num_workers,
# drop_last=False,
# pin_memory=True,
# persistent_workers=num_workers > 0,
# shuffle=False)
st()
return dataset
def load_eval_rays(file_path="",
reso=64,
reso_encoder=224,
imgnet_normalize=True):
dataset = MultiViewDataset(file_path,
reso,
reso_encoder,
imgnet_normalize=imgnet_normalize)
pose_list = dataset.single_pose_list
ray_list = []
for pose_fname in pose_list:
# c2w = dataset.get_c2w(pose_fname).reshape(1,4,4) #[1, 4, 4]
# rays_o, rays_d = dataset.gen_rays(c2w)
# ray_list.append(
# [rays_o.unsqueeze(0),
# rays_d.unsqueeze(0),
# c2w.reshape(-1, 16)])
c2w = dataset.get_c2w(pose_fname).reshape(16) #[1, 4, 4]
c = torch.cat([c2w, dataset.intrinsics],
dim=0).reshape(25) # 25, no '1' dim needed.
ray_list.append(c)
return ray_list
def load_eval_data(file_path="",
reso=64,
reso_encoder=224,
batch_size=1,
num_workers=1,
load_depth=False,
preprocess=None,
imgnet_normalize=True,
interval=1, **kwargs):
dataset = MultiViewDataset(file_path,
reso,
reso_encoder,
preprocess=preprocess,
load_depth=load_depth,
test=True,
imgnet_normalize=imgnet_normalize,
interval=interval)
print('eval dataset size: {}'.format(len(dataset)))
# train_sampler = DistributedSampler(dataset=dataset)
loader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
shuffle=False,
)
# sampler=train_sampler)
return loader
def load_memory_data(file_path="",
reso=64,
reso_encoder=224,
batch_size=1,
num_workers=1,
load_depth=True,
preprocess=None,
imgnet_normalize=True):
# load a single-instance into the memory to speed up training IO
dataset = MultiViewDataset(file_path,
reso,
reso_encoder,
preprocess=preprocess,
load_depth=True,
test=False,
overfitting=True,
imgnet_normalize=imgnet_normalize,
overfitting_bs=batch_size)
logger.log('!!!!!!! memory dataset size: {} !!!!!!'.format(len(dataset)))
# train_sampler = DistributedSampler(dataset=dataset)
loader = DataLoader(
dataset,
batch_size=len(dataset),
num_workers=num_workers,
drop_last=False,
shuffle=False,
)
all_data: dict = next(iter(loader))
while True:
start_idx = np.random.randint(0, len(dataset) - batch_size + 1)
yield {
k: v[start_idx:start_idx + batch_size]
for k, v in all_data.items()
}
class MultiViewDataset(Dataset):
def __init__(self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
interval=1):
self.file_path = file_path
self.overfitting = overfitting
self.scene_scale = scene_scale
self.reso = reso
self.reso_encoder = reso_encoder
self.classes = False
self.load_depth = load_depth
self.preprocess = preprocess
assert not self.classes, "Not support class condition now."
# self.ins_list = os.listdir(self.file_path)
# if test: # TODO
dataset_name = Path(self.file_path).stem.split('_')[0]
self.dataset_name = dataset_name
if test:
# ins_list_file = Path(self.file_path).parent / f'{dataset_name}_test_list.txt' # ? in domain
if dataset_name == 'chair':
self.ins_list = sorted(os.listdir(
self.file_path))[1:2] # more diversity
else:
self.ins_list = sorted(os.listdir(self.file_path))[
0:1] # the first 1 instance for evaluation reference.
else:
# self.ins_list = sorted(Path(self.file_path).glob('[0-8]*'))
# self.ins_list = Path(self.file_path).glob('*')
# self.ins_list = list(Path(self.file_path).glob('*'))[:dataset_size]
# ins_list_file = Path(
# self.file_path).parent / f'{dataset_name}s_train_list.txt'
# assert ins_list_file.exists(), 'add training list for ShapeNet'
# with open(ins_list_file, 'r') as f:
# self.ins_list = [name.strip() for name in f.readlines()]
# if dataset_name == 'chair':
ins_list_file = Path(
self.file_path).parent / f'{dataset_name}_train_list.txt'
# st()
assert ins_list_file.exists(), 'add training list for ShapeNet'
with open(ins_list_file, 'r') as f:
self.ins_list = [name.strip()
for name in f.readlines()][:dataset_size]
# else:
# self.ins_list = Path(self.file_path).glob('*')
if overfitting:
self.ins_list = self.ins_list[:1]
self.rgb_list = []
self.pose_list = []
self.depth_list = []
self.data_ins_list = []
self.instance_data_length = -1
for ins in self.ins_list:
cur_rgb_path = os.path.join(self.file_path, ins, 'rgb')
cur_pose_path = os.path.join(self.file_path, ins, 'pose')
cur_all_fname = sorted([
t.split('.')[0] for t in os.listdir(cur_rgb_path)
if 'depth' not in t
][::interval])
if self.instance_data_length == -1:
self.instance_data_length = len(cur_all_fname)
else:
assert len(cur_all_fname) == self.instance_data_length
# ! check filtered data
# for idx in range(len(cur_all_fname)):
# fname = cur_all_fname[idx]
# if not Path(os.path.join(cur_rgb_path, fname + '.png') ).exists():
# cur_all_fname.remove(fname)
# del cur_all_fname[idx]
if test:
mid_index = len(cur_all_fname) // 3 * 2
cur_all_fname.insert(0, cur_all_fname[mid_index])
self.pose_list += ([
os.path.join(cur_pose_path, fname + '.txt')
for fname in cur_all_fname
])
self.rgb_list += ([
os.path.join(cur_rgb_path, fname + '.png')
for fname in cur_all_fname
])
self.depth_list += ([
os.path.join(cur_rgb_path, fname + '_depth0001.exr')
for fname in cur_all_fname
])
self.data_ins_list += ([ins] * len(cur_all_fname))
# validate overfitting on images
if overfitting:
# bs=9
# self.pose_list = self.pose_list[::50//9+1]
# self.rgb_list = self.rgb_list[::50//9+1]
# self.depth_list = self.depth_list[::50//9+1]
# bs=6
# self.pose_list = self.pose_list[::50//6+1]
# self.rgb_list = self.rgb_list[::50//6+1]
# self.depth_list = self.depth_list[::50//6+1]
# bs=3
assert overfitting_bs != -1
# bs=1
# self.pose_list = self.pose_list[25:26]
# self.rgb_list = self.rgb_list[25:26]
# self.depth_list = self.depth_list[25:26]
# uniform pose sampling
self.pose_list = self.pose_list[::50//overfitting_bs+1]
self.rgb_list = self.rgb_list[::50//overfitting_bs+1]
self.depth_list = self.depth_list[::50//overfitting_bs+1]
# sequentially sampling pose
# self.pose_list = self.pose_list[25:25+overfitting_bs]
# self.rgb_list = self.rgb_list[25:25+overfitting_bs]
# self.depth_list = self.depth_list[25:25+overfitting_bs]
# duplicate the same pose
# self.pose_list = [self.pose_list[25]] * overfitting_bs
# self.rgb_list = [self.rgb_list[25]] * overfitting_bs
# self.depth_list = [self.depth_list[25]] * overfitting_bs
# self.pose_list = [self.pose_list[28]] * overfitting_bs
# self.rgb_list = [self.rgb_list[28]] * overfitting_bs
# self.depth_list = [self.depth_list[28]] * overfitting_bs
self.single_pose_list = [
os.path.join(cur_pose_path, fname + '.txt')
for fname in cur_all_fname
]
# st()
# if imgnet_normalize:
transformations = [
transforms.ToTensor(), # [0,1] range
]
if imgnet_normalize:
transformations.append(
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225)) # type: ignore
)
else:
transformations.append(
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))) # type: ignore
self.normalize = transforms.Compose(transformations)
# self.normalize_normalrange = transforms.Compose([
# transforms.ToTensor(),# [0,1] range
# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
# ])
fx = fy = 525
cx = cy = 256 # rendering default K
factor = self.reso / (cx * 2) # 128 / 512
self.fx = fx * factor
self.fy = fy * factor
self.cx = cx * factor
self.cy = cy * factor
# ! fix scale for triplane ray_sampler(), here we adopt [0,1] uv range, not [0, w] img space range.
self.cx /= self.reso # 0.5
self.cy /= self.reso # 0.5
self.fx /= self.reso
self.fy /= self.reso
intrinsics = np.array([[self.fx, 0, self.cx], [0, self.fy, self.cy],
[0, 0, 1]]).reshape(9)
# self.intrinsics = torch.from_numpy(intrinsics).float()
self.intrinsics = intrinsics
def __len__(self):
return len(self.rgb_list)
def get_c2w(self, pose_fname):
with open(pose_fname, 'r') as f:
cam2world = f.readline().strip()
cam2world = [float(t) for t in cam2world.split(' ')]
c2w = torch.tensor(cam2world, dtype=torch.float32).reshape(4, 4)
return c2w
def gen_rays(self, c2w):
# Generate rays
self.h = self.reso
self.w = self.reso
yy, xx = torch.meshgrid(
torch.arange(self.h, dtype=torch.float32) + 0.5,
torch.arange(self.w, dtype=torch.float32) + 0.5,
indexing='ij')
xx = (xx - self.cx) / self.fx
yy = (yy - self.cy) / self.fy
zz = torch.ones_like(xx)
dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention
dirs /= torch.norm(dirs, dim=-1, keepdim=True)
dirs = dirs.reshape(1, -1, 3, 1)
del xx, yy, zz
dirs = (c2w[:, None, :3, :3] @ dirs)[..., 0]
origins = c2w[:, None, :3, 3].expand(-1, self.h * self.w,
-1).contiguous()
origins = origins.view(-1, 3)
dirs = dirs.view(-1, 3)
return origins, dirs
def read_depth(self, idx):
depth_path = self.depth_list[idx]
# image_path = os.path.join(depth_fname, self.image_names[index])
exr = OpenEXR.InputFile(depth_path)
header = exr.header()
size = (header['dataWindow'].max.x - header['dataWindow'].min.x + 1,
header['dataWindow'].max.y - header['dataWindow'].min.y + 1)
FLOAT = Imath.PixelType(Imath.PixelType.FLOAT)
depth_str = exr.channel('B', FLOAT)
depth = np.frombuffer(depth_str,
dtype=np.float32).reshape(size[1],
size[0]) # H W
depth = np.nan_to_num(depth, posinf=0, neginf=0)
depth = depth.reshape(size)
def resize_depth_mask(depth_to_resize, resolution):
depth_resized = cv2.resize(depth_to_resize,
(resolution, resolution),
interpolation=cv2.INTER_LANCZOS4)
# interpolation=cv2.INTER_AREA)
return depth_resized > 0 # type: ignore
fg_mask_reso = resize_depth_mask(depth, self.reso)
fg_mask_sr = resize_depth_mask(depth, 128)
# depth = cv2.resize(depth, (self.reso, self.reso),
# interpolation=cv2.INTER_LANCZOS4)
# interpolation=cv2.INTER_AREA)
# depth_mask = depth > 0
# depth = np.expand_dims(depth, axis=0).reshape(size)
# return torch.from_numpy(depth)
return torch.from_numpy(depth), torch.from_numpy(
fg_mask_reso), torch.from_numpy(fg_mask_sr)
def load_bbox(self, mask):
nonzero_value = torch.nonzero(mask)
height, width = nonzero_value.max(dim=0)[0]
top, left = nonzero_value.min(dim=0)[0]
bbox = torch.tensor([top, left, height, width], dtype=torch.float32)
return bbox
def __getitem__(self, idx):
rgb_fname = self.rgb_list[idx]
pose_fname = self.pose_list[idx]
raw_img = imageio.imread(rgb_fname)
if self.preprocess is None:
img_to_encoder = cv2.resize(raw_img,
(self.reso_encoder, self.reso_encoder),
interpolation=cv2.INTER_LANCZOS4)
# interpolation=cv2.INTER_AREA)
img_to_encoder = img_to_encoder[
..., :3] #[3, reso_encoder, reso_encoder]
img_to_encoder = self.normalize(img_to_encoder)
else:
img_to_encoder = self.preprocess(Image.open(rgb_fname)) # clip
img = cv2.resize(raw_img, (self.reso, self.reso),
interpolation=cv2.INTER_LANCZOS4)
# interpolation=cv2.INTER_AREA)
# img_sr = cv2.resize(raw_img, (512, 512), interpolation=cv2.INTER_AREA)
# img_sr = cv2.resize(raw_img, (256, 256), interpolation=cv2.INTER_AREA) # just as refinement, since eg3d uses 64->128 final resolution
# img_sr = cv2.resize(raw_img, (128, 128), interpolation=cv2.INTER_AREA) # just as refinement, since eg3d uses 64->128 final resolution
img_sr = cv2.resize(
raw_img, (128, 128), interpolation=cv2.INTER_LANCZOS4
) # just as refinement, since eg3d uses 64->128 final resolution
# img = torch.from_numpy(img)[..., :3].permute(
# 2, 0, 1) / 255.0 #[3, reso, reso]
img = torch.from_numpy(img)[..., :3].permute(
2, 0, 1
) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range
img_sr = torch.from_numpy(img_sr)[..., :3].permute(
2, 0, 1
) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range
# c2w = self.get_c2w(pose_fname).reshape(1, 4, 4) #[1, 4, 4]
# rays_o, rays_d = self.gen_rays(c2w)
# return img_to_encoder, img, rays_o, rays_d, c2w.reshape(-1)
c2w = self.get_c2w(pose_fname).reshape(16) #[1, 4, 4] -> [1, 16]
# c = np.concatenate([c2w, self.intrinsics], axis=0).reshape(25) # 25, no '1' dim needed.
c = torch.cat([c2w, torch.from_numpy(self.intrinsics)],
dim=0).reshape(25) # 25, no '1' dim needed.
ret_dict = {
# 'rgb_fname': rgb_fname,
'img_to_encoder': img_to_encoder,
'img': img,
'c': c,
'img_sr': img_sr,
# 'ins_name': self.data_ins_list[idx]
}
if self.load_depth:
depth, depth_mask, depth_mask_sr = self.read_depth(idx)
bbox = self.load_bbox(depth_mask)
ret_dict.update({
'depth': depth,
'depth_mask': depth_mask,
'depth_mask_sr': depth_mask_sr,
'bbox': bbox
})
# rays_o, rays_d = self.gen_rays(c2w)
# return img_to_encoder, img, c
return ret_dict
class MultiViewDatasetforLMDB(MultiViewDataset):
def __init__(self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1):
super().__init__(file_path, reso, reso_encoder, preprocess, classes,
load_depth, test, scene_scale, overfitting,
imgnet_normalize, dataset_size, overfitting_bs)
def __len__(self):
return super().__len__()
# return 100 # for speed debug
def __getitem__(self, idx):
# ret_dict = super().__getitem__(idx)
rgb_fname = self.rgb_list[idx]
pose_fname = self.pose_list[idx]
raw_img = imageio.imread(rgb_fname)[..., :3]
c2w = self.get_c2w(pose_fname).reshape(16) #[1, 4, 4] -> [1, 16]
# c = np.concatenate([c2w, self.intrinsics], axis=0).reshape(25) # 25, no '1' dim needed.
c = torch.cat([c2w, torch.from_numpy(self.intrinsics)],
dim=0).reshape(25) # 25, no '1' dim needed.
depth, depth_mask, depth_mask_sr = self.read_depth(idx)
bbox = self.load_bbox(depth_mask)
ret_dict = {
'raw_img': raw_img,
'c': c,
'depth': depth,
# 'depth_mask': depth_mask, # 64x64 here?
'bbox': bbox
}
return ret_dict
def load_data_dryrun(
file_path="",
reso=64,
reso_encoder=224,
batch_size=1,
# shuffle=True,
num_workers=6,
load_depth=False,
preprocess=None,
imgnet_normalize=True):
# st()
dataset = MultiViewDataset(file_path,
reso,
reso_encoder,
test=False,
preprocess=preprocess,
load_depth=load_depth,
imgnet_normalize=imgnet_normalize)
print('dataset size: {}'.format(len(dataset)))
# st()
# train_sampler = DistributedSampler(dataset=dataset)
loader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
# shuffle=shuffle,
drop_last=False,
)
# sampler=train_sampler)
return loader
class NovelViewDataset(MultiViewDataset):
"""novel view prediction version.
"""
def __init__(self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1):
super().__init__(file_path, reso, reso_encoder, preprocess, classes,
load_depth, test, scene_scale, overfitting,
imgnet_normalize, dataset_size, overfitting_bs)
def __getitem__(self, idx):
input_view = super().__getitem__(
idx) # get previous input view results
# get novel view of the same instance
novel_view = super().__getitem__(
(idx // self.instance_data_length) * self.instance_data_length +
random.randint(0, self.instance_data_length - 1))
# assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance'
input_view.update({f'nv_{k}': v for k, v in novel_view.items()})
return input_view
def load_data_for_lmdb(
file_path="",
reso=64,
reso_encoder=224,
batch_size=1,
# shuffle=True,
num_workers=6,
load_depth=False,
preprocess=None,
imgnet_normalize=True,
dataset_size=-1,
trainer_name='input_rec'):
# st()
# dataset_cls = {
# 'input_rec': MultiViewDataset,
# 'nv': NovelViewDataset,
# }[trainer_name]
# if 'nv' in trainer_name:
# dataset_cls = NovelViewDataset
# else:
# dataset_cls = MultiViewDataset
dataset_cls = MultiViewDatasetforLMDB
dataset = dataset_cls(file_path,
reso,
reso_encoder,
test=False,
preprocess=preprocess,
load_depth=load_depth,
imgnet_normalize=imgnet_normalize,
dataset_size=dataset_size)
logger.log('dataset_cls: {}, dataset size: {}'.format(
trainer_name, len(dataset)))
# train_sampler = DistributedSampler(dataset=dataset, shuffle=True, drop_last=True)
loader = DataLoader(
dataset,
shuffle=False,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
prefetch_factor=2,
# prefetch_factor=3,
pin_memory=True,
persistent_workers=True,
)
# sampler=train_sampler)
# while True:
# yield from loader
return loader, dataset.dataset_name, len(dataset)
class LMDBDataset(Dataset):
def __init__(self, lmdb_path):
self.env = lmdb.open(
lmdb_path,
readonly=True,
max_readers=32,
lock=False,
readahead=False,
meminit=False,
)
self.num_samples = self.env.stat()['entries']
# self.start_idx = self.env.stat()['start_idx']
# self.end_idx = self.env.stat()['end_idx']
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
with self.env.begin(write=False) as txn:
key = str(idx).encode('utf-8')
value = txn.get(key)
sample = pickle.loads(value)
return sample
def resize_depth_mask(depth_to_resize, resolution):
depth_resized = cv2.resize(depth_to_resize, (resolution, resolution),
interpolation=cv2.INTER_LANCZOS4)
# interpolation=cv2.INTER_AREA)
return depth_resized, depth_resized > 0 # type: ignore
class LMDBDataset_MV(LMDBDataset):
def __init__(self,
lmdb_path,
reso,
reso_encoder,
imgnet_normalize=True,
**kwargs):
super().__init__(lmdb_path)
self.reso_encoder = reso_encoder
self.reso = reso
transformations = [
transforms.ToTensor(), # [0,1] range
]
if imgnet_normalize:
transformations.append(
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225)) # type: ignore
)
else:
transformations.append(
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))) # type: ignore
self.normalize = transforms.Compose(transformations)
def _post_process_sample(self, raw_img, depth):
# if raw_img.shape[-1] == 4: # ! set bg to white
# alpha_mask = raw_img[..., -1:] > 0
# raw_img = alpha_mask * raw_img[..., :3] + (1-alpha_mask) * np.ones_like(raw_img[..., :3]) * 255
# raw_img = raw_img.astype(np.uint8)
# img_to_encoder = cv2.resize(sample.pop('raw_img'),
img_to_encoder = cv2.resize(raw_img,
(self.reso_encoder, self.reso_encoder),
interpolation=cv2.INTER_LANCZOS4)
# interpolation=cv2.INTER_AREA)
img_to_encoder = img_to_encoder[..., :
3] #[3, reso_encoder, reso_encoder]
img_to_encoder = self.normalize(img_to_encoder)
img = cv2.resize(raw_img, (self.reso, self.reso),
interpolation=cv2.INTER_LANCZOS4)
if img.shape[-1] == 4:
alpha_mask = img[..., -1:] > 0
img = alpha_mask * img[..., :3] + (1-alpha_mask) * np.ones_like(img[..., :3]) * 255
img = torch.from_numpy(img)[..., :3].permute(
2, 0, 1
) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range
img_sr = torch.from_numpy(raw_img)[..., :3].permute(
2, 0, 1
) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range
# depth
# fg_mask_reso = resize_depth_mask(sample['depth'], self.reso)
depth_reso, fg_mask_reso = resize_depth_mask(depth, self.reso)
return {
# **sample,
'img_to_encoder': img_to_encoder,
'img': img,
'depth_mask': fg_mask_reso,
'img_sr': img_sr,
'depth': depth_reso,
# ! no need to load img_sr for now
}
def __getitem__(self, idx):
sample = super().__getitem__(idx)
# do transformations online
return self._post_process_sample(sample['raw_img'], sample['depth'])
# return sample
def load_bytes(inp_bytes, dtype, shape):
return np.frombuffer(inp_bytes, dtype=dtype).reshape(shape).copy()
# Function to decompress an image using gzip and open with imageio
def decompress_and_open_image_gzip(compressed_data, is_img=False):
# Decompress the image data using gzip
decompressed_data = gzip.decompress(compressed_data)
# Read the decompressed image using imageio
if is_img:
image = imageio.v3.imread(io.BytesIO(decompressed_data)).copy()
return image
return decompressed_data
# Function to decompress an array using gzip
def decompress_array(compressed_data, shape, dtype):
# Decompress the array data using gzip
decompressed_data = gzip.decompress(compressed_data)
# Convert the decompressed data to a NumPy array
# arr = np.frombuffer(decompressed_data, dtype=dtype).reshape(shape)
return load_bytes(decompressed_data, dtype, shape)
class LMDBDataset_MV_Compressed(LMDBDataset_MV):
def __init__(self,
lmdb_path,
reso,
reso_encoder,
imgnet_normalize=True,
**kwargs):
super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize,
**kwargs)
with self.env.begin(write=False) as txn:
self.length = int(
txn.get('length'.encode('utf-8')).decode('utf-8')) - 40
self.load_image_fn = partial(decompress_and_open_image_gzip,
is_img=True)
def __len__(self):
return self.length
def _load_lmdb_data(self, idx):
with self.env.begin(write=False) as txn:
raw_img_key = f'{idx}-raw_img'.encode('utf-8')
raw_img = self.load_image_fn(txn.get(raw_img_key))
depth_key = f'{idx}-depth'.encode('utf-8')
depth = decompress_array(txn.get(depth_key), (512,512), np.float32)
c_key = f'{idx}-c'.encode('utf-8')
c = decompress_array(txn.get(c_key), (25, ), np.float32)
bbox_key = f'{idx}-bbox'.encode('utf-8')
bbox = decompress_array(txn.get(bbox_key), (4, ), np.float32)
return raw_img, depth, c, bbox
def __getitem__(self, idx):
# sample = super(LMDBDataset).__getitem__(idx)
# do gzip uncompress online
raw_img, depth, c, bbox = self._load_lmdb_data(idx)
return {
**self._post_process_sample(raw_img, depth), 'c': c,
'bbox': bbox*(self.reso/64.0),
# 'depth': depth,
}
class LMDBDataset_NV_Compressed(LMDBDataset_MV_Compressed):
def __init__(self, lmdb_path, reso, reso_encoder, imgnet_normalize=True, **kwargs):
super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize, **kwargs)
self.instance_data_length = 50 #
def __getitem__(self, idx):
input_view = super().__getitem__(
idx) # get previous input view results
# get novel view of the same instance
try:
novel_view = super().__getitem__(
(idx // self.instance_data_length) * self.instance_data_length +
random.randint(0, self.instance_data_length - 1))
except Exception as e:
raise NotImplementedError(idx)
assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance'
input_view.update({f'nv_{k}': v for k, v in novel_view.items()})
return input_view