
How to Create Parallel Execution Branches

pip install -U langgraph
In this example, we fan out from node A to nodes B and C, then fan in to node D. Through our state, we specify the Reducer add operation.
This will merge or accumulate the value of a specific key in the state, rather than simply overwriting the existing value. For lists, this means joining the new list with the existing list.
Note that LangGraph uses Annotated types to specify the Reducer function for specific keys in the State: it maintains the original type (list) for type checking but allows the Reducer function (add) to be appended to that type without changing the type itself.
import operator
from typing import Annotated, Any
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
class State(TypedDict):
# operator.add reducer fn to only support append
aggregate: Annotated[list, operator.add]
class ReturnNodeValue:
def __init__(self, node_secret: str):
self._value = node_secret
def __call__(self, state: State) -> Any:
print(f"Adding {self._value} to {state['aggregate']}")
return {"aggregate": [self._value]}
builder = StateGraph(State)
builder.add_node("a", ReturnNodeValue("I'm A"))
builder.add_edge(START, "a")
builder.add_node("b", ReturnNodeValue("I'm B"))
builder.add_node("c", ReturnNodeValue("I'm C"))
builder.add_node("d", ReturnNodeValue("I'm D"))
builder.add_edge("a", "b")
builder.add_edge("a", "c")
builder.add_edge("b", "d")
builder.add_edge("c", "d")
builder.add_edge("d", END)
graph = builder.compile()
print(graph.invoke({"aggregate": []}, {"configurable": {"thread_id": "foo"}}))
Adding I'm A to []
Adding I'm B to ["I'm A"]
Adding I'm C to ["I'm A"]
Adding I'm D to ["I'm A", "I'm B", "I'm C"]
{'aggregate': ["I'm A", "I'm B", "I'm C", "I'm D"]}
The above example demonstrates how to fan out and fan in when each path has only one step. But what if a path has multiple steps?
Everything else remains the same, only the construction of the graph changes:
builder = StateGraph(State)
builder.add_node("a", ReturnNodeValue("I'm A"))
builder.add_edge(START, "a")
builder.add_node("b", ReturnNodeValue("I'm B"))
builder.add_node("b2", ReturnNodeValue("I'm B2"))
builder.add_node("c", ReturnNodeValue("I'm C"))
builder.add_node("d", ReturnNodeValue("I'm D"))
builder.add_edge("a", "b")
builder.add_edge("a", "c")
builder.add_edge("b", "b2")
builder.add_edge(["b2", "c"], "d")
builder.add_edge("d", END)
graph = builder.compile()
print(graph.invoke({"aggregate": []}))

Adding I'm A to []
Adding I'm B to ["I'm A"]
Adding I'm C to ["I'm A"]
Adding I'm B2 to ["I'm A", "I'm B", "I'm C"]
Adding I'm D to ["I'm A", "I'm B", "I'm C", "I'm B2"]
{'aggregate': ["I'm A", "I'm B", "I'm C", "I'm B2", "I'm D"]}
class State(TypedDict):
aggregate: Annotated[list, operator.add] # New property added
which: str
class ReturnNodeValue:
def __init__(self, node_secret: str):
self._value = node_secret
def __call__(self, state: State) -> Any:
print(f"Adding {self._value} to {state['aggregate']}")
return {"aggregate": [self._value]}
builder = StateGraph(State)
builder.add_node("a", ReturnNodeValue("I'm A"))
builder.add_edge(START, "a")
builder.add_node("b", ReturnNodeValue("I'm B"))
builder.add_node("c", ReturnNodeValue("I'm C"))
builder.add_node("d", ReturnNodeValue("I'm D"))
builder.add_node("e", ReturnNodeValue("I'm E"))
# Route function to determine which branch to take
def route_bc_or_cd(state: State) -> Sequence[str]:
if state["which"] == "cd":
return ["c", "d"]
return ["b", "c"]
intermediates = ["b", "c", "d"]
builder.add_conditional_edges(
"a",
route_bc_or_cd,
intermediates,
)
for node in intermediates:
builder.add_edge(node, "e")
builder.add_edge("e", END)
graph = builder.compile()
print(graph.invoke({"aggregate": [], "which": "bc"}))
Adding I'm A to []
Adding I'm B to ["I'm A"]
Adding I'm C to ["I'm A"]
Adding I'm E to ["I'm A", "I'm B", "I'm C"]
{'aggregate': ["I'm A", "I'm B", "I'm C", "I'm E"], 'which': 'bc'}

The dashed lines along the way indicate possible routes, while solid lines indicate mandatory routes.
Stable Sorting
Normally, after fanning out, nodes will run in parallel as a single “super step”. Once the super step is completed, each super step’s updates will be applied to the state in order.
If we need to consistently and predictably sort updates from parallel super steps, we should write the output (along with an identifying key) into a separate field in the state,
import operator
from typing import Annotated, Sequence, Any
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, END, START
def reduce_fanouts(left, right):
if left is None:
left = []
if not right:
# Overwrite
return []
return left + right
class State(TypedDict):
aggregate: Annotated[list, operator.add]
fanout_values: Annotated[list, reduce_fanouts]
which: str
class ReturnNodeValue:
def __init__(self, node_secret: str):
self._value = node_secret
def __call__(self, state: State) -> Any:
print(f"Adding {self._value} to {state['aggregate']} in parallel.")
return {
"fanout_values": [
{
"value": [self._value],
"reliability": self._reliability,
}
]
}
builder = StateGraph(State)
builder.add_node("a", ReturnNodeValue("I'm A"))
builder.add_edge(START, "a")
class ParallelReturnNodeValue:
def __init__(self,
self,
node_secret: str,
reliability: float,
):
self._value = node_secret
# Assume we want to sort by reliability
self._reliability = reliability
def __call__(self, state: State) -> Any:
print(f"Adding {self._value} to {state['aggregate']} in parallel.")
return {
"fanout_values": [
{
"value": [self._value],
"reliability": self._reliability,
}
]
}
builder.add_node("b", ParallelReturnNodeValue("I'm B", reliability=0.9))
builder.add_node("c", ParallelReturnNodeValue("I'm C", reliability=0.1))
builder.add_node("d", ParallelReturnNodeValue("I'm D", reliability=0.3))
def aggregate_fanout_values(state: State) -> Any:
# Sort by reliability
ranked_values = sorted(
state["fanout_values"], key=lambda x: x["reliability"], reverse=True
)
return {
"aggregate": [x["value"] for x in ranked_values] + ["I'm E"],
"fanout_values": [],
}
# Finally aggregate here, the aggregation is sorted by reliability
builder.add_node("e", aggregate_fanout_values)
def route_bc_or_cd(state: State) -> Sequence[str]:
if state["which"] == "cd":
return ["c", "d"]
return ["b", "c"]
intermediates = ["b", "c", "d"]
builder.add_conditional_edges("a", route_bc_or_cd, intermediates)
for node in intermediates:
builder.add_edge(node, "e")
builder.add_edge("e", END)
graph = builder.compile()
print(graph.invoke({"aggregate": [], "fanout_values": [], "which": "bc"}))
Adding I'm A to []
Adding I'm B to ["I'm A"] in parallel.
Adding I'm C to ["I'm A"] in parallel.
{'aggregate': ["I'm A", ["I'm B"], ["I'm C"], "I'm E"], 'fanout_values': [], 'which': 'bc'}

From the code and the diagram, we can see that node a is the starting node, which then fans out into three subprocesses: one for b_e, one for c_e, and one for d_e. Since b, c, and d are executed concurrently, the control here actually refers to the order in which b, c, and d flow into e.
To put it more bluntly, it refers to which inputs e needs to receive first; at this point, the functions of nodes b, c, and d have already been completed, but for e, in some scenarios, we need to organize the completion of the tasks by b, c, and d. This is what is meant by stable sorting.
If you are interested, you can try it out yourself to deepen your understanding!
Reference Link:
https://langchain-ai.github.io/langgraph/how-tos/branching/