Open
Description
Describe the bug
LLVM ERROR: Unsupported DotOp found when converting TritonGPU to LLVM.
#blocked = #ttg.blocked<{sizePerThread = [1, 1, 4], threadsPerWarp = [1, 1, 32], warpsPerCTA = [4, 1, 1], order = [2, 0, 1]}>
#blocked1 = #ttg.blocked<{sizePerThread = [4, 1, 1], threadsPerWarp = [32, 1, 1], warpsPerCTA = [1, 1, 4], order = [0, 1, 2]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked3 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked4 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}>
#blocked5 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
#blocked6 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 1, 4], order = [2, 1, 0]}>
#blocked7 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked8 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [2, 1, 16], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}>
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}>
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 2, order = [0, 1]}>
#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1, 2]}>
#shared3 = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 2, order = [1, 0]}>
#shared4 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 0, 1]}>
#shared5 = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [0, 1]}>
#shared6 = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 8, order = [1, 0]}>
#shared7 = #ttg.swizzled_shared<{vec = 16, perPhase = 8, maxPhase = 1, order = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:89", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @_fwd_factorized(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32 {tt.divisibility = 16 : i32}, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg24: i32 {tt.divisibility = 16 : i32}, %arg25: i32 {tt.divisibility = 16 : i32}, %arg26: i32 {tt.divisibility = 16 : i32}, %arg27: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg28: i32 {tt.divisibility = 16 : i32}, %arg29: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg30: i32 {tt.divisibility = 16 : i32}, %arg31: !tt.ptr<i1> {tt.divisibility = 16 : i32}, %arg32: i32 {tt.divisibility = 16 : i32}, %arg33: !tt.ptr<i1> {tt.divisibility = 16 : i32}, %arg34: i32 {tt.divisibility = 16 : i32}, %arg35: f32, %arg36: f32, %arg37: i32 {tt.divisibility = 16 : i32}, %arg38: i32 {tt.divisibility = 16 : i32}, %arg39: i32) attributes {noinline = false} {
%c32_i32 = arith.constant 32 : i32
%cst = arith.constant dense<16> : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>>
%cst_0 = arith.constant dense<16> : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #blocked1}>}>>
%cst_1 = arith.constant dense<16> : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%c2_i32 = arith.constant 2 : i32
%c1_i32 = arith.constant 1 : i32
%c-1_i32 = arith.constant -1 : i32
%cst_2 = arith.constant dense<1.44269502> : tensor<16x16xf32, #mma>
%cst_3 = arith.constant dense<0.000000e+00> : tensor<16x128xf32, #mma>
%cst_4 = arith.constant dense<0.000000e+00> : tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%cst_5 = arith.constant dense<0xFF800000> : tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%cst_6 = arith.constant dense<0.000000e+00> : tensor<128xf32, #ttg.slice<{dim = 0, parent = #mma}>>
%cst_7 = arith.constant dense<16> : tensor<1x16xi32, #blocked2>
%cst_8 = arith.constant dense<16> : tensor<16xi32, #blocked3>
%c0_i32 = arith.constant 0 : i32
%c16_i32 = arith.constant 16 : i32
%cst_9 = arith.constant dense<16> : tensor<16x1xi32, #blocked4>
%cst_10 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma>
%cst_11 = arith.constant dense<16> : tensor<16xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%cst_12 = arith.constant dense<0> : tensor<16xi8, #ttg.slice<{dim = 1, parent = #mma}>>
%cst_13 = arith.constant dense<0> : tensor<16xi8, #ttg.slice<{dim = 0, parent = #mma}>>
%0 = tt.get_program_id x : i32
%1 = tt.get_program_id y : i32
%2 = arith.divsi %1, %arg39 : i32
%3 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%4 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>>
%5 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked4}>>
%6 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%7 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%8 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #blocked1}>}>>
%9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked3>
%10 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked3>
%11 = arith.muli %1, %arg4 : i32
%12 = tt.addptr %arg3, %11 : !tt.ptr<f32>, i32
%13 = arith.muli %0, %arg5 : i32
%14 = tt.addptr %12, %13 : !tt.ptr<f32>, i32
%15 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked5}>>
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #blocked}>}>>
%17 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #blocked6}>}>>
%18 = tt.expand_dims %15 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked5}>> -> tensor<1x128xi32, #blocked5>
%19 = tt.expand_dims %16 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #blocked}>}>> -> tensor<1x128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%20 = tt.expand_dims %17 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #blocked6}>}>> -> tensor<1x128xi32, #ttg.slice<{dim = 1, parent = #blocked6}>>
%21 = tt.splat %14 : !tt.ptr<f32> -> tensor<1x128x!tt.ptr<f32>, #blocked5>
%22 = tt.addptr %21, %18 : tensor<1x128x!tt.ptr<f32>, #blocked5>, tensor<1x128xi32, #blocked5>
%23 = arith.muli %1, %arg8 : i32
%24 = tt.addptr %arg7, %23 : !tt.ptr<f32>, i32
%25 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked6}>}>>
%26 = tt.expand_dims %5 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1xi32, #blocked4>
%27 = tt.expand_dims %4 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>> -> tensor<16x1xi32, #ttg.slice<{dim = 2, parent = #blocked}>>
%28 = tt.expand_dims %25 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked6}>}>> -> tensor<16x1xi32, #ttg.slice<{dim = 2, parent = #blocked6}>>
%29 = tt.expand_dims %27 {axis = 2 : i32} : tensor<16x1xi32, #ttg.slice<{dim = 2, parent = #blocked}>> -> tensor<16x1x1xi32, #blocked>
%30 = tt.expand_dims %28 {axis = 2 : i32} : tensor<16x1xi32, #ttg.slice<{dim = 2, parent = #blocked6}>> -> tensor<16x1x1xi32, #blocked6>
%31 = tt.splat %arg9 : i32 -> tensor<16x1x1xi32, #blocked>
%32 = arith.muli %29, %31 : tensor<16x1x1xi32, #blocked>
%33 = tt.splat %24 : !tt.ptr<f32> -> tensor<16x1x1x!tt.ptr<f32>, #blocked>
%34 = tt.addptr %33, %32 : tensor<16x1x1x!tt.ptr<f32>, #blocked>, tensor<16x1x1xi32, #blocked>
%35 = tt.expand_dims %19 {axis = 1 : i32} : tensor<1x128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1x1x128xi32, #blocked>
%36 = tt.expand_dims %20 {axis = 1 : i32} : tensor<1x128xi32, #ttg.slice<{dim = 1, parent = #blocked6}>> -> tensor<1x1x128xi32, #blocked6>
%37 = tt.broadcast %34 : tensor<16x1x1x!tt.ptr<f32>, #blocked> -> tensor<16x1x128x!tt.ptr<f32>, #blocked>
%38 = tt.broadcast %35 : tensor<1x1x128xi32, #blocked> -> tensor<16x1x128xi32, #blocked>
%39 = tt.broadcast %36 : tensor<1x1x128xi32, #blocked6> -> tensor<16x1x128xi32, #blocked6>
%40 = tt.addptr %37, %38 : tensor<16x1x128x!tt.ptr<f32>, #blocked>, tensor<16x1x128xi32, #blocked>
%41 = arith.muli %1, %arg12 : i32
%42 = tt.addptr %arg11, %41 : !tt.ptr<f32>, i32
%43 = arith.muli %0, %arg13 : i32
%44 = tt.addptr %42, %43 : !tt.ptr<f32>, i32
%45 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked7}>>
%46 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked8}>}>>
%47 = tt.expand_dims %45 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked7}>> -> tensor<128x1xi32, #blocked7>
%48 = tt.expand_dims %46 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked8}>}>> -> tensor<128x1xi32, #ttg.slice<{dim = 2, parent = #blocked8}>>
%49 = tt.splat %44 : !tt.ptr<f32> -> tensor<128x1x!tt.ptr<f32>, #blocked7>
%50 = tt.addptr %49, %47 : tensor<128x1x!tt.ptr<f32>, #blocked7>, tensor<128x1xi32, #blocked7>
%51 = arith.muli %1, %arg16 : i32
%52 = tt.addptr %arg15, %51 : !tt.ptr<f32>, i32
%53 = tt.expand_dims %48 {axis = 2 : i32} : tensor<128x1xi32, #ttg.slice<{dim = 2, parent = #blocked8}>> -> tensor<128x1x1xi32, #blocked8>
%54 = tt.splat %52 : !tt.ptr<f32> -> tensor<128x1x1x!tt.ptr<f32>, #blocked8>
%55 = tt.addptr %54, %53 : tensor<128x1x1x!tt.ptr<f32>, #blocked8>, tensor<128x1x1xi32, #blocked8>
%56 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #blocked8}>}>>
%57 = tt.expand_dims %7 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x16xi32, #blocked2>
%58 = tt.expand_dims %56 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #blocked8}>}>> -> tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #blocked8}>>
%59 = tt.expand_dims %58 {axis = 1 : i32} : tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #blocked8}>> -> tensor<1x1x16xi32, #blocked8>
%60 = tt.splat %arg17 : i32 -> tensor<1x1x16xi32, #blocked8>
%61 = arith.muli %59, %60 : tensor<1x1x16xi32, #blocked8>
%62 = tt.broadcast %55 : tensor<128x1x1x!tt.ptr<f32>, #blocked8> -> tensor<128x1x16x!tt.ptr<f32>, #blocked8>
%63 = tt.broadcast %61 : tensor<1x1x16xi32, #blocked8> -> tensor<128x1x16xi32, #blocked8>
%64 = tt.addptr %62, %63 : tensor<128x1x16x!tt.ptr<f32>, #blocked8>, tensor<128x1x16xi32, #blocked8>
%65 = arith.muli %1, %arg28 : i32
%66 = tt.addptr %arg27, %65 : !tt.ptr<f32>, i32
%67 = tt.splat %66 : !tt.ptr<f32> -> tensor<16x1x!tt.ptr<f32>, #blocked4>
%68 = tt.addptr %67, %26 : tensor<16x1x!tt.ptr<f32>, #blocked4>, tensor<16x1xi32, #blocked4>
%69 = arith.muli %1, %arg30 : i32
%70 = tt.addptr %arg29, %69 : !tt.ptr<f32>, i32
%71 = tt.splat %70 : !tt.ptr<f32> -> tensor<1x16x!tt.ptr<f32>, #blocked2>
%72 = tt.addptr %71, %57 : tensor<1x16x!tt.ptr<f32>, #blocked2>, tensor<1x16xi32, #blocked2>
%73 = arith.muli %1, %arg20 : i32
%74 = tt.addptr %arg19, %73 : !tt.ptr<f32>, i32
%75 = arith.muli %0, %arg21 : i32
%76 = tt.addptr %74, %75 : !tt.ptr<f32>, i32
%77 = tt.splat %76 : !tt.ptr<f32> -> tensor<1x128x!tt.ptr<f32>, #blocked5>
%78 = tt.addptr %77, %18 : tensor<1x128x!tt.ptr<f32>, #blocked5>, tensor<1x128xi32, #blocked5>
%79 = arith.muli %1, %arg24 : i32
%80 = tt.addptr %arg23, %79 : !tt.ptr<f32>, i32
%81 = tt.splat %arg25 : i32 -> tensor<16x1x1xi32, #blocked6>
%82 = arith.muli %30, %81 : tensor<16x1x1xi32, #blocked6>
%83 = tt.splat %80 : !tt.ptr<f32> -> tensor<16x1x1x!tt.ptr<f32>, #blocked6>
%84 = tt.addptr %83, %82 : tensor<16x1x1x!tt.ptr<f32>, #blocked6>, tensor<16x1x1xi32, #blocked6>
%85 = tt.broadcast %84 : tensor<16x1x1x!tt.ptr<f32>, #blocked6> -> tensor<16x1x128x!tt.ptr<f32>, #blocked6>
%86 = tt.addptr %85, %39 : tensor<16x1x128x!tt.ptr<f32>, #blocked6>, tensor<16x1x128xi32, #blocked6>
%87 = arith.muli %2, %arg32 : i32
%88 = tt.addptr %arg31, %87 : !tt.ptr<i1>, i32
%89 = tt.splat %88 : !tt.ptr<i1> -> tensor<16x!tt.ptr<i1>, #blocked3>
%90 = tt.addptr %89, %9 : tensor<16x!tt.ptr<i1>, #blocked3>, tensor<16xi32, #blocked3>
%91 = arith.muli %2, %arg34 : i32
%92 = tt.addptr %arg33, %91 : !tt.ptr<i1>, i32
%93 = tt.splat %92 : !tt.ptr<i1> -> tensor<16x!tt.ptr<i1>, #ttg.slice<{dim = 1, parent = #mma}>>
%94 = tt.addptr %93, %3 : tensor<16x!tt.ptr<i1>, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<16xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%95 = arith.muli %1, %arg1 : i32
%96 = tt.addptr %arg0, %95 : !tt.ptr<f32>, i32
%97 = arith.muli %0, %arg2 : i32
%98 = tt.addptr %96, %97 : !tt.ptr<f32>, i32
%99 = tt.splat %98 : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>, #blocked3>
%100 = tt.addptr %99, %10 : tensor<128x!tt.ptr<f32>, #blocked3>, tensor<128xi32, #blocked3>
%101 = tt.load %22 : tensor<1x128x!tt.ptr<f32>, #blocked5>
%102 = tt.splat %arg35 : f32 -> tensor<1x128xf32, #blocked5>
%103 = arith.mulf %101, %102 : tensor<1x128xf32, #blocked5>
%104 = tt.load %50 : tensor<128x1x!tt.ptr<f32>, #blocked7>
%105 = ttg.convert_layout %104 : tensor<128x1xf32, #blocked7> -> tensor<128x1xf32, #ttg.slice<{dim = 2, parent = #blocked1}>>
%106 = tt.load %78 : tensor<1x128x!tt.ptr<f32>, #blocked5>
%107 = ttg.convert_layout %106 : tensor<1x128xf32, #blocked5> -> tensor<1x128xf32, #ttg.slice<{dim = 0, parent = #blocked}>>
%108 = tt.splat %arg37 : i32 -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%109 = tt.splat %arg37 : i32 -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>>
%110 = tt.splat %arg37 : i32 -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked4}>>
%111 = ttg.convert_layout %103 : tensor<1x128xf32, #blocked5> -> tensor<1x128xf32, #ttg.slice<{dim = 0, parent = #blocked}>>
%112 = tt.expand_dims %111 {axis = 0 : i32} : tensor<1x128xf32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x1x128xf32, #blocked>
%113 = tt.broadcast %112 : tensor<1x1x128xf32, #blocked> -> tensor<16x1x128xf32, #blocked>
%114 = tt.splat %arg38 : i32 -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%115 = tt.splat %arg38 : i32 -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%116 = tt.splat %arg38 : i32 -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #blocked1}>}>>
%117 = tt.splat %arg38 : i32 -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>>
%118 = tt.splat %arg38 : i32 -> tensor<16xi32, #blocked3>
%119 = tt.expand_dims %105 {axis = 2 : i32} : tensor<128x1xf32, #ttg.slice<{dim = 2, parent = #blocked1}>> -> tensor<128x1x1xf32, #blocked1>
%120 = tt.broadcast %119 : tensor<128x1x1xf32, #blocked1> -> tensor<128x1x16xf32, #blocked1>
%121 = tt.expand_dims %107 {axis = 0 : i32} : tensor<1x128xf32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x1x128xf32, #blocked>
%122 = tt.broadcast %121 : tensor<1x1x128xf32, #blocked> -> tensor<16x1x128xf32, #blocked>
%123 = tt.splat %arg36 : f32 -> tensor<16x16xf32, #mma>
%124 = arith.muli %arg17, %c16_i32 : i32
%125 = tt.splat %124 : i32 -> tensor<128x1x16xi32, #blocked1>
%126 = tt.splat %124 : i32 -> tensor<128x1x16xi32, #blocked8>
%127 = arith.muli %arg25, %c16_i32 : i32
%128 = tt.splat %127 : i32 -> tensor<16x1x128xi32, #blocked>
%129 = tt.splat %127 : i32 -> tensor<16x1x128xi32, #blocked6>
%130 = arith.muli %arg9, %c16_i32 : i32
%131 = tt.splat %130 : i32 -> tensor<16x1x128xi32, #blocked>
%132:8 = scf.for %arg40 = %c0_i32 to %arg37 step %c16_i32 iter_args(%arg41 = %64, %arg42 = %86, %arg43 = %72, %arg44 = %90, %arg45 = %cst_6, %arg46 = %40, %arg47 = %68, %arg48 = %94) -> (tensor<128x1x16x!tt.ptr<f32>, #blocked8>, tensor<16x1x128x!tt.ptr<f32>, #blocked6>, tensor<1x16x!tt.ptr<f32>, #blocked2>, tensor<16x!tt.ptr<i1>, #blocked3>, tensor<128xf32, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<16x1x128x!tt.ptr<f32>, #blocked>, tensor<16x1x!tt.ptr<f32>, #blocked4>, tensor<16x!tt.ptr<i1>, #ttg.slice<{dim = 1, parent = #mma}>>) : i32 {
%134 = tt.splat %arg40 : i32 -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%135 = tt.splat %arg40 : i32 -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>>
%136 = tt.splat %arg40 : i32 -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked4}>>
%137 = arith.addi %3, %134 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%138 = arith.addi %4, %135 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>>
%139 = arith.addi %5, %136 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked4}>>
%140 = arith.cmpi slt, %137, %108 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%141 = arith.cmpi slt, %138, %109 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>>
%142 = arith.cmpi slt, %139, %110 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked4}>>
%143 = tt.bitcast %arg48 : tensor<16x!tt.ptr<i1>, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<16x!tt.ptr<i8>, #ttg.slice<{dim = 1, parent = #mma}>>
%144 = tt.load %143, %140 : tensor<16x!tt.ptr<i8>, #ttg.slice<{dim = 1, parent = #mma}>>
%145 = arith.cmpi ne, %144, %cst_12 : tensor<16xi8, #ttg.slice<{dim = 1, parent = #mma}>>
%146 = tt.expand_dims %140 {axis = 1 : i32} : tensor<16xi1, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<16x1xi1, #mma>
%147 = tt.expand_dims %141 {axis = 1 : i32} : tensor<16xi1, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>> -> tensor<16x1xi1, #ttg.slice<{dim = 2, parent = #blocked}>>
%148 = tt.expand_dims %142 {axis = 1 : i32} : tensor<16xi1, #ttg.slice<{dim = 1, parent = #blocked4}>> -> tensor<16x1xi1, #blocked4>
%149 = tt.expand_dims %147 {axis = 2 : i32} : tensor<16x1xi1, #ttg.slice<{dim = 2, parent = #blocked}>> -> tensor<16x1x1xi1, #blocked>
%150 = tt.broadcast %149 : tensor<16x1x1xi1, #blocked> -> tensor<16x1x128xi1, #blocked>
%151 = tt.load %arg46, %150 : tensor<16x1x128x!tt.ptr<f32>, #blocked>
%152 = arith.mulf %113, %151 : tensor<16x1x128xf32, #blocked>
%153 = "tt.reduce"(%152) <{axis = 1 : i32}> ({
^bb0(%arg49: f32, %arg50: f32):
%240 = arith.addf %arg49, %arg50 : f32
tt.reduce.return %240 : f32
}) : (tensor<16x1x128xf32, #blocked>) -> tensor<16x128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
%154 = ttg.local_alloc %153 : (tensor<16x128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) -> !ttg.memdesc<16x128xf32, #shared, #smem>
%155 = tt.load %arg47, %148 : tensor<16x1x!tt.ptr<f32>, #blocked4>
%156 = ttg.local_alloc %155 : (tensor<16x1xf32, #blocked4>) -> !ttg.memdesc<16x1xf32, #shared1, #smem>
%157 = tt.expand_dims %145 {axis = 1 : i32} : tensor<16xi1, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<16x1xi1, #mma>
%158 = tt.broadcast %157 : tensor<16x1xi1, #mma> -> tensor<16x16xi1, #mma>
%159 = tt.broadcast %146 : tensor<16x1xi1, #mma> -> tensor<16x16xi1, #mma>
%160 = ttg.local_alloc : () -> !ttg.memdesc<2x128x1x16xf32, #shared2, #smem, mutable>
%161 = ttg.local_alloc : () -> !ttg.memdesc<2x1x16xf32, #shared3, #smem, mutable>
%162 = ttg.local_alloc : () -> !ttg.memdesc<2x16x1x128xf32, #shared4, #smem, mutable>
%163 = arith.cmpi sgt, %arg38, %c0_i32 : i32
%164 = arith.cmpi slt, %7, %115 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%165 = arith.cmpi slt, %8, %116 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #blocked1}>}>>
%166 = arith.cmpi slt, %4, %117 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>>
%167 = tt.expand_dims %164 {axis = 0 : i32} : tensor<16xi1, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x16xi1, #blocked2>
%168 = tt.expand_dims %165 {axis = 0 : i32} : tensor<16xi1, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #blocked1}>}>> -> tensor<1x16xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>
%169 = tt.expand_dims %168 {axis = 1 : i32} : tensor<1x16xi1, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<1x1x16xi1, #blocked1>
%170 = tt.broadcast %169 : tensor<1x1x16xi1, #blocked1> -> tensor<128x1x16xi1, #blocked1>
%171 = ttg.convert_layout %arg41 : tensor<128x1x16x!tt.ptr<f32>, #blocked8> -> tensor<128x1x16x!tt.ptr<f32>, #blocked1>
%172 = ttg.memdesc_subview %160[%c0_i32, %c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x1x16xf32, #shared2, #smem, mutable> -> !ttg.memdesc<128x1x16xf32, #shared2, #smem, mutable, 2x128x1x16>
%173 = tt.splat %163 : i1 -> tensor<128x1x16xi1, #blocked1>
%174 = arith.andi %173, %170 : tensor<128x1x16xi1, #blocked1>
%175 = ttg.async_copy_global_to_local %171, %172 mask %174 : tensor<128x1x16x!tt.ptr<f32>, #blocked1> -> <128x1x16xf32, #shared2, #smem, mutable, 2x128x1x16>
%176 = ttg.async_commit_group %175
%177 = ttg.memdesc_subview %161[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x1x16xf32, #shared3, #smem, mutable> -> !ttg.memdesc<1x16xf32, #shared3, #smem, mutable, 2x1x16>
%178 = tt.splat %163 : i1 -> tensor<1x16xi1, #blocked2>
%179 = arith.andi %178, %167 : tensor<1x16xi1, #blocked2>
%180 = ttg.async_copy_global_to_local %arg43, %177 mask %179 : tensor<1x16x!tt.ptr<f32>, #blocked2> -> <1x16xf32, #shared3, #smem, mutable, 2x1x16>
%181 = ttg.async_commit_group %180
%182 = tt.expand_dims %166 {axis = 1 : i32} : tensor<16xi1, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>> -> tensor<16x1xi1, #ttg.slice<{dim = 2, parent = #blocked}>>
%183 = tt.expand_dims %182 {axis = 2 : i32} : tensor<16x1xi1, #ttg.slice<{dim = 2, parent = #blocked}>> -> tensor<16x1x1xi1, #blocked>
%184 = tt.broadcast %183 : tensor<16x1x1xi1, #blocked> -> tensor<16x1x128xi1, #blocked>
%185 = ttg.convert_layout %arg42 : tensor<16x1x128x!tt.ptr<f32>, #blocked6> -> tensor<16x1x128x!tt.ptr<f32>, #blocked>
%186 = ttg.memdesc_subview %162[%c0_i32, %c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x16x1x128xf32, #shared4, #smem, mutable> -> !ttg.memdesc<16x1x128xf32, #shared4, #smem, mutable, 2x16x1x128>
%187 = tt.splat %163 : i1 -> tensor<16x1x128xi1, #blocked>
%188 = arith.andi %187, %184 : tensor<16x1x128xi1, #blocked>
%189 = ttg.async_copy_global_to_local %185, %186 mask %188 : tensor<16x1x128x!tt.ptr<f32>, #blocked> -> <16x1x128xf32, #shared4, #smem, mutable, 2x16x1x128>
%190 = ttg.async_commit_group %189
%191 = arith.cmpi sgt, %arg38, %c16_i32 : i32
%192 = tt.addptr %171, %125 : tensor<128x1x16x!tt.ptr<f32>, #blocked1>, tensor<128x1x16xi32, #blocked1>
%193 = tt.addptr %arg41, %126 : tensor<128x1x16x!tt.ptr<f32>, #blocked8>, tensor<128x1x16xi32, #blocked8>
%194 = arith.select %163, %192, %171 : tensor<128x1x16x!tt.ptr<f32>, #blocked1>
%195 = arith.select %163, %193, %arg41 : tensor<128x1x16x!tt.ptr<f32>, #blocked8>
%196 = tt.addptr %185, %128 : tensor<16x1x128x!tt.ptr<f32>, #blocked>, tensor<16x1x128xi32, #blocked>
%197 = tt.addptr %arg42, %129 : tensor<16x1x128x!tt.ptr<f32>, #blocked6>, tensor<16x1x128xi32, #blocked6>
%198 = arith.select %163, %196, %185 : tensor<16x1x128x!tt.ptr<f32>, #blocked>
%199 = arith.select %163, %197, %arg42 : tensor<16x1x128x!tt.ptr<f32>, #blocked6>
%200 = tt.addptr %arg43, %cst_7 : tensor<1x16x!tt.ptr<f32>, #blocked2>, tensor<1x16xi32, #blocked2>
%201 = arith.select %163, %200, %arg43 : tensor<1x16x!tt.ptr<f32>, #blocked2>
%202 = arith.addi %7, %cst_1 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%203 = arith.addi %8, %cst_0 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #blocked1}>}>>
%204 = arith.addi %4, %cst : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>>
%205 = arith.cmpi slt, %202, %115 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%206 = arith.cmpi slt, %203, %116 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #blocked1}>}>>
%207 = arith.cmpi slt, %204, %117 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>>
%208 = tt.expand_dims %205 {axis = 0 : i32} : tensor<16xi1, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x16xi1, #blocked2>
%209 = tt.expand_dims %206 {axis = 0 : i32} : tensor<16xi1, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #blocked1}>}>> -> tensor<1x16xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>
%210 = tt.expand_dims %209 {axis = 1 : i32} : tensor<1x16xi1, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<1x1x16xi1, #blocked1>
%211 = tt.broadcast %210 : tensor<1x1x16xi1, #blocked1> -> tensor<128x1x16xi1, #blocked1>
%212 = ttg.memdesc_subview %160[%c1_i32, %c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x1x16xf32, #shared2, #smem, mutable> -> !ttg.memdesc<128x1x16xf32, #shared2, #smem, mutable, 2x128x1x16>
%213 = tt.splat %191 : i1 -> tensor<128x1x16xi1, #blocked1>
%214 = arith.andi %213, %211 : tensor<128x1x16xi1, #blocked1>
%215 = ttg.async_copy_global_to_local %194, %212 mask %214 : tensor<128x1x16x!tt.ptr<f32>, #blocked1> -> <128x1x16xf32, #shared2, #smem, mutable, 2x128x1x16>
%216 = ttg.async_commit_group %215
%217 = ttg.memdesc_subview %161[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x1x16xf32, #shared3, #smem, mutable> -> !ttg.memdesc<1x16xf32, #shared3, #smem, mutable, 2x1x16>
%218 = tt.splat %191 : i1 -> tensor<1x16xi1, #blocked2>
%219 = arith.andi %218, %208 : tensor<1x16xi1, #blocked2>
%220 = ttg.async_copy_global_to_local %201, %217 mask %219 : tensor<1x16x!tt.ptr<f32>, #blocked2> -> <1x16xf32, #shared3, #smem, mutable, 2x1x16>
%221 = ttg.async_commit_group %220
%222 = tt.expand_dims %207 {axis = 1 : i32} : tensor<16xi1, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>> -> tensor<16x1xi1, #ttg.slice<{dim = 2, parent = #blocked}>>
%223 = tt.expand_dims %222 {axis = 2 : i32} : tensor<16x1xi1, #ttg.slice<{dim = 2, parent = #blocked}>> -> tensor<16x1x1xi1, #blocked>
%224 = tt.broadcast %223 : tensor<16x1x1xi1, #blocked> -> tensor<16x1x128xi1, #blocked>
%225 = ttg.memdesc_subview %162[%c1_i32, %c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x16x1x128xf32, #shared4, #smem, mutable> -> !ttg.memdesc<16x1x128xf32, #shared4, #smem, mutable, 2x16x1x128>
%226 = tt.splat %191 : i1 -> tensor<16x1x128xi1, #blocked>
%227 = arith.andi %226, %224 : tensor<16x1x128xi1, #blocked>
%228 = ttg.async_copy_global_to_local %198, %225 mask %227 : tensor<16x1x128x!tt.ptr<f32>, #blocked> -> <16x1x128xf32, #shared4, #smem, mutable, 2x16x1x128>
%229 = ttg.async_commit_group %228
%230:15 = scf.for %arg49 = %c0_i32 to %arg38 step %c16_i32 iter_args(%arg50 = %cst_4, %arg51 = %cst_3, %arg52 = %cst_5, %arg53 = %195, %arg54 = %199, %arg55 = %201, %arg56 = %arg44, %arg57 = %c1_i32, %arg58 = %c-1_i32, %arg59 = %181, %arg60 = %221, %arg61 = %190, %arg62 = %229, %arg63 = %194, %arg64 = %198) -> (tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<16x128xf32, #mma>, tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128x1x16x!tt.ptr<f32>, #blocked8>, tensor<16x1x128x!tt.ptr<f32>, #blocked6>, tensor<1x16x!tt.ptr<f32>, #blocked2>, tensor<16x!tt.ptr<i1>, #blocked3>, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, tensor<128x1x16x!tt.ptr<f32>, #blocked1>, tensor<16x1x128x!tt.ptr<f32>, #blocked>) : i32 {
%240 = arith.subi %arg38, %c32_i32 : i32
%241 = arith.cmpi slt, %arg49, %240 : i32
%242 = arith.subi %arg38, %c16_i32 : i32
%243 = arith.cmpi slt, %arg49, %242 : i32
%244 = arith.addi %arg58, %c1_i32 : i32
%245 = arith.cmpi slt, %244, %c2_i32 : i32
%246 = arith.select %245, %244, %c0_i32 : i32
%247 = tt.splat %arg49 : i32 -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%248 = tt.splat %arg49 : i32 -> tensor<16xi32, #blocked3>
%249 = arith.addi %6, %247 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%250 = arith.addi %9, %248 : tensor<16xi32, #blocked3>
%251 = arith.cmpi slt, %249, %114 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%252 = arith.cmpi slt, %250, %118 : tensor<16xi32, #blocked3>
%253 = tt.bitcast %arg56 : tensor<16x!tt.ptr<i1>, #blocked3> -> tensor<16x!tt.ptr<i8>, #blocked3>
%254 = tt.load %253, %252 : tensor<16x!tt.ptr<i8>, #blocked3>
%255 = ttg.convert_layout %254 : tensor<16xi8, #blocked3> -> tensor<16xi8, #ttg.slice<{dim = 0, parent = #mma}>>
%256 = arith.cmpi ne, %255, %cst_13 : tensor<16xi8, #ttg.slice<{dim = 0, parent = #mma}>>
%257 = tt.expand_dims %251 {axis = 0 : i32} : tensor<16xi1, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x16xi1, #mma>
%258 = ttg.async_wait %arg59, %arg61 {num = 3 : i32}
%259 = ttg.memdesc_subview %160[%246, %c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x1x16xf32, #shared2, #smem, mutable> -> !ttg.memdesc<128x1x16xf32, #shared2, #smem, mutable, 2x128x1x16>
%260 = ttg.local_load %259 token %258 : !ttg.memdesc<128x1x16xf32, #shared2, #smem, mutable, 2x128x1x16> -> tensor<128x1x16xf32, #blocked1>
%261 = ttg.memdesc_subview %161[%246, %c0_i32, %c0_i32] : !ttg.memdesc<2x1x16xf32, #shared3, #smem, mutable> -> !ttg.memdesc<1x16xf32, #shared3, #smem, mutable, 2x1x16>
%262 = ttg.memdesc_subview %162[%246, %c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x16x1x128xf32, #shared4, #smem, mutable> -> !ttg.memdesc<16x1x128xf32, #shared4, #smem, mutable, 2x16x1x128>
%263 = ttg.local_load %262 token %258 : !ttg.memdesc<16x1x128xf32, #shared4, #smem, mutable, 2x16x1x128> -> tensor<16x1x128xf32, #blocked>
%264 = arith.mulf %120, %260 : tensor<128x1x16xf32, #blocked1>
%265 = "tt.reduce"(%264) <{axis = 1 : i32}> ({
^bb0(%arg65: f32, %arg66: f32):
%348 = arith.addf %arg65, %arg66 : f32
tt.reduce.return %348 : f32
}) : (tensor<128x1x16xf32, #blocked1>) -> tensor<128x16xf32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%266 = ttg.local_alloc %265 : (tensor<128x16xf32, #ttg.slice<{dim = 1, parent = #blocked1}>>) -> !ttg.memdesc<128x16xf32, #shared5, #smem>
%267 = arith.mulf %122, %263 : tensor<16x1x128xf32, #blocked>
%268 = "tt.reduce"(%267) <{axis = 1 : i32}> ({
^bb0(%arg65: f32, %arg66: f32):
%348 = arith.addf %arg65, %arg66 : f32
tt.reduce.return %348 : f32
}) : (tensor<16x1x128xf32, #blocked>) -> tensor<16x128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
%269 = ttg.local_alloc %268 : (tensor<16x128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) -> !ttg.memdesc<16x128xf32, #shared6, #smem>
%270 = ttg.local_load %154 : !ttg.memdesc<16x128xf32, #shared, #smem> -> tensor<16x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
%271 = ttg.local_load %266 : !ttg.memdesc<128x16xf32, #shared5, #smem> -> tensor<128x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>
%272 = tt.dot %270, %271, %cst_10, inputPrecision = tf32 : tensor<16x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<128x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<16x16xf32, #mma>
%273 = ttg.local_load %156 : !ttg.memdesc<16x1xf32, #shared1, #smem> -> tensor<16x1xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
%274 = ttg.local_load %261 : !ttg.memdesc<1x16xf32, #shared3, #smem, mutable, 2x1x16> -> tensor<1x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>>
%275 = tt.dot %273, %274, %272, inputPrecision = tf32 : tensor<16x1xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<1x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<16x16xf32, #mma>
%276 = arith.mulf %275, %cst_2 : tensor<16x16xf32, #mma>
%277 = tt.expand_dims %256 {axis = 0 : i32} : tensor<16xi1, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x16xi1, #mma>
%278 = tt.broadcast %277 : tensor<1x16xi1, #mma> -> tensor<16x16xi1, #mma>
%279 = arith.andi %158, %278 : tensor<16x16xi1, #mma>
%280 = arith.select %279, %123, %276 : tensor<16x16xi1, #mma>, tensor<16x16xf32, #mma>
%281 = tt.broadcast %257 : tensor<1x16xi1, #mma> -> tensor<16x16xi1, #mma>
%282 = arith.andi %159, %281 : tensor<16x16xi1, #mma>
%283 = arith.select %282, %280, %123 : tensor<16x16xi1, #mma>, tensor<16x16xf32, #mma>
%284 = "tt.reduce"(%283) <{axis = 1 : i32}> ({
^bb0(%arg65: f32, %arg66: f32):
%348 = arith.maxnumf %arg65, %arg66 : f32
tt.reduce.return %348 : f32
}) : (tensor<16x16xf32, #mma>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%285 = arith.maxnumf %arg52, %284 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%286 = tt.expand_dims %285 {axis = 1 : i32} : tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<16x1xf32, #mma>
%287 = tt.broadcast %286 : tensor<16x1xf32, #mma> -> tensor<16x16xf32, #mma>
%288 = arith.subf %283, %287 : tensor<16x16xf32, #mma>
%289 = math.exp2 %288 : tensor<16x16xf32, #mma>
%290 = ttg.local_alloc %289 : (tensor<16x16xf32, #mma>) -> !ttg.memdesc<16x16xf32, #shared7, #smem>
%291 = "tt.reduce"(%289) <{axis = 1 : i32}> ({
^bb0(%arg65: f32, %arg66: f32):
%348 = arith.addf %arg65, %arg66 : f32
tt.reduce.return %348 : f32
}) : (tensor<16x16xf32, #mma>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%292 = arith.subf %arg52, %285 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%293 = math.exp2 %292 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%294 = arith.mulf %arg50, %293 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%295 = arith.addf %294, %291 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>>
%296 = tt.expand_dims %293 {axis = 1 : i32} : tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<16x1xf32, #mma>
%297 = tt.broadcast %296 : tensor<16x1xf32, #mma> -> tensor<16x128xf32, #mma>
%298 = arith.mulf %arg51, %297 : tensor<16x128xf32, #mma>
%299 = ttg.local_load %290 : !ttg.memdesc<16x16xf32, #shared7, #smem> -> tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
%300 = ttg.local_load %269 : !ttg.memdesc<16x128xf32, #shared6, #smem> -> tensor<16x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%301 = tt.dot %299, %300, %298, inputPrecision = tf32 : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<16x128xf32, #mma>
%302 = tt.addptr %arg56, %cst_8 : tensor<16x!tt.ptr<i1>, #blocked3>, tensor<16xi32, #blocked3>
%303 = tt.addptr %arg63, %125 : tensor<128x1x16x!tt.ptr<f32>, #blocked1>, tensor<128x1x16xi32, #blocked1>
%304 = tt.addptr %arg53, %126 : tensor<128x1x16x!tt.ptr<f32>, #blocked8>, tensor<128x1x16xi32, #blocked8>
%305 = tt.addptr %arg64, %128 : tensor<16x1x128x!tt.ptr<f32>, #blocked>, tensor<16x1x128xi32, #blocked>
%306 = tt.addptr %arg54, %129 : tensor<16x1x128x!tt.ptr<f32>, #blocked6>, tensor<16x1x128xi32, #blocked6>
%307 = tt.addptr %arg55, %cst_7 : tensor<1x16x!tt.ptr<f32>, #blocked2>, tensor<1x16xi32, #blocked2>
%308 = arith.addi %arg57, %c1_i32 : i32
%309 = arith.cmpi slt, %308, %c2_i32 : i32
%310 = arith.select %309, %308, %c0_i32 : i32
%311 = arith.addi %arg49, %c32_i32 : i32
%312 = tt.splat %311 : i32 -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%313 = tt.splat %311 : i32 -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #blocked1}>}>>
%314 = tt.splat %311 : i32 -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>>
%315 = arith.addi %7, %312 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%316 = arith.addi %8, %313 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #blocked1}>}>>
%317 = arith.addi %4, %314 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>>
%318 = arith.cmpi slt, %315, %115 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%319 = arith.cmpi slt, %316, %116 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #blocked1}>}>>
%320 = arith.cmpi slt, %317, %117 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>>
%321 = tt.expand_dims %318 {axis = 0 : i32} : tensor<16xi1, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x16xi1, #blocked2>
%322 = tt.expand_dims %319 {axis = 0 : i32} : tensor<16xi1, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #blocked1}>}>> -> tensor<1x16xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>
%323 = tt.expand_dims %322 {axis = 1 : i32} : tensor<1x16xi1, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<1x1x16xi1, #blocked1>
%324 = tt.broadcast %323 : tensor<1x1x16xi1, #blocked1> -> tensor<128x1x16xi1, #blocked1>
%325 = ttg.memdesc_subview %160[%310, %c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x1x16xf32, #shared2, #smem, mutable> -> !ttg.memdesc<128x1x16xf32, #shared2, #smem, mutable, 2x128x1x16>
%326 = tt.splat %241 : i1 -> tensor<128x1x16xi1, #blocked1>
%327 = arith.andi %326, %324 : tensor<128x1x16xi1, #blocked1>
%328 = ttg.async_copy_global_to_local %303, %325 mask %327 : tensor<128x1x16x!tt.ptr<f32>, #blocked1> -> <128x1x16xf32, #shared2, #smem, mutable, 2x128x1x16>
%329 = ttg.async_commit_group %328
%330 = ttg.memdesc_subview %161[%310, %c0_i32, %c0_i32] : !ttg.memdesc<2x1x16xf32, #shared3, #smem, mutable> -> !ttg.memdesc<1x16xf32, #shared3, #smem, mutable, 2x1x16>
%331 = tt.splat %241 : i1 -> tensor<1x16xi1, #blocked2>
%332 = arith.andi %331, %321 : tensor<1x16xi1, #blocked2>
%333 = ttg.async_copy_global_to_local %307, %330 mask %332 : tensor<1x16x!tt.ptr<f32>, #blocked2> -> <1x16xf32, #shared3, #smem, mutable, 2x1x16>
%334 = ttg.async_commit_group %333
%335 = tt.expand_dims %320 {axis = 1 : i32} : tensor<16xi1, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>> -> tensor<16x1xi1, #ttg.slice<{dim = 2, parent = #blocked}>>
%336 = tt.expand_dims %335 {axis = 2 : i32} : tensor<16x1xi1, #ttg.slice<{dim = 2, parent = #blocked}>> -> tensor<16x1x1xi1, #blocked>
%337 = tt.broadcast %336 : tensor<16x1x1xi1, #blocked> -> tensor<16x1x128xi1, #blocked>
%338 = ttg.memdesc_subview %162[%310, %c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x16x1x128xf32, #shared4, #smem, mutable> -> !ttg.memdesc<16x1x128xf32, #shared4, #smem, mutable, 2x16x1x128>
%339 = tt.splat %241 : i1 -> tensor<16x1x128xi1, #blocked>
%340 = arith.andi %339, %337 : tensor<16x1x128xi1, #blocked>
%341 = ttg.async_copy_global_to_local %305, %338 mask %340 : tensor<16x1x128x!tt.ptr<f32>, #blocked> -> <16x1x128xf32, #shared4, #smem, mutable, 2x16x1x128>
%342 = ttg.async_commit_group %341
%343 = arith.select %243, %303, %arg63 : tensor<128x1x16x!tt.ptr<f32>, #blocked1>
%344 = arith.select %243, %304, %arg53 : tensor<128x1x16x!tt.ptr<f32>, #blocked8>
%345 = arith.select %243, %305, %arg64 : tensor<16x1x128x!tt.ptr<f32>, #blocked>
%346 = arith.select %243, %306, %arg54 : tensor<16x1x128x!tt.ptr<f32>, #blocked6>
%347 = arith.select %243, %307, %arg55 : tensor<1x16x!tt.ptr<f32>, #blocked2>
scf.yield %295, %301, %285, %344, %346, %347, %302, %310, %246, %arg60, %334, %arg62, %342, %343, %345 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<16x128xf32, #mma>, tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<128x1x16x!tt.ptr<f32>, #blocked8>, tensor<16x1x128x!tt.ptr<f32>, #blocked6>, tensor<1x16x!tt.ptr<f32>, #blocked2>, tensor<16x!tt.ptr<i1>, #blocked3>, i32, i32, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, tensor<128x1x16x!tt.ptr<f32>, #blocked1>, tensor<16x1x128x!tt.ptr<f32>, #blocked>
} {tt.divisibility_arg1 = dense<16> : tensor<1xi32>}
%231 = ttg.async_wait {num = 0 : i32}
ttg.local_dealloc %162 : !ttg.memdesc<2x16x1x128xf32, #shared4, #smem, mutable>
ttg.local_dealloc %161 : !ttg.memdesc<2x1x16xf32, #shared3, #smem, mutable>
ttg.local_dealloc %160 : !ttg.memdesc<2x128x1x16xf32, #shared2, #smem, mutable>
%232 = tt.expand_dims %230#0 {axis = 1 : i32} : tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<16x1xf32, #mma>
%233 = tt.broadcast %232 : tensor<16x1xf32, #mma> -> tensor<16x128xf32, #mma>
%234 = arith.divf %230#1, %233 : tensor<16x128xf32, #mma>
%235 = "tt.reduce"(%234) <{axis = 0 : i32}> ({
^bb0(%arg49: f32, %arg50: f32):
%240 = arith.addf %arg49, %arg50 : f32
tt.reduce.return %240 : f32
}) : (tensor<16x128xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 0, parent = #mma}>>
%236 = arith.addf %arg45, %235 : tensor<128xf32, #ttg.slice<{dim = 0, parent = #mma}>>
%237 = tt.addptr %arg46, %131 : tensor<16x1x128x!tt.ptr<f32>, #blocked>, tensor<16x1x128xi32, #blocked>
%238 = tt.addptr %arg47, %cst_9 : tensor<16x1x!tt.ptr<f32>, #blocked4>, tensor<16x1xi32, #blocked4>
%239 = tt.addptr %arg48, %cst_11 : tensor<16x!tt.ptr<i1>, #ttg.slice<{dim = 1, parent = #mma}>>, tensor<16xi32, #ttg.slice<{dim = 1, parent = #mma}>>
scf.yield %230#3, %230#4, %230#5, %230#6, %236, %237, %238, %239 : tensor<128x1x16x!tt.ptr<f32>, #blocked8>, tensor<16x1x128x!tt.ptr<f32>, #blocked6>, tensor<1x16x!tt.ptr<f32>, #blocked2>, tensor<16x!tt.ptr<i1>, #blocked3>, tensor<128xf32, #ttg.slice<{dim = 0, parent = #mma}>>, tensor<16x1x128x!tt.ptr<f32>, #blocked>, tensor<16x1x!tt.ptr<f32>, #blocked4>, tensor<16x!tt.ptr<i1>, #ttg.slice<{dim = 1, parent = #mma}>>
} {tt.divisibility_arg1 = dense<16> : tensor<1xi32>}
%133 = ttg.convert_layout %132#4 : tensor<128xf32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<128xf32, #blocked3>
tt.store %100, %133 : tensor<128x!tt.ptr<f32>, #blocked3>
tt.return
}
}
{-#
external_resources: {
mlir_reproducer: {
pipeline: "builtin.module(triton-nvidia-mma-lowering, tritongpu-combine-tensor-select-and-if, tritongpu-allocate-warp-groups, convert-scf-to-cf, allocate-shared-memory, triton-tensor-memory-allocation, tritongpu-global-scratch-memory-allocation, convert-triton-gpu-to-llvm{compute-capability=89 ptx-version=84}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-nv-gpu-to-llvm, convert-warp-specialize-to-llvm, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info)",
disable_threading: false,
verify_each: true
}
}
#-}
/home/pi-user/trifast/src/trifast/triton_factorized.py:19:0: error: Failures have been detected while processing an MLIR pass pipeline
/home/pi-user/trifast/src/trifast/triton_factorized.py:19:0: note: Pipeline failed while executing [`ConvertTritonGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`
Traceback (most recent call last):
File "/home/pi-user/trifast/src/trifast/test_triton_factorized.py", line 146, in <module>
test_fwd_factorized_kernel()
File "/home/pi-user/trifast/src/trifast/test_triton_factorized.py", line 106, in test_fwd_factorized_kernel
wrap_triton(_fwd_factorized)[grid](
File "/home/pi-user/trifast/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 1812, in __call__
return tracing_triton_hopifier_singleton.call_triton_kernel(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/pi-user/trifast/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 1670, in call_triton_kernel
return self.call_HOP(variable, grids, combined_args_raw, tx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/pi-user/trifast/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 1766, in call_HOP
return triton_kernel_wrapper_mutation(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/pi-user/trifast/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 783, in __call__
return super().__call__(
^^^^^^^^^^^^^^^^^
File "/home/pi-user/trifast/.venv/lib/python3.11/site-packages/torch/_ops.py", line 471, in __call__
return wrapper()
^^^^^^^^^
File "/home/pi-user/trifast/.venv/lib/python3.11/site-packages/torch/_ops.py", line 467, in wrapper
return self.dispatch(
^^^^^^^^^^^^^^
File "/home/pi-user/trifast/.venv/lib/python3.11/site-packages/torch/_ops.py", line 455, in dispatch
return kernel(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/pi-user/trifast/.venv/lib/python3.11/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 886, in triton_kernel_wrapper_mutation_dense
kernel[grid_fn](*args, **kwargs, **constant_args)
File "/home/pi-user/trifast/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 347, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/pi-user/trifast/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 569, in run
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/pi-user/trifast/.venv/lib/python3.11/site-packages/triton/compiler/compiler.py", line 284, in compile
next_module = compile_ir(module, metadata)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/pi-user/trifast/.venv/lib/python3.11/site-packages/triton/backends/nvidia/compiler.py", line 450, in <lambda>
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/pi-user/trifast/.venv/lib/python3.11/site-packages/triton/backends/nvidia/compiler.py", line 341, in make_llir
pm.run(mod)
RuntimeError: PassManager::run failed
Environment details
Triton 3.3.0, L40S GPU