OSDN Git Service

[NFC][ARM][ParallelDSP] Refactor narrow sequence
authorSam Parker <sam.parker@arm.com>
Thu, 30 May 2019 15:26:37 +0000 (15:26 +0000)
committerSam Parker <sam.parker@arm.com>
Thu, 30 May 2019 15:26:37 +0000 (15:26 +0000)
Most of the code used for finding a 'narrow' sequence is not used,
so I've removed it and simplified the calls from the smlad matcher.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@362104 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/ARM/ARMParallelDSP.cpp

index beb44fb..3cff9b5 100644 (file)
@@ -248,45 +248,6 @@ namespace {
   };
 }
 
-// MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
-// instructions, which is set to 16. So here we should collect all i8 and i16
-// narrow operations.
-// TODO: we currently only collect i16, and will support i8 later, so that's
-// why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
-template<unsigned MaxBitWidth>
-static bool IsNarrowSequence(Value *V, ValueList &VL) {
-  ConstantInt *CInt;
-
-  if (match(V, m_ConstantInt(CInt))) {
-    // TODO: if a constant is used, it needs to fit within the bit width.
-    return false;
-  }
-
-  auto *I = dyn_cast<Instruction>(V);
-  if (!I)
-   return false;
-
-  Value *Val, *LHS, *RHS;
-  if (match(V, m_Trunc(m_Value(Val)))) {
-    if (cast<TruncInst>(I)->getDestTy()->getIntegerBitWidth() == MaxBitWidth)
-      return IsNarrowSequence<MaxBitWidth>(Val, VL);
-  } else if (match(V, m_Add(m_Value(LHS), m_Value(RHS)))) {
-    // TODO: we need to implement sadd16/sadd8 for this, which enables to
-    // also do the rewrite for smlad8.ll, but it is unsupported for now.
-    return false;
-  } else if (match(V, m_ZExtOrSExt(m_Value(Val)))) {
-    if (cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
-      return false;
-
-    if (match(Val, m_Load(m_Value()))) {
-      VL.push_back(Val);
-      VL.push_back(I);
-      return true;
-    }
-  }
-  return false;
-}
-
 template<typename MemInst>
 static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1,
                                   const DataLayout &DL, ScalarEvolution &SE) {
@@ -507,6 +468,18 @@ bool ARMParallelDSP::InsertParallelMACs(Reduction &Reduction) {
   return false;
 }
 
+template<typename InstType, unsigned BitWidth>
+bool IsExtendingLoad(Value *V) {
+  auto *I = dyn_cast<InstType>(V);
+  if (!I)
+    return false;
+
+  if (I->getSrcTy()->getIntegerBitWidth() != BitWidth)
+    return false;
+
+  return isa<LoadInst>(I->getOperand(0));
+}
+
 static void MatchParallelMACSequences(Reduction &R,
                                       OpChainList &Candidates) {
   Instruction *Acc = R.AccIntAdd;
@@ -526,15 +499,13 @@ static void MatchParallelMACSequences(Reduction &R,
         return true;
       break;
     case Instruction::Mul: {
-      Value *MulOp0 = I->getOperand(0);
-      Value *MulOp1 = I->getOperand(1);
-      if (isa<SExtInst>(MulOp0) && isa<SExtInst>(MulOp1)) {
-        ValueList LHS;
-        ValueList RHS;
-        if (IsNarrowSequence<16>(MulOp0, LHS) &&
-            IsNarrowSequence<16>(MulOp1, RHS)) {
-          Candidates.push_back(make_unique<BinOpChain>(I, LHS, RHS));
-        }
+      Value *Op0 = I->getOperand(0);
+      Value *Op1 = I->getOperand(1);
+      if (IsExtendingLoad<SExtInst, 16>(Op0) &&
+          IsExtendingLoad<SExtInst, 16>(Op1)) {
+        ValueList LHS = { cast<SExtInst>(Op0)->getOperand(0), Op0 };
+        ValueList RHS = { cast<SExtInst>(Op1)->getOperand(0), Op1 };
+        Candidates.push_back(make_unique<BinOpChain>(I, LHS, RHS));
       }
       return false;
     }