📖 Check out our Introduction to Deep Learning & Neural Networks course 📖

Learn more

How distributed training works in Pytorch: distributed data-parallel and mixed-precision training

Nikolas Adaloglouon2022-04-14·5 mins
Data ProcessingSoftwarePytorch

In this tutorial, we will learn how to use nn.parallel.DistributedDataParallel for training our models in multiple GPUs. We will take a minimal example of training an image classifier and see how we can speed up the training.

Let’s start with some imports.

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time

We will use the CIFAR10 in all our experiments with a batch size of 256.

def create_data_loader_cifar10():
transform = transforms.Compose(
[
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = 256
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=10, pin_memory=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=10)
return trainloader, testloader

We will first train the model on a single Nvidia A100 GPU for 1 epoch. Standard pytorch stuff here, nothing new. The tutorial is based on the official tutorial from Pytorch’s docs.

def train(net, trainloader):
print("Start training...")
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
epochs = 1
num_of_batches = len(trainloader)
for epoch in range(epochs): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
images, labels = inputs.cuda(), labels.cuda()
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
print(f'[Epoch {epoch + 1}/{epochs}] loss: {running_loss / num_of_batches:.3f}')
print('Finished Training')

The test function is similarly defined. The main script will just put everything together:

if __name__ == '__main__':
start = time.time()
PATH = './cifar_net.pth'
trainloader, testloader = create_data_loader_cifar10()
net = torchvision.models.resnet50(False).cuda()
start_train = time.time()
train(net, trainloader)
end_train = time.time()
# save
torch.save(net.state_dict(), PATH)
# test
test(net, PATH, testloader)
end = time.time()
seconds = (end - start)
seconds_train = (end_train - start_train)
print(f"Total elapsed time: {seconds:.2f} seconds, \
Train 1 epoch {seconds_train:.2f} seconds")

We use a resnet50 to measure the performance of a decent-sized network.

Now let’s train the model:

$ python -m train_1gpu
Accuracy of the network on the 10000 test images: 27 %
Total elapsed time: 69.03 seconds, Train 1 epoch 13.08 seconds

Ok, time to get to optimization work.

Code is available on GitHub. If you are planning to solidify your Pytorch knowledge, there are two amazing books that we highly recommend: Deep learning with PyTorch from Manning Publications and Machine Learning with PyTorch and Scikit-Learn by Sebastian Raschka. You can always use the 35% discount code blaisummer21 for all Manning’s products.

torch.nn.DataParallel: no pain, no gain

DataParallel is single-process, multi-thread, and only works on a single machine. For each GPU, we use the same model to do the forward pass. We scatter the data throughout the GPUs and perform forward passes in each one of them. Essentially, what happens is that the batch size is divided across the number of workers.

In this use case, this functionality provided no gain. That’s because the system that I am using has a CPU and hard disk bottleneck. Other machines that have very fast disk and CPU but struggle with the GPU speed (GPU bottleneck) may benefit from this functionality.

In practice, the only change you need to do in the code is the following:

net = torchvision.models.resnet50(False)
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
# Batch size should be divisible by number of GPUs
net = nn.DataParallel(net)

When using nn.DataParallel, the batch size should be divisible by the number of GPUs.

nn.DataParallel splits the batch and processes it independently in all the available GPU’s. In each forward pass, the module is replicated on each GPU, which is a significant overhead. Each replica handles a portion of the batch (batch_size / gpus). During the backwards pass, gradients from each replica are summed into the original module.

More info on our previous article on data vs model parallelism.

A good practice when using multiple GPUs is to define in advance the GPUs that your script is going to use:

import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"

This should be DONE before any other import-related to CUDA.

Even from the Pytorch documentation it is obvious that this is a very poor strategy:

It is recommended to use nn.DistributedDataParallel, instead of this class, to do multi-GPU training, even if there is only a single node.

The reason is that DistributedDataParallel uses one process per worker (GPU) while DataParallel encapsulates all the data communication in a single process.

According to the docs, the data can be on any device before they are passed into the model.

In my experiment, DataParallel was slower than training on a single GPU. Even with 4 GPUs. After increasing the number of workers I reduced the time, but still worse than a single GPU. I measure and report the time required to train the model for one epoch, that is 50K 32x32 images.

Final note: to compare the performance with a single GPU, I multiplied the batch size by the number of workers, i.e. 4 for 4 GPUs. Otherwise, it’s more than 2X slower.

This brings us to the hardcore topic of Distributed Data-Parallel.

Code is available on GitHub. You can always support our work by social media sharing, making a donation, and buying our book and e-course.

Pytorch Distributed Data-Parallel

Distributed data parallel is multi-process and works for both single and multi-machine training. In pytorch, nn.parallel.DistributedDataParallel parallelizes the module by splitting the input across the specified devices. This module is suitable for multi-node,multi-GPU training as well. Here, I only experimented with a single node (1 machine with 4 GPUs).

The main difference here is that each GPU is handled by a process. Parameters are never broadcasted between processes, only gradients.

The module is replicated on each machine and each device. During the forward pass, each worker (GPU) processes the data and computes its own gradient locally. During the backwards pass, gradients from each node are averaged. Finally, each worker performs a parameter update and sends to all the other nodes the computed parameter update.

The module performs an all-reduce step on gradients and assumes that they will be modified by the optimizer in all processes in the same way.

Below are the guidelines for converting your single GPU script to multi-GPU training.

Step 1: Initialize the distributed learning processes

def init_distributed():
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
dist_url = "env://" # default
# only works with torch.distributed.launch // torch.run
rank = int(os.environ["RANK"])
world_size = int(os.environ['WORLD_SIZE'])
local_rank = int(os.environ['LOCAL_RANK'])
dist.init_process_group(
backend="nccl",
init_method=dist_url,
world_size=world_size,
rank=rank)
# this will make all .cuda() calls work properly
torch.cuda.set_device(local_rank)
# synchronizes all the threads to reach this point before moving on
dist.barrier()

This initialization works when we launch our script with torch.distributed.launch (Pytorch 1.7 and 1.8) or torch.run (Pytorch 1.9+) from each node (here 1).

Step 2: Wrap the model using DDP

net = torchvision.models.resnet50(False).cuda()
# Convert BatchNorm to SyncBatchNorm.
net = nn.SyncBatchNorm.convert_sync_batchnorm(net)
local_rank = int(os.environ['LOCAL_RANK'])
net = nn.parallel.DistributedDataParallel(net, device_ids=[local_rank])

If each process has the correct local rank, tensor.cuda() or model.cuda() can be called correctly throughout the script.

Step 3: Use a DistributedSampler in your DataLoader

import torch
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
import torch.nn as nn
def create_data_loader_cifar10():
transform = transforms.Compose(
[
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = 256
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
train_sampler = DistributedSampler(dataset=trainset, shuffle=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
sampler=train_sampler, num_workers=10, pin_memory=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
test_sampler =DistributedSampler(dataset=testset, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, sampler=test_sampler, num_workers=10)
return trainloader, testloader

In distributed mode, calling the data_loader.sampler.set_epoch() method at the beginning of each epoch before creating the DataLoader iterator is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be always used.

def train(net, trainloader):
print("Start training...")
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
epochs = 1
num_of_batches = len(trainloader)
for epoch in range(epochs):
# NEW line added
trainloader.sampler.set_epoch(epoch)
# same as before …….

In a more general form:

for epoch in range(epochs):
data_loader.sampler.set_epoch(epoch)
# train epoch loop
train_one_epoch(...)

Good practices for DDP

Any methods that download data should be isolated to the master process. Any methods that perform file I/O should be isolated to the master process.

import torch.distributed as dist
import torch
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0

Based on this function you can be sure that some commands are only executed from the main process:

if is_main_process():
# do that ….
# save, load models, download data etc….

Launch script using torch.distributed.launch or torch.run

$ python -m torch.distributed.launch --nproc_per_node=4 main_script.py

Mistakes will occur. Be sure to kill any unwanted distributed training process by:

$ kill $(ps aux | grep main_script.py | grep -v grep | awk '{print $2}')

Replace main_script.py with your script name. Another more simple option is $ kill -9 PID. Otherwise you can go to more advanced stuff, like killing all CUDA GPU related processes when not shown in nvidia-smi

lsof /dev/nvidia* | awk '{print $2}' | xargs -I {} kill {}

This is only for the case that you cannot find the PID of the process running in the GPU.

A very good book on distributed training is Distributed Machine Learning with Python: Accelerating model training and serving with distributed systems by Guanhua Wang.

Mixed-precision training in Pytorch

Mixed precision combines Floating Point (FP) 16 and FP 32 in different steps of the training. FP16 training is also known as half-precision training, which comes with inferior performance. Automatic mixed-precision is literally the best of both worlds: reduced training time with comparable performance to FP32.

In Mixed Precision Training, all the computational operations (forward pass, backward pass, weight gradients) see the FP16 casted version. To do so, an FP32 copy of the weight is necessary, as well as computing the loss in FP32 after the forward pass in FP16 to avoid over and underflows. The weight gradients are casted back to FP32 to update the model’s weights. Moreover, the loss in FP32 is scaled up to avoid gradient underflow before getting casted to FP16 to perform the backward pass. As compensation, the FP32 weights will be scaled down by the same scalar before the weight update.

Here are the changes in the train function:

fp16_scaler = torch.cuda.amp.GradScaler(enabled=True)
for epoch in range(epochs): # loop over the dataset multiple times
trainloader.sampler.set_epoch(epoch)
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
images, labels = inputs.cuda(), labels.cuda()
# zero the parameter gradients
optimizer.zero_grad()
# forward
with torch.cuda.amp.autocast():
outputs = net(images)
loss = criterion(outputs, labels)
# mixed precision training
# backward + optimizer step
fp16_scaler.scale(loss).backward()
fp16_scaler.step(optimizer)
fp16_scaler.update()

Results and Sum up

In a utopian parallel world, N workers would give a speedup of N. Here you see that you need 4 GPUs in DistributedDataParallel mode to get a speedup of 2X. Mixed precision training normally provides a substantial speedup but the A100 GPU and other Ampere-based GPU architectures have limited gains (as far as I have read online).

Results below report the time in seconds for 1 epoch on CIFAR10 with a resnet50 (batch size 256, NVidia A100 40GB GPU memory):

Time in seconds
Single GPU (baseline)13.2
DataParallel 4 GPUs19.1
DistributedDataParallel 2 GPUs9.8
DistributedDataParallel 4 GPUs6.1
DistributedDataParallel 4 GPUs + Mixed Precision6.5

A very important note here is that DistributedDataParallel uses an effective batch size of 4*256=1024 so it makes fewer model updates. That’s why I believe it scores a much lower validation accuracy (14% compared to 27% in the baseline).

Code is available on GitHub if you want to play around. The results will vary based on your hardware. There is always the case that I missed something in my experiments. If you find a flaw please let me know on our Discord server.

These findings would provide you with a solid start to training your models. I hope you find them useful. Supports us by social media sharing, making a donation, buying our book or e-course. Your help would help us produce more free content and accessible AI content. As always, thank you for your interest in our blog.

Deep Learning in Production Book 📖

Learn how to build, train, deploy, scale and maintain deep learning models. Understand ML infrastructure and MLOps using hands-on examples.

Learn more

* Disclosure: Please note that some of the links above might be affiliate links, and at no additional cost to you, we will earn a commission if you decide to make a purchase after clicking through.