import os import cv2 import glob import json import tqdm import numpy as np from scipy.spatial.transform import Rotation from .network import AudioEncoder import trimesh import torch from torch.utils.data import DataLoader from .utils import get_audio_features, get_rays, get_bg_coords, AudDataset # ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50 def nerf_matrix_to_ngp(pose, scale=0.33, offset=[0, 0, 0]): new_pose = np.array([ [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale + offset[0]], [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale + offset[1]], [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale + offset[2]], [0, 0, 0, 1], ], dtype=np.float32) return new_pose def smooth_camera_path(poses, kernel_size=5): # smooth the camera trajectory... # poses: [N, 4, 4], numpy array N = poses.shape[0] K = kernel_size // 2 trans = poses[:, :3, 3].copy() # [N, 3] rots = poses[:, :3, :3].copy() # [N, 3, 3] for i in range(N): start = max(0, i - K) end = min(N, i + K + 1) poses[i, :3, 3] = trans[start:end].mean(0) poses[i, :3, :3] = Rotation.from_matrix(rots[start:end]).mean().as_matrix() return poses def polygon_area(x, y): x_ = x - x.mean() y_ = y - y.mean() correction = x_[-1] * y_[0] - y_[-1]* x_[0] main_area = np.dot(x_[:-1], y_[1:]) - np.dot(y_[:-1], x_[1:]) return 0.5 * np.abs(main_area + correction) def visualize_poses(poses, size=0.1): # poses: [B, 4, 4] print(f'[INFO] visualize poses: {poses.shape}') axes = trimesh.creation.axis(axis_length=4) box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline() box.colors = np.array([[128, 128, 128]] * len(box.entities)) objects = [axes, box] for pose in poses: # a camera is visualized with 8 line segments. pos = pose[:3, 3] a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] dir = (a + b + c + d) / 4 - pos dir = dir / (np.linalg.norm(dir) + 1e-8) o = pos + dir * 3 segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a], [pos, o]]) segs = trimesh.load_path(segs) objects.append(segs) trimesh.Scene(objects).show() class NeRFDataset: # 这段代码主要执行以下功能: # 初始化背景图像 bg_img: # 如果 self.opt.bg_img 为 'black',则创建一个全零数组作为背景图像。 # 否则,从文件中加载背景图像,调整其大小和颜色格式,并将其转换为浮点数格式。 # 处理相机姿态 self.poses: # 将所有姿态堆叠成一个数组。 # 如果 self.opt.smooth_path 为真,则对相机路径进行平滑处理。 # 将姿态数组转换为张量。 # 预加载图像数据: # 如果 self.preload 大于0,则将图像数据转换为张量格式。 # 否则,将图像数据转换为NumPy数组格式。 # 处理音频数据 self.auds: # 如果 self.opt.asr 为真,则不进行任何处理。 # 否则,根据选项加载对应的音频数据。 # 处理眼睛区域数据 self.eye_area: # 如果 self.opt.exp_eye 为真,则对眼睛区域数据进行平滑处理,并根据选项调整数据格式。 # 计算相机位置的平均半径 self.radius。 # 如果 self.preload 大于1,则将数据加载到相应设备上并进行半精度转换。 # 加载内参矩阵 self.intrinsics: # 根据不同的内参信息来源,计算焦距和光心位置。 # 构建内参矩阵。 # 构建背景坐标网格 self.bg_coords: # 使用 get_bg_coords 函数在 [-1, 1]^2 范围内构建坐标网格。 def __init__(self, opt, device, type='train', downscale=1): super().__init__() self.opt = opt self.device = device self.type = type # train, val, test self.downscale = downscale self.root_path = opt.path.strip() self.preload = opt.preload # 0 = disk, 1 = cpu, 2 = gpu self.scale = opt.scale # camera radius scale to make sure camera are inside the bounding box. self.offset = opt.offset # camera offset self.bound = opt.bound # bounding box half length, also used as the radius to random sample poses. self.fp16 = opt.fp16 self.start_index = opt.data_range[0] self.end_index = opt.data_range[1] self.training = self.type in ['train', 'all', 'trainval'] self.num_rays = self.opt.num_rays if self.training else -1 # load nerf-compatible format data. print(f'[INFO] load type is {self.type}') # load all splits (train/valid/test) if type == 'all': transform_paths = glob.glob(os.path.join(self.root_path, '*.json')) transform = None for transform_path in transform_paths: with open(transform_path, 'r') as f: tmp_transform = json.load(f) if transform is None: transform = tmp_transform else: transform['frames'].extend(tmp_transform['frames']) # load train and val split elif type == 'trainval': with open(os.path.join(self.root_path, f'transforms_train.json'), 'r') as f: transform = json.load(f) with open(os.path.join(self.root_path, f'transforms_val.json'), 'r') as f: transform_val = json.load(f) transform['frames'].extend(transform_val['frames']) # only load one specified split else: # no test, use val as test _split = 'val' if type == 'test' else type print(f'[INFO] load {_split} split') with open(os.path.join(self.root_path, f'transforms_{_split}.json'), 'r') as f: transform = json.load(f) # load image size if 'h' in transform and 'w' in transform: self.H = int(transform['h']) // downscale self.W = int(transform['w']) // downscale else: self.H = int(transform['cy']) * 2 // downscale self.W = int(transform['cx']) * 2 // downscale # load crop position if 'crop_offset_x' in transform and 'crop_offset_y' in transform: self.crop_offset_x = transform['crop_offset_x'] self.crop_offset_y = transform['crop_offset_y'] else: self.crop_offset_x = 0 self.crop_offset_y = 0 # read images frames = transform["frames"] # use a slice of the dataset if self.end_index == -1: # abuse... self.end_index = len(frames) frames = frames[self.start_index:self.end_index] # use a subset of dataset. if type == 'train': if self.opt.part: frames = frames[::10] # 1/10 frames elif self.opt.part2: frames = frames[:375] # first 15s elif type == 'val': frames = frames[:100] # first 100 frames for val print(f'[INFO] load {len(frames)} {type} frames.') print(f'[INFO] load opt.asr : {self.opt.asr}, aud : {self.opt.aud}') # only load pre-calculated aud features when not live-streaming if not self.opt.asr: # empty means the default self-driven extracted features. if self.opt.aud == '': if 'esperanto' in self.opt.asr_model: aud_features = np.load(os.path.join(self.root_path, 'aud_eo.npy')) elif 'deepspeech' in self.opt.asr_model: aud_features = np.load(os.path.join(self.root_path, 'aud_ds.npy')) # elif 'hubert_cn' in self.opt.asr_model: # aud_features = np.load(os.path.join(self.root_path, 'aud_hu_cn.npy')) elif 'hubert' in self.opt.asr_model: print(f'[INFO] load {self.opt.aud} aud_features: {aud_features.shape}') aud_features = np.load(os.path.join(self.root_path, 'aud_hu.npy')) elif self.opt.asr_model == 'ave': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = AudioEncoder().to(device).eval() ckpt = torch.load('./nerf_triplane/checkpoints/audio_visual_encoder.pth') model.load_state_dict({f'audio_encoder.{k}': v for k, v in ckpt.items()}) dataset = AudDataset(os.path.join(self.root_path, 'aud.wav')) data_loader = DataLoader(dataset, batch_size=64, shuffle=False) outputs = [] for mel in data_loader: mel = mel.to(device) with torch.no_grad(): out = model(mel) outputs.append(out) outputs = torch.cat(outputs, dim=0).cpu() first_frame, last_frame = outputs[:1], outputs[-1:] aud_features = torch.cat([first_frame.repeat(2, 1), outputs, last_frame.repeat(2, 1)], dim=0).numpy() # aud_features = np.load(os.path.join(self.root_path, 'aud_ave.npy')) else: aud_features = np.load(os.path.join(self.root_path, 'aud.npy')) # cross-driven extracted features. else: if self.opt.asr_model == 'ave': try: print(f'[INFO] load {self.opt.aud} aud_features: {aud_features.shape}') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = AudioEncoder().to(device).eval() ckpt = torch.load('./nerf_triplane/checkpoints/audio_visual_encoder.pth') model.load_state_dict({f'audio_encoder.{k}': v for k, v in ckpt.items()}) dataset = AudDataset(self.opt.aud) data_loader = DataLoader(dataset, batch_size=64, shuffle=False) outputs = [] for mel in data_loader: mel = mel.to(device) with torch.no_grad(): out = model(mel) outputs.append(out) outputs = torch.cat(outputs, dim=0).cpu() first_frame, last_frame = outputs[:1], outputs[-1:] aud_features = torch.cat([first_frame.repeat(2, 1), outputs, last_frame.repeat(2, 1)], dim=0).numpy() except: print(f'[ERROR] If do not use Audio Visual Encoder, replace it with the npy file path.') else: try: self.opt.aud.replace('.wav', '_hu.npy') #音频后缀改为_hu.npy base, ext = os.path.splitext(opt.aud) new_wav_name = f"{base}_hu.npy" print(f'[INFO] load {new_wav_name}') aud_features = np.load(new_wav_name) except: print(f'[ERROR] If do not use Audio Visual Encoder, replace it with the npy file path.') if self.opt.asr_model == 'ave': aud_features = torch.from_numpy(aud_features).unsqueeze(0) # support both [N, 16] labels and [N, 16, K] logits if len(aud_features.shape) == 3: aud_features = aud_features.float().permute(1, 0, 2) # [N, 16, 29] --> [N, 29, 16] if self.opt.emb: print(f'[INFO] argmax to aud features {aud_features.shape} for --emb mode') aud_features = aud_features.argmax(1) # [N, 16] else: assert self.opt.emb, "aud only provide labels, must use --emb" aud_features = aud_features.long() print(f'[INFO] load {self.opt.aud} aud_features: {aud_features.shape}') else: aud_features = torch.from_numpy(aud_features) # support both [N, 16] labels and [N, 16, K] logits if len(aud_features.shape) == 3: aud_features = aud_features.float().permute(0, 2, 1) # [N, 16, 29] --> [N, 29, 16] if self.opt.emb: print(f'[INFO] argmax to aud features {aud_features.shape} for --emb mode') aud_features = aud_features.argmax(1) # [N, 16] else: assert self.opt.emb, "aud only provide labels, must use --emb" aud_features = aud_features.long() print(f'[INFO] load {self.opt.aud} aud_features: {aud_features.shape}') if self.opt.au45: import pandas as pd au_blink_info = pd.read_csv(os.path.join(self.root_path, 'au.csv')) bs = au_blink_info[' AU45_r'].values else: bs = np.load(os.path.join(self.root_path, 'bs.npy')) if self.opt.bs_area == "upper": bs = np.hstack((bs[:, 0:5], bs[:, 8:10])) elif self.opt.bs_area == "single": bs = np.hstack((bs[:, 0].reshape(-1, 1),bs[:, 2].reshape(-1, 1),bs[:, 3].reshape(-1, 1), bs[:, 8].reshape(-1, 1))) elif self.opt.bs_area == "eye": bs = bs[:,8:10] self.torso_img = [] # self.images = [] self.gt_images = [] self.face_mask_imgs = [] self.full_body_imgs = [] self.poses = [] self.exps = [] self.auds = [] self.face_rect = [] self.lhalf_rect = [] self.upface_rect = [] self.lowface_rect = [] self.lips_rect = [] self.eye_area = [] self.eye_rect = [] for f in tqdm.tqdm(frames, desc=f'Loading {type} data'): pose = np.array(f['transform_matrix'], dtype=np.float32) # [4, 4] pose = nerf_matrix_to_ngp(pose, scale=self.scale, offset=self.offset) self.poses.append(pose) # portrait mode 走这个分支 if self.opt.portrait: gt_path = os.path.join(self.root_path, 'ori_imgs', str(f['img_id']) + '.jpg') # gt_path = os.path.join(self.root_path, 'torso_imgs', str(f['img_id']) + '_no_face.png') if not os.path.exists(gt_path): print('[WARN]', gt_path, 'NOT FOUND!') continue if self.preload > 0: gt_image = cv2.imread(gt_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] o [H, W, 4] gt_image = cv2.cvtColor(gt_image, cv2.COLOR_BGR2RGB) gt_image = gt_image.astype(np.float32) / 255 # [H, W, 3/4] self.gt_images.append(gt_image) else: self.gt_images.append(gt_path) face_mask_path = os.path.join(self.root_path, 'parsing', str(f['img_id']) + '_face.png') if not os.path.exists(face_mask_path): print('[WARN]', face_mask_path, 'NOT FOUND!') continue if self.preload > 0: face_mask_img = (255 - cv2.imread(face_mask_path)[:, :, 1]) / 255.0 self.face_mask_imgs.append(face_mask_img) else: self.face_mask_imgs.append(face_mask_path) if self.opt.fullbody: full_body_path = os.path.join(self.root_path, 'full_body_imgs', str(f['img_id']) + '.jpg') if not os.path.exists(full_body_path): print('[WARN]', full_body_path, 'NOT FOUND!') continue if self.preload > 0: full_body_image = cv2.imread(full_body_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] o [H, W, 4] full_body_image = cv2.cvtColor(full_body_image, cv2.COLOR_BGR2RGB) full_body_image = full_body_image.astype(np.float32) / 255 # [H, W, 3/4] self.full_body_imgs.append(full_body_image) else: self.full_body_imgs.append(full_body_path) else: f_path = os.path.join(self.root_path, 'gt_imgs', str(f['img_id']) + '.jpg') if not os.path.exists(f_path): print('[WARN]', f_path, 'NOT FOUND!') continue if self.preload > 0: image = cv2.imread(f_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] o [H, W, 4] image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = image.astype(np.float32) / 255 # [H, W, 3/4] self.images.append(image) else: self.images.append(f_path) # load frame-wise bg #***********torsor**************/ torso_img_path = os.path.join(self.root_path, 'torso_imgs', str(f['img_id']) + '.png') if self.preload > 0: torso_img = cv2.imread(torso_img_path, cv2.IMREAD_UNCHANGED) # [H, W, 4] torso_img = cv2.cvtColor(torso_img, cv2.COLOR_BGRA2RGBA) torso_img = torso_img.astype(np.float32) / 255 # [H, W, 3/4] self.torso_img.append(torso_img) else: self.torso_img.append(torso_img_path) # find the corresponding audio to the image frame if not self.opt.asr and self.opt.aud == '': aud = aud_features[min(f['aud_id'], aud_features.shape[0] - 1)] # careful for the last frame... self.auds.append(aud) # ****************load lms and extract face*****************# lms = np.loadtxt(os.path.join(self.root_path, 'ori_imgs', str(f['img_id']) + '.lms')) # [68, 2] lh_xmin, lh_xmax = int(lms[31:36, 1].min()), int(lms[:, 1].max()) # actually lower half area upface_xmin, upface_xmax = int(lms[:, 1].min()),int(lms[30,1]) lowface_xmin, lowface_xmax = int(lms[30,1]), int(lms[:, 1].max()) xmin, xmax = int(lms[:, 1].min()), int(lms[:, 1].max()) ymin, ymax = int(lms[:, 0].min()), int(lms[:, 0].max()) self.face_rect.append([xmin, xmax, ymin, ymax]) self.lhalf_rect.append([lh_xmin, lh_xmax, ymin, ymax]) self.upface_rect.append([upface_xmin, upface_xmax, ymin, ymax]) self.lowface_rect.append([lowface_xmin, lowface_xmax, ymin, ymax]) if self.opt.exp_eye: area = bs[f['img_id']] if self.opt.au45: area = np.clip(area, 0, 2) / 2 self.eye_area.append(area) xmin, xmax = int(lms[36:48, 1].min()), int(lms[36:48, 1].max()) ymin, ymax = int(lms[36:48, 0].min()), int(lms[36:48, 0].max()) self.eye_rect.append([xmin, xmax, ymin, ymax]) if self.opt.finetune_lips: lips = slice(48, 60) xmin, xmax = int(lms[lips, 1].min()), int(lms[lips, 1].max()) ymin, ymax = int(lms[lips, 0].min()), int(lms[lips, 0].max()) # padding to H == W cx = (xmin + xmax) // 2 cy = (ymin + ymax) // 2 l = max(xmax - xmin, ymax - ymin) // 2 xmin = max(0, cx - l) xmax = min(self.H, cx + l) ymin = max(0, cy - l) ymax = min(self.W, cy + l) self.lips_rect.append([xmin, xmax, ymin, ymax]) # load pre-extracted background image (should be the same size as training image...) if self.opt.bg_img == 'white': # special bg_img = np.ones((self.H, self.W, 3), dtype=np.float32) elif self.opt.bg_img == 'black': # special bg_img = np.zeros((self.H, self.W, 3), dtype=np.float32) else: # load from file # default bg if self.opt.bg_img == '': self.opt.bg_img = os.path.join(self.root_path, 'bc.jpg') bg_img = cv2.imread(self.opt.bg_img, cv2.IMREAD_UNCHANGED) # [H, W, 3] if bg_img.shape[0] != self.H or bg_img.shape[1] != self.W: bg_img = cv2.resize(bg_img, (self.W, self.H), interpolation=cv2.INTER_AREA) bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB) bg_img = bg_img.astype(np.float32) / 255 # [H, W, 3/4] self.bg_img = bg_img self.poses = np.stack(self.poses, axis=0) # smooth camera path... if self.opt.smooth_path: self.poses = smooth_camera_path(self.poses, self.opt.smooth_path_window) self.poses = torch.from_numpy(self.poses) # [N, 4, 4] if self.preload > 0: # self.images = torch.from_numpy(np.stack(self.images, axis=0)) # [N, H, W, C] self.torso_img = torch.from_numpy(np.stack(self.torso_img, axis=0)) # [N, H, W, C] if self.opt.portrait: self.gt_images = torch.from_numpy(np.stack(self.gt_images, axis=0)) # [N, H, W, C] self.face_mask_imgs = torch.from_numpy(np.stack(self.face_mask_imgs, axis=0)) # [N, H, W, C] if self.opt.fullbody: self.full_body_imgs = np.array(self.full_body_imgs) else: # self.images = np.array(self.images) self.torso_img = np.array(self.torso_img) if self.opt.portrait: self.gt_images = np.array(self.gt_images) self.face_mask_imgs = np.array(self.face_mask_imgs) if self.opt.fullbody: self.full_body_imgs = np.array(self.full_body_imgs) if self.opt.asr: # live streaming, no pre-calculated auds self.auds = None else: # auds corresponding to images if self.opt.aud == '': self.auds = torch.stack(self.auds, dim=0) # [N, 32, 16] # auds is novel, may have a different length with images else: self.auds = aud_features self.bg_img = torch.from_numpy(self.bg_img) #******************exp_eye******************# if self.opt.exp_eye: self.eye_area = np.array(self.eye_area, dtype=np.float32) # [N] print(f'[INFO] eye_area: {self.eye_area.min()} - {self.eye_area.max()}') if self.opt.smooth_eye: # naive 5 window average ori_eye = self.eye_area.copy() for i in range(ori_eye.shape[0]): start = max(0, i - 1) end = min(ori_eye.shape[0], i + 2) self.eye_area[i] = ori_eye[start:end].mean() if self.opt.au45: self.eye_area = torch.from_numpy(self.eye_area).view(-1, 1) # [N, 1] else: if self.opt.bs_area == "upper": self.eye_area = torch.from_numpy(self.eye_area).view(-1, 7) # [N, 7] elif self.opt.bs_area == "single": self.eye_area = torch.from_numpy(self.eye_area).view(-1, 4) # [N, 7] else: self.eye_area = torch.from_numpy(self.eye_area).view(-1, 2) # *****************calculate mean radius of all camera poses*******************# # self.radius = self.poses[:, :3, 3].norm(dim=-1).mean(0).item() #print(f'[INFO] dataset camera poses: radius = {self.radius:.4f}, bound = {self.bound}') # [debug] uncomment to view all training poses. # visualize_poses(self.poses.numpy()) # [debug] uncomment to view examples of randomly generated poses. # visualize_poses(rand_poses(100, self.device, radius=self.radius).cpu().numpy()) if self.preload > 1: self.poses = self.poses.to(self.device) if self.auds is not None: self.auds = self.auds.to(self.device) self.bg_img = self.bg_img.to(torch.half).to(self.device) self.torso_img = self.torso_img.to(torch.half).to(self.device) # self.images = self.images.to(torch.half).to(self.device) if self.opt.portrait: self.gt_images = self.gt_images.to(torch.half).to(self.device) # if self.opt.fullbody: # self.full_body_imgs = self.full_body_imgs.to(torch.half).to(self.device) self.face_mask_imgs = self.face_mask_imgs.to(torch.half).to(self.device) # if self.opt.exp_eye: # self.eye_area = self.eye_area.to(self.device) # load intrinsics if 'focal_len' in transform: fl_x = fl_y = transform['focal_len'] elif 'fl_x' in transform or 'fl_y' in transform: fl_x = (transform['fl_x'] if 'fl_x' in transform else transform['fl_y']) / downscale fl_y = (transform['fl_y'] if 'fl_y' in transform else transform['fl_x']) / downscale elif 'camera_angle_x' in transform or 'camera_angle_y' in transform: # blender, assert in radians. already downscaled since we use H/W fl_x = self.W / (2 * np.tan(transform['camera_angle_x'] / 2)) if 'camera_angle_x' in transform else None fl_y = self.H / (2 * np.tan(transform['camera_angle_y'] / 2)) if 'camera_angle_y' in transform else None if fl_x is None: fl_x = fl_y if fl_y is None: fl_y = fl_x else: raise RuntimeError('Failed to load focal length, please check the transforms.json!') cx = (transform['cx'] / downscale) if 'cx' in transform else (self.W / 2) cy = (transform['cy'] / downscale) if 'cy' in transform else (self.H / 2) self.intrinsics = np.array([fl_x, fl_y, cx, cy]) # directly build the coordinate meshgrid in [-1, 1]^2 self.bg_coords = get_bg_coords(self.H, self.W, self.device) # [1, H*W, 2] in [-1, 1] def mirror_index(self, index): size = self.poses.shape[0] turn = index // size res = index % size if turn % 2 == 0: return res else: return size - res - 1 def collate(self, index): """ 对输入索引进行处理,返回包含多个数据项的字典结果。 Args: index (list): 包含单个元素的列表,表示需要处理的索引。 Returns: dict: 包含多个数据项的字典,包括音频特征(auds)、姿态(poses)、矩形框(up_rect, low_rect, rect)、 索引(index)、图像高度(H)、图像宽度(W)、射线原点(rays_o)、射线方向(rays_d)、 人脸掩码(face_mask, lhalf_mask, upface_mask, lowface_mask),眼睛区域(eye),眼睛掩码(eye_mask), 背景颜色(bg_color),上半身图像颜色(bg_torso_color),背景图像(bg_gt_images),人脸掩码图像(bg_face_mask), 输入图像(images)以及背景坐标(bg_coords)。 """ B = len(index) # a list of length 1 # assert B == 1 results = {} # audio use the original index if self.auds is not None: auds = get_audio_features(self.auds, self.opt.att, index[0]).to(self.device) results['auds'] = auds # head pose and bg image may mirror (replay --> <-- --> <--). index[0] = self.mirror_index(index[0]) poses = self.poses[index].to(self.device) # [B, 4, 4] if self.training and self.opt.finetune_lips: rect = self.lips_rect[index[0]] results['rect'] = rect rays = get_rays(poses, self.intrinsics, self.H, self.W, -1, rect=rect) else: rays = get_rays(poses, self.intrinsics, self.H, self.W, self.num_rays, self.opt.patch_size) results['up_rect'] = self.upface_rect[index[0]] results['low_rect'] = self.lowface_rect[index[0]] results['index'] = index # for ind. code results['H'] = self.H results['W'] = self.W results['rays_o'] = rays['rays_o'] results['rays_d'] = rays['rays_d'] # get a mask for rays inside rect_face if self.training: xmin, xmax, ymin, ymax = self.face_rect[index[0]] face_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N] results['face_mask'] = face_mask xmin, xmax, ymin, ymax = self.lhalf_rect[index[0]] lhalf_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N] results['lhalf_mask'] = lhalf_mask xmin, xmax, ymin, ymax = self.upface_rect[index[0]] upface_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N] results['upface_mask'] = upface_mask xmin, xmax, ymin, ymax = self.lowface_rect[index[0]] lowface_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N] results['lowface_mask'] = lowface_mask if self.opt.exp_eye: results['eye'] = self.eye_area[index].to(self.device) # [1] if self.training: #results['eye'] += (np.random.rand()-0.5) / 10 xmin, xmax, ymin, ymax = self.eye_rect[index[0]] eye_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N] results['eye_mask'] = eye_mask else: results['eye'] = None # load bg # bg_torso_img = None bg_torso_img = self.torso_img[index] if self.preload == 0: # on the fly loading bg_torso_img = cv2.imread(bg_torso_img[0], cv2.IMREAD_UNCHANGED) # [H, W, 4] bg_torso_img = cv2.cvtColor(bg_torso_img, cv2.COLOR_BGRA2RGBA) bg_torso_img = bg_torso_img.astype(np.float32) / 255 # [H, W, 3/4] bg_torso_img = torch.from_numpy(bg_torso_img).unsqueeze(0) bg_torso_img = bg_torso_img[..., :3] * bg_torso_img[..., 3:] + self.bg_img * (1 - bg_torso_img[..., 3:]) bg_torso_img = bg_torso_img.view(B, -1, 3).to(self.device) if not self.opt.torso: bg_img = bg_torso_img else: bg_img = self.bg_img.view(1, -1, 3).repeat(B, 1, 1).to(self.device) if self.training: bg_img = torch.gather(bg_img, 1, torch.stack(3 * [rays['inds']], -1)) # [B, N, 3] results['bg_color'] = bg_img # if self.opt.torso and self.training: # bg_torso_img = torch.gather(bg_torso_img, 1, torch.stack(3 * [rays['inds']], -1)) # [B, N, 3] # results['bg_torso_color'] = bg_torso_img if self.opt.portrait: bg_gt_images = self.gt_images[index] if self.preload == 0: bg_gt_images = cv2.imread(bg_gt_images[0], cv2.IMREAD_UNCHANGED) bg_gt_images = cv2.cvtColor(bg_gt_images, cv2.COLOR_BGR2RGB) bg_gt_images = bg_gt_images.astype(np.float32) / 255 bg_gt_images = torch.from_numpy(bg_gt_images).unsqueeze(0) bg_gt_images = bg_gt_images.to(self.device) results['bg_gt_images'] = bg_gt_images bg_face_mask = self.face_mask_imgs[index] if self.preload == 0: # bg_face_mask = np.all(cv2.imread(bg_face_mask[0]) == [255, 0, 0], axis=-1).astype(np.uint8) bg_face_mask = (255 - cv2.imread(bg_face_mask[0])[:, :, 1]) / 255.0 bg_face_mask = torch.from_numpy(bg_face_mask).unsqueeze(0) bg_face_mask = bg_face_mask.to(self.device) results['bg_face_mask'] = bg_face_mask if self.opt.fullbody: full_body_images = self.full_body_imgs[index] # print(full_body_images) if self.preload == 0: # print(full_body_images[0]) full_body_images = cv2.imread(full_body_images[0], cv2.IMREAD_UNCHANGED) # [H, W, 3] o [H, W, 4] full_body_images = cv2.cvtColor(full_body_images, cv2.COLOR_BGR2RGB) full_body_images = full_body_images.astype(np.float32) / 255 # [H, W, 3/4] full_body_images = np.expand_dims(full_body_images, axis=0) # print(full_body_images.shape) results['full_body_img'] = full_body_images # images = self.images[index] # [B, H, W, 3/4] # if self.preload == 0: # images = cv2.imread(images[0], cv2.IMREAD_UNCHANGED) # [H, W, 3] # images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB) # images = images.astype(np.float32) / 255 # [H, W, 3] # images = torch.from_numpy(images).unsqueeze(0) # images = images.to(self.device) # if self.training: # C = images.shape[-1] # images = torch.gather(images.view(B, -1, C), 1, torch.stack(C * [rays['inds']], -1)) # [B, N, 3/4] # results['images'] = images if self.training: bg_coords = torch.gather(self.bg_coords, 1, torch.stack(2 * [rays['inds']], -1)) # [1, N, 2] else: bg_coords = self.bg_coords # [1, N, 2] results['bg_coords'] = bg_coords # results['poses'] = convert_poses(poses) # [B, 6] # results['poses_matrix'] = poses # [B, 4, 4] results['poses'] = poses # [B, 4, 4] return results def dataloader(self): if self.training: # training len(poses) == len(auds) size = self.poses.shape[0] else: # test with novel auds, then use its length if self.auds is not None: size = self.auds.shape[0] # live stream test, use 2 * len(poses), so it naturally mirrors. else: size = 2 * self.poses.shape[0] loader = DataLoader(list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0) loader._data = self # an ugly fix... we need poses in trainer. # do evaluate if has gt images and use self-driven setting loader.has_gt = (self.opt.aud == '') return loader