In-Depth Analysis of PyTorch Dynamic Graphs

Click on the aboveBeginner’s Guide to Vision”, select to add “Bookmark” or “Pin

Heavyweight content delivered at the first moment

This article is adapted from: Deep Learning Matters

Background
The dynamic graph framework of PyTorch is primarily implemented in the code under torch/csrc/autograd. This directory defines three main base classes: Variable, Function, Engine, which together form the foundation of PyTorch’s dynamic graph.
Why is it called a dynamic graph?The graph is easy to understand, with Functions as nodes/vertices, and (Function, input_nr) as edges. So where is the dynamic aspect? Every time a forward pass occurs, a graph is constructed, and it is destroyed during the backward pass. This article will delve into the dynamic graph system of PyTorch based on the code in torch/csrc/autograd—it may be the most detailed article on PyTorch’s dynamic graph available online.
In the column article “PyTorch Initialization” (https://zhuanlan.zhihu.com/p/57571317), gemfield describes the initialization process of PyTorch, mentioning at the end the call to THPAutograd_initFunctions(): “The final THPAutograd_initFunctions() initializes PyTorch’s automatic differentiation system, which is the basis of the dynamic graph framework.” This article will start with THPAutograd_initFunctions and guide you into the world of PyTorch’s dynamic graphs, primarily introducing the inheritance system of the classes Function, Variable, and Engine.
Autograd Initialization
The function THPAutograd_initFunctions is implemented as follows:
void THPAutograd_initFunctions(){  THPObjectPtr module(PyModule_New("torch._C._functions"));  ......  generated::initialize_autogenerated_functions();  auto c_module = THPObjectPtr(PyImport_ImportModule("torch._C"));}

This is used to initialize the cpp_function_types table, which maintains the mapping from C++ function types to Python types:
static std::unordered_map<std::type_index, thpobjectptr=""> cpp_function_types

</std::type_index,>
This table stores the mapping of functions related to autograd. What is its purpose? For example, when I print a Variable’s grad_fn in Python:
>>> gemfield = torch.empty([2,2],requires_grad=True)
>>> syszux = gemfield * gemfield
>>> syszux.grad_fn
<thmulbackward 0x7f111621c350="" at="" object="">

</thmulbackward>
The grad_fn is an instance of Function. We defined many backward functions in C++, but how do we access them in Python? It’s all thanks to the mapping in the table above. In fact, the cpp_function_types mapping table is designed to facilitate printing grad_fn in Python.
Variable
Reference: https://zhuanlan.zhihu.com/p/64135058
Using the following code snippet as an example:
gemfield = torch.ones(2, 2, requires_grad=True)
syszux = gemfield + 2
civilnet = syszux * syszux * 3
gemfieldout = civilnet.mean()
gemfieldout.backward()

It should be noted that the dynamic graph is constructed during the forward pass. gemfieldout serves as the final output of the forward pass, but during backpropagation, it is the initial input—in the dynamic graph, we call this the root. In the following introduction to Engine, you will see that we will use gemfieldout as the root to construct a GraphRoot instance, which will serve as the input for the Graph.
Function
Before introducing Function, let’s continue with the previous code example. During a single forward pass, we will create the following instances of Variable and Function:
# Variable instance
gemfield --> grad_fn_ (Function instance) = None         --> grad_accumulator_ (Function instance) = AccumulateGrad instance 0x55ca7f304500         --> output_nr_ = 0
# Function instance, 0x55ca7f872e90 AddBackward0 instance --> sequence_nr_ (uint64_t) = 0            --> next_edges_ (edge_list) --> std::vector<edge> = [(AccumulateGrad instance, 0), (0, 0)]            --> input_metadata_ --> [(type, shape, device)...] = [(CPUFloatType, [2, 2], cpu)]            --> alpha (Scalar) = 1            --> apply() --> uses AddBackward0's apply
# Variable instance syszux --> grad_fn_ (Function instance) = AddBackward0 instance 0x55ca7f872e90       --> output_nr_ = 0
# Function instance, 0x55ca7ebba2a0 MulBackward0 --> sequence_nr_ (uint64_t) = 1            --> next_edges_ (edge_list) = [(AddBackward0 instance 0x55ca7f872e90, 0), (AddBackward0 instance 0x55ca7f872e90, 0)]            --> input_metadata_ --> [(type, shape, device)...] = [(CPUFloatType, [2, 2], cpu)]            --> alpha (Scalar) = 1            --> apply() --> uses MulBackward0's apply
# Variable instance, syszux * syszux gets tmp --> grad_fn_ (Function instance) = MulBackward0 instance 0x55ca7ebba2a0    --> output_nr_ = 0
# Function instance, 0x55ca7fada2f0 MulBackward0 --> sequence_nr_ (uint64_t) = 2 (incremented within each thread)            --> next_edges_ (edge_list) = [(MulBackward0 instance 0x55ca7ebba2a0, 0), (0, 0)]            --> input_metadata_ --> [(type, shape, device)...] = [(CPUFloatType, [2, 2], cpu)]            --> self_ (SavedVariable) = shallow copy of tmp            --> other_ (SavedVariable) = shallow copy of 3            --> apply() --> uses MulBackward0's apply
# Variable instance civilnet --> grad_fn_ (Function instance) = MulBackward0 instance 0x55ca7fada2f0                                          -
# Function instance, 0x55ca7eb358b0 MeanBackward0 --> sequence_nr_ (uint64_t) = 3 (incremented within each thread)              --> next_edges_ (edge_list) = [(MulBackward0 instance 0x55ca7fada2f0, 0)]              --> input_metadata_ --> [(type, shape, device)...] = [(CPUFloatType|[]|cpu)]              --> self_sizes (std::vector<int64_t>) = (2, 2)              --> self_numel = 4              --> apply() --> uses MulBackward0's apply
# Variable instance gemfieldout --> grad_fn_ (Function instance) = MeanBackward0 instance 0x55ca7eb358b0            --> output_nr_ = 0

</int64_t></edge>
These Function instances used for backward computation are connected through next_edges_. Since the actual execution of these Functions occurs during backpropagation, the output relationship is exactly the reverse of that during the forward pass. They are connected through next_edges_. To summarize in a graph, it looks like this:
In-Depth Analysis of PyTorch Dynamic Graphs
This introduces a new topic—how the Function class is abstracted.
#Function Base Class Definition
The data members of Function are as follows:
using edge_list = std::vector<edge>;
using variable_list = std::vector<variable>;
struct TORCH_API Function {...  virtual variable_list apply(variable_list&& inputs) = 0;...  const uint64_t sequence_nr_;  edge_list next_edges_;  PyObject* pyobj_ = nullptr; // weak reference  std::unique_ptr<anomalymetadata> anomaly_metadata_ = nullptr;  std::vector<std::unique_ptr<functionprehook>> pre_hooks_;  std::vector<std::unique_ptr<functionposthook>> post_hooks_;  at::SmallVector<inputmetadata, 2=""> input_metadata_;};

</inputmetadata,></std::unique_ptr<functionposthook></std::unique_ptr<functionprehook></anomalymetadata></variable></edge>
#Function Call
The Function class is an abstract base class representing an operation (op). Each op can accept zero, one, or multiple instances of Variable (wrapped in std::vector), and simultaneously output zero, one, or multiple instances of Variable. All functions used for backpropagation in PyTorch inherit from the Function class and override the pure virtual function apply in the Function class. Since the Function class implements the call function:
variable_list operator()(variable_list&& inputs) {
  return apply(std::move(inputs));
}

Thus, due to C++ polymorphism, calling an op will convert to calling its own (subclass) apply. The most important method in the Function class is the call function, which will invoke apply. The call function receives multiple instances of Variable wrapped in a vector and outputs multiple instances of Variable wrapped in a vector. The length of the input parameter vector can be obtained by calling num_inputs(), and correspondingly, the length of the output vector can be obtained by calling num_outputs().
#Function Inputs
The member input_metadata_ of Function represents the meta information of the input data, defining the input of a Function:
struct InputMetadata {...  const at::Type* type_ = nullptr;  at::DimVector shape_;  at::Device device_ = at::kCPU;};

#Edges and Vertices of the Autograd Graph
If we consider PyTorch’s autograd system as a graph, then each Function instance is a node (nodes/vertices) in the graph, and various Function instances are connected by edges. An edge is a structure that represents an edge in the graph through the pairing of (Function, input_nr):
struct Edge {...  std::shared_ptr<function> function;  uint32_t input_nr;};

</function>
The member next_edges_ of Function is precisely a set of such Edge instances, representing the return values of this function instance that need to be output to another function, meaning that next_edges_ is the link between function and function.
Both the input and output of Function are instances of Variable, so when a graph is executed, instances of Variable flow between these edges. When two or more edges point to the same Function (the indegree of this node is greater than 1), the outputs of these edges will implicitly sum together and be sent to the target Function.
Functions are interconnected through the next_edge interface, and you can use add_next_edge() to add an edge to a Function, retrieve the corresponding edge using next_edge(index), and obtain an iterator for iterating over edges using the next_edges() method. Each Function has a sequence number that monotonically increases as Function instances are continuously constructed. You can obtain a Function’s sequence number using the sequence_nr() method.
Function Inheritance System
The base class Function directly derives from TraceableFunction and the following Functions:
CopySlices : public Function 
DelayedError : public Function 
Error : public Function 
Gather : public Function 
GraphRoot : public Function 
Scatter : public Function 
AccumulateGrad : public Function 
AliasBackward : public Function 
AsStridedBackward : public Function 
CopyBackwards : public Function 
DiagonalBackward : public Function 
ExpandBackward : public Function 
IndicesBackward0 : public Function 
IndicesBackward1 : public Function 
PermuteBackward : public Function 
SelectBackward : public Function 
SliceBackward : public Function 
SqueezeBackward0 : public Function 
SqueezeBackward1 : public Function 
TBackward : public Function 
TransposeBackward0 : public Function 
UnbindBackward : public Function 
UnfoldBackward : public Function 
UnsqueezeBackward0 : public Function 
ValuesBackward0 : public Function 
ValuesBackward1 : public Function 
ViewBackward : public Function

Among these, the derived classes AccumulateGrad, TraceableFunction, and GraphRoot are particularly critical.
#Derived Class AccumulateGrad
First, let’s talk about AccumulateGrad, which is the type of the grad_accumulator_ member of Variable:
struct AccumulateGrad : public Function {  explicit AccumulateGrad(Variable variable_);  variable_list apply(variable_list&& grads) override;  Variable variable;};

It can be seen that an instance of AccumulateGrad must be constructed with a Variable, and the apply call receives a list of Variable instances—all related to the grad_accumulator_ of Variable.
#Derived Class GraphRoot
For GraphRoot, the final output during the forward pass—acting as the initial input during the backward pass—is encapsulated by GraphRoot:
struct GraphRoot : public Function {  GraphRoot(edge_list functions, variable_list inputs)      : Function(std::move(functions)),        outputs(std::move(inputs)) {}  variable_list apply(variable_list&& inputs) override {    return outputs;  }  variable_list outputs;};

GraphRoot—just as the soul of Function lies in apply—its apply function merely returns its inputs!
#Derived Class TraceableFunction
Now let’s discuss TraceableFunction:
struct TraceableFunction : public Function {
  using Function::Function;
  bool is_traceable() final {
    return true;
  }
};

TraceableFunction will further derive 372 subclasses (as of April 2019), all of which share a common part in their names: Backward. What does this indicate? These functions will only be used in backpropagation:
AbsBackward : public TraceableFunction 
AcosBackward : public TraceableFunction 
AdaptiveAvgPool2DBackward : public TraceableFunction 
AdaptiveAvgPool2DBackward : public TraceableFunction 
AdaptiveAvgPool3DBackward : public TraceableFunction 
AdaptiveAvgPool3DBackward : public TraceableFunction 
AdaptiveMaxPool2DBackward : public TraceableFunction 
AdaptiveMaxPool2DBackward : public TraceableFunction 
AdaptiveMaxPool3DBackward : public TraceableFunction 
AdaptiveMaxPool3DBackward : public TraceableFunction 
AddBackward0 : public TraceableFunction 
AddBackward1 : public TraceableFunction 
AddbmmBackward : public TraceableFunction 
AddcdivBackward : public TraceableFunction 
AddcmulBackward : public TraceableFunction 
AddmmBackward : public TraceableFunction 
AddmvBackward : public TraceableFunction 
AddrBackward : public TraceableFunction 
......

These 300+ Backward functions all override the apply function to implement their own backpropagation algorithms, such as the backward function for addition AddBackward0:
struct AddBackward0 : public TraceableFunction {  using TraceableFunction::TraceableFunction;  variable_list apply(variable_list&& grads) override;  Scalar alpha;};

These apply functions are the soul of Function and the core execution logic during backpropagation calculations.
Engine
The Engine class implements the backpropagation from the output variable (and its gradients) to the root variables (which the user creates and sets requires_grad=True).
gemfield = torch.ones(2, 2, requires_grad=True)
syszux = gemfield + 2
civilnet = syszux * syszux * 3
gemfieldout = civilnet.mean()
gemfieldout.backward()

Continuing with the above code snippet, the Engine implements backpropagation from gemfieldout to gemfield:
1. How to construct GraphRoot based on gemfieldout;
2. How to build the graph based on these Function instances and their metadata;
3. How to implement the Queue to perform the backward computation work in multiple threads.
#Engine Class Definition
The definition of the Engine class is as follows:
struct Engine {  using ready_queue_type = std::deque<std::pair<std::shared_ptr<function>, InputBuffer>>;  using dependencies_type = std::unordered_map<function*, int="">;  virtual variable_list execute(const edge_list&& roots,const variable_list&& inputs,...const edge_list&& outputs = {});  void queue_callback(std::function<void()> callback);protected:  void compute_dependencies(Function* root, GraphTask& task);  void evaluate_function(FunctionTask& task);  void start_threads();  virtual void thread_init(int device);  virtual void thread_main(GraphTask *graph_task);  std::vector<std::shared_ptr<readyqueue>> ready_queues;};

</std::shared_ptr<readyqueue></void()></function*,></std::pair<std::shared_ptr<function>
The core is the execute function, which receives a set of edges—(Function, input number) pairs—as input to the function, and then continuously finds the next edge through next_edge, ultimately completing the computation of the entire graph.
#Derived Class PythonEngine
However, what we actually use is the derived class of Engine: PythonEngine. The PythonEngine subclass overrides the parent class’s execute function, but merely provides the functionality of translating C++ exceptions into Python exceptions; the core work is still performed by the base Engine class:
struct PythonEngine : public Engine

Throughout the entire PyTorch program, only one instance of Engine is maintained, which is the instance of PythonEngine.
Backward Propagation Call Stack
Since the Engine is used to compute the backward propagation of the network, let’s take a look at how this call stack reaches the Engine class. If we perform backward computation on gemfieldout, the call stack is as follows:
#torch/tensor.py,self is gemfieldout
def backward(self, gradient=None, retain_graph=None, create_graph=False)|V#torch.autograd.backward(self, gradient, retain_graph, create_graph)
#torch/autograd/__init__.py
def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None)|VVariable._execution_engine.run_backward(tensors, grad_tensors, retain_graph, create_graph,allow_unreachable=True)
#transform to Variable._execution_engine.run_backward((gemfieldout,), (tensor(1.),), False, False,True)|V#torch/csrc/autograd/python_engine.cpp
PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)|V#torch/csrc/autograd/python_engine.cpp
variable_list PythonEngine::execute(const edge_list&& roots, const variable_list&& inputs, bool keep_graph, bool create_graph, const edge_list&& outputs)|V#torch/csrc/autograd/engine.cpp

Conclusion

In the next article, Gemfield will mainly introduce how the Engine class runs the PyTorch dynamic graph in gemfieldout.backward().
Good news!
Beginner's Guide to Vision Knowledge Planet
is now open to the public👇👇👇




Download 1: OpenCV-Contrib Extension Module Chinese Version Tutorial
Reply "Extension Module Chinese Tutorial" in the background of the "Beginner's Guide to Vision" WeChat public account to download the first OpenCV extension module tutorial in Chinese available online, covering installation of extension modules, SFM algorithms, stereo vision, object tracking, biological vision, super-resolution processing, and more than twenty chapters.

Download 2: Python Vision Practical Project 52 Lectures
Reply "Python Vision Practical Project" in the background of the "Beginner's Guide to Vision" WeChat public account to download 31 practical vision projects, including image segmentation, mask detection, lane line detection, vehicle counting, eye line addition, license plate recognition, character recognition, emotion detection, text content extraction, facial recognition, etc., to assist in quickly learning computer vision.

Download 3: OpenCV Practical Project 20 Lectures
Reply "OpenCV Practical Project 20 Lectures" in the background of the "Beginner's Guide to Vision" WeChat public account to download 20 practical projects based on OpenCV, achieving advanced learning of OpenCV.

Discussion Group

You are welcome to join the reader group of the public account to communicate with peers. Currently, there are WeChat groups for SLAM, 3D vision, sensors, autonomous driving, computational photography, detection, segmentation, recognition, medical imaging, GAN, algorithm competitions, etc. (these will gradually be subdivided). Please scan the WeChat ID below to join the group, and note: "Nickname + School/Company + Research Direction", for example: "Zhang San + Shanghai Jiao Tong University + Vision SLAM". Please follow the format, otherwise, you will not be approved. After successful addition, you will be invited into relevant WeChat groups based on research direction. Please do not send advertisements in the group, or you will be removed. Thank you for your understanding~




Leave a Comment