Distributed Training on PyTorch: A Complete Tutorial

In the era of deep learning, the scale of models and datasets is constantly expanding. Training large - scale models on a single machine can be extremely time - consuming and resource - intensive. Distributed training is a powerful technique that allows us to train models across multiple machines or multiple GPUs within a single machine simultaneously. PyTorch, a popular deep learning framework, provides a comprehensive set of tools for distributed training. This tutorial will guide you through the fundamental concepts, usage methods, common practices, and best practices of distributed training in PyTorch.

Table of Contents

  1. Fundamental Concepts of Distributed Training
  2. Setting up the Environment
  3. Data Parallelism in PyTorch
  4. Model Parallelism in PyTorch
  5. Common Practices
  6. Best Practices
  7. Conclusion
  8. References

1. Fundamental Concepts of Distributed Training

1.1 Data Parallelism

Data parallelism is the most common form of distributed training. In data parallelism, the same model is replicated across multiple devices (GPUs or machines). Each device processes a different subset of the training data. During the forward pass, each device computes the gradients based on its subset of data. Then, these gradients are aggregated, and the model parameters are updated.

1.2 Model Parallelism

Model parallelism involves splitting the model across multiple devices. Different parts of the model are placed on different devices, and the data flows through these devices during the forward and backward passes. This is useful when the model is too large to fit on a single device.

1.3 Collective Communication

Collective communication operations are essential for distributed training. These operations allow devices to exchange data, such as gradients, during the training process. Common collective communication operations include all - reduce, broadcast, and gather.

2. Setting up the Environment

Before we start with distributed training in PyTorch, we need to set up the environment. First, make sure you have PyTorch installed. You can install it using pip or conda depending on your preference.

# Using pip
pip install torch torchvision

# Using conda
conda install pytorch torchvision -c pytorch

If you are using multiple GPUs on a single machine, make sure your CUDA is properly installed and configured.

3. Data Parallelism in PyTorch

3.1 DataParallel

DataParallel is a simple way to perform data parallelism in PyTorch. It automatically splits the input data across multiple GPUs and aggregates the gradients.

import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

model = SimpleModel()
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model = nn.DataParallel(model)

model.to('cuda')

# Generate some dummy data
input_data = torch.randn(20, 10).to('cuda')
output = model(input_data)
print(output.size())

3.2 DistributedDataParallel

DistributedDataParallel is a more efficient alternative to DataParallel. It uses distributed communication primitives and is designed for multi - machine and multi - GPU training.

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)


def demo_basic(rank, world_size):
    setup(rank, world_size)

    # create model and move it to GPU with id rank
    model = SimpleModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    inputs = torch.randn(20, 10).to(rank)
    labels = torch.randn(20, 1).to(rank)

    optimizer.zero_grad()
    outputs = ddp_model(inputs)
    loss = loss_fn(outputs, labels)
    loss.backward()
    optimizer.step()

    cleanup()


def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn,
             args=(world_size,),
             nprocs=world_size,
             join=True)


if __name__ == "__main__":
    n_gpus = torch.cuda.device_count()
    assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
    world_size = n_gpus
    run_demo(demo_basic, world_size)

4. Model Parallelism in PyTorch

Model parallelism can be implemented manually in PyTorch by splitting the model layers across different devices.

import torch
import torch.nn as nn


class ModelParallelMLP(nn.Module):
    def __init__(self):
        super(ModelParallelMLP, self).__init__()
        self.fc1 = nn.Linear(10, 20).to('cuda:0')
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(20, 1).to('cuda:1')

    def forward(self, x):
        x = x.to('cuda:0')
        x = self.relu(self.fc1(x))
        x = x.to('cuda:1')
        return self.fc2(x)


model = ModelParallelMLP()
input_data = torch.randn(20, 10).to('cuda:0')
output = model(input_data)
print(output.size())

5. Common Practices

5.1 Data Loading

When using data parallelism, it is important to use a DistributedSampler to ensure that each device gets a different subset of the data.

from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader, Dataset

class MyDataset(Dataset):
    def __len__(self):
        return 100

    def __getitem__(self, idx):
        return torch.randn(10), torch.randn(1)


# Assume rank and world_size are already defined
sampler = DistributedSampler(MyDataset(), num_replicas=world_size, rank=rank)
dataloader = DataLoader(MyDataset(), batch_size=10, sampler=sampler)

5.2 Checkpointing

Checkpointing is crucial in distributed training. You can save the model, optimizer, and other training states periodically to resume training in case of failures.

def save_checkpoint(model, optimizer, epoch, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, path)


def load_checkpoint(model, optimizer, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    return epoch

6. Best Practices

6.1 Use DistributedDataParallel over DataParallel

DistributedDataParallel is generally more efficient than DataParallel because it uses distributed communication primitives and has better gradient synchronization.

6.2 Optimize Collective Communication

Choose the appropriate collective communication backend (nccl for NVIDIA GPUs) and optimize the communication patterns to reduce the overhead.

6.3 Monitoring and Debugging

Use tools like TensorBoard to monitor the training process. You can log metrics such as loss, accuracy, and gradient norms.

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()

# During training
writer.add_scalar('Loss/train', loss.item(), epoch)

7. Conclusion

Distributed training in PyTorch is a powerful technique that allows you to train large - scale models efficiently. By understanding the fundamental concepts of data parallelism and model parallelism, and following the common and best practices, you can effectively utilize multiple GPUs or machines to speed up the training process. Whether you are working on a small - scale project or a large - scale research endeavor, PyTorch’s distributed training capabilities can significantly enhance your productivity.

8. References

  1. PyTorch official documentation: https://pytorch.org/docs/stable/index.html
  2. Distributed training in PyTorch tutorial: https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
  3. Deep learning distributed training best practices: https://arxiv.org/abs/2006.15704