# business_agent/agent.py from __future__ import annotations from typing import AsyncGenerator, Dict, Any, List, Literal import asyncio import uuid from contextlib import asynccontextmanager from openai import AsyncAzureOpenAI, RateLimitError, APIStatusError, BadRequestError, APITimeoutError, InternalServerError from langchain_openai import AzureChatOpenAI from langchain_core.messages import BaseMessage, convert_to_messages, AIMessage from langchain_core.prompts import ChatPromptTemplate from langgraph.graph import StateGraph, END from .schemas import SafetyOutput, IntentNEROutput from .prompts import SAFETY_SYSTEM_PROMPT, INTENT_SYSTEM_PROMPT from .config import AgentConfig class AgentError(Exception): """自訂統一錯誤,讓後端容易 catch""" def __init__(self, message: str, error_type: str, request_id: str | None = None): super().__init__(message) self.error_type = error_type self.request_id = request_id class BusinessAgent: def __init__( self, azure_openai_api_key: str, azure_openai_endpoint: str, azure_deployment: str, vector_search_client: Any, azure_openai_api_version: str = "2024-11-20", config: AgentConfig | None = None, *, timeout_seconds: float = 30.0, # 全流程強制 timeout retry_attempts: int = 2, # 建議 1~3 次,多了反而拖慢 ): self.vector_client = vector_search_client self.config = config or AgentConfig() self.timeout_seconds = timeout_seconds self.retry_attempts = max(0, retry_attempts) deployment = azure_deployment or self.config.model self.base_llm = AzureChatOpenAI( azure_deployment=deployment, api_version=azure_openai_api_version or self.config.azure_api_version, api_key=azure_openai_api_key, azure_endpoint=azure_openai_endpoint, temperature=0.0, max_tokens=self.config.max_tokens, timeout=20.0, # 單次 API call 也設 timeout,避免卡死 max_retries=0, # 我們自己控 retry,不要 LangChain 幫忙 ) self.safety_llm = self.base_llm.with_config({"temperature": 0.0, "streaming": False}) self.intent_llm = self.base_llm.with_config({"temperature": 0.0, "streaming": False}) self.rag_llm = self.base_llm.with_config({"temperature": self.config.temperature_rag, "streaming": True}) self.graph = self._build_graph() # ====================== 核心:乾淨的 astream + 多層錯誤處理 ====================== async def astream( self, history: List[Dict[str, str]], request_context: Dict[str, Any] | None = None, ) -> AsyncGenerator[Dict[str, Any], None]: ctx = self._normalize_context(request_context) try: # 全流程 timeout + retry 包裝 result_generator = await self._run_with_retry( self._execute_graph, history=history, request_context=ctx, max_attempts=self.retry_attempts + 1 ) async for item in result_generator: yield item except asyncio.TimeoutError: await self._handle_error( ctx=ctx, error_type="TIMEOUT", message=f"Agent 執行逾時({self.timeout_seconds}s)", yield_first=True ) raise AgentError( message=f"Agent timeout after {self.timeout_seconds}s", error_type="timeout", request_id=ctx.get("request_id") ) except RateLimitError as e: await self._handle_error(ctx, "RATE_LIMIT", str(e), yield_first=True) raise AgentError("OpenAI Rate limit exceeded", "rate_limit", ctx.get("request_id")) from e except (BadRequestError, APIStatusError) as e: if e.status_code >= 400 and e.status_code < 500: await self._handle_error(ctx, "BAD_REQUEST", str(e)) raise AgentError("Invalid request to OpenAI", "bad_request", ctx.get("request_id")) from e else: await self._handle_error(ctx, "OPENAI_ERROR", str(e)) raise AgentError("OpenAI service error", "openai_error", ctx.get("request_id")) from e except Exception as e: await self._handle_error(ctx, "UNEXPECTED_ERROR", str(e)) raise AgentError("Unexpected error in BusinessAgent", "unexpected", ctx.get("request_id")) from e # ====================== 私有:執行 graph(縮排超淺!) ====================== async def _execute_graph(self, history: List[Dict], request_context: Dict) -> AsyncGenerator: messages = convert_to_messages(history) async with asyncio.timeout(self.timeout_seconds): async for event in self.graph.astream_events( input={ "messages": messages, "request_context": request_context, "log_events": [], "is_safe": False, "intent": "", "entities": {}, "quick_reply": None, "search_results": None }, version="v2", ): kind = event["event"] if kind == "on_chat_model_stream": chunk = event["data"]["chunk"] if chunk.content: yield {"type": "token", "content": chunk.content} elif kind == "on_graph_update": state = event["data"]["state"] if state.get("log_events"): log = state["log_events"][-1].copy() log["request_id"] = request_context.get("request_id") yield {"type": "log", "data": log} elif kind == "on_graph_end": final_state = event["data"]["state"] answer = "".join(m.content for m in final_state["messages"] if isinstance(m, AIMessage)) yield { "type": "final", "answer": answer, "logs": final_state.get("log_events", []), "request_id": request_context.get("request_id"), } # ====================== 私有:自動 retry + timeout 內建 ====================== async def _run_with_retry(self, func, **kwargs): last_exception = None for attempt in range(self.retry_attempts + 1): try: async for item in func(**kwargs): yield item return # 成功就直接結束 except (RateLimitError, InternalServerError, APITimeoutError) as e: last_exception = e if attempt < self.retry_attempts: wait = 2 ** attempt # exponential backoff yield {"type": "retry", "attempt": attempt + 1, "wait_seconds": wait} await asyncio.sleep(wait) continue # 超過次數就拋出 raise last_exception or Exception("Unknown error in retry") # ====================== 私有:統一錯誤處理 + log ====================== async def _handle_error(self, ctx: Dict, error_type: str, message: str, *, yield_first: bool = False): error_log = { "step": "ERROR", "error_type": error_type, "message": message[:500], "request_id": ctx.get("request_id"), "timestamp": asyncio.get_event_loop().time(), } if yield_first: try: yield {"type": "error", "data": error_log} except: pass # ====================== 私有:正規化 context ====================== def _normalize_context(self, ctx_in: Dict | None) -> Dict: ctx = ctx_in.copy() if ctx_in else {} if "request_id" not in ctx: ctx["request_id"] = f"req_{uuid.uuid4().hex[:12]}" return ctx # ====================== Graph 建置(不變) ====================== def _build_graph(self): # ... 你原本的 _build_graph 完全不變 ... # (包含所有 node、edge、compile()) pass