- Add core task execution framework - Add LLM integration with DeepSeek - Add text analysis task implementation - Add configuration management - Add tests and documentation
191 lines
6.9 KiB
Python
191 lines
6.9 KiB
Python
import time
|
|
import uuid
|
|
from typing import List, Dict, Any
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
import json
|
|
import logging
|
|
import asyncio
|
|
from llm_executor import LLMExecutor
|
|
|
|
class TaskStatus(Enum):
|
|
PENDING = "pending"
|
|
IN_PROGRESS = "in_progress"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
|
|
@dataclass
|
|
class StepState:
|
|
step_id: str
|
|
status: TaskStatus
|
|
required_info: List[str]
|
|
available_info: List[str]
|
|
missing_info: List[str]
|
|
resources_used: Dict[str, Any]
|
|
|
|
class TaskExecutor:
|
|
MAX_RETRIES = 3
|
|
TIMEOUT = 300 # seconds
|
|
CHECKPOINT_INTERVAL = 5 # steps
|
|
|
|
def __init__(self, llm_model: str = None):
|
|
"""Initialize TaskExecutor."""
|
|
self.task_id = str(uuid.uuid4())
|
|
self.start_time = time.time()
|
|
self.checkpoints = []
|
|
self.execution_path = []
|
|
self.current_step = None
|
|
self.retry_count = 0
|
|
self.llm_executor = LLMExecutor(model=llm_model)
|
|
self.task_input = None
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
self.logger = logging.getLogger(f"TaskExecutor-{self.task_id}")
|
|
|
|
def get_status_update(self) -> dict:
|
|
"""Generate a status update for the current execution state."""
|
|
return {
|
|
"task_id": self.task_id,
|
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
"current_step": self.current_step,
|
|
"checkpoints": self.checkpoints,
|
|
"resources": {
|
|
"used": {
|
|
"time": time.time() - self.start_time,
|
|
"memory": "N/A" # To be implemented
|
|
},
|
|
"available": {
|
|
"time": self.TIMEOUT - (time.time() - self.start_time),
|
|
"memory": "N/A" # To be implemented
|
|
}
|
|
},
|
|
"execution_path": self.execution_path,
|
|
"status": TaskStatus.COMPLETED.value if self.current_step and self.current_step.get("status") == TaskStatus.COMPLETED.value else TaskStatus.IN_PROGRESS.value
|
|
}
|
|
|
|
async def execute_step(self, step_id: str, step_data: Dict[str, Any]) -> bool:
|
|
"""Execute a single step of the task using LLM."""
|
|
try:
|
|
self.current_step = {
|
|
"id": step_id,
|
|
"name": step_data.get("name", "Unknown"),
|
|
"status": TaskStatus.IN_PROGRESS.value,
|
|
"progress": 0
|
|
}
|
|
|
|
# Check if execution should continue
|
|
if time.time() - self.start_time > self.TIMEOUT:
|
|
raise TimeoutError("Task execution timeout")
|
|
|
|
# Execute step using LLM
|
|
step_result = await self.llm_executor.execute_step(
|
|
step_instruction=step_data.get("instruction", ""),
|
|
step_input=step_data.get("input", {}),
|
|
context=step_data.get("context", {})
|
|
)
|
|
|
|
if not step_result["success"]:
|
|
raise Exception(f"Step failed: {step_result['error']}")
|
|
|
|
# Update step status
|
|
self.execution_path.append({
|
|
"step_id": step_id,
|
|
"result": step_result["output"]
|
|
})
|
|
|
|
self.current_step["status"] = TaskStatus.COMPLETED.value
|
|
self.current_step["progress"] = 100
|
|
|
|
# Create checkpoint if needed
|
|
if len(self.execution_path) % self.CHECKPOINT_INTERVAL == 0:
|
|
self.create_checkpoint()
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.current_step["status"] = TaskStatus.FAILED.value
|
|
return self.handle_error(e)
|
|
|
|
def validate_input(self, input_data: Dict[str, Any]) -> bool:
|
|
"""Validate the input data for the task."""
|
|
return True
|
|
|
|
def create_checkpoint(self):
|
|
"""Create a checkpoint of the current execution state."""
|
|
checkpoint = {
|
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
"task_id": self.task_id,
|
|
"current_step": self.current_step,
|
|
"execution_path": self.execution_path.copy(),
|
|
"resources": {
|
|
"used": {
|
|
"time": time.time() - self.start_time,
|
|
"memory": "N/A" # To be implemented
|
|
}
|
|
}
|
|
}
|
|
self.checkpoints.append(checkpoint)
|
|
self.logger.info(f"Created checkpoint: {json.dumps(checkpoint, indent=2)}")
|
|
|
|
def rollback_to_checkpoint(self, checkpoint_index: int):
|
|
"""Rollback the execution to a specific checkpoint."""
|
|
if 0 <= checkpoint_index < len(self.checkpoints):
|
|
checkpoint = self.checkpoints[checkpoint_index]
|
|
# Implement state restoration logic
|
|
self.logger.info(f"Rolling back to checkpoint: {checkpoint['timestamp']}")
|
|
return True
|
|
return False
|
|
|
|
def get_next_actions(self) -> List[str]:
|
|
"""Determine the next possible actions based on current state."""
|
|
actions = []
|
|
if self.current_step and self.current_step["status"] == TaskStatus.FAILED.value:
|
|
actions.extend(["retry", "rollback", "abort"])
|
|
elif self.current_step and self.current_step["status"] == TaskStatus.COMPLETED.value:
|
|
actions.extend(["continue", "checkpoint"])
|
|
return actions
|
|
|
|
def handle_error(self, error: Exception):
|
|
"""Handle execution errors."""
|
|
self.logger.error(f"Error occurred: {str(error)}")
|
|
self.retry_count += 1
|
|
|
|
if self.retry_count >= self.MAX_RETRIES:
|
|
self.logger.error("Max retries reached. Terminating execution.")
|
|
return False
|
|
|
|
# Implement error recovery logic
|
|
return True
|
|
|
|
async def execute(self, task_input: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Execute the complete task."""
|
|
try:
|
|
# Store task input
|
|
self.task_input = task_input
|
|
|
|
if not self.validate_input(task_input):
|
|
raise ValueError("Invalid task input")
|
|
|
|
self.logger.info(f"Starting task execution: {self.task_id}")
|
|
|
|
# Execute each step in sequence
|
|
if hasattr(self, 'task_steps'):
|
|
for step in self.task_steps:
|
|
if not await self.execute_step(step["id"], step):
|
|
raise Exception(f"Step {step['id']} failed")
|
|
|
|
if len(self.execution_path) % self.CHECKPOINT_INTERVAL == 0:
|
|
self.create_checkpoint()
|
|
else:
|
|
raise ValueError("No task steps defined")
|
|
|
|
return self.get_status_update()
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Task failed: {str(e)}")
|
|
return self.get_status_update()
|