8000 [Torch] Add index_put operator by apivovarov · Pull Request #7465 · apache/tvm · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[Torch] Add index_put operator #7465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2010,6 +2010,32 @@ def scatter(self, inputs, input_types):
src = inputs[3]
return _op.transform.scatter(data, index, src, axis)

def index_put(self, inputs, input_types):
in_tensor = inputs[0]
indices = inputs[1]
values = inputs[2]
accumulate = inputs[3]
# accumulate parameter is ignored.
# torch.index_put default is False but Relay.scatter_nd accumulates values.
# We assume there is no duplicate indices in torch.index_put input
if not accumulate:
logging.warning(
"torch.index_put accumulate parameter is False. "
"TVM uses tvm.relay.scatter_nd operator which accumulates values. "
"Make sure there is no duplicate indices in torch.index_put input."
)
# Relay scatter_nd does not support input tensor
# We assume that torch.index_put is used with empty zero-values input tensor
# scatter_nd will create empty zero-values tensor with a given shape
out_shape = self.infer_shape(in_tensor)
logging.warning(
"tvm.relay.scatter_nd operator does not support input tensor parameter. "
"TVM assumes that torch.index_put is used with empty zero-values input tensor"
)
# Combine array of index tensors into one index tensor with shape (N,_)
index_tensor = _op.stack(indices, axis=0)
return _op.transform.scatter_nd(values, index_tensor, out_shape)

def scalar_tensor(self, inputs, input_types):
data = inputs[0]
cast_map = {
Expand Down Expand Up @@ -2326,6 +2352,8 @@ def create_convert_map(self):
"aten::nonzero": self.nonzero,
"aten::nonzero_numpy": self.nonzero_numpy,
"aten::scatter": self.scatter,
"aten::index_put": self.index_put,
"aten::index_put_": self.index_put,
"aten::scalar_tensor": self.scalar_tensor,
"aten::__interpolate": self.interpolate,
"aten::IntImplicit": self.identity,
Expand Down
1 change: 1 addition & 0 deletions tests/python/driver/tvmc/test_frontends.py
< 8000 td class="blob-code blob-code-context js-file-line"> def test_load_model__pth(pytorch_resnet18):
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def test_load_model___wrong_language__to_onnx(tflite_mobilenet_v1_1_quant):
tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="onnx")


@pytest.mark.skip(reason="https://github.com/apache/tvm/issues/7455")
# some CI environments wont offer torch, so skip in case it is not present
pytest.importorskip("torch")
Expand Down
32 changes: 32 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3327,6 +3327,38 @@ def test_fn_scatter_add(dim):
verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], targets)


def test_forward_index_put():
# torch.index_put for 2D tensor and default accumulate (False)
def test_fn_index_put2():
return lambda data, xidx, yidx, values: torch.index_put(
data, indices=[xidx, yidx], values=values
)

# torch.index_put for 3D tensor and accumulate=True
def test_fn_index_put3a():
return lambda data, xidx, yidx, zidx, values: torch.index_put(
data, indices=[xidx, yidx, zidx], values=values, accumulate=True
)

shape = (3, 5)
in_data = torch.zeros(shape)
xidx = torch.tensor([0, 1, 2, 2])
yidx = torch.tensor([0, 1, 3, 4])
values = torch.tensor([2.0, 4.0, 7.0, 9.0])

targets = ["llvm", "cuda"]
verify_trace_model(test_fn_index_put2(), [in_data, xidx, yidx, values], targets)

shape = (3, 5, 3)
in_data = torch.zeros(shape)
xidx = torch.tensor([0, 1, 2, 2, 0])
yidx = torch.tensor([0, 1, 3, 4, 0])
zidx = torch.tensor([0, 1, 1, 2, 0])
values = torch.tensor([2.0, 4.0, 7.0, 9.0, 1.0])

verify_trace_model(test_fn_index_put3a(), [in_data, xidx, yidx, zidx, values], targets)


def test_numel():
class Numel(Module):
def forward(self, data):
Expand Down
0