Pytorch – Simple Implementation of Elastic Training

Pytorch - Simple Implementation of Elastic Training

MLNLP ( Machine Learning Algorithms and Natural Language Processing ) community is a well-known natural language processing community at home and abroad, covering domestic and foreign NLP master’s and doctoral students, university teachers, and corporate researchers.The vision of the community is to promote communication between the academic and industrial circles of natural language processing and machine learning at home and abroad, especially the progress of beginners.

Source | Jishi Platform

Author | Yan Tingshuai @ Zhihu

Link | https://zhuanlan.zhihu.com/p/489892744

Due to work needs, I have recently been filling in knowledge about distributed training. After some theoretical study, I still feel unsatisfied, and many knowledge points cannot be accurately grasped (for example: what should the code-level distributed primitives scatter, all reduce, etc. look like, how the ring all reduce algorithm is used in gradient synchronization, and how the parameter server parameters are partially updated).

The famous physicist and Nobel laureate Richard Feynman wrote on the blackboard in his office: “What I cannot create, I do not understand.” In the programmer community, there is also the slogan “show me the code”. Therefore, I plan to write a series of articles on distributed training, presenting the previously abstract concepts of distributed training in the form of code, ensuring that each piece of code is executable, verifiable, and reproducible, and contributing the source code for mutual exchange.

After research, I found that Pytorch has done a good job of abstraction and interface for distributed training, so this series of articles will mainly use Pytorch as the framework, and many of the examples in the articles come from the Pytorch documentation, which have been debugged and expanded upon.

Finally, since there are already many theoretical introductions to distributed training on the internet, the theoretical part will not be the focus of this series of articles; I will focus on the code-level introduction.

Pytorch – Simple Experience of Distributed Training: https://zhuanlan.zhihu.com/p/477073906

Pytorch – Distributed Communication Primitives (with source code): https://zhuanlan.zhihu.com/p/478953028

Pytorch – Handwritten Allreduce Distributed Training (with source code): https://zhuanlan.zhihu.com/p/482557067

Pytorch – Simple Implementation of Operator-Level Parallelism (with source code): https://zhuanlan.zhihu.com/p/483640235

Pytorch – Simple Implementation of Multi-Machine Multi-Card (with source code): https://zhuanlan.zhihu.com/p/486130584

1

“Introduction”

Pytorch introduced torchrun in version 1.9.0, replacing torch.distributed.launch in versions prior to 1.9.0. torchrun mainly adds two features based on the functionality of torch.distributed.launch:

  • Failover: When a worker training fails, all workers will automatically restart to continue training;
  • Elastic: Allows dynamic addition or deletion of node nodes; this article will illustrate how to use Elastic Training through an example;

This example will first start a 4 GPU worker group on Node0, and after training for a while, another 4 GPU workers will be started on Node1, forming a new worker group with the workers on Node1, ultimately forming a distributed training of 2 machines and 8 cards.

Pytorch - Simple Implementation of Elastic Training

2

“Model Construction”

A simple fully connected neural network model

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

3

“Checkpoint Handling”

Since every time a node is added or deleted, all workers will be killed and then restarted for training. Therefore, the training state needs to be saved in the training code to ensure that training can continue from the last state after restarting.

The information that needs to be saved generally includes the following:

  • model: model parameter information
  • optimizer: optimizer parameter information
  • epoch: the current epoch number

The code for saving and loading is as follows

  • torch.save: uses Python’s pickle to serialize Python objects and save them to a local file;
  • torch.load: deserializes the local file saved by torch.save and loads it into memory;
  • model.state_dict(): stores the parameter information for each layer of the model
  • optimizer.state_dict(): stores the parameter information of the optimizer
def save_checkpoint(epoch, model, optimizer, path):
    torch.save({
    "epoch": epoch,
    "model_state_dict": model.state_dict(),
    "optimize_state_dict": optimizer.state_dict(),
}, path)

def load_checkpoint(path):
    checkpoint = torch.load(path)
    return checkpoint

4

“Training Code”

The initialization logic is as follows:

  • Lines 1-3: Output the key environment variables of the current worker for later result display
  • Lines 5-8: Create the model, optimizer, and loss function
  • Lines 10-12: Initialize parameter information
  • Lines 14-19: If a checkpoint exists, load the checkpoint and assign values to model, optimizer, and first_epoch
    local_rank = int(os.environ["LOCAL_RANK"])
    rank = int(os.environ["RANK"])
    print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) train worker starting...")
    
    model = ToyModel().cuda(local_rank)
    ddp_model = DDP(model, [local_rank])
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
    optimizer.zero_grad()
    max_epoch = 100
    first_epoch = 0
    ckp_path = "checkpoint.pt"
    
    if os.path.exists(ckp_path):
        print(f"load checkpoint from {ckp_path}")
        checkpoint = load_checkpoint(ckp_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimize_state_dict"])
        first_epoch = checkpoint["epoch"]

Training logic:

  • Line 1: The number of epochs executed is from first_epoch to max_epoch, so that training can continue from the original epoch after the worker is restarted;
  • Line 2: To demonstrate the effect of dynamically adding nodes, a sleep function is added here to slow down the training speed;
  • Lines 3-8: Model training process;
  • Line 9: For simplicity, save a checkpoint after each epoch; save the current epoch, model, and optimizer to the checkpoint;
    for i in range(first_epoch, max_epoch):
        time.sleep(1) # To demonstrate the effect of dynamically adding nodes, a sleep function is added here to slow down the training speed
        outputs = ddp_model(torch.randn(20, 10).to(local_rank))
        labels = torch.randn(20, 5).to(local_rank)
        loss = loss_fn(outputs, labels)
        loss.backward()
        print(f"[{os.getpid()}] epoch {i} (rank = {rank}, local_rank = {local_rank}) loss = {loss.item()}
")
        optimizer.step()
        save_checkpoint(i, model, optimizer, ckp_path)

5

“Startup Method”

Since we use torchrun to start multi-machine multi-card tasks, there is no need to use the spawn interface to start multiple processes (torchrun will handle launching our Python script as a process), so we directly call the train function written above, adding the initialization and effect functions of DistributedDataParallel before and after it.

The following code describes the call to the train interface written above.

def run():
    env_dict = {
        key: os.environ[key]
        for key in ("MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "LOCAL_WORLD_SIZE")
    }
    print(f"[{os.getpid()}] Initializing process group with: {env_dict}")
    dist.init_process_group(backend="nccl")
    train()
    dist.destroy_process_group()


if __name__ == "__main__":
    run()

This example uses torchrun to execute the distributed training task of multiple machines and multiple cards (note: torch.distributed.launch has been deprecated by Pytorch, so it should not be used anymore). The startup script is described as follows (note: both node0 and node1 are started via this script)

  • --nnodes=1:3: indicates that the current training task accepts at least 1 node and at most 3 nodes to participate in distributed training;
  • --nproc_per_node=4: indicates that there are 4 processes on each node;
  • --max_restarts=3: the maximum number of restarts for the worker group; note that node fail, node scale down, and node scale up will all cause a restart;
  • --rdzv_id=1: a unique job ID, all nodes use the same job ID;
  • --rdzv_backend: the backend implementation of rendezvous, supporting c10d and etcd by default; rendezvous is used for communication and coordination between multiple nodes;
  • --rdzv_endpoint: the address of the rendezvous, which should be the host IP and port of a node;
torchrun \
    --nnodes=1:3\
    --nproc_per_node=4\
    --max_restarts=3\
    --rdzv_id=1\
    --rdzv_backend=c10d\
    --rdzv_endpoint="192.0.0.1:1234"\
    train_elastic.py

6

“Result Analysis”

Code: BetterDL – train_elastic.py: https://github.com/tingshua-yts/BetterDL/blob/master/test/pytorch/DDP/train_elastic.py

Running environment: 2 machines with 4 cards each, v100

image: pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime

gpu: v100

First execute the startup script on node0

torchrun \
    --nnodes=1:3\
    --nproc_per_node=4\
    --max_restarts=3\
    --rdzv_id=1\
    --rdzv_backend=c10d\
    --rdzv_endpoint="192.0.0.1:1234"\
    train_elastic.py

Get the following results

  • Lines 2-5: The current task is a single machine 4-card training task, so WORLD_SIZE is 4, and LOCAL_WORLD_SIZE is also 4
  • Lines 6-9: A total of 4 ranks participated in distributed training, rank0~rank3
  • Lines 10-18: rank0~rank3 all start training from epoch=0
r/workspace/DDP# sh run_elastic.sh
[4031] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}
[4029] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}
[4030] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}
[4032] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}
[4029] (rank = 0, local_rank = 0) train worker starting...
[4030] (rank = 1, local_rank = 1) train worker starting...
[4032] (rank = 3, local_rank = 3) train worker starting...
[4031] (rank = 2, local_rank = 2) train worker starting...
[4101] epoch 0 (rank = 1, local_rank = 1) loss = 0.9288564920425415
[4103] epoch 0 (rank = 3, local_rank = 3) loss = 0.9711472988128662
[4102] epoch 0 (rank = 2, local_rank = 2) loss = 1.0727070569992065
[4100] epoch 0 (rank = 0, local_rank = 0) loss = 0.9402943253517151
[4100] epoch 1 (rank = 0, local_rank = 0) loss = 1.0327017307281494
[4101] epoch 1 (rank = 1, local_rank = 1) loss = 1.4485043287277222
[4103] epoch 1 (rank = 3, local_rank = 3) loss = 1.0959293842315674
[4102] epoch 1 (rank = 2, local_rank = 2) loss = 1.0669530630111694
...

Execute the same script on node1

torchrun \
    --nnodes=1:3\
    --nproc_per_node=4\
    --max_restarts=3\
    --rdzv_id=1\
    --rdzv_backend=c10d\
    --rdzv_endpoint="192.0.0.1:1234"\
    train_elastic.py

The results on node1 are as follows:

  • Lines 2-5: Due to the addition of node1, the current task is a 2-machine 8-card distributed training task, so WORLD_SIZE=8, LOCAL_WORLD_SIZE=4
  • Lines 6-9: The ranks of the workers on node1 are rank4 ~rank7
  • Lines 13-20: Since node1 joined the training at epoch 35 on node0, it continues training from epoch 35
/workspace/DDP# sh run_elastic.sh
[696] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[697] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[695] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[694] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[697] (rank = 7, local_rank = 3) train worker starting...
[695] (rank = 5, local_rank = 1) train worker starting...
[694] (rank = 4, local_rank = 0) train worker starting...
[696] (rank = 6, local_rank = 2) train worker starting...
load checkpoint from checkpoint.ptload checkpoint from checkpoint.pt
load checkpoint from checkpoint.pt
load checkpoint from checkpoint.pt
[697] epoch 35 (rank = 7, local_rank = 3) loss = 1.1888569593429565
[694] epoch 35 (rank = 4, local_rank = 0) loss = 0.8916441202163696
[695] epoch 35 (rank = 5, local_rank = 1) loss = 1.5685604810714722
[696] epoch 35 (rank = 6, local_rank = 2) loss = 1.11683189868927
[696] epoch 36 (rank = 6, local_rank = 2) loss = 1.3724170923233032
[694] epoch 36 (rank = 4, local_rank = 0) loss = 1.061527967453003
[695] epoch 36 (rank = 5, local_rank = 1) loss = 0.96876460313797
[697] epoch 36 (rank = 7, local_rank = 3) loss = 0.8060566782951355
...

The results on node0 are as follows:

  • Lines 6-9: The workers on node0 are executing when epoch 35, and the training script on node1 is executed, requesting to join the training task
  • Lines 10-13: All workers restart; due to the addition of node1, the current task is a 2-machine 8-card distributed training task, so WORLD_SIZE=8, LOCAL_WORLD_SIZE=4
  • Lines 14-17: The ranks of the workers on node1 are rank0~rank3
  • Lines 18-21: Load the checkpoint
  • Lines 22-30: Continue training from the model, optimizer, and epoch in the checkpoint
...
[4100] epoch 35 (rank = 0, local_rank = 0) loss = 1.0746158361434937
[4101] epoch 35 (rank = 1, local_rank = 1) loss = 1.1712706089019775
[4103] epoch 35 (rank = 3, local_rank = 3) loss = 1.1774182319641113
[4102] epoch 35 (rank = 2, local_rank = 2) loss = 1.0898035764694214
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4100 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4101 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4102 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4103 closing signal SIGTERM
[4164] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[4165] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[4162] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[4163] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[4162] (rank = 0, local_rank = 0) train worker starting...
[4163] (rank = 1, local_rank = 1) train worker starting...
[4164] (rank = 2, local_rank = 2) train worker starting...
[4165] (rank = 3, local_rank = 3) train worker starting...
load checkpoint from checkpoint.pt
load checkpoint from checkpoint.pt
load checkpoint from checkpoint.pt
load checkpoint from checkpoint.pt
[4165] epoch 35 (rank = 3, local_rank = 3) loss = 1.3437936305999756
[4162] epoch 35 (rank = 0, local_rank = 0) loss = 1.5693414211273193
[4163] epoch 35 (rank = 1, local_rank = 1) loss = 1.199862003326416
[4164] epoch 35 (rank = 2, local_rank = 2) loss = 1.0465545654296875
[4163] epoch 36 (rank = 1, local_rank = 1) loss = 0.9741991758346558
[4162] epoch 36 (rank = 0, local_rank = 0) loss = 1.3609280586242676
[4164] epoch 36 (rank = 2, local_rank = 2) loss = 0.9585908055305481
[4165] epoch 36 (rank = 3, local_rank = 3) loss = 0.9169824123382568
...

Technical Exchange Group Invitation

Pytorch - Simple Implementation of Elastic Training

△ Long press to add the assistant

Scan the QR code to add the assistant WeChat

Please note: Name-School/Company-Research Direction(e.g., Xiao Zhang-Harbin Institute of Technology-Dialogue Systems)to apply to join the Natural Language Processing/Pytorch and other technical exchange groups

About Us

MLNLP Community (Machine Learning Algorithms and Natural Language Processing) is a grassroots academic community jointly built by domestic and foreign natural language processing scholars. It has now developed into a well-known natural language processing community at home and abroad, including well-known brands such as 10,000-person top conference exchange group, AI selection meeting, AI talent meeting and AI academic meeting, aiming to promote the progress of the academic and industrial circles of machine learning and natural language processing and the broad masses of enthusiasts.The community can provide an open exchange platform for relevant practitioners in terms of further education, employment, and research. Everyone is welcome to pay attention to and join us.

Pytorch - Simple Implementation of Elastic Training

Leave a Comment