8000 Take op [WIP] by abalkin · Pull Request #1127 · Theano/Theano · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Take op [WIP] #1127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 12, 2012
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 47 additions & 6 deletions theano/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1662,21 +1662,29 @@ def __getitem__(self, args):
# standard indexing is used; if it fails with
# AdvancedIndexingError, advanced indexing
advanced = False
for arg in args:
axis = None
for i, arg in enumerate(args):
try:
arg == numpy.newaxis or Subtensor.convert(arg)
except AdvancedIndexingError:
advanced = True
break
if advanced:
axis = None
break
else:
advanced = True
axis = i

if advanced:
if (len(args) == 1
and isinstance(args[0], (
if (axis is not None
and numpy.all(a == slice(None) for a in args[:axis])
and numpy.all(a == slice(None) for a in args[axis+1:])
and isinstance(args[axis], (
numpy.ndarray,
list,
TensorVariable,
TensorConstant,
theano.tensor.sharedvar.TensorSharedVariable))):
return advanced_subtensor1(self, *args)
return self.take(arg, axis)
else:
return AdvancedSubtensor()(self, *args)
else:
Expand Down Expand Up @@ -1705,6 +1713,9 @@ def __getitem__(self, args):
return Subtensor(args)(self, *Subtensor.collapse(args,
lambda entry: isinstance(entry, Variable)))

def take(self, indices, axis=None, mode='raise'):
return take(self, indices, axis, mode)

# COPYING
def copy(self):
return tensor_copy(self)
Expand Down Expand Up @@ -6811,6 +6822,36 @@ def R_op(self, inputs, eval_points):
*inputs[2:]).outputs
advanced_inc_subtensor = AdvancedIncSubtensor()

def take(a, indices, axis=None, mode='raise'):
a = as_tensor_variable(a)
indices = as_tensor_variable(indices)
# Reuse advanced_subtensor1 if indices is a vector
if indices.ndim == 1:
if mode == 'clip':
indices = clip(indices, 0, a.shape[axis]-1)
elif mode == 'wrap':
indices = indices % a.shape[axis]
if axis is None:
return advanced_subtensor1(a.flatten(), indices)
elif axis == 0:
return advanced_subtensor1(a, indices)
else:
if axis < 0:
axis += a.ndim
assert axis >= 0
shuffle = range(a.ndim)
shuffle[0] = axis
shuffle[axis] = 0
return advanced_subtensor1(
a.dimshuffle(shuffle), indices).dimshuffle(shuffle)
if axis is None:
shape D69B = indices.shape
ndim = indices.ndim
else:
shape = concatenate([a.shape[:axis], indices.shape, a.shape[axis+1:]])
ndim = a.ndim + indices.ndim - 1
return take(a, indices.flatten(), axis, mode).reshape(shape, ndim)

#########################
# Linalg : Dot
#########################
Expand Down
20 changes: 20 additions & 0 deletions theano/tensor/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7167,6 +7167,26 @@ def test_diagonal(self):
assert_array_equal(X.diagonal(offset, axis1, axis2).eval({X: x}),
x.diagonal(offset, axis1, axis2))

def test_take(self):
X, _ = self.vars
x, _ = self.vals
indices = [1,0,3]
assert_array_equal(X.take(indices).eval({X: x}), x.take(indices))
indices = [1,0,1]
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
indices = [-10,5,12]
assert_array_equal(X.take(indices, 1, mode='wrap').eval({X: x}),
x.take(indices, 1, mode='wrap'))
assert_array_equal(X.take(indices, -1, mode='wrap').eval({X: x}),
x.take(indices, -1, mode='wrap'))
assert_array_equal(X.take(indices, 1, mode='clip').eval({X: x}),
x.take(indices, 1, mode='clip'))
assert_array_equal(X.take(indices, -1, mode='clip').eval({X: x}),
x.take(indices, -1, mode='clip'))
indices = [[1,0,1], [0,1,1]]
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
# Test equivalent advanced indexing
assert_array_equal(X[:,indices].eval({X: x}), x[:,indices])

if __name__ == '__main__':

Expand Down
0