OSDN Git Service

enhance the new isel to use SelectNodeTo for most patterns,
authorChris Lattner <sabre@nondot.org>
Sun, 28 Feb 2010 20:49:53 +0000 (20:49 +0000)
committerChris Lattner <sabre@nondot.org>
Sun, 28 Feb 2010 20:49:53 +0000 (20:49 +0000)
even some the old isel didn't.  There are several parts of
this that make me feel dirty, but it's no worse than the
old isel.  I'll clean up the parts I can do without ripping
out the old one next.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@97415 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/CodeGen/DAGISelHeader.h
utils/TableGen/DAGISelEmitter.cpp
utils/TableGen/DAGISelMatcher.cpp
utils/TableGen/DAGISelMatcher.h
utils/TableGen/DAGISelMatcherEmitter.cpp
utils/TableGen/DAGISelMatcherGen.cpp
utils/TableGen/DAGISelMatcherOpt.cpp

index 67b4155..c4b6a71 100644 (file)
@@ -217,6 +217,51 @@ GetVBR(unsigned Val, const unsigned char *MatcherTable, unsigned &Idx) {
   return Val;
 }
 
+/// UpdateChainsAndFlags - When a match is complete, this method updates uses of
+/// interior flag and chain results to use the new flag and chain results.
+void UpdateChainsAndFlags(SDNode *NodeToMatch, SDValue InputChain,
+                          const SmallVectorImpl<SDNode*> &ChainNodesMatched,
+                          SDValue InputFlag,
+                          const SmallVectorImpl<SDNode*>&FlagResultNodesMatched,
+                          bool isSelectNodeTo) {
+  // Now that all the normal results are replaced, we replace the chain and
+  // flag results if present.
+  if (!ChainNodesMatched.empty()) {
+    assert(InputChain.getNode() != 0 &&
+           "Matched input chains but didn't produce a chain");
+    // Loop over all of the nodes we matched that produced a chain result.
+    // Replace all the chain results with the final chain we ended up with.
+    for (unsigned i = 0, e = ChainNodesMatched.size(); i != e; ++i) {
+      SDNode *ChainNode = ChainNodesMatched[i];
+      
+      // Don't replace the results of the root node if we're doing a
+      // SelectNodeTo.
+      if (ChainNode == NodeToMatch && isSelectNodeTo)
+        continue;
+      
+      SDValue ChainVal = SDValue(ChainNode, ChainNode->getNumValues()-1);
+      if (ChainVal.getValueType() == MVT::Flag)
+        ChainVal = ChainVal.getValue(ChainVal->getNumValues()-2);
+      assert(ChainVal.getValueType() == MVT::Other && "Not a chain?");
+      ReplaceUses(ChainVal, InputChain);
+    }
+  }
+  
+  // If the result produces a flag, update any flag results in the matched
+  // pattern with the flag result.
+  if (InputFlag.getNode() != 0) {
+    // Handle any interior nodes explicitly marked.
+    for (unsigned i = 0, e = FlagResultNodesMatched.size(); i != e; ++i) {
+      SDNode *FRN = FlagResultNodesMatched[i];
+      assert(FRN->getValueType(FRN->getNumValues()-1) == MVT::Flag &&
+             "Doesn't have a flag result");
+      ReplaceUses(SDValue(FRN, FRN->getNumValues()-1), InputFlag);
+    }
+  }
+  
+  DEBUG(errs() << "ISEL: Match complete!\n");
+}
+
 
 enum BuiltinOpcodes {
   OPC_Scope,
@@ -252,6 +297,7 @@ enum BuiltinOpcodes {
   OPC_EmitCopyToReg,
   OPC_EmitNodeXForm,
   OPC_EmitNode,
+  OPC_SelectNodeTo,
   OPC_MarkFlagResults,
   OPC_CompleteMatch
 };
@@ -741,7 +787,8 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
       continue;
     }
         
-    case OPC_EmitNode: {
+    case OPC_EmitNode:
+    case OPC_SelectNodeTo: {
       uint16_t TargetOpc = GetInt2(MatcherTable, MatcherIndex);
       unsigned EmitNodeInfo = MatcherTable[MatcherIndex++];
       // Get the result VT list.
@@ -794,14 +841,54 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
         Ops.push_back(InputFlag);
       
       // Create the node.
-      MachineSDNode *Res = CurDAG->getMachineNode(TargetOpc,
-                                                  NodeToMatch->getDebugLoc(),
-                                                  VTList,
-                                                  Ops.data(), Ops.size());
-      // Add all the non-flag/non-chain results to the RecordedNodes list.
-      for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
-        if (VTs[i] == MVT::Other || VTs[i] == MVT::Flag) break;
-        RecordedNodes.push_back(SDValue(Res, i));
+      SDNode *Res = 0;
+      if (Opcode == OPC_SelectNodeTo) {
+        // It is possible we're using SelectNodeTo to replace a node with no
+        // normal results with one that has a normal result (or we could be
+        // adding a chain) and the input could have flags and chains as well.
+        // In this case we need to shifting the operands down.
+        // FIXME: This is a horrible hack and broken in obscure cases, no worse
+        // than the old isel though.  We should sink this into SelectNodeTo.
+        int OldFlagResultNo = -1, OldChainResultNo = -1;
+        
+        unsigned NTMNumResults = NodeToMatch->getNumValues();
+        if (NodeToMatch->getValueType(NTMNumResults-1) == MVT::Flag) {
+          OldFlagResultNo = NTMNumResults-1;
+          if (NTMNumResults != 1 &&
+              NodeToMatch->getValueType(NTMNumResults-2) == MVT::Other)
+            OldChainResultNo = NTMNumResults-2;
+        } else if (NodeToMatch->getValueType(NTMNumResults-1) == MVT::Other)
+          OldChainResultNo = NTMNumResults-1;
+        
+        Res = CurDAG->SelectNodeTo(NodeToMatch, TargetOpc, VTList,
+                                   Ops.data(), Ops.size());
+        
+        // FIXME: Whether the selected node has a flag result should come from
+        // flags on the node.
+        unsigned ResNumResults = Res->getNumValues();
+        if (Res->getValueType(ResNumResults-1) == MVT::Flag) {
+          // Move the flag if needed.
+          if (OldFlagResultNo != -1 &&
+              (unsigned)OldFlagResultNo != ResNumResults-1)
+            ReplaceUses(SDValue(Res, OldFlagResultNo), 
+                        SDValue(Res, ResNumResults-1));
+          --ResNumResults;
+        }
+
+        // Move the chain reference if needed.
+        if ((EmitNodeInfo & OPFL_Chain) && OldChainResultNo != -1 &&
+            (unsigned)OldChainResultNo != ResNumResults-1)
+          ReplaceUses(SDValue(Res, OldChainResultNo), 
+                      SDValue(Res, ResNumResults-1));
+      } else {
+        Res = CurDAG->getMachineNode(TargetOpc, NodeToMatch->getDebugLoc(),
+                                     VTList, Ops.data(), Ops.size());
+      
+        // Add all the non-flag/non-chain results to the RecordedNodes list.
+        for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
+          if (VTs[i] == MVT::Other || VTs[i] == MVT::Flag) break;
+          RecordedNodes.push_back(SDValue(Res, i));
+        }
       }
       
       // If the node had chain/flag results, update our notion of the current
@@ -823,10 +910,22 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
         MachineSDNode::mmo_iterator MemRefs =
           MF->allocateMemRefsArray(MatchedMemRefs.size());
         std::copy(MatchedMemRefs.begin(), MatchedMemRefs.end(), MemRefs);
-        Res->setMemRefs(MemRefs, MemRefs + MatchedMemRefs.size());
+        cast<MachineSDNode>(Res)
+          ->setMemRefs(MemRefs, MemRefs + MatchedMemRefs.size());
+      }
+      
+      DEBUG(errs() << "  "
+                   << (Opcode == OPC_SelectNodeTo ? "Selected" : "Created")
+                   << " node: "; Res->dump(CurDAG); errs() << "\n");
+      
+      // If this was a SelectNodeTo then we're completely done!
+      if (Opcode == OPC_SelectNodeTo) {
+        // Update chain and flag uses.
+        UpdateChainsAndFlags(NodeToMatch, InputChain, ChainNodesMatched,
+                             InputFlag, FlagResultNodesMatched, true);
+        return Res;
       }
       
-      DEBUG(errs() << "  Created node: "; Res->dump(CurDAG); errs() << "\n");
       continue;
     }
         
@@ -875,47 +974,19 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
                "invalid replacement");
         ReplaceUses(SDValue(NodeToMatch, i), Res);
       }
-      
-      // Now that all the normal results are replaced, we replace the chain and
-      // flag results if present.
-      if (!ChainNodesMatched.empty()) {
-        assert(InputChain.getNode() != 0 &&
-               "Matched input chains but didn't produce a chain");
-        // Loop over all of the nodes we matched that produced a chain result.
-        // Replace all the chain results with the final chain we ended up with.
-        for (unsigned i = 0, e = ChainNodesMatched.size(); i != e; ++i) {
-          SDNode *ChainNode = ChainNodesMatched[i];
-          SDValue ChainVal = SDValue(ChainNode, ChainNode->getNumValues()-1);
-          if (ChainVal.getValueType() == MVT::Flag)
-            ChainVal = ChainVal.getValue(ChainVal->getNumValues()-2);
-          assert(ChainVal.getValueType() == MVT::Other && "Not a chain?");
-          ReplaceUses(ChainVal, InputChain);
-        }
-      }
 
-      // If the result produces a flag, update any flag results in the matched
-      // pattern with the flag result.
-      if (InputFlag.getNode() != 0) {
-        // Handle the root node:
-        if (NodeToMatch->getValueType(NodeToMatch->getNumValues()-1) ==
-              MVT::Flag)
-          ReplaceUses(SDValue(NodeToMatch, NodeToMatch->getNumValues()-1),
-                      InputFlag);
-        
-        // Handle any interior nodes explicitly marked.
-        for (unsigned i = 0, e = FlagResultNodesMatched.size(); i != e; ++i) {
-          SDNode *FRN = FlagResultNodesMatched[i];
-          assert(FRN->getValueType(FRN->getNumValues()-1) == MVT::Flag &&
-                 "Doesn't have a flag result");
-          ReplaceUses(SDValue(FRN, FRN->getNumValues()-1), InputFlag);
-        }
-      }
+      // If the root node defines a flag, add it to the flag nodes to update
+      // list.
+      if (NodeToMatch->getValueType(NodeToMatch->getNumValues()-1) == MVT::Flag)
+        FlagResultNodesMatched.push_back(NodeToMatch);
+      
+      // Update chain and flag uses.
+      UpdateChainsAndFlags(NodeToMatch, InputChain, ChainNodesMatched,
+                           InputFlag, FlagResultNodesMatched, false);
       
       assert(NodeToMatch->use_empty() &&
              "Didn't replace all uses of the node?");
       
-      DEBUG(errs() << "ISEL: Match complete!\n");
-      
       // FIXME: We just return here, which interacts correctly with SelectRoot
       // above.  We should fix this to not return an SDNode* anymore.
       return 0;
index 5e2b07d..2ea8bf0 100644 (file)
@@ -1966,7 +1966,7 @@ void DAGISelEmitter::run(raw_ostream &OS) {
   Matcher *TheMatcher = new ScopeMatcher(&PatternMatchers[0],
                                          PatternMatchers.size());
 
-  TheMatcher = OptimizeMatcher(TheMatcher);
+  TheMatcher = OptimizeMatcher(TheMatcher, CGP);
   //Matcher->dump();
   EmitMatcherTable(TheMatcher, OS);
   delete TheMatcher;
index d939edb..085682f 100644 (file)
@@ -39,8 +39,12 @@ ScopeMatcher::~ScopeMatcher() {
 
 void ScopeMatcher::printImpl(raw_ostream &OS, unsigned indent) const {
   OS.indent(indent) << "Scope\n";
-  for (unsigned i = 0, e = getNumChildren(); i != e; ++i)
-    getChild(i)->print(OS, indent+2);
+  for (unsigned i = 0, e = getNumChildren(); i != e; ++i) {
+    if (getChild(i) == 0)
+      OS.indent(indent+1) << "NULL POINTER\n";
+    else
+      getChild(i)->print(OS, indent+2);
+  }
 }
 
 void RecordMatcher::printImpl(raw_ostream &OS, unsigned indent) const {
index b91b591..0d674c5 100644 (file)
@@ -27,7 +27,7 @@ namespace llvm {
 
 Matcher *ConvertPatternToMatcher(const PatternToMatch &Pattern,
                                  const CodeGenDAGPatterns &CGP);
-Matcher *OptimizeMatcher(Matcher *Matcher);
+Matcher *OptimizeMatcher(Matcher *Matcher, const CodeGenDAGPatterns &CGP);
 void EmitMatcherTable(const Matcher *Matcher, raw_ostream &OS);
 
   
@@ -900,12 +900,25 @@ public:
     assert(i < VTs.size());
     return VTs[i];
   }
+
+  /// getNumNonChainFlagVTs - Return the number of normal results that this node
+  /// will have, ignoring flag and chain results.
+  unsigned getNumNonChainFlagVTs() const {
+    for (unsigned i = 0, e = getNumVTs(); i != e; ++i)
+      if (VTs[i] == MVT::Flag || VTs[i] == MVT::Other)
+        return i;
+    return getNumVTs();
+  }
   
   unsigned getNumOperands() const { return Operands.size(); }
   unsigned getOperand(unsigned i) const {
     assert(i < Operands.size());
     return Operands[i];
-  }  
+  }
+  
+  const SmallVectorImpl<MVT::SimpleValueType> &getVTList() const { return VTs; }
+  const SmallVectorImpl<unsigned> &getOperandList() const { return Operands; }
+
   
   bool hasChain() const { return HasChain; }
   bool hasFlag() const { return HasFlag; }
@@ -999,7 +1012,7 @@ class CompleteMatchMatcher : public Matcher {
   const PatternToMatch &Pattern;
 public:
   CompleteMatchMatcher(const unsigned *results, unsigned numresults,
-                           const PatternToMatch &pattern)
+                       const PatternToMatch &pattern)
   : Matcher(CompleteMatch), Results(results, results+numresults),
     Pattern(pattern) {}
 
index 942a612..aec1e18 100644 (file)
@@ -416,18 +416,23 @@ EmitMatcher(const Matcher *N, unsigned Indent, unsigned CurrentIdx,
     
     // Print the result #'s for EmitNode.
     if (const EmitNodeMatcher *E = dyn_cast<EmitNodeMatcher>(EN)) {
-      if (EN->getVT(0) != MVT::Flag && EN->getVT(0) != MVT::Other) {
+      if (unsigned NumResults = EN->getNumNonChainFlagVTs()) {
         OS.PadToColumn(CommentIndent) << "// Results = ";
         unsigned First = E->getFirstResultSlot();
-        for (unsigned i = 0, e = EN->getNumVTs(); i != e; ++i) {
-          if (EN->getVT(0) == MVT::Flag || EN->getVT(0) == MVT::Other)
-            break;
+        for (unsigned i = 0; i != NumResults; ++i)
           OS << "#" << First+i << " ";
-        }
       }
     }
-    
     OS << '\n';
+    
+    if (const SelectNodeToMatcher *SNT = dyn_cast<SelectNodeToMatcher>(N)) {
+      OS.PadToColumn(Indent*2) << "// Src: "
+      << *SNT->getPattern().getSrcPattern() << '\n';
+      OS.PadToColumn(Indent*2) << "// Dst: " 
+      << *SNT->getPattern().getDstPattern() << '\n';
+      
+    }
+    
     return 6+EN->getNumVTs()+NumOperandBytes;
   }
   case Matcher::MarkFlagResults: {
index 18735de..c558eba 100644 (file)
@@ -741,10 +741,6 @@ EmitResultInstructionAsOperand(const TreePatternNode *N,
   bool NodeHasMemRefs =
     isRoot && Pattern.getSrcPattern()->TreeHasProperty(SDNPMemOperand, CGP);
 
-  // FIXME: Eventually add a SelectNodeTo form.  It works if the new node has a
-  // superset of the results of the old node, in the same places.  E.g. turning
-  // (add (load)) -> add32rm is ok because result #0 is the result and result #1
-  // is new.
   AddMatcher(new EmitNodeMatcher(II.Namespace+"::"+II.TheDef->getName(),
                                  ResultVTs.data(), ResultVTs.size(),
                                  InstOps.data(), InstOps.size(),
@@ -757,9 +753,6 @@ EmitResultInstructionAsOperand(const TreePatternNode *N,
     if (ResultVTs[i] == MVT::Other || ResultVTs[i] == MVT::Flag) break;
     OutputOps.push_back(NextRecordedOperandNo++);
   }
-  
-  // FIXME2: Kill off all the SelectionDAG::SelectNodeTo and getMachineNode
-  // variants.  Call MorphNodeTo instead of SelectNodeTo.
 }
 
 void MatcherGen::
@@ -851,7 +844,7 @@ void MatcherGen::EmitResultCode() {
 
 
 Matcher *llvm::ConvertPatternToMatcher(const PatternToMatch &Pattern,
-                                           const CodeGenDAGPatterns &CGP) {
+                                       const CodeGenDAGPatterns &CGP) {
   MatcherGen Gen(Pattern, CGP);
 
   // Generate the code for the matcher.
index 12c4c1b..2ea178f 100644 (file)
@@ -13,6 +13,7 @@
 
 #define DEBUG_TYPE "isel-opt"
 #include "DAGISelMatcher.h"
+#include "CodeGenDAGPatterns.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
@@ -21,7 +22,8 @@ using namespace llvm;
 
 /// ContractNodes - Turn multiple matcher node patterns like 'MoveChild+Record'
 /// into single compound nodes like RecordChild.
-static void ContractNodes(OwningPtr<Matcher> &MatcherPtr) {
+static void ContractNodes(OwningPtr<Matcher> &MatcherPtr,
+                          const CodeGenDAGPatterns &CGP) {
   // If we reached the end of the chain, we're done.
   Matcher *N = MatcherPtr.get();
   if (N == 0) return;
@@ -30,7 +32,7 @@ static void ContractNodes(OwningPtr<Matcher> &MatcherPtr) {
   if (ScopeMatcher *Scope = dyn_cast<ScopeMatcher>(N)) {
     for (unsigned i = 0, e = Scope->getNumChildren(); i != e; ++i) {
       OwningPtr<Matcher> Child(Scope->takeChild(i));
-      ContractNodes(Child);
+      ContractNodes(Child, CGP);
       Scope->resetChild(i, Child.take());
     }
     return;
@@ -52,7 +54,7 @@ static void ContractNodes(OwningPtr<Matcher> &MatcherPtr) {
       MatcherPtr.reset(New);
       // Remove the old one.
       MC->setNext(MC->getNext()->takeNext());
-      return ContractNodes(MatcherPtr);
+      return ContractNodes(MatcherPtr, CGP);
     }
   }
   
@@ -61,17 +63,69 @@ static void ContractNodes(OwningPtr<Matcher> &MatcherPtr) {
     if (MoveParentMatcher *MP = 
           dyn_cast<MoveParentMatcher>(MC->getNext())) {
       MatcherPtr.reset(MP->takeNext());
-      return ContractNodes(MatcherPtr);
+      return ContractNodes(MatcherPtr, CGP);
     }
   
   // Turn EmitNode->CompleteMatch into SelectNodeTo if we can.
   if (EmitNodeMatcher *EN = dyn_cast<EmitNodeMatcher>(N))
     if (CompleteMatchMatcher *CM =
           dyn_cast<CompleteMatchMatcher>(EN->getNext())) {
-      (void)CM;
+      // We can only use SelectNodeTo if the result values match up.
+      unsigned RootResultFirst = EN->getFirstResultSlot();
+      bool ResultsMatch = true;
+      for (unsigned i = 0, e = CM->getNumResults(); i != e; ++i)
+        if (CM->getResult(i) != RootResultFirst+i)
+          ResultsMatch = false;
+      
+      // If the selected node defines a subset of the flag/chain results, we
+      // can't use SelectNodeTo.  For example, we can't use SelectNodeTo if the
+      // matched pattern has a chain but the root node doesn't.
+      const PatternToMatch &Pattern = CM->getPattern();
+      
+      if (!EN->hasChain() &&
+          Pattern.getSrcPattern()->NodeHasProperty(SDNPHasChain, CGP))
+        ResultsMatch = false;
+
+      // If the matched node has a flag and the output root doesn't, we can't
+      // use SelectNodeTo.
+      //
+      // NOTE: Strictly speaking, we don't have to check for the flag here
+      // because the code in the pattern generator doesn't handle it right.  We
+      // do it anyway for thoroughness.
+      if (!EN->hasFlag() &&
+          Pattern.getSrcPattern()->NodeHasProperty(SDNPOutFlag, CGP))
+        ResultsMatch = false;
+      
+      
+      // If the root result node defines more results than the source root node
+      // *and* has a chain or flag input, then we can't match it because it
+      // would end up replacing the extra result with the chain/flag.
+#if 0
+      if ((EN->hasFlag() || EN->hasChain()) &&
+          EN->getNumNonChainFlagVTs() > ... need to get no results reliably ...)
+        ResultMatch = false;
+#endif
+          
+      if (ResultsMatch) {
+        const SmallVectorImpl<MVT::SimpleValueType> &VTs = EN->getVTList();
+        const SmallVectorImpl<unsigned> &Operands = EN->getOperandList();
+        MatcherPtr.reset(new SelectNodeToMatcher(EN->getOpcodeName(),
+                                                 &VTs[0], VTs.size(),
+                                               Operands.data(), Operands.size(),
+                                                 EN->hasChain(), EN->hasFlag(),
+                                                 EN->hasMemRefs(),
+                                                 EN->getNumFixedArityOperands(),
+                                                 Pattern));
+        return;
+      }
+
+      // FIXME: Handle OPC_MarkFlagResults.
+      
+      // FIXME2: Kill off all the SelectionDAG::SelectNodeTo and getMachineNode
+      // variants.  Call MorphNodeTo instead of SelectNodeTo.
     }
   
-  ContractNodes(N->getNextPtr());
+  ContractNodes(N->getNextPtr(), CGP);
 }
 
 /// SinkPatternPredicates - Pattern predicates can be checked at any level of
@@ -253,6 +307,8 @@ static void FactorNodes(OwningPtr<Matcher> &MatcherPtr) {
 
   // Reassemble a new Scope node.
   assert(!NewOptionsToMatch.empty() && "where'd all our children go?");
+  if (NewOptionsToMatch.empty())
+    MatcherPtr.reset(0);
   if (NewOptionsToMatch.size() == 1)
     MatcherPtr.reset(NewOptionsToMatch[0]);
   else {
@@ -262,9 +318,10 @@ static void FactorNodes(OwningPtr<Matcher> &MatcherPtr) {
   }
 }
 
-Matcher *llvm::OptimizeMatcher(Matcher *TheMatcher) {
+Matcher *llvm::OptimizeMatcher(Matcher *TheMatcher,
+                               const CodeGenDAGPatterns &CGP) {
   OwningPtr<Matcher> MatcherPtr(TheMatcher);
-  ContractNodes(MatcherPtr);
+  ContractNodes(MatcherPtr, CGP);
   SinkPatternPredicates(MatcherPtr);
   FactorNodes(MatcherPtr);
   return MatcherPtr.take();