384 lines
14 KiB
Python
384 lines
14 KiB
Python
from typing import Optional, Dict
|
|
from fastapi import FastAPI, Body, File, UploadFile, BackgroundTasks, HTTPException, Form, APIRouter
|
|
from fastapi.responses import JSONResponse
|
|
import asyncio
|
|
import os
|
|
import yaml
|
|
from thirdparty.edgetts.azuretts import textToSpeechAsync, SUPPORTED_VOICES as AZURE_SUPPORTED_VOICES
|
|
from thirdparty.aliyun.aliyuntts import aliyun_text_to_speech_stream, SUPPORTED_VOICES as ALIYUN_SUPPORTED_VOICES
|
|
from thirdparty.lingda.filestorage import upload_files
|
|
from pydantic import BaseModel, Field
|
|
import logging
|
|
from config import load_config, setup_logging
|
|
from pathlib import Path
|
|
from inferenceapi import setup_args, inference
|
|
import requests
|
|
from urllib.parse import urljoin
|
|
import json
|
|
|
|
# Load the default configuration and setup logging
|
|
config = load_config("default_config.yaml")
|
|
setup_logging(config)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
API_PREFIX = config['server']['api_prefix']
|
|
# Create a FastAPI app instance
|
|
app = FastAPI(
|
|
title="真人复刻数字人 API", # API title
|
|
description="API 文档", # API description
|
|
version="1.0", # API version
|
|
docs_url=f"{API_PREFIX}/docs",
|
|
redoc_url=f"{API_PREFIX}/redoc",
|
|
openapi_url=f"{API_PREFIX}/openapi.json",
|
|
)
|
|
|
|
# 创建一个 APIRouter
|
|
api_router = APIRouter()
|
|
|
|
def load_avatar_configs(config_file: str) -> Dict:
|
|
"""
|
|
Load avatar configuration from a YAML file.
|
|
|
|
:param config_file: Path to the configuration file.
|
|
:return: Dictionary of avatar configurations.
|
|
"""
|
|
with open(config_file, "r", encoding="utf-8") as f:
|
|
data = yaml.safe_load(f)
|
|
return data
|
|
|
|
# Load avatar configurations
|
|
avatar_data = load_avatar_configs("avatar_configs.yaml")
|
|
avatars = avatar_data["avatars"]
|
|
|
|
class TTSRequest(BaseModel):
|
|
text: str = Field(..., description="要转换为语音的文本", example="你好呀,我是未来式数字人.")
|
|
voice: str = Field(config['aliyuntts']['default_voice'], description="使用的语音模型", example=config['aliyuntts']['default_voice'])
|
|
rate: int = Field(config['aliyuntts']['default_rate'], description="语速 (-100 到 100)", example=0)
|
|
volume: int = Field(config['aliyuntts']['default_volume'], description="音量 (-100 到 100)", example=0)
|
|
|
|
class VideoGenerationRequest(BaseModel):
|
|
avatar_name: str = Field(..., description="头像名称", example="康辉")
|
|
audio_url: str = Field(..., description="音频文件的 URL", example="https://test.agentspro.cn/api/fs/6747041e1f8ad06695190b0d.wav")
|
|
path: Optional[str] = Field(None, description="音频文件路径", example="")
|
|
workspace: Optional[str] = Field(None, description="工作区", example="")
|
|
seed: Optional[int] = Field(None, description="随机种子", example=1234)
|
|
test: Optional[bool] = Field(None, description="是否为测试模式", example=True)
|
|
test_train: Optional[bool] = Field(None, description="是否为测试训练模式", example=True)
|
|
|
|
@api_router.get("/configs", responses={
|
|
200: {
|
|
"description": "Successful retrieval of avatar configurations",
|
|
"content": {
|
|
"application/json": {
|
|
"example": {
|
|
"avatars": [
|
|
{
|
|
"name": "avatar1",
|
|
"config": {
|
|
"some_param": "value"
|
|
}
|
|
},
|
|
{
|
|
"name": "avatar2",
|
|
"config": {
|
|
"some_param": "value"
|
|
}
|
|
}
|
|
]
|
|
}
|
|
}
|
|
}
|
|
}
|
|
})
|
|
async def get_configs():
|
|
"""
|
|
获取所有可用的 Avatar 配置。
|
|
|
|
使用示例:
|
|
```
|
|
GET /configs
|
|
```
|
|
|
|
:return: 所有 Avatar 配置的字典。
|
|
"""
|
|
return avatars
|
|
|
|
@api_router.post("/text_to_speech/", responses={
|
|
200: {
|
|
"description": "Successful text-to-speech conversion",
|
|
"content": {
|
|
"application/json": {
|
|
"example": {
|
|
"audio_url": "http://example.com/generated_audio.wav"
|
|
}
|
|
}
|
|
}
|
|
},
|
|
500: {"description": "Internal Server Error"}
|
|
})
|
|
async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks):
|
|
"""
|
|
将文本转换为语音。
|
|
|
|
使用示例:
|
|
```
|
|
POST /text_to_speech/
|
|
{
|
|
"text": "你好呀,我是未来式数字人.",
|
|
"voice": "知冰_多情感",
|
|
"rate": 0,
|
|
"volume": 0
|
|
}
|
|
```
|
|
|
|
:param request: TTSRequest 包含文本、语音模型、语速和音量等参数
|
|
:param background_tasks: BackgroundTasks 用于后台任务处理
|
|
|
|
:return: 生成的音频文件 URL。
|
|
"""
|
|
logger.info(f"Received TTS request. Params: {request}")
|
|
|
|
try:
|
|
# Check supported voices and generate speech accordingly
|
|
if request.voice in AZURE_SUPPORTED_VOICES:
|
|
audio_data = await textToSpeechAsync(request.text, request.voice, request.rate, request.volume)
|
|
logger.info("Used Azure TTS")
|
|
elif request.voice in ALIYUN_SUPPORTED_VOICES:
|
|
audio_data = await aliyun_text_to_speech_stream(request.text, request.voice, request.rate, request.volume)
|
|
logger.info("Used Aliyun TTS")
|
|
else:
|
|
raise ValueError("Unsupported voice")
|
|
|
|
# Save and upload the generated audio file
|
|
temp_audio_path = save_audio(audio_data)
|
|
audio_url = upload_file(temp_audio_path)
|
|
|
|
# Schedule background task to remove the temporary audio file
|
|
background_tasks.add_task(os.remove, temp_audio_path)
|
|
|
|
logger.info(f"Audio uploaded to: {audio_url}")
|
|
return JSONResponse(content={"audio_url": audio_url})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in TTS: {str(e)}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail="Internal Server Error")
|
|
|
|
def save_audio(audio_data) -> str:
|
|
"""
|
|
Save the audio data to a temporary file.
|
|
|
|
:param audio_data: Audio data to be saved
|
|
:return: Path to the saved audio file
|
|
"""
|
|
save_dir = Path("output/tmp")
|
|
save_dir.mkdir(exist_ok=True, parents=True)
|
|
temp_audio_path = save_dir / "output_tts.wav"
|
|
|
|
with open(temp_audio_path, "wb") as f:
|
|
if type(audio_data) == bytes:
|
|
f.write(audio_data)
|
|
else:
|
|
audio_data.seek(0)
|
|
f.write(audio_data.getvalue())
|
|
|
|
logger.info(f"TTS audio generated: {temp_audio_path}")
|
|
return str(temp_audio_path)
|
|
|
|
def upload_file(file_path: str) -> str:
|
|
"""
|
|
Upload the file to a remote server and return the URL.
|
|
|
|
:param file_path: Path to the file to be uploaded
|
|
:return: URL of the uploaded file
|
|
"""
|
|
with open(file_path, "rb") as audio_file:
|
|
files = [{'name': os.path.basename(file_path), 'data': audio_file.read()}]
|
|
|
|
uploaded_files = upload_files(files)
|
|
audio_url = uploaded_files[0]
|
|
logger.info(f"Audio uploaded to: {audio_url}")
|
|
return audio_url
|
|
@api_router.post("/cosyvoice/", responses={
|
|
200: {
|
|
"description": "Successful cosyvoice text-to-speech conversion",
|
|
"content": {
|
|
"application/json": {
|
|
"example": {
|
|
"audio_url": "http://example.com/generated_audio.wav"
|
|
}
|
|
}
|
|
}
|
|
},
|
|
500: {"description": "Internal Server Error"}
|
|
})
|
|
async def cosy_tts(
|
|
background_tasks: BackgroundTasks,
|
|
text: str = Form(..., example="Hello, this is a test."),
|
|
speaker_name: str = Form(..., example="speaker1")
|
|
):
|
|
"""
|
|
使用 CosyVoice TTS 进行语音合成。
|
|
|
|
使用示例:
|
|
```
|
|
POST /cosyvoice
|
|
text="Hello, this is a test."
|
|
speaker_name="speaker1"
|
|
```
|
|
|
|
:param text: 要转换的文本
|
|
:param speaker_name: 使用的演讲者名称
|
|
|
|
:return: 生成的音频文件路径
|
|
"""
|
|
# Extract host and endpoint from config
|
|
host = config['cosyvoice']['host']
|
|
endpoint = config['cosyvoice']['endpoint']
|
|
# Ensure the host has a trailing slash and endpoint starts appropriately
|
|
if not host.endswith('/'):
|
|
host += '/'
|
|
# Join the URL parts
|
|
url = urljoin(host, endpoint)
|
|
logger.info(f"CosyVoice TTS request: {url}, {text}, {speaker_name}")
|
|
headers = {"accept": "application/json", "Content-Type": "application/json"}
|
|
data = {"query": text, "speaker_name": "","sft_name": speaker_name}
|
|
payload = json.dumps(data)
|
|
try:
|
|
response = requests.post(url, headers=headers, data=payload)
|
|
response.raise_for_status()
|
|
audio_content = response.content
|
|
temp_audio_path = save_audio(audio_content)
|
|
audio_url = upload_file(temp_audio_path)
|
|
|
|
# Schedule background task to remove the temporary audio file
|
|
background_tasks.add_task(os.remove, temp_audio_path)
|
|
|
|
logger.info(f"CosyVoice TTS audio generated: {audio_url}")
|
|
return JSONResponse(content={"audio_url": audio_url})
|
|
|
|
except requests.HTTPError as e:
|
|
logger.error(f"CosyVoice TTS request failed with HTTP error: {e}", exc_info=True)
|
|
raise HTTPException(status_code=response.status_code, detail="CosyVoice TTS service failed.")
|
|
except Exception as e:
|
|
logger.error(f"CosyVoice TTS request failed: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail="Internal Server Error")
|
|
|
|
@api_router.post("/generate_video", responses={
|
|
200: {
|
|
"description": "Successful video generation",
|
|
"content": {
|
|
"application/json": {
|
|
"example": {
|
|
"video_url": "http://example.com/generated_video.mp4"
|
|
}
|
|
}
|
|
}
|
|
},
|
|
400: {"description": "Bad Request, Avatar not found"},
|
|
500: {"description": "Internal Server Error"}
|
|
})
|
|
async def run_inference(
|
|
request_data: VideoGenerationRequest, background_tasks: BackgroundTasks
|
|
):
|
|
"""
|
|
根据提供的参数生成视频。
|
|
|
|
使用示例:
|
|
```
|
|
POST /generate_video
|
|
{
|
|
"avatar_name": "avatar1",
|
|
"audio_url": "https://test.agentspro.cn/api/fs/6747041e1f8ad06695190b0d.wav",
|
|
"path": "output",
|
|
"workspace": "workspace1",
|
|
"seed": 1234,
|
|
"test": false,
|
|
"test_train": false
|
|
}
|
|
```
|
|
|
|
:param request_data: VideoGenerationRequest 对象包含需要生成视频的信息
|
|
:param background_tasks: BackgroundTasks 用于后台任务处理
|
|
|
|
:return: 生成的视频文件 URL。
|
|
"""
|
|
try:
|
|
# Convert request data to dictionary
|
|
request_data_dict = request_data.dict()
|
|
|
|
# Find the avatar configuration
|
|
avatar_config = next((avatar for avatar in avatars if avatar['name'] == request_data.avatar_name), None)
|
|
|
|
if not avatar_config:
|
|
raise HTTPException(status_code=400, detail="Avatar not found.")
|
|
|
|
# Merge configuration from avatar with request data, giving priority to request data
|
|
config = {**avatar_config.get("config", {}), **{k: v for k, v in request_data_dict.items() if v is not None and v != ''}}
|
|
logger.info(f"Infer Config: {config}")
|
|
|
|
# Download the audio file
|
|
aud_file_path = await download_audio_from_url(config["audio_url"])
|
|
|
|
args_dict = {
|
|
"path": config["path"],
|
|
"workspace": config["workspace"],
|
|
"seed": config.get("seed", 42),
|
|
"test": config.get("test", True),
|
|
"test_train": config.get("test_train", True),
|
|
"aud": aud_file_path
|
|
}
|
|
|
|
# Setup arguments and run inference
|
|
opt = setup_args(None, args_dict)
|
|
video_path = inference(opt)
|
|
|
|
logger.info(f"Generated video: {video_path}")
|
|
|
|
# Upload the generated video and return the URL
|
|
video_url = upload_file(video_path)
|
|
background_tasks.add_task(os.remove, video_path) # Schedule background task to remove the temporary video file
|
|
|
|
logger.info(f"Video uploaded to: {video_url}")
|
|
return JSONResponse(content={"video_url": video_url})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in generate_video: {str(e)}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail="Internal Server Error")
|
|
|
|
async def download_audio_from_url(url: str) -> str:
|
|
"""
|
|
Download an audio file from a URL and save it locally.
|
|
|
|
:param url: URL of the audio file
|
|
:return: Path to the saved audio file
|
|
"""
|
|
response = requests.get(url)
|
|
file_path = "output/tmp/downloaded_audio.wav"
|
|
with open(file_path, "wb") as f:
|
|
f.write(response.content)
|
|
return file_path
|
|
|
|
async def save_uploaded_audio(file: UploadFile) -> str:
|
|
"""
|
|
Save an uploaded audio file locally.
|
|
|
|
:param file: Uploaded audio file
|
|
:return: Path to the saved audio file
|
|
"""
|
|
if not os.path.exists("assets"):
|
|
os.makedirs("assets")
|
|
|
|
file_path = os.path.join("assets", "temp_input.wav")
|
|
|
|
with open(file_path, "wb") as f:
|
|
f.write(await file.read())
|
|
|
|
logger.info(f"Uploaded audio saved to: {file_path}")
|
|
return file_path
|
|
|
|
# 将 APIRouter 包含到主应用中,并增加路径前缀
|
|
app.include_router(api_router, prefix=API_PREFIX)
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host=config['server']['host'], port=config['server']['port']) |