Home Swift UNIX C Assembly Go Web MCU Research Non-Tech

Why Do We Need to Zero Out Gradients in PyTorch?

2025-01-21 | Research | #Words: 319 | 中文原版

While reading a book today, I noticed that gradients are zeroed out at the end of the optimization algorithm:

Code snippet showing gradient zeroing operation

But the book didn’t provide an explanation. After checking, I found many optimizers also include this operation—why is that?

First, .zero_grad() resets all gradients in the optimizer (i.e., sets them to zero). The question now becomes: Why do we need to zero out gradients?

1. Prevent Unnecessary Gradient Accumulation

When you create a tensor and set its .requires_grad attribute to True, the framework tracks all operations performed on the tensor. During the backward pass, the tensor’s gradients accumulate in its .grad attribute. If you don’t zero out the gradients, unnecessary information will be tracked during training, leading to extra performance and memory overhead.

2. Avoid Ambiguity in Current Gradient Status

Without zeroing, gradients keep accumulating over iterations. You can’t easily determine the current gradient status—storing every previous gradient would waste memory, and those old gradients are typically irrelevant (e.g., for classic loss function gradient calculations).

However, there are cases where gradient accumulation is useful—such as when calculating the total loss across multiple mini-batches in RNNs. Whether to zero out gradients depends on your specific use case.

I hope these will help someone in need~

References

There are various posts and resources explaining this topic, but they’re scattered and some are outdated. Here are two high-quality references if you want to learn more: