digitalhumantalk/gradio_app.py
2024-12-10 17:05:37 +08:00

192 lines
9.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import yaml
import gradio as gr
import asyncio
import requests
from inferenceapi import setup_args, inference
from thirdparty.edgetts.azuretts import textToSpeechAsync, SUPPORTED_VOICES as AZURE_SUPPORTED_VOICES
from thirdparty.aliyun.aliyuntts import aliyun_text_to_speech, SUPPORTED_VOICES as ALIYUN_SUPPORTED_VOICES
from thirdparty.ui.gradiotheme import EnhancedSeafoam
# 安装 PyYAML 库
# pip install pyyaml
# 加载Avatar配置文件
def load_avatar_configs(config_file="avatar_configs.yaml"):
with open(config_file, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
return data
avatar_data = load_avatar_configs()
avatars = avatar_data["avatars"]
# 定义全局模型缓存字典
models = {}
# 获取所有音频文件的唯一列表,用于下拉菜单和上传组件
all_audios = list(set(aud for avatar in avatars for aud in avatar["config"]["audios"]))
def run_inference(path, workspace, seed, test, test_train, aud):
args_dict = {
"path": path,
"workspace": workspace,
"seed": seed,
"test": test,
"test_train": test_train,
"aud": aud
}
opt = setup_args(None, args_dict)
assert os.path.exists(opt.workspace), "Workspace directory does not exist."
assert os.path.exists(opt.path), "Dataset path does not exist."
video_path = inference(opt)
return video_path
def update_params(avatar_index):
config = avatars[avatar_index]["config"]
default_audio = config['audios'][0] if config['audios'] else all_audios[0]
return config["path"], config["workspace"], config["seed"], default_audio
def find_examples(folder_path, extensions):
full_path = os.path.abspath(folder_path)
examples = []
for f in os.listdir(full_path):
if f.endswith(extensions):
file_path = os.path.join(full_path, f)
if os.path.exists(file_path):
examples.append(file_path)
else:
print(f"Warning: File not found: {file_path}")
return examples
base_dir = os.path.dirname(os.path.abspath(__file__))
audio_folder = os.path.join(base_dir, "assets", "")
audio_examples = find_examples(audio_folder, ('.mp3', '.wav', '.ogg'))
def save_uploaded_audio(file):
if not os.path.exists("assets"):
os.makedirs("assets")
file_path = os.path.join("assets", "temp_input.wav")
with open(file_path, "wb") as f:
f.write(file.read())
return file_path
def textToSpeechUnifiedAzure(text, voice, rate, volume):
return asyncio.run(textToSpeechAsync(text, voice, rate, volume))
def textToSpeechUnifiedAliyun(text, voice, rate, volume):
return asyncio.run(aliyun_text_to_speech(text, voice, rate, volume))[1]
def textToSpeechUnifiedCosy(text, sft_name="康辉", speaker_name="", prompt_text="", prompt_speech=""):
url = "http://localhost:8000/inference/tts"
headers = {"accept": "application/json", "Content-Type": "application/json"}
data = {"query": text, "speaker_name": speaker_name, "sft_name": sft_name, "prompt_text": prompt_text, "prompt_speech": prompt_speech}
response = requests.post(url, headers=headers, json=data)
if response.status_code == 200:
audio_content = response.content
file_path = f"assets/cosy_tmp.wav"
with open(file_path, "wb") as f:
f.write(audio_content)
return file_path
else:
raise Exception(f"CosyVoice TTS request failed with status {response.status_code}")
enhanced_seafoam = EnhancedSeafoam()
logo_path = os.path.join(base_dir, "assets", "company_logo_small_1.png")
def get_version_details():
with open("version_history.md", "r", encoding="utf-8") as markdown_file:
version_details = markdown_file.read()
return version_details
def gradio_interface():
with gr.Blocks(css="""footer {display: none !important;} #audio_button {color: black !important;} #output_video {color: blue !important;} """, theme=enhanced_seafoam) as app:
logo = gr.Image(logo_path, show_label=False, container=False)
gr.Markdown('<h1 style="text-align: center;">真人复刻数字人 v1.2</h1>', elem_id="title")
gr.Markdown('<p style="text-align: center;">Choose an avatar and configure parameters for generating animated models based on speech.</p>', elem_id="subtitle")
with gr.Accordion("版本更新说明", open=False):
gr.Markdown(get_version_details())
gr.Markdown('<p style="text-align: center;">本应用仅供测试体验,非法版权人像传播引起的一切后果请自行承担!</p>', elem_id="subtitle")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Choose Avatar")
avatar_gallery = gr.Gallery(label="Avatar Preview", value=[(avatar['photo'], avatar['name']) for avatar in avatars])
with gr.Column(scale=1):
output_video = gr.Video(label="Inference Result", elem_id="output_video")
gr.Markdown("### Input Audio驱动用的音频以下三种方式三选一")
with gr.Row():
audio_output = gr.Audio(label="Upload Audio上传或者麦克风录制语音", type="filepath")
with gr.Column(scale=1):
gr.Examples(examples=[[audio] for audio in audio_examples], inputs=[audio_output], label="Audio Examples", elem_id="audio_button")
with gr.Column(scale=2):
with gr.Tabs():
with gr.Tab("CosyVoice TTS声音复刻本地部署"):
with gr.Column():
cosy_input_text = gr.Textbox(label="Input Text for CosyVoice TTS")
cosy_speaker_name = gr.Dropdown(choices=["康辉", "rayray","其他可用说话人"], label="Speaker Name")
cosy_generate = gr.Button("点击语音合成")
with gr.Tab("Aliyun TTS推荐"):
with gr.Column():
aliyun_input_text = gr.Textbox(label="Input Text for Aliyun TTS")
aliyun_voice = gr.Dropdown(choices=list(ALIYUN_SUPPORTED_VOICES.keys()), label="Aliyun TTS Voice")
aliyun_generate = gr.Button("点击语音合成")
with gr.Tab("Azure TTS (海外不稳定)"):
with gr.Column():
azure_input_text = gr.Textbox(label="Input Text for Azure TTS")
azure_voice = gr.Dropdown(choices=list(AZURE_SUPPORTED_VOICES.keys()), label="Azure TTS Voice")
azure_generate = gr.Button("点击语音合成")
cosy_generate.click(textToSpeechUnifiedCosy, inputs=[cosy_input_text, cosy_speaker_name], outputs=audio_output)
def handle_uploaded_audio(file):
if isinstance(file, str): return file
else: return save_uploaded_audio(file)
audio_output.change(fn=handle_uploaded_audio, inputs=[audio_output], outputs=[audio_output])
with gr.Row():
with gr.Accordion("Basic Parameters", open=False):
path_input = gr.Textbox(label="Data Path", value=avatars[0]["config"]["path"])
workspace_input = gr.Textbox(label="Workspace Directory", value=avatars[0]["config"]["workspace"])
seed_input = gr.Number(label="Random Seed", value=avatars[0]["config"]["seed"], precision=0)
test_input = gr.Checkbox(label="Test Mode", value=True)
test_train_input = gr.Checkbox(label="Test Train Mode", value=True)
with gr.Accordion("TTS Advanced Settings语音TTS高级设置", open=False):
tts_rate = gr.Slider(-100, 100, value=0, label="Speech Rate")
tts_volume = gr.Slider(-100, 100, value=0, label="Volume")
run_button = gr.Button("生成")
def select_avatar(evt: gr.SelectData):
return update_params(evt.index)
avatar_gallery.select(fn=select_avatar, inputs=[], outputs=[path_input, workspace_input, seed_input, audio_output])
azure_generate.click(
textToSpeechUnifiedAzure,
inputs=[azure_input_text, azure_voice, tts_rate, tts_volume],
outputs=audio_output
)
aliyun_generate.click(
textToSpeechUnifiedAliyun,
inputs=[aliyun_input_text, aliyun_voice, tts_rate, tts_volume],
outputs=audio_output
)
# 事件绑定:运行推理
run_button.click(
fn=run_inference,
inputs=[path_input, workspace_input, seed_input, test_input, test_train_input, audio_output],
outputs=output_video
)
gr.HTML("""
<div align="center" style="text-align: center; font-size: 12px;">
<p style="display: inline; margin-right: 10px;">&copy; 2024 未来式智能</p>
<a href="https://agentspro.cn/" target="_blank" style="display: inline;">autoagents.ai</a>
</div>
""")
return app
if __name__ == "__main__":
# 启动Gradio服务
app = gradio_interface()
app.queue(max_size=20)
app.launch(server_name="0.0.0.0", server_port=8080, share=True)