digitalhumantalk/fastapi_server.py
2024-12-10 17:05:37 +08:00

384 lines
14 KiB
Python

from typing import Optional, Dict
from fastapi import FastAPI, Body, File, UploadFile, BackgroundTasks, HTTPException, Form, APIRouter
from fastapi.responses import JSONResponse
import asyncio
import os
import yaml
from thirdparty.edgetts.azuretts import textToSpeechAsync, SUPPORTED_VOICES as AZURE_SUPPORTED_VOICES
from thirdparty.aliyun.aliyuntts import aliyun_text_to_speech_stream, SUPPORTED_VOICES as ALIYUN_SUPPORTED_VOICES
from thirdparty.lingda.filestorage import upload_files
from pydantic import BaseModel, Field
import logging
from config import load_config, setup_logging
from pathlib import Path
from inferenceapi import setup_args, inference
import requests
from urllib.parse import urljoin
import json
# Load the default configuration and setup logging
config = load_config("default_config.yaml")
setup_logging(config)
logger = logging.getLogger(__name__)
API_PREFIX = config['server']['api_prefix']
# Create a FastAPI app instance
app = FastAPI(
title="真人复刻数字人 API", # API title
description="API 文档", # API description
version="1.0", # API version
docs_url=f"{API_PREFIX}/docs",
redoc_url=f"{API_PREFIX}/redoc",
openapi_url=f"{API_PREFIX}/openapi.json",
)
# 创建一个 APIRouter
api_router = APIRouter()
def load_avatar_configs(config_file: str) -> Dict:
"""
Load avatar configuration from a YAML file.
:param config_file: Path to the configuration file.
:return: Dictionary of avatar configurations.
"""
with open(config_file, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
return data
# Load avatar configurations
avatar_data = load_avatar_configs("avatar_configs.yaml")
avatars = avatar_data["avatars"]
class TTSRequest(BaseModel):
text: str = Field(..., description="要转换为语音的文本", example="你好呀,我是未来式数字人.")
voice: str = Field(config['aliyuntts']['default_voice'], description="使用的语音模型", example=config['aliyuntts']['default_voice'])
rate: int = Field(config['aliyuntts']['default_rate'], description="语速 (-100 到 100)", example=0)
volume: int = Field(config['aliyuntts']['default_volume'], description="音量 (-100 到 100)", example=0)
class VideoGenerationRequest(BaseModel):
avatar_name: str = Field(..., description="头像名称", example="康辉")
audio_url: str = Field(..., description="音频文件的 URL", example="https://test.agentspro.cn/api/fs/6747041e1f8ad06695190b0d.wav")
path: Optional[str] = Field(None, description="音频文件路径", example="")
workspace: Optional[str] = Field(None, description="工作区", example="")
seed: Optional[int] = Field(None, description="随机种子", example=1234)
test: Optional[bool] = Field(None, description="是否为测试模式", example=True)
test_train: Optional[bool] = Field(None, description="是否为测试训练模式", example=True)
@api_router.get("/configs", responses={
200: {
"description": "Successful retrieval of avatar configurations",
"content": {
"application/json": {
"example": {
"avatars": [
{
"name": "avatar1",
"config": {
"some_param": "value"
}
},
{
"name": "avatar2",
"config": {
"some_param": "value"
}
}
]
}
}
}
}
})
async def get_configs():
"""
获取所有可用的 Avatar 配置。
使用示例:
```
GET /configs
```
:return: 所有 Avatar 配置的字典。
"""
return avatars
@api_router.post("/text_to_speech/", responses={
200: {
"description": "Successful text-to-speech conversion",
"content": {
"application/json": {
"example": {
"audio_url": "http://example.com/generated_audio.wav"
}
}
}
},
500: {"description": "Internal Server Error"}
})
async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks):
"""
将文本转换为语音。
使用示例:
```
POST /text_to_speech/
{
"text": "你好呀,我是未来式数字人.",
"voice": "知冰_多情感",
"rate": 0,
"volume": 0
}
```
:param request: TTSRequest 包含文本、语音模型、语速和音量等参数
:param background_tasks: BackgroundTasks 用于后台任务处理
:return: 生成的音频文件 URL。
"""
logger.info(f"Received TTS request. Params: {request}")
try:
# Check supported voices and generate speech accordingly
if request.voice in AZURE_SUPPORTED_VOICES:
audio_data = await textToSpeechAsync(request.text, request.voice, request.rate, request.volume)
logger.info("Used Azure TTS")
elif request.voice in ALIYUN_SUPPORTED_VOICES:
audio_data = await aliyun_text_to_speech_stream(request.text, request.voice, request.rate, request.volume)
logger.info("Used Aliyun TTS")
else:
raise ValueError("Unsupported voice")
# Save and upload the generated audio file
temp_audio_path = save_audio(audio_data)
audio_url = upload_file(temp_audio_path)
# Schedule background task to remove the temporary audio file
background_tasks.add_task(os.remove, temp_audio_path)
logger.info(f"Audio uploaded to: {audio_url}")
return JSONResponse(content={"audio_url": audio_url})
except Exception as e:
logger.error(f"Error in TTS: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail="Internal Server Error")
def save_audio(audio_data) -> str:
"""
Save the audio data to a temporary file.
:param audio_data: Audio data to be saved
:return: Path to the saved audio file
"""
save_dir = Path("output/tmp")
save_dir.mkdir(exist_ok=True, parents=True)
temp_audio_path = save_dir / "output_tts.wav"
with open(temp_audio_path, "wb") as f:
if type(audio_data) == bytes:
f.write(audio_data)
else:
audio_data.seek(0)
f.write(audio_data.getvalue())
logger.info(f"TTS audio generated: {temp_audio_path}")
return str(temp_audio_path)
def upload_file(file_path: str) -> str:
"""
Upload the file to a remote server and return the URL.
:param file_path: Path to the file to be uploaded
:return: URL of the uploaded file
"""
with open(file_path, "rb") as audio_file:
files = [{'name': os.path.basename(file_path), 'data': audio_file.read()}]
uploaded_files = upload_files(files)
audio_url = uploaded_files[0]
logger.info(f"Audio uploaded to: {audio_url}")
return audio_url
@api_router.post("/cosyvoice/", responses={
200: {
"description": "Successful cosyvoice text-to-speech conversion",
"content": {
"application/json": {
"example": {
"audio_url": "http://example.com/generated_audio.wav"
}
}
}
},
500: {"description": "Internal Server Error"}
})
async def cosy_tts(
background_tasks: BackgroundTasks,
text: str = Form(..., example="Hello, this is a test."),
speaker_name: str = Form(..., example="speaker1")
):
"""
使用 CosyVoice TTS 进行语音合成。
使用示例:
```
POST /cosyvoice
text="Hello, this is a test."
speaker_name="speaker1"
```
:param text: 要转换的文本
:param speaker_name: 使用的演讲者名称
:return: 生成的音频文件路径
"""
# Extract host and endpoint from config
host = config['cosyvoice']['host']
endpoint = config['cosyvoice']['endpoint']
# Ensure the host has a trailing slash and endpoint starts appropriately
if not host.endswith('/'):
host += '/'
# Join the URL parts
url = urljoin(host, endpoint)
logger.info(f"CosyVoice TTS request: {url}, {text}, {speaker_name}")
headers = {"accept": "application/json", "Content-Type": "application/json"}
data = {"query": text, "speaker_name": "","sft_name": speaker_name}
payload = json.dumps(data)
try:
response = requests.post(url, headers=headers, data=payload)
response.raise_for_status()
audio_content = response.content
temp_audio_path = save_audio(audio_content)
audio_url = upload_file(temp_audio_path)
# Schedule background task to remove the temporary audio file
background_tasks.add_task(os.remove, temp_audio_path)
logger.info(f"CosyVoice TTS audio generated: {audio_url}")
return JSONResponse(content={"audio_url": audio_url})
except requests.HTTPError as e:
logger.error(f"CosyVoice TTS request failed with HTTP error: {e}", exc_info=True)
raise HTTPException(status_code=response.status_code, detail="CosyVoice TTS service failed.")
except Exception as e:
logger.error(f"CosyVoice TTS request failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Internal Server Error")
@api_router.post("/generate_video", responses={
200: {
"description": "Successful video generation",
"content": {
"application/json": {
"example": {
"video_url": "http://example.com/generated_video.mp4"
}
}
}
},
400: {"description": "Bad Request, Avatar not found"},
500: {"description": "Internal Server Error"}
})
async def run_inference(
request_data: VideoGenerationRequest, background_tasks: BackgroundTasks
):
"""
根据提供的参数生成视频。
使用示例:
```
POST /generate_video
{
"avatar_name": "avatar1",
"audio_url": "https://test.agentspro.cn/api/fs/6747041e1f8ad06695190b0d.wav",
"path": "output",
"workspace": "workspace1",
"seed": 1234,
"test": false,
"test_train": false
}
```
:param request_data: VideoGenerationRequest 对象包含需要生成视频的信息
:param background_tasks: BackgroundTasks 用于后台任务处理
:return: 生成的视频文件 URL。
"""
try:
# Convert request data to dictionary
request_data_dict = request_data.dict()
# Find the avatar configuration
avatar_config = next((avatar for avatar in avatars if avatar['name'] == request_data.avatar_name), None)
if not avatar_config:
raise HTTPException(status_code=400, detail="Avatar not found.")
# Merge configuration from avatar with request data, giving priority to request data
config = {**avatar_config.get("config", {}), **{k: v for k, v in request_data_dict.items() if v is not None and v != ''}}
logger.info(f"Infer Config: {config}")
# Download the audio file
aud_file_path = await download_audio_from_url(config["audio_url"])
args_dict = {
"path": config["path"],
"workspace": config["workspace"],
"seed": config.get("seed", 42),
"test": config.get("test", True),
"test_train": config.get("test_train", True),
"aud": aud_file_path
}
# Setup arguments and run inference
opt = setup_args(None, args_dict)
video_path = inference(opt)
logger.info(f"Generated video: {video_path}")
# Upload the generated video and return the URL
video_url = upload_file(video_path)
background_tasks.add_task(os.remove, video_path) # Schedule background task to remove the temporary video file
logger.info(f"Video uploaded to: {video_url}")
return JSONResponse(content={"video_url": video_url})
except Exception as e:
logger.error(f"Error in generate_video: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail="Internal Server Error")
async def download_audio_from_url(url: str) -> str:
"""
Download an audio file from a URL and save it locally.
:param url: URL of the audio file
:return: Path to the saved audio file
"""
response = requests.get(url)
file_path = "output/tmp/downloaded_audio.wav"
with open(file_path, "wb") as f:
f.write(response.content)
return file_path
async def save_uploaded_audio(file: UploadFile) -> str:
"""
Save an uploaded audio file locally.
:param file: Uploaded audio file
:return: Path to the saved audio file
"""
if not os.path.exists("assets"):
os.makedirs("assets")
file_path = os.path.join("assets", "temp_input.wav")
with open(file_path, "wb") as f:
f.write(await file.read())
logger.info(f"Uploaded audio saved to: {file_path}")
return file_path
# 将 APIRouter 包含到主应用中,并增加路径前缀
app.include_router(api_router, prefix=API_PREFIX)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=config['server']['host'], port=config['server']['port'])