70 lines
2.4 KiB
Python
70 lines
2.4 KiB
Python
from fastapi import FastAPI, UploadFile, File, HTTPException
|
|
import os
|
|
import subprocess
|
|
import shutil
|
|
from pydantic import BaseModel
|
|
|
|
app = FastAPI()
|
|
|
|
class InferenceRequest(BaseModel):
|
|
audio_file: str
|
|
id: str
|
|
workspace: str = "workspace"
|
|
asr_model: str = "deepspeech"
|
|
portrait: bool = True
|
|
fullbody: bool = False
|
|
|
|
@app.post("/api/inference/")
|
|
async def inference(
|
|
request: InferenceRequest,
|
|
audio_file: UploadFile = File(...),
|
|
):
|
|
"""
|
|
Perform inference on the uploaded audio file.
|
|
|
|
This endpoint processes an uploaded audio file using hubert.py, then runs inference using main.py.
|
|
The output video is saved in the specified workspace directory and returned as a response.
|
|
|
|
Parameters:
|
|
- request (InferenceRequest): Request parameters including audio_file, id, workspace, asr_model, portrait, fullbody.
|
|
- audio_file (UploadFile): The uploaded audio file to be processed.
|
|
|
|
Returns:
|
|
A JSON response containing a success message and the path to the output video.
|
|
"""
|
|
try:
|
|
# Save the uploaded audio file to a temporary location
|
|
audio_path = f"data/{request.audio_file}"
|
|
with open(audio_path, "wb") as buffer:
|
|
shutil.copyfileobj(audio_file.file, buffer)
|
|
|
|
# Run hubert.py to process the audio file
|
|
subprocess.run(["python", "data_utils/hubert.py", "--wav", audio_path], check=True)
|
|
|
|
# Run inference using main.py
|
|
inference_command = [
|
|
"python", "main.py", f"data/{request.id}",
|
|
"--workspace", request.workspace,
|
|
"-O", "--test", "--test_train",
|
|
"--asr_model", request.asr_model,
|
|
"--aud", audio_path,
|
|
"--portrait" if request.portrait else "",
|
|
"--fullbody" if request.fullbody else ""
|
|
]
|
|
subprocess.run(inference_command, check=True)
|
|
|
|
# Assuming the output video is saved in the workspace directory
|
|
output_video = f"{request.workspace}/output.mp4"
|
|
if not os.path.exists(output_video):
|
|
raise HTTPException(status_code=500, detail="Inference failed to produce an output file.")
|
|
|
|
return {"message": "Inference successful", "video_path": output_video}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="127.0.0.1", port=8000)
|