192 lines
9.3 KiB
Python
192 lines
9.3 KiB
Python
import os
|
||
import yaml
|
||
import gradio as gr
|
||
import asyncio
|
||
import requests
|
||
from inferenceapi import setup_args, inference
|
||
from thirdparty.edgetts.azuretts import textToSpeechAsync, SUPPORTED_VOICES as AZURE_SUPPORTED_VOICES
|
||
from thirdparty.aliyun.aliyuntts import aliyun_text_to_speech, SUPPORTED_VOICES as ALIYUN_SUPPORTED_VOICES
|
||
from thirdparty.ui.gradiotheme import EnhancedSeafoam
|
||
|
||
# 安装 PyYAML 库
|
||
# pip install pyyaml
|
||
|
||
# 加载Avatar配置文件
|
||
def load_avatar_configs(config_file="avatar_configs.yaml"):
|
||
with open(config_file, "r", encoding="utf-8") as f:
|
||
data = yaml.safe_load(f)
|
||
return data
|
||
|
||
avatar_data = load_avatar_configs()
|
||
avatars = avatar_data["avatars"]
|
||
|
||
# 定义全局模型缓存字典
|
||
models = {}
|
||
|
||
# 获取所有音频文件的唯一列表,用于下拉菜单和上传组件
|
||
all_audios = list(set(aud for avatar in avatars for aud in avatar["config"]["audios"]))
|
||
|
||
def run_inference(path, workspace, seed, test, test_train, aud):
|
||
args_dict = {
|
||
"path": path,
|
||
"workspace": workspace,
|
||
"seed": seed,
|
||
"test": test,
|
||
"test_train": test_train,
|
||
"aud": aud
|
||
}
|
||
opt = setup_args(None, args_dict)
|
||
assert os.path.exists(opt.workspace), "Workspace directory does not exist."
|
||
assert os.path.exists(opt.path), "Dataset path does not exist."
|
||
video_path = inference(opt)
|
||
return video_path
|
||
|
||
def update_params(avatar_index):
|
||
config = avatars[avatar_index]["config"]
|
||
default_audio = config['audios'][0] if config['audios'] else all_audios[0]
|
||
return config["path"], config["workspace"], config["seed"], default_audio
|
||
|
||
def find_examples(folder_path, extensions):
|
||
full_path = os.path.abspath(folder_path)
|
||
examples = []
|
||
for f in os.listdir(full_path):
|
||
if f.endswith(extensions):
|
||
file_path = os.path.join(full_path, f)
|
||
if os.path.exists(file_path):
|
||
examples.append(file_path)
|
||
else:
|
||
print(f"Warning: File not found: {file_path}")
|
||
return examples
|
||
|
||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||
audio_folder = os.path.join(base_dir, "assets", "")
|
||
audio_examples = find_examples(audio_folder, ('.mp3', '.wav', '.ogg'))
|
||
|
||
def save_uploaded_audio(file):
|
||
if not os.path.exists("assets"):
|
||
os.makedirs("assets")
|
||
file_path = os.path.join("assets", "temp_input.wav")
|
||
with open(file_path, "wb") as f:
|
||
f.write(file.read())
|
||
return file_path
|
||
|
||
def textToSpeechUnifiedAzure(text, voice, rate, volume):
|
||
return asyncio.run(textToSpeechAsync(text, voice, rate, volume))
|
||
|
||
def textToSpeechUnifiedAliyun(text, voice, rate, volume):
|
||
return asyncio.run(aliyun_text_to_speech(text, voice, rate, volume))[1]
|
||
|
||
def textToSpeechUnifiedCosy(text, sft_name="康辉", speaker_name="", prompt_text="", prompt_speech=""):
|
||
url = "http://localhost:8000/inference/tts"
|
||
headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||
data = {"query": text, "speaker_name": speaker_name, "sft_name": sft_name, "prompt_text": prompt_text, "prompt_speech": prompt_speech}
|
||
response = requests.post(url, headers=headers, json=data)
|
||
if response.status_code == 200:
|
||
audio_content = response.content
|
||
file_path = f"assets/cosy_tmp.wav"
|
||
with open(file_path, "wb") as f:
|
||
f.write(audio_content)
|
||
return file_path
|
||
else:
|
||
raise Exception(f"CosyVoice TTS request failed with status {response.status_code}")
|
||
|
||
enhanced_seafoam = EnhancedSeafoam()
|
||
logo_path = os.path.join(base_dir, "assets", "company_logo_small_1.png")
|
||
def get_version_details():
|
||
with open("version_history.md", "r", encoding="utf-8") as markdown_file:
|
||
version_details = markdown_file.read()
|
||
return version_details
|
||
|
||
def gradio_interface():
|
||
with gr.Blocks(css="""footer {display: none !important;} #audio_button {color: black !important;} #output_video {color: blue !important;} """, theme=enhanced_seafoam) as app:
|
||
logo = gr.Image(logo_path, show_label=False, container=False)
|
||
gr.Markdown('<h1 style="text-align: center;">真人复刻数字人 v1.2</h1>', elem_id="title")
|
||
gr.Markdown('<p style="text-align: center;">Choose an avatar and configure parameters for generating animated models based on speech.</p>', elem_id="subtitle")
|
||
with gr.Accordion("版本更新说明", open=False):
|
||
gr.Markdown(get_version_details())
|
||
gr.Markdown('<p style="text-align: center;">本应用仅供测试体验,非法版权人像传播引起的一切后果请自行承担!</p>', elem_id="subtitle")
|
||
with gr.Row():
|
||
with gr.Column(scale=1):
|
||
gr.Markdown("### Choose Avatar")
|
||
avatar_gallery = gr.Gallery(label="Avatar Preview", value=[(avatar['photo'], avatar['name']) for avatar in avatars])
|
||
with gr.Column(scale=1):
|
||
output_video = gr.Video(label="Inference Result", elem_id="output_video")
|
||
|
||
gr.Markdown("### Input Audio驱动用的音频,以下三种方式三选一")
|
||
with gr.Row():
|
||
audio_output = gr.Audio(label="Upload Audio上传或者麦克风录制语音", type="filepath")
|
||
with gr.Column(scale=1):
|
||
gr.Examples(examples=[[audio] for audio in audio_examples], inputs=[audio_output], label="Audio Examples", elem_id="audio_button")
|
||
with gr.Column(scale=2):
|
||
with gr.Tabs():
|
||
with gr.Tab("CosyVoice TTS(声音复刻,本地部署)"):
|
||
with gr.Column():
|
||
cosy_input_text = gr.Textbox(label="Input Text for CosyVoice TTS")
|
||
cosy_speaker_name = gr.Dropdown(choices=["康辉", "rayray","其他可用说话人"], label="Speaker Name")
|
||
cosy_generate = gr.Button("点击语音合成")
|
||
with gr.Tab("Aliyun TTS(推荐)"):
|
||
with gr.Column():
|
||
aliyun_input_text = gr.Textbox(label="Input Text for Aliyun TTS")
|
||
aliyun_voice = gr.Dropdown(choices=list(ALIYUN_SUPPORTED_VOICES.keys()), label="Aliyun TTS Voice")
|
||
aliyun_generate = gr.Button("点击语音合成")
|
||
with gr.Tab("Azure TTS (海外不稳定)"):
|
||
with gr.Column():
|
||
azure_input_text = gr.Textbox(label="Input Text for Azure TTS")
|
||
azure_voice = gr.Dropdown(choices=list(AZURE_SUPPORTED_VOICES.keys()), label="Azure TTS Voice")
|
||
azure_generate = gr.Button("点击语音合成")
|
||
|
||
cosy_generate.click(textToSpeechUnifiedCosy, inputs=[cosy_input_text, cosy_speaker_name], outputs=audio_output)
|
||
|
||
def handle_uploaded_audio(file):
|
||
if isinstance(file, str): return file
|
||
else: return save_uploaded_audio(file)
|
||
|
||
audio_output.change(fn=handle_uploaded_audio, inputs=[audio_output], outputs=[audio_output])
|
||
with gr.Row():
|
||
with gr.Accordion("Basic Parameters", open=False):
|
||
path_input = gr.Textbox(label="Data Path", value=avatars[0]["config"]["path"])
|
||
workspace_input = gr.Textbox(label="Workspace Directory", value=avatars[0]["config"]["workspace"])
|
||
seed_input = gr.Number(label="Random Seed", value=avatars[0]["config"]["seed"], precision=0)
|
||
test_input = gr.Checkbox(label="Test Mode", value=True)
|
||
test_train_input = gr.Checkbox(label="Test Train Mode", value=True)
|
||
with gr.Accordion("TTS Advanced Settings语音TTS高级设置", open=False):
|
||
tts_rate = gr.Slider(-100, 100, value=0, label="Speech Rate")
|
||
tts_volume = gr.Slider(-100, 100, value=0, label="Volume")
|
||
run_button = gr.Button("生成")
|
||
def select_avatar(evt: gr.SelectData):
|
||
return update_params(evt.index)
|
||
|
||
avatar_gallery.select(fn=select_avatar, inputs=[], outputs=[path_input, workspace_input, seed_input, audio_output])
|
||
|
||
azure_generate.click(
|
||
textToSpeechUnifiedAzure,
|
||
inputs=[azure_input_text, azure_voice, tts_rate, tts_volume],
|
||
outputs=audio_output
|
||
)
|
||
|
||
aliyun_generate.click(
|
||
textToSpeechUnifiedAliyun,
|
||
inputs=[aliyun_input_text, aliyun_voice, tts_rate, tts_volume],
|
||
outputs=audio_output
|
||
)
|
||
|
||
# 事件绑定:运行推理
|
||
run_button.click(
|
||
fn=run_inference,
|
||
inputs=[path_input, workspace_input, seed_input, test_input, test_train_input, audio_output],
|
||
outputs=output_video
|
||
)
|
||
|
||
gr.HTML("""
|
||
<div align="center" style="text-align: center; font-size: 12px;">
|
||
<p style="display: inline; margin-right: 10px;">© 2024 未来式智能</p>
|
||
<a href="https://agentspro.cn/" target="_blank" style="display: inline;">autoagents.ai</a>
|
||
</div>
|
||
""")
|
||
return app
|
||
|
||
if __name__ == "__main__":
|
||
# 启动Gradio服务
|
||
app = gradio_interface()
|
||
app.queue(max_size=20)
|
||
app.launch(server_name="0.0.0.0", server_port=8080, share=True) |