Advanced Practice: Adding Conditional Branches to LangGraph

  • • Hello everyone, I am student Xiao Zhang, sharing AI knowledge and practical cases daily.

  • • Welcome to like + follow 👏, continue learning, and keep delivering valuable content.

  • • +v: jasper_8017 Let’s communicate 💬 and progress together 💪.

Overview of Articles in Official Account

Advanced Practice: Adding Conditional Branches to LangGraph

Continuing from the previous article (【AI Agent】【LangGraph】0. Quick Start: Collaborating with LangChain, LangGraph helps you easily build multi-agent systems using graph structures), we learned about the concept and basic construction methods of LangGraph. Today, we will look at advanced usage in LangGraph construction: adding conditions to edges – Conditional edges.

LangGraph constructs a graph data structure with nodes (node) and edges (edge), and its edges can also have conditions. How do we add conditions to edges? We can use the add_conditional_edges function to add conditional edges.

1. Complete Code and Execution

Without further ado, here is the complete code and the execution results. Let’s run it and see the effect.

from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, BaseMessage
from langgraph.graph import END, MessageGraph
import json
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langchain_core.utils.function_calling import convert_to_openai_tool
from typing import List

@tool
def multiply(first_number: int, second_number: int):
    """Multiplies two numbers together."""
    return first_number * second_number

model = ChatOpenAI(temperature=0)
model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])

graph = MessageGraph()

def invoke_model(state: List[BaseMessage]):
    return model_with_tools.invoke(state)

graph.add_node("oracle", invoke_model)

def invoke_tool(state: List[BaseMessage]):
    tool_calls = state[-1].additional_kwargs.get("tool_calls", [])
    multiply_call = None

    for tool_call in tool_calls:
        if tool_call.get("function").get("name") == "multiply":
            multiply_call = tool_call

    if multiply_call is None:
        raise Exception("No adder input found.")

    res = multiply.invoke(
        json.loads(multiply_call.get("function").get("arguments"))
    )

    return ToolMessage(
        tool_call_id=multiply_call.get("id"),
        content=res
    )

graph.add_node("multiply", invoke_tool)

graph.add_edge("multiply", END)

graph.set_entry_point("oracle")

def router(state: List[BaseMessage]):
    tool_calls = state[-1].additional_kwargs.get("tool_calls", [])
    if len(tool_calls):
        return "multiply"
    else:
        return "end"

graph.add_conditional_edges("oracle", router, {
    "multiply": "multiply",
    "end": END,
})

runnable = graph.compile()

response = runnable.invoke(HumanMessage("What is 123 * 456?"))
print(response)

The execution result is as follows:

Advanced Practice: Adding Conditional Branches to LangGraph

2. Code Explanation

Now let’s explain the above code in detail.

2.1 add_conditional_edges

First, we know that we can use add_conditional_edges to add conditions to edges. The relevant code is as follows:

graph.add_conditional_edges("oracle", router, {
    "multiply": "multiply",
    "end": END,
})

add_conditional_edges accepts three parameters:

  • • The first is the name of the first node of the edge.

  • • The second is the condition for the edge.

  • • The third is the mapping of the condition’s return results (mapping to the corresponding node based on the condition results).

As shown in the code above, it means adding an edge to the “oracle” node. This node has two edges, one leading to the “multiply” node and one leading to “END”. How to decide which direction to go: the condition is the router (explained later), if the router returns “multiply”, then it goes in the “multiply” direction; if the router returns “end”, it goes to “END”.

Let’s look at the source code of this function:

def add_conditional_edges(
    self,
    start_key: str,
    condition: Callable[..., str],
    conditional_edge_mapping: Optional[Dict[str, str]] = None,
) -> None:
    if self.compiled:
        logger.warning(
            "Adding an edge to a graph that has already been compiled. This will "
            "not be reflected in the compiled graph."
        )
    if start_key not in self.nodes:
        raise ValueError(f"Need to add_node `{start_key}` first")
    if iscoroutinefunction(condition):
        raise ValueError("Condition cannot be a coroutine function")
    if conditional_edge_mapping and set(
        conditional_edge_mapping.values()
    ).difference([END]).difference(self.nodes):
        raise ValueError(
            f"Missing nodes which are in conditional edge mapping. Mapping "
            f"contains possible destinations: "
            f"{list(conditional_edge_mapping.values())}. Possible nodes are "
            f"{list(self.nodes.keys())}."
        )

    self.branches[start_key].append(Branch(condition, conditional_edge_mapping))

The key part is this line: self.branches[start_key].append(Branch(condition, conditional_edge_mapping)), which adds a branch to the current node.

2.2 Conditional Router

The conditional code is as follows: it checks whether the execution result contains the tool_calls parameter. If it does, it returns “multiply”; if not, it returns “end”.

def router(state: List[BaseMessage]):
    tool_calls = state[-1].additional_kwargs.get("tool_calls", [])
    if len(tool_calls):
        return "multiply"
    else:
        return "end"

2.3 Definition of Each Node

(1) Starting node: oracle

@tool
def multiply(first_number: int, second_number: int):
    """Multiplies two numbers together."""
    return first_number * second_number

model = ChatOpenAI(temperature=0)
model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])

graph = MessageGraph()

def invoke_model(state: List[BaseMessage]):
    return model_with_tools.invoke(state)

graph.add_node("oracle", invoke_model)

This node is a ChatOpenAI with Tools. For a detailed tutorial on using Tools in LangChain, please refer to this article: 【AI大模型应用开发】【LangChain系列】5. LangChain入门:智能体Agents模块的实战详解. In simple terms, the execution result of this node will indicate whether the bound Tools should be used.

(2) Multiply

def invoke_tool(state: List[BaseMessage]):
    tool_calls = state[-1].additional_kwargs.get("tool_calls", [])
    multiply_call = None

    for tool_call in tool_calls:
        if tool_call.get("function").get("name") == "multiply":
            multiply_call = tool_call

    if multiply_call is None:
        raise Exception("No adder input found.")

    res = multiply.invoke(
        json.loads(multiply_call.get("function").get("arguments"))
    )

    return ToolMessage(
        tool_call_id=multiply_call.get("id"),
        content=res
    )

graph.add_node("multiply", invoke_tool)

This node is responsible for executing the Tools.

2.4 Overall Process

Advanced Practice: Adding Conditional Branches to LangGraph

If you find this article helpful, please give it a like and follow ~~~

  • • Hello everyone, I am student Xiao Zhang, sharing AI knowledge and practical cases daily.

  • • Welcome to like + follow 👏, continue learning, and keep delivering valuable content.

  • • +v: jasper_8017 Let’s communicate 💬 and progress together 💪.

Overview of Articles in Official Account

Advanced Practice: Adding Conditional Branches to LangGraph

Leave a Comment