From a36644e674b6929bf12e49b3ef86750ad6cfa068 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 11 Apr 2023 10:10:36 -0700 Subject: [PATCH] fix: Out-Of-Bounds bug in Unsqueeze - Implement support for negative-indexing in unsqueeze, in accordance with the `torch.unsqueeze` function - Fix bug for TFT model in PyTorch, which requires negative indexing in unsqueeze - Update core utility logic and comments --- core/util/trt_util.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/core/util/trt_util.cpp b/core/util/trt_util.cpp index 7982ffb846..77c88b465d 100644 --- a/core/util/trt_util.cpp +++ b/core/util/trt_util.cpp @@ -161,8 +161,14 @@ nvinfer1::Dims unpadDims(const nvinfer1::Dims& d) { } nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos, int val, bool use_zeros) { - // acceptable range for pos is [0, d.nbDims] - TORCHTRT_ASSERT(pos >= 0 && pos <= d.nbDims, "ERROR: Index to unsqueeze is out of bounds."); + // Acceptable range for pos is [-d.nbDims - 1, d.nbDims] + TORCHTRT_ASSERT( + pos >= (-d.nbDims - 1) && pos <= d.nbDims, + "ERROR: Index to unsqueeze is out of bounds. " + << "Expected value in range [" << (-d.nbDims - 1) << ", " << d.nbDims << "], but got " << pos); + + // Unsqueeze with negative dimensions creates a new dimension at that index + pos = (pos < 0) ? (pos + d.nbDims + 1) : pos; nvinfer1::Dims dims; for (int i = 0, j = 0; j <= d.nbDims; j++) {