Multi-Task Loss in Pytorch: Add or Backward Separately?

Multi-Task Loss in Pytorch: Add or Backward Separately?

MLNLP ( Machine Learning Algorithms and Natural Language Processing ) community is a well-known natural language processing community at home and abroad, covering domestic and foreign NLP master’s and doctoral students, university teachers, and corporate researchers.The vision of the community is to promote communication between the academic and industrial circles of natural language processing and machine learning at home and abroad, especially for the progress of beginners.

This article is reproduced from | Jishi Platform

Author | Yai Gang Xiaozhang@Zhihu

Source丨https://zhuanlan.zhihu.com/p/451441329

1

『The Original Intention of Writing This Article』

Recently, while reproducing the training code of a paper, I found that the total <span>loss</span> in the original paper consists of multiple <span>losses</span>. If there is only one <span>loss</span>, then you can directly use <span>loss.backward()</span>, but here there is more than one. At first, when I saw more than one <span>loss</span>, I didn’t know where to put <span>backward()</span>.

for j in range(len(output)):
    loss += criterion(output[j], target_var)

We know that the traditional gradient backpropagation steps are as follows:

outputs = model(images)
loss = criterion(outputs,target)

optimizer.zero_grad()  
loss.backward()
optimizer.step()
  • First, the model calculates the corresponding loss function through the input images and labels;
  • Then clear the previous gradients<span>optimizer.zero_grad()</span>;
  • Perform backpropagation and calculate the current gradient<span>loss.backward()</span> to backpropagate and compute the current gradient;
  • Update network parameters based on the current gradient. Generally, a gradient is calculated for each incoming <span>batch</span> of data, and then the network is updated<span>optimizer.step()</span>

Now, I need to calculate the <span>loss</span> in a <span>for</span> loop, so I was wondering if I need to perform the <span>backward()</span> calculation in the <span>for</span> loop?

for j in range(len(output)):
    loss += criterion(output[j], target_var)
    loss.backward()

However, when I used the backward method after calculating a loss, I encountered an error: Pytorch – RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

The reason is that in <span>Pytorch</span>, a computation graph only allows a single loss backpropagation calculation. After each gradient backpropagation, the intermediate variables are released. Therefore, if you want to calculate the graph’s gradient again in this <span>batch</span>, the program will find that the intermediate computation graph is gone, and naturally, it will not be able to compute the gradient.

I saw a solution online that is to add <span>retain_grad=True</span> in the <span>backward</span>, that is, <span>backward(retain_graph=True)</span>.

This means temporarily not releasing the computation graph, so the computation graph will not be released during the subsequent training process, but will continue to accumulate. However, as training progresses, <span>OOM</span> may occur. Therefore, at the last <span>loss</span> calculation, you need to remove <span>(retain_graph=True)</span>, that is, only use <span>backward()</span>, which means that only the final <span>loss</span> needs to release resources and compute gradients, while the previous several <span>losses</span> do not perform this step.

for j in range(len(output)):
    loss += criterion(output[j], target_var)
loss.backward()

Some may ask, why not write it this way? I used to do it like this, but I found that the <span>loss</span> did not decrease, so I started to look for reasons in the <span>loss</span>, but I still do not understand why it did not decrease. I hope those who understand can communicate with me~

In fact, when encountering such situations, the best way is to write them separately and then sum them into a total <span>loss</span> for the <span>backward</span> calculation. For example:

loss1= Loss(output[0], target)
loss2= Loss(output[1], target)
loss3= Loss(output[2], target)
loss4= Loss(output[3], target)

loss = loss1 + loss2 + loss3 + loss4
loss.backward()

When I wrote it this way, the <span>loss</span> decreased normally. Seeing that the <span>loss</span> was decreasing normally, I felt a bit relieved.

2

『Other Possible Reasons for Errors』

While checking the information, I found that even calculating a single <span>loss</span> may also lead to errors.

  • It is possible that you are calculating on one <span>cpu</span> and another on <span>gpu</span>, so just set the devices to be the same.
  • It is also possible that in multiple loops, some inputs do not require gradient calculation; at this time, you can set the input’s <span>require_grad</span> to <span>False</span>.
  • Regarding the <span>require_grad</span> attribute in tensors: If a tensor has its <span>requires_grad=True</span>, then calling the <span>backward</span> method will compute the gradient of this tensor during backpropagation. However, it is essential to note that after computing the gradient, this gradient will not necessarily be preserved in the <span>grad</span> attribute, only for leaf nodes with <span>requires_grad=True</span> will the gradient be preserved in the <span>grad</span> attribute. For non-leaf nodes, i.e., intermediate nodes, we generally release the calculated gradients after computing them for more efficient memory utilization.
  • When using <span>LSTM</span> and <span>GRU</span> networks, I think it is because they not only compute gradients from forward to backward but also backward to forward, so it can be seen as gradients propagating in both directions. During this process, there may be overlapping parts. Therefore, you may need to use <span>detach</span> to truncate. In the source code, the comment for <span>detach</span> is: Returns a new Variable, detached from the current graph. It makes a node a variable that does not require gradients, detaching it from the current computation graph. Therefore, when backpropagating through this node, the gradient will not propagate forward from this node.

<span>detach()</span> and <span>detach_()</span>

<span>pytorch</span> has two functions <span>detach()</span> and <span>detach_()</span><code>, both of which have similar names and functions, used to cut off the backpropagation of gradients. So when will they be used?

When training a network, you may want to keep some network parameters unchanged while only adjusting part of the network parameters; or only train part of the branch network of the network without letting its gradients affect the main network’s gradients. At this time, we can use these two functions to truncate the backpropagation of gradients.

The difference between the two is that <span>detach_()</span> modifies the original tensor, while <span>detach()</span> generates a new <span>tensor</span>.

Using <span>detach()</span> will return a new <span>Variable</span>. Although it is detached from the current computation graph, it still points to the original variable’s storage location, sharing the same memory area. After using <span>detach</span>, its <span>requires_grad</span> attribute becomes <span>False</span>, meaning it no longer requires gradient computation. Even if you later set its <span>requires_grad</span> back to <span>True</span>, it still won’t have a gradient <span>grad</span>. Thus, we will continue using this new variable for computation, and during backpropagation, gradients will be computed until reaching the node that called <span>detach()</span>, at which point the gradient will stop propagating.

However, the returned variable and the original node share the same memory area, so if you modify it after using <span>detach</span>, it may lead to errors when calling <span>backward()</span>.

Using <span>tensor.detach_()</span> will detach a <span>tensor</span> from the graph that created it and set it as a leaf node. For example:

Assuming the initial variable relationship is:<span>x</span> -><span>m</span> -> <span>y</span>, then the leaf node here is <span>x</span>. When you perform <span>m.detach_()</span>, it first cancels the association between <span>m</span> and the previous node <span>x</span>, and <span>grad_fn</span> becomes <span>None</span>. At this point, the relationship becomes <span>x</span>, <span>m</span> -><span>y</span>, and now <span>m</span> becomes a leaf node. Then, if you set the <span>requires_grad</span> attribute of <span>m</span> to <span>False</span>, when we perform <span>backward()</span> on <span>y</span>, it will not compute the gradient of <span>m</span>.

3

『How to Write a More Memory-Efficient Backward』

Speaking of gradient backpropagation, I have also seen some people’s writing online that aims to save memory:

for i, (images, target) in enumerate(train_loader):
    images = images.cuda(non_blocking=True)
    target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)
    outputs = model(images)
    loss = criterion(outputs, target)
    loss = loss / accumulation_steps   

    loss.backward()

    if (i + 1) % accumulation_steps == 0:
        optimizer.step()       
        optimizer.zero_grad()
  • First, perform forward propagation, inputting data into the network for inference to get results
  • Input the predicted results and <span>label</span> into the loss function to compute the loss
  • Perform backpropagation to compute gradients
  • Repeat the previous steps, first without clearing gradients, but rather accumulating gradients, and when the gradients reach a fixed count, update network parameters, then set gradients to zero

Gradient accumulation means getting one <span>batch</span> of data each time, calculating gradients once, but not clearing them, instead accumulating gradients, and continuously accumulating until reaching a certain count, then updating network parameters and clearing gradients for the next loop.

Through this delayed parameter update method, it can achieve effects similar to using a large <span>batch size</span>. In my usual experimental process, I generally adopt gradient accumulation techniques, and in most cases, models trained with gradient accumulation perform significantly better than those trained with small batch sizes.

Under certain conditions, the larger the <span>batch size</span>, the better the training effect. Gradient accumulation effectively expands the <span>batch size</span>. If <span>accumulation_steps</span> is 8, then the <span>batch size</span> is effectively expanded by 8 times. When using, it is necessary to note that the learning rate should also be appropriately increased: because more samples are used, the gradients become more stable.

Some may ask, why not directly sum the multiple <span>batch</span> losses first and then average, followed by backpropagation and updates?

In my understanding, this is to reduce memory consumption. When summing multiple <span>batch</span> losses and then averaging before backpropagation, we will perform <span>accumulation_steps</span> forward computations, and each forward computation generates a computation graph. In other words, this approach will generate <span>accumulation_steps</span> computation graphs before performing <span>backward</span> calculations.

In contrast, using the above code, after each <span>batch</span> forward computation ends, it performs the <span>backward</span> computation, and after the computation ends, it releases the computation graph. Since the gradient calculations of both processes are accumulated, the results are the same, but the above method will generate at most one computation graph at any time, thus reducing memory consumption during computation.

4

『Conclusion』

In fact, through this discussion, I can only say that I have a slightly deeper understanding, but I still do not fully understand the principles involved. For example, the tracking of <span>autograd</span>, the properties of <span>in-place operations</span>, when <span>requires_grad</span> is <span>True</span>, when it is <span>False</span>, when gradients will be overwritten, etc., I am still confused. Especially the writing method above, I do not understand why the <span>loss</span> suddenly decreased, so I still need to learn and practice more to remember and understand deeply.

I have also seen many articles mention: in fact, most of the writing methods are very efficient, so unless under very heavy memory pressure, one generally will not use too many complicated operations.

The author is not very knowledgeable, if there are any errors in the above content, please feel free to point them out, and I look forward to communicating and discussing with everyone. Thank you!

Technical Group Invitation

Multi-Task Loss in Pytorch: Add or Backward Separately?

△ Long press to add the assistant

Scan the QR code to add the assistant’s WeChat

Please note: Name-School/Company-Research Direction(e.g., Xiaozhang-Harbin Institute of Technology-Dialogue System) to apply for joining the Natural Language Processing/Pytorch technical group

About Us

MLNLP Community ( Machine Learning Algorithms and Natural Language Processing ) is a grassroots academic community jointly built by scholars in natural language processing at home and abroad, which has now developed into a well-known natural language processing community at home and abroad, including well-known brands such as 10,000-person top conference group, AI selection meeting, AI talent meeting, and AI academic meeting, aiming to promote progress between the academic and industrial circles of machine learning and natural language processing and enthusiasts.The community can provide an open communication platform for related practitioners’ further education, employment, and research. Everyone is welcome to pay attention to and join us.

Multi-Task Loss in Pytorch: Add or Backward Separately?

Leave a Comment