8000 [core] Support broadcast and reduce collective for compiled graphs by jeffreyjeffreywang · Pull Request #53625 · ray-project/ray · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
< 8000 div id="partial-discussion-header" class="gh-header mb-3 js-details-container Details js-socket-channel js-updatable-content pull request js-pull-header-details" data-channel="eyJjIjoicHVsbF9yZXF1ZXN0OjI1NzQ0Nzg4NjMiLCJ0IjoxNzUxMTYxNzQ4fQ==--054fbc6b85dbe2af96bf2fba87317a4fad8beae9f53aed2af300943e3a1f9c84" data-url="/ray-project/ray/pull/53625/partials/title?sticky=false" data-channel-event-name="title_updated" data-pull-is-open="true" data-gid="PR_kwDOBEmZvc6Zc24P">

[core] Support broadcast and reduce collective for compiled graphs #53625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions python/ray/dag/collective_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
AllGatherOp,
AllReduceOp,
ReduceScatterOp,
BroadcastOp,
ReduceOperation,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not thrilled about this naming. Using ReduceOp for the collective could conflict with actual reduction ops like min, max, and sum (code). Open to better naming ideas.

)
from ray.util.annotations import DeveloperAPI

Expand All @@ -33,19 +35,27 @@ class _CollectiveOperation:
1. Input nodes are unique.
2. Actor handles are unique.
3. Actor handles match the custom NCCL group if specified.
4. If root_node is specified, it must be an input node.
"""

def __init__(
self,
input_nodes: List[DAGNode],
op: _CollectiveOp,
root_node: Optional[DAGNode] = None,
transport: Optional[Union[str, Communicator]] = None,
):
if len(input_nodes) == 0:
raise ValueError("Expected input nodes for a collective operation")
if len(set(input_nodes)) != len(input_nodes):
raise ValueError("Expected unique input nodes for a collective operation")

self._root_actor_handle = (
root_node._get_actor_handle() if root_node is not None else None
)
if root_node is not None and root_node not in input_nodes:
raise ValueError("Expected the root node to be an input node")

self._actor_handles: List["ray.actor.ActorHandle"] = []
for input_node in input_nodes:
actor_handle = input_node._get_actor_handle()
Expand Down Expand Up @@ -135,6 +145,14 @@ def execute(self, send_buf: "torch.Tensor") -> "torch.Tensor":
device=send_buf.device,
)
communicator.reducescatter(send_buf, recv_buf, self._op.reduceOp)
elif isinstance(self._op, BroadcastOp):
recv_buf = torch.empty_like(send_buf)
root_rank = communicator.get_rank(self._root_actor_handle)
communicator.broadcast(send_buf, recv_buf, root_rank)
elif isinstance(self._op, ReduceOperation):
recv_buf = torch.empty_like(send_buf)
root_rank = communicator.get_rank(self._root_actor_handle)
communicator.reduce(send_buf, recv_buf, root_rank, self._op.reduceOp)
else:
raise ValueError("Expected a collective operation")
return recv_buf
Expand All @@ -147,9 +165,7 @@ class CollectiveOutputNode(ClassMethodNode):
def __init__(
self,
method_name: str,
method_args: Tuple[
DAGNode,
],
method_args: Tuple[DAGNode,],
method_kwargs: Dict[str, Any],
method_options: Dict[str, Any],
other_args_to_resolve: Dict[str, Any],
Expand Down
200 changes: 200 additions & 0 deletions python/ray/dag/tests/experimental/test_torch_tensor_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,25 @@ def reducescatter(
self._inner.reducescatter(send_buf, recv_buf, op)
recv_buf += 1

def broadcast(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
root_rank: int,
) -> None:
self._inner.broadcast(send_buf, recv_buf, root_rank)
recv_buf += 1

def reduce(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
root_rank: int,
op: ReduceOp = ReduceOp.SUM,
) -> None:
self._inner.reduce(send_buf, recv_buf, root_rank, op)
recv_buf += 1

@property
def recv_stream(self):
return self._inner.recv_stream
Expand Down Expand Up @@ -687,6 +706,23 @@ def reducescatter(
) -> None:
raise NotImplementedError

def broadcast(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
root_rank: int,
) -> None:
raise NotImplementedError

def reduce(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
root_rank: int,
op: ReduceOp = ReduceOp.SUM,
) -> None:
raise NotImplementedError

@property
def recv_stream(self):
return AcceleratorContext.get().current_stream()
Expand Down Expand Up @@ -831,6 +867,23 @@ def reducescatter(
) -> None:
raise NotImplementedError

def broadcast(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
root_rank: int,
) -> None:
8000 raise NotImplementedError

def reduce(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
root_rank: int,
op: ReduceOp = ReduceOp.SUM,
) -> None:
raise NotImplementedError

@property
def recv_stream(self):
return AcceleratorContext.get().current_stream()
Expand Down Expand Up @@ -988,6 +1041,23 @@ def reducescatter(
) -> None:
raise NotImplementedError

def broadcast(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
root_rank: int,
) -> None:
raise NotImplementedError

def reduce(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
root_rank: int,
op: ReduceOp = ReduceOp.SUM,
) -> None:
raise NotImplementedError

@property
def recv_stream(self):
return AcceleratorContext.get().current_stream()
Expand Down Expand Up @@ -1348,6 +1418,11 @@ def test_torch_tensor_explicit_communicator(ray_start_regular):
(collective.reducescatter, ReduceOp.PRODUCT),
(collective.reducescatter, ReduceOp.MIN),
(collective.reducescatter, ReduceOp.MAX),
(collective.broadcast, None),
(collective.reduce, ReduceOp.SUM),
(collective.reduce, ReduceOp.PRODUCT),
6D47 (collective.reduce, ReduceOp.MIN),
(collective.reduce, ReduceOp.MAX),
],
)
def test_torch_tensor_nccl_collective_ops(ray_start_regular, operation, reduce_op):
Expand All @@ -1368,6 +1443,10 @@ def test_torch_tensor_nccl_collective_ops(ray_start_regular, operation, reduce_o
]
if operation == collective.allgather:
collectives = operation.bind(computes)
elif operation == collective.broadcast:
collectives = operation.bind(computes[0], computes)
elif operation == collective.reduce:
collectives = operation.bind(computes[0], computes, op=reduce_op)
else:
collectives = operation.bind(computes, op=reduce_op)
recvs = [
Expand Down Expand Up @@ -1440,6 +1519,20 @@ def test_torch_tensor_nccl_collective_ops(ray_start_regular, operation, reduce_o
)
else:
raise ValueError(f"Unknown reduce_op: {reduce_op}")
elif operation == collective.broadcast:
expected_tensors = [input_tensors[0] for _ in range(num_workers)]
elif operation == collective.reduce:
# Only validate result of root node for NCCL reduce
if reduce_op == ReduceOp.SUM:
expected_tensors = [torch.sum(torch.stack(input_tensors), dim=0)]
elif reduce_op == ReduceOp.PRODUCT:
expected_tensors = [torch.prod(torch.stack(input_tensors), dim=0)]
elif reduce_op == ReduceOp.MIN:
expected_tensors = [torch.min(torch.stack(input_tensors), dim=0).values]
elif reduce_op == ReduceOp.MAX:
expected_tensors = [torch.max(torch.stack(input_tensors), dim=0).values]
else:
raise ValueError(f"Unknown reduce_op: {reduce_op}")
else:
raise ValueError(f"Unknown operation: {operation}")

Expand Down Expand Up @@ -1489,6 +1582,52 @@ def test_torch_tensor_nccl_all_reduce_get_partial(ray_start_regular):
assert torch.equal(tensor, expected_tensor_val)


@pytest.mark.skipif(not USE_GPU, reason="Skipping GPU Test")
@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
def test_torch_tensor_nccl_broadcast_get_partial(ray_start_regular):
"""
Test getting partial results from a broadcast does not hang.
"""
assert (
sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1
), "This test requires at least 2 GPUs"

actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1)

num_workers = 2
workers = [actor_cls.remote() for _ in range(num_workers)]

shape = (10,)
dtype = torch.float16

with InputNode() as inp:
computes = [
worker.compute_with_tuple_args.bind(inp, i)
for i, worker in enumerate(workers)
]

collectives = collective.broadcast.bind(computes[0], computes)
recv_root = workers[0].recv.bind(collectives[0])
recv = workers[1].recv.bind(collectives[1])
tensor = workers[1].recv_tensor.bind(collectives[1])
dag = MultiOutputNode([recv_root, recv, tensor])

compiled_dag = dag.experimental_compile()

for i in range(3):
ref = compiled_dag.execute(
[(shape, dtype, i + idx + 1) for idx in range(num_workers)]
)
result = ray.get(ref)
_, metadata, tensor = result

root_val = i + 1
assert metadata == (root_val, shape, dtype)
tensor = tensor.to("cpu")
expected_tensor_val = torch.ones(shape, dtype=dtype) * root_val
assert torch.equal(tensor, expected_tensor_val)


@pytest.mark.skipif(not USE_GPU, reason="Skipping GPU Test")
@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
def test_torch_tensor_nccl_all_reduce_wrong_shape(ray_start_regular):
Expand Down Expand Up @@ -1635,6 +1774,25 @@ def reducescatter(
self._inner.reducescatter(send_buf, recv_buf, op)
recv_buf += 1

def broadcast(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
root_rank: int,
) -> None:
self._inner.broadcast(send_buf, recv_buf, root_rank)
recv_buf += 1

def reduce(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
root_rank: int,
op: ReduceOp = ReduceOp.SUM,
) -> None:
self._inner.reduce(send_buf, recv_buf, root_rank, op)
recv_buf += 1

@property
def recv_stream(self):
return self._inner.recv_stream
Expand Down Expand Up @@ -1896,6 +2054,48 @@ def test_torch_nccl_channel_with_all_local_readers(ray_start_regular):
dag.experimental_compile()


@pytest.mark.parametrize("ray_start_regular", [{"num_gpus": 4}], indirect=True)
@pytest.mark.parametrize("collective_op", [collective.broadcast, collective.reduce])
def test_torch_tensor_nccl_wrong_root_node(ray_start_regular, collective_op):
actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1)

num_workers = 2
root_worker = actor_cls.remote()
workers = [actor_cls.remote() for _ in range(num_workers)]

with pytest.raises(
ValueError,
match="Expected the root node to be an input node",
):
with InputNode() as inp:
root_compute = root_worker.compute_with_tuple_args.bind(inp, 0)
computes = [
worker.compute_with_tuple_args.bind(inp, i)
for i, worker in enumerate(workers)
]
collective_op.bind(root_compute, computes)


@pytest.mark.parametrize("ray_start_regular", [{"num_gpus": 4}], indirect=True)
@pytest.mark.parametrize("collective_op", [collective.broadcast, collective.reduce])
def test_torch_tensor_nccl_no_root_node(ray_start_regular, collective_op):
actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1)

num_workers = 2
workers = [actor_cls.remote() for _ in range(num_workers)]

with pytest.raises(
TypeError,
match="missing 1 required positional argument",
):
with InputNode() as inp:
computes = [
worker.compute_with_tuple_args.bind(inp, i)
for i, worker in enumerate(workers)
]
collective_op.bind(computes)


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
Expand Down
19 changes: 19 additions & 0 deletions python/ray/experimental/channel/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,25 @@ def reducescatter(
"""
raise NotImplementedError

@abstractmethod
EE2E def broadcast(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
root: int,
) -> None:
raise NotImplementedError

@abstractmethod
def reduce(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
root: int,
op: ReduceOp,
) -> None:
raise NotImplementedError

@abstractmethod
def destroy(self) -> None:
"""
Expand Down
17 changes: 17 additions & 0 deletions python/ray/experimental/channel/cpu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,23 @@ def reducescatter(
):
raise NotImplementedError

def broadcast(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
root: int,
):
raise NotImplementedError

def reduce(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
root: int,
op: ReduceOp = ReduceOp.SUM,
):
raise NotImplementedError

def destroy(self) -> None:
for barrier in self.barriers:
ray.kill(barrier)
Expand Down
Loading
0