From 732643d6e5c0a81671bc00209db0e44812c891f3 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Mon, 25 Mar 2024 20:16:46 +0900 Subject: [PATCH 1/4] feat: support aten.index_select converter --- .../dynamo/conversion/aten_ops_converters.py | 25 ++++++++++++ .../dynamo/conversion/impl/__init__.py | 1 + .../dynamo/conversion/impl/index.py | 28 ++++++++++++++ .../conversion/test_index_select_aten.py | 38 +++++++++++++++++++ 4 files changed, 92 insertions(+) create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/index.py create mode 100644 tests/py/dynamo/conversion/test_index_select_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 0dd153d0aa..9f9096487d 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2782,3 +2782,28 @@ def aten_ops_roll( args[1], args_bounds_check(args, 2, []), ) + + +@dynamo_tensorrt_converter(torch.ops.aten.index_select.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + 2: (TRTTensor,), + } +) +def aten_ops_index_select( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.index.index_select( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args[2], + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index ca71cb0b0c..2cacdc9ae4 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -12,6 +12,7 @@ elementwise, embedding, grid, + index, linear, matmul, normalization, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/index.py b/py/torch_tensorrt/dynamo/conversion/impl/index.py new file mode 100644 index 0000000000..2b4a28d08d --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/index.py @@ -0,0 +1,28 @@ +from typing import Optional + +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTTensor + + +def index_select( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: int, + index: TRTTensor, +) -> TRTTensor: + input_tensor = get_trt_tensor(ctx, input, f"{name}_input") + index_tensor = get_trt_tensor(ctx, index, f"{name}_index") + + # The axis parameter specifies the dimension along which to index. + gather_layer = ctx.net.add_gather(input_tensor, index_tensor, axis=dim) + + set_layer_name(gather_layer, target, f"{name}_gather", source_ir) + + return gather_layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_index_select_aten.py b/tests/py/dynamo/conversion/test_index_select_aten.py new file mode 100644 index 0000000000..05386fc722 --- /dev/null +++ b/tests/py/dynamo/conversion/test_index_select_aten.py @@ -0,0 +1,38 @@ +import torch +import torch.nn as nn +from harness import DispatchTestCase +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + + +class TestIndexSelectConverter(DispatchTestCase): + @parameterized.expand( + [ + ("1d_input", (10,), 0, (1,)), + ("2d_input_dim_0", (10, 3), 0, (0, 2)), + ("2d_input_dim_1", (5, 10), 1, (1, 2, 3)), + ("3d_input_dim_0", (10, 5, 10), 0, (0, 5)), + ("3d_input_dim_2", (10, 5, 10), 2, (3, 3, 4)), + ] + ) + def test_index_select(self, _, source_shape, dim, indices_val): + class TestIndexSelect(torch.nn.Module): + def forward(self, source_tensor, indices_tensor): + return torch.ops.aten.index_select.default( + source_tensor, dim, indices_tensor + ) + + input = [ + torch.randn(*source_shape, dtype=torch.float32), + torch.tensor([*indices_val], dtype=torch.int32), + ] + + self.run_test( + TestIndexSelect(), + input, + ) + + +if __name__ == "__main__": + run_tests() From 56c6d9d7d52c6ef77b1b97e172a42af649958500 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Tue, 26 Mar 2024 11:42:51 +0900 Subject: [PATCH 2/4] chore: remove ITensor convert function --- py/torch_tensorrt/dynamo/conversion/impl/index.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/index.py b/py/torch_tensorrt/dynamo/conversion/impl/index.py index 2b4a28d08d..555b37b7b4 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/index.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/index.py @@ -3,7 +3,6 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTTensor @@ -17,11 +16,8 @@ def index_select( dim: int, index: TRTTensor, ) -> TRTTensor: - input_tensor = get_trt_tensor(ctx, input, f"{name}_input") - index_tensor = get_trt_tensor(ctx, index, f"{name}_index") - # The axis parameter specifies the dimension along which to index. - gather_layer = ctx.net.add_gather(input_tensor, index_tensor, axis=dim) + gather_layer = ctx.net.add_gather(input, index, axis=dim) set_layer_name(gather_layer, target, f"{name}_gather", source_ir) From 20d43e054df04866314e83862c2019d0bd91fbd0 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Mon, 1 Apr 2024 15:38:26 +0900 Subject: [PATCH 3/4] feat: support negative dim for index.select --- .../dynamo/conversion/aten_ops_converters.py | 2 +- .../dynamo/conversion/impl/index.py | 24 --------------- .../dynamo/conversion/impl/select.py | 30 +++++++++++++++---- .../conversion/test_index_select_aten.py | 9 ++++-- 4 files changed, 31 insertions(+), 34 deletions(-) delete mode 100644 py/torch_tensorrt/dynamo/conversion/impl/index.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 9f9096487d..c43b541ed0 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2798,7 +2798,7 @@ def aten_ops_index_select( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.index.index_select( + return impl.select.index_select( ctx, target, SourceIR.ATEN, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/index.py b/py/torch_tensorrt/dynamo/conversion/impl/index.py deleted file mode 100644 index 555b37b7b4..0000000000 --- a/py/torch_tensorrt/dynamo/conversion/impl/index.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import Optional - -from torch.fx.node import Target -from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.fx.converters.converter_utils import set_layer_name -from torch_tensorrt.fx.types import TRTTensor - - -def index_select( - ctx: ConversionContext, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input: TRTTensor, - dim: int, - index: TRTTensor, -) -> TRTTensor: - # The axis parameter specifies the dimension along which to index. - gather_layer = ctx.net.add_gather(input, index, axis=dim) - - set_layer_name(gather_layer, target, f"{name}_gather", source_ir) - - return gather_layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index db586be65f..470abb8f48 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -90,7 +90,7 @@ def index( # is_numpy is a flag to specify if all the indices are numpy or torchTensor. # If any is not this flag will be set to False _LOGGER.debug( - f"Determining whether aten.index constant-index optimization can be invoked" + "Determining whether aten.index constant-index optimization can be invoked" ) is_numpy = all( isinstance(ind, (torch.Tensor, np.ndarray)) for ind in index if ind is not None @@ -123,7 +123,7 @@ def index( return identity_layer.get_output(0) elif len(tensor_indices) == 1: indices_tensor = get_trt_tensor( - ctx, tensor_indices[0], name + f"_parameter_to_fp32_tensor" + ctx, tensor_indices[0], name + "_parameter_to_fp32_tensor" ) index = adv_indx_indices[0] _LOGGER.debug(f"The advanced index indices is {adv_indx_indices}") @@ -204,7 +204,7 @@ def index( cum_adv_index = cum_adv_index + adv_index multiplier = multiplier * input_shape[adv_indx_indices[i]] cum_adv_index = get_trt_tensor( - ctx, cum_adv_index, name + f"_index_sum_intermediate" + ctx, cum_adv_index, name + "_index_sum_intermediate" ) else: multiplier = get_trt_tensor( @@ -263,7 +263,7 @@ def index( adv_indx_count == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1 ): - _LOGGER.debug(f"The indices are continuous in this case") + _LOGGER.debug("The indices are continuous in this case") concat_tensor_reshape.append( get_trt_tensor(ctx, -1, name + "_dynamic_concat") ) @@ -287,7 +287,7 @@ def index( source_ir, ) unfold_tensor = regular_index_shuffle_layer.get_output(0) - _LOGGER.debug(f"The tensor is unfolded now") + _LOGGER.debug("The tensor is unfolded now") _LOGGER.debug(f"The unfolded tensor shape is {unfold_tensor.shape}") # Transpose folded advanced indexed axis to its original location. @@ -342,7 +342,7 @@ def index( reshape_output = unfold_advanced_shuffle_layer.get_output(0) else: - _LOGGER.debug(f"The indices are not continuous in this case") + _LOGGER.debug("The indices are not continuous in this case") concat_final_tensor = [] concat_final_tensor.append(cum_adv_index_shape_tensor) for i in range(0, rank): @@ -370,3 +370,21 @@ def index( reshape_output = reshape_layer.get_output(0) return reshape_output + + +def index_select( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: int, + index: TRTTensor, +) -> TRTTensor: + # The axis parameter specifies the dimension along which to index. + dim = get_positive_dim(dim, len(input.shape)) + gather_layer = ctx.net.add_gather(input, index, axis=dim) + + set_layer_name(gather_layer, target, f"{name}_gather", source_ir) + + return gather_layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_index_select_aten.py b/tests/py/dynamo/conversion/test_index_select_aten.py index 05386fc722..83eaedb944 100644 --- a/tests/py/dynamo/conversion/test_index_select_aten.py +++ b/tests/py/dynamo/conversion/test_index_select_aten.py @@ -1,9 +1,9 @@ import torch import torch.nn as nn -from harness import DispatchTestCase -from parameterized import param, parameterized +from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt import Input + +from .harness import DispatchTestCase class TestIndexSelectConverter(DispatchTestCase): @@ -12,8 +12,11 @@ class TestIndexSelectConverter(DispatchTestCase): ("1d_input", (10,), 0, (1,)), ("2d_input_dim_0", (10, 3), 0, (0, 2)), ("2d_input_dim_1", (5, 10), 1, (1, 2, 3)), + ("2d_input_dim_-2", (5, 10), -2, (1, 2, 3)), ("3d_input_dim_0", (10, 5, 10), 0, (0, 5)), ("3d_input_dim_2", (10, 5, 10), 2, (3, 3, 4)), + ("3d_input_dim_-1", (10, 5, 10), -1, (3, 3, 4)), + ("3d_input_dim_-3", (10, 5, 10), -3, (5, 3, 4)), ] ) def test_index_select(self, _, source_shape, dim, indices_val): From 6b7ab0140a83ad6e3b40f91b1e0ed8d09660fe4d Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Thu, 11 Apr 2024 14:09:00 +0900 Subject: [PATCH 4/4] chore: resolve an unnecessary import --- py/torch_tensorrt/dynamo/conversion/impl/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index 2cacdc9ae4..ca71cb0b0c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -12,7 +12,6 @@ elementwise, embedding, grid, - index, linear, matmul, normalization,