Armeet Singh Jatyani

Founder · AI Researcher & Engineer

Back to blog

Multi-Node Training on TACC Vista

Multi-node training can be tricky. When training on multiple nodes, machines must communicate with one another, synchronizing gradients and state. Most frameworks handle most of the heavy lifting behind distributed training algorithms (like torch's DDP implementation). However, different systems may have different networking setups and quirks, leading to frustrating errors. This is a guide specific to UTexas's TACC vista.

While the docs have a section dedicated to multi-node training, I had some issues with this1.

  • The guide requires using system managed MPI-enabled Python. If you want to use your own managed virtual environment, you're out of luck.
  • The guide requires using system provided PyTorch. If you want to use a different version or use a virtual environment, once again, you're out of luck.

Here's how I got multi-node training working with my setup. I'm using Lightning, but if you're using vanilla PyTorch, most of the steps will be the same. For all I know, my troubles could be specific to Lightning. Make sure to follow all below steps.


SLURM Script

Shell
#!/bin/bash
#SBATCH -A <your project id>          # Project/account to charge
#SBATCH -J trainvol                   # Job name
#SBATCH -p gh                         # Partition/queue name (Grace Hopper nodes for GPUs)
#SBATCH -N 4                          # Number of nodes
#SBATCH --ntasks-per-node=1           # One task (MPI rank) per node. Each task will manage its GPU(s).
#SBATCH -t 04:00:00                   # Wall time (hh:mm:ss)
#SBATCH -o logs/train_%j.out          # Standard output file (%j = job ID)
#SBATCH -e logs/train_%j.err          # Standard error file

# Load necessary modules
# Ensure CUDA version matches what your PyTorch was compiled with in the venv
module load gcc cuda/12.6
ulimit -n 4096

cd /home1/10846/armeet/research/mri3d/

echo "Starting training job on $(hostname) at $(date)"

ln -sf "trainvol_${SLURM_JOB_ID}.out" "logs/latest.out"
ln -sf "trainvol_${SLURM_JOB_ID}.err" "logs/latest.err"

ibrun uv run scripts/train_volume.py

echo "Job finished at $(date)"

Importantly to launch a job, you must use the Vista specific command ibrun <cmd> 2. In my case, since I'm using uv to manage my environment, I run my scripts with uv run /path/to/script.py. If you're using normal python just use ibrun python /path/to/script.py.


Environment Variables

I found that I had to correctly set the following environment variables manually in my code:

  • RANK
  • WORLD_SIZE
  • LOCAL_RANK
  • MASTER_ADDR
  • MASTER_PORT

Here's how I did it.

Python
import torch.distributed as dist

if 'OMPI_COMM_WORLD_RANK' in os.environ:
    rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
    size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
    local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
    
    os.environ['RANK'] = str(rank)
    os.environ['WORLD_SIZE'] = str(size)
    os.environ['LOCAL_RANK'] = str(local_rank)
    
    # Set master address and port
    if rank == 0:
        import socket
        master_addr = socket.gethostname()
        os.environ['MASTER_ADDR'] = master_addr
        print(f"Rank {rank}: Master address: {master_addr}")
    else:
        # For rank 1, use rank 0's hostname
        slurm_node = os.environ.get("SLURMD_NODENAME", "localhost")
        master_addr = f"{slurm_node}.vista.tacc.utexas.edu"
        os.environ['MASTER_ADDR'] = master_addr
        print(f"Rank {rank}: Master address: {master_addr}")
    
    os.environ['MASTER_PORT'] = '12355'
    
    print(f"Set environment: RANK={rank}, WORLD_SIZE={size}, LOCAL_RANK={local_rank}")
    
    # MANUALLY initialize distributed before Lightning
    print(f"Rank {rank}: Manually initializing distributed...")
    dist.init_process_group(
        backend='nccl',
        rank=rank,
        world_size=size,
        init_method='env://'
    )
    print(f"Rank {rank}: Distributed initialized successfully!")
    
    # Set device
    torch.cuda.set_device(local_rank)
    print(f"Rank {rank}: CUDA device set to {local_rank}")
else:
    print(
        "\nERROR: This script must be run using an MPI launcher like `ibrun` or `mpirun`.\n"
    )
    sys.exit(1)

This next step is extremely important. Make sure to set process_group_backend=None when initializing your strategy.

Python
print(f"Rank {rank}: Setting up DDP strategy...")
strategy = DDPStrategy(
    find_unused_parameters=False,
    # Don't let Lightning initialize distributed since we already did
    process_group_backend=None,
)

That's it. If you do this, everything should work. Make sure to use the nccl backend for fast inter-device networking. You'll be transferring heaps of data between devices frequently and this backend is the fastest.


Testing

The best way to test is to open a live node via idev 3. We want exactly 1 task per node (--tpn), otherwise Lightning will panic.

idev -p gh -N <num_nodes> -t 00:30:00 --tpn 1

Then from the live node, launch the job with the system MPI ibrun command.

ibrun <cmd>

If you're running into trouble, I had Cursor write me up some MPI test scripts. Run them in-order to figure out what's going wrong.

test_mpi.py
Python
import os
import argparse

from mpi4py import MPI

import torch
import torch.distributed as dist

# use mpi4py to get the world size and tasks rank
WORLD_SIZE = MPI.COMM_WORLD.Get_size()
WORLD_RANK = MPI.COMM_WORLD.Get_rank()

# use the convention that gets the local rank based on how many
# GPUs there are on the node.
GPU_ID = WORLD_RANK % torch.cuda.device_count()
name = MPI.Get_processor_name()

def run(backend):
    tensor = torch.randn(10000,10000)

    # Need to put tensor on a GPU device for nccl backend
    if backend == 'nccl':
        device = torch.device("cuda:{}".format(GPU_ID))
        tensor = tensor.to(device)
    print("Starting process on " + name+ ":" +torch.cuda.get_device_name(GPU_ID))
    if WORLD_RANK == 0:
        for rank_recv in range(1, WORLD_SIZE):
            dist.send(tensor=tensor, dst=rank_recv)
            print('worker_{} sent data to Rank {}\n'.format(0, rank_recv))
    else:
        dist.recv(tensor=tensor, src=0)
        print('worker_{} has received data from rank {}\n'.format(WORLD_RANK,0))

def init_processes(backend, master_address):
    print("World Rank: %s, World Size: %s, GPU_ID: %s"%(WORLD_RANK,WORLD_SIZE,GPU_ID))
    print("Torch cuda available:", torch.cuda.is_available())
    os.environ["MASTER_ADDR"] = master_address
    os.environ["MASTER_PORT"] = "12355"
    dist.init_process_group(backend, rank=WORLD_RANK, world_size=WORLD_SIZE)
    run(backend)

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("master_node", type=str)
    parser.add_argument("--backend", type=str, default="nccl", choices=['nccl', 'gloo'])
    args = parser.parse_args()
    backend=args.backend
    if torch.cuda.device_count() == 0:
        print("No gpu detected...switching to gloo for backend")
        backend="gloo"
    init_processes(backend=backend,master_address=args.master_node)
    dist.destroy_process_group()
test_simple_distributed.py
Python
"""
Simple PyTorch distributed test to verify communication between nodes.
"""

import os
import torch
import torch.distributed as dist

def test_simple_distributed():
    """Test simple PyTorch distributed communication."""
    print("=== Simple Distributed Test ===")
    print(f"OMPI_COMM_WORLD_RANK: {os.environ.get('OMPI_COMM_WORLD_RANK', 'Not set')}")
    print(f"OMPI_COMM_WORLD_SIZE: {os.environ.get('OMPI_COMM_WORLD_SIZE', 'Not set')}")
    
    if 'OMPI_COMM_WORLD_RANK' in os.environ:
        rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
        size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
        local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
        
        # Set PyTorch distributed environment variables
        os.environ['RANK'] = str(rank)
        os.environ['WORLD_SIZE'] = str(size)
        os.environ['LOCAL_RANK'] = str(local_rank)
        
        # Set master address and port
        # Use a simple approach: rank 0 is always the master
        if rank == 0:
            import socket
            master_addr = socket.gethostname()
            os.environ['MASTER_ADDR'] = master_addr
            print(f"Rank {rank}: Master address: {master_addr}")
        else:
            # Use SLURMD_NODENAME environment variable or default to c612-141
            master_addr = os.environ.get("SLURMD_NODENAME")
            os.environ['MASTER_ADDR'] = master_addr
            print(f"Rank {rank}: Master address: {master_addr}")
        os.environ['MASTER_PORT'] = '12355'
        
        print(f"Rank {rank}: Setting up distributed...")
        
        # Initialize distributed
        dist.init_process_group(
            backend='nccl',  # Use nccl for GPU-based communication
            rank=rank,
            world_size=size,
            init_method='env://'
        )
        
        print(f"Rank {rank}: Distributed initialized successfully!")
        
        # Create a simple tensor on GPU
        tensor = torch.tensor([rank * 10.0], device=f'cuda:{local_rank}')
        print(f"Rank {rank}: Initial tensor: {tensor}")
        
        # Test allreduce operation
        dist.all_reduce(tensor)
        print(f"Rank {rank}: After allreduce: {tensor}")
        
        # Test broadcast operation
        if rank == 0:
            broadcast_tensor = torch.tensor([100.0], device=f'cuda:{local_rank}')
        else:
            broadcast_tensor = torch.tensor([0.0], device=f'cuda:{local_rank}')
        
        dist.broadcast(broadcast_tensor, src=0)
        print(f"Rank {rank}: After broadcast: {broadcast_tensor}")
        
        # Clean up
        dist.destroy_process_group()
        print(f"Rank {rank}: Test completed successfully!")
    else:
        print("Not running in MPI environment")

if __name__ == "__main__":
    test_simple_distributed() 
test_lightning_mpi.py
Python
"""
Test script to verify PyTorch Lightning MPI environment detection.
"""

import os
import torch
import pytorch_lightning as pl
from pytorch_lightning.strategies import DDPStrategy

def test_mpi_environment():
    """Test MPI environment detection."""
    print("=== MPI Environment Test ===")
    print(f"RANK: {os.environ.get('OMPI_COMM_WORLD_RANK', 'Not set')}")
    print(f"SIZE: {os.environ.get('OMPI_COMM_WORLD_SIZE', 'Not set')}")
    print(f"LOCAL_RANK: {os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', 'Not set')}")
    print(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR', 'Not set')}")
    print(f"MASTER_PORT: {os.environ.get('MASTER_PORT', 'Not set')}")
    
    print(f"\n=== CUDA Info ===")
    print(f"CUDA available: {torch.cuda.is_available()}")
    print(f"CUDA device count: {torch.cuda.device_count()}")
    if torch.cuda.is_available():
        print(f"Current device: {torch.cuda.current_device()}")
        print(f"Device name: {torch.cuda.get_device_name()}")
    
    print(f"\n=== Lightning Environment ===")
    # Test if Lightning can detect the environment
    strategy = DDPStrategy(
        find_unused_parameters=False,
        use_distributed_sampler=True,
    )
    print(f"Strategy: {strategy}")
    
    # Test trainer initialization
    trainer = pl.Trainer(
        accelerator="gpu",
        devices=1,        
        num_nodes=2,
        strategy=strategy,
        max_epochs=1,
        enable_checkpointing=False,
        logger=False,
    )
    print(f"Trainer devices: {trainer.num_devices}")
    print(f"Trainer nodes: {trainer.num_nodes}")
    print("Lightning environment detection successful!")

if __name__ == "__main__":
    test_mpi_environment() 


Footnotes

Footnotes

  1. If I'm mistaken, let me know and I'll update this. I had previously tried a hacky solution of trying to set environment variables to use the system managed python interpreter while looking in the environment folder for package files. This didn't really work.

  2. ibrun is specific to Vista and replaces srun: [docs]

  3. idev is also specific to Vista