Multi-Node on BCP

This tutorial introduces ways to run multi-node applications on NVIDIA Base Command Platform (BCP).

Objectives

Run multi-node applications using the following:

  • Introduction to distributed environments

  • MPI

  • PyTorch DDP

  • PyTorch Lightning

  • PyTorch + Ray Train

Requirements

Introduction to distributed environments

Multi-node, or distributed, computing is a model of computation that takes the tasks from an algorithm that can be run independently and executes them across multiple computers. In the deep learning world, the simplest version of this is the data-distributed parallel model, where a dataset is split across GPUs allocated to the job and a single model is trained in collaboration by all the GPUs.

While DGXs and DGX Cloud were designed for multi-gpu computing, a single process can’t span across computers, so code still needs to be written to coordinate the processing on each system.

To learn some basics about working in distributed environments, lets start a 2-node job and explore

# Make a workspace
ngc workspace create --name ${USER}_tutorial

# Launch a 2-node job with the workspace mounted
ngc batch run --name "N2-test" --total-runtime 1h \
        --instance dgxa100.80g.8.norm \
        --commandline "jupyter lab --ip=0.0.0.0 --allow-root \
            --no-browser --NotebookApp.token='' \
            --notebook-dir=/ --NotebookApp.allow_origin='*'" \
        --result /results --array-type "PYTORCH" --replicas "2" \
        --workspace ${USER}_tutorial:/workspace:RW \
        --image "nvidia/pytorch:23.03-py3" --port 8888

In this job, we’ll be using bcprun to launch our distributed processes.

bcprun help text
bcprun (NGC multi-node run utility) version 1.2                                                                    [61/1911]
usage: bcprun [--nnodes <n>] [--npernode <p>] [--env <e>] [--workdir <w>] [--cmd <command-line>]
             [--async] [--debug] [--version] [--help] [--binding <b>]             
                                                                                                                            
required arguments:                                                                                                         
  -c <c>, --cmd <c>
                    Provide a command to run. (type: string)                                                                
                    Default value: (none)                                                                                   
                    Example: --cmd 'python train.py'                                                                        
                                                                                                                            
optional arguments:                                                                                                         
  -n <n>, --nnodes <n>                                                                                                      
                    Number of nodes to run on. (type: integer)                       
                    Range: min value: 1, max value: R,                                                                      
                    where R is max number of replicas requested by the NGC job.                
                    Default value: R                          
                    Example: --nnodes 2
  -p <p>, --npernode <p>
                    Number of tasks per node to run. (type: integer)
                    Range: min value: 1, max value: (none)
                    Default value: environment variable NGC_NTASKS_PER_NODE, if
                    set, otherwise 1.
                    Example: --npernode 8
  -e <e>, --env <e>
                    Environment variables to set with format 'key=value'. (type: string)
                    Each variable assignment requires a separate -e/--env flag.
                    Default value: (none)
                    Example: --env 'var1=value1' --env 'var2=value2'
  -w <w>, --workdir <w>
                    Base directory from which to run <cmd>. (type: string)
                    May include environment variables defined with --env.
                    Default value: environment variable PWD (current working directory)
                    Example: --workdir '$WORK_HOME/scripts' --env 'WORK_HOME=/mnt/workspace'
  -l <l>, --launcher <l>
                    Run <cmd> using an external launcher program. (type: string)
                    Supported launchers: mpirun, horovodrun
                    - mpirun: maps to OpenMPI options (https://www.open-mpi.org/)
                    - horovodrun: maps to Horovod options (https://horovod.ai/)
                    Note: This option assumes the launcher exists and is in PATH.
                    Launcher specific arguments (not part of bcprun options) can be provided as a suffix. E.g. --launcher 'mpirun --allow-run-as-root'
                    Default value: (none)
  -log <log>, --logdir <log>
                    Directory that stores bcprun.log. Also, in the case of PyTorch applications, it stores the logs per rank
 in each node. (type: string)
                    Default value: resultset mount path per node (env: NGC_RESULT_DIR)
  -a, --async
                    Run with asynchronous failure support enabled, i.e. a child process of bcprun can exit on failure without halting the program.
                    The program will continue while at least one child is running.
                    The default semantics of bcprun is to halt the program when any child process launched by bcprun exits with error.
   
  -b, --binding
                    Bind process to cpu-cores.

                    The following numa binding options are available.
                    - 'node': Processes are bound to cpus within a NUMA node. On GPU-enabled compute nodes,
                    a process is bound to all the cpus of the affined NUMA node (mapping local rank to GPU id), and the total number of ranks is limited to the total number of GPUs.
                    Example: Given 2 NUMA nodes N{0,1}, each with 4 GPUs and 32 CPUs C{0-31,32-63}, 8 processes P{0-7} will be mapped as: P{0-3}:N0:C{0-31}, P{4-7}:N1:C{32-63}

                    - 'exclusive': Processes are bound to exclusive sets of cpus within a NUMA node.
                    On GPU-enabled compute nodes, a process is bound to an exclusive cpu set within the affined NUMA node (mapping local rank to GPU id), and the total number of ranks is limited to the total number of GPUs.
                    Example: Given 2 NUMA nodes N(0,1), each with 4 GPUs and 32 CPUs C{0-31,32-63},
                    8 processes P{0-7} will be mapped as: P0:N0:C{0-7}, P1:N0:C{8-15}, P2:N0:C{16-23},
                    P3:N0:C{24-31}, P4:N1:C{32-39}, P5:N1:C{40-47}, P6:N1:C{48-55}, P7:N1:C{56-63}

                    - 'core-complex': Processes are bound to a core-complex, i.e. cpus sharing a last-level cache.
                    On GPU-enabled compute nodes, a process is bound to a core-complex of the
                    affined NUMA node (mapping local rank to GPU id), and
                    the total number of ranks is limited to the total number of GPUs.
                    Example: Given 2 NUMA nodes N{0,1}, each with 2 GPUs and 4 core-complexes X{0-3,4-7},
                    4 processes P{0-3} will be mapped as: P0:N0:X0, P1:N0:X1, P2:N1:X4, P3:N1:X5

                    - 'socket': Processes are bound to cpus within a socket. On GPU-enabled compute nodes,
                    a process is bound to the cpus of the socket containing the affined NUMA node (mapping local rank to GPU id), and the total number of ranks is limited to the total number of GPUs.
                    Example: Given 2 Sockets S(0,1), each with 4 GPUs and 64 CPUs C{0-63,64-127},
                    8 processes P{0-7} will be mapped as: P{0-3}:S0:C{0-63}, P{4-7}:S1:C{64-127}

                    Note:
                    --binding option is only applicable when arraytype is PYTORCH.

  -d, --debug
                    Print debug info and enable verbose mode.
  -j, --jsonlogs
                    Print the per-node aggregated logs in JSON format to joblog.log
  -no_redirect, --no_redirect
                    Print the logs to stdout/stderr instead of logging to files
  -v, --version
                    Print version info.
  -h, --help
                    Print this help message.

Note:
1.Local rank is passed to the python script using flag argument --local-rank
  for PyTorch version < 1.10. For all PyTorch versions >= 1.10, the --local_rank
  flag argument will NOT be passed to the python script by default.
  If you depend on parsing --local-rank in  your script for PyTorch versions >= 1.10,
  you can override the default behavior by setting environment variable NGC_PYTORCH_USE_ENV=0.
  Conversely, setting environment variable NGC_PYTORCH_USE_ENV=1 for PyTorch version < 1.10
  will suppress passing --local-rank flag argument.
2.Environment variable LOCAL_RANK is always set regardless of PyTorch version.
  Reading LOCAL_RANK from environment variable is the recommended method.

Once your job is running, connect to the jupyter portal or ngc batch exec into the job and work through the following:

# Print out a message
#mpirun --allow-run-as-root echo hello
bcprun -no_redirect -c "echo hello"

# Print out the hostname of each process
#mpirun --allow-run-as-root hostname
bcprun -no_redirect -c "hostname"

# Run two processes per node
#mpirun -npernode 1 --allow-run-as-root hostname
bcprun -no_redirect -p 2 -c "hostname"

# Install a python package and load it
pip install lightning
bcprun -no_redirect -c "python test.py"

# Install the package on all nodes and then try loading it
bcprun -no_redirect -c "pip install lightning"
bcprun -no_redirect -c "python test.py"

Traditionally, multi-node applications utilized Message Passing Interface (MPI) implementations like MVAPICH for collective operations like broadcast, reduce, and allreduce. Many high performance computing applications still use it, so this tutorial will start with how to run applications with the MPI launcher, mpirun.

MPI

MPI requires environment variables to be populated on each node. Schedulers like SLURM automatically populate these variables so MPI knows how many processes and threads to spawn, and how to communicate with other processes. On DGX Cloud with BCP, those variables get set by specifying the --array-type "MPI" argument when spawning a job.

Multi-node or distributed computing will enable huge speedups by allowing GPUs across multiple computers, or nodes, to cooperate when working when training deep learning models by allowing your program to While DGXs and DGX Cloud was designed for distributed computing, code still needs to be written to communicate

mpirun --allow-run-as-root -x IBV_DRIVERS=/usr/lib/libibverbs/libmlx5 \
    -np ${NGC_ARRAY_SIZE} -npernode 1 \
    bash -c "all_reduce_perf_mpi -b 64M -e 4G -f 2 -c 0 -n 100 -g ${NGC_GPUS_PER_NODE}"

This can also be run with bcprun as follows:

NGC_ARRAY_TYPE=MPIJob bcprun -no_redirect \
    --launcher 'mpirun --allow-run-as-root' \
    -c "all_reduce_perf_mpi -b 64M -e 4G -f 2 -c 0 -n 100 -g ${NGC_GPUS_PER_NODE}"

You’ll notice that both the mpirun and bcprun commands are using two environment variables to help make these scripts generally applicable for jobs of various sizes.

NGC_ARRAY_SIZE    - Number of nodes allocated to job
NGC_GPUS_PER_NODE - Number of GPUs allocated per node

This was a 2 node job with 8 gpus per node, and we can double-check these values with

$ env | grep -E "(NODE|SIZE)="
NGC_ARRAY_SIZE=2
NGC_GPUS_PER_NODE=8

PyTorch DDP

PyTorch DistributedDataParallel (DDP) implements data parallelism at the module level, allowing training to take place across multiple devices or compute nodes. DDP uses collective communications in the torch.distributed package to synchronize gradients and buffers.

To learn how to run PyTorch DDP training scripts and understand how scale affects the data processed by each rank, we’re going to use the following code.

#!/usr/bin/env python

# Adapted from https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multinode.py

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os

def ddp_setup():
    init_process_group(backend="nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

# Added this class becase datautils no longer exists
class MyTrainDataset(Dataset):
    def __init__(self, size):
        self.size = size
        self.data = [(torch.rand(20), torch.rand(1)) for _ in range(size)]
    def __len__(self):
        return self.size
    def __getitem__(self, index):
        return self.data[index]

class Trainer:
    def __init__(
        self,
        model: torch.nn.Module,
        train_data: DataLoader,
        optimizer: torch.optim.Optimizer,
        snapshot_path: str,
    ) -> None:
        self.local_rank = int(os.environ["LOCAL_RANK"])
        self.global_rank = int(os.environ["RANK"])
        self.model = model.to(self.local_rank)
        self.train_data = train_data
        self.optimizer = optimizer
        self.epochs_run = 0
        self.snapshot_path = snapshot_path
        self.model = DDP(self.model, device_ids=[self.local_rank])

    def _load_snapshot(self, snapshot_path):
        loc = f"cuda:{self.local_rank}"
        snapshot = torch.load(snapshot_path, map_location=loc)
        self.model.load_state_dict(snapshot["MODEL_STATE"])
        self.epochs_run = snapshot["EPOCHS_RUN"]
        print(f"Resuming training from snapshot at Epoch {self.epochs_run}")

    def _run_batch(self, source, targets):
        self.optimizer.zero_grad()
        output = self.model(source)
        loss = F.cross_entropy(output, targets)
        loss.backward()
        self.optimizer.step()

    def _run_epoch(self, epoch):
        b_sz = len(next(iter(self.train_data))[0])
        print(f"[GPU{self.global_rank}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
        self.train_data.sampler.set_epoch(epoch)
        for source, targets in self.train_data:
            source = source.to(self.local_rank)
            targets = targets.to(self.local_rank)
            self._run_batch(source, targets)

    def train(self, max_epochs: int):
        for epoch in range(self.epochs_run, max_epochs):
            self._run_epoch(epoch)

def load_train_objs():
    train_set = MyTrainDataset(2048*2048)  # load your dataset
    # Make layers for model
    layers = [torch.nn.Linear(20, 512),torch.nn.Linear(512,2048)]
    layers += [torch.nn.Linear(2048,2048) for i in range(13)]
    layers.append(torch.nn.Linear(2048, 1))
    model = torch.nn.Sequential(*layers)
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    return train_set, model, optimizer

def prepare_dataloader(dataset: Dataset, batch_size: int):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        # Scales steps with number of workers
        sampler=DistributedSampler(dataset)
    )

def main(total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"):
    ddp_setup()
    dataset, model, optimizer = load_train_objs()
    train_data = prepare_dataloader(dataset, batch_size)
    trainer = Trainer(model, train_data, optimizer, snapshot_path)
    trainer.train(total_epochs)
    destroy_process_group()

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='simple distributed training job')
    parser.add_argument('--epochs', defaut=5, type=int, help='Total epochs to train the model')
    parser.add_argument('--batch_size', default=512, type=int, help='Input batch size on each device (default: 512)')
    args = parser.parse_args()
    
    main(args.epochs, args.batch_size)

The data is random, so this is just a “noise in, noise out” model that is large enough to not instantly finish.

Optional Exercises

  • What happens if you run on 4 GPUs?

  • Try increasing the batch size

PyTorch Lightning

PyTorch Lightning is a high-level framework for building and training PyTorch models with simplicity and scalability.

We can adapt the PyTorch DDP example for PyTorch Lightning as follows:

#!/usr/bin/env python

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import lightning as L

import os, argparse

# Added this class becase datautils no longer exists
class MyTrainDataset(Dataset):
    def __init__(self, size):
        self.size = size
        self.data = [(torch.rand(20), torch.rand(1)) for _ in range(size)]
    def __len__(self):
        return self.size
    def __getitem__(self, index):
        return self.data[index]

class LanguageModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        layers = [torch.nn.Linear(20, 512),torch.nn.Linear(512,2048)]
        layers += [torch.nn.Linear(2048,2048) for i in range(13)]
        layers.append(torch.nn.Linear(2048, 1))
        self.model = torch.nn.Sequential(*layers)

    def training_step(self, batch, batch_idx):
        input, target = batch
        output = self.model(input)
        loss = F.cross_entropy(output, target)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.model.parameters(), lr=1e-3)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='simple distributed training job')
    parser.add_argument('--epochs', default=5, type=int, help='Total epochs to train the model')
    parser.add_argument('--batch_size', default=512, type=int, help='Input batch size on each device (default: 512)')
    parser.add_argument('-N', default=1, type=int, help='Number of nodes')
    parser.add_argument('-p', default=1, type=int, help='Number of gpus per node')

    args = parser.parse_args()
    
    train_dataloader = DataLoader(MyTrainDataset(2048*2048), batch_size=args.batch_size)
    model = LanguageModel()

    # Trainer
    trainer = L.Trainer(max_epochs=args.epochs, devices=args.p, accelerator="gpu", num_nodes=args.N)
    trainer.fit(model, train_dataloader)

You’ve probably noticed that the PyTorch Lightning version of this script is much shorter. PyTorch Lightning was designed to hide things like logging, distributed communication, data loading, and more in the background so it’s easier to just design models, utilize accelerators, and scale.

Similarly to DDP, this script can then be launched with the following:

#!/bin/bash

# Set address for main process
export NGC_MASTER_ADDR=launcher-svc-${NGC_JOB_ID}

# Launch PTL script on all nodes and GPUs in job
bcprun -no_redirect -n ${NGC_ARRAY_SIZE} -p ${NGC_GPUS_PER_NODE} -c "python ptl_ddp_example.py -N ${NGC_ARRAY_SIZE} -p ${NGC_GPUS_PER_NODE}"

Just notice that the main process URL, which is the hostname, needs to be set.

When run this is run on 2 nodes with 8 GPUs each, you’ll see the following output

$ bcprun -no_redirect -n 2 -p 8 -c "python ptl_ddp_example.py -N 2 -p 8"

*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your
 application as needed.
*****************************************
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your
 application as needed.
*****************************************
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/16
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/16
Initializing distributed: GLOBAL_RANK: 7, MEMBER: 8/16            
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 5, MEMBER: 6/16
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/16
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/16
Initializing distributed: GLOBAL_RANK: 10, MEMBER: 11/16
Initializing distributed: GLOBAL_RANK: 8, MEMBER: 9/16
Initializing distributed: GLOBAL_RANK: 14, MEMBER: 15/16
Initializing distributed: GLOBAL_RANK: 6, MEMBER: 7/16
Initializing distributed: GLOBAL_RANK: 4, MEMBER: 5/16
Initializing distributed: GLOBAL_RANK: 15, MEMBER: 16/16
Initializing distributed: GLOBAL_RANK: 9, MEMBER: 10/16
Initializing distributed: GLOBAL_RANK: 12, MEMBER: 13/16
Initializing distributed: GLOBAL_RANK: 13, MEMBER: 14/16
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 16 processes
----------------------------------------------------------------------------------------------------

Initializing distributed: GLOBAL_RANK: 11, MEMBER: 12/16
LOCAL_RANK: 5 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 6 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 4 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 7 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 4 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 6 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 5 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 7 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 55.6 M
-------------------------------------
55.6 M    Trainable params
0         Non-trainable params
55.6 M    Total params
222.464   Total estimated model params size (MB)
/usr/local/lib/python3.8/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider 
increasing the value of the `num_workers` argument` to `num_workers=30` in the `DataLoader` to improve performance.
Epoch 4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 512/512 [00:03<00:00, 129.19it/s, v_num=2, train_loss=-]
`Trainer.fit` stopped: `max_epochs=5` reached.
Epoch 4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 512/512 [00:04<00:00, 119.64it/s, v_num=2, train_loss=-]
Cleaning up
Cleaning up

Optional Exercises

  • What happens if you run on 4 GPUs?

  • Try decreasing the batch size

Pytorch + Ray Train

Ray is an open source framework to build and scale ML and Python applications. At it’s core, it contains collective operations that can be run on Ray “clusters”, or collections of worker processes.

To use Ray across multiple nodes, a Ray cluster needs to first be started across the nodes. I recommend using the following helper script to start these processes in a BCP environment

#!/bin/bash

if [ "$NGC_ARRAY_SIZE" -gt "1" ]; then
  export NGC_MASTER_ADDR=launcher-svc-${NGC_JOB_ID}
fi

export PORT=6379

if [ "$NGC_REPLICA_ID" -eq "0" ]; then
	ray start --head --node-ip-address=${NGC_MASTER_ADDR} --port=${PORT}
else
	sleep 10
	ray start --address=${NGC_MASTER_ADDR} --port=${PORT}
fi

and the bcprun command to run the script on each node:

bcprun -no_redirect -c 'bash start_ray_cluster.sh'

This first starts the main process on the head node and then worker processes across all other nodes after sleeping for 10 seconds. Once the cluster is started, you’re able to submit jobs for execution on the cluster simply by utilizing the Ray library.

To illustrate this, the Train a PyTorch Model on Fashion MNIST example was modified to accept an argument for the number of workers

#!/usr/bin/env python

# https://docs.ray.io/en/latest/train/examples/pytorch/torch_fashion_mnist_example.html

import os, argparse
from typing import Dict

import torch
from filelock import FileLock
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import Normalize, ToTensor
from tqdm import tqdm

import ray.train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer


def get_dataloaders(batch_size):
    # Transform to normalize the input images
    transform = transforms.Compose([ToTensor(), Normalize((0.5,), (0.5,))])

    with FileLock(os.path.expanduser("~/data.lock")):
        # Download training data from open datasets
        training_data = datasets.FashionMNIST(
            root="~/data",
            train=True,
            download=True,
            transform=transform,
        )

        # Download test data from open datasets
        test_data = datasets.FashionMNIST(
            root="~/data",
            train=False,
            download=True,
            transform=transform,
        )

    # Create data loaders
    train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=batch_size)

    return train_dataloader, test_dataloader


# Model Definition
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(512, 10),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


def train_func_per_worker(config: Dict):
    lr = config["lr"]
    epochs = config["epochs"]
    batch_size = config["batch_size_per_worker"]

    # Get dataloaders inside the worker training function
    train_dataloader, test_dataloader = get_dataloaders(batch_size=batch_size)

    # [1] Prepare Dataloader for distributed training
    # Shard the datasets among workers and move batches to the correct device
    # =======================================================================
    train_dataloader = ray.train.torch.prepare_data_loader(train_dataloader)
    test_dataloader = ray.train.torch.prepare_data_loader(test_dataloader)

    model = NeuralNetwork()

    # [2] Prepare and wrap your model with DistributedDataParallel
    # Move the model to the correct GPU/CPU device
    # ============================================================
    model = ray.train.torch.prepare_model(model)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    # Model training loop
    for epoch in range(epochs):
        if ray.train.get_context().get_world_size() > 1:
            # Required for the distributed sampler to shuffle properly across epochs.
            train_dataloader.sampler.set_epoch(epoch)

        model.train()
        for X, y in tqdm(train_dataloader, desc=f"Train Epoch {epoch}"):
            pred = model(X)
            loss = loss_fn(pred, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        model.eval()
        test_loss, num_correct, num_total = 0, 0, 0
        with torch.no_grad():
            for X, y in tqdm(test_dataloader, desc=f"Test Epoch {epoch}"):
                pred = model(X)
                loss = loss_fn(pred, y)

                test_loss += loss.item()
                num_total += y.shape[0]
                num_correct += (pred.argmax(1) == y).sum().item()

        test_loss /= len(test_dataloader)
        accuracy = num_correct / num_total

        # [3] Report metrics to Ray Train
        # ===============================
        ray.train.report(metrics={"loss": test_loss, "accuracy": accuracy})


def train_fashion_mnist(num_workers=2, use_gpu=False):
    global_batch_size = 32

    train_config = {
        "lr": 1e-3,
        "epochs": 10,
        "batch_size_per_worker": global_batch_size // num_workers,
    }

    # Configure computation resources
    scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)

    # Initialize a Ray TorchTrainer
    trainer = TorchTrainer(
        train_loop_per_worker=train_func_per_worker,
        train_loop_config=train_config,
        scaling_config=scaling_config,
    )

    # [4] Start distributed training
    # Run `train_func_per_worker` on all workers
    # =============================================
    result = trainer.fit()
    print(f"Training result: {result}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="PyTorch MNIST Example",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--num-workers",
        type=int,
        default=4,
        help="Number of Ray workers to use for training.",
    )
    args = parser.parse_args()
    train_fashion_mnist(num_workers=args.num_workers, use_gpu=True)

The script can then be run across 2 nodes (16 GPUs) with the following command:

python ray+pt.py --num-workers=16

This can also scale with your job size by using the NGC_* environment variables to calculate the number of GPUs in your job.

python ray+pt.py --num-workers=$(( $NGC_ARRAY_SIZE * $NGC_GPUS_PER_NODE ))

Once you’re done, you can stop the ray cluster with

ray stop -f

Next Steps