8000 TorchToLinalg: casting float to integer should round to nearest · Issue #4091 · llvm/torch-mlir · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
TorchToLinalg: casting float to integer should round to nearest #4091
Open
@bjacob

Description

@bjacob

This comes from debugging a IREE ONNX test suite failure: https://github.com/iree-org/iree/actions/runs/13796654893/job/38590777349#step:8:72

The failure message is:

 [FAILED] result[0]: element at index 1 (31) does not match the expected (32); expected that the view is equal to contents of a view of 3xi32
  expected:
3xi32=1 32 729
  actual:
3xi32=1 31 729

Notice: 32 != 31.

The test linked from that failure is:
https://github.com/iree-org/iree-test-suites/tree/main/onnx_ops/onnx/node/generated/test_pow_types_int32_float32

Its source code is:
https://github.com/iree-org/iree-test-suites/blob/main/onnx_ops/onnx/node/generated/test_pow_types_int32_float32/model.mlir

The relevant op is:

    %0 = torch.operator "onnx.Pow"(%arg0, %arg1) : (!torch.vtensor<[3],si32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],si32> 

The relevant aspect of it is that the return type is integral, but the op internally expands to a math.powf which produces a floating-point value which needs to be casted to an integer type.

// -----// IR Dump After ConvertTorchToLinalg (convert-torch-to-linalg) //----- //
func.func @test_pow_types_int32_float32(%arg0: !torch.vtensor<[3],si32>, %arg1: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],si32> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
  %0 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[3],f32> -> tensor<3xf32>
  %1 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[3],si32> -> tensor<3xi32>
  %int3 = torch.constant.int 3
  %none = torch.constant.none
  %false = torch.constant.bool false
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %c3 = arith.constant 3 : index
  %c0_0 = arith.constant 0 : index
  %c3_1 = arith.constant 3 : index
  %2 = tensor.empty() : tensor<3xf64>
  %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %0 : tensor<3xi32>, tensor<3xf32>) outs(%2 : tensor<3xf64>) {
  ^bb0(%in: i32, %in_6: f32, %out: f64):
    %7 = arith.sitofp %in : i32 to f64
    %8 = arith.extf %in_6 : f32 to f64
    %9 = math.powf %7, %8 : f64
    linalg.yield %9 : f64
  } -> tensor<3xf64>
  %cast = tensor.cast %3 : tensor<3xf64> to tensor<3xf64>
  %c1_2 = arith.constant 1 : index
  %c0_3 = arith.constant 0 : index
  %c3_4 = arith.constant 3 : index
  %4 = tensor.empty() : tensor<3xi32>
  %5 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%cast : tensor<3xf64>) outs(%4 : tensor<3xi32>) {
  ^bb0(%in: f64, %out: i32):
    %7 = arith.fptosi %in : f64 to i32
    linalg.yield %7 : i32
  } -> tensor<3xi32>
  %cast_5 = tensor.cast %5 : tensor<3xi32> to tensor<3xi32>
  %6 = torch_c.from_builtin_tensor %cast_5 : tensor<3xi32> -> !torch.vtensor<[3],si32>
  return %6 : !torch.vtensor<[3],si32>
}

The problem here is that arith.fptosi is explicitly rounding towards zero:
https://mlir.llvm.org/docs/Dialects/ArithOps/#arithfptosi-arithfptosiop

That makes any floating point difference, producing e.g. 31.9999 instead of 32.0, cause this test failure as 31.9999 gets rounded towards zero to 31.0.

Instead, ConvertTorchToLinalg should emit some kind of round or roundeven op.

469B

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0