200 lines
8.2 KiB
Python
200 lines
8.2 KiB
Python
import time
|
|
import uuid
|
|
from typing import List, Dict, Any
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
import json
|
|
import logging
|
|
import asyncio
|
|
from .llm_executor import LLMExecutor
|
|
|
|
# 任务状态枚举
|
|
class TaskStatus(Enum):
|
|
PENDING = "pending" # 等待中
|
|
IN_PROGRESS = "in_progress" # 进行中
|
|
COMPLETED = "completed" # 已完成
|
|
FAILED = "failed" # 已失败
|
|
|
|
# 步骤状态数据类
|
|
@dataclass
|
|
class StepState:
|
|
step_id: str # 步骤ID
|
|
status: TaskStatus # 当前状态
|
|
required_info: List[str] # 所需信息
|
|
available_info: List[str] # 可用信息
|
|
missing_info: List[str] # 缺失信息
|
|
resources_used: Dict[str, Any] # 已用资源
|
|
|
|
# 任务执行器主类
|
|
class TaskExecutor:
|
|
MAX_RETRIES = 3 # 最大重试次数
|
|
TIMEOUT = 300 # 超时时间(秒)
|
|
CHECKPOINT_INTERVAL = 5 # 检查点间隔(步骤数)
|
|
|
|
def __init__(self, llm_model: str = None):
|
|
"""初始化任务执行器"""
|
|
self.task_id = str(uuid.uuid4()) # 生成唯一任务ID
|
|
self.start_time = time.time() # 记录开始时间
|
|
self.checkpoints = [] # 检查点列表
|
|
self.execution_path = [] # 执行路径记录
|
|
self.current_step = None # 当前步骤
|
|
self.retry_count = 0 # 重试计数器
|
|
self.llm_executor = LLMExecutor(model=llm_model) # LLM执行器
|
|
self.task_input = None # 任务输入
|
|
|
|
# 配置日志
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
self.logger = logging.getLogger(f"TaskExecutor-{self.task_id}")
|
|
|
|
def get_status_update(self) -> dict:
|
|
"""生成当前执行状态的状态更新"""
|
|
return {
|
|
"task_id": self.task_id, # 任务ID
|
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), # 时间戳
|
|
"current_step": self.current_step, # 当前步骤
|
|
"checkpoints": self.checkpoints, # 检查点列表
|
|
"resources": {
|
|
"used": {
|
|
"time": time.time() - self.start_time, # 已用时间
|
|
"memory": "N/A" # 待实现:内存使用
|
|
},
|
|
"available": {
|
|
"time": self.TIMEOUT - (time.time() - self.start_time), # 剩余时间
|
|
"memory": "N/A" # 待实现:可用内存
|
|
}
|
|
},
|
|
"execution_path": self.execution_path, # 执行路径
|
|
"status": TaskStatus.COMPLETED.value if self.current_step and self.current_step.get("status") == TaskStatus.COMPLETED.value else TaskStatus.IN_PROGRESS.value # 当前状态
|
|
}
|
|
|
|
async def execute_step(self, step_id: str, step_data: Dict[str, Any]) -> bool:
|
|
"""使用LLM执行任务的单个步骤"""
|
|
try:
|
|
# 初始化当前步骤状态
|
|
self.current_step = {
|
|
"id": step_id, # 步骤ID
|
|
"name": step_data.get("name", "Unknown"), # 步骤名称
|
|
"status": TaskStatus.IN_PROGRESS.value, # 状态设为进行中
|
|
"progress": 0 # 进度初始化为0
|
|
}
|
|
|
|
# 检查是否超时
|
|
if time.time() - self.start_time > self.TIMEOUT:
|
|
raise TimeoutError("任务执行超时")
|
|
|
|
# 使用LLM执行步骤
|
|
step_result = await self.llm_executor.execute_step(
|
|
step_instruction=step_data.get("instruction", ""), # 步骤指令
|
|
step_input=step_data.get("input", {}), # 步骤输入
|
|
context=step_data.get("context", {}) # 上下文信息
|
|
)
|
|
|
|
# 检查步骤是否成功
|
|
if not step_result["success"]:
|
|
raise Exception(f"步骤失败: {step_result['error']}")
|
|
|
|
# 更新执行路径
|
|
self.execution_path.append({
|
|
"step_id": step_id, # 步骤ID
|
|
"result": step_result["output"] # 步骤结果
|
|
})
|
|
|
|
# 更新步骤状态
|
|
self.current_step["status"] = TaskStatus.COMPLETED.value # 状态设为已完成
|
|
self.current_step["progress"] = 100 # 进度设为100%
|
|
|
|
# 如果需要创建检查点
|
|
if len(self.execution_path) % self.CHECKPOINT_INTERVAL == 0:
|
|
self.create_checkpoint()
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
# 处理异常
|
|
self.current_step["status"] = TaskStatus.FAILED.value # 状态设为失败
|
|
return self.handle_error(e) # 调用错误处理
|
|
|
|
def validate_input(self, input_data: Dict[str, Any]) -> bool:
|
|
"""验证任务输入数据"""
|
|
return True # 默认实现,子类可重写
|
|
|
|
def create_checkpoint(self):
|
|
"""创建当前执行状态的检查点"""
|
|
checkpoint = {
|
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), # 时间戳
|
|
"task_id": self.task_id, # 任务ID
|
|
"current_step": self.current_step, # 当前步骤
|
|
"execution_path": self.execution_path.copy(), # 执行路径副本
|
|
"resources": {
|
|
"used": {
|
|
"time": time.time() - self.start_time, # 已用时间
|
|
"memory": "N/A" # 待实现:内存使用
|
|
}
|
|
}
|
|
}
|
|
self.checkpoints.append(checkpoint) # 添加到检查点列表
|
|
self.logger.info(f"创建检查点: {json.dumps(checkpoint, indent=2)}") # 记录日志
|
|
|
|
def rollback_to_checkpoint(self, checkpoint_index: int):
|
|
"""回滚到指定检查点"""
|
|
if 0 <= checkpoint_index < len(self.checkpoints):
|
|
checkpoint = self.checkpoints[checkpoint_index] # 获取检查点
|
|
# 实现状态恢复逻辑
|
|
self.logger.info(f"回滚到检查点: {checkpoint['timestamp']}") # 记录日志
|
|
return True
|
|
return False # 检查点索引无效
|
|
|
|
def get_next_actions(self) -> List[str]:
|
|
"""根据当前状态确定下一步可能的操作"""
|
|
actions = []
|
|
if self.current_step and self.current_step["status"] == TaskStatus.FAILED.value:
|
|
actions.extend(["retry", "rollback", "abort"]) # 失败时可重试、回滚或中止
|
|
elif self.current_step and self.current_step["status"] == TaskStatus.COMPLETED.value:
|
|
actions.extend(["continue", "checkpoint"]) # 完成时可继续或创建检查点
|
|
return actions
|
|
|
|
def handle_error(self, error: Exception):
|
|
"""处理执行错误"""
|
|
self.logger.error(f"发生错误: {str(error)}") # 记录错误日志
|
|
self.retry_count += 1 # 增加重试计数
|
|
|
|
if self.retry_count >= self.MAX_RETRIES:
|
|
self.logger.error("达到最大重试次数。终止执行。") # 记录终止日志
|
|
return False
|
|
|
|
# 实现错误恢复逻辑
|
|
return True # 允许重试
|
|
|
|
async def execute(self, task_input: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""执行完整任务"""
|
|
try:
|
|
# 存储任务输入
|
|
self.task_input = task_input
|
|
|
|
# 验证输入
|
|
if not self.validate_input(task_input):
|
|
raise ValueError("无效的任务输入")
|
|
|
|
self.logger.info(f"开始任务执行: {self.task_id}") # 记录开始日志
|
|
|
|
# 按顺序执行每个步骤
|
|
if hasattr(self, 'task_steps'):
|
|
for step in self.task_steps:
|
|
if not await self.execute_step(step["id"], step):
|
|
raise Exception(f"步骤 {step['id']} 失败")
|
|
|
|
# 按间隔创建检查点
|
|
if len(self.execution_path) % self.CHECKPOINT_INTERVAL == 0:
|
|
self.create_checkpoint()
|
|
else:
|
|
raise ValueError("未定义任务步骤")
|
|
|
|
return self.get_status_update() # 返回最终状态
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"任务失败: {str(e)}") # 记录失败日志
|
|
return self.get_status_update() # 返回失败状态
|