8000 [Feature] Subclass conservation in td ops by vmoens · Pull Request #1186 · pytorch/tensordict · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[Feature] Subclass conservation in td ops #1186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions tensordict/_td.py
< 8000 /tr> 10000
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def _new_unsafe(
if source: # faster than calling items
for key, value in source.items():
if nested and isinstance(value, dict):
value = TensorDict._new_unsafe(
value = cls._new_unsafe(
source=value,
batch_size=self._batch_size,
device=self._device,
Expand Down Expand Up @@ -374,7 +374,7 @@ def from_module(
filter_empty=filter_empty,
)
if result is None:
result = TensorDict._new_unsafe({}, batch_size=torch.Size(()))
result = cls._new_unsafe({}, batch_size=torch.Size(()))
if lock:
result.lock_()
return result
Expand Down Expand Up @@ -419,7 +419,7 @@ def _from_module(
destination = hook_result
if not filter_empty or destination:
destination_set = True
destination = TensorDict._new_unsafe(destination, batch_size=torch.Size(()))
destination = cls._new_unsafe(destination, batch_size=torch.Size(()))
else:
destination_set = False
for name, submodule in module._modules.items():
Expand All @@ -433,7 +433,7 @@ def _from_module(
)
if subtd is not None:
if not destination_set:
destination = TensorDict._new_unsafe(batch_size=torch.Size(()))
destination = cls._new_unsafe(batch_size=torch.Size(()))
destination_set = True
destination._set_str(
name, subtd, validated=True, inplace=False, non_blocking=False
Expand Down Expand Up @@ -610,7 +610,7 @@ def _quick_set(swap_dict, swap_td):
_quick_set(_swap, swap_dest)
return swap_dest
else:
return TensorDict._new_unsafe(_swap, batch_size=torch.Size(()))
return self._new_unsafe(_swap, batch_size=torch.Size(()))

@_maybe_broadcast_other("__ne__")
def __ne__(self, other: Any) -> T | bool:
Expand Down Expand Up @@ -1479,7 +1479,7 @@ def _add_batch_dim_wrapper(key, value):
return value
return _add_batch_dim(value, in_dim, vmap_level)

out = TensorDict._new_unsafe(
out = self._new_unsafe(
{key: _add_batch_dim_wrapper(key, value) for key, value in td.items()},
batch_size=torch.Size(
[b for i, b in enumerate(td.batch_size) if i != in_dim]
Expand Down Expand Up @@ -1613,7 +1613,7 @@ def _check_for_invalid_index(index):
)
else:
source[key] = _get_item(item, index)
result = TensorDict._new_unsafe(
result = self._new_unsafe(
source=source,
batch_size=batch_size,
device=self.device,
Expand Down Expand Up @@ -1694,7 +1694,7 @@ def empty(
is_shared=is_shared,
is_memmap=is_memmap,
):
result = TensorDict._new_unsafe(
result = self._new_unsafe(
{}, batch_size=batch_size, names=names, device=device
)
result._is_shared = is_shared
Expand Down Expand Up @@ -3231,7 +3231,7 @@ def _clone(self, recurse: bool = True) -> T:
if recurse and self.device is not None:
return self._clone_recurse()

result = TensorDict._new_unsafe(
result = self._new_unsafe(
source={key: _clone_value(value, recurse) for key, value in self.items()},
batch_size=self.batch_size,
device=self.device,
Expand All @@ -3248,7 +3248,7 @@ def contiguous(self) -> T:
source = {key: value.contiguous() for key, value in self.items()}
batch_size = self.batch_size
device = self.device
out = TensorDict._new_unsafe(
out = self._new_unsafe(
source=source,
batch_size=batch_size,
device=device,
Expand All @@ -3260,7 +3260,7 @@ def empty(
self, recurse=False, *, batch_size=None, device=NO_DEFAULT, names=NO_DEFAULT
) -> T:
if not recurse:
return TensorDict._new_unsafe(
return self._new_unsafe(
device=self._device if device is NO_DEFAULT else device,
batch_size=(
self._batch_size if batch_size is None else torch.Size(batch_size)
Expand Down Expand Up @@ -3309,7 +3309,7 @@ def _select(
*val, strict=strict, inplace=inplace, set_shared=set_shared
)

result = TensorDict._new_unsafe(
result = self._new_unsafe(
device=self.device,
batch_size=self.batch_size,
source=source,
Expand Down Expand Up @@ -3358,7 +3358,7 @@ def _exclude(
_tensordict[key] = val
if inplace:
return self
result = TensorDict._new_unsafe(
result = self._new_unsafe(
_tensordict,
batch_size=self.batch_size,
device=self.device,
Expand Down Expand Up @@ -4059,7 +4059,7 @@ def is_contiguous(self) -> bool:
return all(value.is_contiguous() for value in self.values())

def contiguous(self) -> T:
return TensorDict._new_unsafe(
return self._new_unsafe(
batch_size=self.batch_size,
source={key: value.contiguous() for key, value in self.items()},
device=self.device,
Expand Down
32 changes: 21 additions & 11 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
from tensordict.utils import (
_check_keys,
_ErrorInteceptor,
_is_tensorclass,
_pass_through,
_shape,
_zip_strict,
DeviceType,
is_tensorclass,
lazy_legacy,
set_lazy_legacy,
)
Expand Down Expand Up @@ -138,12 +138,12 @@ def _gather_tensor(tensor, dest_container=None, dest_key=None):
return out

if out is None:
if len(index.shape) == input.ndim and input._has_names():
names = input.names
if len(index.shape) == input.ndim:
names = input._maybe_names()
else:
names = None
device = input.device
return TensorDict(
return type(input)._new_unsafe(
{
key: _gather_tensor(value)
for key, value in input.items(is_leaf=_is_leaf_nontensor)
Expand Down Expand Up @@ -300,6 +300,7 @@ def _cat(
raise RuntimeError("list_of_tensordicts cannot be empty")

batch_size = list(list_of_tensordicts[0].batch_size)
tdtype = type(list_of_tensordicts[0])
if dim < 0:
dim = len(batch_size) + dim
if dim >= len(batch_size):
Expand Down Expand Up @@ -334,9 +335,13 @@ def _cat(
names = None
if list_of_tensordicts[0]._has_names():
names = list_of_tensordicts[0].names
return TensorDict._new_unsafe(
out, device=device, batch_size=batch_size, names=names
)
# if we have a TD subclass, use _new_unsafe bc we know it exists. Otherwise, use
# TensorDict's one
if issubclass(tdtype, TensorDict):
clz = tdtype
else:
clz = TensorDict
return clz._new_unsafe(out, device=device, batch_size=batch_size, names=names)
else:
if out.batch_size != batch_size:
raise RuntimeError(
Expand Down Expand Up @@ -453,14 +458,19 @@ def _stack(
raise RuntimeError("list_of_tensordicts cannot be empty")
if maybe_dense_stack is None:
maybe_dense_stack = lazy_legacy()
is_tc = any(is_tensorclass(td) for td in list_of_tensordicts)
td_types = [type(td) for td in list_of_tensordicts]
is_tc = any(_is_tensorclass(td_type) for td_type in td_types)
if all(_pass_through(td) for td in list_of_tensordicts):
return type(list_of_tensordicts[0])._stack_non_tensor(
list_of_tensordicts, dim=dim
)
if is_tc:
tc_type = type(list_of_tensordicts[0])
list_of_tensordicts = [tc._tensordict for tc in list_of_tensordicts]
clz = type(list_of_tensordicts[0])
elif issubclass(td_types[0], TensorDict):
clz = td_types[0]
else:
clz = TensorDict

batch_size = list_of_tensordicts[0].batch_size
if dim < 0:
Expand Down Expand Up @@ -617,15 +627,15 @@ def stack_fn(key, values, is_not_init, is_tensor):
for key, (values, is_not_init, is_tensor) in out.items()
}

result = TensorDict._new_unsafe(
result = clz._new_unsafe(
out,
batch_size=LazyStackedTensorDict._compute_batch_size(
batch_size, dim, len(list_of_tensordicts)
),
device=device,
)
if is_tc:
return tc_type._from_tensordict(result)
return td_types[0]._from_tensordict(result)
return result
else:
out = LazyStackedTensorDict(
Expand Down
10 changes: 10 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,16 @@ class TensorDictBase(MutableMapping):
_memmap_prefix = None
_stream: torch.cuda.Stream | None = None

@classmethod
def _new_unsafe(cls, *args, **kwargs):
# This to make sure all TensorDictBase subclasses have a proper fallback if they don't have a _new_unsafe
# In other words, only TensorDict subclasses will have their type preserved, others will become TensorDict
# instances (note that TensorDictBase should not be directly subclassed outside of this codebase, as it is
# highly abstract).
from tensordict._td import TensorDict

return TensorDict._new_unsafe(*args, **kwargs)

def __bool__(self) -> bool:
raise RuntimeError("Converting a tensordict to boolean value is not permitted")

Expand Down
22 changes: 13 additions & 9 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def _new_unsafe(
cls,
parameters: TensorDictBase,
*,
no_convert=False,
no_convert=None,
lock: bool = False,
params: dict | None = None,
buffers: dict | None = None,
Expand All @@ -399,24 +399,28 @@ def _new_unsafe(
if is_compiling():
return TensorDictParams(parameters, no_convert="skip", lock=lock)

self = TensorDictParams.__new__(cls)
nn.Module.__init__(self)

if parameters is None:
parameters = kwargs
elif kwargs:
raise TypeError(
f"parameters cannot be passed along with extra keyword arguments, but got {kwargs.keys()} extra args."
)

if isinstance(parameters, dict):
parameters = TensorDict._new_unsafe(parameters)
parameters = TensorDict._new_unsafe(parameters, **kwargs)
if no_convert is None:
# Then _new_unsafe is called from somewhere that doesn't know
# that it's a TDParams and we return a TensorDict (eg, torch.gather)
return parameters
elif isinstance(parameters, TensorDictParams):
if kwargs:
raise TypeError(
f"parameters cannot be passed along with extra keyword arguments, but got {kwargs.keys()} extra args."
)
params = dict(pa FE13 rameters._parameters)
buffers = dict(parameters._buffers)
parameters = parameters._param_td
no_convert = "skip"

self = TensorDictParams.__new__(cls)
nn.Module.__init__(self)

self._param_td = parameters
self.no_convert = no_convert
if no_convert != "skip":
Expand Down
32 changes: 32 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2709,6 +2709,38 @@ def test_reduction_feature_full(self, reduction):
assert getattr(td, reduction)(reduce=True, dim="feature").shape == (3, 4)
assert getattr(td, reduction)(reduce=True, dim=1).shape == (3, 5)

def test_subclassing(self):
class SubTD(TensorDict): ...

t = SubTD(a=torch.randn(3))
assert isinstance(t + t, SubTD)
assert isinstance(t / 2, SubTD)
assert isinstance(2 / t, SubTD)
assert isinstance(t.to(torch.float), SubTD)
assert isinstance(t.to("cpu"), SubTD)
assert isinstance(torch.zeros_like(t), SubTD)
assert isinstance(t.copy(), SubTD)
assert isinstance(t.clone(), SubTD)
assert isinstance(t.empty(), SubTD)
assert isinstance(t.select(), SubTD)
assert isinstance(t.exclude("a"), SubTD)
assert isinstance(t.split_keys({"a"})[0], SubTD)
assert isinstance(t.flatten_keys(), SubTD)
assert isinstance(t.unflatten_keys(), SubTD)
stack = torch.stack([t, t])
assert isinstance(stack, SubTD)
assert isinstance(stack[0], SubTD)
assert isinstance(stack.unbind(0)[0], SubTD)
assert isinstance(stack.split(1)[0], SubTD)
assert isinstance(stack.gather(0, torch.ones((1,), dtype=torch.long)), SubTD)
unsqueeze = stack.unsqueeze(0)
assert isinstance(unsqueeze, SubTD)
assert isinstance(unsqueeze.transpose(1, 0), SubTD)
assert isinstance(unsqueeze.permute(1, 0), SubTD)
assert isinstance(unsqueeze.squeeze(), SubTD)
assert isinstance(unsqueeze.reshape(-1), SubTD)
assert isinstance(unsqueeze.view(-1), SubTD)

@pytest.mark.parametrize("device", get_available_devices())
def test_subtensordict_construction(self, device):
torch.manual_seed(1)
Expand Down
Loading
0