digitalhumantalk/app.py
2024-12-10 17:05:37 +08:00

70 lines
2.4 KiB
Python

from fastapi import FastAPI, UploadFile, File, HTTPException
import os
import subprocess
import shutil
from pydantic import BaseModel
app = FastAPI()
class InferenceRequest(BaseModel):
audio_file: str
id: str
workspace: str = "workspace"
asr_model: str = "deepspeech"
portrait: bool = True
fullbody: bool = False
@app.post("/api/inference/")
async def inference(
request: InferenceRequest,
audio_file: UploadFile = File(...),
):
"""
Perform inference on the uploaded audio file.
This endpoint processes an uploaded audio file using hubert.py, then runs inference using main.py.
The output video is saved in the specified workspace directory and returned as a response.
Parameters:
- request (InferenceRequest): Request parameters including audio_file, id, workspace, asr_model, portrait, fullbody.
- audio_file (UploadFile): The uploaded audio file to be processed.
Returns:
A JSON response containing a success message and the path to the output video.
"""
try:
# Save the uploaded audio file to a temporary location
audio_path = f"data/{request.audio_file}"
with open(audio_path, "wb") as buffer:
shutil.copyfileobj(audio_file.file, buffer)
# Run hubert.py to process the audio file
subprocess.run(["python", "data_utils/hubert.py", "--wav", audio_path], check=True)
# Run inference using main.py
inference_command = [
"python", "main.py", f"data/{request.id}",
"--workspace", request.workspace,
"-O", "--test", "--test_train",
"--asr_model", request.asr_model,
"--aud", audio_path,
"--portrait" if request.portrait else "",
"--fullbody" if request.fullbody else ""
]
subprocess.run(inference_command, check=True)
# Assuming the output video is saved in the workspace directory
output_video = f"{request.workspace}/output.mp4"
if not os.path.exists(output_video):
raise HTTPException(status_code=500, detail="Inference failed to produce an output file.")
return {"message": "Inference successful", "video_path": output_video}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)