OSDN Git Service

[mlir] NFC - Add debug information for Linalg transformations.
authorNicolas Vasilache <ntv@google.com>
Fri, 29 May 2020 22:07:39 +0000 (18:07 -0400)
committerNicolas Vasilache <ntv@google.com>
Fri, 29 May 2020 22:35:22 +0000 (18:35 -0400)
Address post-commit review of https://reviews.llvm.org/D79518

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/IR/PatternMatch.h
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

index 2e06737..2e6a859 100644 (file)
@@ -521,7 +521,7 @@ struct LinalgCopyVTWForwardingPattern
 LogicalResult applyStagedPatterns(
     Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
     const OwningRewritePatternList &stage2Patterns,
-    llvm::function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
+    function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
 } // namespace linalg
 } // namespace mlir
 
index 6b124e0..8178f71 100644 (file)
@@ -394,7 +394,7 @@ public:
   /// type `T`.
   template <typename T>
   OwningRewritePatternList(T &&t) {
-    patterns.emplace_back(std::make_unique<T>(t));
+    patterns.emplace_back(std::make_unique<T>(std::forward<T>(t)));
   }
 
   PatternListT::iterator begin() { return patterns.begin(); }
index 527d162..76e118e 100644 (file)
@@ -37,6 +37,8 @@ using namespace mlir::linalg;
 
 using llvm::dbgs;
 
+#define DEBUG_TYPE "linalg-transforms"
+
 //===----------------------------------------------------------------------===//
 // Transformations exposed as rewrite patterns.
 //===----------------------------------------------------------------------===//
@@ -45,13 +47,13 @@ const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
     "__internal_linalg_transform__";
 
 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
-                                         llvm::Optional<StringRef> replacement)
+                                         Optional<StringRef> replacement)
     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
       replacement(replacement) {}
 
 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
                                          StringRef replacement)
-    : LinalgMarker(matchDisjunction, llvm::Optional<StringRef>{replacement}) {}
+    : LinalgMarker(matchDisjunction, Optional<StringRef>{replacement}) {}
 
 LogicalResult
 mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
@@ -72,7 +74,7 @@ mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
     // 3. Has no marker but was expecting a marker.
     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
       diag << " does not have any marker from list: ";
-      llvm::interleaveComma(matchDisjunction, diag);
+      interleaveComma(matchDisjunction, diag);
     });
   }
 
@@ -84,7 +86,7 @@ mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
   // 5. Fail to match.
   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
     diag << " does not have any marker from list: ";
-    llvm::interleaveComma(matchDisjunction, diag);
+    interleaveComma(matchDisjunction, diag);
   });
 }
 
@@ -105,7 +107,7 @@ mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
     OpBuilder::InsertionGuard guard(b);
     b.setInsertionPointToStart(
         &op->getParentOfType<FuncOp>().getBody().front());
-    return llvm::to_vector<4>(llvm::map_range(tileSizes, [&](int64_t s) {
+    return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
       Value v = b.create<ConstantIndexOp>(op->getLoc(), s);
       return v;
     }));
@@ -217,19 +219,33 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
 LogicalResult mlir::linalg::applyStagedPatterns(
     Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
     const OwningRewritePatternList &stage2Patterns,
-    llvm::function_ref<LogicalResult(Operation *)> stage3Lambda) {
+    function_ref<LogicalResult(Operation *)> stage3Lambda) {
+  unsigned iteration = 0;
+  (void)iteration;
+  StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
+  (void)dbgPref;
   for (const auto &patterns : stage1Patterns) {
     if (!applyPatternsAndFoldGreedily(op, patterns)) {
-      llvm::dbgs() << "Underlying first stage rewrite did not converge";
+      dbgs() << "Underlying first stage rewrite did not converge";
       return failure();
     }
+    LLVM_DEBUG(dbgs()
+               << dbgPref << "After 1st stage, iter: " << ++iteration << "\n"
+               << *op);
     if (!applyPatternsAndFoldGreedily(op, stage2Patterns)) {
-      llvm::dbgs() << "Underlying second stage rewrite did not converge";
+      LLVM_DEBUG(dbgs()
+                 << dbgPref << "Underlying 2nd stage rewrite did not converge");
       return failure();
     }
+    LLVM_DEBUG(dbgs()
+               << dbgPref << "After 2nd stage, iter : " << iteration << "\n"
+               << *op);
     if (stage3Lambda) {
       if (failed(stage3Lambda(op)))
         return failure();
+      LLVM_DEBUG(dbgs()
+                 << dbgPref << "After 3rd stage, iter : " << iteration << "\n"
+                 << *op);
     }
   }
   return success();