From 5501acebdb47a7e8b2f82edef5363773921cc696 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 5 Feb 2024 15:12:46 -0800 Subject: [PATCH 1/2] small fix: Index validator enable int64 - Repair test case --- .../dynamo/conversion/aten_ops_converters.py | 2 +- tests/py/dynamo/conversion/test_index_aten.py | 9 ++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 451fdbcb63..478cf98dea 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -397,7 +397,7 @@ def index_dtype_validator(node: Node) -> bool: for ind in index: if ind is not None: val = ind.meta.get("val") - if val is not None and val.dtype != torch.int32: + if val is not None and val.dtype not in (torch.int32, torch.int64): return False return True diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index df61a4b835..bf7769c608 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -1,10 +1,8 @@ -import operator - import torch import torch.nn as nn -from .harness import DispatchTestCase from torch.testing._internal.common_utils import run_tests -from torch_tensorrt import Input + +from .harness import DispatchTestCase class TestIndexConverter(DispatchTestCase): @@ -15,7 +13,6 @@ def __init__(self): super().__init__() def forward(self, x): - index0 = torch.randint(0, 1, (1, 1)) indices = [None, self.index0] out = torch.ops.aten.index.Tensor(x, indices) return out @@ -158,8 +155,6 @@ def __init__(self): super().__init__() def forward(self, x): - index0 = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]) - index1 = index0.unsqueeze(0).T.long() indices = [None, None, self.index0, self.index1] out = torch.ops.aten.index.Tensor(x, indices) return out From 7909916042b2206cc0db80e6cf9ea1bd10da1190 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 5 Feb 2024 16:29:18 -0800 Subject: [PATCH 2/2] feat: Add `dynamic=False` specification to examples --- examples/dynamo/torch_compile_advanced_usage.py | 7 +++++-- examples/dynamo/torch_compile_transformers_example.py | 1 + 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/dynamo/torch_compile_advanced_usage.py b/examples/dynamo/torch_compile_advanced_usage.py index 96146a43d8..8ebedab111 100644 --- a/examples/dynamo/torch_compile_advanced_usage.py +++ b/examples/dynamo/torch_compile_advanced_usage.py @@ -43,7 +43,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): # For the default settings, we can simply call torch.compile # with the backend "torch_tensorrt", and run the model on an # input to cause compilation, as so: -optimized_model = torch.compile(model, backend="torch_tensorrt") +optimized_model = torch.compile(model, backend="torch_tensorrt", dynamic=False) optimized_model(*sample_inputs) # %% @@ -81,7 +81,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): # Run the model on an input to cause compilation, as so: optimized_model_custom = torch.compile( - model_half, backend="torch_tensorrt", options=backend_kwargs + model_half, + backend="torch_tensorrt", + options=backend_kwargs, + dynamic=False, ) optimized_model_custom(*sample_inputs_half) diff --git a/examples/dynamo/torch_compile_transformers_example.py b/examples/dynamo/torch_compile_transformers_example.py index 5422f9cc1d..01d46e96f6 100644 --- a/examples/dynamo/torch_compile_transformers_example.py +++ b/examples/dynamo/torch_compile_transformers_example.py @@ -61,6 +61,7 @@ optimized_model = torch.compile( model, backend="torch_tensorrt", + dynamic=False, options=compilation_kwargs, ) optimized_model(*inputs)