from fastapi import FastAPI, UploadFile, File, HTTPException import os import subprocess import shutil from pydantic import BaseModel app = FastAPI() class InferenceRequest(BaseModel): audio_file: str id: str workspace: str = "workspace" asr_model: str = "deepspeech" portrait: bool = True fullbody: bool = False @app.post("/api/inference/") async def inference( request: InferenceRequest, audio_file: UploadFile = File(...), ): """ Perform inference on the uploaded audio file. This endpoint processes an uploaded audio file using hubert.py, then runs inference using main.py. The output video is saved in the specified workspace directory and returned as a response. Parameters: - request (InferenceRequest): Request parameters including audio_file, id, workspace, asr_model, portrait, fullbody. - audio_file (UploadFile): The uploaded audio file to be processed. Returns: A JSON response containing a success message and the path to the output video. """ try: # Save the uploaded audio file to a temporary location audio_path = f"data/{request.audio_file}" with open(audio_path, "wb") as buffer: shutil.copyfileobj(audio_file.file, buffer) # Run hubert.py to process the audio file subprocess.run(["python", "data_utils/hubert.py", "--wav", audio_path], check=True) # Run inference using main.py inference_command = [ "python", "main.py", f"data/{request.id}", "--workspace", request.workspace, "-O", "--test", "--test_train", "--asr_model", request.asr_model, "--aud", audio_path, "--portrait" if request.portrait else "", "--fullbody" if request.fullbody else "" ] subprocess.run(inference_command, check=True) # Assuming the output video is saved in the workspace directory output_video = f"{request.workspace}/output.mp4" if not os.path.exists(output_video): raise HTTPException(status_code=500, detail="Inference failed to produce an output file.") return {"message": "Inference successful", "video_path": output_video} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="127.0.0.1", port=8000)