8000 [SelectionDAG][AArch64] Add dot product lowering in NEON for PARTIAL_REDUCE_*MLA ISD nodes by NickGuy-Arm · Pull Request #140075 · llvm/llvm-project · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[SelectionDAG][AArch64] Add dot product lowering in NEON for PARTIAL_REDUCE_*MLA ISD nodes #140075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

B 8000 y clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

NickGuy-Arm
Copy link
Contributor

Lowering for fixed width vectors added to tablegen.
There is also custom lowering to ensure that the USDOT patterns are
still lowered for fixed width vectors. It also ensures that the
v16i8 -> v4i64 partial reduction case is lowered here instead of
being split (as there is not a v2i64 dot product instruction).

@JamesChesterman is the original author.

JamesChesterman and others added 3 commits May 13, 2025 17:58
…REDUCE_*MLA ISD nodes

Lowering for fixed width vectors added to tablegen.
There is also custom lowering to ensure that the USDOT patterns are
still lowered for fixed width vectors. It also ensures that the
v16i8 -> v4i64 partial reduction case is lowered here instead of
being split (as there is not a v2i64 dot product instruction).
@llvmbot
Copy link
Member
llvmbot commented May 15, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Nicholas Guy (NickGuy-Arm)

Changes

Lowering for fixed width vectors added to tablegen.
There is also custom lowering to ensure that the USDOT patterns are
still lowered for fixed width vectors. It also ensures that the
v16i8 -> v4i64 partial reduction case is lowered here instead of
being split (as there is not a v2i64 dot product instruction).

@JamesChesterman is the original author.


Patch is 37.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140075.diff

3 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+61-14)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (+11)
  • (modified) llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll (+342-129)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 13fb6a32233fe..f1354bf1147dd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1872,6 +1872,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
   }
 
+  if (EnablePartialReduceNodes && Subtarget->hasNEON() &&
+      Subtarget->hasDotProd()) {
+    setPartialReduceMLAAction(MVT::v2i64, MVT::v8i16, Legal);
+    setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
+    setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
+    setPartialReduceMLAAction(MVT::v2i32, MVT::v8i8, Legal);
+    setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
+  }
+
   // Handle operations that are only available in non-streaming SVE mode.
   if (Subtarget->isSVEAvailable()) {
     for (auto VT : {MVT::nxv16i8,  MVT::nxv8i16, MVT::nxv4i32,  MVT::nxv2i64,
@@ -7743,8 +7752,11 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     return LowerVECTOR_HISTOGRAM(Op, DAG);
   case ISD::PARTIAL_REDUCE_SMLA:
-  case ISD::PARTIAL_REDUCE_UMLA:
-    return LowerPARTIAL_REDUCE_MLA(Op, DAG);
+  case ISD::PARTIAL_REDUCE_UMLA: {
+    if (SDValue Result = LowerPARTIAL_REDUCE_MLA(Op, DAG))
+      return Result;
+    return expandPartialReduceMLA(Op.getNode(), DAG);
+  }
   }
 }
 
@@ -27569,6 +27581,14 @@ void AArch64TargetLowering::ReplaceNodeResults(
     if (SDValue Res = LowerVECTOR_COMPRESS(SDValue(N, 0), DAG))
       Results.push_back(Res);
     return;
+  case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA: {
+    if (SDValue Res = LowerPARTIAL_REDUCE_MLA(SDValue(N, 0), DAG))
+      Results.push_back(Res);
+    else
+      Results.push_back(expandPartialReduceMLA(N, DAG));
+    return;
+  }
   case ISD::ADD:
   case ISD::FADD:
     ReplaceAddWithADDP(N, Results, DAG, Subtarget);
@@ -29518,37 +29538,64 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
 }
 
 /// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing
-/// of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We can
+/// of v2i64/v16i8, we cannot directly lower it to a (u|s)dot. We can
 /// however still make use of the dot product instruction by instead
-/// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64.
+/// accumulating over two steps: v16i8 -> v4i32 -> v2i64.
 SDValue
 AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
                                                SelectionDAG &DAG) const {
+  bool Scalable = Op.getValueType().isScalableVector();
+  if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
+    return SDValue();
+  if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
+    return SDValue();
+
   SDLoc DL(Op);
 
   SDValue Acc = Op.getOperand(0);
   SDValue LHS = Op.getOperand(1);
   SDValue RHS = Op.getOperand(2);
   EVT ResultVT = Op.getValueType();
-  assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
 
-  SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
-                                DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);
+  assert((Scalable && ResultVT == MVT::nxv2i64 &&
+          LHS.getValueType() == MVT::nxv16i8) ||
+         (!Scalable && ResultVT == MVT::v2i64 &&
+          LHS.getValueType() == MVT::v16i8));
+
+  EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
+  SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
+                                DAG.getConstant(0, DL, DotVT), LHS, RHS);
 
   bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
-  if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
+  if (Scalable &&
+      (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
     unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
     unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
     SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
     return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
   }
 
-  unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
-  unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
-  auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
-  auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
-  auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
-  return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
+  if (Scalable) {
+    unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
+    unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
+    auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
+    auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
+    auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
+    return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
+  }
+
+  // Fold v4i32 into v2i64
+  // SDValues
+  auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
+  if (IsUnsigned) {
+    DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, MVT::v2i64);
+    DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, MVT::v2i64);
+  } else {
+    DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, MVT::v2i64);
+    DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, MVT::v2i64);
+  }
+  auto Lo = DAG.getNode(ISD::ADD, DL, MVT::v2i64, Acc, DotNodeLo);
+  return DAG.getNode(ISD::ADD, DL, MVT::v2i64, Lo, DotNodeHi);
 }
 
 SDValue
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index b02a907f7439f..5cc6a38d55977 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1474,6 +1474,17 @@ defm SDOTlane : SIMDThreeSameVectorDotIndex<0, 0, 0b10, "sdot", AArch64sdot>;
 defm UDOTlane : SIMDThreeSameVectorDotIndex<1, 0, 0b10, "udot", AArch64udot>;
 }
 
+let Predicates = [HasNEON, HasDotProd] in {
+  def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$Acc), (v16i8 V128:$MulLHS), (v16i8 V128:$MulRHS))),
+            (v4i32 (UDOTv16i8 V128:$Acc, V128:$MulLHS, V128:$MulRHS))>;
+  def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$Acc), (v16i8 V128:$MulLHS), (v16i8 V128:$MulRHS))),
+            (v4i32 (SDOTv16i8 V128:$Acc, V128:$MulLHS, V128:$MulRHS))>;
+  def : Pat<(v2i32 (partial_reduce_umla (v2i32 V64:$Acc), (v8i8 V64:$MulLHS), (v8i8 V64:$MulRHS))),
+            (v2i32 (UDOTv8i8 V64:$Acc, V64:$MulLHS, V64:$MulRHS))>;
+  def : Pat<(v2i32 (partial_reduce_smla (v2i32 V64:$Acc), (v8i8 V64:$MulLHS), (v8i8 V64:$MulRHS))),
+            (v2i32 (SDOTv8i8 V64:$Acc, V64:$MulLHS, V64:$MulRHS))>; 
+} // End HasNEON, HasDotProd
+
 // ARMv8.6-A BFloat
 let Predicates = [HasNEON, HasBF16] in {
 defm BFDOT       : SIMDThreeSameVectorBFDot<1, "bfdot">;
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index ab9813aa796e3..47a4796d0f9a1 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -2,7 +2,8 @@
 ; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-NOI8MM
 ; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM,CHECK-NODOT
 ; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-I8MM
-; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
+; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-I8MM,CHECK-NEWLOWERING-I8MM
+; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod -aarch64-enable-partial-reduce-nodes < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-NOI8MM,CHECK-NEWLOWERING-NOI8MM
 
 define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ; CHECK-DOT-LABEL: udot:
@@ -174,10 +175,17 @@ define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ; CHECK-NOI8MM-NEXT:    smlal2 v0.4s, v2.8h, v1.8h
 ; CHECK-NOI8MM-NEXT:    ret
 ;
-; CHECK-I8MM-LABEL: usdot:
-; CHECK-I8MM:       // %bb.0:
-; CHECK-I8MM-NEXT:    usdot v0.4s, v1.16b, v2.16b
-; CHECK-I8MM-NEXT:    ret
+; CHECK-NEWLOWERING-I8MM-LABEL: usdot:
+; CHECK-NEWLOWERING-I8MM:       // %bb.0:
+; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v3.8h, v1.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v4.8h, v2.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    ushll2 v1.8h, v1.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    sshll2 v2.8h, v2.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.4s, v4.4h, v3.4h
+; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v0.4s, v4.8h, v3.8h
+; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NEWLOWERING-I8MM-NEXT:    ret
   %u.wide = zext <16 x i8> %u to <16 x i32>
   %s.wide = sext <16 x i8> %s to <16 x i32>
   %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -209,21 +217,28 @@ define <4 x i32> @usdot_in_loop(ptr %p1, ptr %p2){
 ; CHECK-NOI8MM-NEXT:  // %bb.2: // %end
 ; CHECK-NOI8MM-NEXT:    ret
 ;
-; CHECK-I8MM-LABEL: usdot_in_loop:
-; CHECK-I8MM:       // %bb.0: // %entry
-; CHECK-I8MM-NEXT:    movi v1.2d, #0000000000000000
-; CHECK-I8MM-NEXT:    mov x8, xzr
-; CHECK-I8MM-NEXT:  .LBB6_1: // %vector.body
-; CHECK-I8MM-NEXT:    // =>This Inner Loop Header: Depth=1
-; CHECK-I8MM-NEXT:    ldr q2, [x0, x8]
-; CHECK-I8MM-NEXT:    ldr q3, [x1, x8]
-; CHECK-I8MM-NEXT:    mov v0.16b, v1.16b
-; CHECK-I8MM-NEXT:    add x8, x8, #16
-; CHECK-I8MM-NEXT:    usdot v1.4s, v3.16b, v2.16b
-; CHECK-I8MM-NEXT:    cmp x8, #16
-; CHECK-I8MM-NEXT:    b.ne .LBB6_1
-; CHECK-I8MM-NEXT:  // %bb.2: // %end
-; CHECK-I8MM-NEXT:    ret
+; CHECK-NEWLOWERING-I8MM-LABEL: usdot_in_loop:
+; CHECK-NEWLOWERING-I8MM:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-I8MM-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-NEWLOWERING-I8MM-NEXT:    mov x8, xzr
+; CHECK-NEWLOWERING-I8MM-NEXT:  .LBB6_1: // %vector.body
+; CHECK-NEWLOWERING-I8MM-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-NEWLOWERING-I8MM-NEXT:    ldr q2, [x0, x8]
+; CHECK-NEWLOWERING-I8MM-NEXT:    ldr q3, [x1, x8]
+; CHECK-NEWLOWERING-I8MM-NEXT:    mov v0.16b, v1.16b
+; CHECK-NEWLOWERING-I8MM-NEXT:    add x8, x8, #16
+; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v4.8h, v2.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v5.8h, v3.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    sshll2 v2.8h, v2.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    ushll2 v3.8h, v3.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    cmp x8, #16
+; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v1.4s, v4.4h, v5.4h
+; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v1.4s, v4.8h, v5.8h
+; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v1.4s, v2.4h, v3.4h
+; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v1.4s, v2.8h, v3.8h
+; CHECK-NEWLOWERING-I8MM-NEXT:    b.ne .LBB6_1
+; CHECK-NEWLOWERING-I8MM-NEXT:  // %bb.2: // %end
+; CHECK-NEWLOWERING-I8MM-NEXT:    ret
 entry:
   br label %vector.body
 
@@ -264,10 +279,22 @@ define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
 ; CHECK-NOI8MM-NEXT:    add v0.2s, v1.2s, v0.2s
 ; CHECK-NOI8MM-NEXT:    ret
 ;
-; CHECK-I8MM-LABEL: usdot_narrow:
-; CHECK-I8MM:       // %bb.0:
-; CHECK-I8MM-NEXT:    usdot v0.2s, v1.8b, v2.8b
-; CHECK-I8MM-NEXT:    ret
+; CHECK-NEWLOWERING-I8MM-LABEL: usdot_narrow:
+; CHECK-NEWLOWERING-I8MM:       // %bb.0:
+; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v2.8h, v2.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NEWLOWERING-I8MM-NEXT:    smull v3.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-I8MM-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NEWLOWERING-I8MM-NEXT:    ext v5.16b, v2.16b, v2.16b, #8
+; CHECK-NEWLOWERING-I8MM-NEXT:    smull2 v1.4s, v2.8h, v1.8h
+; CHECK-NEWLOWERING-I8MM-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NEWLOWERING-I8MM-NEXT:    ext v1.16b, v1.16b, v1.16b, #8
+; CHECK-NEWLOWERING-I8MM-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.4s, v5.4h, v4.4h
+; CHECK-NEWLOWERING-I8MM-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NEWLOWERING-I8MM-NEXT:    ret
   %u.wide = zext <8 x i8> %u to <8 x i32>
   %s.wide = sext <8 x i8> %s to <8 x i32>
   %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -288,10 +315,17 @@ define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
 ; CHECK-NOI8MM-NEXT:    smlal2 v0.4s, v2.8h, v1.8h
 ; CHECK-NOI8MM-NEXT:    ret
 ;
-; CHECK-I8MM-LABEL: sudot:
-; CHECK-I8MM:       // %bb.0:
-; CHECK-I8MM-NEXT:    usdot v0.4s, v2.16b, v1.16b
-; CHECK-I8MM-NEXT:    ret
+; CHECK-NEWLOWERING-I8MM-LABEL: sudot:
+; CHECK-NEWLOWERING-I8MM:       // %bb.0:
+; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v3.8h, v1.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v4.8h, v2.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    ushll2 v2.8h, v2.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.4s, v4.4h, v3.4h
+; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v0.4s, v4.8h, v3.8h
+; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NEWLOWERING-I8MM-NEXT:    ret
   %s.wide = sext <16 x i8> %u to <16 x i32>
   %u.wide = zext <16 x i8> %s to <16 x i32>
   %mult = mul nuw nsw <16 x i32> %u.wide, %s.wide
@@ -323,21 +357,28 @@ define <4 x i32> @sudot_in_loop(ptr %p1, ptr %p2){
 ; CHECK-NOI8MM-NEXT:  // %bb.2: // %end
 ; CHECK-NOI8MM-NEXT:    ret
 ;
-; CHECK-I8MM-LABEL: sudot_in_loop:
-; CHECK-I8MM:       // %bb.0: // %entry
-; CHECK-I8MM-NEXT:    movi v1.2d, #0000000000000000
-; CHECK-I8MM-NEXT:    mov x8, xzr
-; CHECK-I8MM-NEXT:  .LBB9_1: // %vector.body
-; CHECK-I8MM-NEXT:    // =>This Inner Loop Header: Depth=1
-; CHECK-I8MM-NEXT:    ldr q2, [x0, x8]
-; CHECK-I8MM-NEXT:    ldr q3, [x1, x8]
-; CHECK-I8MM-NEXT:    mov v0.16b, v1.16b
-; CHECK-I8MM-NEXT:    add x8, x8, #16
-; CHECK-I8MM-NEXT:    usdot v1.4s, v2.16b, v3.16b
-; CHECK-I8MM-NEXT:    cmp x8, #16
-; CHECK-I8MM-NEXT:    b.ne .LBB9_1
-; CHECK-I8MM-NEXT:  // %bb.2: // %end
-; CHECK-I8MM-NEXT:    ret
+; CHECK-NEWLOWERING-I8MM-LABEL: sudot_in_loop:
+; CHECK-NEWLOWERING-I8MM:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-I8MM-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-NEWLOWERING-I8MM-NEXT:    mov x8, xzr
+; CHECK-NEWLOWERING-I8MM-NEXT:  .LBB9_1: // %vector.body
+; CHECK-NEWLOWERING-I8MM-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-NEWLOWERING-I8MM-NEXT:    ldr q2, [x0, x8]
+; CHECK-NEWLOWERING-I8MM-NEXT:    ldr q3, [x1, x8]
+; CHECK-NEWLOWERING-I8MM-NEXT:    mov v0.16b, v1.16b
+; CHECK-NEWLOWERING-I8MM-NEXT:    add x8, x8, #16
+; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v4.8h, v2.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v5.8h, v3.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    ushll2 v2.8h, v2.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    sshll2 v3.8h, v3.16b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    cmp x8, #16
+; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v1.4s, v4.4h, v5.4h
+; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v1.4s, v4.8h, v5.8h
+; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v1.4s, v2.4h, v3.4h
+; CHECK-NEWLOWERING-I8MM-NEXT:    smlal2 v1.4s, v2.8h, v3.8h
+; CHECK-NEWLOWERING-I8MM-NEXT:    b.ne .LBB9_1
+; CHECK-NEWLOWERING-I8MM-NEXT:  // %bb.2: // %end
+; CHECK-NEWLOWERING-I8MM-NEXT:    ret
 entry:
   br label %vector.body
 
@@ -378,10 +419,22 @@ define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
 ; CHECK-NOI8MM-NEXT:    add v0.2s, v1.2s, v0.2s
 ; CHECK-NOI8MM-NEXT:    ret
 ;
-; CHECK-I8MM-LABEL: sudot_narrow:
-; CHECK-I8MM:       // %bb.0:
-; CHECK-I8MM-NEXT:    usdot v0.2s, v2.8b, v1.8b
-; CHECK-I8MM-NEXT:    ret
+; CHECK-NEWLOWERING-I8MM-LABEL: sudot_narrow:
+; CHECK-NEWLOWERING-I8MM:       // %bb.0:
+; CHECK-NEWLOWERING-I8MM-NEXT:    sshll v1.8h, v1.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    ushll v2.8h, v2.8b, #0
+; CHECK-NEWLOWERING-I8MM-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NEWLOWERING-I8MM-NEXT:    smull v3.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-I8MM-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NEWLOWERING-I8MM-NEXT:    ext v5.16b, v2.16b, v2.16b, #8
+; CHECK-NEWLOWERING-I8MM-NEXT:    smull2 v1.4s, v2.8h, v1.8h
+; CHECK-NEWLOWERING-I8MM-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NEWLOWERING-I8MM-NEXT:    ext v1.16b, v1.16b, v1.16b, #8
+; CHECK-NEWLOWERING-I8MM-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NEWLOWERING-I8MM-NEXT:    smlal v0.4s, v5.4h, v4.4h
+; CHECK-NEWLOWERING-I8MM-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NEWLOWERING-I8MM-NEXT:    ret
   %u.wide = sext <8 x i8> %u to <8 x i32>
   %s.wide = zext <8 x i8> %s to <8 x i32>
   %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -390,14 +443,6 @@ define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
 }
 
 define <4 x i64> @udot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
-; CHECK-DOT-LABEL: udot_8to64:
-; CHECK-DOT:       // %bb.0: // %entry
-; CHECK-DOT-NEXT:    movi v4.2d, #0000000000000000
-; CHECK-DOT-NEXT:    udot v4.4s, v2.16b, v3.16b
-; CHECK-DOT-NEXT:    saddw2 v1.2d, v1.2d, v4.4s
-; CHECK-DOT-NEXT:    saddw v0.2d, v0.2d, v4.2s
-; CHECK-DOT-NEXT:    ret
-;
 ; CHECK-NODOT-LABEL: udot_8to64:
 ; CHECK-NODOT:       // %bb.0: // %entry
 ; CHECK-NODOT-NEXT:    umull v4.8h, v2.8b, v3.8b
@@ -415,6 +460,22 @@ define <4 x i64> @udot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
 ; CHECK-NODOT-NEXT:    uaddw2 v1.2d, v1.2d, v2.4s
 ; CHECK-NODOT-NEXT:    uaddw2 v0.2d, v0.2d, v4.4s
 ; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NEWLOWERING-I8MM-LABEL: udot_8to64:
+; CHECK-NEWLOWERING-I8MM:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-I8MM-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-I8MM-NEXT:    udot v4.4s, v2.16b, v3.16b
+; CHECK-NEWLOWERING-I8MM-NEXT:    uaddw v0.2d, v0.2d, v4.2s
+; CHECK-NEWLOWERING-I8MM-NEXT:    uaddw2 v0.2d, v0.2d, v4.4s
+; CHECK-NEWLOWERING-I8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-NOI8MM-LABEL: udot_8to64:
+; CHECK-NEWLOWERING-NOI8MM:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NOI8MM-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-NOI8MM-NEXT:    udot v4.4s, v2.16b, v3.16b
+; CHECK-NEWLOWERING-NOI8MM-NEXT:    uaddw v0.2d, v0.2d, v4.2s
+; CHECK-NEWLOWERING-NOI8MM-NEXT:    uaddw2 v0.2d, v0.2d, v4.4s
+; CHECK-NEWLOWERING-NOI8MM-NEXT:    ret
 entry:
   %a.wide = zext <16 x i8> %a to <16 x i64>
   %b.wide = zext <16 x i8> %b to <16 x i64>
@@ -425,14 +486,6 @@ entry:
 }
 
 define <4 x i64> @sdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
-; CHECK-DOT-LABEL: sdot_8to64:
-; CHECK-DOT:       // %bb.0: // %entry
-; CHECK-DOT-NEXT:    movi v4.2d, #0000000000000000
-; CHECK-DOT-NEXT:    sdot v4.4s, v2.16b, v3.16b
-; CHECK-DOT-NEXT:    saddw2 v1.2d, v1.2d, v4.4s
-; CHECK-DOT-NEXT:    saddw v0.2d, v0.2d, v4.2s
-; CHECK-DOT-NEXT:    ret
-;
 ; CHECK-NODOT-LABEL: sdot_8to64:
 ; CHECK-NODOT:       // %bb.0: // %entry
 ; CHECK-NODOT-NEXT:    smull v4.8h, v2.8b, v3.8b
@@ -450,6 +503,22 @@ define <4 x i64> @sdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
 ; CHECK-NODOT-NEXT:    saddw2 v1.2d, v1.2d, v2.4s
 ; CHECK-NODOT-NEXT:    saddw2 v0.2d, v0.2d, v4.4s
 ; CHECK-NODOT-NEXT:    ret
+;
+; CHECK-NEWLOWERING-I8MM-LABEL: sdot_8to64:
+; CHECK-NEWLOWERING-I8MM:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-I8MM-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-I8MM-NEXT:    sdot v4.4s, v2.16b, v3.16b
+; CHECK-NEWLOWERING-I8MM-NEXT:    saddw v0.2d, v0.2d, v4.2s
+; CHECK-NEWLOWERING-I8MM-NEXT:    saddw2 v0.2d, v0.2d, v4.4s
+; CHECK-NEWLOWERING-I8MM-NEXT:    ret
+;
+; CHECK-NEWLOWERING-NOI8MM-LABEL: sdot_8to64:
+; CHECK-NEWLOWERING-NOI8MM:       // %bb.0: // %entry
+; CHECK-NEWLOWERING-NOI8MM-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEWLOWERING-NOI8MM-NEXT:    sdot v4.4s, v2.16b, v3.16b
+; CHECK-NEWLOWERING-NOI8MM-NEXT:    saddw v0.2d, v0.2d, v4.2s
+; CHECK-NEWLOWERING-NOI8MM-NEXT:    saddw2 v0.2d, v0.2d, v4.4s
+; CHECK-NEWLOWERING-NOI8MM-NEXT:    ret
...
[truncated]

Copy link
Collaborator
@SamTebbs33 SamTebbs33 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with a couple of suggestions.

@@ -948,3 +1159,5 @@ end:
%2 = add <4 x i32> %psum2, %psum1
ret <4 x i32> %2
}
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line: 10000
; CHECK-I8MM: {{.*}}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this shows regressions have been introduced when the code does not use the new lowering?

Copy link
Contributor Author
@NickGuy-Arm NickGuy-Arm May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it was regressions, or them being masked by other prefixes being fed to FileCheck. I've separated the relevant run line out with the prefix CHECK-DOT-I8MM, and we're at least seeing the output from it again.
As far as I could tell, running llc with the +i8mm attribute produced the correct asm (i.e. a usdot instruction), so it doesn't appear to have been a functional regression at least.

With some post-push hindsight, I could probably have removed the prefix from the succeeding run line, instead of adding a new one.

Copy link
Member
@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

;
; CHECK-NEWLOWERING-NOI8MM-LABEL: sdot_8to64:
; CHECK-NEWLOWERING-NOI8MM: // %bb.0: // %entry
; CHECK-NEWLOWERING-NOI8MM-NEXT: movi v4.2d, #0000000000000000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I guess that's correct, but given that v4 is used as a vector i32, wouldn't it make more sense to have movi v4.4s, #0?

Comment on lines 27579 to 27584
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA: {
if (SDValue Res = LowerPARTIAL_REDUCE_MLA(SDValue(N, 0), DAG))
Results.push_back(Res);
return;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry just noticed this, but Is this code actually used? (normally this is only needed when the result type is not legal, but the input is).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It definitely was at one point in history, to support the v16i8 -> v4i64 cases. But as we're now handling that differently (by splitting the accumulator into v2i64) this code is never hit with the current test cases. Removed.

Copy link
Collaborator
@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@NickGuy-Arm NickGuy-Arm merged commit 26bae79 into llvm:main May 28, 2025
11 checks passed
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Jun 3, 2025
…REDUCE_*MLA ISD nodes (llvm#140075)

Lowering for fixed width vectors added to tablegen.
There is also custom lowering to ensure that the USDOT patterns are
still lowered for fixed width vectors. It also ensures that the
v16i8 -> v4i64 partial reduction case is lowered here instead of
being split (as there is not a v2i64 dot product instruction).

@JamesChesterman is the original author.

---------

Co-authored-by: James Chesterman <james.chesterman@arm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants
0