diff --git a/theano/tensor/basic.py b/theano/tensor/basic.py index 60f77847709..bb2773957de 100644 --- a/theano/tensor/basic.py +++ b/theano/tensor/basic.py @@ -1662,21 +1662,29 @@ def __getitem__(self, args): # standard indexing is used; if it fails with # AdvancedIndexingError, advanced indexing advanced = False - for arg in args: + axis = None + for i, arg in enumerate(args): try: arg == numpy.newaxis or Subtensor.convert(arg) except AdvancedIndexingError: - advanced = True - break + if advanced: + axis = None + break + else: + advanced = True + axis = i if advanced: - if (len(args) == 1 - and isinstance(args[0], ( + if (axis is not None + and numpy.all(a == slice(None) for a in args[:axis]) + and numpy.all(a == slice(None) for a in args[axis+1:]) + and isinstance(args[axis], ( + numpy.ndarray, list, TensorVariable, TensorConstant, theano.tensor.sharedvar.TensorSharedVariable))): - return advanced_subtensor1(self, *args) + return self.take(arg, axis) else: return AdvancedSubtensor()(self, *args) else: @@ -1705,6 +1713,9 @@ def __getitem__(self, args): return Subtensor(args)(self, *Subtensor.collapse(args, lambda entry: isinstance(entry, Variable))) + def take(self, indices, axis=None, mode='raise'): + return take(self, indices, axis, mode) + # COPYING def copy(self): return tensor_copy(self) @@ -6811,6 +6822,36 @@ def R_op(self, inputs, eval_points): *inputs[2:]).outputs advanced_inc_subtensor = AdvancedIncSubtensor() +def take(a, indices, axis=None, mode='raise'): + a = as_tensor_variable(a) + indices = as_tensor_variable(indices) + # Reuse advanced_subtensor1 if indices is a vector + if indices.ndim == 1: + if mode == 'clip': + indices = clip(indices, 0, a.shape[axis]-1) + elif mode == 'wrap': + indices = indices % a.shape[axis] + if axis is None: + return advanced_subtensor1(a.flatten(), indices) + elif axis == 0: + return advanced_subtensor1(a, indices) + else: + if axis < 0: + axis += a.ndim + assert axis >= 0 + shuffle = range(a.ndim) + shuffle[0] = axis + shuffle[axis] = 0 + return advanced_subtensor1( + a.dimshuffle(shuffle), indices).dimshuffle(shuffle) + if axis is None: + shape = indices.shape + ndim = indices.ndim + else: + shape = concatenate([a.shape[:axis], indices.shape, a.shape[axis+1:]]) + ndim = a.ndim + indices.ndim - 1 + return take(a, indices.flatten(), axis, mode).reshape(shape, ndim) + ######################### # Linalg : Dot ######################### diff --git a/theano/tensor/tests/test_basic.py b/theano/tensor/tests/test_basic.py index 78a2fb60cc0..01ff5b01222 100644 --- a/theano/tensor/tests/test_basic.py +++ b/theano/tensor/tests/test_basic.py @@ -7167,6 +7167,26 @@ def test_diagonal(self): assert_array_equal(X.diagonal(offset, axis1, axis2).eval({X: x}), x.diagonal(offset, axis1, axis2)) + def test_take(self): + X, _ = self.vars + x, _ = self.vals + indices = [1,0,3] + assert_array_equal(X.take(indices).eval({X: x}), x.take(indices)) + indices = [1,0,1] + assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1)) + indices = [-10,5,12] + assert_array_equal(X.take(indices, 1, mode='wrap').eval({X: x}), + x.take(indices, 1, mode='wrap')) + assert_array_equal(X.take(indices, -1, mode='wrap').eval({X: x}), + x.take(indices, -1, mode='wrap')) + assert_array_equal(X.take(indices, 1, mode='clip').eval({X: x}), + x.take(indices, 1, mode='clip')) + assert_array_equal(X.take(indices, -1, mode='clip').eval({X: x}), + x.take(indices, -1, mode='clip')) + indices = [[1,0,1], [0,1,1]] + assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1)) + # Test equivalent advanced indexing + assert_array_equal(X[:,indices].eval({X: x}), x[:,indices]) if __name__ == '__main__':