In this post, we’ll explore the crucial role of distributed training in scaling Deep Neural Networks (DNNs) to handle large datasets and complex models. We’ll take an in-depth look at data parallel training, the most widely used technique in this domain, and dive into its implementation to provide an intuitive understanding of how it enhances efficiency in deep learning.

Why Do We Need Distributed Training?

There are several advantages to using distributed training, but in my opinion, these two are the most important ones that we will focus on:

Training or Fine-Tuning Large Models

In some cases, the model itself is too large to fit on a single GPU. Consider Llama 3.1, a model with an astounding 405 billion parameters, requiring approximately 810GB of memory in its original BF16 precision (ref: Hugging Face blog on Llama 3.1’s memory requirements). Even the most advanced GPUs today don’t have the capacity to handle such a model alone. In these scenarios, we must distribute the model across multiple GPUs to make training possible. This technique is referred to as model parallelism.

Growth in model size over time, illustrating the rapid increase in the number of parameters in natural language models from 2018 to 2022 (ref: Hugging Face blog - Large Language Models: A New Moore’s Law?)

Growth in model size over time, illustrating the rapid increase in the number of parameters in natural language models from 2018 to 2022 (ref: Hugging Face blog - Large Language Models: A New Moore’s Law?)

A historical example of the need for model parallelism is the groundbreaking AlexNet model, which stunned the world with its performance on the ImageNet dataset in 2012. Due to the limited GPU resources available at the time, the creators had to split the model into two parts to fit it onto their GPUs. This approach laid the foundation for the distributed training techniques we rely on today.

Network architecture of AlexNet (ref: AlexNet and ImageNet: The Birth of Deep Learning)

Network architecture of AlexNet (ref: AlexNet and ImageNet: The Birth of Deep Learning)

Leveraging Data Parallelism with GPUs

GPUs have revolutionized deep learning by enabling efficient processing through data parallelism. This means they can handle multiple data points simultaneously, making them incredibly effective for tensor operations. During model training, we typically feed batches of data into the network to take full advantage of this parallelism, resulting in better performance.

However, when dealing with enormous datasets, simply increasing the batch size isn’t always feasible. Doing so can lead to a CUDA out-of-memory error, especially when a single GPU’s memory is insufficient. But what if we could spread this workload across multiple GPUs? By distributing the data across different GPUs, we can maintain large batch sizes without exceeding the memory limits of individual devices. This approach, known as data parallelism, effectively increases computational power and efficiency.

(ref: Model Parallelism)

(ref: Model Parallelism)

Data Parallel Training

Data parallelism is more common and easier to implement compared to model parallelism. Most popular deep learning frameworks, like PyTorch and TensorFlow, offer built-in support for data parallelism, allowing you to make your training loop data parallel with just a few additional lines of code.

So, what happens during the typical training of a neural network? The process starts with data being fed into the network, where each layer computes and caches its output. During the back-propagation phase, the gradient of the loss with respect to the model weights is calculated. Finally, these gradients are used to update the model weights.

Forward and backward path in DNNs.

Forward and backward path in DNNs.

In this setup, the upper limit of the batch size is determined by the GPU memory and the model size.

Increasing the batch size offers two key advantages: it reduces training time and leads to less noisy convergence, which might help the model reach an optimal point faster. However, batch size is limited by GPU memory. When working with a single GPU, how can we increase the batch size without triggering an out-of-memory error? One effective technique is gradient accumulation.

Gradient Accumulation

Gradient accumulation is a technique that allows us to effectively increase the batch size without running into GPU memory limitations. Here’s how it works: instead of processing a large batch all at once, you break it into smaller mini-batches and feed them into the network. However, instead of updating the model weights after each mini-batch, you accumulate the gradients over several mini-batches. Once all the gradients are collected, they are averaged, and the result is treated as the gradient for the entire batch.

Gradient accumulation technique.

Gradient accumulation technique.

Mathematically, it can be shown that under certain assumptions, the result of gradient accumulation across mini-batches is equivalent to computing the gradient from a single large batch.

In many loss functions, the overall loss is simply the average of the individual losses. Therefore, the loss for the entire batch can be seen as the average of the losses from the mini-batches:

$$ \begin{align*} L_{\text{Batch}} &= \frac{1}{N} \sum_{i=1}^{\text{batch-size}} \text{loss}(\text{input}_i, \text{target}_i) \\ &= \frac{1}{\#\text{Nodes}} \times \sum_{i=1}^{\#\text{Nodes}} \frac{1}{\text{mini-batch-size}} \sum_{j=1}^{\text{mini-batch-size}} \text{loss}(\text{input}_{ij}, \text{target}_{ij}) \\ &= \frac{1}{\#\text{Nodes}} \sum_{i=1}^{\#\text{Nodes}} L_{\text{mini-batch}\ i} \end{align*} $$

This allows the gradient for each term to be computed independently on separate machines and then summed before updating the weights:

$$ \frac{\partial L_{\text{Batch}}}{\partial W} = \frac{1}{\#\text{Nodes}} \times \sum_{i=1}^{\#\text{Nodes}} \frac{\partial L_{\text{mini-batch } i}}{\partial W} $$

However, there are a few assumptions and considerations to keep in mind:

  1. Batch Size Consistency: The size of the mini-batches should satisfy the equation $N = \# \text{Nodes} \times \text{mini-batch-size}$ to above equations hold.
  2. Sensitivity to Input Batch: Some layers, like batch normalization, are sensitive to the composition of the batch and batch-size. This means that splitting the batch across different devices can lead to slight variations in output, depending on the specific content of each mini-batch. However, in practice, these differences are usually negligible.
  3. Floating Point Precision: In the real world, floating-point operations are not perfectly associative or commutative, meaning the result of $A + (B + C)$ is not always exactly the same as $(A + B) + C$. This discrepancy can lead to minor differences in results, so it’s important to account for this when debugging or evaluating your models.

Distributed Training

With gradient accumulation, we can effectively increase the batch size without running into memory limitations. However, one drawback of this technique is that it doesn’t reduce training time; the relationship between batch size and training time remains linear. The main advantage of increasing the batch size is leveraging the parallel processing power of GPUs to make training more efficient. As we’ve seen, there’s no dependency between calculating different mini-batches, which opens the door to using distributed training for even greater efficiency.

Data parallel distributed training.

Data parallel distributed training.

The general steps of the distributed training algorithm are as follows:

  1. Copy the model to all devices.
  2. Split the data across devices.
  3. Perform forward and backward passes on each device and calculate local gradients for the mini-batch.
  4. Use an AllReduce operation to collect the gradients from all the devices and calculate the whole batch gradients on each device.
  5. Optimize the model on each device with the batch gradients.

Steps 1 and 2 are performed only once at the start of training, while steps 3 to 5 are repeated in every training loop iteration. After step 4, all devices have equal gradients for the whole batch. Therefore, the updated model after step 5 will be the same and consistent across all devices.

The standard AllReduce distributed operation, which sums or averages the gradients, ensures that the order of summation is same across all devices, resulting in identical outcomes on each device. Additionally, in PyTorch, you can convert all BatchNorm layers to SyncBatchNorm using convert_sync_batchnorm method, which adds synchronization and ensures consistent results across all GPUs.

Model parameters can be distributed instead of gradients if the optimizer is stateless like SGD. However, if the optimizer is stateful, the model update is performed with different gradients locally, leading to optimizer’s state divergence over time, even if the model parameters are synchronized. Additionally, during fine-tuning, we often freeze certain parts of the network, which means exchanging gradients typically involves transferring fewer parameters between devices. Therefore, generally synchronizing gradients is a better approach than synchronizing model parameters and is used by most deep learning frameworks by default.

More about AllReduce

AllReduce is a collective operation commonly used in Single Program Multiple Data (SPMD) applications. Its primary function in distributed training is to collect the individual results (in this case, gradients) from each node, compute the reduction (such as averaging the gradients), and then communicate the result back to all nodes.

There are various implementations of the AllReduce method. The most naive approach involves each node fetching data from all other nodes and calculating the reduction locally. However, this method is neither efficient nor scalable, as it consumes excessive network bandwidth.

A more efficient implementation is Ring AllReduce. In Ring AllReduce, each device only sends its data to the next device and receives data from the previous one. In the next time step, each node processes new data from the previous node and sends this new data to the next one. After $N$ steps (where $N$ is the number of devices), all devices will have received data from all other devices and can do the reduction. This method is more efficient and is implemented by most deep learning frameworks, such as PyTorch. For a more detailed explanation of Ring AllReduce, you can check out this blog post. Below is an intuitive visualization of how this algorithm works:

Ring AllReduce algorithm.

Ring AllReduce algorithm.

In the Ring AllReduce method, it takes $N-1$ time steps to complete the data transfer, and the bandwidth usage remains constant for the same model, regardless of the number of devices. The graph below illustrates the number of samples processed per second with a 300-million parameter language model, showing how performance scales linearly with the number of GPUs performing synchronous training using the Ring AllReduce method, based on real experimental results:

(Ref: Bringing HPC Techniques to Deep Learning)

(Ref: Bringing HPC Techniques to Deep Learning)

Conclusion

In this post, we explored how distributed training enables us to handle models larger than the memory capacity of a single GPU and achieve faster training times by increasing batch sizes beyond what a single GPU can manage. We examined gradient accumulation as a technique to increase batch size on a single GPU and demonstrated, under certain assumptions that it is mathematically equivalent to training with larger batches.

We then extended this concept by applying gradient accumulation within a distributed training framework. We detailed how this technique scales across multiple GPUs, allowing us to efficiently manage larger batch sizes and synchronize data using methods like Ring AllReduce.

As deep learning models continue to grow and evolve, mastering these distributed training techniques will be essential for pushing the boundaries of what’s possible and achieving state-of-the-art results.