agent-task-executor/core/task_executor.py
2025-01-14 23:56:08 +08:00

200 lines
8.2 KiB
Python

import time
import uuid
from typing import List, Dict, Any
from dataclasses import dataclass
from enum import Enum
import json
import logging
import asyncio
from .llm_executor import LLMExecutor
# 任务状态枚举
class TaskStatus(Enum):
PENDING = "pending" # 等待中
IN_PROGRESS = "in_progress" # 进行中
COMPLETED = "completed" # 已完成
FAILED = "failed" # 已失败
# 步骤状态数据类
@dataclass
class StepState:
step_id: str # 步骤ID
status: TaskStatus # 当前状态
required_info: List[str] # 所需信息
available_info: List[str] # 可用信息
missing_info: List[str] # 缺失信息
resources_used: Dict[str, Any] # 已用资源
# 任务执行器主类
class TaskExecutor:
MAX_RETRIES = 3 # 最大重试次数
TIMEOUT = 300 # 超时时间(秒)
CHECKPOINT_INTERVAL = 5 # 检查点间隔(步骤数)
def __init__(self, llm_model: str = None):
"""初始化任务执行器"""
self.task_id = str(uuid.uuid4()) # 生成唯一任务ID
self.start_time = time.time() # 记录开始时间
self.checkpoints = [] # 检查点列表
self.execution_path = [] # 执行路径记录
self.current_step = None # 当前步骤
self.retry_count = 0 # 重试计数器
self.llm_executor = LLMExecutor(model=llm_model) # LLM执行器
self.task_input = None # 任务输入
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(f"TaskExecutor-{self.task_id}")
def get_status_update(self) -> dict:
"""生成当前执行状态的状态更新"""
return {
"task_id": self.task_id, # 任务ID
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), # 时间戳
"current_step": self.current_step, # 当前步骤
"checkpoints": self.checkpoints, # 检查点列表
"resources": {
"used": {
"time": time.time() - self.start_time, # 已用时间
"memory": "N/A" # 待实现:内存使用
},
"available": {
"time": self.TIMEOUT - (time.time() - self.start_time), # 剩余时间
"memory": "N/A" # 待实现:可用内存
}
},
"execution_path": self.execution_path, # 执行路径
"status": TaskStatus.COMPLETED.value if self.current_step and self.current_step.get("status") == TaskStatus.COMPLETED.value else TaskStatus.IN_PROGRESS.value # 当前状态
}
async def execute_step(self, step_id: str, step_data: Dict[str, Any]) -> bool:
"""使用LLM执行任务的单个步骤"""
try:
# 初始化当前步骤状态
self.current_step = {
"id": step_id, # 步骤ID
"name": step_data.get("name", "Unknown"), # 步骤名称
"status": TaskStatus.IN_PROGRESS.value, # 状态设为进行中
"progress": 0 # 进度初始化为0
}
# 检查是否超时
if time.time() - self.start_time > self.TIMEOUT:
raise TimeoutError("任务执行超时")
# 使用LLM执行步骤
step_result = await self.llm_executor.execute_step(
step_instruction=step_data.get("instruction", ""), # 步骤指令
step_input=step_data.get("input", {}), # 步骤输入
context=step_data.get("context", {}) # 上下文信息
)
# 检查步骤是否成功
if not step_result["success"]:
raise Exception(f"步骤失败: {step_result['error']}")
# 更新执行路径
self.execution_path.append({
"step_id": step_id, # 步骤ID
"result": step_result["output"] # 步骤结果
})
# 更新步骤状态
self.current_step["status"] = TaskStatus.COMPLETED.value # 状态设为已完成
self.current_step["progress"] = 100 # 进度设为100%
# 如果需要创建检查点
if len(self.execution_path) % self.CHECKPOINT_INTERVAL == 0:
self.create_checkpoint()
return True
except Exception as e:
# 处理异常
self.current_step["status"] = TaskStatus.FAILED.value # 状态设为失败
return self.handle_error(e) # 调用错误处理
def validate_input(self, input_data: Dict[str, Any]) -> bool:
"""验证任务输入数据"""
return True # 默认实现,子类可重写
def create_checkpoint(self):
"""创建当前执行状态的检查点"""
checkpoint = {
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), # 时间戳
"task_id": self.task_id, # 任务ID
"current_step": self.current_step, # 当前步骤
"execution_path": self.execution_path.copy(), # 执行路径副本
"resources": {
"used": {
"time": time.time() - self.start_time, # 已用时间
"memory": "N/A" # 待实现:内存使用
}
}
}
self.checkpoints.append(checkpoint) # 添加到检查点列表
self.logger.info(f"创建检查点: {json.dumps(checkpoint, indent=2)}") # 记录日志
def rollback_to_checkpoint(self, checkpoint_index: int):
"""回滚到指定检查点"""
if 0 <= checkpoint_index < len(self.checkpoints):
checkpoint = self.checkpoints[checkpoint_index] # 获取检查点
# 实现状态恢复逻辑
self.logger.info(f"回滚到检查点: {checkpoint['timestamp']}") # 记录日志
return True
return False # 检查点索引无效
def get_next_actions(self) -> List[str]:
"""根据当前状态确定下一步可能的操作"""
actions = []
if self.current_step and self.current_step["status"] == TaskStatus.FAILED.value:
actions.extend(["retry", "rollback", "abort"]) # 失败时可重试、回滚或中止
elif self.current_step and self.current_step["status"] == TaskStatus.COMPLETED.value:
actions.extend(["continue", "checkpoint"]) # 完成时可继续或创建检查点
return actions
def handle_error(self, error: Exception):
"""处理执行错误"""
self.logger.error(f"发生错误: {str(error)}") # 记录错误日志
self.retry_count += 1 # 增加重试计数
if self.retry_count >= self.MAX_RETRIES:
self.logger.error("达到最大重试次数。终止执行。") # 记录终止日志
return False
# 实现错误恢复逻辑
return True # 允许重试
async def execute(self, task_input: Dict[str, Any]) -> Dict[str, Any]:
"""执行完整任务"""
try:
# 存储任务输入
self.task_input = task_input
# 验证输入
if not self.validate_input(task_input):
raise ValueError("无效的任务输入")
self.logger.info(f"开始任务执行: {self.task_id}") # 记录开始日志
# 按顺序执行每个步骤
if hasattr(self, 'task_steps'):
for step in self.task_steps:
if not await self.execute_step(step["id"], step):
raise Exception(f"步骤 {step['id']} 失败")
# 按间隔创建检查点
if len(self.execution_path) % self.CHECKPOINT_INTERVAL == 0:
self.create_checkpoint()
else:
raise ValueError("未定义任务步骤")
return self.get_status_update() # 返回最终状态
except Exception as e:
self.logger.error(f"任务失败: {str(e)}") # 记录失败日志
return self.get_status_update() # 返回失败状态