Multi-node training can be tricky. When training on multiple nodes, machines must communicate with one another, synchronizing gradients and state. Most frameworks handle most of the heavy lifting behind distributed training algorithms (like torch's DDP implementation). However, different systems may have different networking setups and quirks, leading to frustrating errors. This is a guide specific to UTexas's TACC vista.
While the docs have a section dedicated to multi-node training, I had some issues with this1.
- The guide requires using system managed MPI-enabled Python. If you want to use your own managed virtual environment, you're out of luck.
- The guide requires using system provided PyTorch. If you want to use a different version or use a virtual environment, once again, you're out of luck.
Here's how I got multi-node training working with my setup. I'm using Lightning, but if you're using vanilla PyTorch, most of the steps will be the same. For all I know, my troubles could be specific to Lightning. Make sure to follow all below steps.
SLURM Script
#!/bin/bash
#SBATCH -A <your project id> # Project/account to charge
#SBATCH -J trainvol # Job name
#SBATCH -p gh # Partition/queue name (Grace Hopper nodes for GPUs)
#SBATCH -N 4 # Number of nodes
#SBATCH --ntasks-per-node=1 # One task (MPI rank) per node. Each task will manage its GPU(s).
#SBATCH -t 04:00:00 # Wall time (hh:mm:ss)
#SBATCH -o logs/train_%j.out # Standard output file (%j = job ID)
#SBATCH -e logs/train_%j.err # Standard error file
# Load necessary modules
# Ensure CUDA version matches what your PyTorch was compiled with in the venv
module load gcc cuda/12.6
ulimit -n 4096
cd /home1/10846/armeet/research/mri3d/
echo "Starting training job on $(hostname) at $(date)"
ln -sf "trainvol_${SLURM_JOB_ID}.out" "logs/latest.out"
ln -sf "trainvol_${SLURM_JOB_ID}.err" "logs/latest.err"
ibrun uv run scripts/train_volume.py
echo "Job finished at $(date)"Importantly to launch a job, you must use the Vista specific command ibrun <cmd> 2. In my case, since I'm using uv to manage my environment, I run my scripts with uv run /path/to/script.py. If you're using normal python just use ibrun python /path/to/script.py.
Environment Variables
I found that I had to correctly set the following environment variables manually in my code:
RANKWORLD_SIZELOCAL_RANKMASTER_ADDRMASTER_PORT
Here's how I did it.
import torch.distributed as dist
if 'OMPI_COMM_WORLD_RANK' in os.environ:
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(size)
os.environ['LOCAL_RANK'] = str(local_rank)
# Set master address and port
if rank == 0:
import socket
master_addr = socket.gethostname()
os.environ['MASTER_ADDR'] = master_addr
print(f"Rank {rank}: Master address: {master_addr}")
else:
# For rank 1, use rank 0's hostname
slurm_node = os.environ.get("SLURMD_NODENAME", "localhost")
master_addr = f"{slurm_node}.vista.tacc.utexas.edu"
os.environ['MASTER_ADDR'] = master_addr
print(f"Rank {rank}: Master address: {master_addr}")
os.environ['MASTER_PORT'] = '12355'
print(f"Set environment: RANK={rank}, WORLD_SIZE={size}, LOCAL_RANK={local_rank}")
# MANUALLY initialize distributed before Lightning
print(f"Rank {rank}: Manually initializing distributed...")
dist.init_process_group(
backend='nccl',
rank=rank,
world_size=size,
init_method='env://'
)
print(f"Rank {rank}: Distributed initialized successfully!")
# Set device
torch.cuda.set_device(local_rank)
print(f"Rank {rank}: CUDA device set to {local_rank}")
else:
print(
"\nERROR: This script must be run using an MPI launcher like `ibrun` or `mpirun`.\n"
)
sys.exit(1)This next step is extremely important. Make sure to set process_group_backend=None when initializing your strategy.
print(f"Rank {rank}: Setting up DDP strategy...")
strategy = DDPStrategy(
find_unused_parameters=False,
# Don't let Lightning initialize distributed since we already did
process_group_backend=None,
)That's it. If you do this, everything should work. Make sure to use the nccl backend for fast inter-device networking. You'll be transferring heaps of data between devices frequently and this backend is the fastest.
Testing
The best way to test is to open a live node via idev 3. We want exactly 1 task per node (--tpn), otherwise Lightning will panic.
idev -p gh -N <num_nodes> -t 00:30:00 --tpn 1
Then from the live node, launch the job with the system MPI ibrun command.
ibrun <cmd>
If you're running into trouble, I had Cursor write me up some MPI test scripts. Run them in-order to figure out what's going wrong.
test_mpi.py
import os
import argparse
from mpi4py import MPI
import torch
import torch.distributed as dist
# use mpi4py to get the world size and tasks rank
WORLD_SIZE = MPI.COMM_WORLD.Get_size()
WORLD_RANK = MPI.COMM_WORLD.Get_rank()
# use the convention that gets the local rank based on how many
# GPUs there are on the node.
GPU_ID = WORLD_RANK % torch.cuda.device_count()
name = MPI.Get_processor_name()
def run(backend):
tensor = torch.randn(10000,10000)
# Need to put tensor on a GPU device for nccl backend
if backend == 'nccl':
device = torch.device("cuda:{}".format(GPU_ID))
tensor = tensor.to(device)
print("Starting process on " + name+ ":" +torch.cuda.get_device_name(GPU_ID))
if WORLD_RANK == 0:
for rank_recv in range(1, WORLD_SIZE):
dist.send(tensor=tensor, dst=rank_recv)
print('worker_{} sent data to Rank {}\n'.format(0, rank_recv))
else:
dist.recv(tensor=tensor, src=0)
print('worker_{} has received data from rank {}\n'.format(WORLD_RANK,0))
def init_processes(backend, master_address):
print("World Rank: %s, World Size: %s, GPU_ID: %s"%(WORLD_RANK,WORLD_SIZE,GPU_ID))
print("Torch cuda available:", torch.cuda.is_available())
os.environ["MASTER_ADDR"] = master_address
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group(backend, rank=WORLD_RANK, world_size=WORLD_SIZE)
run(backend)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("master_node", type=str)
parser.add_argument("--backend", type=str, default="nccl", choices=['nccl', 'gloo'])
args = parser.parse_args()
backend=args.backend
if torch.cuda.device_count() == 0:
print("No gpu detected...switching to gloo for backend")
backend="gloo"
init_processes(backend=backend,master_address=args.master_node)
dist.destroy_process_group()test_simple_distributed.py
"""
Simple PyTorch distributed test to verify communication between nodes.
"""
import os
import torch
import torch.distributed as dist
def test_simple_distributed():
"""Test simple PyTorch distributed communication."""
print("=== Simple Distributed Test ===")
print(f"OMPI_COMM_WORLD_RANK: {os.environ.get('OMPI_COMM_WORLD_RANK', 'Not set')}")
print(f"OMPI_COMM_WORLD_SIZE: {os.environ.get('OMPI_COMM_WORLD_SIZE', 'Not set')}")
if 'OMPI_COMM_WORLD_RANK' in os.environ:
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
# Set PyTorch distributed environment variables
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(size)
os.environ['LOCAL_RANK'] = str(local_rank)
# Set master address and port
# Use a simple approach: rank 0 is always the master
if rank == 0:
import socket
master_addr = socket.gethostname()
os.environ['MASTER_ADDR'] = master_addr
print(f"Rank {rank}: Master address: {master_addr}")
else:
# Use SLURMD_NODENAME environment variable or default to c612-141
master_addr = os.environ.get("SLURMD_NODENAME")
os.environ['MASTER_ADDR'] = master_addr
print(f"Rank {rank}: Master address: {master_addr}")
os.environ['MASTER_PORT'] = '12355'
print(f"Rank {rank}: Setting up distributed...")
# Initialize distributed
dist.init_process_group(
backend='nccl', # Use nccl for GPU-based communication
rank=rank,
world_size=size,
init_method='env://'
)
print(f"Rank {rank}: Distributed initialized successfully!")
# Create a simple tensor on GPU
tensor = torch.tensor([rank * 10.0], device=f'cuda:{local_rank}')
print(f"Rank {rank}: Initial tensor: {tensor}")
# Test allreduce operation
dist.all_reduce(tensor)
print(f"Rank {rank}: After allreduce: {tensor}")
# Test broadcast operation
if rank == 0:
broadcast_tensor = torch.tensor([100.0], device=f'cuda:{local_rank}')
else:
broadcast_tensor = torch.tensor([0.0], device=f'cuda:{local_rank}')
dist.broadcast(broadcast_tensor, src=0)
print(f"Rank {rank}: After broadcast: {broadcast_tensor}")
# Clean up
dist.destroy_process_group()
print(f"Rank {rank}: Test completed successfully!")
else:
print("Not running in MPI environment")
if __name__ == "__main__":
test_simple_distributed() test_lightning_mpi.py
"""
Test script to verify PyTorch Lightning MPI environment detection.
"""
import os
import torch
import pytorch_lightning as pl
from pytorch_lightning.strategies import DDPStrategy
def test_mpi_environment():
"""Test MPI environment detection."""
print("=== MPI Environment Test ===")
print(f"RANK: {os.environ.get('OMPI_COMM_WORLD_RANK', 'Not set')}")
print(f"SIZE: {os.environ.get('OMPI_COMM_WORLD_SIZE', 'Not set')}")
print(f"LOCAL_RANK: {os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', 'Not set')}")
print(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR', 'Not set')}")
print(f"MASTER_PORT: {os.environ.get('MASTER_PORT', 'Not set')}")
print(f"\n=== CUDA Info ===")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
print(f"Current device: {torch.cuda.current_device()}")
print(f"Device name: {torch.cuda.get_device_name()}")
print(f"\n=== Lightning Environment ===")
# Test if Lightning can detect the environment
strategy = DDPStrategy(
find_unused_parameters=False,
use_distributed_sampler=True,
)
print(f"Strategy: {strategy}")
# Test trainer initialization
trainer = pl.Trainer(
accelerator="gpu",
devices=1,
num_nodes=2,
strategy=strategy,
max_epochs=1,
enable_checkpointing=False,
logger=False,
)
print(f"Trainer devices: {trainer.num_devices}")
print(f"Trainer nodes: {trainer.num_nodes}")
print("Lightning environment detection successful!")
if __name__ == "__main__":
test_mpi_environment() Footnotes