-
Notifications
You must be signed in to change notification settings - Fork 930
NCCL_SHM_USE_CUDA_MEMCPY=1 causes hang in PyTorch #803
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Can you try with |
Thanks for the quick reply! I tried that, no luck. Here is a full repro, NCCL version 2.17.1-1 with trunk PyTorch. It also repros on PyTorch 2.0 with built in NCCL. CUDA 11.8, 530.30.02 driver. import os
os.environ['NCCL_SHM_USE_CUDA_MEMCPY'] = '1'
os.environ['NCCL_CREATE_THREAD_CONTEXT'] = '1'
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def init_process(rank, size, backend='nccl'):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank, world_size=size)
buf0 = torch.randn(32768).cuda(rank)
dist.all_reduce(buf0)
print(f"{rank}: all reduce done")
# NOTE: uncommenting this synchronize will fix it, but it shouldn't be needed
#torch.cuda.synchronize(rank)
buf0.cpu()
print(f"{rank}: to cpu")
if __name__ == "__main__":
p1 = mp.Process(target=init_process, args=(0, 2))
p2 = mp.Process(target=init_process, args=(1, 2))
p1.start()
p2.start()
p1.join()
p2.join() |
Ahh! This NCCL_CREATE_THREAD_CONTEXT seems to almost fix it. If I add:
to |
So I believe the bug is deeper than that. This construct exists twice in proxy.cc: if (ncclSetThreadContext(comm) != ncclSuccess) {
WARN("[Proxy Progress] Failed to set CUDA context on device %d", comm->cudaDev);
} else if (cudaSetDevice(comm->cudaDev) != cudaSuccess) {
WARN("[Proxy Progress] Failed to set CUDA device %d", comm->cudaDev);
} If ncclSetThreadContext succeeds, it also calls cudaSetDevice. But I don't think this is right. Calling cudaSetDevice resets the context: https://stackoverflow.com/questions/62877646/what-does-cudasetdevice-do-to-a-cuda-devices-context-stack I suspect what you meant was this: if (cudaSetDevice(comm->cudaDev) != cudaSuccess) {
WARN("[Proxy Progress] Failed to set CUDA device %d", comm->cudaDev);
} else if (ncclSetThreadContext(comm) != ncclSuccess) {
WARN("[Proxy Progress] Failed to set CUDA context on device %d", comm->cudaDev);
} If I change it to that the code works. (NOTE: this might be redundant, does the context contain the device?) |
Sorry for the delay. Indeed, this was broken at some point. We were only supposed to call cudaSetDevice (which would override the context we just set) if NCCL_CREATE_THREAD_CONTEXT was not set to 1. Can you check the attached patch works? Your patch works, but calls cudaSetDevice when we would not need to call it. |
Add support for IB SHARP to NVLS (NVLink SHARP algorithm). Add NVLS+Tree algorithm. Add support for memory management using cuMem* functions. Use all NICs for Send/Receive operations on systems with more than one NIC per GPU (#804). Add ncclCommSplit primitive, with resource sharing option in config. Fix alltoallv hang (#788) Increase number of channels on H100 when we're not limited by NVLink. Improve error reporting in case of IB failure, printing local and remote ID (#779). Add build option to allow compilation against RDMA includes instead of dynamically loading IB verbs symbols (#802). Fix context creation for progress thread (#803). NET/IB: add option to use multiple QPs in round-robin mode. Fix tree performance issue when NVB is disabled on HCM topologies.
Hi, I tried the script that @geohot provided with pytorch 2.1.1+cu121 (nccl version 2.18.6). It doesn't work for me, which means my program still hangs. Any ideas, comments, or suggestions? Thanks! |
Using cuda memcpy is much faster, so I'm trying to make it work. On a machine with two non nvlinked 3090s.
If I comment out the torch.cuda.synchronize, it works, otherwise it hangs. I suspect it has to do with the stream created in
shmSendProxyConnect
(also, does that stream work with cuda graphs?)If this is something you'd like to look into, I can help with a more full reproduction.
The text was updated successfully, but these errors were encountered: