digitalhumantalk/data_utils/hubert.py
2024-12-10 17:05:37 +08:00

110 lines
4.3 KiB
Python

from transformers import Wav2Vec2Processor, HubertModel
import soundfile as sf
import numpy as np
import torch
import os
from pathlib import Path
import soundfile as sf
import numpy as np
import torch
from argparse import ArgumentParser
import librosa
print("Loading the Wav2Vec2 Processor...")
wav2vec2_processor = Wav2Vec2Processor.from_pretrained("./model/facebook/hubert-large-ls960-ft")
print("Loading the HuBERT Model...")
hubert_model = HubertModel.from_pretrained("./model/facebook/hubert-large-ls960-ft")
def get_hubert_from_16k_wav(wav_16k_name):
speech_16k, _ = sf.read(wav_16k_name)
hubert = get_hubert_from_16k_speech(speech_16k)
return hubert
@torch.no_grad()
def get_hubert_from_16k_speech(speech, device="cuda:0"):
print("Getting HuBERT from 16k wav...")
global hubert_model
hubert_model = hubert_model.to(device)
if speech.ndim ==2:
speech = speech[:, 0] # [T, 2] ==> [T,]
input_values_all = wav2vec2_processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T]
input_values_all = input_values_all.to(device)
# For long audio sequence, due to the memory limitation, we cannot process them in one run
# HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320
# Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step.
# So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320
# We have the equation to calculate out time step: T = floor((t-k)/s)
# To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip
# The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N
kernel = 400
stride = 320
clip_length = stride * 1000
num_iter = input_values_all.shape[1] // clip_length
expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride
res_lst = []
for i in range(num_iter):
if i == 0:
start_idx = 0
end_idx = clip_length - stride + kernel
else:
start_idx = clip_length * i
end_idx = start_idx + (clip_length - stride + kernel)
input_values = input_values_all[:, start_idx: end_idx]
hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
res_lst.append(hidden_states[0])
if num_iter > 0:
input_values = input_values_all[:, clip_length * num_iter:]
else:
input_values = input_values_all
# if input_values.shape[1] != 0:
if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it
hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
res_lst.append(hidden_states[0])
ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024]
# assert ret.shape[0] == expected_T
assert abs(ret.shape[0] - expected_T) <= 1
if ret.shape[0] < expected_T:
ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0]))
else:
ret = ret[:expected_T]
return ret
def make_even_first_dim(tensor):
size = list(tensor.size())
if size[0] % 2 == 1:
size[0] -= 1
return tensor[:size[0]]
return tensor
def process_audio(wav_path):
# 假设这里是处理音频的逻辑
print(f"Processing audio: {wav_path}")
# check if the wav file exists
if not os.path.exists(wav_path):
print(f"Wav file {wav_path} does not exist!")
return
# 这里可以添加实际的音频处理代码
speech, sr = sf.read(wav_path)
speech_16k = librosa.resample(speech, orig_sr=sr, target_sr=16000)
print("SR: {} to {}".format(sr, 16000))
# print(speech.shape, speech_16k.shape)
hubert_hidden = get_hubert_from_16k_speech(speech_16k)
hubert_hidden = make_even_first_dim(hubert_hidden).reshape(-1, 2, 1024)
#音频后缀改为_hu.npy
base, ext = os.path.splitext(wav_path)
new_wav_name = f"{base}_hu.npy"
np.save(new_wav_name, hubert_hidden.detach().numpy())
print(hubert_hidden.detach().numpy().shape)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--wav', type=str, help='')
args = parser.parse_args()
wav_name = args.wav
process_audio(wav_name)