8000 [BUG] Running copy-related operations in instances of used defined `TensordDict` subclasses, returns a `TensordDict` object · Issue #1184 · pytorch/tensordict · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[BUG] Running copy-related operations in instances of used defined TensordDict subclasses, returns a TensordDict object #1184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
3 tasks done
alex-bene opened this issue Jan 15, 2025 · 9 comments · Fixed by #1186 or #1197
Assignees
Labels
bug Something isn't working

Comments

@alex-bene
Copy link

Describe the bug

I have created a TensorDict subclass named AutoTensorDict that fits my use case. However, trying to run .to(), .clone(), and other copy-related operations to instances of this class returns a TensorDict object and not an AutoTensorDict object.

To Reproduce

Here's an example to reproduce this:

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import torch
from tensordict.tensordict import TensorDict

if TYPE_CHECKING:
    from collections.abc import Sequence

    from tensordict._nestedkey import NestedKey
    from tensordict.base import CompatibleType, T
    from tensordict.utils import DeviceType, IndexType
    from torch import Size


class AutoTensorDict(TensorDict):
    def __init__(
        self,
        source: T | dict[NestedKey, CompatibleType] = None,
        batch_size: Sequence[int] | Size | int | None = None,
        device: DeviceType | None = None,
        names: Sequence[str] | None = None,
        non_blocking: bool | None = None,
        lock: bool = False,
        **kwargs: dict[str, Any] | None,
    ) -> None:
        super().__init__(source, batch_size, device, names, non_blocking, lock, **kwargs)
        self.auto_batch_size_(1)
        if self.device is None:
            self.auto_device_()

    def __setitem__(self, key: IndexType, value: Any) -> None:
        super().__setitem__(key, value)
        if self.device is None:
            self.auto_device_()
        if not self.batch_size:
            self.auto_batch_size_(1)


if __name__ == "__main__":
    tt = AutoTensorDict()
    tt["a"] = torch.rand(3, 4)
    print(tt.to("mps"))

This snippet prints:

TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 4]), device=mps:0, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=mps,
    is_shared=False)

Expected behavior

I would expect an instance of the newly defined class to be returned.

Reason and Possible fixes

At first, I thought this was related to the function to_tensordict being called around inside TensorDict and TensorDictBase. However, I tried to monkey-patch it but nothing came from this.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
8000
@alex-bene alex-bene added the bug Something isn't working label Jan 15, 2025
@vmoens
Copy link
Collaborator
vmoens commented Jan 16, 2025

Hello!
Good question, but before we address it, let's see how a tensor behaves:

import torch

class IntTensor(torch.Tensor):
    def __new__(cls, *args, **kwargs):
        print('__new__')
        self = torch.empty((), dtype=torch.int)
        return super().__new__(cls, self)

    def __init__(self, tensor: torch.Tensor):
        print('__init__')
        super().__init__()
        t = tensor.to(torch.int)
        self.set_(t.untyped_storage(), storage_offset=t.storage_offset(), stride=t.stride(), size=t.shape)

inttensor = IntTensor(torch.zeros((), dtype=torch.float))
assert inttensor.dtype == torch.int
assert isinstance(inttensor/2, IntTensor)
assert (inttensor/2).dtype == torch.int

As you can see, tensor subclasses ops give you tensor subclasses (which I understand is what you want), however, whatever it is I do in __new__ and __init__ will be ignored.

TensorDict is purely built in python, so it will probably not be exactly the same but, in general, I think that any solution we come up with will no call the __init__ of your subclass during calls to to and similar things (for efficiency reasons as well as because your __init__ will not, in general, have the same signature as the parent class).

To come back to your example, what I'm trying to say is that even if AutoTensorDict(...).to("cuda") returns an AutoTensorDict, we won't be setting the batch-size and device automatically within it. Would that fix your issue, or are you expecting all the AutoTensorDict instances to have an auto batch-size and device?

Another question is: what should we do with this use case AutoTensorDict({"a": {"b": 0}}). One could say: if you want an inner and outer AutoTensorDict, you should do AutoTensorDict({"a": AutoTensorDict({"b": 0}})), but we could also argue that tensordict should assume that all nested classes are from the same type.

Action items:

cc @albanD if you're interested in sharing your 2 cents on the topic of sublcassing tensordict

@vmoens vmoens linked a pull request Jan 16, 2025 that will close this issue
@albanD
Copy link
Contributor
albanD commented Jan 16, 2025

I am sure people that actually studied CS have a much more structured thinking about this (things like LSP https://en.wikipedia.org/wiki/Liskov_substitution_principle that Ed sent me last time I asked haha).

My 2 cents are: Tensor are a bit special because it is done in c++ and the default subclass behavior (via torch_function) is really about preserving the subclass type, not much else.

In a more proper subclassing world, without the Tensor constraint that only plain Tensor exist in c++, then I think you can make things better, but I would suggest to follow the regular subclassing behavior from python rather than the weird torch.Tensor behavior.
In particular, I would say that methods on the object should either:

  • work with the input in an abstract manner without needing to know if it's a plain TensorDict or a subclass, doing all operations based on other methods. If you want a fast path, you can even branch there between plain TensorDict vs others.
  • raise an error saying that the subclass must implement this method.

It is very much expected for subclass that they might need to implement specific methods or do adapters before calling the super() method.

@alex-bene
Copy link
Author

To come back to your example, what I'm trying to say is that even if AutoTensorDict(...).to("cuda") returns an AutoTensorDict, we won't be setting the batch-size and device automatically within it. Would that fix your issue, or are you expecting all the AutoTensorDict instances to have an auto batch-size and device?

@vmoens In the specific scenario, I would not want to run the command for auto batch size or device. However, I would expect the value of these attributes to be passed from the original object instance to the new one. Additionally, and I think also very important, I would expect any other self attributes I have in the original object to be transferred to the new one.

Another question is: what should we do with this use case AutoTensorDict({"a": {"b": 0}}). One could say: if you want an inner and outer AutoTensorDict, you should do AutoTensorDict({"a": AutoTensorDict({"b": 0}})), but we could also argue that tensordict should assume that all nested classes are from the same type.

My opinion would be that running AutoTensorDict({"a": {"b": 0}}) implies that the nested dicts should also be AutoTensorDict since this is the constructor we call. If I wanted the inner ones to be TensorDict, in that case I should call AutoTensorDict({"a": TensorDict({"b": 0}})).

It is very much expected for subclass that they might need to implement specific methods or do adapters before calling the super() method.

I do agree with this point from @albanD . It would seem normal if I had to implement a function to help with this "transfer" of attributes from one object to the new one.

@vmoens
Copy link
Collaborator
vmoens commented Jan 16, 2025

@vmoens In the specific scenario, I would not want to run the command for auto batch size or device. However, I would expect the value of these attributes to be passed from the original object instance to the new one. Additionally, and I think also very important, I would expect any other self attributes I have in the original object to be transferred to the new one.

Do you mean that any op with this guy

class SubTD(TensorDict):
    def __init__(*args, **kwargs):
        super().__init__(*args, **kwargs)
        self.p = "hello!"

should return a SubTD with p attribute set to "hello!"? That would mean scanning through every instance during every op and copying all the attributes, which sounds a bit ambitious (and perhaps dangerous) to me. Inheriting the class is achievable, but copy-pasting all attributes is harder (think about the IntTensor` I gave earlier - what should you do when you divide an odd number by an even number?)

What if the __init__ has a python side effect

class SubTD(TensorDict):
    SOME_GLOBAL_COUNTER = 0
    def __init__(*args, **kwargs):
        super().__init__(*args, **kwargs)
        type(self).SOME_GLOBAL_COUNTER += 1

That again will only be called when you explicitly call __init__.

My opinion would be that running AutoTensorDict({"a": {"b": 0}}) implies that the nested dicts should also be AutoTensorDict since this is the constructor we call. If I wanted the inner ones to be TensorDict, in that case I should call AutoTensorDict({"a": TensorDict({"b": 0}})).

That's not impossible but it's going to be more opinionated than one would think. The __init__ of the subclass can presumably be anything. The line where we convert the sub-dictionary in a TensorDict is there:

tensordict/tensordict/_td.py

Lines 1555 to 1565 in a172326

def _convert_to_tensordict(
self, dict_value: dict[str, Any], non_blocking: bool | None = None
) -> T:
return TensorDict(
dict_value,
batch_size=self.batch_size,
device=self.device,
names=self._maybe_names(),
lock=self.is_locked,
non_blocking=non_blocking,
)

If we replace that by

    def _convert_to_tensordict(
        self, dict_value: dict[str, Any], non_blocking: bool | None = None
    ) -> T:
        return type(self)(
            dict_value,
            batch_size=self.batch_size,
            device=self.device,
            names=self._maybe_names(),
            lock=self.is_locked,
            non_blocking=non_blocking,
        )

we explicitly assume that the __init__ signature of the subclass matches the one from tensordict (and we can't use _new_unsafe because we don't know what's in the dict_value).

In other words: in this example

AutoTensorDict({"a": {"b": 0}})

we have no way to tell tensordict what to do with the sub-dict when a subclass is used. The only thing we could do is pretend it's an AutoTensorDict but actually use TensorDict.__init__ to build it.

One rule of thumbs in TensorDict that we try to obey is that features should be intuitive given a small set of axiomas. I think that here we're in one of these regions where things are a bit blurry and every use case will differ slightly in what is to be considered a reasonable behavior...

@alex-bene
Copy link
Author

should return a SubTD with p attribute set to "hello!"? That would mean scanning through every instance during every op and copying all the attributes, which sounds a bit ambitious (and perhaps dangerous) to me. Inheriting the class is achievable, but copy-pasting all attributes is harder (think about the IntTensor` I gave earlier - what should you do when you divide an odd number by an even number?)

That's indeed what I meant, but I see how this can create problems. To be honest, I mostly had the copy-related operations in mind (e.g. clone(), to()) where the expected behavior from a user standpoint would be that the new tensor has the same data just in a different memory space. Also, I just realized that TensorDicts can accept non-tensor arguments which probably solves my issue with copying parameters (in my use-case) but still not the "wrong class" situation.

The only thing we could do is pretent it's an AutoTensorDict but actually use TensorDict.init to build it.

I got confused with the last part, I thought that with the example replacement snippet, AutoTensorDict.__init__ would be used. Is that not the case?


Also, a side question. Is there a particular reason why we can't do something like TensorDict({"a":0, "b":1}, c=2) in general?

@vmoens
Copy link
Collaborator
vmoens commented Jan 17, 2025

That's indeed what I meant, but I see how this can create problems. To be honest, I mostly had the copy-related operations in mind (e.g. clone(), to()) where the expected behavior from a user standpoint would be that the new tensor has the same data just in a different memory space.

I would note that

t = torch.randn(3)
t.stuff_that_should_not_be_there = 0
t.clone().stuff_that_should_not_be_there # breaks

which to me seems like an appropriate behaviour.

Indeed, NonTensorData should be your friend if you want to carry metadata through clone and to ops.

The only thing we could do is pretent it's an AutoTensorDict

What I meant is that we could do

other = self.__new__(type(self))
other._new_unsafe(*args, **kwargs)

where ._new_unsafe is a fast proxy to TensorDict.__init__ that should not be overwritten.

@alex-bene
Copy link
Author

which to me seems like an appropriate behaviour.

Yeah, the more we talk about this the more it makes sense, especially considering that there's already NonTensorData in order to carry metadata in subclasses.

where ._new_unsafe is a fast proxy to TensorDict.init that should not be overwritten.

Oh, I see, so that we can have an __init__ in subclasses where the signature does not fully match the one from tensordict, that makes sense. Still, if the batch_size and/or device of self is set, I believe the same should be true for the other tensor.

@vmoens
Copy link
Collaborator
vmoens commented Jan 20, 2025

Cool!
So I will merge #1186 which solves the problem of inheriting a class through ops like to or copy!

@alex-bene
Copy link
Author

@vmoens everything seems to work just fine! Thanks for the super fast update!
I'm closing the issue for now and will re-open it if I find any related problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
3 participants
0