8000 [Refactor, Tests] Move TestCudagraphs by vmoens · Pull Request #1007 · pytorch/tensordict · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[Refactor, Tests] Move TestCudagraphs #1007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 194 additions & 4 deletions test/test_compile.py
10000
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,34 @@
import argparse
import contextlib
import importlib.util
import inspect
import os
from pathlib import Path
from typing import Any
from typing import Any, Callable

import pytest

import torch
from packaging import version

from tensordict import assert_close, tensorclass, TensorDict, TensorDictParams
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
from torch.utils._pytree import tree_map
from tensordict import (
assert_close,
PYTREE_REGISTERED_LAZY_TDS,
PYTREE_REGISTERED_TDS,
tensorclass,
TensorDict,
TensorDictParams,
)
from tensordict.nn import (
CudaGraphModule,
TensorDictModule,
TensorDictModule as Mod,
TensorDictSequential as Seq,
)

from tensordict.nn.functional_modules import _exclude_td_from_pytree

from torch.utils._pytree import SUPPORTED_NODES, tree_map

TORCH_VERSION = version.parse(torch.__version__).base_version

Expand Down Expand Up @@ -871,3 +887,177 @@ def to_numpy(tensor):
if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)


@pytest.mark.skipif(TORCH_VERSION <= "2.4.1", reason="requires torch>=2.5")
@pytest.mark.parametrize("compiled", [False, True])
class TestCudaGraphs:
@pytest.fixture(scope="class", autouse=True)
def _set_cuda_device(self):
device = torch.get_default_device()
do_unset = False
for tdtype in PYTREE_REGISTERED_TDS + PYTREE_REGISTERED_LAZY_TDS:
if tdtype in SUPPORTED_NODES:
do_unset = True
excluder = _exclude_td_from_pytree()
excluder.set()
break
if torch.cuda.is_available():
torch.set_default_device("cuda:0")
yield
if do_unset:
excluder.unset()
torch.set_default_device(device)

def test_cudagraphs_random(self, compiled):
def func(x):
return x + torch.randn_like(x)

if compiled:
func = torch.compile(func)

with (
pytest.warns(UserWarning)
if not torch.cuda.is_available()
else contextlib.nullcontext()
):
func = CudaGraphModule(func)

x = torch.randn(10)
for _ in range(10):
func(x)
assert isinstance(func(torch.zeros(10)), torch.Tensor)
assert (func(torch.zeros(10)) != 0).any()
y0 = func(x)
y1 = func(x + 1)
with pytest.raises(AssertionError):
torch.testing.assert_close(y0, y1 + 1)

@staticmethod
def _make_cudagraph(
func: Callable, compiled: bool, *args, **kwargs
) -> CudaGraphModule:
if compiled:
func = torch.compile(func)
with (
pytest.warns(UserWarning)
if not torch.cuda.is_available()
else contextlib.nullcontext()
):
func = CudaGraphModule(func, *args, **kwargs)
return func

@staticmethod
def check_types(func, *args, **kwargs):
signature = inspect.signature(func)
bound_args = signature.bind(*args, **kwargs)
bound_args.apply_defaults()
for param_name, param in signature.parameters.items():
arg_value = bound_args.arguments[param_name]
if param.annotation != param.empty:
if not isinstance(arg_value, param.annotation):
raise TypeError(
f"Argument '{param_name}' should be of type {param.annotation}, but is of type {type(arg_value)}"
)

def test_signature(self, compiled):
if compiled:
pytest.skip()

def func(x: torch.Tensor):
return x + torch.randn_like(x)

with pytest.raises(TypeError):
self.check_types(func, "a string")
self.check_types(func, torch.ones(()))

def test_backprop(self, compiled):
x = torch.nn.Parameter(torch.ones(3))
y = torch.nn.Parameter(torch.ones(3))
optimizer = torch.optim.SGD([x, y], lr=1)

def func():
optimizer.zero_grad()
z = x + y
z = z.sum()
z.backward()
optimizer.step()

func = self._make_cudagraph(func, compiled, warmup=4)

for i in range(1, 11):
torch.compiler.cudagraph_mark_step_begin()
func()

assert (x == 1 - i).all(), i
assert (y == 1 - i).all(), i
# assert (x.grad == 1).all()
# assert (y.grad == 1).all()

def test_tdmodule(self, compiled):
tdmodule = TensorDictModule(lambda x: x + 1, in_keys=["x"], out_keys=["y"])
tdmodule = self._make_cudagraph(tdmodule, compiled)
assert tdmodule._is_tensordict_module
for i in range(10):
td = TensorDict(x=torch.randn(()))
tdmodule(td)
assert td["y"] == td["x"] + 1, i

tdmodule = TensorDictModule(lambda x: x + 1, in_keys=["x"], out_keys=["y"])
tdmodule = self._make_cudagraph(tdmodule, compiled)
assert tdmodule._is_tensordict_module
for _ in range(10):
x = torch.randn(())
y = tdmodule(x=x)
assert y == x + 1

tdmodule = TensorDictModule(lambda x: x + 1, in_keys=["x"], out_keys=["y"])
tdmodule = self._make_cudagraph(tdmodule, compiled)
assert tdmodule._is_tensordict_module
for _ in range(10):
td = TensorDict(x=torch.randn(()))
tdout = TensorDict()
tdmodule(td, tensordict_out=tdout)
assert tdout is not td
assert "x" not in tdout
assert tdout["y"] == td["x"] + 1

tdmodule = lambda td: td.set("y", td.get("x") + 1)
tdmodule = self._make_cudagraph(tdmodule, compiled, in_keys=[], out_keys=[])
assert tdmodule._is_tensordict_module
for i in range(10):
td = TensorDict(x=torch.randn(()))
tdmodule(td)
assert tdmodule._out_matches_in
if i >= tdmodule._warmup and torch.cuda.is_available():
assert tdmodule._selected_keys == ["y"]
assert td["y"] == td["x"] + 1

tdmodule = lambda td: td.set("y", td.get("x") + 1)
tdmodule = self._make_cudagraph(
tdmodule, compiled, in_keys=["x"], out_keys=["y"]
)
assert tdmodule._is_tensordict_module
for _ in range(10):
td = TensorDict(x=torch.randn(()))
tdmodule(td)
assert td["y"] == td["x"] + 1

tdmodule = lambda td: td.copy().set("y", td.get("x") + 1)
tdmodule = self._make_cudagraph(tdmodule, compiled, in_keys=[], out_keys=[])
assert tdmodule._is_tensordict_module
for _ in range(10):
td = TensorDict(x=torch.randn(()))
tdout = tdmodule(td)
assert tdout is not td
assert "y" not in td
assert tdout["y"] == td["x"] + 1

def test_td_input_non_tdmodule(self, compiled):
func = lambda x: x + 1
func = self._make_cudagraph(func, compiled)
for i in range(10):
td = TensorDict(a=1)
func(td)
if i == 5:
assert not func._is_tensordict_module
Loading
Loading
0