-
Notifications
You must be signed in to change notification settings - Fork 930
NCCL allreduce is slower than others in some certain process groups #820
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
It could be. Could you set |
@sjeaugey Thanks for you reply. Would you please tell me how to set
|
Actually it seems there is no difference between GPUs 0/1 and others. Each pair of GPUs (0-1, 2-3, 4-5 and 6-7) all have 2 local NICs. I see nothing unusual here and it should work fine. What could happen here is that the two slow flows (node 0 GPU 0 <-> node 1 GPU 0 and node 0 GPU 1 <-> node 1 GPU 1) end up using the same network link at some point in the fabric. Are the two nodes connected to the same switch or different switches through a multi-level network fabric? Do you have a rail-optimized fabric (different switch per NIC). If not, are you using adaptive routing? Without a rail-optimized fabric and without adaptive routing, these kind of route collision is quite common. |
@sjeaugey we have rail-optimized fabric and we use the similar network solution as selene. We doubt if using both local nics per gpu might bring some collision locally. Is there any way to limit only one local nic using per GPU?
|
Ok thanks. The fact GPU 0 uses both NICs is normal. If you dump the rings topology with
So it's normal for GPU 0 to communicate with NET 0 and NET 1; that's how we guarantee traffic stays local to each rail on a communicator with more than one GPU per node. The logs you pasted here are for the communicator with all GPUs (16 channels, using all 8 NICs). Could you get the log for the communicator with only one GPU per node, and check that communicator is indeed using a different NIC on each GPU? |
@sjeaugey Thanks for your advice.
|
Ok, interesting. To better understand what's going on, it would be helpful if you could run with 8 GPUs per node, set The NCCL code is supposed to already detect PEX PCI switches and flatten the PCI topology, so it should not be a problem in theory. It could be a NIC ordering issue, or something else ... which the log would tell me hopefully. |
Sure. Just get 2 nodes available. |
It seems this is using NCCL 2.10 which is not detecting the topology correctly:
Previous logs were using NCCL 2.12, it would be good to use a recent enough version. Regardless of that, indeed mlx5_1 and mlx5_0 are not enumerated in the PCI order. If that's the way it is enumerated on all nodes, it should not be a problem, but if it is not consistent, it could indeed explain the lower performance. Can you confirm whether whether nodes have different NIC enumeration order or not? If not, can you check whether reordering the NICs in the XML with mlx5_0 first (and re-injecting the XML with |
We have tried NCCL 2.15 in job environment. It's the same result with 2.10. It happens when adding NICs to xml and seems no relevant difference between the versions.
Yes, we fix it by reordering the NICs and it works. It's proved to be OK for a larger scale of 16 nodes. We would launch about hundreds of nodes later.
The NICs are all enumerated in the PCI order in system. But NCCL doesn't seem to sort it that way and the reason is the irregular topology brought by PEX88096 PCIe switch. In this case It causes NIC 0 not used while doing multi-allreduce test, So the achieved communication bandwidth is halved in scenarios where we do model parallelism intra-node and pipeline or data parallelism inter-node. |
I'm failing to see why NICs get added in a reverse order. Can you get a log with |
Actually nevermind, I figured out why that happens. It's due to the fact mlx5_1 is on the same switch as GPU 0, and since we add GPUs first and NICs after that, after the PEX switch is flattened, mlx5_1 ends up before mlx5_0. Now I need to figure out why it is a problem and whether we should reorder them. |
After investigating, I still don't see why it would be a problem. When we create a communicator with 1 GPU per node, GPU 0 should get NIC 1 and GPU 1 should get NIC 0. Would you be able to provide the log for GPU 1? The only log was for GPU 0 and it was using NIC 1 as expected. |
I don't have stable access to available machines. I'll get more clear logs next time if you need. |
Can you run again with a recent version of NCCL? The NIC selection code for rings was changed in NCCL 2.11 to make two GPUs on the same PCI switch with 2 NICs not use the same NIC. Your original bug description was mentioning NCCL 2.10, 2.12 and 2.15. Please use NCCL 2.15 if you can, 2.12 worst case, but do not use NCCL 2.10, because NCCL 2.10 is supposed to show the performance issue you are reporting. |
Ok, problem understood. When we run only with GPU 1, then mlx5_1 is no longer on the same switch as a GPU which was added before, and the reordering does not occur. Now we need to find out how to end up with a consistent topology graph, regardless of the GPUs which are part of the communicator. |
Can you check whether the issue is fixed with the attached patch (applies on top of 2.18; would probably also work with previous versions). Thanks! |
OK. I'll verify it when test resources are ready. |
nccl_test.txt |
Fix NVLS search (issue NVIDIA#931). Increase max IB NICs to 32. Fix inconsistent device ordering (issue NVIDIA#820). Try to use different devices for different GPUs in systems with more than one NIC per GFU.
Uh oh!
There was an error while loading. Please reload this page.
I had 4-nodes and only use the first 2 nodes to create 8 process groups to perform all-reduce tests. Each node has 8 NVIDIA A100-SXM4-80GB GPUs, and 8 200Gb/s IB NICs. The i-th process group contains all of the i-th GPUs in the first 2 nodes. All of these process groups would perform all-reduce operations simultaneously with same data size. But it is strange that the time cost of each process group is not the same. The results are as follows, and show that the all-reduce operations on GPU 0 and GPU 1 are slower than others.
I tried to use NCCL v2.12.10-1, v2.12.12-1 and v2.15.5-1, and
export NCCL_CROSS_NIC=0
orunset NCCL_CROSS_NIC
. But none of them solved my problems.Here are my informations about my environments.
lspci -tv
. The PCIe topology of GPU-NIC pair of GPU 0 and 1 is different from other GPUs.export NCCL_DEBUG=INFO
.GPU 0:
GPU 7
We found that GPU 0 are connected with two IB NICs in the NCCL logs. May this be the reason why GPU 0 and GPU 1 is slower than others?
Please give us some suggestions on how to solve this problems. Thank you.
The text was updated successfully, but these errors were encountered: