Settings for Reproducible Experiments in PyTorch

Click on the above Beginner’s Guide to Vision” to choose to add “Star” or “Pin

Important content delivered promptly
Author: Alxander@Zhihu (authorized)
Source: https://zhuanlan.zhihu.com/p/448284000
Editor: Jishi Platform

Jishi Guide

During the training process in deep learning, due to random initialization and the randomness of sample reading, repeated experimental results may differ, with some variations being quite large. To ensure rigorous conclusions in papers, a fixed random seed is typically used to make the results deterministic. This article summarizes some methods to achieve deterministic settings, along with detailed code.

Deterministic Settings

1 Random Seed Settings

The random function is the greatest source of uncertainty, including the random initialization of model parameters and the shuffling of samples.

  • PyTorch random seed
  • Python random seed
  • NumPy random seed
# PyTorch
import torch
torch.manual_seed(0)

# Python
import random
random.seed(0)

# Third-party libraries
import numpy as np
np.random.seed(0)

After setting the above random seeds in the CPU version, the experiment can basically be reproduced.

For the GPU version, there are many algorithms implemented as non-deterministic, which are highly efficient but return slightly different values each time. This is mainly due to floating-point precision being discarded, as different floating-point numbers are added in different orders, leading to small differences (in the last decimal place).

2 Deterministic Implementation of GPU Algorithms

There are two sources of uncertainty in GPU algorithms:

  • CUDA convolution benchmarking
  • Non-deterministic algorithms

CUDA convolution benchmarking is aimed at improving running efficiency by selecting the optimal implementation after a trial run of model parameters. Different hardware and the benchmarking process itself introduce noise, leading to uncertainty.

Non-deterministic algorithms: The greatest advantage of GPUs is parallel computing, and if the order can be ignored, synchronization requirements can be avoided, greatly enhancing running efficiency. Therefore, many algorithms have non-deterministic implementations. By setting use_deterministic_algorithms, PyTorch can choose deterministic algorithms.

# Disable benchmarking
torch.backends.cudnn.benchmark=False

# Choose deterministic algorithms
torch.use_deterministic_algorithms()

RUNTIME ERROR

If a PyTorch function interface does not have a deterministic implementation and only has a non-deterministic implementation while use_deterministic_algorithms() is set, it will lead to a runtime error. For example:

>>> import torch
>>> torch.use_deterministic_algorithms(True)
>>> torch.randn(2, 2).cuda().index_add_(0, torch.tensor([0, 1]), torch.randn(2, 2))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: index_add_cuda_ does not have a deterministic implementation, but you set
'torch.use_deterministic_algorithms(True)'. ...

Reason for the error:

The index_add function does not have a deterministic implementation. This error usually occurs when the torch.index_select API is called or directly calling tensor.index_add_.

Solution:

Define a deterministic implementation to replace the called interface. For the torch.index_select API, the following implementation can be used.

def deterministic_index_select(input_tensor, dim, indices):
    """
    input_tensor: Tensor
    dim: dim 
    indices: 1D tensor
    """
    tensor_transpose = torch.transpose(x, 0, dim)
    return tensor_transpose[indices].transpose(dim, 0)

Random Sample Reading

  1. Set the random seed for each thread reading in a multi-threaded situation
  2. Set the sample generator
# Set the random seed for each reading thread
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
# Set the sample shuffle random seed as a parameter of DataLoader
g.manual_seed(0)

DataLoader(
    train_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    worker_init_fn=seed_worker,
    generator=g,
)

References

Reproducibility – PyTorch 1.10.1 documentation

torch.index_select – PyTorch 1.10.1 documentation

Download 1: OpenCV-Contrib Extension Module Chinese Tutorial

Reply with "Extension Module Chinese Tutorial" in the "Beginner's Guide to Vision" public account backend to download the first OpenCV extension module tutorial in Chinese, covering installation of extension modules, SFM algorithms, stereo vision, target tracking, biological vision, super-resolution processing, and more than twenty chapters of content.

Download 2: Python Vision Practical Project 52 Lectures

Reply with "Python Vision Practical Project" in the "Beginner's Guide to Vision" public account backend to download 31 vision practical projects including image segmentation, mask detection, lane line detection, vehicle counting, eyeliner addition, license plate recognition, character recognition, emotion detection, text content extraction, and face recognition to help quickly learn computer vision.

Download 3: OpenCV Practical Project 20 Lectures

Reply with "OpenCV Practical Project 20 Lectures" in the "Beginner's Guide to Vision" public account backend to download 20 practical projects based on OpenCV to advance your OpenCV learning.

Group Chat

Welcome to join the public account reader group to exchange ideas with peers. Currently, we have WeChat groups for SLAM, 3D vision, sensors, autonomous driving, computational photography, detection, segmentation, recognition, medical imaging, GAN, algorithm competitions, etc. (these will gradually be subdivided). Please scan the WeChat account below to join the group, with the note: "Nickname + School/Company + Research Direction", for example: "Zhang San + Shanghai Jiao Tong University + Visual SLAM". Please follow the format for notes; otherwise, you will not be approved. After successful addition, you will be invited into relevant WeChat groups based on research direction. Please do not send advertisements in the group, otherwise, you will be removed. Thank you for understanding.~

Leave a Comment