agent-task-executor/llm_client.py
zhukang dd797ab5e4 feat: initial commit of agent task executor framework
- Add core task execution framework
- Add LLM integration with DeepSeek
- Add text analysis task implementation
- Add configuration management
- Add tests and documentation
2025-01-14 20:53:09 +08:00

111 lines
3.7 KiB
Python

from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional
import os
import json
import httpx
from dataclasses import dataclass
from enum import Enum
class LLMProvider(Enum):
OPENAI = "openai"
AZURE = "azure"
ANTHROPIC = "anthropic"
DEEPSEEK = "deepseek"
CUSTOM = "custom"
@dataclass
class LLMConfig:
provider: LLMProvider
api_key: str
api_base: Optional[str] = None
model: str = "gpt-4"
max_tokens: int = 2000
temperature: float = 0.7
timeout: int = 30
class BaseLLMClient(ABC):
@abstractmethod
async def generate(self,
messages: List[Dict[str, str]],
**kwargs) -> Dict[str, Any]:
pass
class OpenAICompatibleClient(BaseLLMClient):
def __init__(self, config: LLMConfig):
self.config = config
async def generate(self,
messages: List[Dict[str, str]],
**kwargs) -> Dict[str, Any]:
"""
Generate a response using OpenAI-compatible API.
Args:
messages: List of message dictionaries with 'role' and 'content'
**kwargs: Additional parameters to pass to the API
Returns:
Dict containing the API response
"""
try:
# Prepare request data
request_data = {
"model": self.config.model,
"messages": messages,
"max_tokens": self.config.max_tokens,
"temperature": self.config.temperature,
**kwargs
}
# Make API request
async with httpx.AsyncClient(
base_url=self.config.api_base or "https://api.openai.com/v1",
timeout=self.config.timeout,
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json"
}
) as client:
response = await client.post(
"/v1/chat/completions",
json=request_data
)
# Check for HTTP errors
response.raise_for_status()
# Parse response
response_data = response.json()
# Add usage information if available
if "usage" in response_data:
response_data["usage"] = {
"prompt_tokens": response_data["usage"].get("prompt_tokens", 0),
"completion_tokens": response_data["usage"].get("completion_tokens", 0),
"total_tokens": response_data["usage"].get("total_tokens", 0)
}
return response_data
except httpx.HTTPError as http_err:
# Handle HTTP errors
error_msg = f"HTTP error occurred: {str(http_err)}"
try:
error_data = http_err.response.json()
if "error" in error_data:
error_msg = f"API error: {error_data['error'].get('message', str(error_data['error']))}"
except:
pass
return {"error": error_msg}
except Exception as e:
# Handle other errors
return {"error": f"Error generating response: {str(e)}"}
def create_llm_client(config: LLMConfig) -> BaseLLMClient:
"""Factory function to create appropriate LLM client."""
if config.provider in [LLMProvider.OPENAI, LLMProvider.AZURE, LLMProvider.DEEPSEEK, LLMProvider.CUSTOM]:
return OpenAICompatibleClient(config)
else:
raise ValueError(f"Unsupported LLM provider: {config.provider}")