Map-reduce operations are crucial for efficient task decomposition and parallel processing. This method involves breaking down tasks into smaller sub-tasks, processing each sub-task in parallel, and aggregating the results of all completed sub-tasks.
Given a general topic from the user, generate a list of related topics, create a joke for each topic, and select the best joke from the resulting list.
In this design pattern, the first node might generate a list of objects (e.g., related topics), and we want to apply other nodes (e.g., generating a joke) to all these objects (e.g., topics).
However, two main challenges arise.
(1) The number of objects (e.g., subjects) might be unknown in advance when we lay out the graph (which means the number of edges might be unknown).
(2) The input states of downstream nodes should be different (one for each generated object).
LangGraph addresses these challenges through its Send API. By utilizing conditional edges, Send can distribute different states (e.g., subjects) to multiple instances of nodes (e.g., joke generation).
Importantly, the state sent can differ from the core graph’s state, allowing for flexible and dynamic workflow management.
The flow of the image is easy to understand; let’s see how the Send API is actually used.
import operator
from typing import Annotated
from typing_extensions import TypedDict
from langchain_ollama import ChatOllama
import base_conf
from langgraph.types import Send
from langgraph.graph import END, StateGraph, START
from pydantic import BaseModel, Field
# Models and prompts
# Define the models and prompts we will use
subjects_prompt = """Generate 3 to 5 examples related to the following topic, separated by commas: {topic}."""
joke_prompt = """Generate a joke about {subject}"""
best_joke_prompt = """Here are some jokes about {topic}. Please select the best one! Return the ID of the best joke.{jokes}"""
# Define data models
class Subjects(BaseModel):
subjects: list[str]
class Joke(BaseModel):
joke: str
class BestJoke(BaseModel):
id: int = Field(description="Index of the best joke, starting from 0", ge=0)
# Initialize the models used
model = ChatOllama(base_url=base_conf.base_url, model=base_conf.model_name, temperature=0.7)
# Components of the graph: define the components that make up the state graph
# This is the overall state of the main graph
# Contains a topic (provided by the user)
# Then generates a list of topics and generates a joke for each topic
class OverallState(TypedDict):
topic: str
subjects: list
# Note here we use operator.add
# Because we need to merge the jokes generated by each node into a single list
# This is equivalent to a "reduce" operation
jokes: Annotated[list, operator.add]
best_selected_joke: str
# This is the state definition of nodes, we will map all generated topics to generate jokes
class JokeState(TypedDict):
subject: str
# Function to generate topics
def generate_topics(state: OverallState):
prompt = subjects_prompt.format(topic=state["topic"])
response = model.with_structured_output(Subjects).invoke(prompt)
return {"subjects": response.subjects}
# Function to generate jokes based on topics
def generate_joke(state: JokeState):
prompt = joke_prompt.format(subject=state["subject"])
response = model.with_structured_output(Joke).invoke(prompt)
return {"jokes": [response.joke]}
# Define logic to map generated topics to joke generation nodes
# We will use this function as an edge in the graph
def continue_to_jokes(state: OverallState):
# Return a set of `Send` objects
# Each `Send` object contains the name of the node in the graph and the state sent to that node
return [Send("generate_joke", {"subject": s}) for s in state["subjects"]]
# Function to select the best joke
def best_joke(state: OverallState):
jokes = "\n\n".join(state["jokes"])
prompt = best_joke_prompt.format(topic=state["topic"], jokes=jokes)
response = model.with_structured_output(BestJoke).invoke(prompt)
return {"best_selected_joke": state["jokes"][response.id]}
# Build the state graph: combine all parts to build the state graph
graph = StateGraph(OverallState)
graph.add_node("generate_topics", generate_topics)
graph.add_node("generate_joke", generate_joke)
graph.add_node("best_joke", best_joke)
graph.add_edge(START, "generate_topics")
graph.add_conditional_edges("generate_topics", continue_to_jokes, ["generate_joke"])
graph.add_edge("generate_joke", "best_joke")
graph.add_edge("best_joke", END)
# Compile the graph and generate the application
app = graph.compile()
for s in app.stream({"topic": "Animals"}):
print(s)
{'generate_topics': {'subjects': ['The importance of animal protection', 'The benefits of interacting with animals for humans', 'Technology applications in zoos']}}{'generate_joke': {'jokes': ['Why don’t the animals in the zoo join protection organizations? Because they are already in the “circle”!']}}{'generate_joke': {'jokes': ['Why do programmers like to work with cats? Because they come and go, making debugging easier!']}}{'generate_joke': {'jokes': ['The technology in the zoo is so advanced, the giant panda can swipe into the bamboo forest with its phone, even the staff exclaims: Is this going to create a “panda economy”?']}}{'best_joke': {'best_selected_joke': 'Why don’t the animals in the zoo join protection organizations? Because they are already in the “circle”!'}}
The core code is this line:
def continue_to_jokes(state: OverallState):
# Return a set of `Send` objects
# Each `Send` object contains the name of the node in the graph and the parameters sent to that node
return [Send("generate_joke", {"subject": s}) for s in state["subjects"]]
graph.add_conditional_edges("generate_topics", continue_to_jokes, ["generate_joke"])
Its core idea is to dynamically generate corresponding nodes and then route these nodes to the generate_joke node, ultimately obtaining all jokes and selecting the funniest one.
Isn’t it interesting? You should try it too!!!
https://langchain-ai.github.io/langgraph/how-tos/map-reduce/