336 lines
17 KiB
Python
336 lines
17 KiB
Python
import argparse
|
|
import os
|
|
import hashlib
|
|
import torch
|
|
from yaml import safe_load
|
|
from nerf_triplane.provider import NeRFDataset
|
|
from nerf_triplane.utils import *
|
|
from nerf_triplane.network import NeRFNetwork
|
|
from data_utils.hubert import process_audio
|
|
|
|
CACHE_DIR = './cache'
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser(description="NeRF Inference Script")
|
|
|
|
# File paths and general configuration options
|
|
parser.add_argument('--path', type=str, help="Path to the input data")
|
|
parser.add_argument('--workspace', type=str, default='workspace', help="Directory for storing intermediate and final results")
|
|
parser.add_argument('--seed', type=int, default=0, help="Random seed for reproducibility")
|
|
parser.add_argument('--config', type=str, default=None, help='Path to configuration file (YAML)')
|
|
|
|
# New cache related arguments
|
|
parser.add_argument('--cache', action='store_true', help="Enable caching for data and model loading")
|
|
|
|
# Test modes
|
|
parser.add_argument('--test', action='store_true', help="Test mode using the test dataset")
|
|
parser.add_argument('--test_train', action='store_true', help="Test mode using the train dataset")
|
|
|
|
# Data range
|
|
parser.add_argument('--data_range', type=int, nargs=2, default=[0, -1], help="Range of data indices to use [start, end)")
|
|
|
|
# Training options
|
|
parser.add_argument('--iters', type=int, default=200000, help="Number of training iterations")
|
|
parser.add_argument('--lr', type=float, default=1e-2, help="Initial learning rate for the main network")
|
|
parser.add_argument('--lr_net', type=float, default=1e-3, help="Initial learning rate for other networks")
|
|
|
|
# Checkpoint management
|
|
parser.add_argument('--ckpt', type=str, default='latest', help="Checkpoint to load or save")
|
|
|
|
# Ray sampling settings
|
|
parser.add_argument('--num_rays', type=int, default=4096 * 16, help="Number of rays sampled per image for each training step")
|
|
parser.add_argument('--cuda_ray', action='store_true', help="Use CUDA raymarching instead of PyTorch")
|
|
parser.add_argument('--max_steps', type=int, default=16, help="Max num steps sampled per ray (only valid when using --cuda_ray)")
|
|
parser.add_argument('--num_steps', type=int, default=16, help="Num steps sampled per ray (only valid when NOT using --cuda_ray)")
|
|
parser.add_argument('--upsample_steps', type=int, default=0, help="Num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
|
|
parser.add_argument('--update_extra_interval', type=int, default=16, help="Iter interval to update extra status (only valid when using --cuda_ray)")
|
|
parser.add_argument('--max_ray_batch', type=int, default=4096, help="Batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
|
|
|
|
# Loss settings
|
|
parser.add_argument('--warmup_step', type=int, default=10000, help="Number of warm up steps")
|
|
parser.add_argument('--amb_aud_loss', action='store_true', help="Use ambient audio loss")
|
|
parser.add_argument('--amb_eye_loss', action='store_true', help="Use ambient eye loss")
|
|
parser.add_argument('--unc_loss', action='store_true', help="Use uncertainty loss")
|
|
parser.add_argument('--lambda_amb', type=float, default=1e-4, help="Lambda for ambient loss")
|
|
parser.add_argument('--pyramid_loss', action='store_true', help="Use perceptual loss")
|
|
|
|
# Network backbone options
|
|
parser.add_argument('--fp16', action='store_true', help="Use AMP mixed precision training")
|
|
parser.add_argument('--bg_img', type=str, default='', help="Background image path")
|
|
parser.add_argument('--fbg', action='store_true', help="Frame-wise background")
|
|
parser.add_argument('--exp_eye', action='store_true', help="Explicitly control the eyes")
|
|
parser.add_argument('--fix_eye', type=float, default=-1, help="Fixed eye area (negative to disable)")
|
|
parser.add_argument('--smooth_eye', action='store_true', help="Smooth the eye area sequence")
|
|
parser.add_argument('--bs_area', choices=['upper', 'eye'], default="upper", help="Area for background subtraction ('upper' or 'eye')")
|
|
parser.add_argument('--au45', action='store_true', help="Use OpenFace AU45")
|
|
parser.add_argument('--torso_shrink', type=float, default=0.8, help="Shrink bg coords to allow more flexibility in deform")
|
|
|
|
# Dataset options
|
|
parser.add_argument('--color_space', choices=['linear', 'srgb'], default='srgb', help="Color space (supports linear or srgb)")
|
|
parser.add_argument('--preload', type=int, choices=[0, 1, 2], default=0, help="Preload data (0: on-the-fly, 1: CPU, 2: GPU)")
|
|
parser.add_argument('--bound', type=float, default=1, help="Assume the scene is bounded in box[-bound, bound]^3")
|
|
parser.add_argument('--scale', type=float, default=4, help="Scale camera location into box[-bound, bound]^3")
|
|
parser.add_argument('--offset', type=float, nargs=3, default=[0, 0, 0], help="Offset of camera location [x, y, z]")
|
|
parser.add_argument('--dt_gamma', type=float, default=1/256, help="Dt_gamma for adaptive ray marching")
|
|
parser.add_argument('--min_near', type=float, default=0.05, help="Minimum near distance for camera")
|
|
parser.add_argument('--density_thresh', type=float, default=10, help="Threshold for density grid to be occupied (sigma)")
|
|
parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="Threshold for density grid to be occupied (alpha)")
|
|
parser.add_argument('--patch_size', type=int, choices=[1] + [64, 32, 16], default=1, help="Render patches in training for LPIPS loss")
|
|
|
|
# Specific training options
|
|
parser.add_argument('--init_lips', action='store_true', help="Initialize lips region")
|
|
parser.add_argument('--finetune_lips', action='store_true', help="Fine tune lips region using LPIPS and landmarks")
|
|
parser.add_argument('--smooth_lips', action='store_true', help="Smooth enc_a in a exponential decay way")
|
|
parser.add_argument('--torso', action='store_true', help="Fix head and train torso")
|
|
parser.add_argument('--head_ckpt', type=str, default='', help="Path to the pre-trained head model")
|
|
|
|
# GUI options
|
|
parser.add_argument('--gui', action='store_true', help="Start a GUI interface")
|
|
parser.add_argument('--W', type=int, default=450, help="GUI width (pixels)")
|
|
parser.add_argument('--H', type=int, default=450, help="GUI height (pixels)")
|
|
parser.add_argument('--radius', type=float, default=3.35, help="Default GUI camera radius from center")
|
|
parser.add_argument('--fovy', type=float, default=21.24, help="Default GUI camera field of view in degrees")
|
|
parser.add_argument('--max_spp', type=int, default=1, help="Max samples per pixel for GUI rendering")
|
|
|
|
# Other options
|
|
parser.add_argument('--fullbody', action='store_true', help="Enable full body mode")
|
|
parser.add_argument('--att', type=int, choices=[0, 1, 2], default=2, help="Audio attention mode (0 = off, 1 = left-direction, 2 = bi-direction)")
|
|
parser.add_argument('--aud', type=str, default='', help="Path to audio source (empty for default)")
|
|
parser.add_argument('--emb', action='store_true', help="Use audio class + embedding instead of logits")
|
|
parser.add_argument('--portrait', action='store_true', help="Only render face")
|
|
|
|
# Other options (continued)
|
|
parser.add_argument('--ind_dim', type=int, default=4, help="Dimension of individual codes (0 to turn off)")
|
|
parser.add_argument('--ind_num', type=int, default=20000, help="Number of individual codes (should be larger than training dataset size)")
|
|
parser.add_argument('--ind_dim_torso', type=int, default=8, help="Dimension of torso individual codes (0 to turn off)")
|
|
parser.add_argument('--amb_dim', type=int, default=2, help="Ambient dimension")
|
|
parser.add_argument('--part', action='store_true', help="Use partial training data (1/10)")
|
|
parser.add_argument('--part2', action='store_true', help="Use partial training data (first 15s)")
|
|
parser.add_argument('--train_camera', action='store_true', help="Optimize camera pose")
|
|
parser.add_argument('--smooth_path', action='store_true', help="Smooth camera pose trajectory with a window size")
|
|
parser.add_argument('--smooth_path_window', type=int, default=7, help="Smoothing window size for camera path")
|
|
|
|
# ASR settings
|
|
parser.add_argument('--asr', action='store_true', help="Load ASR for real-time application")
|
|
parser.add_argument('--asr_wav', type=str, default='', help="Path to WAV file for input")
|
|
parser.add_argument('--asr_play', action='store_true', help="Play out the audio in real time")
|
|
parser.add_argument('--asr_model', type=str, default='deepspeech', help="ASR model to use")
|
|
parser.add_argument('--asr_save_feats', action='store_true', help="Save features extracted by the ASR model")
|
|
|
|
# Audio processing settings
|
|
parser.add_argument('--fps', type=int, default=50, help="Audio frames per second")
|
|
parser.add_argument('-l', type=int, default=10, help="Length of sliding window (left) in 20ms units")
|
|
parser.add_argument('-m', type=int, default=50, help="Length of sliding window (middle) in 20ms units")
|
|
parser.add_argument('-r', type=int, default=10, help="Length of sliding window (right) in 20ms units")
|
|
|
|
# Shortcut options for common combinations
|
|
parser.add_argument('--O', action='store_true', help="Shortcut for --fp16 --cuda_ray --exp_eye")
|
|
|
|
return parser.parse_args()
|
|
|
|
def load_config(config_path):
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
config = safe_load(f)
|
|
return config
|
|
|
|
def merge_configs(base_config, override_config):
|
|
"""
|
|
Merge two configurations. Override_config values take precedence.
|
|
"""
|
|
for key, value in override_config.items():
|
|
if isinstance(value, dict) and key in base_config and isinstance(base_config[key], dict):
|
|
merge_configs(base_config[key], value)
|
|
else:
|
|
base_config[key] = value
|
|
return base_config
|
|
|
|
def setup_args(args=None, config_override=None):
|
|
"""
|
|
Setup the configuration by merging default, file-based, command-line, and function parameters.
|
|
|
|
Parameters:
|
|
args: argparse.Namespace or None
|
|
Command-line arguments parsed using argparse.
|
|
config_override: dict or None
|
|
Function-based overrides to merge into the final configuration.
|
|
|
|
Returns:
|
|
opt: argparse.Namespace
|
|
The final configuration as an argparse.Namespace object.
|
|
"""
|
|
# Load default configuration (if available)
|
|
default_config_path = os.path.join(os.path.dirname(__file__), 'default_config.yaml')
|
|
if os.path.exists(default_config_path):
|
|
with open(default_config_path, 'r', encoding='utf-8') as f:
|
|
config = safe_load(f)
|
|
else:
|
|
raise FileNotFoundError("Default configuration file not found.")
|
|
|
|
# Override with user-specified configuration file
|
|
if args and args.config is not None:
|
|
override_config = load_config(args.config)
|
|
config = merge_configs(config, override_config)
|
|
|
|
# Merge command-line arguments into the configuration
|
|
if args:
|
|
for key, value in vars(args).items():
|
|
if value is not None:
|
|
config[key] = value
|
|
|
|
# Override with function-based parameters
|
|
if config_override:
|
|
config = merge_configs(config, config_override)
|
|
|
|
return argparse.Namespace(**config)
|
|
|
|
def get_cache_key(opt):
|
|
"""
|
|
Generate a cache key based on the current configuration.
|
|
"""
|
|
hash_input = f"{opt.path}_{opt.workspace}"
|
|
cache_key = hashlib.md5(hash_input.encode()).hexdigest()
|
|
return cache_key
|
|
|
|
def load_from_cache(cache_key, data_type):
|
|
"""
|
|
Load data or model from cache if available.
|
|
"""
|
|
cache_path = os.path.join(CACHE_DIR, f"{cache_key}_{data_type}.pt")
|
|
if os.path.exists(cache_path):
|
|
return torch.load(cache_path)
|
|
return None
|
|
|
|
def save_to_cache(cache_key, data, data_type):
|
|
"""
|
|
Save data or model to cache.
|
|
"""
|
|
if not os.path.exists(CACHE_DIR):
|
|
os.makedirs(CACHE_DIR)
|
|
cache_path = os.path.join(CACHE_DIR, f"{cache_key}_{data_type}.pt")
|
|
torch.save(data, cache_path)
|
|
|
|
def inference(opt):
|
|
"""
|
|
Perform inference using the NeRF model.
|
|
|
|
Parameters:
|
|
opt: argparse.Namespace
|
|
The options that define how to perform inference, similar to command-line arguments.
|
|
|
|
Returns:
|
|
results: list or dict
|
|
The output of the test (e.g., metrics) and possibly images/video frames.
|
|
"""
|
|
# Close tf32 features. Fix low numerical accuracy on rtx30xx gpu.
|
|
try:
|
|
torch.backends.cuda.matmul.allow_tf32 = False
|
|
torch.backends.cudnn.allow_tf32 = False
|
|
except AttributeError as e:
|
|
print('Info. This PyTorch version is not support with tf32.')
|
|
|
|
seed_everything(opt.seed)
|
|
|
|
if opt.O:
|
|
opt.fp16 = True
|
|
opt.exp_eye = True
|
|
|
|
if opt.test and False:
|
|
opt.smooth_path = True
|
|
opt.smooth_eye = True
|
|
opt.smooth_lips = True
|
|
|
|
opt.cuda_ray = True
|
|
# assert opt.cuda_ray, "Only support CUDA ray mode."
|
|
|
|
if opt.patch_size > 1:
|
|
# assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss."
|
|
assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays."
|
|
|
|
print(opt)
|
|
|
|
cache_key = get_cache_key(opt)
|
|
|
|
# Load audio features
|
|
if opt.cache and opt.asr_model == 'hubert':
|
|
cached_audio = load_from_cache(cache_key, 'audio')
|
|
if cached_audio is not None:
|
|
process_audio(opt.aud, cached_audio)
|
|
else:
|
|
audio_data = process_audio(opt.aud)
|
|
save_to_cache(cache_key, audio_data, 'audio')
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
# Load model from cache if available
|
|
model = None
|
|
if opt.cache:
|
|
model = load_from_cache(cache_key, 'model')
|
|
|
|
if model is None:
|
|
model = NeRFNetwork(opt)
|
|
if opt.torso and hasattr(opt, 'head_ckpt') and opt.head_ckpt != '':
|
|
model_dict = torch.load(opt.head_ckpt, map_location='cpu')['model']
|
|
missing_keys, unexpected_keys = model.load_state_dict(model_dict, strict=False)
|
|
if len(missing_keys) > 0:
|
|
print(f"[WARN] Missing keys: {missing_keys}")
|
|
if len(unexpected_keys) > 0:
|
|
print(f"[WARN] Unexpected keys: {unexpected_keys}")
|
|
|
|
# Freeze these keys
|
|
for k, v in model.named_parameters():
|
|
if k in model_dict:
|
|
print(f'[INFO] Freezing {k}, {v.shape}')
|
|
v.requires_grad = False
|
|
|
|
if opt.cache:
|
|
save_to_cache(cache_key, model, 'model')
|
|
|
|
criterion = torch.nn.L1Loss(reduction='none')
|
|
metrics = [PSNRMeter(), LPIPSMeter(device=device), LMDMeter(backend='fan')]
|
|
|
|
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace.strip(),
|
|
criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=getattr(opt, 'ckpt', 'latest'))
|
|
|
|
test_loader = None
|
|
if getattr(opt, 'test_train', False):
|
|
test_set = NeRFDataset(opt, device=device, type='train')
|
|
# A manual fix to test on the training dataset
|
|
test_set.training = False
|
|
test_set.num_rays = -1
|
|
test_loader = test_set.dataloader()
|
|
else:
|
|
test_loader = NeRFDataset(opt, device=device, type='test').dataloader()
|
|
|
|
if opt.cache:
|
|
cached_test_loader = load_from_cache(cache_key, 'test_loader')
|
|
if cached_test_loader is None:
|
|
save_to_cache(cache_key, test_loader, 'test_loader')
|
|
else:
|
|
test_loader = cached_test_loader
|
|
|
|
model.aud_features = test_loader._data.auds
|
|
model.eye_areas = test_loader._data.eye_area
|
|
|
|
results = None
|
|
if getattr(opt, 'gui', False):
|
|
from nerf_triplane.gui import NeRFGUI
|
|
with NeRFGUI(opt, trainer, test_loader) as gui:
|
|
gui.render()
|
|
results = None # GUI doesn't return results directly, you may handle it in a different way
|
|
else:
|
|
### Test and save video (fast)
|
|
results = trainer.test(test_loader)
|
|
return results
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_arguments()
|
|
if args.config is None:
|
|
args = None
|
|
|
|
opt = setup_args(args)
|
|
|
|
# 验证参数(可选),参数必须设置才能进行推理
|
|
assert os.path.exists(opt.workspace), "Workspace directory does not exist."
|
|
assert os.path.exists(opt.path), "Dataset path does not exist."
|
|
|
|
results = inference(opt) |