agent-task-executor/core/task_executor.py

191 lines
6.9 KiB
Python

import time
import uuid
from typing import List, Dict, Any
from dataclasses import dataclass
from enum import Enum
import json
import logging
import asyncio
from .llm_executor import LLMExecutor
class TaskStatus(Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class StepState:
step_id: str
status: TaskStatus
required_info: List[str]
available_info: List[str]
missing_info: List[str]
resources_used: Dict[str, Any]
class TaskExecutor:
MAX_RETRIES = 3
TIMEOUT = 300 # seconds
CHECKPOINT_INTERVAL = 5 # steps
def __init__(self, llm_model: str = None):
"""Initialize TaskExecutor."""
self.task_id = str(uuid.uuid4())
self.start_time = time.time()
self.checkpoints = []
self.execution_path = []
self.current_step = None
self.retry_count = 0
self.llm_executor = LLMExecutor(model=llm_model)
self.task_input = None
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(f"TaskExecutor-{self.task_id}")
def get_status_update(self) -> dict:
"""Generate a status update for the current execution state."""
return {
"task_id": self.task_id,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
"current_step": self.current_step,
"checkpoints": self.checkpoints,
"resources": {
"used": {
"time": time.time() - self.start_time,
"memory": "N/A" # To be implemented
},
"available": {
"time": self.TIMEOUT - (time.time() - self.start_time),
"memory": "N/A" # To be implemented
}
},
"execution_path": self.execution_path,
"status": TaskStatus.COMPLETED.value if self.current_step and self.current_step.get("status") == TaskStatus.COMPLETED.value else TaskStatus.IN_PROGRESS.value
}
async def execute_step(self, step_id: str, step_data: Dict[str, Any]) -> bool:
"""Execute a single step of the task using LLM."""
try:
self.current_step = {
"id": step_id,
"name": step_data.get("name", "Unknown"),
"status": TaskStatus.IN_PROGRESS.value,
"progress": 0
}
# Check if execution should continue
if time.time() - self.start_time > self.TIMEOUT:
raise TimeoutError("Task execution timeout")
# Execute step using LLM
step_result = await self.llm_executor.execute_step(
step_instruction=step_data.get("instruction", ""),
step_input=step_data.get("input", {}),
context=step_data.get("context", {})
)
if not step_result["success"]:
raise Exception(f"Step failed: {step_result['error']}")
# Update step status
self.execution_path.append({
"step_id": step_id,
"result": step_result["output"]
})
self.current_step["status"] = TaskStatus.COMPLETED.value
self.current_step["progress"] = 100
# Create checkpoint if needed
if len(self.execution_path) % self.CHECKPOINT_INTERVAL == 0:
self.create_checkpoint()
return True
except Exception as e:
self.current_step["status"] = TaskStatus.FAILED.value
return self.handle_error(e)
def validate_input(self, input_data: Dict[str, Any]) -> bool:
"""Validate the input data for the task."""
return True
def create_checkpoint(self):
"""Create a checkpoint of the current execution state."""
checkpoint = {
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
"task_id": self.task_id,
"current_step": self.current_step,
"execution_path": self.execution_path.copy(),
"resources": {
"used": {
"time": time.time() - self.start_time,
"memory": "N/A" # To be implemented
}
}
}
self.checkpoints.append(checkpoint)
self.logger.info(f"Created checkpoint: {json.dumps(checkpoint, indent=2)}")
def rollback_to_checkpoint(self, checkpoint_index: int):
"""Rollback the execution to a specific checkpoint."""
if 0 <= checkpoint_index < len(self.checkpoints):
checkpoint = self.checkpoints[checkpoint_index]
# Implement state restoration logic
self.logger.info(f"Rolling back to checkpoint: {checkpoint['timestamp']}")
return True
return False
def get_next_actions(self) -> List[str]:
"""Determine the next possible actions based on current state."""
actions = []
if self.current_step and self.current_step["status"] == TaskStatus.FAILED.value:
actions.extend(["retry", "rollback", "abort"])
elif self.current_step and self.current_step["status"] == TaskStatus.COMPLETED.value:
actions.extend(["continue", "checkpoint"])
return actions
def handle_error(self, error: Exception):
"""Handle execution errors."""
self.logger.error(f"Error occurred: {str(error)}")
self.retry_count += 1
if self.retry_count >= self.MAX_RETRIES:
self.logger.error("Max retries reached. Terminating execution.")
return False
# Implement error recovery logic
return True
async def execute(self, task_input: Dict[str, Any]) -> Dict[str, Any]:
"""Execute the complete task."""
try:
# Store task input
self.task_input = task_input
if not self.validate_input(task_input):
raise ValueError("Invalid task input")
self.logger.info(f"Starting task execution: {self.task_id}")
# Execute each step in sequence
if hasattr(self, 'task_steps'):
for step in self.task_steps:
if not await self.execute_step(step["id"], step):
raise Exception(f"Step {step['id']} failed")
if len(self.execution_path) % self.CHECKPOINT_INTERVAL == 0:
self.create_checkpoint()
else:
raise ValueError("No task steps defined")
return self.get_status_update()
except Exception as e:
self.logger.error(f"Task failed: {str(e)}")
return self.get_status_update()