-
Notifications
You must be signed in to change notification settings - Fork 92
[BUG] Running copy-related operations in instances of used defined TensordDict
subclasses, returns a TensordDict
object
#1184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Hello! import torch
class IntTensor(torch.Tensor):
def __new__(cls, *args, **kwargs):
print('__new__')
self = torch.empty((), dtype=torch.int)
return super().__new__(cls, self)
def __init__(self, tensor: torch.Tensor):
print('__init__')
super().__init__()
t = tensor.to(torch.int)
self.set_(t.untyped_storage(), storage_offset=t.storage_offset(), stride=t.stride(), size=t.shape)
inttensor = IntTensor(torch.zeros((), dtype=torch.float))
assert inttensor.dtype == torch.int
assert isinstance(inttensor/2, IntTensor)
assert (inttensor/2).dtype == torch.int As you can see, tensor subclasses ops give you tensor subclasses (which I understand is what you want), however, whatever it is I do in TensorDict is purely built in python, so it will probably not be exactly the same but, in general, I think that any solution we come up with will no call the To come back to your example, what I'm trying to say is that even if Another question is: what should we do with this use case Action items:
cc @albanD if you're interested in sharing your 2 cents on the topic of sublcassing tensordict |
I am sure people that actually studied CS have a much more structured thinking about this (things like LSP https://en.wikipedia.org/wiki/Liskov_substitution_principle that Ed sent me last time I asked haha). My 2 cents are: Tensor are a bit special because it is done in c++ and the default subclass behavior (via torch_function) is really about preserving the subclass type, not much else. In a more proper subclassing world, without the Tensor constraint that only plain Tensor exist in c++, then I think you can make things better, but I would suggest to follow the regular subclassing behavior from python rather than the weird torch.Tensor behavior.
It is very much expected for subclass that they might need to implement specific methods or do adapters before calling the super() method. |
@vmoens In the specific scenario, I would not want to run the command for auto batch size or device. However, I would expect the value of these attributes to be passed from the original object instance to the new one. Additionally, and I think also very important, I would expect any other
My opinion would be that running
I do agree with this point from @albanD . It would seem normal if I had to implement a function to help with this "transfer" of attributes from one object to the new one. |
Do you mean that any op with this guy class SubTD(TensorDict):
def __init__(*args, **kwargs):
super().__init__(*args, **kwargs)
self.p = "hello!" should return a What if the class SubTD(TensorDict):
SOME_GLOBAL_COUNTER = 0
def __init__(*args, **kwargs):
super().__init__(*args, **kwargs)
type(self).SOME_GLOBAL_COUNTER += 1 That again will only be called when you explicitly call
That's not impossible but it's going to be more opinionated than one would think. The Lines 1555 to 1565 in a172326
If we replace that by def _convert_to_tensordict(
self, dict_value: dict[str, Any], non_blocking: bool | None = None
) -> T:
return type(self)(
dict_value,
batch_size=self.batch_size,
device=self.device,
names=self._maybe_names(),
lock=self.is_locked,
non_blocking=non_blocking,
) we explicitly assume that the In other words: in this example AutoTensorDict({"a": {"b": 0}}) we have no way to tell tensordict what to do with the sub-dict when a subclass is used. The only thing we could do is pretend it's an One rule of thumbs in TensorDict that we try to obey is that features should be intuitive given a small set of axiomas. I think that here we're in one of these regions where things are a bit blurry and every use case will differ slightly in what is to be considered a reasonable behavior... |
That's indeed what I meant, but I see how this can create problems. To be honest, I mostly had the copy-related operations in mind (e.g. clone(), to()) where the expected behavior from a user standpoint would be that the new tensor has the same data just in a different memory space. Also, I just realized that
I got confused with the last part, I thought that with the example replacement snippet, Also, a side question. Is there a particular reason why we can't do something like |
I would note that t = torch.randn(3)
t.stuff_that_should_not_be_there = 0
t.clone().stuff_that_should_not_be_there # breaks which to me seems like an appropriate behaviour. Indeed, NonTensorData should be your friend if you want to carry metadata through
What I meant is that we could do other = self.__new__(type(self))
other._new_unsafe(*args, **kwargs) where |
Yeah, the more we talk about this the more it makes sense, especially considering that there's already NonTensorData in order to carry metadata in subclasses.
Oh, I see, so that we can have an |
Cool! |
@vmoens everything seems to work just fine! Thanks for the super fast update! |
Describe the bug
I have created a
TensorDict
subclass namedAutoTensorDict
that fits my use case. However, trying to run.to()
,.clone()
, and other copy-related operations to instances of this class returns aTensorDict
object and not anAutoTensorDict
object.To Reproduce
Here's an example to reproduce this:
This snippet prints:
Expected behavior
I would expect an instance of the newly defined class to be returned.
Reason and Possible fixes
At first, I thought this was related to the function
to_tensordict
being called around insideTensorDict
andTensorDictBase
. However, I tried to monkey-patch it but nothing came from this.Checklist
The text was updated successfully, but these errors were encountered: