Implementing Checkpoint Resume Training with PyTorch

Click the above “Beginner’s Guide to Vision” and choose to add a Star or “Top

Essential content delivered promptly

Introduction

This article summarizes important points to consider when implementing checkpoint resume training with PyTorch, along with detailed code explanations.

Recently, while trying to train a classification problem using CIFAR10, I found that the dataset is quite large, and the training process takes a long time. Sometimes I want to pause the training, but if I do, I have to restart. My senior advised us to pay attention to changes in epochs when learning about checkpoint resume training. This morning, I organized my thoughts on this topic; it is not comprehensive and is for reference only.

Epoch:  9 | train loss: 0.3517 | test accuracy: 0.7184 | train time: 14215.1018  sEpoch:  9 | train loss: 0.2471 | test accuracy: 0.7252 | train time: 14309.1216  sEpoch:  9 | train loss: 0.4335 | test accuracy: 0.7201 | train time: 14403.2398  sEpoch:  9 | train loss: 0.2186 | test accuracy: 0.7242 | train time: 14497.1921  sEpoch:  9 | train loss: 0.2127 | test accuracy: 0.7196 | train time: 14591.4974  sEpoch:  9 | train loss: 0.1624 | test accuracy: 0.7142 | train time: 14685.7034  sEpoch:  9 | train loss: 0.1795 | test accuracy: 0.7170 | train time: 14780.2831  sDespair!!!! After training for a certain number of times, I found that the training count was less, or if interrupted midway, I had to restart training.

1. Saving and Loading Models

Saving (serialization, from memory to disk) and loading (deserialization, from disk to memory) in PyTorch

Main parameters for torch.save: obj: object, f: output path

Main parameters for torch.load: f: file path, map_location: specify storage location, cpu or gpu

Two methods for saving models:

1. Save the entire Module

torch.save(net, path)

2. Save model parameters

state_dict = net.state_dict()torch.save(state_dict , path)

2. Saving During Training

checkpoint = {        "net": model.state_dict(),        'optimizer': optimizer.state_dict(),        "epoch": epoch    }

This saves the weights of the network and the optimizer during training, as well as the epoch, to facilitate resuming training.

During training, you can save network parameters every few epochs or steps as needed to improve robustness and recovery.

checkpoint = {        "net": model.state_dict(),        'optimizer': optimizer.state_dict(),        "epoch": epoch    }    if not os.path.isdir("./models/checkpoint"):        os.mkdir("./models/checkpoint")    torch.save(checkpoint, './models/checkpoint/ckpt_best_%s.pth' %(str(epoch)))
By following the above process, you can automatically create a folder in the specified location during training and save checkpoint files.

Implementing Checkpoint Resume Training with PyTorch

3. Resume Training from Checkpoint

if RESUME:    path_checkpoint = "./models/checkpoint/ckpt_best_1.pth"  # checkpoint path    checkpoint = torch.load(path_checkpoint)  # load checkpoint
    model.load_state_dict(checkpoint['net'])  # load model learnable parameters
    optimizer.load_state_dict(checkpoint['optimizer'])  # load optimizer parameters    start_epoch = checkpoint['epoch']  # set starting epoch

You can determine whether to continue training and the location of the training checkpoint file via argparse directly from the command line, or load it from a log file, or modify it in the code. For more on argparse, refer to my article:

HUST Beginner: argparse command line options, parameters, and sub-command parsers

https://zhuanlan.zhihu.com/p/133285373

4. Focus on Resuming Epoch

start_epoch = -1

if RESUME:    path_checkpoint = "./models/checkpoint/ckpt_best_1.pth"  # checkpoint path    checkpoint = torch.load(path_checkpoint)  # load checkpoint
    model.load_state_dict(checkpoint['net'])  # load model learnable parameters
    optimizer.load_state_dict(checkpoint['optimizer'])  # load optimizer parameters    start_epoch = checkpoint['epoch']  # set starting epoch

for epoch in  range(start_epoch + 1 ,EPOCH):    # print('EPOCH:',epoch)    for step, (b_img,b_label) in enumerate(train_loader):        train_output = model(b_img)        loss = loss_func(train_output,b_label)        # losses.append(loss)        optimizer.zero_grad()        loss.backward()        optimizer.step()

By defining the start_epoch variable, you can ensure that the epoch does not change when resuming training.

Implementing Checkpoint Resume Training with PyTorch

Checkpoint Resume Training

1. Initialize Random Seed

import torchimport randomimport numpy as np
def set_random_seed(seed = 10, deterministic=False, benchmark=False):    random.seed(seed)    np.random.seed(seed)    torch.manual_seed(seed)    torch.cuda.manual_seed_all(seed)    if deterministic:        torch.backends.cudnn.deterministic = True    if benchmark:        torch.backends.cudnn.benchmark = True

For more on torch.backends.cudnn.deterministic and torch.backends.cudnn.benchmark, see

PyTorch Study 0.01: Setting cudnn.benchmark = True

https://www.cnblogs.com/captain-dl/p/11938864.html

pytorch—cudnn.benchmark and cudnn.deterministic_ Artificial Intelligence Blog by zxyhhjs2017

https://blog.csdn.net/zxyhhjs2017/article/details/91348108

Implementing Checkpoint Resume Training with PyTorch

Benchmark is used when input sizes are consistent to speed up training, while deterministic is used to fix internal randomness.

2. Multi-Step SGD Resume Training

In simple tasks, we use a fixed step size (i.e., learning rate LR) for training. However, if the learning rate is set too low, it will be difficult to converge. If the learning rate is too high, it will cause oscillation around the minimum value, making it impossible to converge. Therefore, we need to use different learning rates for different training phases, which can speed up the training process and the convergence of the network.

Using multi-step torch.optim.lr_scheduler for various step settings can control the step size. For recommendations on using lr_scheduler, refer to the following tutorial:

【Reprint】 Learning Rate Adjustment in PyTorch lr_scheduler, ReduceLROnPlateau

https://www.cnblogs.com/devilmaycry812839668/p/10630302.html

So, when saving training parameters in the network, we also need to save the state_dict of the lr_scheduler, and when resuming training, we need to restore it.

# Here I set different learning rate decay for different epochs at 10->20->30, the learning rate decays to 0.1 of the original, i.e., a magnitude lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,20,30,40,50], gamma=0.1)optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
for epoch in range(start_epoch+1, 80):    optimizer.zero_grad()    optimizer.step()    lr_schedule.step()
    if epoch % 10 == 0:        print('epoch:', epoch)        print('learning rate:', optimizer.state_dict()['param_groups'][0]['lr'])
The change of lr is as follows:
epoch: 10learning rate: 0.1epoch: 20learning rate: 0.010000000000000002epoch: 30learning rate: 0.0010000000000000002epoch: 40learning rate: 0.00010000000000000003epoch: 50learning rate: 1.0000000000000004e-05epoch: 60learning rate: 1.0000000000000004e-06epoch: 70learning rate: 1.0000000000000004e-06

When saving, we also need to save the state_dict of the lr_scheduler, and when resuming training, we also need to restore the lr_scheduler.

# Load and restore if RESUME:    path_checkpoint = "./model_parameter/test/ckpt_best_50.pth"  # checkpoint path    checkpoint = torch.load(path_checkpoint)  # load checkpoint
    model.load_state_dict(checkpoint['net'])  # load model learnable parameters
    optimizer.load_state_dict(checkpoint['optimizer'])  # load optimizer parameters    start_epoch = checkpoint['epoch']  # set starting epoch    lr_schedule.load_state_dict(checkpoint['lr_schedule'])# load lr_scheduler


# Save for epoch in range(start_epoch+1, 80):    optimizer.zero_grad()    optimizer.step()    lr_schedule.step()

    if epoch % 10 == 0:        print('epoch:', epoch)        print('learning rate:', optimizer.state_dict()['param_groups'][0]['lr'])        checkpoint = {            "net": model.state_dict(),            'optimizer': optimizer.state_dict(),            "epoch": epoch,            'lr_schedule': lr_schedule.state_dict()        }        if not os.path.isdir("./model_parameter/test"):            os.mkdir("./model_parameter/test")        torch.save(checkpoint, './model_parameter/test/ckpt_best_%s.pth' % (str(epoch)))

3. Save the Best Results

Each step in every epoch may yield different results; you can save the best result from each generation for subsequent training.

First Experiment Code

RESUME = True
EPOCH = 40
LR = 0.0005

model = cifar10_cnn.CIFAR10_CNN()
print(model)optimizer = torch.optim.Adam(model.parameters(), lr=LR)loss_func = nn.CrossEntropyLoss()
start_epoch = -1

if RESUME:    path_checkpoint = "./models/checkpoint/ckpt_best_1.pth"  # checkpoint path    checkpoint = torch.load(path_checkpoint)  # load checkpoint
    model.load_state_dict(checkpoint['net'])  # load model learnable parameters
    optimizer.load_state_dict(checkpoint['optimizer'])  # load optimizer parameters    start_epoch = checkpoint['epoch']  # set starting epoch

for epoch in  range(start_epoch + 1 ,EPOCH):    # print('EPOCH:',epoch)    for step, (b_img,b_label) in enumerate(train_loader):        train_output = model(b_img)        loss = loss_func(train_output,b_label)        # losses.append(loss)        optimizer.zero_grad()        loss.backward()        optimizer.step()
        if step % 100 == 0:            now = time.time()            print('EPOCH:', epoch, '| step :', step, '| loss :', loss.data.numpy(), '| train time: %.4f' % (now - start_time))
    checkpoint = {        "net": model.state_dict(),        'optimizer': optimizer.state_dict(),        "epoch": epoch    }    if not os.path.isdir("./models/checkpoint"):        os.mkdir("./models/checkpoint")    torch.save(checkpoint, './models/checkpoint/ckpt_best_%s.pth' % (str(epoch)))

Updated Experiment Code

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,20,30,40,50], gamma=0.1)start_epoch = 9# print(schedule)

if RESUME:    path_checkpoint = "./model_parameter/test/ckpt_best_50.pth"  # checkpoint path    checkpoint = torch.load(path_checkpoint)  # load checkpoint
    model.load_state_dict(checkpoint['net'])  # load model learnable parameters
    optimizer.load_state_dict(checkpoint['optimizer'])  # load optimizer parameters    start_epoch = checkpoint['epoch']  # set starting epoch    lr_schedule.load_state_dict(checkpoint['lr_schedule'])
for epoch in range(start_epoch + 1, 80):    optimizer.zero_grad()    optimizer.step()    lr_schedule.step()

    if epoch % 10 == 0:        print('epoch:', epoch)        print('learning rate:', optimizer.state_dict()['param_groups'][0]['lr'])        checkpoint = {            "net": model.state_dict(),            'optimizer': optimizer.state_dict(),            "epoch": epoch,            'lr_schedule': lr_schedule.state_dict()        }        if not os.path.isdir("./model_parameter/test"):            os.mkdir("./model_parameter/test")        torch.save(checkpoint, './model_parameter/test/ckpt_best_%s.pth' % (str(epoch)))
Download 1: Chinese Version of OpenCV-Contrib Extension Module Tutorial

Reply "Extension Module Chinese Tutorial" in the backend of the "Beginner's Guide to Vision" public account to download the first Chinese version of the OpenCV extension module tutorial available online, covering installation of extension modules, SFM algorithms, stereo vision, object tracking, biological vision, super-resolution processing, and more than twenty chapters of content.

Download 2: 52 Lectures on Python Vision Practical Projects

Reply "Python Vision Practical Projects" in the backend of the "Beginner's Guide to Vision" public account to download 31 practical vision projects including image segmentation, mask detection, lane line detection, vehicle counting, eyeliner addition, license plate recognition, character recognition, emotion detection, text content extraction, facial recognition, etc., to help quickly learn computer vision.

Download 3: 20 Lectures on OpenCV Practical Projects

Reply "OpenCV Practical Projects 20 Lectures" in the backend of the "Beginner's Guide to Vision" public account to download 20 practical projects based on OpenCV for advanced learning of OpenCV.

Discussion Group

Welcome to join the public account reader group to exchange ideas with peers. Currently, there are WeChat groups for SLAM, 3D vision, sensors, autonomous driving, computational photography, detection, segmentation, recognition, medical imaging, GAN, algorithm competitions, etc. (these will be gradually subdivided). Please scan the WeChat ID below to join the group, and note: "Nickname + School/Company + Research Direction", for example: "Zhang San + Shanghai Jiao Tong University + Vision SLAM". Please follow the format for notes; otherwise, you will not be approved. After successful addition, you will be invited to join relevant WeChat groups based on your research direction. Please do not send advertisements in the group; otherwise, you will be removed from the group. Thank you for your understanding~

Leave a Comment