111 lines
3.7 KiB
Python
111 lines
3.7 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Dict, Any, List, Optional
|
|
import os
|
|
import json
|
|
import httpx
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
|
|
class LLMProvider(Enum):
|
|
OPENAI = "openai"
|
|
AZURE = "azure"
|
|
ANTHROPIC = "anthropic"
|
|
DEEPSEEK = "deepseek"
|
|
CUSTOM = "custom"
|
|
|
|
@dataclass
|
|
class LLMConfig:
|
|
provider: LLMProvider
|
|
api_key: str
|
|
api_base: Optional[str] = None
|
|
model: str = "gpt-4"
|
|
max_tokens: int = 2000
|
|
temperature: float = 0.7
|
|
timeout: int = 30
|
|
|
|
class BaseLLMClient(ABC):
|
|
@abstractmethod
|
|
async def generate(self,
|
|
messages: List[Dict[str, str]],
|
|
**kwargs) -> Dict[str, Any]:
|
|
pass
|
|
|
|
class OpenAICompatibleClient(BaseLLMClient):
|
|
def __init__(self, config: LLMConfig):
|
|
self.config = config
|
|
|
|
async def generate(self,
|
|
messages: List[Dict[str, str]],
|
|
**kwargs) -> Dict[str, Any]:
|
|
"""
|
|
Generate a response using OpenAI-compatible API.
|
|
|
|
Args:
|
|
messages: List of message dictionaries with 'role' and 'content'
|
|
**kwargs: Additional parameters to pass to the API
|
|
|
|
Returns:
|
|
Dict containing the API response
|
|
"""
|
|
try:
|
|
# Prepare request data
|
|
request_data = {
|
|
"model": self.config.model,
|
|
"messages": messages,
|
|
"max_tokens": self.config.max_tokens,
|
|
"temperature": self.config.temperature,
|
|
**kwargs
|
|
}
|
|
|
|
# Make API request
|
|
async with httpx.AsyncClient(
|
|
base_url=self.config.api_base or "https://api.openai.com/v1",
|
|
timeout=self.config.timeout,
|
|
headers={
|
|
"Authorization": f"Bearer {self.config.api_key}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
) as client:
|
|
response = await client.post(
|
|
"/v1/chat/completions",
|
|
json=request_data
|
|
)
|
|
|
|
# Check for HTTP errors
|
|
response.raise_for_status()
|
|
|
|
# Parse response
|
|
response_data = response.json()
|
|
|
|
# Add usage information if available
|
|
if "usage" in response_data:
|
|
response_data["usage"] = {
|
|
"prompt_tokens": response_data["usage"].get("prompt_tokens", 0),
|
|
"completion_tokens": response_data["usage"].get("completion_tokens", 0),
|
|
"total_tokens": response_data["usage"].get("total_tokens", 0)
|
|
}
|
|
|
|
return response_data
|
|
|
|
except httpx.HTTPError as http_err:
|
|
# Handle HTTP errors
|
|
error_msg = f"HTTP error occurred: {str(http_err)}"
|
|
try:
|
|
error_data = http_err.response.json()
|
|
if "error" in error_data:
|
|
error_msg = f"API error: {error_data['error'].get('message', str(error_data['error']))}"
|
|
except:
|
|
pass
|
|
return {"error": error_msg}
|
|
|
|
except Exception as e:
|
|
# Handle other errors
|
|
return {"error": f"Error generating response: {str(e)}"}
|
|
|
|
def create_llm_client(config: LLMConfig) -> BaseLLMClient:
|
|
"""Factory function to create appropriate LLM client."""
|
|
if config.provider in [LLMProvider.OPENAI, LLMProvider.AZURE, LLMProvider.DEEPSEEK, LLMProvider.CUSTOM]:
|
|
return OpenAICompatibleClient(config)
|
|
else:
|
|
raise ValueError(f"Unsupported LLM provider: {config.provider}")
|