8000 [Feature] torch.export and onnx compatibility by vmoens · Pull Request #991 · pytorch/tensordict · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[Feature] torch.export and onnx compatibility #991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -1987,7 +1987,7 @@ def from_dict(
)

batch_size_set = torch.Size(()) if batch_size is None else batch_size
input_dict = copy(input_dict)
input_dict = dict(input_dict)
for key, value in list(input_dict.items()):
if isinstance(value, (dict,)):
# we don't know if another tensor of smaller size is coming
Expand Down
39 changes: 26 additions & 13 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,7 @@ def __getitem__(self, index: IndexType) -> Any:
# _unravel_key_to_tuple will return an empty tuple if the index isn't a NestedKey
idx_unravel = _unravel_key_to_tuple(index)
if idx_unravel:
result = self._get_tuple(idx_unravel, NO_DEFAULT)
if is_non_tensor(result):
result_data = getattr(result, "data", NO_DEFAULT)
if result_data is NO_DEFAULT:
return result.tolist()
return result_data
return result
return self._get_tuple_maybe_non_tensor(idx_unravel, NO_DEFAULT)

if (istuple and not index) or (not istuple and index is Ellipsis):
# empty tuple returns self
Expand Down Expand Up @@ -4669,6 +4663,15 @@ def _get_str(self, key, default): ...
@abc.abstractmethod
def _get_tuple(self, key, default): ...

def _get_tuple_maybe_non_tensor(self, key, default):
result = self._get_tuple(key, default)
if is_non_tensor(result):
result_data = getattr(result, "data", NO_DEFAULT)
if result_data is NO_DEFAULT:
return result.tolist()
return result_data
return result

def get_at(
self, key: NestedKey, index: IndexType, default: CompatibleType = NO_DEFAULT
) -> CompatibleType:
Expand Down Expand Up @@ -8549,25 +8552,34 @@ def _remove_batch_dim(self, vmap_level, batch_size, out_dim): ...
def _maybe_remove_batch_dim(self, funcname, vmap_level, batch_size, out_dim): ...

# Validation and checks
def _convert_to_tensor(self, array: np.ndarray) -> Tensor:
def _convert_to_tensor(
self, array: Any
) -> Tensor | "NonTensorData" | TensorDictBase: # noqa: F821
# We are sure that array is not a dict or anything in _ACCEPTED_CLASSES
castable = None
if isinstance(array, (float, int, bool)):
pass
castable = True
elif isinstance(array, np.ndarray) and array.dtype.names is not None:
return TensorDictBase.from_struct_array(array, device=self.device)
elif isinstance(array, np.ndarray):
castable = array.dtype.kind in ("c", "i", "f", "b", "u")
elif isinstance(array, np.bool_):
castable = True
array = array.item()
elif isinstance(array, list):
elif isinstance(array, (list, tuple)):
array = np.asarray(array)
castable = array.dtype.kind in ("c", "i", "f", "b", "u")
elif hasattr(array, "numpy"):
# tf.Tensor with no shape can't be converted otherwise
array = array.numpy()
try:
castable = array.dtype.kind in ("c", "i", "f", "b", "u")
if castable:
return torch.as_tensor(array, device=self.device)
except Exception:
else:
from tensordict.tensorclass import NonTensorData

return NonTensorData(
array,
data=array,
batch_size=self.batch_size,
device=self.device,
names=self._maybe_names(),
Expand Down Expand Up @@ -8624,6 +8636,7 @@ def _validate_value(
)
is_tc = True
elif not issubclass(cls, _ACCEPTED_CLASSES):
# If cls is not a tensor
try:
value = self._convert_to_tensor(value)
except ValueError as err:
Expand Down
52 changes: 51 additions & 1 deletion tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,53 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return func(_self, tensordict, *args, **kwargs)
return func(tensordict, *args, **kwargs)

return self._update_func_signature(func, wrapper)

def _update_func_signature(self, func, wrapper):
# Create a new signature with the desired parameters
# Get the original function's signature
orig_signature = inspect.signature(func)

# params = [inspect.Parameter(name='', kind=inspect.Parameter.VAR_POSITIONAL)]
params = []
i = -1
for i, param in enumerate(orig_signature.parameters.values()):
if param.kind in (
inspect.Parameter.VAR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
):
i = i - 1
break
if param.default is inspect._empty:
params.append(
inspect.Parameter(
name=param.name,
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=None,
)
)
else:
params.append(param)

# Add the **kwargs parameter

# for key in self.get_source(func, self_func):
if i >= 0:
params.extend(list(orig_signature.parameters.values())[i + 1 :])
elif i == -1:
params.extend(list(orig_signature.parameters.values()))

# Update the wrapper's signature
wrapper.__signature__ = inspect.Signature(params)

return wrapper

def get_source(self, func, self_func):
source = self.source
if isinstance(source, str):
return getattr(self_func, source)
return source


class _OutKeysSelect:
def __init__(self, out_keys):
Expand Down Expand Up @@ -1226,7 +1271,12 @@ def forward(
tensors = ()
else:
# TODO: v0.7: remove the None
tensors = tuple(tensordict.get(in_key, None) for in_key in self.in_keys)
tensors = tuple(
tensordict._get_tuple_maybe_non_tensor(
_unravel_key_to_tuple(in_key), None
)
for in_key in self.in_keys
)
try:
tensors = self._call_module(tensors, **kwargs)
except Exception as err:
Expand Down
1 change: 1 addition & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def __subclasscheck__(self, subclass):
"_get_names_idx", # no wrap output
"_get_str",
"_get_tuple",
"_get_tuple_maybe_non_tensor",
"_has_names",
"_items_list",
"_maybe_names",
Expand Down
129 changes: 129 additions & 0 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
# LICENSE file in the root directory of this source tree.
import argparse
import contextlib
import importlib.util
import os
from pathlib import Path
from typing import Any

import pytest
Expand All @@ -14,9 +16,14 @@

from tensordict import assert_close, tensorclass, TensorDict, TensorDictParams
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
from torch.utils._pytree import tree_map

TORCH_VERSION = version.parse(torch.__version__).base_version

_has_onnx = importlib.util.find_spec("onnxruntime", None) is not None

_v2_5 = version.parse(".".join(TORCH_VERSION.split(".")[:3])) >= version.parse("2.5.0")


def test_vmap_compile():
# Since we monkey patch vmap we need to make sure compile is happy with it
Expand Down Expand Up @@ -605,6 +612,33 @@ def remove_hidden(td):
assert_close(module(td), module_compile(td))
assert module_compile(td) is not td

def test_dispatch_nontensor(self, mode):
torch._dynamo.reset_code_caches()

# Non tensor
x = torch.randn(3)
y = None
mod = Seq(
Mod(lambda x, y: x[y, :], in_keys=["x", "y"], out_keys=["_z"]),
Mod(lambda x, z: z * x, in_keys=["x", "_z"], out_keys=["out"]),
)
assert mod(x=x, y=y)[-1].shape == torch.Size((1, 3))
mod_compile = torch.compile(mod, fullgraph=_v2_5, mode=mode)
torch.testing.assert_close(mod(x=x, y=y), mod_compile(x=x, y=y))

def test_dispatch_tensor(self, mode):
torch._dynamo.reset_code_caches()

x = torch.randn(3)
y = torch.randn(3)
mod = Seq(
Mod(lambda x, y: x + y, in_keys=["x", "y"], out_keys=["z"]),
Mod(lambda x, z: z * x, in_keys=["x", "z"], out_keys=["out"]),
)
mod(x=x, y=y)
mod_compile = torch.compile(mod, fullgraph=_v2_5, mode=mode)
torch.testing.assert_close(mod(x=x, y=y), mod_compile(x=x, y=y))


@pytest.mark.skipif(not (TORCH_VERSION > "2.4.0"), reason="requires torch>2.4")
@pytest.mark.parametrize("mode", [None, "reduce-overhead"])
Expand Down Expand Up @@ -737,6 +771,101 @@ def call(x, td):
assert (td_zero == 0).all()


@pytest.mark.skipif(not _v2_5, reason="Requires PT>=2.5")
class TestExport:
def test_export_module(self):
torch._dynamo.reset_code_caches()
tdm = Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"])
x = torch.randn(3)
y = torch.randn(3)
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y})
assert (out.module()(x=x, y=y) == tdm(x=x, y=y)).all()

def test_export_seq(self):
torch._dynamo.reset_code_caches()
tdm = Seq(
Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]),
Mod(lambda z, x: z + x, in_keys=["z", "x"], out_keys=["out"]),
)
x = torch.randn(3)
y = torch.randn(3)
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y})
torch.testing.assert_close(out.module()(x=x, y=y), tdm(x=x, y=y))


@pytest.mark.skipif(not _has_onnx, reason="ONNX is not available")
class TestONNXExport:
def test_onnx_export_module(self, tmpdir):
tdm = Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"])
x = torch.randn(3)
y = torch.randn(3)
torch_input = {"x": x, "y": y}
onnx_program = torch.onnx.dynamo_export(tdm, **torch_input)

onnx_input = onnx_program.adapt_torch_inputs_to_onnx(**torch_input)

path = Path(tmpdir) / "file.onnx"
onnx_program.save(str(path))
import onnxruntime

ort_session = onnxruntime.InferenceSession(
path, providers=["CPUExecutionProvider"]
)

def to_numpy(tensor):
return (
tensor.detach().cpu().numpy()
if tensor.requires_grad
else tensor.cpu().numpy()
)

onnxruntime_input = {
k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)
}

onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
torch.testing.assert_close(
torch.as_tensor(onnxruntime_outputs[0]), tdm(x=x, y=y)
)

def test_onnx_export_seq(self, tmpdir):
tdm = Seq(
Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]),
Mod(lambda z, x: z + x, in_keys=["z", "x"], out_keys=["out"]),
)
x = torch.randn(3)
y = torch.randn(3)
torch_input = {"x": x, "y": y}
torch.onnx.dynamo_export(tdm, x=x, y=y)
onnx_program = torch.onnx.dynamo_export(tdm, **torch_input)

onnx_input = onnx_program.adapt_torch_inputs_to_onnx(**torch_input)

path = Path(tmpdir) / "file.onnx"
onnx_program.save(str(path))
import onnxruntime

ort_session = onnxruntime.InferenceSession(
path, providers=["CPUExecutionProvider"]
)

def to_numpy(tensor):
return (
tensor.detach().cpu().numpy()
if tensor.requires_grad
else tensor.cpu().numpy()
)

onnxruntime_input = {
k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)
}

onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
torch.testing.assert_close(
tree_map(torch.as_tensor, onnxruntime_outputs), tdm(x=x, y=y)
)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
Loading
0