Click the above “Beginner’s Guide to Vision” and choose to add a Star or “Top”
Essential content delivered promptly
Introduction
This article summarizes important points to consider when implementing checkpoint resume training with PyTorch, along with detailed code explanations.
Recently, while trying to train a classification problem using CIFAR10, I found that the dataset is quite large, and the training process takes a long time. Sometimes I want to pause the training, but if I do, I have to restart. My senior advised us to pay attention to changes in epochs when learning about checkpoint resume training. This morning, I organized my thoughts on this topic; it is not comprehensive and is for reference only.
Epoch: 9 | train loss: 0.3517 | test accuracy: 0.7184 | train time: 14215.1018 sEpoch: 9 | train loss: 0.2471 | test accuracy: 0.7252 | train time: 14309.1216 sEpoch: 9 | train loss: 0.4335 | test accuracy: 0.7201 | train time: 14403.2398 sEpoch: 9 | train loss: 0.2186 | test accuracy: 0.7242 | train time: 14497.1921 sEpoch: 9 | train loss: 0.2127 | test accuracy: 0.7196 | train time: 14591.4974 sEpoch: 9 | train loss: 0.1624 | test accuracy: 0.7142 | train time: 14685.7034 sEpoch: 9 | train loss: 0.1795 | test accuracy: 0.7170 | train time: 14780.2831 sDespair!!!! After training for a certain number of times, I found that the training count was less, or if interrupted midway, I had to restart training.
1. Saving and Loading Models
Saving (serialization, from memory to disk) and loading (deserialization, from disk to memory) in PyTorch
Main parameters for torch.save: obj: object, f: output path
Main parameters for torch.load: f: file path, map_location: specify storage location, cpu or gpu
Two methods for saving models:
1. Save the entire Module
torch.save(net, path)
2. Save model parameters
state_dict = net.state_dict()torch.save(state_dict , path)
2. Saving During Training
checkpoint = { "net": model.state_dict(), 'optimizer': optimizer.state_dict(), "epoch": epoch }
This saves the weights of the network and the optimizer during training, as well as the epoch, to facilitate resuming training.
During training, you can save network parameters every few epochs or steps as needed to improve robustness and recovery.
checkpoint = { "net": model.state_dict(), 'optimizer': optimizer.state_dict(), "epoch": epoch } if not os.path.isdir("./models/checkpoint"): os.mkdir("./models/checkpoint") torch.save(checkpoint, './models/checkpoint/ckpt_best_%s.pth' %(str(epoch)))
By following the above process, you can automatically create a folder in the specified location during training and save checkpoint files.
3. Resume Training from Checkpoint
if RESUME: path_checkpoint = "./models/checkpoint/ckpt_best_1.pth" # checkpoint path checkpoint = torch.load(path_checkpoint) # load checkpoint
model.load_state_dict(checkpoint['net']) # load model learnable parameters
optimizer.load_state_dict(checkpoint['optimizer']) # load optimizer parameters start_epoch = checkpoint['epoch'] # set starting epoch
You can determine whether to continue training and the location of the training checkpoint file via argparse directly from the command line, or load it from a log file, or modify it in the code. For more on argparse, refer to my article:
HUST Beginner: argparse command line options, parameters, and sub-command parsers
https://zhuanlan.zhihu.com/p/133285373
4. Focus on Resuming Epoch
start_epoch = -1
if RESUME: path_checkpoint = "./models/checkpoint/ckpt_best_1.pth" # checkpoint path checkpoint = torch.load(path_checkpoint) # load checkpoint
model.load_state_dict(checkpoint['net']) # load model learnable parameters
optimizer.load_state_dict(checkpoint['optimizer']) # load optimizer parameters start_epoch = checkpoint['epoch'] # set starting epoch
for epoch in range(start_epoch + 1 ,EPOCH): # print('EPOCH:',epoch) for step, (b_img,b_label) in enumerate(train_loader): train_output = model(b_img) loss = loss_func(train_output,b_label) # losses.append(loss) optimizer.zero_grad() loss.backward() optimizer.step()
By defining the start_epoch variable, you can ensure that the epoch does not change when resuming training.

Checkpoint Resume Training
1. Initialize Random Seed
import torchimport randomimport numpy as np
def set_random_seed(seed = 10, deterministic=False, benchmark=False): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) if deterministic: torch.backends.cudnn.deterministic = True if benchmark: torch.backends.cudnn.benchmark = True
For more on torch.backends.cudnn.deterministic and torch.backends.cudnn.benchmark, see
PyTorch Study 0.01: Setting cudnn.benchmark = True
https://www.cnblogs.com/captain-dl/p/11938864.html
pytorch—cudnn.benchmark and cudnn.deterministic_ Artificial Intelligence Blog by zxyhhjs2017
https://blog.csdn.net/zxyhhjs2017/article/details/91348108
Benchmark is used when input sizes are consistent to speed up training, while deterministic is used to fix internal randomness.
2. Multi-Step SGD Resume Training
In simple tasks, we use a fixed step size (i.e., learning rate LR) for training. However, if the learning rate is set too low, it will be difficult to converge. If the learning rate is too high, it will cause oscillation around the minimum value, making it impossible to converge. Therefore, we need to use different learning rates for different training phases, which can speed up the training process and the convergence of the network.
Using multi-step torch.optim.lr_scheduler for various step settings can control the step size. For recommendations on using lr_scheduler, refer to the following tutorial:
【Reprint】 Learning Rate Adjustment in PyTorch lr_scheduler, ReduceLROnPlateau
https://www.cnblogs.com/devilmaycry812839668/p/10630302.html
So, when saving training parameters in the network, we also need to save the state_dict of the lr_scheduler, and when resuming training, we need to restore it.
# Here I set different learning rate decay for different epochs at 10->20->30, the learning rate decays to 0.1 of the original, i.e., a magnitude lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,20,30,40,50], gamma=0.1)optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
for epoch in range(start_epoch+1, 80): optimizer.zero_grad() optimizer.step() lr_schedule.step()
if epoch % 10 == 0: print('epoch:', epoch) print('learning rate:', optimizer.state_dict()['param_groups'][0]['lr'])
The change of lr is as follows:
epoch: 10learning rate: 0.1epoch: 20learning rate: 0.010000000000000002epoch: 30learning rate: 0.0010000000000000002epoch: 40learning rate: 0.00010000000000000003epoch: 50learning rate: 1.0000000000000004e-05epoch: 60learning rate: 1.0000000000000004e-06epoch: 70learning rate: 1.0000000000000004e-06
When saving, we also need to save the state_dict of the lr_scheduler, and when resuming training, we also need to restore the lr_scheduler.
# Load and restore if RESUME: path_checkpoint = "./model_parameter/test/ckpt_best_50.pth" # checkpoint path checkpoint = torch.load(path_checkpoint) # load checkpoint
model.load_state_dict(checkpoint['net']) # load model learnable parameters
optimizer.load_state_dict(checkpoint['optimizer']) # load optimizer parameters start_epoch = checkpoint['epoch'] # set starting epoch lr_schedule.load_state_dict(checkpoint['lr_schedule'])# load lr_scheduler
# Save for epoch in range(start_epoch+1, 80): optimizer.zero_grad() optimizer.step() lr_schedule.step()
if epoch % 10 == 0: print('epoch:', epoch) print('learning rate:', optimizer.state_dict()['param_groups'][0]['lr']) checkpoint = { "net": model.state_dict(), 'optimizer': optimizer.state_dict(), "epoch": epoch, 'lr_schedule': lr_schedule.state_dict() } if not os.path.isdir("./model_parameter/test"): os.mkdir("./model_parameter/test") torch.save(checkpoint, './model_parameter/test/ckpt_best_%s.pth' % (str(epoch)))
3. Save the Best Results
Each step in every epoch may yield different results; you can save the best result from each generation for subsequent training.
First Experiment Code
RESUME = True
EPOCH = 40
LR = 0.0005
model = cifar10_cnn.CIFAR10_CNN()
print(model)optimizer = torch.optim.Adam(model.parameters(), lr=LR)loss_func = nn.CrossEntropyLoss()
start_epoch = -1
if RESUME: path_checkpoint = "./models/checkpoint/ckpt_best_1.pth" # checkpoint path checkpoint = torch.load(path_checkpoint) # load checkpoint
model.load_state_dict(checkpoint['net']) # load model learnable parameters
optimizer.load_state_dict(checkpoint['optimizer']) # load optimizer parameters start_epoch = checkpoint['epoch'] # set starting epoch
for epoch in range(start_epoch + 1 ,EPOCH): # print('EPOCH:',epoch) for step, (b_img,b_label) in enumerate(train_loader): train_output = model(b_img) loss = loss_func(train_output,b_label) # losses.append(loss) optimizer.zero_grad() loss.backward() optimizer.step()
if step % 100 == 0: now = time.time() print('EPOCH:', epoch, '| step :', step, '| loss :', loss.data.numpy(), '| train time: %.4f' % (now - start_time))
checkpoint = { "net": model.state_dict(), 'optimizer': optimizer.state_dict(), "epoch": epoch } if not os.path.isdir("./models/checkpoint"): os.mkdir("./models/checkpoint") torch.save(checkpoint, './models/checkpoint/ckpt_best_%s.pth' % (str(epoch)))
Updated Experiment Code
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,20,30,40,50], gamma=0.1)start_epoch = 9# print(schedule)
if RESUME: path_checkpoint = "./model_parameter/test/ckpt_best_50.pth" # checkpoint path checkpoint = torch.load(path_checkpoint) # load checkpoint
model.load_state_dict(checkpoint['net']) # load model learnable parameters
optimizer.load_state_dict(checkpoint['optimizer']) # load optimizer parameters start_epoch = checkpoint['epoch'] # set starting epoch lr_schedule.load_state_dict(checkpoint['lr_schedule'])
for epoch in range(start_epoch + 1, 80): optimizer.zero_grad() optimizer.step() lr_schedule.step()
if epoch % 10 == 0: print('epoch:', epoch) print('learning rate:', optimizer.state_dict()['param_groups'][0]['lr']) checkpoint = { "net": model.state_dict(), 'optimizer': optimizer.state_dict(), "epoch": epoch, 'lr_schedule': lr_schedule.state_dict() } if not os.path.isdir("./model_parameter/test"): os.mkdir("./model_parameter/test") torch.save(checkpoint, './model_parameter/test/ckpt_best_%s.pth' % (str(epoch)))
Download 1: Chinese Version of OpenCV-Contrib Extension Module Tutorial
Reply "Extension Module Chinese Tutorial" in the backend of the "Beginner's Guide to Vision" public account to download the first Chinese version of the OpenCV extension module tutorial available online, covering installation of extension modules, SFM algorithms, stereo vision, object tracking, biological vision, super-resolution processing, and more than twenty chapters of content.
Download 2: 52 Lectures on Python Vision Practical Projects
Reply "Python Vision Practical Projects" in the backend of the "Beginner's Guide to Vision" public account to download 31 practical vision projects including image segmentation, mask detection, lane line detection, vehicle counting, eyeliner addition, license plate recognition, character recognition, emotion detection, text content extraction, facial recognition, etc., to help quickly learn computer vision.
Download 3: 20 Lectures on OpenCV Practical Projects
Reply "OpenCV Practical Projects 20 Lectures" in the backend of the "Beginner's Guide to Vision" public account to download 20 practical projects based on OpenCV for advanced learning of OpenCV.
Discussion Group
Welcome to join the public account reader group to exchange ideas with peers. Currently, there are WeChat groups for SLAM, 3D vision, sensors, autonomous driving, computational photography, detection, segmentation, recognition, medical imaging, GAN, algorithm competitions, etc. (these will be gradually subdivided). Please scan the WeChat ID below to join the group, and note: "Nickname + School/Company + Research Direction", for example: "Zhang San + Shanghai Jiao Tong University + Vision SLAM". Please follow the format for notes; otherwise, you will not be approved. After successful addition, you will be invited to join relevant WeChat groups based on your research direction. Please do not send advertisements in the group; otherwise, you will be removed from the group. Thank you for your understanding~