11395 words
57 minutes
LangGraph底层源码解析

1. LangGraph介绍#

官方文档:点此跳转

正如数据结构中的链表和图的关系,传统的AI工作流为链式调用,这就好比一个链表,它很直观易懂,但是局限性太强,如果想要基于某个条件选择不同的Agent进行工作,链式工作流就捉襟见肘了,如果当我在某处触发某个条件需要回溯到先前的节点,链式工作流更是无法实现。

而LangGraph提供了一种图结构,将各个节点以图的形式组织起来,从而实现更灵活的调用方式。

与图类似的,LangGraph也有节点(Node)和边(Edge)两个概念,节点不一定是Agent,也可以是任意对象,如LLM、工作流、以及可运行的如函数等,则表示两个节点组成了一个工作流,注意边是有向边,即有明确的调用顺序。

与图不同的是,LangGraph还有几个新概念,ActionRouter,严格来说,Action和Router都是节点,但是它们有特殊的含义,Action表示一个可运行的函数,Router表示一个选择器,它根据某些条件选择下一个节点。

简单来说,Action表示的是一个节点的业务逻辑,Router则是一个特殊的节点,它包含的是路由逻辑而非业务逻辑,用炒菜来举个例子,假设我们有一个炒菜的Agent,它需要调用不同的工具和食材,那么这些工具和食材就是节点,而炒菜的过程就是边,而Action就是具体的炒菜步骤,Router就是选择食材和工具的步骤。

在LangGraph框架中,Router使用if...else的形式来决定路径,主要通过以下三种方式实现:

  • 提示词工程:指示大模型以特定格式作出回应
  • 输出解析器:使用后处理从大模型响应中提取结构化数据
  • 工具调用:利用大模型的内置工具调用功能来生成结构化输出

2. LangGraph安装#

在此我使用uv来进行python的包管理。

Terminal window
uv init test
cd test
uv add langgraph

如果是pip,运行pip install langgraph即可

3. LangGraph state.py 源码解读#

💡 通俗理解:这个文件就像是 LangGraph 的”大脑”,它定义了如何构建和执行一个工作流图。想象一下,你有一个复杂的任务需要多个步骤完成,每个步骤(节点)都需要从共享的”工作台”(状态)上读取信息,处理完后把结果放回工作台。这个文件就是定义这个工作台和工作流程的规则。

目录#

  1. 导入部分
  2. 工具函数
  3. 协议定义
  4. 类型别名与辅助函数
  5. StateNodeSpec 命名元组
  6. StateGraph 类
  7. CompiledStateGraph 类
  8. 辅助函数

导入部分#

from __future__ import annotations
# 启用延迟类型注解评估,允许在类型提示中使用前向引用
import inspect
# 用于检查函数签名和参数信息
import logging
# 用于日志记录
import typing
# 用于类型检查和类型操作
import warnings
# 用于发出警告信息
from collections import defaultdict
# 提供默认值字典,用于自动创建缺失的键
from collections.abc import Awaitable, Hashable, Sequence
# 抽象基类:Awaitable(可等待对象)、Hashable(可哈希对象)、Sequence(序列)
from functools import partial
# 用于创建偏函数(部分应用函数)
from inspect import isclass, isfunction, ismethod, signature
# isclass: 检查是否为类
# isfunction: 检查是否为函数
# ismethod: 检查是否为方法
# signature: 获取函数签名
from types import FunctionType
# 函数类型
from typing import (
Any, # 任意类型
Callable, # 可调用对象类型
Generic, # 泛型基类
Literal, # 字面量类型
NamedTuple, # 命名元组
Protocol, # 协议类型(结构子类型)
Union, # 联合类型
cast, # 类型转换函数
get_args, # 获取泛型参数
get_origin, # 获取泛型原始类型
get_type_hints, # 获取类型提示
overload, # 函数重载装饰器
)
from langchain_core.runnables import Runnable, RunnableConfig
# Runnable: 可运行对象接口
# RunnableConfig: 运行配置对象
from pydantic import BaseModel, TypeAdapter
# BaseModel: Pydantic 基础模型类
# TypeAdapter: 类型适配器,用于类型转换和验证
from typing_extensions import Self, TypeAlias, Unpack, is_typeddict
# Self: 表示当前类自身的类型
# TypeAlias: 类型别名
# Unpack: 解包操作符
# is_typeddict: 检查是否为 TypedDict
from langgraph._typing import UNSET, DeprecatedKwargs
# UNSET: 未设置值的标记
# DeprecatedKwargs: 已弃用的关键字参数类型
from langgraph.cache.base import BaseCache
# 缓存基类
from langgraph.channels.base import BaseChannel
# 通道基类,用于节点间通信
# 💡 通俗理解:通道就像是一个"共享邮箱",节点可以把数据放进去,其他节点可以从里面读取。每个状态字段都有自己的"邮箱"(通道)。
from langgraph.channels.binop import BinaryOperatorAggregate
# 二元操作符聚合通道
from langgraph.channels.ephemeral_value import EphemeralValue
# 临时值通道
from langgraph.channels.last_value import LastValue, LastValueAfterFinish
# LastValue: 最后值通道
# LastValueAfterFinish: 完成后最后值通道
# 💡 通俗理解:LastValue 就像是一个"黑板",每次写入都会覆盖之前的内容,只保留最新的值。LastValueAfterFinish 则是等所有节点都完成后才更新。
from langgraph.channels.named_barrier_value import (
NamedBarrierValue,
NamedBarrierValueAfterFinish,
)
# NamedBarrierValue: 命名屏障值通道(用于等待多个节点完成)
# NamedBarrierValueAfterFinish: 完成后命名屏障值通道
# 💡 通俗理解:这就像是一个"集合点",多个节点必须都到达这里,才能继续执行下一个节点。就像团队开会,必须所有人都到齐才能开始。
from langgraph.checkpoint.base import Checkpoint
# 检查点基类,用于保存和恢复图执行状态
# 💡 通俗理解:检查点就像游戏存档,可以保存当前进度,如果出错了可以回到之前的状态继续执行,不用从头开始。
from langgraph.constants import (
EMPTY_SEQ, # 空序列常量
END, # 结束节点标识
INTERRUPT, # 中断标识
MISSING, # 缺失值标识
NS_END, # 命名空间结束符
NS_SEP, # 命名空间分隔符
START, # 开始节点标识
TAG_HIDDEN, # 隐藏标签
TASKS, # 任务标识
)
from langgraph.errors import (
ErrorCode, # 错误代码
InvalidUpdateError, # 无效更新错误
ParentCommand, # 父命令异常
create_error_message, # 创建错误消息
)
from langgraph.graph.branch import Branch
# 分支类,用于条件边
from langgraph.managed.base import (
ManagedValueSpec, # 托管值规范
is_managed_value, # 检查是否为托管值
)
from langgraph.pregel import Pregel
# Pregel 图执行引擎基类
# 💡 通俗理解:Pregel 是 Google 提出的图计算模型,LangGraph 用它来实际执行图。就像是一个"调度器",负责决定哪个节点什么时候执行,如何传递数据。
from langgraph.pregel.read import ChannelRead, PregelNode
# ChannelRead: 通道读取器
# PregelNode: Pregel 节点
from langgraph.pregel.write import (
ChannelWrite, # 通道写入器
ChannelWriteEntry, # 通道写入条目
ChannelWriteTupleEntry, # 通道写入元组条目
)
from langgraph.store.base import BaseStore
# 存储基类
from langgraph.types import (
All, # 全部标识
CachePolicy, # 缓存策略
Checkpointer, # 检查点器
Command, # 命令类型
RetryPolicy, # 重试策略
Send, # 发送操作
StreamWriter, # 流写入器
)
from langgraph.typing import InputT, OutputT, StateT, StateT_contra
# InputT: 输入类型变量
# OutputT: 输出类型变量
# StateT: 状态类型变量
# StateT_contra: 状态类型逆变变量
from langgraph.utils.fields import (
get_cached_annotated_keys, # 获取缓存的注解键
get_field_default, # 获取字段默认值
get_update_as_tuples, # 获取更新元组
)
from langgraph.utils.pydantic import create_model
# 创建 Pydantic 模型
from langgraph.utils.runnable import coerce_to_runnable
# 强制转换为可运行对象
from langgraph.warnings import LangGraphDeprecatedSinceV05
# LangGraph 弃用警告
logger = logging.getLogger(__name__)
# 创建日志记录器

工具函数#

_warn_invalid_state_schema#

def _warn_invalid_state_schema(schema: type[Any] | Any) -> None:
"""
警告无效的状态模式。
如果模式不是类型,且没有泛型参数,则发出警告。
这有助于用户发现状态模式定义错误。
"""
if isinstance(schema, type):
# 如果是类型,直接返回(有效)
return
if typing.get_args(schema):
# 如果有泛型参数(如 Annotated),也认为是有效的
return
# 否则发出警告
warnings.warn(
f"Invalid state_schema: {schema}. Expected a type or Annotated[type, reducer]. "
"Please provide a valid schema to ensure correct updates.\n"
" See: https://langchain-ai.github.io/langgraph/reference/graphs/#stategraph"
)

协议定义#

这些协议定义了节点函数的不同签名变体,支持不同的参数组合:

💡 通俗理解:协议就像是一份”合同”,规定了节点函数可以有哪些不同的”签名”(参数组合)。有些节点只需要状态,有些还需要配置、写入器或存储。就像不同的工作岗位有不同的职责要求一样。

_StateNode#

class _StateNode(Protocol[StateT_contra]):
"""最基本的节点协议:只接受状态参数"""
def __call__(self, state: StateT_contra) -> Any: ...

_NodeWithConfig#

class _NodeWithConfig(Protocol[StateT_contra]):
"""带配置的节点协议:接受状态和配置参数"""
def __call__(self, state: StateT_contra, config: RunnableConfig) -> Any: ...

_NodeWithWriter#

class _NodeWithWriter(Protocol[StateT_contra]):
"""带写入器的节点协议:接受状态和流写入器"""
def __call__(self, state: StateT_contra, *, writer: StreamWriter) -> Any: ...

_NodeWithStore#

class _NodeWithStore(Protocol[StateT_contra]):
"""带存储的节点协议:接受状态和存储对象"""
def __call__(self, state: StateT_contra, *, store: BaseStore) -> Any: ...

_NodeWithWriterStore#

class _NodeWithWriterStore(Protocol[StateT_contra]):
"""带写入器和存储的节点协议"""
def __call__(
self, state: StateT_contra, *, writer: StreamWriter, store: BaseStore
) -> Any: ...

_NodeWithConfigWriter#

class _NodeWithConfigWriter(Protocol[StateT_contra]):
"""带配置和写入器的节点协议"""
def __call__(
self, state: StateT_contra, *, config: RunnableConfig, writer: StreamWriter
) -> Any: ...

_NodeWithConfigStore#

class _NodeWithConfigStore(Protocol[StateT_contra]):
"""带配置和存储的节点协议"""
def __call__(
self, state: StateT_contra, *, config: RunnableConfig, store: BaseStore
) -> Any: ...

_NodeWithConfigWriterStore#

class _NodeWithConfigWriterStore(Protocol[StateT_contra]):
"""带配置、写入器和存储的节点协议(最完整的签名)"""
def __call__(
self,
state: StateT_contra,
*,
config: RunnableConfig,
writer: StreamWriter,
store: BaseStore,
) -> Any: ...

类型别名与辅助函数#

StateNode 类型别名#

StateNode: TypeAlias = Union[
_StateNode[StateT_contra], # 基础节点
_NodeWithConfig[StateT_contra], # 带配置
_NodeWithWriter[StateT_contra], # 带写入器
_NodeWithStore[StateT_contra], # 带存储
_NodeWithWriterStore[StateT_contra], # 带写入器和存储
_NodeWithConfigWriter[StateT_contra], # 带配置和写入器
_NodeWithConfigStore[StateT_contra], # 带配置和存储
_NodeWithConfigWriterStore[StateT_contra], # 带所有参数
Runnable[StateT_contra, Any], # 可运行对象
]

说明StateNode 是所有支持的节点类型的联合类型,允许节点函数有不同的签名。

💡 通俗理解StateNode 就像是一个”万能接口”,不管你写节点函数时用了哪些参数(只要符合其中一种协议),都可以被 LangGraph 识别和使用。这给了开发者很大的灵活性。

_get_node_name 函数#

def _get_node_name(node: StateNode) -> str:
"""
获取节点名称。
尝试从节点的 __name__ 属性获取,如果没有则使用类名。
如果都获取不到,抛出类型错误。
"""
try:
return getattr(node, "__name__", node.__class__.__name__)
except AttributeError:
raise TypeError(f"Unsupported node type: {type(node)}")

StateNodeSpec 命名元组#

💡 通俗理解StateNodeSpec 就像是节点的”身份证”,记录了节点的所有信息:它是什么(runnable)、需要什么输入(input)、失败后怎么办(retry_policy)、结果可以缓存吗(cache_policy)、执行完后可以去哪里(ends)等等。

class StateNodeSpec(NamedTuple):
"""
节点规范,存储节点的所有配置信息。
字段说明:
- runnable: 节点可运行对象(函数或 Runnable)
- metadata: 节点元数据字典(可选)
- input: 节点输入模式类型
- retry_policy: 重试策略(可选,可以是单个策略或策略序列)
- cache_policy: 缓存策略(可选)
- ends: 节点可以路由到的目标节点(可选)
可以是元组(节点名列表)或字典(节点名到标签的映射)
- defer: 是否延迟执行(直到运行即将结束时才执行)
"""
runnable: StateNode
metadata: dict[str, Any] | None
input: type[Any]
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None
cache_policy: CachePolicy | None
ends: tuple[str, ...] | dict[str, str] | None = EMPTY_SEQ
defer: bool = False

StateGraph 类#

StateGraph 是 LangGraph 的核心类,用于构建基于状态的有向图。

💡 通俗理解StateGraph 就像是一个”建筑图纸”,你用它来设计工作流程。你可以添加节点(工作步骤)、连接边(执行顺序)、设置条件分支(根据情况选择路径)。设计完成后,调用 compile() 方法把它”建造”成可以实际运行的图。

类属性#

edges: set[tuple[str, str]]
# 边的集合,每个边是 (起始节点, 结束节点) 的元组
nodes: dict[str, StateNodeSpec]
# 节点字典,键是节点名,值是节点规范
branches: defaultdict[str, dict[str, Branch]]
# 分支字典,键是源节点名,值是条件名到分支对象的映射
channels: dict[str, BaseChannel]
# 通道字典,键是状态键名,值是通道对象
# 💡 通俗理解:每个状态字段都有自己的"通道"(邮箱),节点通过通道读写数据
managed: dict[str, ManagedValueSpec]
# 托管值字典,键是状态键名,值是托管值规范
schemas: dict[type[Any], dict[str, BaseChannel | ManagedValueSpec]]
# 模式字典,键是模式类型,值是键到通道或托管值的映射
waiting_edges: set[tuple[tuple[str, ...], str]]
# 等待边集合,用于多源节点到单目标节点的边
# 💡 通俗理解:就像"等所有人到齐再出发",多个节点必须都完成,才能执行下一个节点
compiled: bool
# 是否已编译的标志
state_schema: type[StateT]
# 状态模式类型
input_schema: type[InputT]
# 输入模式类型
output_schema: type[OutputT]
# 输出模式类型

__init__ 方法#

def __init__(
self,
state_schema: type[StateT], # 状态模式(必需)
config_schema: type[Any] | None = None, # 配置模式(可选)
*,
input_schema: type[InputT] | None = None, # 输入模式(可选)
output_schema: type[OutputT] | None = None, # 输出模式(可选)
**kwargs: Unpack[DeprecatedKwargs], # 已弃用的关键字参数
) -> None:
"""
初始化状态图。
处理已弃用的参数(input/output),然后初始化所有数据结构。
最后添加状态、输入和输出模式。
"""
> 💡 **通俗理解**:初始化就像准备一个空的工作台,设置好工作台的规格(状态模式),准备好各种"邮箱"(通道)。之后你就可以往上面添加工作步骤(节点)了。
# 处理已弃用的 input 参数
if (input_ := kwargs.get("input", UNSET)) is not UNSET:
warnings.warn(
"`input` is deprecated and will be removed. Please use `input_schema` instead.",
category=LangGraphDeprecatedSinceV05,
stacklevel=2,
)
if input_schema is None:
input_schema = cast(Union[type[InputT], None], input_)
# 处理已弃用的 output 参数
if (output := kwargs.get("output", UNSET)) is not UNSET:
warnings.warn(
"`output` is deprecated and will be removed. Please use `output_schema` instead.",
category=LangGraphDeprecatedSinceV05,
stacklevel=2,
)
if output_schema is None:
output_schema = cast(Union[type[OutputT], None], output)
# 初始化数据结构
self.nodes = {}
self.edges = set()
self.branches = defaultdict(dict)
self.schemas = {}
self.channels = {}
self.managed = {}
self.compiled = False
self.waiting_edges = set()
# 设置模式
self.state_schema = state_schema
self.input_schema = cast(type[InputT], input_schema or state_schema)
self.output_schema = cast(type[OutputT], output_schema or state_schema)
self.config_schema = config_schema
# 添加模式(这会解析模式并创建通道)
self._add_schema(self.state_schema)
self._add_schema(self.input_schema, allow_managed=False)
self._add_schema(self.output_schema, allow_managed=False)

_all_edges 属性#

@property
def _all_edges(self) -> set[tuple[str, str]]:
"""
获取所有边的集合(包括普通边和等待边)。
等待边会被展开为多个普通边。
"""
return self.edges | {
(start, end) for starts, end in self.waiting_edges for start in starts
}

_add_schema 方法#

def _add_schema(self, schema: type[Any], /, allow_managed: bool = True) -> None:
"""
添加模式到图中。
解析模式中的字段,创建相应的通道和托管值。
如果模式已存在,则跳过。
参数:
- schema: 要添加的模式类型
- allow_managed: 是否允许托管值(输入/输出模式不允许)
"""
> 💡 **通俗理解**:模式就像是一个"表格模板",定义了状态有哪些字段、每个字段是什么类型。这个方法会为每个字段创建一个"邮箱"(通道),节点可以通过这些邮箱传递数据。`allow_managed=False` 的意思是:输入/输出模式不允许使用"托管值"(一种特殊的通道类型,由系统自动管理)。
if schema not in self.schemas:
# 警告无效的模式
_warn_invalid_state_schema(schema)
# 获取通道、托管值和类型提示
channels, managed, type_hints = _get_channels(schema)
# 检查托管值是否允许
if managed and not allow_managed:
names = ", ".join(managed)
schema_name = getattr(schema, "__name__", "")
raise ValueError(
f"Invalid managed channels detected in {schema_name}: {names}."
" Managed channels are not permitted in Input/Output schema."
)
# 保存模式信息
self.schemas[schema] = {**channels, **managed}
# 添加通道
for key, channel in channels.items():
if key in self.channels:
if self.channels[key] != channel:
# 如果通道类型不同,且不是 LastValue,则报错
if isinstance(channel, LastValue):
pass # LastValue 可以兼容
else:
raise ValueError(
f"Channel '{key}' already exists with a different type"
)
else:
self.channels[key] = channel
# 添加托管值
for key, managed in managed.items():
if key in self.managed:
if self.managed[key] != managed:
raise ValueError(
f"Managed value '{key}' already exists with a different type"
)
else:
self.managed[key] = managed

add_node 方法#

add_node 方法有两个重载版本和一个实现版本:

重载版本 1#

@overload
def add_node(
self,
node: StateNode[StateT], # 节点函数或可运行对象
*,
defer: bool = False,
metadata: dict[str, Any] | None = None,
input_schema: type[Any] | None = None,
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
cache_policy: CachePolicy | None = None,
destinations: dict[str, str] | tuple[str, ...] | None = None,
**kwargs: Unpack[DeprecatedKwargs],
) -> Self:
"""添加节点,节点名从函数/可运行对象名称推断"""
...

重载版本 2#

@overload
def add_node(
self,
node: str, # 节点名称
action: StateNode[StateT], # 节点函数或可运行对象
*,
defer: bool = False,
metadata: dict[str, Any] | None = None,
input_schema: type[Any] | None = None,
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
cache_policy: CachePolicy | None = None,
destinations: dict[str, str] | tuple[str, ...] | None = None,
**kwargs: Unpack[DeprecatedKwargs],
) -> Self:
"""添加节点,显式指定节点名"""
...

实现版本#

def add_node(
self,
node: str | StateNode[StateT],
action: StateNode[StateT] | None = None,
*,
defer: bool = False,
metadata: dict[str, Any] | None = None,
input_schema: type[Any] | None = None,
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
cache_policy: CachePolicy | None = None,
destinations: dict[str, str] | tuple[str, ...] | None = None,
**kwargs: Unpack[DeprecatedKwargs],
) -> Self:
"""
添加节点到图中。
主要步骤:
1. 处理已弃用的参数
2. 确定节点名称和动作
3. 验证节点名称
4. 推断输入模式和目标节点(从类型提示)
5. 创建节点规范并添加到图中
"""
> 💡 **通俗理解**:添加节点就像是在工作流程中添加一个工作步骤。这个方法会:
> - 检查节点名是否合法(不能是保留字,不能重复)
> - 自动识别节点函数的类型提示,推断它需要什么输入、可以输出到哪里
> - 创建一个"身份证"(StateNodeSpec)记录这个节点的所有信息
> - 把节点添加到图中
# 处理已弃用的 retry 参数
if (retry := kwargs.get("retry", UNSET)) is not UNSET:
warnings.warn(
"`retry` is deprecated and will be removed. Please use `retry_policy` instead.",
category=LangGraphDeprecatedSinceV05,
)
if retry_policy is None:
retry_policy = retry
# 处理已弃用的 input 参数
if (input_ := kwargs.get("input", UNSET)) is not UNSET:
warnings.warn(
"`input` is deprecated and will be removed. Please use `input_schema` instead.",
category=LangGraphDeprecatedSinceV05,
)
if input_schema is None:
input_schema = cast(Union[type[InputT], None], input_)
# 确定节点名称和动作
if not isinstance(node, str):
action = node
if isinstance(action, Runnable):
node = action.get_name()
else:
node = getattr(action, "__name__", action.__class__.__name__)
if node is None:
raise ValueError(
"Node name must be provided if action is not a function"
)
# 如果图已编译,发出警告
if self.compiled:
logger.warning(
"Adding a node to a graph that has already been compiled. This will "
"not be reflected in the compiled graph."
)
# 再次检查节点名称(处理边界情况)
if not isinstance(node, str):
action = node
node = cast(str, getattr(action, "name", getattr(action, "__name__", None)))
if node is None:
raise ValueError(
"Node name must be provided if action is not a function"
)
# 验证动作存在
if action is None:
raise RuntimeError
# 验证节点名唯一性
if node in self.nodes:
raise ValueError(f"Node `{node}` already present.")
# 验证节点名不是保留字
if node == END or node == START:
raise ValueError(f"Node `{node}` is reserved.")
# 验证节点名不包含保留字符
for character in (NS_SEP, NS_END):
if character in node:
raise ValueError(
f"'{character}' is a reserved character and is not allowed in the node names."
)
# 尝试从类型提示推断目标节点
ends: tuple[str, ...] | dict[str, str] = EMPTY_SEQ
try:
if (
isfunction(action)
or ismethod(action)
or ismethod(getattr(action, "__call__", None))
) and (
hints := get_type_hints(getattr(action, "__call__"))
or get_type_hints(action)
):
# 推断输入模式
if input_schema is None:
first_parameter_name = next(
iter(
inspect.signature(
cast(FunctionType, action)
).parameters.keys()
)
)
if input_hint := hints.get(first_parameter_name):
if isinstance(input_hint, type) and get_type_hints(input_hint):
input_schema = input_hint
# 推断返回类型中的目标节点
if rtn := hints.get("return"):
# 处理 Union 类型
rtn_origin = get_origin(rtn)
if rtn_origin is Union:
rtn_args = get_args(rtn)
# 查找 Command 类型
for arg in rtn_args:
arg_origin = get_origin(arg)
if arg_origin is Command:
rtn = arg
rtn_origin = arg_origin
break
# 检查是否为 Command 类型
if (
rtn_origin is Command
and (rargs := get_args(rtn))
and get_origin(rargs[0]) is Literal
and (vals := get_args(rargs[0]))
):
ends = vals
except (NameError, TypeError, StopIteration):
pass
# 如果提供了 destinations,使用它
if destinations is not None:
ends = destinations
# 添加输入模式(如果指定)
if input_schema is not None:
self._add_schema(input_schema)
# 创建节点规范并添加到图中
self.nodes[node] = StateNodeSpec(
coerce_to_runnable(action, name=node, trace=False),
metadata,
input=input_schema or self.state_schema,
retry_policy=retry_policy,
cache_policy=cache_policy,
ends=ends,
defer=defer,
)
return self

add_edge 方法#

def add_edge(self, start_key: str | list[str], end_key: str) -> Self:
"""
添加有向边到图中。
支持两种模式:
1. 单源到单目标:start_key 是字符串
2. 多源到单目标:start_key 是字符串列表(等待所有源节点完成)
参数:
- start_key: 起始节点(或节点列表)
- end_key: 结束节点
返回:
- Self: 允许方法链式调用
"""
> 💡 **通俗理解**:添加边就像是在工作流程中画箭头,表示"先做这个,再做那个"。
> - 单源边:A → B(A 完成后执行 B)
> - 多源边:[A, B, C] → D(A、B、C 都完成后才执行 D,就像等所有同事都到齐再开会)
if self.compiled:
logger.warning(
"Adding an edge to a graph that has already been compiled. This will "
"not be reflected in the compiled graph."
)
# 处理单源情况
if isinstance(start_key, str):
if start_key == END:
raise ValueError("END cannot be a start node")
if end_key == START:
raise ValueError("START cannot be an end node")
# 验证(仅对非 StateGraph 图)
if not hasattr(self, "channels") and start_key in set(
start for start, _ in self.edges
):
raise ValueError(
f"Already found path for node '{start_key}'.\n"
"For multiple edges, use StateGraph with an Annotated state key."
)
self.edges.add((start_key, end_key))
return self
# 处理多源情况
for start in start_key:
if start == END:
raise ValueError("END cannot be a start node")
if start not in self.nodes:
raise ValueError(f"Need to add_node `{start}` first")
if end_key == START:
raise ValueError("START cannot be an end node")
if end_key != END and end_key not in self.nodes:
raise ValueError(f"Need to add_node `{end_key}` first")
# 添加到等待边集合
self.waiting_edges.add((tuple(start_key), end_key))
return self

add_conditional_edges 方法#

def add_conditional_edges(
self,
source: str, # 源节点
path: Callable[..., Hashable | list[Hashable]]
| Callable[..., Awaitable[Hashable | list[Hashable]]]
| Runnable[Any, Hashable | list[Hashable]], # 路径函数
path_map: dict[Hashable, str] | list[str] | None = None, # 路径映射
) -> Self:
"""
添加条件边。
条件边允许根据状态动态选择下一个节点。
参数:
- source: 源节点名
- path: 路径函数,根据状态返回下一个节点名(或节点名列表)
- path_map: 可选的路径映射,将路径函数的返回值映射到节点名
返回:
- Self: 允许方法链式调用
"""
> 💡 **通俗理解**:条件边就像是一个"智能路口",根据当前情况(状态)决定走哪条路。
> - `path` 函数就像是一个"导航员",查看当前状态后决定下一步去哪里
> - `path_map` 就像一个"翻译表",把导航员的决定(可能是数字、字符串)翻译成实际的节点名
> - 例如:如果用户输入是问题,就去"回答节点";如果是命令,就去"执行节点"
if self.compiled:
logger.warning(
"Adding an edge to a graph that has already been compiled. This will "
"not be reflected in the compiled graph."
)
# 将路径函数转换为可运行对象
path = coerce_to_runnable(path, name=None, trace=True)
name = path.name or "condition"
# 验证条件名唯一性
if name in self.branches[source]:
raise ValueError(
f"Branch with name `{path.name}` already exists for node `{source}`"
)
# 创建分支并保存
self.branches[source][name] = Branch.from_path(path, path_map, True)
# 如果分支有输入模式,添加它
if schema := self.branches[source][name].input_schema:
self._add_schema(schema)
return self

add_sequence 方法#

def add_sequence(
self,
nodes: Sequence[StateNode[StateT] | tuple[str, StateNode[StateT]]],
) -> Self:
"""
添加节点序列。
按顺序添加多个节点,并在它们之间创建边。
参数:
- nodes: 节点序列,可以是 StateNode 或 (名称, StateNode) 元组
返回:
- Self: 允许方法链式调用
"""
if len(nodes) < 1:
raise ValueError("Sequence requires at least one node.")
previous_name: str | None = None
for node in nodes:
# 确定节点名称
if isinstance(node, tuple) and len(node) == 2:
name, node = node
else:
name = _get_node_name(node)
# 验证节点名唯一性
if name in self.nodes:
raise ValueError(
f"Node names must be unique: node with the name '{name}' already exists. "
"If you need to use two different runnables/callables with the same name (for example, using `lambda`), please provide them as tuples (name, runnable/callable)."
)
# 添加节点
self.add_node(name, node)
# 如果不是第一个节点,添加边
if previous_name is not None:
self.add_edge(previous_name, name)
previous_name = name
return self

set_entry_point 方法#

def set_entry_point(self, key: str) -> Self:
"""
设置图的入口点。
等价于调用 add_edge(START, key)。
参数:
- key: 入口节点名
返回:
- Self: 允许方法链式调用
"""
return self.add_edge(START, key)

set_conditional_entry_point 方法#

def set_conditional_entry_point(
self,
path: Callable[..., Hashable | list[Hashable]]
| Callable[..., Awaitable[Hashable | list[Hashable]]]
| Runnable[Any, Hashable | list[Hashable]],
path_map: dict[Hashable, str] | list[str] | None = None,
) -> Self:
"""
设置条件入口点。
允许根据输入动态选择第一个节点。
参数:
- path: 路径函数
- path_map: 路径映射
返回:
- Self: 允许方法链式调用
"""
return self.add_conditional_edges(START, path, path_map)

set_finish_point 方法#

def set_finish_point(self, key: str) -> Self:
"""
设置图的结束点。
等价于调用 add_edge(key, END)。
参数:
- key: 结束节点名
返回:
- Self: 允许方法链式调用
"""
return self.add_edge(key, END)

validate 方法#

def validate(self, interrupt: Sequence[str] | None = None) -> Self:
"""
验证图的结构。
检查:
1. 所有边的源节点都存在
2. 图有入口点
3. 所有边的目标节点都存在
4. 中断节点存在
参数:
- interrupt: 中断节点列表(可选)
返回:
- Self: 允许方法链式调用
副作用:
- 设置 self.compiled = True
"""
> 💡 **通俗理解**:验证就像是在"施工前检查图纸",确保:
> - 所有箭头都指向存在的节点(不能指向空)
> - 有起点(不能没有入口)
> - 所有目标节点都存在(不能指向不存在的节点)
> - 如果设置了中断点,这些节点必须存在
> 如果检查通过,就标记为"已验证",可以编译了。
# 收集所有源节点
all_sources = {src for src, _ in self._all_edges}
for start, branches in self.branches.items():
all_sources.add(start)
for name, spec in self.nodes.items():
if spec.ends:
all_sources.add(name)
# 验证源节点
for source in all_sources:
if source not in self.nodes and source != START:
raise ValueError(f"Found edge starting at unknown node '{source}'")
# 验证有入口点
if START not in all_sources:
raise ValueError(
"Graph must have an entrypoint: add at least one edge from START to another node"
)
# 收集所有目标节点
all_targets = {end for _, end in self._all_edges}
for start, branches in self.branches.items():
for cond, branch in branches.items():
if branch.ends is not None:
for end in branch.ends.values():
if end not in self.nodes and end != END:
raise ValueError(
f"At '{start}' node, '{cond}' branch found unknown target '{end}'"
)
all_targets.add(end)
else:
# 如果没有指定目标,可能路由到任何节点
all_targets.add(END)
for node in self.nodes:
if node != start:
all_targets.add(node)
for name, spec in self.nodes.items():
if spec.ends:
all_targets.update(spec.ends)
# 验证目标节点
for target in all_targets:
if target not in self.nodes and target != END:
raise ValueError(f"Found edge ending at unknown node `{target}`")
# 验证中断节点
if interrupt:
for node in interrupt:
if node not in self.nodes:
raise ValueError(f"Interrupt node `{node}` not found")
self.compiled = True
return self

compile 方法#

def compile(
self,
checkpointer: Checkpointer = None, # 检查点器
*,
cache: BaseCache | None = None, # 缓存
store: BaseStore | None = None, # 存储
interrupt_before: All | list[str] | None = None, # 执行前中断的节点
interrupt_after: All | list[str] | None = None, # 执行后中断的节点
debug: bool = False, # 调试模式
name: str | None = None, # 图名称
) -> CompiledStateGraph[StateT, InputT, OutputT]:
"""
编译状态图为 CompiledStateGraph。
编译后的图实现了 Runnable 接口,可以调用、流式传输、批处理和异步运行。
主要步骤:
1. 验证图结构
2. 准备输出通道和流通道
3. 创建 CompiledStateGraph 实例
4. 附加所有节点、边和分支
返回:
- CompiledStateGraph: 编译后的状态图
"""
> 💡 **通俗理解**:编译就像把"设计图纸"变成"实际可运行的机器"。
> - 先检查图纸有没有问题(validate)
> - 准备好输出接口(output_channels)和流式接口(stream_channels)
> - 创建一个"执行引擎"(CompiledStateGraph)
> - 把所有节点、边、分支都"安装"到引擎上
> 编译完成后,你就可以调用 `invoke()` 或 `stream()` 来运行这个图了!
# 设置默认值
interrupt_before = interrupt_before or []
interrupt_after = interrupt_after or []
# 验证图
self.validate(
interrupt=(
(interrupt_before if interrupt_before != "*" else []) + interrupt_after
if interrupt_after != "*"
else []
)
)
# 准备输出通道
output_channels = (
"__root__"
if len(self.schemas[self.output_schema]) == 1
and "__root__" in self.schemas[self.output_schema]
else [
key
for key, val in self.schemas[self.output_schema].items()
if not is_managed_value(val)
]
)
# 准备流通道
stream_channels = (
"__root__"
if len(self.channels) == 1 and "__root__" in self.channels
else [
key for key, val in self.channels.items() if not is_managed_value(val)
]
)
# 创建编译后的图
compiled = CompiledStateGraph[StateT, InputT, OutputT](
builder=self,
schema_to_mapper={},
config_type=self.config_schema,
nodes={},
channels={
**self.channels,
**self.managed,
START: EphemeralValue(self.input_schema),
},
input_channels=START,
stream_mode="updates",
output_channels=output_channels,
stream_channels=stream_channels,
checkpointer=checkpointer,
interrupt_before_nodes=interrupt_before,
interrupt_after_nodes=interrupt_after,
auto_validate=False,
debug=debug,
store=store,
cache=cache,
name=name or "LangGraph",
)
# 附加 START 节点
compiled.attach_node(START, None)
# 附加所有节点
for key, node in self.nodes.items():
compiled.attach_node(key, node)
# 附加所有边
for start, end in self.edges:
compiled.attach_edge(start, end)
# 附加所有等待边
for starts, end in self.waiting_edges:
compiled.attach_edge(starts, end)
# 附加所有分支
for start, branches in self.branches.items():
for name, branch in branches.items():
compiled.attach_branch(start, name, branch)
return compiled.validate()

CompiledStateGraph 类#

CompiledStateGraph 是编译后的状态图,继承自 Pregel,实现了实际的图执行逻辑。

💡 通俗理解CompiledStateGraph 就是”实际运行的机器”,而 StateGraph 只是”设计图纸”。

  • StateGraph:你用它来设计工作流程(添加节点、边等)
  • CompiledStateGraph:编译后得到的可执行对象,你可以调用它来实际运行工作流程 就像写代码和运行代码的区别一样。

类属性#

class CompiledStateGraph(
Pregel[StateT, InputT, OutputT], Generic[StateT, InputT, OutputT]
):
"""
编译后的状态图。
继承自 Pregel,实现了图的执行逻辑。
"""
builder: StateGraph[StateT, InputT, OutputT]
# 构建器引用,保存原始的状态图定义
schema_to_mapper: dict[type[Any], Callable[[Any], Any] | None]
# 模式到映射函数的字典,用于将状态字典转换为模式对象

__init__ 方法#

def __init__(
self,
*,
builder: StateGraph[StateT, InputT, OutputT],
schema_to_mapper: dict[type[Any], Callable[[Any], Any] | None],
**kwargs: Any,
) -> None:
"""
初始化编译后的状态图。
调用父类 Pregel 的初始化方法,然后保存构建器和模式映射器。
"""
super().__init__(**kwargs)
self.builder = builder
self.schema_to_mapper = schema_to_mapper

get_input_jsonschema 方法#

def get_input_jsonschema(
self, config: RunnableConfig | None = None
) -> dict[str, Any]:
"""
获取输入的 JSON Schema。
用于 API 文档生成和验证。
"""
return _get_json_schema(
typ=self.builder.input_schema,
schemas=self.builder.schemas,
channels=self.builder.channels,
name=self.get_name("Input"),
)

get_output_jsonschema 方法#

def get_output_jsonschema(
self, config: RunnableConfig | None = None
) -> dict[str, Any]:
"""
获取输出的 JSON Schema。
用于 API 文档生成和验证。
"""
return _get_json_schema(
typ=self.builder.output_schema,
schemas=self.builder.schemas,
channels=self.builder.channels,
name=self.get_name("Output"),
)

attach_node 方法#

def attach_node(self, key: str, node: StateNodeSpec | None) -> None:
"""
将节点附加到编译后的图中。
这是编译过程的关键步骤,将 StateGraph 的节点转换为 Pregel 节点。
主要工作:
1. 确定输出键
2. 创建更新函数(处理节点返回值)
3. 创建写入条目(状态更新和分支控制)
4. 创建 PregelNode 并添加到图中
参数:
- key: 节点键(START 或节点名)
- node: 节点规范(START 节点为 None)
"""
> 💡 **通俗理解**:这个方法把"设计图纸上的节点"转换成"实际可执行的节点"。
> - 确定节点可以更新哪些状态字段(output_keys)
> - 创建一个"转换器"(_get_updates),把节点的返回值转换成状态更新
> - 创建"写入器"(write_entries),负责把更新写入状态,并控制下一步去哪里
> - 最后创建一个 PregelNode(实际执行单元)并添加到图中
> 就像把"工作步骤说明"转换成"实际的工作指令"。
# 确定输出键
if key == START:
# START 节点输出输入模式的所有键
output_keys = [
k
for k, v in self.builder.schemas[self.builder.input_schema].items()
if not is_managed_value(v)
]
else:
# 普通节点输出所有通道和托管值
output_keys = list(self.builder.channels) + [
k for k, v in self.builder.managed.items()
]
def _get_updates(
input: None | dict | Any,
) -> Sequence[tuple[str, Any]] | None:
"""
将节点返回值转换为更新元组列表。
处理多种返回类型:
- None: 无更新
- dict: 字典键值对
- Command: 命令对象
- list/tuple: 命令列表
- 带注解的对象: 使用 get_update_as_tuples
"""
if input is None:
return None
elif isinstance(input, dict):
return [(k, v) for k, v in input.items() if k in output_keys]
elif isinstance(input, Command):
if input.graph == Command.PARENT:
return None
return [
(k, v) for k, v in input._update_as_tuples() if k in output_keys
]
elif (
isinstance(input, (list, tuple))
and input
and any(isinstance(i, Command) for i in input)
):
updates: list[tuple[str, Any]] = []
for i in input:
if isinstance(i, Command):
if i.graph == Command.PARENT:
continue
updates.extend(
(k, v) for k, v in i._update_as_tuples() if k in output_keys
)
else:
updates.extend(_get_updates(i) or ())
return updates
elif (t := type(input)) and get_cached_annotated_keys(t):
return get_update_as_tuples(input, output_keys)
else:
msg = create_error_message(
message=f"Expected dict, got {input}",
error_code=ErrorCode.INVALID_GRAPH_NODE_RETURN_VALUE,
)
raise InvalidUpdateError(msg)
# 创建写入条目
write_entries: tuple[ChannelWriteEntry | ChannelWriteTupleEntry, ...] = (
ChannelWriteTupleEntry(
mapper=_get_root if output_keys == ["__root__"] else _get_updates
),
ChannelWriteTupleEntry(
mapper=_control_branch,
static=_control_static(node.ends)
if node is not None and node.ends is not None
else None,
),
)
# 添加节点
if key == START:
# START 节点
self.nodes[key] = PregelNode(
tags=[TAG_HIDDEN],
triggers=[START],
channels=START,
writers=[ChannelWrite(write_entries)],
)
elif node is not None:
# 普通节点
input_schema = node.input if node else self.builder.state_schema
input_channels = list(self.builder.schemas[input_schema])
is_single_input = len(input_channels) == 1 and "__root__" in input_channels
# 获取或创建映射器
if input_schema in self.schema_to_mapper:
mapper = self.schema_to_mapper[input_schema]
else:
mapper = _pick_mapper(input_channels, input_schema)
self.schema_to_mapper[input_schema] = mapper
# 创建分支通道
branch_channel = CHANNEL_BRANCH_TO.format(key)
self.channels[branch_channel] = (
LastValueAfterFinish(Any)
if node.defer
else EphemeralValue(Any, guard=False)
)
# 创建 PregelNode
self.nodes[key] = PregelNode(
triggers=[branch_channel],
channels=("__root__" if is_single_input else input_channels),
mapper=mapper,
writers=[ChannelWrite(write_entries)],
metadata=node.metadata,
retry_policy=node.retry_policy,
cache_policy=node.cache_policy,
bound=node.runnable,
)
else:
raise RuntimeError

attach_edge 方法#

def attach_edge(self, starts: str | Sequence[str], end: str) -> None:
"""
将边附加到编译后的图中。
处理两种类型的边:
1. 单源边:直接写入目标节点的分支通道
2. 多源边:使用命名屏障值通道等待所有源节点完成
参数:
- starts: 起始节点(或节点列表)
- end: 结束节点
"""
> 💡 **通俗理解**:这个方法把"设计图纸上的箭头"转换成"实际的执行路径"。
> - 单源边:A → B,直接在 A 的写入器中添加"完成后通知 B"
> - 多源边:[A, B, C] → D,创建一个"集合点通道",A、B、C 都完成后才通知 D
> 就像把"流程图上的箭头"转换成"实际的信号传递机制"。
if isinstance(starts, str):
# 单源边:直接写入目标节点的分支通道
if end != END:
self.nodes[starts].writers.append(
ChannelWrite(
(ChannelWriteEntry(CHANNEL_BRANCH_TO.format(end), None),)
)
)
elif end != END:
# 多源边:创建命名屏障值通道
channel_name = f"join:{'+'.join(starts)}:{end}"
# 注册通道
if self.builder.nodes[end].defer:
# 如果目标节点延迟执行,使用 AfterFinish 版本
self.channels[channel_name] = NamedBarrierValueAfterFinish(
str, set(starts)
)
else:
self.channels[channel_name] = NamedBarrierValue(str, set(starts))
# 订阅通道(目标节点等待)
self.nodes[end].triggers.append(channel_name)
# 发布到通道(源节点写入)
for start in starts:
self.nodes[start].writers.append(
ChannelWrite((ChannelWriteEntry(channel_name, start),))
)

attach_branch 方法#

def attach_branch(
self, start: str, name: str, branch: Branch, *, with_reader: bool = True
) -> None:
"""
将分支附加到编译后的图中。
分支用于条件边,根据状态动态选择下一个节点。
参数:
- start: 源节点名
- name: 分支名
- branch: 分支对象
- with_reader: 是否创建读取器(用于读取状态)
"""
> 💡 **通俗理解**:这个方法把"条件判断逻辑"转换成"实际的路由机制"。
> - 创建一个"读取器"(reader)来读取当前状态
> - 创建一个"写入器"(get_writes)来把分支决策转换成"通知哪个节点"
> - 把分支逻辑附加到源节点上
> 就像把"如果...就..."的逻辑转换成"实际的信号路由"。
def get_writes(
packets: Sequence[str | Send], static: bool = False
) -> Sequence[ChannelWriteEntry | Send]:
"""
将分支返回的包转换为写入条目。
参数:
- packets: 包列表(节点名或 Send 对象)
- static: 是否为静态分支
返回:
- 写入条目列表
"""
writes = [
(
ChannelWriteEntry(
p if p == END else CHANNEL_BRANCH_TO.format(p), None
)
if not isinstance(p, Send)
else p
)
for p in packets
if (True if static else p != END)
]
if not writes:
return []
return writes
if with_reader:
# 获取模式
schema = branch.input_schema or (
self.builder.nodes[start].input
if start in self.builder.nodes
else self.builder.state_schema
)
channels = list(self.builder.schemas[schema])
# 获取映射器
if schema in self.schema_to_mapper:
mapper = self.schema_to_mapper[schema]
else:
mapper = _pick_mapper(channels, schema)
self.schema_to_mapper[schema] = mapper
# 创建读取器
reader: Callable[[RunnableConfig], Any] | None = partial(
ChannelRead.do_read,
select=channels[0] if channels == ["__root__"] else channels,
fresh=True,
mapper=mapper,
)
else:
reader = None
# 附加分支发布器
self.nodes[start].writers.append(branch.run(get_writes, reader))

_migrate_checkpoint 方法#

def _migrate_checkpoint(self, checkpoint: Checkpoint) -> None:
"""
迁移检查点到新的通道布局。
当图的内部结构改变时(如通道命名方式改变),需要迁移旧的检查点。
迁移规则:
1. start:node -> branch:to:node
2. branch:source:condition:node -> branch:to:node
3. node -> branch:to:node(对于所有目标节点)
参数:
- checkpoint: 要迁移的检查点
"""
> 💡 **通俗理解**:这个方法就像"游戏版本更新时的存档迁移"。
> - LangGraph 的内部实现可能会改变(比如通道的命名方式)
> - 但旧的检查点(存档)还在用旧的命名方式
> - 这个方法会自动把旧的命名转换成新的命名,让旧存档可以在新版本中使用
> 就像游戏更新后,旧存档还能继续使用一样。
super()._migrate_checkpoint(checkpoint)
values = checkpoint["channel_values"]
versions = checkpoint["channel_versions"]
seen = checkpoint["versions_seen"]
# 空检查点不需要迁移
if not versions:
return
# 当前版本(v3+)不需要迁移
if checkpoint["v"] >= 3:
return
# 迁移 1: start:node -> branch:to:node
for k in list(versions):
if k.startswith("start:"):
node = k.split(":")[1]
if node not in self.nodes:
continue
new_k = f"branch:to:{node}"
new_v = (
max(versions[new_k], versions.pop(k))
if new_k in versions
else versions.pop(k)
)
# 更新 seen
for ss in (seen.get(node, {}), seen.get(INTERRUPT, {})):
if k in ss:
s = ss.pop(k)
if new_k in ss:
ss[new_k] = max(s, ss[new_k])
else:
ss[new_k] = s
# 更新值
if new_k not in values and k in values:
values[new_k] = values.pop(k)
# 更新版本
versions[new_k] = new_v
# 迁移 2: branch:source:condition:node -> branch:to:node
for k in list(versions):
if k.startswith("branch:") and k.count(":") == 3:
node = k.split(":")[-1]
if node not in self.nodes:
continue
new_k = f"branch:to:{node}"
new_v = (
max(versions[new_k], versions.pop(k))
if new_k in versions
else versions.pop(k)
)
# 更新 seen
for ss in (seen.get(node, {}), seen.get(INTERRUPT, {})):
if k in ss:
s = ss.pop(k)
if new_k in ss:
ss[new_k] = max(s, ss[new_k])
else:
ss[new_k] = s
# 更新值
if new_k not in values and k in values:
values[new_k] = values.pop(k)
# 更新版本
versions[new_k] = new_v
# 迁移 3: node -> branch:to:node(对于所有目标节点)
if not set(self.nodes).isdisjoint(versions):
source_to_target = defaultdict(list)
for start, end in self.builder.edges:
if start != START and end != END:
source_to_target[start].append(end)
for k in list(versions):
if k == START:
continue
if k in self.nodes:
v = versions.pop(k)
c = values.pop(k, MISSING)
for end in source_to_target[k]:
new_k = f"branch:to:{end}"
new_v = max(versions[new_k], v) if new_k in versions else v
# 更新 seen
for ss in (seen.get(end, {}), seen.get(INTERRUPT, {})):
if k in ss:
s = ss.pop(k)
if new_k in ss:
ss[new_k] = max(s, ss[new_k])
else:
ss[new_k] = s
# 更新值
if new_k not in values and c is not MISSING:
values[new_k] = c
# 更新版本
versions[new_k] = new_v
# 弹出中断 seen
if INTERRUPT in seen:
seen[INTERRUPT].pop(k, MISSING)

辅助函数#

_pick_mapper 函数#

def _pick_mapper(
state_keys: Sequence[str], schema: type[Any]
) -> Callable[[Any], Any] | None:
"""
选择状态映射函数。
如果状态键是 __root__ 或模式是字典类型,返回 None(不需要映射)。
否则返回一个将字典转换为模式对象的函数。
"""
if state_keys == ["__root__"]:
return None
if isclass(schema) and issubclass(schema, dict):
return None
return partial(_coerce_state, schema)

_coerce_state 函数#

def _coerce_state(schema: type[Any], input: dict[str, Any]) -> dict[str, Any]:
"""
将字典强制转换为模式对象。
使用模式类的构造函数创建实例。
"""
return schema(**input)

_control_branch 函数#

def _control_branch(value: Any) -> Sequence[tuple[str, Any]]:
"""
处理分支控制值。
将 Command 对象或 Send 对象转换为写入条目。
参数:
- value: 控制值(Command、Send 或列表)
返回:
- 写入条目列表
"""
if isinstance(value, Send):
return ((TASKS, value),)
commands: list[Command] = []
if isinstance(value, Command):
commands.append(value)
elif isinstance(value, (list, tuple)):
for cmd in value:
if isinstance(cmd, Command):
commands.append(cmd)
rtn: list[tuple[str, Any]] = []
for command in commands:
# 如果是父命令,抛出异常
if command.graph == Command.PARENT:
raise ParentCommand(command)
# 获取目标节点
goto_targets = (
[command.goto] if isinstance(command.goto, (Send, str)) else command.goto
)
# 转换为写入条目
for go in goto_targets:
if isinstance(go, Send):
rtn.append((TASKS, go))
elif isinstance(go, str) and go != END:
rtn.append((CHANNEL_BRANCH_TO.format(go), None))
return rtn

_control_static 函数#

def _control_static(
ends: tuple[str, ...] | dict[str, str],
) -> Sequence[tuple[str, Any, str | None]]:
"""
创建静态分支控制条目。
用于节点规范中指定的目标节点。
参数:
- ends: 目标节点(元组或字典)
返回:
- 写入条目列表(包含标签)
"""
if isinstance(ends, dict):
return [
(k if k == END else CHANNEL_BRANCH_TO.format(k), None, label)
for k, label in ends.items()
]
else:
return [
(e if e == END else CHANNEL_BRANCH_TO.format(e), None, None) for e in ends
]

_get_root 函数#

def _get_root(input: Any) -> Sequence[tuple[str, Any]] | None:
"""
处理根值更新。
当状态只有一个根键(__root__)时使用。
参数:
- input: 输入值(Command、列表或普通值)
返回:
- 更新元组列表
"""
if isinstance(input, Command):
if input.graph == Command.PARENT:
return ()
return input._update_as_tuples()
elif (
isinstance(input, (list, tuple))
and input
and any(isinstance(i, Command) for i in input)
):
updates: list[tuple[str, Any]] = []
for i in input:
if isinstance(i, Command):
if i.graph == Command.PARENT:
continue
updates.extend(i._update_as_tuples())
else:
updates.append(("__root__", i))
return updates
elif input is not None:
return [("__root__", input)]

_get_channels 函数#

def _get_channels(
schema: type[dict],
) -> tuple[dict[str, BaseChannel], dict[str, ManagedValueSpec], dict[str, Any]]:
"""
从模式中提取通道和托管值。
参数:
- schema: 模式类型
返回:
- (通道字典, 托管值字典, 类型提示字典)
"""
> 💡 **通俗理解**:这个方法就像"解析表格模板,为每个字段创建邮箱"。
> - 输入:一个状态模式(定义了有哪些字段)
> - 输出:为每个字段创建的通道(邮箱)和托管值(特殊邮箱)
> - 就像你定义了一个表格,系统自动为每列创建一个"数据传递通道"
# 如果没有注解,返回根通道
if not hasattr(schema, "__annotations__"):
return (
{"__root__": _get_channel("__root__", schema, allow_managed=False)},
{},
{},
)
# 获取类型提示(包括额外信息)
type_hints = get_type_hints(schema, include_extras=True)
# 为每个字段创建通道或托管值
all_keys = {
name: _get_channel(name, typ)
for name, typ in type_hints.items()
if name != "__slots__"
}
# 分离通道和托管值
return (
{k: v for k, v in all_keys.items() if isinstance(v, BaseChannel)},
{k: v for k, v in all_keys.items() if is_managed_value(v)},
type_hints,
)

_get_channel 函数#

@overload
def _get_channel(
name: str, annotation: Any, *, allow_managed: Literal[False]
) -> BaseChannel: ...
@overload
def _get_channel(
name: str, annotation: Any, *, allow_managed: Literal[True] = True
) -> BaseChannel | ManagedValueSpec: ...
def _get_channel(
name: str, annotation: Any, *, allow_managed: bool = True
) -> BaseChannel | ManagedValueSpec:
"""
从类型注解创建通道或托管值。
检查顺序:
1. 是否为托管值
2. 是否为显式通道
3. 是否为二元操作符聚合
4. 默认使用 LastValue
参数:
- name: 字段名
- annotation: 类型注解
- allow_managed: 是否允许托管值
返回:
- 通道或托管值规范
"""
> 💡 **通俗理解**:这个方法就像"根据字段类型选择合适的邮箱类型"。
> - 如果字段标注了"托管值"(系统自动管理),就用托管值通道
> - 如果字段标注了特定通道类型,就用那个类型
> - 如果字段有 reducer(聚合函数),就用聚合通道(可以合并多个值)
> - 否则默认用 LastValue(只保留最新值,覆盖旧值)
> 就像根据不同的需求选择不同的"邮箱类型":有的邮箱只存最新邮件,有的可以合并多封邮件。
# 检查托管值
if manager := _is_field_managed_value(name, annotation):
if allow_managed:
return manager
else:
raise ValueError(f"This {annotation} not allowed in this position")
# 检查显式通道
elif channel := _is_field_channel(annotation):
channel.key = name
return channel
# 检查二元操作符聚合
elif channel := _is_field_binop(annotation):
channel.key = name
return channel
# 默认使用 LastValue
fallback: LastValue = LastValue(annotation)
fallback.key = name
return fallback

_is_field_channel 函数#

def _is_field_channel(typ: type[Any]) -> BaseChannel | None:
"""
检查类型注解是否为显式通道。
检查注解的元数据中是否包含 BaseChannel 实例或类。
"""
if hasattr(typ, "__metadata__"):
meta = typ.__metadata__
if len(meta) >= 1 and isinstance(meta[-1], BaseChannel):
return meta[-1]
elif len(meta) >= 1 and isclass(meta[-1]) and issubclass(meta[-1], BaseChannel):
return meta[-1](typ.__origin__ if hasattr(typ, "__origin__") else typ)
return None

_is_field_binop 函数#

def _is_field_binop(typ: type[Any]) -> BinaryOperatorAggregate | None:
"""
检查类型注解是否为二元操作符聚合(reducer)。
检查注解的元数据中是否包含可调用对象,且签名符合 reducer 要求(两个位置参数)。
"""
if hasattr(typ, "__metadata__"):
meta = typ.__metadata__
if len(meta) >= 1 and callable(meta[-1]):
sig = signature(meta[-1])
params = list(sig.parameters.values())
if (
sum(
p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
for p in params
)
== 2
):
return BinaryOperatorAggregate(typ, meta[-1])
else:
raise ValueError(
f"Invalid reducer signature. Expected (a, b) -> c. Got {sig}"
)
return None

_is_field_managed_value 函数#

def _is_field_managed_value(name: str, typ: type[Any]) -> ManagedValueSpec | None:
"""
检查类型注解是否为托管值。
检查注解的元数据中是否包含托管值装饰器。
"""
if hasattr(typ, "__metadata__"):
meta = typ.__metadata__
if len(meta) >= 1:
decoration = get_origin(meta[-1]) or meta[-1]
if is_managed_value(decoration):
return decoration
return None

_get_json_schema 函数#

def _get_json_schema(
typ: type,
schemas: dict,
channels: dict,
name: str,
) -> dict[str, Any]:
"""
生成类型的 JSON Schema。
用于 API 文档和验证。
参数:
- typ: 类型
- schemas: 模式字典
- channels: 通道字典
- name: Schema 名称
返回:
- JSON Schema 字典
"""
# Pydantic 模型
if isclass(typ) and issubclass(typ, BaseModel):
return typ.model_json_schema()
# TypedDict
elif is_typeddict(typ):
return TypeAdapter(typ).json_schema()
# 其他类型
else:
keys = list(schemas[typ].keys())
# 单根键
if len(keys) == 1 and keys[0] == "__root__":
return create_model(
name,
root=(channels[keys[0]].UpdateType, None),
).model_json_schema()
# 多键
else:
return create_model(
name,
field_definitions={
k: (
channels[k].UpdateType,
(
get_field_default(
k,
channels[k].UpdateType,
typ,
)
),
)
for k in schemas[typ]
if k in channels and isinstance(channels[k], BaseChannel)
},
).model_json_schema()

常量#

CHANNEL_BRANCH_TO = "branch:to:{}"
# 分支通道名称格式模板
# 用于创建节点分支通道的名称,例如 "branch:to:node_name"

总结#

state.py 文件实现了 LangGraph 的核心状态图功能:

  1. StateGraph 类:用于构建状态图,支持节点、边、条件边等
  2. CompiledStateGraph 类:编译后的图,实现实际执行逻辑
  3. 通道系统:通过通道实现节点间的状态共享
  4. 模式系统:支持 TypedDict、Pydantic 模型等多种状态模式
  5. 检查点迁移:支持旧版本检查点的自动迁移

关键概念:

  • 节点(Node):图的执行单元,接收状态并返回部分状态更新
  • 边(Edge):定义节点间的执行顺序
  • 条件边(Conditional Edge):根据状态动态选择下一个节点
  • 通道(Channel):节点间共享状态的机制
  • 模式(Schema):定义状态的结构和类型

这个文件是 LangGraph 框架的核心,理解它有助于深入理解整个框架的工作原理。


💡 整体理解:用生活化的比喻#

StateGraph(设计阶段)

  • 就像画流程图,你定义有哪些步骤(节点)、步骤之间的顺序(边)、根据情况选择路径(条件边)
  • 你还可以设置每个步骤的配置:失败后重试几次、结果是否缓存、执行完后可以去哪里

CompiledStateGraph(运行阶段)

  • 就像把流程图转换成实际的工作流程系统
  • 每个节点都有一个”邮箱”(通道)来传递数据
  • 系统会自动调度:哪个节点先执行、什么时候执行、如何传递数据

通道系统

  • 就像公司里的”共享文件夹”或”公告板”
  • 每个状态字段都有自己的”文件夹”(通道)
  • 节点可以读取和写入这些文件夹
  • 不同类型的通道有不同的行为:
    • LastValue:只保留最新值(覆盖旧值)
    • NamedBarrierValue:等待多个节点都完成(集合点)
    • BinaryOperatorAggregate:用函数合并多个值(如累加、列表合并)

执行流程

  1. 你调用 graph.invoke(input)graph.stream(input)
  2. 系统从 START 节点开始
  3. 每个节点执行时:
    • 从通道读取需要的状态字段
    • 执行节点函数
    • 把返回值写入通道(更新状态)
    • 根据边或条件边决定下一个节点
  4. 重复步骤 3,直到到达 END 节点或没有更多节点可执行
  5. 返回最终状态

检查点系统

  • 就像游戏存档,可以在任意时刻保存当前状态
  • 如果出错了,可以从检查点恢复,不用从头开始
  • 支持版本迁移:即使 LangGraph 内部实现改变了,旧存档也能用

🎯 关键理解点#

  1. 状态是共享的:所有节点共享同一个状态对象,通过通道来读写
  2. 节点只返回部分更新:节点不需要返回完整状态,只返回它要更新的字段
  3. 编译是必须的StateGraph 只是设计,必须调用 compile() 才能运行
  4. 通道是核心:节点间的数据传递都通过通道,通道类型决定了数据如何合并
  5. 类型提示很重要:LangGraph 会从类型提示推断很多信息,写好类型提示可以让代码更清晰

希望这些解释能帮助你更好地理解 LangGraph 的工作原理!🚀

LangGraph底层源码解析
https://hygenluo.github.io/posts/agent/langgraph/原理/002底层源码解析/
Author
Hygen
Published at
2025-11-08
License
CC BY-NC-SA 4.0

Some information may be outdated

封面
示例歌曲
示例艺术家
封面
示例歌曲
示例艺术家
0:00 / 0:00