1. LangGraph介绍
官方文档:点此跳转
正如数据结构中的链表和图的关系,传统的AI工作流为链式调用,这就好比一个链表,它很直观易懂,但是局限性太强,如果想要基于某个条件选择不同的Agent进行工作,链式工作流就捉襟见肘了,如果当我在某处触发某个条件需要回溯到先前的节点,链式工作流更是无法实现。
而LangGraph提供了一种图结构,将各个节点以图的形式组织起来,从而实现更灵活的调用方式。
与图类似的,LangGraph也有节点(Node)和边(Edge)两个概念,节点不一定是Agent,也可以是任意对象,如LLM、工作流、以及可运行的如函数等,边则表示两个节点组成了一个工作流,注意边是有向边,即有明确的调用顺序。
与图不同的是,LangGraph还有几个新概念,Action和Router,严格来说,Action和Router都是节点,但是它们有特殊的含义,Action表示一个可运行的函数,Router表示一个选择器,它根据某些条件选择下一个节点。
简单来说,Action表示的是一个节点的业务逻辑,Router则是一个特殊的节点,它包含的是路由逻辑而非业务逻辑,用炒菜来举个例子,假设我们有一个炒菜的Agent,它需要调用不同的工具和食材,那么这些工具和食材就是节点,而炒菜的过程就是边,而Action就是具体的炒菜步骤,Router就是选择食材和工具的步骤。
在LangGraph框架中,Router使用if...else的形式来决定路径,主要通过以下三种方式实现:
- 提示词工程:指示大模型以特定格式作出回应
- 输出解析器:使用后处理从大模型响应中提取结构化数据
- 工具调用:利用大模型的内置工具调用功能来生成结构化输出
2. LangGraph安装
在此我使用uv来进行python的包管理。
uv init testcd testuv add langgraph如果是pip,运行pip install langgraph即可
3. LangGraph state.py 源码解读
💡 通俗理解:这个文件就像是 LangGraph 的”大脑”,它定义了如何构建和执行一个工作流图。想象一下,你有一个复杂的任务需要多个步骤完成,每个步骤(节点)都需要从共享的”工作台”(状态)上读取信息,处理完后把结果放回工作台。这个文件就是定义这个工作台和工作流程的规则。
目录
导入部分
from __future__ import annotations# 启用延迟类型注解评估,允许在类型提示中使用前向引用
import inspect# 用于检查函数签名和参数信息
import logging# 用于日志记录
import typing# 用于类型检查和类型操作
import warnings# 用于发出警告信息
from collections import defaultdict# 提供默认值字典,用于自动创建缺失的键
from collections.abc import Awaitable, Hashable, Sequence# 抽象基类:Awaitable(可等待对象)、Hashable(可哈希对象)、Sequence(序列)
from functools import partial# 用于创建偏函数(部分应用函数)
from inspect import isclass, isfunction, ismethod, signature# isclass: 检查是否为类# isfunction: 检查是否为函数# ismethod: 检查是否为方法# signature: 获取函数签名
from types import FunctionType# 函数类型
from typing import ( Any, # 任意类型 Callable, # 可调用对象类型 Generic, # 泛型基类 Literal, # 字面量类型 NamedTuple, # 命名元组 Protocol, # 协议类型(结构子类型) Union, # 联合类型 cast, # 类型转换函数 get_args, # 获取泛型参数 get_origin, # 获取泛型原始类型 get_type_hints, # 获取类型提示 overload, # 函数重载装饰器)
from langchain_core.runnables import Runnable, RunnableConfig# Runnable: 可运行对象接口# RunnableConfig: 运行配置对象
from pydantic import BaseModel, TypeAdapter# BaseModel: Pydantic 基础模型类# TypeAdapter: 类型适配器,用于类型转换和验证
from typing_extensions import Self, TypeAlias, Unpack, is_typeddict# Self: 表示当前类自身的类型# TypeAlias: 类型别名# Unpack: 解包操作符# is_typeddict: 检查是否为 TypedDict
from langgraph._typing import UNSET, DeprecatedKwargs# UNSET: 未设置值的标记# DeprecatedKwargs: 已弃用的关键字参数类型
from langgraph.cache.base import BaseCache# 缓存基类
from langgraph.channels.base import BaseChannel# 通道基类,用于节点间通信# 💡 通俗理解:通道就像是一个"共享邮箱",节点可以把数据放进去,其他节点可以从里面读取。每个状态字段都有自己的"邮箱"(通道)。
from langgraph.channels.binop import BinaryOperatorAggregate# 二元操作符聚合通道
from langgraph.channels.ephemeral_value import EphemeralValue# 临时值通道
from langgraph.channels.last_value import LastValue, LastValueAfterFinish# LastValue: 最后值通道# LastValueAfterFinish: 完成后最后值通道# 💡 通俗理解:LastValue 就像是一个"黑板",每次写入都会覆盖之前的内容,只保留最新的值。LastValueAfterFinish 则是等所有节点都完成后才更新。
from langgraph.channels.named_barrier_value import ( NamedBarrierValue, NamedBarrierValueAfterFinish,)# NamedBarrierValue: 命名屏障值通道(用于等待多个节点完成)# NamedBarrierValueAfterFinish: 完成后命名屏障值通道# 💡 通俗理解:这就像是一个"集合点",多个节点必须都到达这里,才能继续执行下一个节点。就像团队开会,必须所有人都到齐才能开始。
from langgraph.checkpoint.base import Checkpoint# 检查点基类,用于保存和恢复图执行状态# 💡 通俗理解:检查点就像游戏存档,可以保存当前进度,如果出错了可以回到之前的状态继续执行,不用从头开始。
from langgraph.constants import ( EMPTY_SEQ, # 空序列常量 END, # 结束节点标识 INTERRUPT, # 中断标识 MISSING, # 缺失值标识 NS_END, # 命名空间结束符 NS_SEP, # 命名空间分隔符 START, # 开始节点标识 TAG_HIDDEN, # 隐藏标签 TASKS, # 任务标识)
from langgraph.errors import ( ErrorCode, # 错误代码 InvalidUpdateError, # 无效更新错误 ParentCommand, # 父命令异常 create_error_message, # 创建错误消息)
from langgraph.graph.branch import Branch# 分支类,用于条件边
from langgraph.managed.base import ( ManagedValueSpec, # 托管值规范 is_managed_value, # 检查是否为托管值)
from langgraph.pregel import Pregel# Pregel 图执行引擎基类# 💡 通俗理解:Pregel 是 Google 提出的图计算模型,LangGraph 用它来实际执行图。就像是一个"调度器",负责决定哪个节点什么时候执行,如何传递数据。
from langgraph.pregel.read import ChannelRead, PregelNode# ChannelRead: 通道读取器# PregelNode: Pregel 节点
from langgraph.pregel.write import ( ChannelWrite, # 通道写入器 ChannelWriteEntry, # 通道写入条目 ChannelWriteTupleEntry, # 通道写入元组条目)
from langgraph.store.base import BaseStore# 存储基类
from langgraph.types import ( All, # 全部标识 CachePolicy, # 缓存策略 Checkpointer, # 检查点器 Command, # 命令类型 RetryPolicy, # 重试策略 Send, # 发送操作 StreamWriter, # 流写入器)
from langgraph.typing import InputT, OutputT, StateT, StateT_contra# InputT: 输入类型变量# OutputT: 输出类型变量# StateT: 状态类型变量# StateT_contra: 状态类型逆变变量
from langgraph.utils.fields import ( get_cached_annotated_keys, # 获取缓存的注解键 get_field_default, # 获取字段默认值 get_update_as_tuples, # 获取更新元组)
from langgraph.utils.pydantic import create_model# 创建 Pydantic 模型
from langgraph.utils.runnable import coerce_to_runnable# 强制转换为可运行对象
from langgraph.warnings import LangGraphDeprecatedSinceV05# LangGraph 弃用警告
logger = logging.getLogger(__name__)# 创建日志记录器工具函数
_warn_invalid_state_schema
def _warn_invalid_state_schema(schema: type[Any] | Any) -> None: """ 警告无效的状态模式。
如果模式不是类型,且没有泛型参数,则发出警告。 这有助于用户发现状态模式定义错误。 """ if isinstance(schema, type): # 如果是类型,直接返回(有效) return if typing.get_args(schema): # 如果有泛型参数(如 Annotated),也认为是有效的 return # 否则发出警告 warnings.warn( f"Invalid state_schema: {schema}. Expected a type or Annotated[type, reducer]. " "Please provide a valid schema to ensure correct updates.\n" " See: https://langchain-ai.github.io/langgraph/reference/graphs/#stategraph" )协议定义
这些协议定义了节点函数的不同签名变体,支持不同的参数组合:
💡 通俗理解:协议就像是一份”合同”,规定了节点函数可以有哪些不同的”签名”(参数组合)。有些节点只需要状态,有些还需要配置、写入器或存储。就像不同的工作岗位有不同的职责要求一样。
_StateNode
class _StateNode(Protocol[StateT_contra]): """最基本的节点协议:只接受状态参数""" def __call__(self, state: StateT_contra) -> Any: ..._NodeWithConfig
class _NodeWithConfig(Protocol[StateT_contra]): """带配置的节点协议:接受状态和配置参数""" def __call__(self, state: StateT_contra, config: RunnableConfig) -> Any: ..._NodeWithWriter
class _NodeWithWriter(Protocol[StateT_contra]): """带写入器的节点协议:接受状态和流写入器""" def __call__(self, state: StateT_contra, *, writer: StreamWriter) -> Any: ..._NodeWithStore
class _NodeWithStore(Protocol[StateT_contra]): """带存储的节点协议:接受状态和存储对象""" def __call__(self, state: StateT_contra, *, store: BaseStore) -> Any: ..._NodeWithWriterStore
class _NodeWithWriterStore(Protocol[StateT_contra]): """带写入器和存储的节点协议""" def __call__( self, state: StateT_contra, *, writer: StreamWriter, store: BaseStore ) -> Any: ..._NodeWithConfigWriter
class _NodeWithConfigWriter(Protocol[StateT_contra]): """带配置和写入器的节点协议""" def __call__( self, state: StateT_contra, *, config: RunnableConfig, writer: StreamWriter ) -> Any: ..._NodeWithConfigStore
class _NodeWithConfigStore(Protocol[StateT_contra]): """带配置和存储的节点协议""" def __call__( self, state: StateT_contra, *, config: RunnableConfig, store: BaseStore ) -> Any: ..._NodeWithConfigWriterStore
class _NodeWithConfigWriterStore(Protocol[StateT_contra]): """带配置、写入器和存储的节点协议(最完整的签名)""" def __call__( self, state: StateT_contra, *, config: RunnableConfig, writer: StreamWriter, store: BaseStore, ) -> Any: ...类型别名与辅助函数
StateNode 类型别名
StateNode: TypeAlias = Union[ _StateNode[StateT_contra], # 基础节点 _NodeWithConfig[StateT_contra], # 带配置 _NodeWithWriter[StateT_contra], # 带写入器 _NodeWithStore[StateT_contra], # 带存储 _NodeWithWriterStore[StateT_contra], # 带写入器和存储 _NodeWithConfigWriter[StateT_contra], # 带配置和写入器 _NodeWithConfigStore[StateT_contra], # 带配置和存储 _NodeWithConfigWriterStore[StateT_contra], # 带所有参数 Runnable[StateT_contra, Any], # 可运行对象]说明:StateNode 是所有支持的节点类型的联合类型,允许节点函数有不同的签名。
💡 通俗理解:
StateNode就像是一个”万能接口”,不管你写节点函数时用了哪些参数(只要符合其中一种协议),都可以被 LangGraph 识别和使用。这给了开发者很大的灵活性。
_get_node_name 函数
def _get_node_name(node: StateNode) -> str: """ 获取节点名称。
尝试从节点的 __name__ 属性获取,如果没有则使用类名。 如果都获取不到,抛出类型错误。 """ try: return getattr(node, "__name__", node.__class__.__name__) except AttributeError: raise TypeError(f"Unsupported node type: {type(node)}")StateNodeSpec 命名元组
💡 通俗理解:
StateNodeSpec就像是节点的”身份证”,记录了节点的所有信息:它是什么(runnable)、需要什么输入(input)、失败后怎么办(retry_policy)、结果可以缓存吗(cache_policy)、执行完后可以去哪里(ends)等等。
class StateNodeSpec(NamedTuple): """ 节点规范,存储节点的所有配置信息。
字段说明: - runnable: 节点可运行对象(函数或 Runnable) - metadata: 节点元数据字典(可选) - input: 节点输入模式类型 - retry_policy: 重试策略(可选,可以是单个策略或策略序列) - cache_policy: 缓存策略(可选) - ends: 节点可以路由到的目标节点(可选) 可以是元组(节点名列表)或字典(节点名到标签的映射) - defer: 是否延迟执行(直到运行即将结束时才执行) """ runnable: StateNode metadata: dict[str, Any] | None input: type[Any] retry_policy: RetryPolicy | Sequence[RetryPolicy] | None cache_policy: CachePolicy | None ends: tuple[str, ...] | dict[str, str] | None = EMPTY_SEQ defer: bool = FalseStateGraph 类
StateGraph 是 LangGraph 的核心类,用于构建基于状态的有向图。
💡 通俗理解:
StateGraph就像是一个”建筑图纸”,你用它来设计工作流程。你可以添加节点(工作步骤)、连接边(执行顺序)、设置条件分支(根据情况选择路径)。设计完成后,调用compile()方法把它”建造”成可以实际运行的图。
类属性
edges: set[tuple[str, str]]# 边的集合,每个边是 (起始节点, 结束节点) 的元组
nodes: dict[str, StateNodeSpec]# 节点字典,键是节点名,值是节点规范
branches: defaultdict[str, dict[str, Branch]]# 分支字典,键是源节点名,值是条件名到分支对象的映射
channels: dict[str, BaseChannel]# 通道字典,键是状态键名,值是通道对象# 💡 通俗理解:每个状态字段都有自己的"通道"(邮箱),节点通过通道读写数据
managed: dict[str, ManagedValueSpec]# 托管值字典,键是状态键名,值是托管值规范
schemas: dict[type[Any], dict[str, BaseChannel | ManagedValueSpec]]# 模式字典,键是模式类型,值是键到通道或托管值的映射
waiting_edges: set[tuple[tuple[str, ...], str]]# 等待边集合,用于多源节点到单目标节点的边# 💡 通俗理解:就像"等所有人到齐再出发",多个节点必须都完成,才能执行下一个节点
compiled: bool# 是否已编译的标志
state_schema: type[StateT]# 状态模式类型
input_schema: type[InputT]# 输入模式类型
output_schema: type[OutputT]# 输出模式类型__init__ 方法
def __init__( self, state_schema: type[StateT], # 状态模式(必需) config_schema: type[Any] | None = None, # 配置模式(可选) *, input_schema: type[InputT] | None = None, # 输入模式(可选) output_schema: type[OutputT] | None = None, # 输出模式(可选) **kwargs: Unpack[DeprecatedKwargs], # 已弃用的关键字参数) -> None: """ 初始化状态图。
处理已弃用的参数(input/output),然后初始化所有数据结构。 最后添加状态、输入和输出模式。 """
> 💡 **通俗理解**:初始化就像准备一个空的工作台,设置好工作台的规格(状态模式),准备好各种"邮箱"(通道)。之后你就可以往上面添加工作步骤(节点)了。
# 处理已弃用的 input 参数 if (input_ := kwargs.get("input", UNSET)) is not UNSET: warnings.warn( "`input` is deprecated and will be removed. Please use `input_schema` instead.", category=LangGraphDeprecatedSinceV05, stacklevel=2, ) if input_schema is None: input_schema = cast(Union[type[InputT], None], input_)
# 处理已弃用的 output 参数 if (output := kwargs.get("output", UNSET)) is not UNSET: warnings.warn( "`output` is deprecated and will be removed. Please use `output_schema` instead.", category=LangGraphDeprecatedSinceV05, stacklevel=2, ) if output_schema is None: output_schema = cast(Union[type[OutputT], None], output)
# 初始化数据结构 self.nodes = {} self.edges = set() self.branches = defaultdict(dict) self.schemas = {} self.channels = {} self.managed = {} self.compiled = False self.waiting_edges = set()
# 设置模式 self.state_schema = state_schema self.input_schema = cast(type[InputT], input_schema or state_schema) self.output_schema = cast(type[OutputT], output_schema or state_schema) self.config_schema = config_schema
# 添加模式(这会解析模式并创建通道) self._add_schema(self.state_schema) self._add_schema(self.input_schema, allow_managed=False) self._add_schema(self.output_schema, allow_managed=False)_all_edges 属性
@propertydef _all_edges(self) -> set[tuple[str, str]]: """ 获取所有边的集合(包括普通边和等待边)。
等待边会被展开为多个普通边。 """ return self.edges | { (start, end) for starts, end in self.waiting_edges for start in starts }_add_schema 方法
def _add_schema(self, schema: type[Any], /, allow_managed: bool = True) -> None: """ 添加模式到图中。
解析模式中的字段,创建相应的通道和托管值。 如果模式已存在,则跳过。
参数: - schema: 要添加的模式类型 - allow_managed: 是否允许托管值(输入/输出模式不允许) """
> 💡 **通俗理解**:模式就像是一个"表格模板",定义了状态有哪些字段、每个字段是什么类型。这个方法会为每个字段创建一个"邮箱"(通道),节点可以通过这些邮箱传递数据。`allow_managed=False` 的意思是:输入/输出模式不允许使用"托管值"(一种特殊的通道类型,由系统自动管理)。
if schema not in self.schemas: # 警告无效的模式 _warn_invalid_state_schema(schema) # 获取通道、托管值和类型提示 channels, managed, type_hints = _get_channels(schema)
# 检查托管值是否允许 if managed and not allow_managed: names = ", ".join(managed) schema_name = getattr(schema, "__name__", "") raise ValueError( f"Invalid managed channels detected in {schema_name}: {names}." " Managed channels are not permitted in Input/Output schema." )
# 保存模式信息 self.schemas[schema] = {**channels, **managed}
# 添加通道 for key, channel in channels.items(): if key in self.channels: if self.channels[key] != channel: # 如果通道类型不同,且不是 LastValue,则报错 if isinstance(channel, LastValue): pass # LastValue 可以兼容 else: raise ValueError( f"Channel '{key}' already exists with a different type" ) else: self.channels[key] = channel
# 添加托管值 for key, managed in managed.items(): if key in self.managed: if self.managed[key] != managed: raise ValueError( f"Managed value '{key}' already exists with a different type" ) else: self.managed[key] = managedadd_node 方法
add_node 方法有两个重载版本和一个实现版本:
重载版本 1
@overloaddef add_node( self, node: StateNode[StateT], # 节点函数或可运行对象 *, defer: bool = False, metadata: dict[str, Any] | None = None, input_schema: type[Any] | None = None, retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None, cache_policy: CachePolicy | None = None, destinations: dict[str, str] | tuple[str, ...] | None = None, **kwargs: Unpack[DeprecatedKwargs],) -> Self: """添加节点,节点名从函数/可运行对象名称推断""" ...重载版本 2
@overloaddef add_node( self, node: str, # 节点名称 action: StateNode[StateT], # 节点函数或可运行对象 *, defer: bool = False, metadata: dict[str, Any] | None = None, input_schema: type[Any] | None = None, retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None, cache_policy: CachePolicy | None = None, destinations: dict[str, str] | tuple[str, ...] | None = None, **kwargs: Unpack[DeprecatedKwargs],) -> Self: """添加节点,显式指定节点名""" ...实现版本
def add_node( self, node: str | StateNode[StateT], action: StateNode[StateT] | None = None, *, defer: bool = False, metadata: dict[str, Any] | None = None, input_schema: type[Any] | None = None, retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None, cache_policy: CachePolicy | None = None, destinations: dict[str, str] | tuple[str, ...] | None = None, **kwargs: Unpack[DeprecatedKwargs],) -> Self: """ 添加节点到图中。
主要步骤: 1. 处理已弃用的参数 2. 确定节点名称和动作 3. 验证节点名称 4. 推断输入模式和目标节点(从类型提示) 5. 创建节点规范并添加到图中 """
> 💡 **通俗理解**:添加节点就像是在工作流程中添加一个工作步骤。这个方法会: > - 检查节点名是否合法(不能是保留字,不能重复) > - 自动识别节点函数的类型提示,推断它需要什么输入、可以输出到哪里 > - 创建一个"身份证"(StateNodeSpec)记录这个节点的所有信息 > - 把节点添加到图中
# 处理已弃用的 retry 参数 if (retry := kwargs.get("retry", UNSET)) is not UNSET: warnings.warn( "`retry` is deprecated and will be removed. Please use `retry_policy` instead.", category=LangGraphDeprecatedSinceV05, ) if retry_policy is None: retry_policy = retry
# 处理已弃用的 input 参数 if (input_ := kwargs.get("input", UNSET)) is not UNSET: warnings.warn( "`input` is deprecated and will be removed. Please use `input_schema` instead.", category=LangGraphDeprecatedSinceV05, ) if input_schema is None: input_schema = cast(Union[type[InputT], None], input_)
# 确定节点名称和动作 if not isinstance(node, str): action = node if isinstance(action, Runnable): node = action.get_name() else: node = getattr(action, "__name__", action.__class__.__name__) if node is None: raise ValueError( "Node name must be provided if action is not a function" )
# 如果图已编译,发出警告 if self.compiled: logger.warning( "Adding a node to a graph that has already been compiled. This will " "not be reflected in the compiled graph." )
# 再次检查节点名称(处理边界情况) if not isinstance(node, str): action = node node = cast(str, getattr(action, "name", getattr(action, "__name__", None))) if node is None: raise ValueError( "Node name must be provided if action is not a function" )
# 验证动作存在 if action is None: raise RuntimeError
# 验证节点名唯一性 if node in self.nodes: raise ValueError(f"Node `{node}` already present.")
# 验证节点名不是保留字 if node == END or node == START: raise ValueError(f"Node `{node}` is reserved.")
# 验证节点名不包含保留字符 for character in (NS_SEP, NS_END): if character in node: raise ValueError( f"'{character}' is a reserved character and is not allowed in the node names." )
# 尝试从类型提示推断目标节点 ends: tuple[str, ...] | dict[str, str] = EMPTY_SEQ try: if ( isfunction(action) or ismethod(action) or ismethod(getattr(action, "__call__", None)) ) and ( hints := get_type_hints(getattr(action, "__call__")) or get_type_hints(action) ): # 推断输入模式 if input_schema is None: first_parameter_name = next( iter( inspect.signature( cast(FunctionType, action) ).parameters.keys() ) ) if input_hint := hints.get(first_parameter_name): if isinstance(input_hint, type) and get_type_hints(input_hint): input_schema = input_hint
# 推断返回类型中的目标节点 if rtn := hints.get("return"): # 处理 Union 类型 rtn_origin = get_origin(rtn) if rtn_origin is Union: rtn_args = get_args(rtn) # 查找 Command 类型 for arg in rtn_args: arg_origin = get_origin(arg) if arg_origin is Command: rtn = arg rtn_origin = arg_origin break
# 检查是否为 Command 类型 if ( rtn_origin is Command and (rargs := get_args(rtn)) and get_origin(rargs[0]) is Literal and (vals := get_args(rargs[0])) ): ends = vals except (NameError, TypeError, StopIteration): pass
# 如果提供了 destinations,使用它 if destinations is not None: ends = destinations
# 添加输入模式(如果指定) if input_schema is not None: self._add_schema(input_schema)
# 创建节点规范并添加到图中 self.nodes[node] = StateNodeSpec( coerce_to_runnable(action, name=node, trace=False), metadata, input=input_schema or self.state_schema, retry_policy=retry_policy, cache_policy=cache_policy, ends=ends, defer=defer, ) return selfadd_edge 方法
def add_edge(self, start_key: str | list[str], end_key: str) -> Self: """ 添加有向边到图中。
支持两种模式: 1. 单源到单目标:start_key 是字符串 2. 多源到单目标:start_key 是字符串列表(等待所有源节点完成)
参数: - start_key: 起始节点(或节点列表) - end_key: 结束节点
返回: - Self: 允许方法链式调用 """
> 💡 **通俗理解**:添加边就像是在工作流程中画箭头,表示"先做这个,再做那个"。 > - 单源边:A → B(A 完成后执行 B) > - 多源边:[A, B, C] → D(A、B、C 都完成后才执行 D,就像等所有同事都到齐再开会)
if self.compiled: logger.warning( "Adding an edge to a graph that has already been compiled. This will " "not be reflected in the compiled graph." )
# 处理单源情况 if isinstance(start_key, str): if start_key == END: raise ValueError("END cannot be a start node") if end_key == START: raise ValueError("START cannot be an end node")
# 验证(仅对非 StateGraph 图) if not hasattr(self, "channels") and start_key in set( start for start, _ in self.edges ): raise ValueError( f"Already found path for node '{start_key}'.\n" "For multiple edges, use StateGraph with an Annotated state key." )
self.edges.add((start_key, end_key)) return self
# 处理多源情况 for start in start_key: if start == END: raise ValueError("END cannot be a start node") if start not in self.nodes: raise ValueError(f"Need to add_node `{start}` first") if end_key == START: raise ValueError("START cannot be an end node") if end_key != END and end_key not in self.nodes: raise ValueError(f"Need to add_node `{end_key}` first")
# 添加到等待边集合 self.waiting_edges.add((tuple(start_key), end_key)) return selfadd_conditional_edges 方法
def add_conditional_edges( self, source: str, # 源节点 path: Callable[..., Hashable | list[Hashable]] | Callable[..., Awaitable[Hashable | list[Hashable]]] | Runnable[Any, Hashable | list[Hashable]], # 路径函数 path_map: dict[Hashable, str] | list[str] | None = None, # 路径映射) -> Self: """ 添加条件边。
条件边允许根据状态动态选择下一个节点。
参数: - source: 源节点名 - path: 路径函数,根据状态返回下一个节点名(或节点名列表) - path_map: 可选的路径映射,将路径函数的返回值映射到节点名
返回: - Self: 允许方法链式调用 """
> 💡 **通俗理解**:条件边就像是一个"智能路口",根据当前情况(状态)决定走哪条路。 > - `path` 函数就像是一个"导航员",查看当前状态后决定下一步去哪里 > - `path_map` 就像一个"翻译表",把导航员的决定(可能是数字、字符串)翻译成实际的节点名 > - 例如:如果用户输入是问题,就去"回答节点";如果是命令,就去"执行节点"
if self.compiled: logger.warning( "Adding an edge to a graph that has already been compiled. This will " "not be reflected in the compiled graph." )
# 将路径函数转换为可运行对象 path = coerce_to_runnable(path, name=None, trace=True) name = path.name or "condition"
# 验证条件名唯一性 if name in self.branches[source]: raise ValueError( f"Branch with name `{path.name}` already exists for node `{source}`" )
# 创建分支并保存 self.branches[source][name] = Branch.from_path(path, path_map, True)
# 如果分支有输入模式,添加它 if schema := self.branches[source][name].input_schema: self._add_schema(schema) return selfadd_sequence 方法
def add_sequence( self, nodes: Sequence[StateNode[StateT] | tuple[str, StateNode[StateT]]],) -> Self: """ 添加节点序列。
按顺序添加多个节点,并在它们之间创建边。
参数: - nodes: 节点序列,可以是 StateNode 或 (名称, StateNode) 元组
返回: - Self: 允许方法链式调用 """ if len(nodes) < 1: raise ValueError("Sequence requires at least one node.")
previous_name: str | None = None for node in nodes: # 确定节点名称 if isinstance(node, tuple) and len(node) == 2: name, node = node else: name = _get_node_name(node)
# 验证节点名唯一性 if name in self.nodes: raise ValueError( f"Node names must be unique: node with the name '{name}' already exists. " "If you need to use two different runnables/callables with the same name (for example, using `lambda`), please provide them as tuples (name, runnable/callable)." )
# 添加节点 self.add_node(name, node)
# 如果不是第一个节点,添加边 if previous_name is not None: self.add_edge(previous_name, name)
previous_name = name
return selfset_entry_point 方法
def set_entry_point(self, key: str) -> Self: """ 设置图的入口点。
等价于调用 add_edge(START, key)。
参数: - key: 入口节点名
返回: - Self: 允许方法链式调用 """ return self.add_edge(START, key)set_conditional_entry_point 方法
def set_conditional_entry_point( self, path: Callable[..., Hashable | list[Hashable]] | Callable[..., Awaitable[Hashable | list[Hashable]]] | Runnable[Any, Hashable | list[Hashable]], path_map: dict[Hashable, str] | list[str] | None = None,) -> Self: """ 设置条件入口点。
允许根据输入动态选择第一个节点。
参数: - path: 路径函数 - path_map: 路径映射
返回: - Self: 允许方法链式调用 """ return self.add_conditional_edges(START, path, path_map)set_finish_point 方法
def set_finish_point(self, key: str) -> Self: """ 设置图的结束点。
等价于调用 add_edge(key, END)。
参数: - key: 结束节点名
返回: - Self: 允许方法链式调用 """ return self.add_edge(key, END)validate 方法
def validate(self, interrupt: Sequence[str] | None = None) -> Self: """ 验证图的结构。
检查: 1. 所有边的源节点都存在 2. 图有入口点 3. 所有边的目标节点都存在 4. 中断节点存在
参数: - interrupt: 中断节点列表(可选)
返回: - Self: 允许方法链式调用
副作用: - 设置 self.compiled = True """
> 💡 **通俗理解**:验证就像是在"施工前检查图纸",确保: > - 所有箭头都指向存在的节点(不能指向空) > - 有起点(不能没有入口) > - 所有目标节点都存在(不能指向不存在的节点) > - 如果设置了中断点,这些节点必须存在 > 如果检查通过,就标记为"已验证",可以编译了。
# 收集所有源节点 all_sources = {src for src, _ in self._all_edges} for start, branches in self.branches.items(): all_sources.add(start) for name, spec in self.nodes.items(): if spec.ends: all_sources.add(name)
# 验证源节点 for source in all_sources: if source not in self.nodes and source != START: raise ValueError(f"Found edge starting at unknown node '{source}'")
# 验证有入口点 if START not in all_sources: raise ValueError( "Graph must have an entrypoint: add at least one edge from START to another node" )
# 收集所有目标节点 all_targets = {end for _, end in self._all_edges} for start, branches in self.branches.items(): for cond, branch in branches.items(): if branch.ends is not None: for end in branch.ends.values(): if end not in self.nodes and end != END: raise ValueError( f"At '{start}' node, '{cond}' branch found unknown target '{end}'" ) all_targets.add(end) else: # 如果没有指定目标,可能路由到任何节点 all_targets.add(END) for node in self.nodes: if node != start: all_targets.add(node) for name, spec in self.nodes.items(): if spec.ends: all_targets.update(spec.ends)
# 验证目标节点 for target in all_targets: if target not in self.nodes and target != END: raise ValueError(f"Found edge ending at unknown node `{target}`")
# 验证中断节点 if interrupt: for node in interrupt: if node not in self.nodes: raise ValueError(f"Interrupt node `{node}` not found")
self.compiled = True return selfcompile 方法
def compile( self, checkpointer: Checkpointer = None, # 检查点器 *, cache: BaseCache | None = None, # 缓存 store: BaseStore | None = None, # 存储 interrupt_before: All | list[str] | None = None, # 执行前中断的节点 interrupt_after: All | list[str] | None = None, # 执行后中断的节点 debug: bool = False, # 调试模式 name: str | None = None, # 图名称) -> CompiledStateGraph[StateT, InputT, OutputT]: """ 编译状态图为 CompiledStateGraph。
编译后的图实现了 Runnable 接口,可以调用、流式传输、批处理和异步运行。
主要步骤: 1. 验证图结构 2. 准备输出通道和流通道 3. 创建 CompiledStateGraph 实例 4. 附加所有节点、边和分支
返回: - CompiledStateGraph: 编译后的状态图 """
> 💡 **通俗理解**:编译就像把"设计图纸"变成"实际可运行的机器"。 > - 先检查图纸有没有问题(validate) > - 准备好输出接口(output_channels)和流式接口(stream_channels) > - 创建一个"执行引擎"(CompiledStateGraph) > - 把所有节点、边、分支都"安装"到引擎上 > 编译完成后,你就可以调用 `invoke()` 或 `stream()` 来运行这个图了!
# 设置默认值 interrupt_before = interrupt_before or [] interrupt_after = interrupt_after or []
# 验证图 self.validate( interrupt=( (interrupt_before if interrupt_before != "*" else []) + interrupt_after if interrupt_after != "*" else [] ) )
# 准备输出通道 output_channels = ( "__root__" if len(self.schemas[self.output_schema]) == 1 and "__root__" in self.schemas[self.output_schema] else [ key for key, val in self.schemas[self.output_schema].items() if not is_managed_value(val) ] )
# 准备流通道 stream_channels = ( "__root__" if len(self.channels) == 1 and "__root__" in self.channels else [ key for key, val in self.channels.items() if not is_managed_value(val) ] )
# 创建编译后的图 compiled = CompiledStateGraph[StateT, InputT, OutputT]( builder=self, schema_to_mapper={}, config_type=self.config_schema, nodes={}, channels={ **self.channels, **self.managed, START: EphemeralValue(self.input_schema), }, input_channels=START, stream_mode="updates", output_channels=output_channels, stream_channels=stream_channels, checkpointer=checkpointer, interrupt_before_nodes=interrupt_before, interrupt_after_nodes=interrupt_after, auto_validate=False, debug=debug, store=store, cache=cache, name=name or "LangGraph", )
# 附加 START 节点 compiled.attach_node(START, None)
# 附加所有节点 for key, node in self.nodes.items(): compiled.attach_node(key, node)
# 附加所有边 for start, end in self.edges: compiled.attach_edge(start, end)
# 附加所有等待边 for starts, end in self.waiting_edges: compiled.attach_edge(starts, end)
# 附加所有分支 for start, branches in self.branches.items(): for name, branch in branches.items(): compiled.attach_branch(start, name, branch)
return compiled.validate()CompiledStateGraph 类
CompiledStateGraph 是编译后的状态图,继承自 Pregel,实现了实际的图执行逻辑。
💡 通俗理解:
CompiledStateGraph就是”实际运行的机器”,而StateGraph只是”设计图纸”。
StateGraph:你用它来设计工作流程(添加节点、边等)CompiledStateGraph:编译后得到的可执行对象,你可以调用它来实际运行工作流程 就像写代码和运行代码的区别一样。
类属性
class CompiledStateGraph( Pregel[StateT, InputT, OutputT], Generic[StateT, InputT, OutputT]): """ 编译后的状态图。
继承自 Pregel,实现了图的执行逻辑。 """ builder: StateGraph[StateT, InputT, OutputT] # 构建器引用,保存原始的状态图定义
schema_to_mapper: dict[type[Any], Callable[[Any], Any] | None] # 模式到映射函数的字典,用于将状态字典转换为模式对象__init__ 方法
def __init__( self, *, builder: StateGraph[StateT, InputT, OutputT], schema_to_mapper: dict[type[Any], Callable[[Any], Any] | None], **kwargs: Any,) -> None: """ 初始化编译后的状态图。
调用父类 Pregel 的初始化方法,然后保存构建器和模式映射器。 """ super().__init__(**kwargs) self.builder = builder self.schema_to_mapper = schema_to_mapperget_input_jsonschema 方法
def get_input_jsonschema( self, config: RunnableConfig | None = None) -> dict[str, Any]: """ 获取输入的 JSON Schema。
用于 API 文档生成和验证。 """ return _get_json_schema( typ=self.builder.input_schema, schemas=self.builder.schemas, channels=self.builder.channels, name=self.get_name("Input"), )get_output_jsonschema 方法
def get_output_jsonschema( self, config: RunnableConfig | None = None) -> dict[str, Any]: """ 获取输出的 JSON Schema。
用于 API 文档生成和验证。 """ return _get_json_schema( typ=self.builder.output_schema, schemas=self.builder.schemas, channels=self.builder.channels, name=self.get_name("Output"), )attach_node 方法
def attach_node(self, key: str, node: StateNodeSpec | None) -> None: """ 将节点附加到编译后的图中。
这是编译过程的关键步骤,将 StateGraph 的节点转换为 Pregel 节点。
主要工作: 1. 确定输出键 2. 创建更新函数(处理节点返回值) 3. 创建写入条目(状态更新和分支控制) 4. 创建 PregelNode 并添加到图中
参数: - key: 节点键(START 或节点名) - node: 节点规范(START 节点为 None) """
> 💡 **通俗理解**:这个方法把"设计图纸上的节点"转换成"实际可执行的节点"。 > - 确定节点可以更新哪些状态字段(output_keys) > - 创建一个"转换器"(_get_updates),把节点的返回值转换成状态更新 > - 创建"写入器"(write_entries),负责把更新写入状态,并控制下一步去哪里 > - 最后创建一个 PregelNode(实际执行单元)并添加到图中 > 就像把"工作步骤说明"转换成"实际的工作指令"。
# 确定输出键 if key == START: # START 节点输出输入模式的所有键 output_keys = [ k for k, v in self.builder.schemas[self.builder.input_schema].items() if not is_managed_value(v) ] else: # 普通节点输出所有通道和托管值 output_keys = list(self.builder.channels) + [ k for k, v in self.builder.managed.items() ]
def _get_updates( input: None | dict | Any, ) -> Sequence[tuple[str, Any]] | None: """ 将节点返回值转换为更新元组列表。
处理多种返回类型: - None: 无更新 - dict: 字典键值对 - Command: 命令对象 - list/tuple: 命令列表 - 带注解的对象: 使用 get_update_as_tuples """ if input is None: return None elif isinstance(input, dict): return [(k, v) for k, v in input.items() if k in output_keys] elif isinstance(input, Command): if input.graph == Command.PARENT: return None return [ (k, v) for k, v in input._update_as_tuples() if k in output_keys ] elif ( isinstance(input, (list, tuple)) and input and any(isinstance(i, Command) for i in input) ): updates: list[tuple[str, Any]] = [] for i in input: if isinstance(i, Command): if i.graph == Command.PARENT: continue updates.extend( (k, v) for k, v in i._update_as_tuples() if k in output_keys ) else: updates.extend(_get_updates(i) or ()) return updates elif (t := type(input)) and get_cached_annotated_keys(t): return get_update_as_tuples(input, output_keys) else: msg = create_error_message( message=f"Expected dict, got {input}", error_code=ErrorCode.INVALID_GRAPH_NODE_RETURN_VALUE, ) raise InvalidUpdateError(msg)
# 创建写入条目 write_entries: tuple[ChannelWriteEntry | ChannelWriteTupleEntry, ...] = ( ChannelWriteTupleEntry( mapper=_get_root if output_keys == ["__root__"] else _get_updates ), ChannelWriteTupleEntry( mapper=_control_branch, static=_control_static(node.ends) if node is not None and node.ends is not None else None, ), )
# 添加节点 if key == START: # START 节点 self.nodes[key] = PregelNode( tags=[TAG_HIDDEN], triggers=[START], channels=START, writers=[ChannelWrite(write_entries)], ) elif node is not None: # 普通节点 input_schema = node.input if node else self.builder.state_schema input_channels = list(self.builder.schemas[input_schema]) is_single_input = len(input_channels) == 1 and "__root__" in input_channels
# 获取或创建映射器 if input_schema in self.schema_to_mapper: mapper = self.schema_to_mapper[input_schema] else: mapper = _pick_mapper(input_channels, input_schema) self.schema_to_mapper[input_schema] = mapper
# 创建分支通道 branch_channel = CHANNEL_BRANCH_TO.format(key) self.channels[branch_channel] = ( LastValueAfterFinish(Any) if node.defer else EphemeralValue(Any, guard=False) )
# 创建 PregelNode self.nodes[key] = PregelNode( triggers=[branch_channel], channels=("__root__" if is_single_input else input_channels), mapper=mapper, writers=[ChannelWrite(write_entries)], metadata=node.metadata, retry_policy=node.retry_policy, cache_policy=node.cache_policy, bound=node.runnable, ) else: raise RuntimeErrorattach_edge 方法
def attach_edge(self, starts: str | Sequence[str], end: str) -> None: """ 将边附加到编译后的图中。
处理两种类型的边: 1. 单源边:直接写入目标节点的分支通道 2. 多源边:使用命名屏障值通道等待所有源节点完成
参数: - starts: 起始节点(或节点列表) - end: 结束节点 """
> 💡 **通俗理解**:这个方法把"设计图纸上的箭头"转换成"实际的执行路径"。 > - 单源边:A → B,直接在 A 的写入器中添加"完成后通知 B" > - 多源边:[A, B, C] → D,创建一个"集合点通道",A、B、C 都完成后才通知 D > 就像把"流程图上的箭头"转换成"实际的信号传递机制"。
if isinstance(starts, str): # 单源边:直接写入目标节点的分支通道 if end != END: self.nodes[starts].writers.append( ChannelWrite( (ChannelWriteEntry(CHANNEL_BRANCH_TO.format(end), None),) ) ) elif end != END: # 多源边:创建命名屏障值通道 channel_name = f"join:{'+'.join(starts)}:{end}"
# 注册通道 if self.builder.nodes[end].defer: # 如果目标节点延迟执行,使用 AfterFinish 版本 self.channels[channel_name] = NamedBarrierValueAfterFinish( str, set(starts) ) else: self.channels[channel_name] = NamedBarrierValue(str, set(starts))
# 订阅通道(目标节点等待) self.nodes[end].triggers.append(channel_name)
# 发布到通道(源节点写入) for start in starts: self.nodes[start].writers.append( ChannelWrite((ChannelWriteEntry(channel_name, start),)) )attach_branch 方法
def attach_branch( self, start: str, name: str, branch: Branch, *, with_reader: bool = True) -> None: """ 将分支附加到编译后的图中。
分支用于条件边,根据状态动态选择下一个节点。
参数: - start: 源节点名 - name: 分支名 - branch: 分支对象 - with_reader: 是否创建读取器(用于读取状态) """
> 💡 **通俗理解**:这个方法把"条件判断逻辑"转换成"实际的路由机制"。 > - 创建一个"读取器"(reader)来读取当前状态 > - 创建一个"写入器"(get_writes)来把分支决策转换成"通知哪个节点" > - 把分支逻辑附加到源节点上 > 就像把"如果...就..."的逻辑转换成"实际的信号路由"。
def get_writes( packets: Sequence[str | Send], static: bool = False ) -> Sequence[ChannelWriteEntry | Send]: """ 将分支返回的包转换为写入条目。
参数: - packets: 包列表(节点名或 Send 对象) - static: 是否为静态分支
返回: - 写入条目列表 """ writes = [ ( ChannelWriteEntry( p if p == END else CHANNEL_BRANCH_TO.format(p), None ) if not isinstance(p, Send) else p ) for p in packets if (True if static else p != END) ] if not writes: return [] return writes
if with_reader: # 获取模式 schema = branch.input_schema or ( self.builder.nodes[start].input if start in self.builder.nodes else self.builder.state_schema ) channels = list(self.builder.schemas[schema])
# 获取映射器 if schema in self.schema_to_mapper: mapper = self.schema_to_mapper[schema] else: mapper = _pick_mapper(channels, schema) self.schema_to_mapper[schema] = mapper
# 创建读取器 reader: Callable[[RunnableConfig], Any] | None = partial( ChannelRead.do_read, select=channels[0] if channels == ["__root__"] else channels, fresh=True, mapper=mapper, ) else: reader = None
# 附加分支发布器 self.nodes[start].writers.append(branch.run(get_writes, reader))_migrate_checkpoint 方法
def _migrate_checkpoint(self, checkpoint: Checkpoint) -> None: """ 迁移检查点到新的通道布局。
当图的内部结构改变时(如通道命名方式改变),需要迁移旧的检查点。
迁移规则: 1. start:node -> branch:to:node 2. branch:source:condition:node -> branch:to:node 3. node -> branch:to:node(对于所有目标节点)
参数: - checkpoint: 要迁移的检查点 """
> 💡 **通俗理解**:这个方法就像"游戏版本更新时的存档迁移"。 > - LangGraph 的内部实现可能会改变(比如通道的命名方式) > - 但旧的检查点(存档)还在用旧的命名方式 > - 这个方法会自动把旧的命名转换成新的命名,让旧存档可以在新版本中使用 > 就像游戏更新后,旧存档还能继续使用一样。
super()._migrate_checkpoint(checkpoint)
values = checkpoint["channel_values"] versions = checkpoint["channel_versions"] seen = checkpoint["versions_seen"]
# 空检查点不需要迁移 if not versions: return
# 当前版本(v3+)不需要迁移 if checkpoint["v"] >= 3: return
# 迁移 1: start:node -> branch:to:node for k in list(versions): if k.startswith("start:"): node = k.split(":")[1] if node not in self.nodes: continue new_k = f"branch:to:{node}" new_v = ( max(versions[new_k], versions.pop(k)) if new_k in versions else versions.pop(k) ) # 更新 seen for ss in (seen.get(node, {}), seen.get(INTERRUPT, {})): if k in ss: s = ss.pop(k) if new_k in ss: ss[new_k] = max(s, ss[new_k]) else: ss[new_k] = s # 更新值 if new_k not in values and k in values: values[new_k] = values.pop(k) # 更新版本 versions[new_k] = new_v
# 迁移 2: branch:source:condition:node -> branch:to:node for k in list(versions): if k.startswith("branch:") and k.count(":") == 3: node = k.split(":")[-1] if node not in self.nodes: continue new_k = f"branch:to:{node}" new_v = ( max(versions[new_k], versions.pop(k)) if new_k in versions else versions.pop(k) ) # 更新 seen for ss in (seen.get(node, {}), seen.get(INTERRUPT, {})): if k in ss: s = ss.pop(k) if new_k in ss: ss[new_k] = max(s, ss[new_k]) else: ss[new_k] = s # 更新值 if new_k not in values and k in values: values[new_k] = values.pop(k) # 更新版本 versions[new_k] = new_v
# 迁移 3: node -> branch:to:node(对于所有目标节点) if not set(self.nodes).isdisjoint(versions): source_to_target = defaultdict(list) for start, end in self.builder.edges: if start != START and end != END: source_to_target[start].append(end) for k in list(versions): if k == START: continue if k in self.nodes: v = versions.pop(k) c = values.pop(k, MISSING) for end in source_to_target[k]: new_k = f"branch:to:{end}" new_v = max(versions[new_k], v) if new_k in versions else v # 更新 seen for ss in (seen.get(end, {}), seen.get(INTERRUPT, {})): if k in ss: s = ss.pop(k) if new_k in ss: ss[new_k] = max(s, ss[new_k]) else: ss[new_k] = s # 更新值 if new_k not in values and c is not MISSING: values[new_k] = c # 更新版本 versions[new_k] = new_v # 弹出中断 seen if INTERRUPT in seen: seen[INTERRUPT].pop(k, MISSING)辅助函数
_pick_mapper 函数
def _pick_mapper( state_keys: Sequence[str], schema: type[Any]) -> Callable[[Any], Any] | None: """ 选择状态映射函数。
如果状态键是 __root__ 或模式是字典类型,返回 None(不需要映射)。 否则返回一个将字典转换为模式对象的函数。 """ if state_keys == ["__root__"]: return None if isclass(schema) and issubclass(schema, dict): return None return partial(_coerce_state, schema)_coerce_state 函数
def _coerce_state(schema: type[Any], input: dict[str, Any]) -> dict[str, Any]: """ 将字典强制转换为模式对象。
使用模式类的构造函数创建实例。 """ return schema(**input)_control_branch 函数
def _control_branch(value: Any) -> Sequence[tuple[str, Any]]: """ 处理分支控制值。
将 Command 对象或 Send 对象转换为写入条目。
参数: - value: 控制值(Command、Send 或列表)
返回: - 写入条目列表 """ if isinstance(value, Send): return ((TASKS, value),)
commands: list[Command] = [] if isinstance(value, Command): commands.append(value) elif isinstance(value, (list, tuple)): for cmd in value: if isinstance(cmd, Command): commands.append(cmd)
rtn: list[tuple[str, Any]] = [] for command in commands: # 如果是父命令,抛出异常 if command.graph == Command.PARENT: raise ParentCommand(command)
# 获取目标节点 goto_targets = ( [command.goto] if isinstance(command.goto, (Send, str)) else command.goto )
# 转换为写入条目 for go in goto_targets: if isinstance(go, Send): rtn.append((TASKS, go)) elif isinstance(go, str) and go != END: rtn.append((CHANNEL_BRANCH_TO.format(go), None)) return rtn_control_static 函数
def _control_static( ends: tuple[str, ...] | dict[str, str],) -> Sequence[tuple[str, Any, str | None]]: """ 创建静态分支控制条目。
用于节点规范中指定的目标节点。
参数: - ends: 目标节点(元组或字典)
返回: - 写入条目列表(包含标签) """ if isinstance(ends, dict): return [ (k if k == END else CHANNEL_BRANCH_TO.format(k), None, label) for k, label in ends.items() ] else: return [ (e if e == END else CHANNEL_BRANCH_TO.format(e), None, None) for e in ends ]_get_root 函数
def _get_root(input: Any) -> Sequence[tuple[str, Any]] | None: """ 处理根值更新。
当状态只有一个根键(__root__)时使用。
参数: - input: 输入值(Command、列表或普通值)
返回: - 更新元组列表 """ if isinstance(input, Command): if input.graph == Command.PARENT: return () return input._update_as_tuples() elif ( isinstance(input, (list, tuple)) and input and any(isinstance(i, Command) for i in input) ): updates: list[tuple[str, Any]] = [] for i in input: if isinstance(i, Command): if i.graph == Command.PARENT: continue updates.extend(i._update_as_tuples()) else: updates.append(("__root__", i)) return updates elif input is not None: return [("__root__", input)]_get_channels 函数
def _get_channels( schema: type[dict],) -> tuple[dict[str, BaseChannel], dict[str, ManagedValueSpec], dict[str, Any]]: """ 从模式中提取通道和托管值。
参数: - schema: 模式类型
返回: - (通道字典, 托管值字典, 类型提示字典) """
> 💡 **通俗理解**:这个方法就像"解析表格模板,为每个字段创建邮箱"。 > - 输入:一个状态模式(定义了有哪些字段) > - 输出:为每个字段创建的通道(邮箱)和托管值(特殊邮箱) > - 就像你定义了一个表格,系统自动为每列创建一个"数据传递通道"
# 如果没有注解,返回根通道 if not hasattr(schema, "__annotations__"): return ( {"__root__": _get_channel("__root__", schema, allow_managed=False)}, {}, {}, )
# 获取类型提示(包括额外信息) type_hints = get_type_hints(schema, include_extras=True)
# 为每个字段创建通道或托管值 all_keys = { name: _get_channel(name, typ) for name, typ in type_hints.items() if name != "__slots__" }
# 分离通道和托管值 return ( {k: v for k, v in all_keys.items() if isinstance(v, BaseChannel)}, {k: v for k, v in all_keys.items() if is_managed_value(v)}, type_hints, )_get_channel 函数
@overloaddef _get_channel( name: str, annotation: Any, *, allow_managed: Literal[False]) -> BaseChannel: ...
@overloaddef _get_channel( name: str, annotation: Any, *, allow_managed: Literal[True] = True) -> BaseChannel | ManagedValueSpec: ...
def _get_channel( name: str, annotation: Any, *, allow_managed: bool = True) -> BaseChannel | ManagedValueSpec: """ 从类型注解创建通道或托管值。
检查顺序: 1. 是否为托管值 2. 是否为显式通道 3. 是否为二元操作符聚合 4. 默认使用 LastValue
参数: - name: 字段名 - annotation: 类型注解 - allow_managed: 是否允许托管值
返回: - 通道或托管值规范 """
> 💡 **通俗理解**:这个方法就像"根据字段类型选择合适的邮箱类型"。 > - 如果字段标注了"托管值"(系统自动管理),就用托管值通道 > - 如果字段标注了特定通道类型,就用那个类型 > - 如果字段有 reducer(聚合函数),就用聚合通道(可以合并多个值) > - 否则默认用 LastValue(只保留最新值,覆盖旧值) > 就像根据不同的需求选择不同的"邮箱类型":有的邮箱只存最新邮件,有的可以合并多封邮件。
# 检查托管值 if manager := _is_field_managed_value(name, annotation): if allow_managed: return manager else: raise ValueError(f"This {annotation} not allowed in this position")
# 检查显式通道 elif channel := _is_field_channel(annotation): channel.key = name return channel
# 检查二元操作符聚合 elif channel := _is_field_binop(annotation): channel.key = name return channel
# 默认使用 LastValue fallback: LastValue = LastValue(annotation) fallback.key = name return fallback_is_field_channel 函数
def _is_field_channel(typ: type[Any]) -> BaseChannel | None: """ 检查类型注解是否为显式通道。
检查注解的元数据中是否包含 BaseChannel 实例或类。 """ if hasattr(typ, "__metadata__"): meta = typ.__metadata__ if len(meta) >= 1 and isinstance(meta[-1], BaseChannel): return meta[-1] elif len(meta) >= 1 and isclass(meta[-1]) and issubclass(meta[-1], BaseChannel): return meta[-1](typ.__origin__ if hasattr(typ, "__origin__") else typ) return None_is_field_binop 函数
def _is_field_binop(typ: type[Any]) -> BinaryOperatorAggregate | None: """ 检查类型注解是否为二元操作符聚合(reducer)。
检查注解的元数据中是否包含可调用对象,且签名符合 reducer 要求(两个位置参数)。 """ if hasattr(typ, "__metadata__"): meta = typ.__metadata__ if len(meta) >= 1 and callable(meta[-1]): sig = signature(meta[-1]) params = list(sig.parameters.values()) if ( sum( p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) for p in params ) == 2 ): return BinaryOperatorAggregate(typ, meta[-1]) else: raise ValueError( f"Invalid reducer signature. Expected (a, b) -> c. Got {sig}" ) return None_is_field_managed_value 函数
def _is_field_managed_value(name: str, typ: type[Any]) -> ManagedValueSpec | None: """ 检查类型注解是否为托管值。
检查注解的元数据中是否包含托管值装饰器。 """ if hasattr(typ, "__metadata__"): meta = typ.__metadata__ if len(meta) >= 1: decoration = get_origin(meta[-1]) or meta[-1] if is_managed_value(decoration): return decoration
return None_get_json_schema 函数
def _get_json_schema( typ: type, schemas: dict, channels: dict, name: str,) -> dict[str, Any]: """ 生成类型的 JSON Schema。
用于 API 文档和验证。
参数: - typ: 类型 - schemas: 模式字典 - channels: 通道字典 - name: Schema 名称
返回: - JSON Schema 字典 """ # Pydantic 模型 if isclass(typ) and issubclass(typ, BaseModel): return typ.model_json_schema()
# TypedDict elif is_typeddict(typ): return TypeAdapter(typ).json_schema()
# 其他类型 else: keys = list(schemas[typ].keys())
# 单根键 if len(keys) == 1 and keys[0] == "__root__": return create_model( name, root=(channels[keys[0]].UpdateType, None), ).model_json_schema()
# 多键 else: return create_model( name, field_definitions={ k: ( channels[k].UpdateType, ( get_field_default( k, channels[k].UpdateType, typ, ) ), ) for k in schemas[typ] if k in channels and isinstance(channels[k], BaseChannel) }, ).model_json_schema()常量
CHANNEL_BRANCH_TO = "branch:to:{}"# 分支通道名称格式模板# 用于创建节点分支通道的名称,例如 "branch:to:node_name"总结
state.py 文件实现了 LangGraph 的核心状态图功能:
- StateGraph 类:用于构建状态图,支持节点、边、条件边等
- CompiledStateGraph 类:编译后的图,实现实际执行逻辑
- 通道系统:通过通道实现节点间的状态共享
- 模式系统:支持 TypedDict、Pydantic 模型等多种状态模式
- 检查点迁移:支持旧版本检查点的自动迁移
关键概念:
- 节点(Node):图的执行单元,接收状态并返回部分状态更新
- 边(Edge):定义节点间的执行顺序
- 条件边(Conditional Edge):根据状态动态选择下一个节点
- 通道(Channel):节点间共享状态的机制
- 模式(Schema):定义状态的结构和类型
这个文件是 LangGraph 框架的核心,理解它有助于深入理解整个框架的工作原理。
💡 整体理解:用生活化的比喻
StateGraph(设计阶段):
- 就像画流程图,你定义有哪些步骤(节点)、步骤之间的顺序(边)、根据情况选择路径(条件边)
- 你还可以设置每个步骤的配置:失败后重试几次、结果是否缓存、执行完后可以去哪里
CompiledStateGraph(运行阶段):
- 就像把流程图转换成实际的工作流程系统
- 每个节点都有一个”邮箱”(通道)来传递数据
- 系统会自动调度:哪个节点先执行、什么时候执行、如何传递数据
通道系统:
- 就像公司里的”共享文件夹”或”公告板”
- 每个状态字段都有自己的”文件夹”(通道)
- 节点可以读取和写入这些文件夹
- 不同类型的通道有不同的行为:
LastValue:只保留最新值(覆盖旧值)NamedBarrierValue:等待多个节点都完成(集合点)BinaryOperatorAggregate:用函数合并多个值(如累加、列表合并)
执行流程:
- 你调用
graph.invoke(input)或graph.stream(input) - 系统从 START 节点开始
- 每个节点执行时:
- 从通道读取需要的状态字段
- 执行节点函数
- 把返回值写入通道(更新状态)
- 根据边或条件边决定下一个节点
- 重复步骤 3,直到到达 END 节点或没有更多节点可执行
- 返回最终状态
检查点系统:
- 就像游戏存档,可以在任意时刻保存当前状态
- 如果出错了,可以从检查点恢复,不用从头开始
- 支持版本迁移:即使 LangGraph 内部实现改变了,旧存档也能用
🎯 关键理解点
- 状态是共享的:所有节点共享同一个状态对象,通过通道来读写
- 节点只返回部分更新:节点不需要返回完整状态,只返回它要更新的字段
- 编译是必须的:
StateGraph只是设计,必须调用compile()才能运行 - 通道是核心:节点间的数据传递都通过通道,通道类型决定了数据如何合并
- 类型提示很重要:LangGraph 会从类型提示推断很多信息,写好类型提示可以让代码更清晰
希望这些解释能帮助你更好地理解 LangGraph 的工作原理!🚀
Some information may be outdated