A library for distributed PyTorch execution in Jupyter notebooks with seamless REPL-like behavior.
This library is being built to help run my new course and as a result is constantly changing. For right now it is "stable enough" however as I find new features to use/need in the course I need to expand the framework.
As a result it is not open to contributions at this time.
- Seamless Distributed Execution: Run PyTorch code across multiple GPUs directly from Jupyter notebooks
- REPL-like Behavior: See results immediately without explicit print statements
- Automatic GPU Management: Smart allocation of GPUs to worker processes
- Interactive Development: Real-time feedback and error reporting
- IDE Support: Namespace synchronization for code completion and type hints
- Robust Process Management: Graceful startup, monitoring, and shutdown
pip install nbdistributed
- Import and initialize in your Jupyter notebook:
%load_ext nbdistributed
%dist_init -n 4 # Start 4 worker processes
- Run code on all workers:
import torch
print(f"Rank {rank} running on {torch.cuda.get_device_name()}")
- Run code on specific ranks:
%%rank[0,1]
print(f"Running on rank {rank}")
The library consists of four main components:
- Provides IPython magic commands for interaction
- Manages automatic distributed execution
- Handles namespace synchronization
- Key commands:
%dist_init
: Initialize workers%%distributed
: Execute on all ranks%%rank[n]
: Execute on specific ranks%sync
: Synchronize workers%dist_status
: Show worker status%dist_mode
: Toggle automatic mode%dist_shutdown
: Clean shutdown
- Runs on each GPU/CPU
- Executes distributed PyTorch code
- Maintains isolated Python namespace
- Features:
- REPL-like output capturing
- Error handling and reporting
- GPU device management
- Namespace synchronization
- Manages worker lifecycle
- Handles GPU assignments
- Monitors process health
- Provides:
- Clean process startup
- Status monitoring
- Graceful shutdown
- GPU utilization tracking
- Coordinates inter-process communication
- Uses ZMQ for efficient messaging
- Features:
- Asynchronous message handling
- Reliable message delivery
- Timeout management
- Worker targeting
%dist_init -n 2 # Start 2 workers
import torch
import torch.distributed as dist
# Create tensor on each GPU
x = torch.randn(100, 100).cuda()
# All-reduce across GPUs
dist.all_reduce(x)
print(f"Rank {rank}: {x.mean():.3f}") # Same value on all ranks
%%rank[0]
# Only runs on rank 0
model = torch.nn.Linear(10, 10).cuda()
print("Model created on rank 0")
# In another cell:
# Broadcast model parameters to all ranks
for param in model.parameters():
dist.broadcast(param.data, src=0)
print(f"Rank {rank} received model")
%dist_status
# Shows:
# - Process status
# - GPU assignments
# - Memory usage
# - Device names
Specify exact GPU-to-rank mapping:
%dist_init -n 4 -g "0,1,2,3" # Assign specific GPUs
The library automatically syncs worker namespaces to enable IDE features:
- Code completion
- Type hints
- Variable inspection
Errors are caught and reported with:
- Full traceback
- Rank information
- GPU context
The library provides robust error recovery:
%dist_reset # Complete environment reset
%dist_init # Start fresh