1729 lines
65 KiB
Python
1729 lines
65 KiB
Python
import os
|
||
import glob
|
||
import tqdm
|
||
import random
|
||
import tensorboardX
|
||
import librosa
|
||
import librosa.filters
|
||
from scipy import signal
|
||
from os.path import basename
|
||
import numpy as np
|
||
import time
|
||
import cv2
|
||
import matplotlib.pyplot as plt
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
import torch.nn.functional as F
|
||
|
||
import trimesh
|
||
import mcubes
|
||
from rich.console import Console
|
||
from torch_ema import ExponentialMovingAverage
|
||
|
||
from packaging import version as pver
|
||
import imageio
|
||
import lpips
|
||
|
||
def custom_meshgrid(*args):
|
||
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
|
||
if pver.parse(torch.__version__) < pver.parse('1.10'):
|
||
return torch.meshgrid(*args)
|
||
else:
|
||
return torch.meshgrid(*args, indexing='ij')
|
||
|
||
def blend_with_mask_cuda(src, dst, mask):
|
||
src = src.permute(2, 0, 1)
|
||
dst = dst.permute(2, 0, 1)
|
||
mask = mask.unsqueeze(0)
|
||
|
||
# Blending
|
||
blended = src * mask + dst * (1 - mask)
|
||
|
||
# Convert back to numpy and return
|
||
return blended.permute(1, 2, 0).detach().cpu().numpy()
|
||
|
||
|
||
def get_audio_features(features, att_mode, index):
|
||
if att_mode == 0:
|
||
return features[[index]]
|
||
elif att_mode == 1:
|
||
left = index - 8
|
||
pad_left = 0
|
||
if left < 0:
|
||
pad_left = -left
|
||
left = 0
|
||
auds = features[left:index]
|
||
if pad_left > 0:
|
||
# pad may be longer than auds, so do not use zeros_like
|
||
auds = torch.cat([torch.zeros(pad_left, *auds.shape[1:], device=auds.device, dtype=auds.dtype), auds], dim=0)
|
||
return auds
|
||
elif att_mode == 2:
|
||
left = index - 4
|
||
right = index + 4
|
||
pad_left = 0
|
||
pad_right = 0
|
||
if left < 0:
|
||
pad_left = -left
|
||
left = 0
|
||
if right > features.shape[0]:
|
||
pad_right = right - features.shape[0]
|
||
right = features.shape[0]
|
||
auds = features[left:right]
|
||
if pad_left > 0:
|
||
auds = torch.cat([torch.zeros_like(auds[:pad_left]), auds], dim=0)
|
||
if pad_right > 0:
|
||
auds = torch.cat([auds, torch.zeros_like(auds[:pad_right])], dim=0) # [8, 16]
|
||
return auds
|
||
else:
|
||
raise NotImplementedError(f'wrong att_mode: {att_mode}')
|
||
|
||
|
||
@torch.jit.script
|
||
def linear_to_srgb(x):
|
||
return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055)
|
||
|
||
|
||
@torch.jit.script
|
||
def srgb_to_linear(x):
|
||
return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)
|
||
|
||
# copied from pytorch3d
|
||
def _angle_from_tan(
|
||
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
|
||
) -> torch.Tensor:
|
||
"""
|
||
Extract the first or third Euler angle from the two members of
|
||
the matrix which are positive constant times its sine and cosine.
|
||
|
||
Args:
|
||
axis: Axis label "X" or "Y or "Z" for the angle we are finding.
|
||
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
|
||
convention.
|
||
data: Rotation matrices as tensor of shape (..., 3, 3).
|
||
horizontal: Whether we are looking for the angle for the third axis,
|
||
which means the relevant entries are in the same row of the
|
||
rotation matrix. If not, they are in the same column.
|
||
tait_bryan: Whether the first and third axes in the convention differ.
|
||
|
||
Returns:
|
||
Euler Angles in radians for each matrix in data as a tensor
|
||
of shape (...).
|
||
"""
|
||
|
||
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
|
||
if horizontal:
|
||
i2, i1 = i1, i2
|
||
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
|
||
if horizontal == even:
|
||
return torch.atan2(data[..., i1], data[..., i2])
|
||
if tait_bryan:
|
||
return torch.atan2(-data[..., i2], data[..., i1])
|
||
return torch.atan2(data[..., i2], -data[..., i1])
|
||
|
||
|
||
def _index_from_letter(letter: str) -> int:
|
||
if letter == "X":
|
||
return 0
|
||
if letter == "Y":
|
||
return 1
|
||
if letter == "Z":
|
||
return 2
|
||
raise ValueError("letter must be either X, Y or Z.")
|
||
|
||
|
||
def matrix_to_euler_angles(matrix: torch.Tensor, convention: str = 'XYZ') -> torch.Tensor:
|
||
"""
|
||
Convert rotations given as rotation matrices to Euler angles in radians.
|
||
|
||
Args:
|
||
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
||
convention: Convention string of three uppercase letters.
|
||
|
||
Returns:
|
||
Euler angles in radians as tensor of shape (..., 3).
|
||
"""
|
||
# if len(convention) != 3:
|
||
# raise ValueError("Convention must have 3 letters.")
|
||
# if convention[1] in (convention[0], convention[2]):
|
||
# raise ValueError(f"Invalid convention {convention}.")
|
||
# for letter in convention:
|
||
# if letter not in ("X", "Y", "Z"):
|
||
# raise ValueError(f"Invalid letter {letter} in convention string.")
|
||
# if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
||
# raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
||
i0 = _index_from_letter(convention[0])
|
||
i2 = _index_from_letter(convention[2])
|
||
tait_bryan = i0 != i2
|
||
if tait_bryan:
|
||
central_angle = torch.asin(
|
||
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
|
||
)
|
||
else:
|
||
central_angle = torch.acos(matrix[..., i0, i0])
|
||
|
||
o = (
|
||
_angle_from_tan(
|
||
convention[0], convention[1], matrix[..., i2], False, tait_bryan
|
||
),
|
||
central_angle,
|
||
_angle_from_tan(
|
||
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
|
||
),
|
||
)
|
||
return torch.stack(o, -1)
|
||
|
||
@torch.cuda.amp.autocast(enabled=False)
|
||
def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
Return the rotation matrices for one of the rotations about an axis
|
||
of which Euler angles describe, for each value of the angle given.
|
||
Args:
|
||
axis: Axis label "X" or "Y or "Z".
|
||
angle: any shape tensor of Euler angles in radians
|
||
Returns:
|
||
Rotation matrices as tensor of shape (..., 3, 3).
|
||
"""
|
||
|
||
cos = torch.cos(angle)
|
||
sin = torch.sin(angle)
|
||
one = torch.ones_like(angle)
|
||
zero = torch.zeros_like(angle)
|
||
|
||
if axis == "X":
|
||
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
|
||
elif axis == "Y":
|
||
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
|
||
elif axis == "Z":
|
||
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
|
||
else:
|
||
raise ValueError("letter must be either X, Y or Z.")
|
||
|
||
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
|
||
|
||
@torch.cuda.amp.autocast(enabled=False)
|
||
def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str='XYZ') -> torch.Tensor:
|
||
"""
|
||
Convert rotations given as Euler angles in radians to rotation matrices.
|
||
Args:
|
||
euler_angles: Euler angles in radians as tensor of shape (..., 3).
|
||
convention: Convention string of three uppercase letters from
|
||
{"X", "Y", and "Z"}.
|
||
Returns:
|
||
Rotation matrices as tensor of shape (..., 3, 3).
|
||
"""
|
||
|
||
# print(euler_angles, euler_angles.dtype)
|
||
|
||
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
|
||
raise ValueError("Invalid input euler angles.")
|
||
if len(convention) != 3:
|
||
raise ValueError("Convention must have 3 letters.")
|
||
if convention[1] in (convention[0], convention[2]):
|
||
raise ValueError(f"Invalid convention {convention}.")
|
||
for letter in convention:
|
||
if letter not in ("X", "Y", "Z"):
|
||
raise ValueError(f"Invalid letter {letter} in convention string.")
|
||
matrices = [
|
||
_axis_angle_rotation(c, e)
|
||
for c, e in zip(convention, torch.unbind(euler_angles, -1))
|
||
]
|
||
|
||
return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
|
||
|
||
|
||
@torch.cuda.amp.autocast(enabled=False)
|
||
def convert_poses(poses):
|
||
# poses: [B, 4, 4]
|
||
# return [B, 3], 4 rot, 3 trans
|
||
out = torch.empty(poses.shape[0], 6, dtype=torch.float32, device=poses.device)
|
||
out[:, :3] = matrix_to_euler_angles(poses[:, :3, :3])
|
||
out[:, 3:] = poses[:, :3, 3]
|
||
return out
|
||
|
||
@torch.cuda.amp.autocast(enabled=False)
|
||
def get_bg_coords(H, W, device):
|
||
X = torch.arange(H, device=device) / (H - 1) * 2 - 1 # in [-1, 1]
|
||
Y = torch.arange(W, device=device) / (W - 1) * 2 - 1 # in [-1, 1]
|
||
xs, ys = custom_meshgrid(X, Y)
|
||
bg_coords = torch.cat([xs.reshape(-1, 1), ys.reshape(-1, 1)], dim=-1).unsqueeze(0) # [1, H*W, 2], in [-1, 1]
|
||
return bg_coords
|
||
|
||
|
||
@torch.cuda.amp.autocast(enabled=False)
|
||
def get_rays(poses, intrinsics, H, W, N=-1, patch_size=1, rect=None):
|
||
''' get rays
|
||
Args:
|
||
poses: [B, 4, 4], cam2world
|
||
intrinsics: [4]
|
||
H, W, N: int
|
||
Returns:
|
||
rays_o, rays_d: [B, N, 3]
|
||
inds: [B, N]
|
||
'''
|
||
|
||
device = poses.device
|
||
B = poses.shape[0]
|
||
fx, fy, cx, cy = intrinsics
|
||
|
||
if rect is not None:
|
||
xmin, xmax, ymin, ymax = rect
|
||
N = (xmax - xmin) * (ymax - ymin)
|
||
|
||
i, j = custom_meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device)) # float
|
||
i = i.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
|
||
j = j.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
|
||
|
||
results = {}
|
||
|
||
if N > 0:
|
||
N = min(N, H*W)
|
||
|
||
if patch_size > 1:
|
||
|
||
# random sample left-top cores.
|
||
# NOTE: this impl will lead to less sampling on the image corner pixels... but I don't have other ideas.
|
||
num_patch = N // (patch_size ** 2)
|
||
inds_x = torch.randint(0, H - patch_size, size=[num_patch], device=device)
|
||
inds_y = torch.randint(0, W - patch_size, size=[num_patch], device=device)
|
||
inds = torch.stack([inds_x, inds_y], dim=-1) # [np, 2]
|
||
# all_inds = torch.randperm((H - patch_size + 1) * (W - patch_size + 1), device=device)[:num_patch]
|
||
# all_inds, _ = torch.sort(all_inds)
|
||
#
|
||
# inds_x = all_inds // (W - patch_size)
|
||
# inds_y = all_inds % (W - patch_size)
|
||
# inds = torch.stack([inds_x, inds_y], dim=-1) # [np, 2]
|
||
|
||
# create meshgrid for each patch
|
||
pi, pj = custom_meshgrid(torch.arange(patch_size, device=device), torch.arange(patch_size, device=device))
|
||
offsets = torch.stack([pi.reshape(-1), pj.reshape(-1)], dim=-1) # [p^2, 2]
|
||
|
||
inds = inds.unsqueeze(1) + offsets.unsqueeze(0) # [np, p^2, 2]
|
||
inds = inds.view(-1, 2) # [N, 2]
|
||
inds = inds[:, 0] * W + inds[:, 1] # [N], flatten
|
||
|
||
inds = inds.expand([B, N])
|
||
|
||
# only get rays in the specified rect
|
||
elif rect is not None:
|
||
# assert B == 1
|
||
mask = torch.zeros(H, W, dtype=torch.bool, device=device)
|
||
xmin, xmax, ymin, ymax = rect
|
||
mask[xmin:xmax, ymin:ymax] = 1
|
||
inds = torch.where(mask.view(-1))[0] # [nzn]
|
||
inds = inds.unsqueeze(0) # [1, N]
|
||
|
||
else:
|
||
inds = torch.randint(0, H*W, size=[N], device=device) # may duplicate
|
||
inds = inds.expand([B, N])
|
||
|
||
# inds = torch.randperm(H * W, device=device)[:N]
|
||
# inds, _ = torch.sort(inds)
|
||
# inds = inds.expand([B, N])
|
||
|
||
i = torch.gather(i, -1, inds)
|
||
j = torch.gather(j, -1, inds)
|
||
|
||
|
||
else:
|
||
inds = torch.arange(H*W, device=device).expand([B, H*W])
|
||
|
||
results['i'] = i
|
||
results['j'] = j
|
||
results['inds'] = inds
|
||
|
||
zs = torch.ones_like(i)
|
||
xs = (i - cx) / fx * zs
|
||
ys = (j - cy) / fy * zs
|
||
directions = torch.stack((xs, ys, zs), dim=-1)
|
||
directions = directions / torch.norm(directions, dim=-1, keepdim=True)
|
||
|
||
rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3)
|
||
|
||
rays_o = poses[..., :3, 3] # [B, 3]
|
||
rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3]
|
||
|
||
results['rays_o'] = rays_o
|
||
results['rays_d'] = rays_d
|
||
|
||
return results
|
||
|
||
|
||
def seed_everything(seed):
|
||
random.seed(seed)
|
||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||
np.random.seed(seed)
|
||
torch.manual_seed(seed)
|
||
torch.cuda.manual_seed(seed)
|
||
#torch.backends.cudnn.deterministic = True
|
||
#torch.backends.cudnn.benchmark = True
|
||
|
||
|
||
def torch_vis_2d(x, renormalize=False):
|
||
# x: [3, H, W] or [1, H, W] or [H, W]
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import torch
|
||
|
||
if isinstance(x, torch.Tensor):
|
||
if len(x.shape) == 3:
|
||
x = x.permute(1,2,0).squeeze()
|
||
x = x.detach().cpu().numpy()
|
||
|
||
print(f'[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}')
|
||
|
||
x = x.astype(np.float32)
|
||
|
||
# renormalize
|
||
if renormalize:
|
||
x = (x - x.min(axis=0, keepdims=True)) / (x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8)
|
||
|
||
plt.imshow(x)
|
||
plt.show()
|
||
|
||
|
||
def extract_fields(bound_min, bound_max, resolution, query_func, S=128):
|
||
|
||
X = torch.linspace(bound_min[0], bound_max[0], resolution).split(S)
|
||
Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(S)
|
||
Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(S)
|
||
|
||
u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
|
||
with torch.no_grad():
|
||
for xi, xs in enumerate(X):
|
||
for yi, ys in enumerate(Y):
|
||
for zi, zs in enumerate(Z):
|
||
xx, yy, zz = custom_meshgrid(xs, ys, zs)
|
||
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3]
|
||
val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z]
|
||
u[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val
|
||
return u
|
||
|
||
|
||
def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
|
||
#print('threshold: {}'.format(threshold))
|
||
u = extract_fields(bound_min, bound_max, resolution, query_func)
|
||
|
||
#print(u.shape, u.max(), u.min(), np.percentile(u, 50))
|
||
|
||
vertices, triangles = mcubes.marching_cubes(u, threshold)
|
||
|
||
b_max_np = bound_max.detach().cpu().numpy()
|
||
b_min_np = bound_min.detach().cpu().numpy()
|
||
|
||
vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
||
return vertices, triangles
|
||
|
||
def ssim_1d_loss(pred, true, C1=1e-4, C2=9e-4):
|
||
"""
|
||
Compute 1D SSIM loss between two signals.
|
||
Args:
|
||
pred: predicted signal, [1, 512*512, 3]
|
||
true: ground truth signal, [1, 512*512, 3]
|
||
Returns:
|
||
ssim_val: ssim index of two input signals
|
||
"""
|
||
if pred.size() != true.size():
|
||
raise ValueError(f'Expected input size ({pred.size()}) to match target size ({true.size()}).')
|
||
|
||
mu1 = pred.mean(dim=1, keepdim=True)
|
||
mu2 = true.mean(dim=1, keepdim=True)
|
||
|
||
mu1_sq = mu1.pow(2)
|
||
mu2_sq = mu2.pow(2)
|
||
mu1_mu2 = mu1 * mu2
|
||
|
||
sigma1_sq = (pred * pred).mean(dim=1, keepdim=True) - mu1_sq
|
||
sigma2_sq = (true * true).mean(dim=1, keepdim=True) - mu2_sq
|
||
sigma12 = (pred * true).mean(dim=1, keepdim=True) - mu1_mu2
|
||
|
||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
||
ssim_val = ssim_map.mean()
|
||
|
||
return ssim_val
|
||
|
||
class PSNRMeter:
|
||
def __init__(self):
|
||
self.V = 0
|
||
self.N = 0
|
||
|
||
def clear(self):
|
||
self.V = 0
|
||
self.N = 0
|
||
|
||
def prepare_inputs(self, *inputs):
|
||
outputs = []
|
||
for i, inp in enumerate(inputs):
|
||
if torch.is_tensor(inp):
|
||
inp = inp.detach().cpu().numpy()
|
||
outputs.append(inp)
|
||
|
||
return outputs
|
||
|
||
def update(self, preds, truths):
|
||
preds, truths = self.prepare_inputs(preds, truths) # [B, N, 3] or [B, H, W, 3], range in [0, 1]
|
||
|
||
# simplified since max_pixel_value is 1 here.
|
||
psnr = -10 * np.log10(np.mean((preds - truths) ** 2))
|
||
|
||
self.V += psnr
|
||
self.N += 1
|
||
|
||
def measure(self):
|
||
return self.V / self.N
|
||
|
||
def write(self, writer, global_step, prefix=""):
|
||
writer.add_scalar(os.path.join(prefix, "PSNR"), self.measure(), global_step)
|
||
|
||
def report(self):
|
||
return f'PSNR = {self.measure():.6f}'
|
||
|
||
class LPIPSMeter:
|
||
def __init__(self, net='alex', device=None):
|
||
self.V = 0
|
||
self.N = 0
|
||
self.net = net
|
||
|
||
self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
self.fn = lpips.LPIPS(net=net).eval().to(self.device)
|
||
|
||
def clear(self):
|
||
self.V = 0
|
||
self.N = 0
|
||
|
||
def prepare_inputs(self, *inputs):
|
||
outputs = []
|
||
for i, inp in enumerate(inputs):
|
||
inp = inp.permute(0, 3, 1, 2).contiguous() # [B, 3, H, W]
|
||
inp = inp.to(self.device)
|
||
outputs.append(inp)
|
||
return outputs
|
||
|
||
def update(self, preds, truths):
|
||
preds, truths = self.prepare_inputs(preds, truths) # [B, H, W, 3] --> [B, 3, H, W], range in [0, 1]
|
||
v = self.fn(truths, preds, normalize=True).item() # normalize=True: [0, 1] to [-1, 1]
|
||
self.V += v
|
||
self.N += 1
|
||
|
||
def measure(self):
|
||
return self.V / self.N
|
||
|
||
def write(self, writer, global_step, prefix=""):
|
||
writer.add_scalar(os.path.join(prefix, f"LPIPS ({self.net})"), self.measure(), global_step)
|
||
|
||
def report(self):
|
||
return f'LPIPS ({self.net}) = {self.measure():.6f}'
|
||
|
||
|
||
class LMDMeter:
|
||
def __init__(self, backend='dlib', region='mouth'):
|
||
self.backend = backend
|
||
self.region = region # mouth or face
|
||
|
||
if self.backend == 'dlib':
|
||
import dlib
|
||
|
||
# load checkpoint manually
|
||
self.predictor_path = './shape_predictor_68_face_landmarks.dat'
|
||
if not os.path.exists(self.predictor_path):
|
||
raise FileNotFoundError('Please download dlib checkpoint from http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2')
|
||
|
||
self.detector = dlib.get_frontal_face_detector()
|
||
self.predictor = dlib.shape_predictor(self.predictor_path)
|
||
|
||
else:
|
||
|
||
import face_alignment
|
||
try:
|
||
self.predictor = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
|
||
except:
|
||
self.predictor = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False)
|
||
|
||
self.V = 0
|
||
self.N = 0
|
||
|
||
def get_landmarks(self, img):
|
||
|
||
if self.backend == 'dlib':
|
||
dets = self.detector(img, 1)
|
||
for det in dets:
|
||
shape = self.predictor(img, det)
|
||
# ref: https://github.com/PyImageSearch/imutils/blob/c12f15391fcc945d0d644b85194b8c044a392e0a/imutils/face_utils/helpers.py
|
||
lms = np.zeros((68, 2), dtype=np.int32)
|
||
for i in range(0, 68):
|
||
lms[i, 0] = shape.part(i).x
|
||
lms[i, 1] = shape.part(i).y
|
||
break
|
||
|
||
else:
|
||
lms = self.predictor.get_landmarks(img)[-1]
|
||
|
||
# self.vis_landmarks(img, lms)
|
||
lms = lms.astype(np.float32)
|
||
|
||
return lms
|
||
|
||
def vis_landmarks(self, img, lms):
|
||
plt.imshow(img)
|
||
plt.plot(lms[48:68, 0], lms[48:68, 1], marker='o', markersize=1, linestyle='-', lw=2)
|
||
plt.show()
|
||
|
||
def clear(self):
|
||
self.V = 0
|
||
self.N = 0
|
||
|
||
def prepare_inputs(self, *inputs):
|
||
outputs = []
|
||
for i, inp in enumerate(inputs):
|
||
inp = inp.detach().cpu().numpy()
|
||
inp = (inp * 255).astype(np.uint8)
|
||
outputs.append(inp)
|
||
return outputs
|
||
|
||
def update(self, preds, truths):
|
||
# assert B == 1
|
||
preds, truths = self.prepare_inputs(preds[0], truths[0]) # [H, W, 3] numpy array
|
||
|
||
# get lms
|
||
lms_pred = self.get_landmarks(preds)
|
||
lms_truth = self.get_landmarks(truths)
|
||
|
||
if self.region == 'mouth':
|
||
lms_pred = lms_pred[48:68]
|
||
lms_truth = lms_truth[48:68]
|
||
|
||
# avarage
|
||
lms_pred = lms_pred - lms_pred.mean(0)
|
||
lms_truth = lms_truth - lms_truth.mean(0)
|
||
|
||
# distance
|
||
dist = np.sqrt(((lms_pred - lms_truth) ** 2).sum(1)).mean(0)
|
||
|
||
self.V += dist
|
||
self.N += 1
|
||
|
||
def measure(self):
|
||
return self.V / self.N
|
||
|
||
def write(self, writer, global_step, prefix=""):
|
||
writer.add_scalar(os.path.join(prefix, f"LMD ({self.backend})"), self.measure(), global_step)
|
||
|
||
def report(self):
|
||
return f'LMD ({self.backend}) = {self.measure():.6f}'
|
||
|
||
|
||
class Trainer(object):
|
||
def __init__(self,
|
||
name, # name of this experiment
|
||
opt, # extra conf
|
||
model, # network
|
||
criterion=None, # loss function, if None, assume inline implementation in train_step
|
||
optimizer=None, # optimizer
|
||
ema_decay=None, # if use EMA, set the decay
|
||
ema_update_interval=1000, # update ema per $ training steps.
|
||
lr_scheduler=None, # scheduler
|
||
metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric.
|
||
local_rank=0, # which GPU am I
|
||
world_size=1, # total num of GPUs
|
||
device=None, # device to use, usually setting to None is OK. (auto choose device)
|
||
mute=False, # whether to mute all print
|
||
fp16=False, # amp optimize level
|
||
eval_interval=1, # eval once every $ epoch
|
||
max_keep_ckpt=50, # max num of saved ckpts in disk
|
||
workspace='workspace', # workspace to save logs & ckpts
|
||
best_mode='min', # the smaller/larger result, the better
|
||
use_loss_as_metric=True, # use loss as the first metric
|
||
report_metric_at_train=False, # also report metrics at training
|
||
use_checkpoint="latest", # which ckpt to use at init time
|
||
use_tensorboardX=True, # whether to use tensorboard for logging
|
||
scheduler_update_every_step=False, # whether to call scheduler.step() after every train step
|
||
):
|
||
|
||
self.name = name
|
||
self.opt = opt
|
||
self.mute = mute
|
||
self.metrics = metrics
|
||
self.local_rank = local_rank
|
||
self.world_size = world_size
|
||
self.workspace = workspace
|
||
self.ema_decay = ema_decay
|
||
self.ema_update_interval = ema_update_interval
|
||
self.fp16 = fp16
|
||
self.best_mode = best_mode
|
||
self.use_loss_as_metric = use_loss_as_metric
|
||
self.report_metric_at_train = report_metric_at_train
|
||
self.max_keep_ckpt = max_keep_ckpt
|
||
self.eval_interval = eval_interval
|
||
self.use_checkpoint = use_checkpoint
|
||
self.use_tensorboardX = use_tensorboardX
|
||
self.flip_finetune_lips = self.opt.finetune_lips
|
||
self.flip_init_lips = self.opt.init_lips
|
||
self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S")
|
||
self.scheduler_update_every_step = scheduler_update_every_step
|
||
self.device = device if device is not None else torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu')
|
||
self.console = Console()
|
||
|
||
model.to(self.device)
|
||
if self.world_size > 1:
|
||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
|
||
self.model = model
|
||
|
||
if isinstance(criterion, nn.Module):
|
||
criterion.to(self.device)
|
||
self.criterion = criterion
|
||
self.criterionL1 = nn.L1Loss().to(self.device)
|
||
if optimizer is None:
|
||
self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam
|
||
else:
|
||
self.optimizer = optimizer(self.model)
|
||
|
||
if lr_scheduler is None:
|
||
self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler
|
||
else:
|
||
self.lr_scheduler = lr_scheduler(self.optimizer)
|
||
|
||
if ema_decay is not None:
|
||
self.ema = ExponentialMovingAverage(self.model.parameters(), decay=ema_decay)
|
||
else:
|
||
self.ema = None
|
||
|
||
self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)
|
||
|
||
# optionally use LPIPS loss for patch-based training
|
||
if self.opt.patch_size > 1 or self.opt.finetune_lips or True:
|
||
import lpips
|
||
# self.criterion_lpips_vgg = lpips.LPIPS(net='vgg').to(self.device)
|
||
self.criterion_lpips_alex = lpips.LPIPS(net='alex').to(self.device)
|
||
|
||
# variable init
|
||
self.epoch = 0
|
||
self.global_step = 0
|
||
self.local_step = 0
|
||
self.stats = {
|
||
"loss": [],
|
||
"valid_loss": [],
|
||
"results": [], # metrics[0], or valid_loss
|
||
"checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt
|
||
"best_result": None,
|
||
}
|
||
|
||
# auto fix
|
||
if len(metrics) == 0 or self.use_loss_as_metric:
|
||
self.best_mode = 'min'
|
||
|
||
# workspace prepare
|
||
self.log_ptr = None
|
||
if self.workspace is not None:
|
||
os.makedirs(self.workspace, exist_ok=True)
|
||
self.log_path = os.path.join(workspace, f"log_{self.name}.txt")
|
||
self.log_ptr = open(self.log_path, "a+")
|
||
|
||
self.ckpt_path = os.path.join(self.workspace, 'checkpoints')
|
||
self.best_path = f"{self.ckpt_path}/{self.name}.pth"
|
||
os.makedirs(self.ckpt_path, exist_ok=True)
|
||
|
||
self.log(f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}')
|
||
self.log(f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}')
|
||
|
||
if self.workspace is not None:
|
||
if self.use_checkpoint == "scratch":
|
||
self.log("[INFO] Training from scratch ...")
|
||
elif self.use_checkpoint == "latest":
|
||
self.log("[INFO] Loading latest checkpoint ...")
|
||
self.load_checkpoint()
|
||
elif self.use_checkpoint == "latest_model":
|
||
self.log("[INFO] Loading latest checkpoint (model only)...")
|
||
self.load_checkpoint(model_only=True)
|
||
elif self.use_checkpoint == "best":
|
||
if os.path.exists(self.best_path):
|
||
self.log("[INFO] Loading best checkpoint ...")
|
||
self.load_checkpoint(self.best_path)
|
||
else:
|
||
self.log(f"[INFO] {self.best_path} not found, loading latest ...")
|
||
self.load_checkpoint()
|
||
else: # path to ckpt
|
||
self.log(f"[INFO] Loading {self.use_checkpoint} ...")
|
||
self.load_checkpoint(self.use_checkpoint)
|
||
|
||
def __del__(self):
|
||
if self.log_ptr:
|
||
self.log_ptr.close()
|
||
|
||
|
||
def log(self, *args, **kwargs):
|
||
if self.local_rank == 0:
|
||
if not self.mute:
|
||
#print(*args)
|
||
self.console.print(*args, **kwargs)
|
||
if self.log_ptr:
|
||
print(*args, file=self.log_ptr)
|
||
self.log_ptr.flush() # write immediately to file
|
||
|
||
### ------------------------------
|
||
|
||
def train_step(self, data):
|
||
|
||
rays_o = data['rays_o'] # [B, N, 3]
|
||
rays_d = data['rays_d'] # [B, N, 3]
|
||
bg_coords = data['bg_coords'] # [1, N, 2]
|
||
poses = data['poses'] # [B, 6]
|
||
face_mask = data['face_mask'] # [B, N]
|
||
upface_mask = data['upface_mask'] # [B, N]
|
||
lowface_mask = data['lowface_mask'] # [B, N]
|
||
eye_mask = data['eye_mask'] # [B, N]
|
||
lhalf_mask = data['lhalf_mask']
|
||
eye = data['eye'] # [B, 1]
|
||
auds = data['auds'] # [B, 29, 16]
|
||
index = data['index'] # [B]
|
||
loss_perception =0
|
||
|
||
if not self.opt.torso:
|
||
rgb = data['images'] # [B, N, 3]
|
||
else:
|
||
rgb = data['bg_torso_color']
|
||
|
||
B, N, C = rgb.shape
|
||
|
||
if self.opt.color_space == 'linear':
|
||
rgb[..., :3] = srgb_to_linear(rgb[..., :3])
|
||
|
||
bg_color = data['bg_color']
|
||
|
||
if not self.opt.torso:
|
||
outputs = self.model.render(rays_o, rays_d, auds, bg_coords, poses, eye=eye, index=index, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False if (self.opt.patch_size <= 1 and not self.opt.train_camera) else True, **vars(self.opt))
|
||
else:
|
||
outputs = self.model.render_torso(rays_o, rays_d, auds, bg_coords, poses, eye=eye, index=index, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False if (self.opt.patch_size <= 1 and not self.opt.train_camera) else True, **vars(self.opt))
|
||
|
||
if not self.opt.torso:
|
||
pred_rgb = outputs['image']
|
||
else:
|
||
pred_rgb = outputs['torso_color']
|
||
|
||
|
||
# loss factor
|
||
step_factor = min(self.global_step / self.opt.iters, 1.0)
|
||
# MSE loss
|
||
loss = self.criterion(pred_rgb, rgb).mean(-1) # [B, N, 3] --> [B, N]
|
||
|
||
if self.opt.torso:
|
||
loss = loss.mean()
|
||
loss += ((1 - self.model.anchor_points[:, 3])**2).mean()
|
||
return pred_rgb, rgb, loss
|
||
|
||
|
||
if self.opt.unc_loss and not self.flip_finetune_lips:
|
||
alpha = 0.2
|
||
uncertainty = outputs['uncertainty'] # [N], abs sum
|
||
beta = uncertainty + 1
|
||
|
||
unc_weight = F.softmax(uncertainty, dim=-1) * N
|
||
loss *= alpha + (1-alpha)*((1 - step_factor) + step_factor * unc_weight.detach()).clamp(0, 10)
|
||
|
||
beta = uncertainty + 1
|
||
norm_rgb = torch.norm((pred_rgb - rgb), dim=-1).detach()
|
||
loss_u = norm_rgb / (2*beta**2) + (torch.log(beta)**2) / 2
|
||
loss_u *= face_mask.view(-1)
|
||
|
||
loss += 0.01 * step_factor * loss_u
|
||
|
||
|
||
loss_static_uncertainty = (uncertainty * (~face_mask.view(-1)))
|
||
loss += 0.01 * step_factor * loss_static_uncertainty
|
||
|
||
# patch-based rendering
|
||
if self.opt.patch_size > 1 and not self.opt.finetune_lips:
|
||
rgb = rgb.view(-1, self.opt.patch_size, self.opt.patch_size, 3).permute(0, 3, 1, 2).contiguous() * 2 - 1
|
||
pred_rgb = pred_rgb.view(-1, self.opt.patch_size, self.opt.patch_size, 3).permute(0, 3, 1, 2).contiguous() * 2 - 1
|
||
|
||
|
||
loss_lpips = self.criterion_lpips_alex(pred_rgb, rgb)
|
||
|
||
loss = loss + 0.1 * loss_lpips
|
||
|
||
# lips finetune
|
||
if self.opt.finetune_lips:
|
||
xmin, xmax, ymin, ymax = data['rect']
|
||
rgb = rgb.view(-1, xmax - xmin, ymax - ymin, 3).permute(0, 3, 1, 2).contiguous() * 2 - 1
|
||
pred_rgb = pred_rgb.view(-1, xmax - xmin, ymax - ymin, 3).permute(0, 3, 1, 2).contiguous() * 2 - 1
|
||
|
||
padding_h = max(0, (32 - rgb.shape[-2] + 1) // 2)
|
||
padding_w = max(0, (32 - rgb.shape[-1] + 1) // 2)
|
||
|
||
if padding_w or padding_h:
|
||
rgb = torch.nn.functional.pad(rgb, (padding_w, padding_w, padding_h, padding_h))
|
||
pred_rgb = torch.nn.functional.pad(pred_rgb, (padding_w, padding_w, padding_h, padding_h))
|
||
|
||
loss = loss + 0.01 * self.criterion_lpips_alex(pred_rgb, rgb)
|
||
# flip every step... if finetune lips
|
||
if self.flip_finetune_lips:
|
||
self.opt.finetune_lips = not self.opt.finetune_lips
|
||
|
||
|
||
loss = loss.mean()
|
||
|
||
if self.opt.patch_size > 1 and not self.opt.finetune_lips:
|
||
if self.opt.pyramid_loss:
|
||
loss = loss + 0.1 * loss_perception
|
||
# print('loss', loss.item())
|
||
|
||
# weights_sum loss
|
||
# entropy to encourage weights_sum to be 0 or 1.
|
||
if self.opt.torso:
|
||
alphas = outputs['torso_alpha'].clamp(1e-5, 1 - 1e-5)
|
||
# alphas = alphas ** 2 # skewed entropy, favors 0 over 1
|
||
loss_ws = - alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)
|
||
loss = loss + 1e-4 * loss_ws.mean()
|
||
|
||
else:
|
||
alphas = outputs['weights_sum'].clamp(1e-5, 1 - 1e-5)
|
||
loss_ws = - alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)
|
||
loss = loss + 1e-4 * loss_ws.mean()
|
||
|
||
# aud att loss (regions out of face should be static)
|
||
if self.opt.amb_aud_loss and not self.opt.torso:
|
||
ambient_aud = outputs['ambient_aud']
|
||
loss_amb_aud = (ambient_aud * (~lowface_mask.view(-1))).mean()
|
||
# gradually increase it
|
||
lambda_amb = step_factor * self.opt.lambda_amb
|
||
loss += lambda_amb * loss_amb_aud
|
||
|
||
# eye att loss
|
||
if self.opt.amb_eye_loss and not self.opt.torso:
|
||
ambient_eye = outputs['ambient_eye']
|
||
loss_cross = ((ambient_eye)*(lowface_mask.view(-1))).mean()
|
||
lambda_amb = step_factor * self.opt.lambda_amb
|
||
loss += lambda_amb * loss_cross
|
||
|
||
# regularize
|
||
if self.global_step % 16 == 0 and not self.flip_finetune_lips:
|
||
xyzs, dirs, enc_a, ind_code, eye = outputs['rays']
|
||
xyz_delta = (torch.rand(size=xyzs.shape, dtype=xyzs.dtype, device=xyzs.device) * 2 - 1) * 1e-3
|
||
with torch.no_grad():
|
||
sigmas_raw, rgbs_raw, ambient_aud_raw, ambient_eye_raw, unc_raw = self.model(xyzs, dirs, enc_a.detach(), ind_code.detach(), eye)
|
||
sigmas_reg, rgbs_reg, ambient_aud_reg, ambient_eye_reg, unc_reg = self.model(xyzs+xyz_delta, dirs, enc_a.detach(), ind_code.detach(), eye)
|
||
|
||
lambda_reg = step_factor * 1e-5
|
||
reg_loss = 0
|
||
if self.opt.unc_loss:
|
||
reg_loss += self.criterion(unc_raw, unc_reg).mean()
|
||
if self.opt.amb_aud_loss:
|
||
reg_loss += self.criterion(ambient_aud_raw, ambient_aud_reg).mean()
|
||
if self.opt.amb_eye_loss:
|
||
reg_loss += self.criterion(ambient_eye_raw, ambient_eye_reg).mean()
|
||
|
||
loss += reg_loss * lambda_reg
|
||
|
||
return pred_rgb, rgb, loss
|
||
|
||
|
||
def eval_step(self, data):
|
||
|
||
rays_o = data['rays_o'] # [B, N, 3]
|
||
rays_d = data['rays_d'] # [B, N, 3]
|
||
bg_coords = data['bg_coords'] # [1, N, 2]
|
||
poses = data['poses'] # [B, 7]
|
||
|
||
images = data['images'] # [B, H, W, 3/4]
|
||
if self.opt.portrait:
|
||
images = data['bg_gt_images']
|
||
auds = data['auds']
|
||
index = data['index'] # [B]
|
||
eye = data['eye'] # [B, 1]
|
||
|
||
B, H, W, C = images.shape
|
||
|
||
if self.opt.color_space == 'linear':
|
||
images[..., :3] = srgb_to_linear(images[..., :3])
|
||
|
||
# eval with fixed background color
|
||
# bg_color = 1
|
||
bg_color = data['bg_color']
|
||
|
||
outputs = self.model.render(rays_o, rays_d, auds, bg_coords, poses, eye=eye, index=index, staged=True, bg_color=bg_color, perturb=False, **vars(self.opt))
|
||
|
||
pred_rgb = outputs['image'].reshape(B, H, W, 3)
|
||
pred_depth = outputs['depth'].reshape(B, H, W)
|
||
pred_ambient_aud = outputs['ambient_aud'].reshape(B, H, W)
|
||
pred_ambient_eye = outputs['ambient_eye'].reshape(B, H, W)
|
||
pred_uncertainty = outputs['uncertainty'].reshape(B, H, W)
|
||
|
||
loss_raw = self.criterion(pred_rgb, images)
|
||
loss = loss_raw.mean()
|
||
|
||
return pred_rgb, pred_depth, pred_ambient_aud, pred_ambient_eye, pred_uncertainty, images, loss, loss_raw
|
||
|
||
# 定义测试步骤函数,增加了对背景颜色和扰动参数的灵活控制
|
||
def test_step(self, data, bg_color=None, perturb=False):
|
||
"""
|
||
对输入数据进行渲染,返回预测的RGB图像和深度图。
|
||
|
||
Args:
|
||
data (dict): 包含射线起始点、方向、背景坐标、姿态、音频数据、索引、图像高度和宽度等信息的字典。
|
||
bg_color (Optional[torch.Tensor]): 背景颜色,形状为[3],表示RGB颜色值。默认为None,使用输入数据中的背景颜色。
|
||
perturb (bool): 是否在渲染时加入扰动。默认为False。
|
||
|
||
Returns:
|
||
Tuple[torch.Tensor, torch.Tensor]: 返回预测的RGB图像和深度图,形状分别为[B*N, H, W, 3]和[B*N, H, W]。
|
||
|
||
"""
|
||
|
||
# 提取输入数据中的射线起始点、方向、背景坐标和姿态信息
|
||
rays_o = data['rays_o'] # [B, N, 3]
|
||
rays_d = data['rays_d'] # [B, N, 3]
|
||
bg_coords = data['bg_coords'] # [1, N, 2]
|
||
poses = data['poses'] # [B, 7]
|
||
|
||
# 提取音频数据和索引
|
||
auds = data['auds'] # [B, 29, 16]
|
||
index = data['index']
|
||
|
||
# 提取图像高度和宽度
|
||
H, W = data['H'], data['W']
|
||
|
||
# 在测试时允许使用固定的眼部区域(避免眨眼效果)
|
||
if self.opt.exp_eye and self.opt.fix_eye >= 0:
|
||
eye = torch.FloatTensor([self.opt.fix_eye]).view(1, 1).to(self.device)
|
||
else:
|
||
eye = data['eye'] # [B, 1]
|
||
|
||
# 处理背景颜色参数
|
||
if bg_color is not None:
|
||
bg_color = bg_color.to(self.device)
|
||
else:
|
||
bg_color = data['bg_color']
|
||
|
||
# 设置模型为测试模式
|
||
self.model.testing = True
|
||
# 调用模型的渲染函数进行推理
|
||
outputs = self.model.render(rays_o, rays_d, auds, bg_coords, poses, eye=eye, index=index, staged=True, bg_color=bg_color, perturb=perturb, **vars(self.opt))
|
||
# 恢复模型为训练模式
|
||
self.model.testing = False
|
||
|
||
# 重塑输出的RGB图像和深度图
|
||
pred_rgb = outputs['image'].reshape(-1, H, W, 3)
|
||
pred_depth = outputs['depth'].reshape(-1, H, W)
|
||
|
||
# 返回预测的RGB图像和深度图
|
||
return pred_rgb, pred_depth
|
||
|
||
|
||
def save_mesh(self, save_path=None, resolution=256, threshold=10):
|
||
|
||
if save_path is None:
|
||
save_path = os.path.join(self.workspace, 'meshes', f'{self.name}_{self.epoch}.ply')
|
||
|
||
self.log(f"==> Saving mesh to {save_path}")
|
||
|
||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||
|
||
def query_func(pts):
|
||
with torch.no_grad():
|
||
with torch.cuda.amp.autocast(enabled=self.fp16):
|
||
sigma = self.model.density(pts.to(self.device))['sigma']
|
||
return sigma
|
||
|
||
vertices, triangles = extract_geometry(self.model.aabb_infer[:3], self.model.aabb_infer[3:], resolution=resolution, threshold=threshold, query_func=query_func)
|
||
|
||
mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault...
|
||
mesh.export(save_path)
|
||
|
||
self.log(f"==> Finished saving mesh.")
|
||
|
||
### ------------------------------
|
||
|
||
def train(self, train_loader, valid_loader, max_epochs):
|
||
if self.use_tensorboardX and self.local_rank == 0:
|
||
self.writer = tensorboardX.SummaryWriter(os.path.join(self.workspace, "run", self.name))
|
||
|
||
# mark untrained region (i.e., not covered by any camera from the training dataset)
|
||
if self.model.cuda_ray:
|
||
self.model.mark_untrained_grid(train_loader._data.poses, train_loader._data.intrinsics)
|
||
|
||
for epoch in range(self.epoch + 1, max_epochs + 1):
|
||
self.epoch = epoch
|
||
|
||
self.train_one_epoch(train_loader)
|
||
|
||
if self.workspace is not None and self.local_rank == 0:
|
||
self.save_checkpoint(full=True, best=False)
|
||
|
||
if self.epoch % self.eval_interval == 0:
|
||
self.evaluate_one_epoch(valid_loader)
|
||
self.save_checkpoint(full=False, best=True)
|
||
|
||
if self.use_tensorboardX and self.local_rank == 0:
|
||
self.writer.close()
|
||
|
||
def evaluate(self, loader, name=None):
|
||
self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX
|
||
self.evaluate_one_epoch(loader, name)
|
||
self.use_tensorboardX = use_tensorboardX
|
||
|
||
# Function to blend two images with a mask
|
||
|
||
def test(self, loader, save_path=None, name=None, write_image=False):
|
||
"""
|
||
对模型进行测试,并保存结果到指定路径。
|
||
|
||
Args:
|
||
loader (torch.utils.data.DataLoader): 数据加载器。
|
||
save_path (str, optional): 保存结果的路径。默认为None,此时将结果保存到self.workspace/results目录下。
|
||
name (str, optional): 结果文件的名称。默认为None,此时将使用self.name_ep{self.epoch:04d}作为文件名。
|
||
write_image (bool, optional): 是否将结果保存为图片。默认为False。
|
||
|
||
Returns:
|
||
None
|
||
"""
|
||
|
||
# 如果未指定保存路径,则默认保存到工作空间的结果目录下
|
||
if save_path is None:
|
||
save_path = os.path.join(self.workspace, 'results')
|
||
|
||
# 如果未指定文件名,则使用默认的命名规则
|
||
if name is None:
|
||
name = f'{self.name}_ep{self.epoch:04d}'
|
||
|
||
# 确保保存路径存在,如果不存在则创建
|
||
os.makedirs(save_path, exist_ok=True)
|
||
|
||
# 记录测试开始和结果保存路径
|
||
self.log(f"==> Start Test, save results to {save_path}")
|
||
|
||
# 初始化进度条
|
||
pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
|
||
# 设定模型为评估模式
|
||
self.model.eval()
|
||
|
||
# 用于保存所有预测结果和深度图
|
||
all_preds = []
|
||
all_preds_depth = []
|
||
all_full_merged_imgs = []
|
||
|
||
# 禁止梯度计算,减少内存消耗
|
||
with torch.no_grad():
|
||
|
||
# 遍历数据加载器中的每个数据
|
||
for i, data in enumerate(loader):
|
||
|
||
# 使用混合精度进行测试
|
||
with torch.cuda.amp.autocast(enabled=self.fp16):
|
||
preds, preds_depth = self.test_step(data)
|
||
|
||
# 定义保存RGB图像和深度图的路径
|
||
path = os.path.join(save_path, f'{name}_{i:04d}_rgb.png')
|
||
path_depth = os.path.join(save_path, f'{name}_{i:04d}_depth.png')
|
||
|
||
# 如果需要,将线性颜色空间转换为sRGB
|
||
if self.opt.color_space == 'linear':
|
||
preds = linear_to_srgb(preds)
|
||
|
||
# 处理肖像模式下的预测结果
|
||
if self.opt.portrait:
|
||
pred = blend_with_mask_cuda(preds[0], data["bg_gt_images"].squeeze(0), data["bg_face_mask"].squeeze(0))
|
||
pred = (pred * 255).astype(np.uint8)
|
||
else:
|
||
# 非肖像模式下处理预测结果
|
||
pred = preds[0].detach().cpu().numpy()
|
||
pred = (pred * 255).astype(np.uint8)
|
||
|
||
#如果是fullbody模式,则需要将预测结果和背景图像进行融合
|
||
if self.opt.fullbody:
|
||
image_fullbody = (data['full_body_img'][0] * 255).astype(np.uint8)
|
||
start_x = loader._data.crop_offset_x # 合并后小图片的起始x坐标
|
||
start_y = loader._data.crop_offset_y # 合并后小图片的起始y坐标
|
||
# print(f"pred shape: {pred.shape}")
|
||
# print(f"image_fullbody shape: {image_fullbody.shape}")
|
||
image_fullbody[start_y : start_y + pred.shape[0], start_x : start_x + pred.shape[1]] = pred
|
||
all_full_merged_imgs.append(image_fullbody)
|
||
|
||
# 处理并保存深度图
|
||
pred_depth = preds_depth[0].detach().cpu().numpy()
|
||
pred_depth = (pred_depth * 255).astype(np.uint8)
|
||
|
||
# 如果需要,将预测结果和深度图保存为图片
|
||
if write_image:
|
||
imageio.imwrite(path, pred)
|
||
imageio.imwrite(path_depth, pred_depth)
|
||
|
||
# 将预测结果添加到列表中
|
||
all_preds.append(pred)
|
||
all_preds_depth.append(pred_depth)
|
||
|
||
# 更新进度条
|
||
pbar.update(loader.batch_size)
|
||
|
||
# 将所有预测结果和深度图保存为视频
|
||
all_preds = np.stack(all_preds, axis=0)
|
||
all_preds_depth = np.stack(all_preds_depth, axis=0)
|
||
if self.opt.fullbody:
|
||
imageio.mimwrite(os.path.join(save_path, f'{name}_fullbody.mp4'), all_full_merged_imgs, fps=25, quality=8, macro_block_size=1)
|
||
else:
|
||
imageio.mimwrite(os.path.join(save_path, f'{name}.mp4'), all_preds, fps=25, quality=8, macro_block_size=1)
|
||
imageio.mimwrite(os.path.join(save_path, f'{name}_depth.mp4'), all_preds_depth, fps=25, quality=8, macro_block_size=1)
|
||
|
||
# 如果指定了音频文件和ASR模型为'ave',则将音频合并到视频中
|
||
result_file = None
|
||
if self.opt.aud != '':
|
||
if self.opt.asr_model == 'ave':
|
||
result_file = os.path.join(save_path, f"{name}_audio.mp4")
|
||
os.system(f'ffmpeg -i {os.path.join(save_path, f"{name}.mp4")} -i {self.opt.aud} -strict -2 {result_file} -y')
|
||
elif self.opt.asr_model == 'hubert' and self.opt.fullbody:
|
||
result_file = os.path.join(save_path, f"{name}_full_audio.mp4")
|
||
os.system(f'ffmpeg -i {os.path.join(save_path, f"{name}_fullbody.mp4")} -i {self.opt.aud} -strict -2 {result_file} -y')
|
||
|
||
# 记录测试完成
|
||
self.log(f"==> Finished Test.")
|
||
return result_file
|
||
|
||
# [GUI] just train for 16 steps, without any other overhead that may slow down rendering.
|
||
def train_gui(self, train_loader, step=16):
|
||
|
||
self.model.train()
|
||
|
||
total_loss = torch.tensor([0], dtype=torch.float32, device=self.device)
|
||
|
||
loader = iter(train_loader)
|
||
|
||
# mark untrained grid
|
||
if self.global_step == 0:
|
||
self.model.mark_untrained_grid(train_loader._data.poses, train_loader._data.intrinsics)
|
||
|
||
for _ in range(step):
|
||
|
||
# mimic an infinite loop dataloader (in case the total dataset is smaller than step)
|
||
try:
|
||
data = next(loader)
|
||
except StopIteration:
|
||
loader = iter(train_loader)
|
||
data = next(loader)
|
||
|
||
# update grid every 16 steps
|
||
if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0:
|
||
with torch.cuda.amp.autocast(enabled=self.fp16):
|
||
self.model.update_extra_state()
|
||
|
||
self.global_step += 1
|
||
|
||
self.optimizer.zero_grad()
|
||
|
||
with torch.cuda.amp.autocast(enabled=self.fp16):
|
||
preds, truths, loss = self.train_step(data)
|
||
|
||
self.scaler.scale(loss).backward()
|
||
self.scaler.step(self.optimizer)
|
||
self.scaler.update()
|
||
|
||
if self.scheduler_update_every_step:
|
||
self.lr_scheduler.step()
|
||
|
||
total_loss += loss.detach()
|
||
|
||
if self.ema is not None and self.global_step % self.ema_update_interval == 0:
|
||
self.ema.update()
|
||
|
||
average_loss = total_loss.item() / step
|
||
|
||
if not self.scheduler_update_every_step:
|
||
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||
self.lr_scheduler.step(average_loss)
|
||
else:
|
||
self.lr_scheduler.step()
|
||
|
||
outputs = {
|
||
'loss': average_loss,
|
||
'lr': self.optimizer.param_groups[0]['lr'],
|
||
}
|
||
|
||
return outputs
|
||
|
||
# [GUI] test on a single image
|
||
def test_gui(self, pose, intrinsics, W, H, auds, eye=None, index=0, bg_color=None, spp=1, downscale=1):
|
||
|
||
# render resolution (may need downscale to for better frame rate)
|
||
rH = int(H * downscale)
|
||
rW = int(W * downscale)
|
||
intrinsics = intrinsics * downscale
|
||
|
||
if auds is not None:
|
||
auds = auds.to(self.device)
|
||
|
||
pose = torch.from_numpy(pose).unsqueeze(0).to(self.device)
|
||
rays = get_rays(pose, intrinsics, rH, rW, -1)
|
||
|
||
bg_coords = get_bg_coords(rH, rW, self.device)
|
||
|
||
if eye is not None:
|
||
eye = torch.FloatTensor([eye]).view(1, 1).to(self.device)
|
||
|
||
data = {
|
||
'rays_o': rays['rays_o'],
|
||
'rays_d': rays['rays_d'],
|
||
'H': rH,
|
||
'W': rW,
|
||
'auds': auds,
|
||
'index': [index], # support choosing index for individual codes
|
||
'eye': eye,
|
||
'poses': pose,
|
||
'bg_coords': bg_coords,
|
||
}
|
||
|
||
self.model.eval()
|
||
|
||
if self.ema is not None:
|
||
self.ema.store()
|
||
self.ema.copy_to()
|
||
|
||
with torch.no_grad():
|
||
with torch.cuda.amp.autocast(enabled=self.fp16):
|
||
# here spp is used as perturb random seed!
|
||
# face: do not perturb for the first spp, else lead to scatters.
|
||
preds, preds_depth = self.test_step(data, bg_color=bg_color, perturb=False if spp == 1 else spp)
|
||
|
||
if self.ema is not None:
|
||
self.ema.restore()
|
||
|
||
# interpolation to the original resolution
|
||
if downscale != 1:
|
||
# TODO: have to permute twice with torch...
|
||
preds = F.interpolate(preds.permute(0, 3, 1, 2), size=(H, W), mode='bilinear').permute(0, 2, 3, 1).contiguous()
|
||
preds_depth = F.interpolate(preds_depth.unsqueeze(1), size=(H, W), mode='nearest').squeeze(1)
|
||
|
||
if self.opt.color_space == 'linear':
|
||
preds = linear_to_srgb(preds)
|
||
|
||
pred = preds[0].detach().cpu().numpy()
|
||
pred_depth = preds_depth[0].detach().cpu().numpy()
|
||
|
||
outputs = {
|
||
'image': pred,
|
||
'depth': pred_depth,
|
||
}
|
||
|
||
return outputs
|
||
|
||
# [GUI] test with provided data
|
||
def test_gui_with_data(self, data, W, H):
|
||
|
||
self.model.eval()
|
||
|
||
if self.ema is not None:
|
||
self.ema.store()
|
||
self.ema.copy_to()
|
||
|
||
with torch.no_grad():
|
||
with torch.cuda.amp.autocast(enabled=self.fp16):
|
||
# here spp is used as perturb random seed!
|
||
# face: do not perturb for the first spp, else lead to scatters.
|
||
preds, preds_depth = self.test_step(data, perturb=False)
|
||
|
||
if self.ema is not None:
|
||
self.ema.restore()
|
||
|
||
if self.opt.color_space == 'linear':
|
||
preds = linear_to_srgb(preds)
|
||
|
||
# the H/W in data may be differnt to GUI, so we still need to resize...
|
||
preds = F.interpolate(preds.permute(0, 3, 1, 2), size=(H, W), mode='bilinear').permute(0, 2, 3, 1).contiguous()
|
||
preds_depth = F.interpolate(preds_depth.unsqueeze(1), size=(H, W), mode='nearest').squeeze(1)
|
||
|
||
pred = preds[0].detach().cpu().numpy()
|
||
pred_depth = preds_depth[0].detach().cpu().numpy()
|
||
|
||
outputs = {
|
||
'image': pred,
|
||
'depth': pred_depth,
|
||
}
|
||
|
||
return outputs
|
||
|
||
def train_one_epoch(self, loader):
|
||
self.log(f"==> Start Training Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...")
|
||
|
||
total_loss = 0
|
||
if self.local_rank == 0 and self.report_metric_at_train:
|
||
for metric in self.metrics:
|
||
metric.clear()
|
||
|
||
self.model.train()
|
||
|
||
# distributedSampler: must call set_epoch() to shuffle indices across multiple epochs
|
||
# ref: https://pytorch.org/docs/stable/data.html
|
||
if self.world_size > 1:
|
||
loader.sampler.set_epoch(self.epoch)
|
||
|
||
if self.local_rank == 0:
|
||
pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, mininterval=1, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
|
||
|
||
self.local_step = 0
|
||
|
||
for data in loader:
|
||
# update grid every 16 steps
|
||
if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0:
|
||
with torch.cuda.amp.autocast(enabled=self.fp16):
|
||
self.model.update_extra_state()
|
||
|
||
self.local_step += 1
|
||
self.global_step += 1
|
||
|
||
self.optimizer.zero_grad()
|
||
|
||
with torch.cuda.amp.autocast(enabled=self.fp16):
|
||
preds, truths, loss = self.train_step(data)
|
||
|
||
self.scaler.scale(loss).backward()
|
||
self.scaler.step(self.optimizer)
|
||
self.scaler.update()
|
||
|
||
if self.scheduler_update_every_step:
|
||
self.lr_scheduler.step()
|
||
|
||
loss_val = loss.item()
|
||
total_loss += loss_val
|
||
|
||
if self.ema is not None and self.global_step % self.ema_update_interval == 0:
|
||
self.ema.update()
|
||
|
||
if self.local_rank == 0:
|
||
if self.report_metric_at_train:
|
||
for metric in self.metrics:
|
||
metric.update(preds, truths)
|
||
|
||
if self.use_tensorboardX:
|
||
self.writer.add_scalar("train/loss", loss_val, self.global_step)
|
||
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]['lr'], self.global_step)
|
||
|
||
if self.scheduler_update_every_step:
|
||
pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}")
|
||
else:
|
||
pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})")
|
||
pbar.update(loader.batch_size)
|
||
|
||
average_loss = total_loss / self.local_step
|
||
self.stats["loss"].append(average_loss)
|
||
|
||
if self.local_rank == 0:
|
||
pbar.close()
|
||
if self.report_metric_at_train:
|
||
for metric in self.metrics:
|
||
self.log(metric.report(), style="red")
|
||
if self.use_tensorboardX:
|
||
metric.write(self.writer, self.epoch, prefix="train")
|
||
metric.clear()
|
||
|
||
if not self.scheduler_update_every_step:
|
||
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||
self.lr_scheduler.step(average_loss)
|
||
else:
|
||
self.lr_scheduler.step()
|
||
self.log(f"loss={average_loss:.4f}")
|
||
self.log(f"==> Finished Epoch {self.epoch}.")
|
||
|
||
|
||
def evaluate_one_epoch(self, loader, name=None):
|
||
self.log(f"++> Evaluate at epoch {self.epoch} ...")
|
||
|
||
if name is None:
|
||
name = f'{self.name}_ep{self.epoch:04d}'
|
||
|
||
total_loss = 0
|
||
if self.local_rank == 0:
|
||
for metric in self.metrics:
|
||
metric.clear()
|
||
|
||
self.model.eval()
|
||
|
||
if self.ema is not None:
|
||
self.ema.store()
|
||
self.ema.copy_to()
|
||
|
||
if self.local_rank == 0:
|
||
pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
|
||
|
||
with torch.no_grad():
|
||
self.local_step = 0
|
||
|
||
for data in loader:
|
||
self.local_step += 1
|
||
|
||
with torch.cuda.amp.autocast(enabled=self.fp16):
|
||
preds, preds_depth, pred_ambient_aud, pred_ambient_eye, pred_uncertainty, truths, loss, loss_raw = self.eval_step(data)
|
||
loss_val = loss.item()
|
||
total_loss += loss_val
|
||
|
||
# only rank = 0 will perform evaluation.
|
||
if self.local_rank == 0:
|
||
|
||
# save image
|
||
save_path = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_rgb.png')
|
||
save_path_depth = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_depth.png')
|
||
save_path_ambient_aud = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_aud.png')
|
||
save_path_ambient_eye = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_eye.png')
|
||
save_path_uncertainty = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_uncertainty.png')
|
||
|
||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||
|
||
if self.opt.color_space == 'linear':
|
||
preds = linear_to_srgb(preds)
|
||
|
||
if self.opt.portrait:
|
||
pred = blend_with_mask_cuda(preds[0], data["bg_gt_images"].squeeze(0),data["bg_face_mask"].squeeze(0))
|
||
preds = torch.from_numpy(pred).unsqueeze(0).to(self.device).float()
|
||
else:
|
||
pred = preds[0].detach().cpu().numpy()
|
||
pred_depth = preds_depth[0].detach().cpu().numpy()
|
||
|
||
for metric in self.metrics:
|
||
metric.update(preds, truths)
|
||
# loss_raw = loss_raw[0].mean(-1).detach().cpu().numpy()
|
||
# loss_raw = (loss_raw - np.min(loss_raw)) / (np.max(loss_raw) - np.min(loss_raw))
|
||
pred_ambient_aud = pred_ambient_aud[0].detach().cpu().numpy()
|
||
pred_ambient_aud /= np.max(pred_ambient_aud)
|
||
pred_ambient_eye = pred_ambient_eye[0].detach().cpu().numpy()
|
||
pred_ambient_eye /= np.max(pred_ambient_eye)
|
||
# pred_ambient = pred_ambient / 16
|
||
# print(pred_ambient.shape)
|
||
pred_uncertainty = pred_uncertainty[0].detach().cpu().numpy()
|
||
# pred_uncertainty = (pred_uncertainty - np.min(pred_uncertainty)) / (np.max(pred_uncertainty) - np.min(pred_uncertainty))
|
||
pred_uncertainty /= np.max(pred_uncertainty)
|
||
|
||
cv2.imwrite(save_path, cv2.cvtColor((pred * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
|
||
|
||
if not self.opt.torso:
|
||
cv2.imwrite(save_path_depth, (pred_depth * 255).astype(np.uint8))
|
||
# cv2.imwrite(save_path_error, (loss_raw * 255).astype(np.uint8))
|
||
cv2.imwrite(save_path_ambient_aud, (pred_ambient_aud * 255).astype(np.uint8))
|
||
cv2.imwrite(save_path_ambient_eye, (pred_ambient_eye * 255).astype(np.uint8))
|
||
cv2.imwrite(save_path_uncertainty, (pred_uncertainty * 255).astype(np.uint8))
|
||
#cv2.imwrite(save_path_gt, cv2.cvtColor((linear_to_srgb(truths[0].detach().cpu().numpy()) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
|
||
|
||
pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})")
|
||
pbar.update(loader.batch_size)
|
||
|
||
|
||
average_loss = total_loss / self.local_step
|
||
self.stats["valid_loss"].append(average_loss)
|
||
|
||
if self.local_rank == 0:
|
||
pbar.close()
|
||
if not self.use_loss_as_metric and len(self.metrics) > 0:
|
||
result = self.metrics[0].measure()
|
||
self.stats["results"].append(result if self.best_mode == 'min' else - result) # if max mode, use -result
|
||
else:
|
||
self.stats["results"].append(average_loss) # if no metric, choose best by min loss
|
||
|
||
for metric in self.metrics:
|
||
self.log(metric.report(), style="blue")
|
||
if self.use_tensorboardX:
|
||
metric.write(self.writer, self.epoch, prefix="evaluate")
|
||
metric.clear()
|
||
|
||
if self.ema is not None:
|
||
self.ema.restore()
|
||
|
||
self.log(f"++> Evaluate epoch {self.epoch} Finished.")
|
||
|
||
def save_checkpoint(self, name=None, full=False, best=False, remove_old=True):
|
||
|
||
if name is None:
|
||
name = f'{self.name}_ep{self.epoch:04d}'
|
||
|
||
state = {
|
||
'epoch': self.epoch,
|
||
'global_step': self.global_step,
|
||
'stats': self.stats,
|
||
}
|
||
|
||
|
||
state['mean_count'] = self.model.mean_count
|
||
state['mean_density'] = self.model.mean_density
|
||
state['mean_density_torso'] = self.model.mean_density_torso
|
||
|
||
if full:
|
||
state['optimizer'] = self.optimizer.state_dict()
|
||
state['lr_scheduler'] = self.lr_scheduler.state_dict()
|
||
state['scaler'] = self.scaler.state_dict()
|
||
if self.ema is not None:
|
||
state['ema'] = self.ema.state_dict()
|
||
|
||
if not best:
|
||
|
||
state['model'] = self.model.state_dict()
|
||
|
||
file_path = f"{self.ckpt_path}/{name}.pth"
|
||
|
||
if remove_old:
|
||
self.stats["checkpoints"].append(file_path)
|
||
|
||
if len(self.stats["checkpoints"]) > self.max_keep_ckpt:
|
||
old_ckpt = self.stats["checkpoints"].pop(0)
|
||
if os.path.exists(old_ckpt):
|
||
os.remove(old_ckpt)
|
||
|
||
|
||
torch.save(state, file_path)
|
||
|
||
else:
|
||
if len(self.stats["results"]) > 0:
|
||
# always save new as best... (since metric cannot really reflect performance...)
|
||
if True:
|
||
|
||
# save ema results
|
||
if self.ema is not None:
|
||
self.ema.store()
|
||
self.ema.copy_to()
|
||
|
||
state['model'] = self.model.state_dict()
|
||
|
||
# we don't consider continued training from the best ckpt, so we discard the unneeded density_grid to save some storage (especially important for dnerf)
|
||
if 'density_grid' in state['model']:
|
||
del state['model']['density_grid']
|
||
|
||
if self.ema is not None:
|
||
self.ema.restore()
|
||
|
||
torch.save(state, self.best_path)
|
||
else:
|
||
self.log(f"[WARN] no evaluated results found, skip saving best checkpoint.")
|
||
|
||
def load_checkpoint(self, checkpoint=None, model_only=False):
|
||
if checkpoint is None:
|
||
checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/{self.name}_ep*.pth'))
|
||
if checkpoint_list:
|
||
checkpoint = checkpoint_list[-1]
|
||
self.log(f"[INFO] Latest checkpoint is {checkpoint}")
|
||
else:
|
||
self.log("[WARN] No checkpoint found, model randomly initialized.")
|
||
return
|
||
|
||
checkpoint_dict = torch.load(checkpoint, map_location=self.device)
|
||
|
||
if 'model' not in checkpoint_dict:
|
||
self.model.load_state_dict(checkpoint_dict)
|
||
self.log("[INFO] loaded bare model.")
|
||
return
|
||
|
||
missing_keys, unexpected_keys = self.model.load_state_dict(checkpoint_dict['model'], strict=False)
|
||
self.log("[INFO] loaded model.")
|
||
if len(missing_keys) > 0:
|
||
self.log(f"[WARN] missing keys: {missing_keys}")
|
||
if len(unexpected_keys) > 0:
|
||
self.log(f"[WARN] unexpected keys: {unexpected_keys}")
|
||
|
||
if self.ema is not None and 'ema' in checkpoint_dict:
|
||
self.ema.load_state_dict(checkpoint_dict['ema'])
|
||
|
||
|
||
if 'mean_count' in checkpoint_dict:
|
||
self.model.mean_count = checkpoint_dict['mean_count']
|
||
if 'mean_density' in checkpoint_dict:
|
||
self.model.mean_density = checkpoint_dict['mean_density']
|
||
if 'mean_density_torso' in checkpoint_dict:
|
||
self.model.mean_density_torso = checkpoint_dict['mean_density_torso']
|
||
|
||
if model_only:
|
||
return
|
||
|
||
self.stats = checkpoint_dict['stats']
|
||
self.epoch = checkpoint_dict['epoch']
|
||
self.global_step = checkpoint_dict['global_step']
|
||
self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}")
|
||
|
||
if self.optimizer and 'optimizer' in checkpoint_dict:
|
||
try:
|
||
self.optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
||
self.log("[INFO] loaded optimizer.")
|
||
except:
|
||
self.log("[WARN] Failed to load optimizer.")
|
||
|
||
if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict:
|
||
try:
|
||
self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler'])
|
||
self.log("[INFO] loaded scheduler.")
|
||
except:
|
||
self.log("[WARN] Failed to load scheduler.")
|
||
|
||
if self.scaler and 'scaler' in checkpoint_dict:
|
||
try:
|
||
self.scaler.load_state_dict(checkpoint_dict['scaler'])
|
||
self.log("[INFO] loaded scaler.")
|
||
except:
|
||
self.log("[WARN] Failed to load scaler.")
|
||
|
||
|
||
def load_wav(path, sr):
|
||
return librosa.core.load(path, sr=sr)[0]
|
||
|
||
|
||
def preemphasis(wav, k):
|
||
return signal.lfilter([1, -k], [1], wav)
|
||
|
||
|
||
def melspectrogram(wav):
|
||
D = _stft(preemphasis(wav, 0.97))
|
||
S = _amp_to_db(_linear_to_mel(np.abs(D))) - 20
|
||
|
||
return _normalize(S)
|
||
|
||
|
||
def _stft(y):
|
||
return librosa.stft(y=y, n_fft=800, hop_length=200, win_length=800)
|
||
|
||
|
||
def _linear_to_mel(spectogram):
|
||
global _mel_basis
|
||
_mel_basis = _build_mel_basis()
|
||
return np.dot(_mel_basis, spectogram)
|
||
|
||
|
||
def _build_mel_basis():
|
||
return librosa.filters.mel(sr=16000, n_fft=800, n_mels=80, fmin=55, fmax=7600)
|
||
|
||
|
||
def _amp_to_db(x):
|
||
min_level = np.exp(-5 * np.log(10))
|
||
return 20 * np.log10(np.maximum(min_level, x))
|
||
|
||
|
||
def _normalize(S):
|
||
return np.clip((2 * 4.) * ((S - -100) / (--100)) - 4., -4., 4.)
|
||
|
||
|
||
class AudDataset(object):
|
||
def __init__(self, wavpath):
|
||
wav = load_wav(wavpath, 16000)
|
||
|
||
self.orig_mel = melspectrogram(wav).T
|
||
self.data_len = int((self.orig_mel.shape[0] - 16) / 80. * float(25)) + 2
|
||
|
||
def get_frame_id(self, frame):
|
||
return int(basename(frame).split('.')[0])
|
||
|
||
def crop_audio_window(self, spec, start_frame):
|
||
if type(start_frame) == int:
|
||
start_frame_num = start_frame
|
||
else:
|
||
start_frame_num = self.get_frame_id(start_frame)
|
||
start_idx = int(80. * (start_frame_num / float(25)))
|
||
|
||
end_idx = start_idx + 16
|
||
if end_idx > spec.shape[0]:
|
||
# print(end_idx, spec.shape[0])
|
||
end_idx = spec.shape[0]
|
||
start_idx = end_idx - 16
|
||
|
||
return spec[start_idx: end_idx, :]
|
||
|
||
def __len__(self):
|
||
return self.data_len
|
||
|
||
def __getitem__(self, idx):
|
||
|
||
mel = self.crop_audio_window(self.orig_mel.copy(), idx)
|
||
if (mel.shape[0] != 16):
|
||
raise Exception('mel.shape[0] != 16')
|
||
mel = torch.FloatTensor(mel.T).unsqueeze(0)
|
||
|
||
return mel
|