OSDN Git Service

[DAGCombiner] Initial support for the fast-math flag contract
authorAdam Nemet <anemet@apple.com>
Thu, 30 Mar 2017 18:53:04 +0000 (18:53 +0000)
committerAdam Nemet <anemet@apple.com>
Thu, 30 Mar 2017 18:53:04 +0000 (18:53 +0000)
Now alternatively to the TargetOption.AllowFPOpFusion global flag, FMUL->FADD
can also use the per operation FMF to allow fusion.

The idea here is not to port everything to the new scheme (e.g. fused
multiply-and-sub will be ported later) but that this work all the way from
clang.

The transformation is conditionalized on *both* the FADD and the FMUL having
the FMF contract flag.

Differential Revision: https://reviews.llvm.org/D31169

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@299096 91177308-0d34-0410-b5e6-96231b3b80d8

lib/CodeGen/SelectionDAG/DAGCombiner.cpp
test/CodeGen/AArch64/neon-fma-FMF.ll [new file with mode: 0644]
test/CodeGen/PowerPC/fma-aggr-FMF.ll [new file with mode: 0644]

index 9260c9e..901c0e9 100644 (file)
@@ -8720,6 +8720,11 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
   return DAG.getBuildVector(VT, DL, Ops);
 }
 
+static bool isContractable(SDNode *N) {
+  SDNodeFlags F = cast<BinaryWithFlagsSDNode>(N)->Flags;
+  return F.hasAllowContract() || F.hasUnsafeAlgebra();
+}
+
 /// Try to perform FMA combining on a given FADD node.
 SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
   SDValue N0 = N->getOperand(0);
@@ -8728,24 +8733,27 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
   SDLoc SL(N);
 
   const TargetOptions &Options = DAG.getTarget().Options;
-  bool AllowFusion =
-      (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath);
 
   // Floating-point multiply-add with intermediate rounding.
   bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT));
 
   // Floating-point multiply-add without intermediate rounding.
   bool HasFMA =
-      AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) &&
+      TLI.isFMAFasterThanFMulAndFAdd(VT) &&
       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
 
   // No valid opcode, do not combine.
   if (!HasFMAD && !HasFMA)
     return SDValue();
 
+  bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast ||
+                              Options.UnsafeFPMath || HasFMAD);
+  // If the addition is not contractable, do not combine.
+  if (!AllowFusionGlobally && !isContractable(N))
+    return SDValue();
+
   const SelectionDAGTargetInfo *STI = DAG.getSubtarget().getSelectionDAGInfo();
-  ;
-  if (AllowFusion && STI && STI->generateFMAsInMachineCombiner(OptLevel))
+  if (STI && STI->generateFMAsInMachineCombiner(OptLevel))
     return SDValue();
 
   // Always prefer FMAD to FMA for precision.
@@ -8753,35 +8761,39 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
   bool LookThroughFPExt = TLI.isFPExtFree(VT);
 
+  // Is the node an FMUL and contractable either due to global flags or
+  // SDNodeFlags.
+  auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
+    if (N.getOpcode() != ISD::FMUL)
+      return false;
+    return AllowFusionGlobally || isContractable(N.getNode());
+  };
   // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)),
   // prefer to fold the multiply with fewer uses.
-  if (Aggressive && N0.getOpcode() == ISD::FMUL &&
-      N1.getOpcode() == ISD::FMUL) {
+  if (Aggressive && isContractableFMUL(N0) && isContractableFMUL(N1)) {
     if (N0.getNode()->use_size() > N1.getNode()->use_size())
       std::swap(N0, N1);
   }
 
   // fold (fadd (fmul x, y), z) -> (fma x, y, z)
-  if (N0.getOpcode() == ISD::FMUL &&
-      (Aggressive || N0->hasOneUse())) {
+  if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
     return DAG.getNode(PreferredFusedOpcode, SL, VT,
                        N0.getOperand(0), N0.getOperand(1), N1);
   }
 
   // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
   // Note: Commutes FADD operands.
-  if (N1.getOpcode() == ISD::FMUL &&
-      (Aggressive || N1->hasOneUse())) {
+  if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
     return DAG.getNode(PreferredFusedOpcode, SL, VT,
                        N1.getOperand(0), N1.getOperand(1), N0);
   }
 
   // Look through FP_EXTEND nodes to do more combining.
-  if (AllowFusion && LookThroughFPExt) {
+  if (LookThroughFPExt) {
     // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
     if (N0.getOpcode() == ISD::FP_EXTEND) {
       SDValue N00 = N0.getOperand(0);
-      if (N00.getOpcode() == ISD::FMUL)
+      if (isContractableFMUL(N00))
         return DAG.getNode(PreferredFusedOpcode, SL, VT,
                            DAG.getNode(ISD::FP_EXTEND, SL, VT,
                                        N00.getOperand(0)),
@@ -8793,7 +8805,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
     // Note: Commutes FADD operands.
     if (N1.getOpcode() == ISD::FP_EXTEND) {
       SDValue N10 = N1.getOperand(0);
-      if (N10.getOpcode() == ISD::FMUL)
+      if (isContractableFMUL(N10))
         return DAG.getNode(PreferredFusedOpcode, SL, VT,
                            DAG.getNode(ISD::FP_EXTEND, SL, VT,
                                        N10.getOperand(0)),
@@ -8834,7 +8846,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
                                      N0));
     }
 
-    if (AllowFusion && LookThroughFPExt) {
+    if (/*AllowFusion &&*/ LookThroughFPExt) {
       // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
       //   -> (fma x, y, (fma (fpext u), (fpext v), z))
       auto FoldFAddFMAFPExtFMul = [&] (
@@ -8849,7 +8861,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
         SDValue N02 = N0.getOperand(2);
         if (N02.getOpcode() == ISD::FP_EXTEND) {
           SDValue N020 = N02.getOperand(0);
-          if (N020.getOpcode() == ISD::FMUL)
+          if (isContractableFMUL(N020))
             return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1),
                                         N020.getOperand(0), N020.getOperand(1),
                                         N1);
@@ -8875,7 +8887,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
         SDValue N00 = N0.getOperand(0);
         if (N00.getOpcode() == PreferredFusedOpcode) {
           SDValue N002 = N00.getOperand(2);
-          if (N002.getOpcode() == ISD::FMUL)
+          if (isContractableFMUL(N002))
             return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1),
                                         N002.getOperand(0), N002.getOperand(1),
                                         N1);
@@ -8888,7 +8900,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
         SDValue N12 = N1.getOperand(2);
         if (N12.getOpcode() == ISD::FP_EXTEND) {
           SDValue N120 = N12.getOperand(0);
-          if (N120.getOpcode() == ISD::FMUL)
+          if (isContractableFMUL(N120))
             return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1),
                                         N120.getOperand(0), N120.getOperand(1),
                                         N0);
@@ -8904,7 +8916,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
         SDValue N10 = N1.getOperand(0);
         if (N10.getOpcode() == PreferredFusedOpcode) {
           SDValue N102 = N10.getOperand(2);
-          if (N102.getOpcode() == ISD::FMUL)
+          if (isContractableFMUL(N102))
             return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1),
                                         N102.getOperand(0), N102.getOperand(1),
                                         N0);
diff --git a/test/CodeGen/AArch64/neon-fma-FMF.ll b/test/CodeGen/AArch64/neon-fma-FMF.ll
new file mode 100644 (file)
index 0000000..f1e9d4f
--- /dev/null
@@ -0,0 +1,27 @@
+; RUN: llc < %s -verify-machineinstrs -mtriple=aarch64-none-linux-gnu -mattr=+neon | FileCheck %s
+
+define <2 x float> @fma(<2 x float> %A, <2 x float> %B, <2 x float> %C) {
+; CHECK-LABEL: fma:
+; CHECK: fmla {{v[0-9]+}}.2s, {{v[0-9]+}}.2s, {{v[0-9]+}}.2s
+       %tmp1 = fmul contract <2 x float> %A, %B;
+       %tmp2 = fadd contract <2 x float> %C, %tmp1;
+       ret <2 x float> %tmp2
+}
+
+define <2 x float> @no_fma_1(<2 x float> %A, <2 x float> %B, <2 x float> %C) {
+; CHECK-LABEL: no_fma_1:
+; CHECK: fmul
+; CHECK: fadd
+       %tmp1 = fmul contract <2 x float> %A, %B;
+       %tmp2 = fadd <2 x float> %C, %tmp1;
+       ret <2 x float> %tmp2
+}
+
+define <2 x float> @no_fma_2(<2 x float> %A, <2 x float> %B, <2 x float> %C) {
+; CHECK-LABEL: no_fma_2:
+; CHECK: fmul
+; CHECK: fadd
+       %tmp1 = fmul <2 x float> %A, %B;
+       %tmp2 = fadd contract <2 x float> %C, %tmp1;
+       ret <2 x float> %tmp2
+}
diff --git a/test/CodeGen/PowerPC/fma-aggr-FMF.ll b/test/CodeGen/PowerPC/fma-aggr-FMF.ll
new file mode 100644 (file)
index 0000000..8e97115
--- /dev/null
@@ -0,0 +1,35 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -verify-machineinstrs -mtriple=powerpc64le-linux-gnu | FileCheck %s
+
+define float @can_fma_with_fewer_uses(float %f1, float %f2, float %f3, float %f4) {
+; CHECK-LABEL: can_fma_with_fewer_uses:
+; CHECK:       # BB#0:
+; CHECK-NEXT:    xsmulsp 0, 1, 2
+; CHECK-NEXT:    fmr 1, 0
+; CHECK-NEXT:    xsmaddasp 1, 3, 4
+; CHECK-NEXT:    xsdivsp 1, 0, 1
+; CHECK-NEXT:    blr
+  %mul1 = fmul contract float %f1, %f2
+  %mul2 = fmul contract float %f3, %f4
+  %add = fadd contract float %mul1, %mul2
+  %second_use_of_mul1 = fdiv float %mul1, %add
+  ret float %second_use_of_mul1
+}
+
+; There is no contract on the mul with no extra use so we can't fuse that.
+; Since we are fusing with the mul with an extra use, the fmul needs to stick
+; around beside the fma.
+define float @no_fma_with_fewer_uses(float %f1, float %f2, float %f3, float %f4) {
+; CHECK-LABEL: no_fma_with_fewer_uses:
+; CHECK:       # BB#0:
+; CHECK-NEXT:    xsmulsp 0, 3, 4
+; CHECK-NEXT:    xsmulsp 13, 1, 2
+; CHECK-NEXT:    xsmaddasp 0, 1, 2
+; CHECK-NEXT:    xsdivsp 1, 13, 0
+; CHECK-NEXT:    blr
+  %mul1 = fmul contract float %f1, %f2
+  %mul2 = fmul float %f3, %f4
+  %add = fadd contract float %mul1, %mul2
+  %second_use_of_mul1 = fdiv float %mul1, %add
+  ret float %second_use_of_mul1
+}