# business_agent/agent.py from __future__ import annotations from typing import AsyncGenerator, Dict, Any, List, Literal import asyncio from openai import AsyncOpenAI from langchain_openai import ChatOpenAI from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, convert_to_messages from langchain_core.prompts import ChatPromptTemplate from langgraph.graph import StateGraph, END from langgraph.checkpoint.memory import MemorySaver from .schemas import SafetyOutput, IntentNEROutput from .prompts import SAFETY_SYSTEM_PROMPT, INTENT_SYSTEM_PROMPT from .config import AgentConfig # ====================== State 定義 ====================== from typing import TypedDict, Annotated import operator class AgentState(TypedDict): messages: Annotated[List[BaseMessage], operator.add] metadata: Dict[str, Any] # 執行過程中的中間結果 is_safe: bool intent: str entities: Dict[str, Any] quick_reply: str | None search_results: List[Dict] | None # 所有 log 最後會從這裡吐出 log_events: Annotated[List[Dict[str, Any]], operator.add] # ====================== 主 SDK Class ====================== class BusinessAgent: def __init__( self, openai_client: AsyncOpenAI, # 後端必須傳 openai>=1.53.0 的 client vector_search_client: Any, # 你的第三方搜尋 client(有 .search() 方法即可) config: AgentConfig | None = None, ): self.openai_client = openai_client self.vector_client = vector_search_client self.config = config or AgentConfig() # LangChain LLM(自動走最新 Responses API) self.llm = ChatOpenAI( model=self.config.model, temperature=0.0, openai_client=openai_client, streaming=True, max_tokens=self.config.max_tokens, ) # 不同階段不同 temperature self.safety_llm = self.llm.with_config({"temperature": self.config.temperature_safety}) self.intent_llm = self.llm.with_config({"temperature": self.config.temperature_intent}) self.rag_llm = self.llm.with_config({"temperature": self.config.temperature_rag}) self.graph = self._build_graph() # ====================== Graph 建置 ====================== def _build_graph(self): workflow = StateGraph(AgentState) workflow.add_node("safety_check", self._safety_check) workflow.add_node("intent_ner", self._intent_ner) workflow.add_node("quick_reply", self._quick_reply) workflow.add_node("vector_search", self._vector_search) workflow.add_node("rag_generate", self._rag_generate) workflow.add_node("reject", self._reject) workflow.set_entry_point("safety_check") workflow.add_conditional_edges( "safety_check", self._route_safety, {"safe": "intent_ner", "reject": "reject"} ) workflow.add_conditional_edges( "intent_ner", self._route_intent, { "quick_reply": "quick_reply", "rag_needed": "vector_search", "off_topic": "reject", } ) workflow.add_edge("vector_search", "rag_generate") workflow.add_edge("quick_reply", END) workflow.add_edge("rag_generate", END) workflow.add_edge("reject", END) # 每次 thread_id 不同 → 達成 stateless memory = MemorySaver() return workflow.compile(checkpointer=memory) # ====================== 對外 Streaming 入口 ====================== async def astream( self, history: List[Dict[str, str]], metadata: Dict[str, Any] | None = None, ) -> AsyncGenerator[Dict[str, Any], None]: """ 後端唯一呼叫點,返回 async generator 會即時 yield token + log,結束時一定會有 final 事件(含完整 logs) """ messages = convert_to_messages(history) thread_id = f"stateless_{asyncio.current_task().__hash__()}" # 保證每次不同 async for event in self.graph.astream_events( input={ "messages": messages, "metadata": metadata or {}, "log_events": [], "is_safe": False, "intent": "", "entities": {}, "quick_reply": None, "search_results": None, }, config={"configurable": {"thread_id": thread_id}}, version="v2", ): kind = event["event"] if kind == "on_chat_model_stream": chunk = event["data"]["chunk"] if chunk.content: yield {"type": "token", "content": chunk.content} elif kind == "on_graph_update": state = event["data"]["state"] if state.get("log_events"): latest_log = state["log_events"][-1] yield {"type": "log", "data": latest_log} elif kind == "on_graph_end": final_state = event["data"]["state"] final_answer = "".join(m.content for m in final_state["messages"] if isinstance(m, AIMessage)) yield { "type": "final", "answer": final_answer, "logs": final_state.get("log_events", []) } # ====================== Nodes ====================== async def _safety_check(self, state: AgentState) -> Dict[str, Any]: user_msg = state["messages"][-1].content prompt = ChatPromptTemplate.from_messages([ ("system", SAFETY_SYSTEM_PROMPT), ("human", user_msg), ]) chain = prompt | self.safety_llm.with_structured_output(SafetyOutput, method="json_schema") result = await chain.ainvoke({}) log = {"step": "safety_check", "is_safe": result.is_safe, "reason": result.reason or ""} return {"is_safe": result.is_safe, "log_events": [log]} def _route_safety(self, state: AgentState) -> Literal["safe", "reject"]: return "safe" if state["is_safe"] else "reject" async def _intent_ner(self, state: AgentState) -> Dict[str, Any]: user_msg = state["messages"][-1].content prompt = ChatPromptTemplate.from_messages([ ("system", INTENT_SYSTEM_PROMPT), ("human", user_msg), ]) chain = prompt | self.intent_llm.with_structured_output(IntentNEROutput, method="json_schema") result: IntentNEROutput = await chain.ainvoke({}) log = { "step": "intent_ner", "intent": result.intent, "entities": result.entities, "has_quick_reply": bool(result.quick_reply), } return { "intent": result.intent, "entities": result.entities.dict(), "quick_reply": result.quick_reply, "log_events": [log], } def _route_intent(self, state: AgentState) -> str: if state["quick_reply"]: return "quick_reply" if state["intent"] == "off_topic": return "off_topic" return "rag_needed" async def _quick_reply(self, state: AgentState) -> Dict[str, Any]: reply = state["quick_reply"] or "您好!有什麼可以幫您的?" return {"messages": [AIMessage(content=reply)]} async def _vector_search(self, state: AgentState) -> Dict[str, Any]: # 你們的第三方 API query = self._build_query(state["entities"], state["messages"][-1].content) results = await self.vector_client.search(query) # 假設是 async log = {"step": "vector_search", "query": query, "hit_count": len(results)} return {"search_results": results, "log_events": [log]} async def _rag_generate(self, state: AgentState) -> Dict[str, Any]: context = "\n\n".join([doc.get("content", "") for doc in state["search_results"] or []]) user_question = state["messages"][-1].content prompt = ChatPromptTemplate.from_messages([ ("system", "你是一個專業友善的業務助理,請根據以下資料回答問題:\n{context}"), ("human", user_question), ]) chain = prompt | self.rag_llm response = await chain.ainvoke({"context": context}) log = {"step": "rag_generate", "context_length": len(context)} return { "messages": [response], "log_events": [log], } async def _reject(self, state: AgentState) -> Dict[str, Any]: reply = "抱歉,我無法回答這個問題。" return {"messages": [AIMessage(content=reply)]} # ====================== 輔助方法 ====================== def _build_query(self, entities: dict, question: str) -> str: # 你自己實作 query 組合邏輯 return question