-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Take op [WIP] #1127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Take op [WIP] #1127
Conversation
I don't have strong feelings about allowing symbolic Could you also add a test to check that it works correctly with negative values? I tend to use those a lot, and they let you write a greater range of generic code without resorting to symbolics. |
def take(a, indices, axis=None, mode='raise'): | ||
a = as_tensor_variable(a) | ||
indices = as_tensor_variable(indices) | ||
# Reuse advanced indexing in supported cases. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that we can reuse AdvancedSubtensor1 for all cases (or at least many cases), in particular if mode == 'raise'.
When axis != -1, there is a way to use DimShuffle to put the desired axis in the latest position, and then dimshuffle back.
If mode == 'wrap', I think we can use indices % a.shape[axis], and when mode == 'clip', we could use tensor.clip(a, 0, a.shape[axis]), at least when axis is not None.
Did I miss a case
8000
where there is no way of specifying the appropriate computation using existing ops? Or would it be more efficient to have a new Op?
I agree with @lamblin: AdvancedSubtensor1 can be reused for all cases unless we want to support symbolic axis. I'll update the code and write more tests. |
For some reason AdvancedSubtensor1 is limited to the case of rank-1 (vector) indices. Numpy's take does not have this limitation: >>> np.array([42]).take(np.zeros((2,3,4), dtype=int)).shape
(2, 3, 4) |
Yes, AdvancedSubtensor1 has that limitation. I did not realize that numpy.take did not have it. Then, I guess we would have to use a new Op, or actually finish implementing AdvancedSubtensor. |
I guess this should be coordinated with @jsalvatier's work on gh-1083. In my view, Take is a simpler Op than advanced indexing and advanced indexing should reuse Take Op in special cases rather than the other way around. |
On the second thought, higher rank indices can be supported simply as x.take(i.flatten()).reshape(i.shape) |
I think the important thing is that the same graph is generated in those cases, whether it was created by advanced indexing (getitem) or by take(). This makes it easier for optimizations to handle. |
OK, now I understand what the actual restriction of numpy.take is: you cannot iterate jointly over different axes, but you can put indices in an array of any dimensions, and the resulting tensor would have x.ndim - 1 + a.ndim dimensions. |
Thanks for the addition, it looks fine to me. |
This pull request is related to issue #1080.
I am sharing work in progress to start a discussion on how x.take() should interplay with advanced indexing. Note that currently I don't allow symbolic values for the axis argument. This is not a necessary restriction, so comments welcome.