OSDN Git Service

[mlir][Vector] Fix vector.transfer alignment calculation
authorNicolas Vasilache <ntv@google.com>
Thu, 28 May 2020 21:55:21 +0000 (17:55 -0400)
committerNicolas Vasilache <ntv@google.com>
Thu, 28 May 2020 21:58:51 +0000 (17:58 -0400)
https://reviews.llvm.org/D79246 introduces alignment propagation for vector transfer operations. Unfortunately, the alignment calculation is incorrect and can result in crashes.

This revision fixes the calculation by using the natural alignment of the memref elemental type, instead of the resulting vector type.

If more alignment is desired, it can be done in 2 ways:
1. use a proper vector.type_cast to transform a memref<axbxcxdxf32> into a memref<axbxvector<cxdxf32>> giving a natural alignment of vector<cxdxf32>
2. add an alignment attribute to vector transfer operations and propagate it.

With this change the alignment in the relevant tests goes down from 128 to 4.

Lastly, a few minor cleanups are performed and the custom `isMinorIdentityMap` is deprecated.

Differential Revision: https://reviews.llvm.org/D80734

mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/include/mlir/InitAllPasses.h
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir
mlir/test/lib/Transforms/CMakeLists.txt
mlir/test/lib/Transforms/TestVectorToSCFConversion.cpp [deleted file]
mlir/tools/mlir-opt/mlir-opt.cpp

index 65d05a7..5d83184 100644 (file)
@@ -272,6 +272,20 @@ def ConvertStandardToSPIRV : Pass<"convert-std-to-spirv", "ModuleOp"> {
 }
 
 //===----------------------------------------------------------------------===//
+// VectorToSCF
+//===----------------------------------------------------------------------===//
+
+def ConvertVectorToSCF : FunctionPass<"convert-vector-to-scf"> {
+  let summary = "Lower the operations from the vector dialect into the SCF "
+                "dialect";
+  let constructor = "mlir::createConvertVectorToSCFPass()";
+  let options = [
+    Option<"fullUnroll", "full-unroll", "bool", /*default=*/"false",
+           "Perform full unrolling when converting vector transfers to SCF">,
+  ];
+}
+
+//===----------------------------------------------------------------------===//
 // VectorToLLVM
 //===----------------------------------------------------------------------===//
 
index d7a6f82..f34a576 100644 (file)
@@ -14,6 +14,7 @@
 namespace mlir {
 class MLIRContext;
 class OwningRewritePatternList;
+class Pass;
 
 /// Control whether unrolling is used when lowering vector transfer ops to SCF.
 ///
@@ -164,6 +165,10 @@ void populateVectorToSCFConversionPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context,
     const VectorTransferToSCFOptions &options = VectorTransferToSCFOptions());
 
+/// Create a pass to convert a subset of vector ops to SCF.
+std::unique_ptr<Pass> createConvertVectorToSCFPass(
+    const VectorTransferToSCFOptions &options = VectorTransferToSCFOptions());
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_VECTORTOSCF_VECTORTOSCF_H_
index 423c72d..8c8424e 100644 (file)
@@ -56,6 +56,11 @@ enum class VectorContractLowering {
 /// Structure to control the behavior of vector transform patterns.
 struct VectorTransformsOptions {
   VectorContractLowering vectorContractLowering = VectorContractLowering::FMA;
+  VectorTransformsOptions &
+  setVectorTransformsOptions(VectorContractLowering opt) {
+    vectorContractLowering = opt;
+    return *this;
+  }
 };
 
 /// Collect a set of transformation patterns that are related to contracting
index fb2ac1e..95f9ce1 100644 (file)
@@ -28,6 +28,7 @@
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
 #include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Dialect/GPU/Passes.h"
 #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
index 5b3a01c..4185eae 100644 (file)
@@ -124,6 +124,89 @@ static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
   return res;
 }
 
+template <typename TransferOp>
+LogicalResult getVectorTransferAlignment(LLVMTypeConverter &typeConverter,
+                                         TransferOp xferOp, unsigned &align) {
+  Type elementTy =
+      typeConverter.convertType(xferOp.getMemRefType().getElementType());
+  if (!elementTy)
+    return failure();
+
+  auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout();
+  align = dataLayout.getPrefTypeAlignment(
+      elementTy.cast<LLVM::LLVMType>().getUnderlyingType());
+  return success();
+}
+
+static LogicalResult
+replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
+                                 LLVMTypeConverter &typeConverter, Location loc,
+                                 TransferReadOp xferOp,
+                                 ArrayRef<Value> operands, Value dataPtr) {
+  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr);
+  return success();
+}
+
+static LogicalResult
+replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
+                            LLVMTypeConverter &typeConverter, Location loc,
+                            TransferReadOp xferOp, ArrayRef<Value> operands,
+                            Value dataPtr, Value mask) {
+  auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
+  VectorType fillType = xferOp.getVectorType();
+  Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
+  fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
+
+  Type vecTy = typeConverter.convertType(xferOp.getVectorType());
+  if (!vecTy)
+    return failure();
+
+  unsigned align;
+  if (failed(getVectorTransferAlignment(typeConverter, xferOp, align)))
+    return failure();
+
+  rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
+      xferOp, vecTy, dataPtr, mask, ValueRange{fill},
+      rewriter.getI32IntegerAttr(align));
+  return success();
+}
+
+static LogicalResult
+replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
+                                 LLVMTypeConverter &typeConverter, Location loc,
+                                 TransferWriteOp xferOp,
+                                 ArrayRef<Value> operands, Value dataPtr) {
+  auto adaptor = TransferWriteOpOperandAdaptor(operands);
+  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr);
+  return success();
+}
+
+static LogicalResult
+replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
+                            LLVMTypeConverter &typeConverter, Location loc,
+                            TransferWriteOp xferOp, ArrayRef<Value> operands,
+                            Value dataPtr, Value mask) {
+  unsigned align;
+  if (failed(getVectorTransferAlignment(typeConverter, xferOp, align)))
+    return failure();
+
+  auto adaptor = TransferWriteOpOperandAdaptor(operands);
+  rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
+      xferOp, adaptor.vector(), dataPtr, mask,
+      rewriter.getI32IntegerAttr(align));
+  return success();
+}
+
+static TransferReadOpOperandAdaptor
+getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) {
+  return TransferReadOpOperandAdaptor(operands);
+}
+
+static TransferWriteOpOperandAdaptor
+getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) {
+  return TransferWriteOpOperandAdaptor(operands);
+}
+
 namespace {
 
 /// Conversion pattern for a vector.matrix_multiply.
@@ -767,108 +850,6 @@ public:
   }
 };
 
-LogicalResult getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter,
-                                      Type type, LLVM::LLVMType &llvmType,
-                                      unsigned &align) {
-  auto convertedType = typeConverter.convertType(type);
-  if (!convertedType)
-    return failure();
-
-  llvmType = convertedType.template cast<LLVM::LLVMType>();
-  auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout();
-  align = dataLayout.getPrefTypeAlignment(llvmType.getUnderlyingType());
-  return success();
-}
-
-LogicalResult
-replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
-                                 LLVMTypeConverter &typeConverter, Location loc,
-                                 TransferReadOp xferOp,
-                                 ArrayRef<Value> operands, Value dataPtr) {
-  LLVM::LLVMType vecTy;
-  unsigned align;
-  if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
-                                     vecTy, align)))
-    return failure();
-  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr);
-  return success();
-}
-
-LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
-                                          LLVMTypeConverter &typeConverter,
-                                          Location loc, TransferReadOp xferOp,
-                                          ArrayRef<Value> operands,
-                                          Value dataPtr, Value mask) {
-  auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
-  VectorType fillType = xferOp.getVectorType();
-  Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
-  fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
-
-  LLVM::LLVMType vecTy;
-  unsigned align;
-  if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
-                                     vecTy, align)))
-    return failure();
-
-  rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
-      xferOp, vecTy, dataPtr, mask, ValueRange{fill},
-      rewriter.getI32IntegerAttr(align));
-  return success();
-}
-
-LogicalResult
-replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
-                                 LLVMTypeConverter &typeConverter, Location loc,
-                                 TransferWriteOp xferOp,
-                                 ArrayRef<Value> operands, Value dataPtr) {
-  auto adaptor = TransferWriteOpOperandAdaptor(operands);
-  LLVM::LLVMType vecTy;
-  unsigned align;
-  if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
-                                     vecTy, align)))
-    return failure();
-  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr);
-  return success();
-}
-
-LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
-                                          LLVMTypeConverter &typeConverter,
-                                          Location loc, TransferWriteOp xferOp,
-                                          ArrayRef<Value> operands,
-                                          Value dataPtr, Value mask) {
-  auto adaptor = TransferWriteOpOperandAdaptor(operands);
-  LLVM::LLVMType vecTy;
-  unsigned align;
-  if (failed(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
-                                     vecTy, align)))
-    return failure();
-
-  rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
-      xferOp, adaptor.vector(), dataPtr, mask,
-      rewriter.getI32IntegerAttr(align));
-  return success();
-}
-
-static TransferReadOpOperandAdaptor
-getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) {
-  return TransferReadOpOperandAdaptor(operands);
-}
-
-static TransferWriteOpOperandAdaptor
-getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) {
-  return TransferWriteOpOperandAdaptor(operands);
-}
-
-bool isMinorIdentity(AffineMap map, unsigned rank) {
-  if (map.getNumResults() < rank)
-    return false;
-  unsigned startDim = map.getNumDims() - rank;
-  for (unsigned i = 0; i < rank; ++i)
-    if (map.getResult(i) != getAffineDimExpr(startDim + i, map.getContext()))
-      return false;
-  return true;
-}
-
 /// Conversion pattern that converts a 1-D vector transfer read/write op in a
 /// sequence of:
 /// 1. Bitcast or addrspacecast to vector form.
@@ -892,8 +873,10 @@ public:
     if (xferOp.getVectorType().getRank() > 1 ||
         llvm::size(xferOp.indices()) == 0)
       return failure();
-    if (!isMinorIdentity(xferOp.permutation_map(),
-                         xferOp.getVectorType().getRank()))
+    if (xferOp.permutation_map() !=
+        AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(),
+                                       xferOp.getVectorType().getRank(),
+                                       op->getContext()))
       return failure();
 
     auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
index 8c72800..6816bc7 100644 (file)
@@ -13,6 +13,8 @@
 #include <type_traits>
 
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
+
+#include "../PassDetail.h"
 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
 #include "mlir/Dialect/SCF/EDSC/Builders.h"
 #include "mlir/Dialect/SCF/EDSC/Intrinsics.h"
@@ -29,6 +31,8 @@
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Types.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
 
 using namespace mlir;
 using namespace mlir::edsc;
@@ -349,7 +353,7 @@ LogicalResult NDTransferOpHelper<TransferWriteOp>::doReplace() {
 }
 
 } // namespace
-  
+
 /// Analyzes the `transfer` to find an access dimension along the fastest remote
 /// MemRef dimension. If such a dimension with coalescing properties is found,
 /// `pivs` and `vectorBoundsCapture` are swapped so that the invocation of
@@ -435,7 +439,7 @@ clip(TransferOpTy transfer, MemRefBoundsCapture &bounds, ArrayRef<Value> ivs) {
 }
 
 namespace mlir {
-  
+
 template <typename TransferOpTy>
 VectorTransferRewriter<TransferOpTy>::VectorTransferRewriter(
     VectorTransferToSCFOptions options, MLIRContext *context)
@@ -631,3 +635,28 @@ void populateVectorToSCFConversionPatterns(
 
 } // namespace mlir
 
+namespace {
+
+struct ConvertVectorToSCFPass
+    : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
+  ConvertVectorToSCFPass() = default;
+  ConvertVectorToSCFPass(const ConvertVectorToSCFPass &pass) {}
+  ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
+    this->fullUnroll = options.unroll;
+  }
+
+  void runOnFunction() override {
+    OwningRewritePatternList patterns;
+    auto *context = getFunction().getContext();
+    populateVectorToSCFConversionPatterns(
+        patterns, context, VectorTransferToSCFOptions().setUnroll(fullUnroll));
+    applyPatternsAndFoldGreedily(getFunction(), patterns);
+  }
+};
+
+} // namespace
+
+std::unique_ptr<Pass>
+mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
+  return std::make_unique<ConvertVectorToSCFPass>(options);
+}
index 6150ac7..3662c24 100644 (file)
@@ -818,7 +818,7 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
 //       CHECK: %[[PASS_THROUGH:.*]] =  llvm.mlir.constant(dense<7.000000e+00> :
 //  CHECK-SAME:  vector<17xf32>) : !llvm<"<17 x float>">
 //       CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]],
-//  CHECK-SAME: %[[PASS_THROUGH]] {alignment = 128 : i32} :
+//  CHECK-SAME: %[[PASS_THROUGH]] {alignment = 4 : i32} :
 //  CHECK-SAME: (!llvm<"<17 x float>*">, !llvm<"<17 x i1>">, !llvm<"<17 x float>">) -> !llvm<"<17 x float>">
 
 //
@@ -850,7 +850,7 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
 //
 // 5. Rewrite as a masked write.
 //       CHECK: llvm.intr.masked.store %[[loaded]], %[[vecPtr_b]], %[[mask_b]]
-//  CHECK-SAME: {alignment = 128 : i32} :
+//  CHECK-SAME: {alignment = 4 : i32} :
 //  CHECK-SAME: !llvm<"<17 x float>">, !llvm<"<17 x i1>"> into !llvm<"<17 x float>*">
 
 func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index) -> vector<17xf32> {
index dc35058..d4f22d2 100644 (file)
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-convert-vector-to-scf -split-input-file | FileCheck %s
-// RUN: mlir-opt %s -test-convert-vector-to-scf=full-unroll=true -split-input-file | FileCheck %s --check-prefix=FULL-UNROLL
+// RUN: mlir-opt %s -convert-vector-to-scf -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-scf=full-unroll=true -split-input-file | FileCheck %s --check-prefix=FULL-UNROLL
 
 // CHECK-LABEL: func @materialize_read_1d() {
 func @materialize_read_1d() {
index 4ea7498..6069570 100644 (file)
@@ -20,7 +20,6 @@ add_mlir_library(MLIRTestTransforms
   TestMemRefBoundCheck.cpp
   TestMemRefDependenceCheck.cpp
   TestMemRefStrideCalculation.cpp
-  TestVectorToSCFConversion.cpp
   TestVectorTransforms.cpp
 
   EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Transforms/TestVectorToSCFConversion.cpp b/mlir/test/lib/Transforms/TestVectorToSCFConversion.cpp
deleted file mode 100644 (file)
index 7a83e20..0000000
+++ /dev/null
@@ -1,48 +0,0 @@
-//===- TestVectorToSCFConversion.cpp - Test VectorTransfers lowering ------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include <type_traits>
-
-#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/Passes.h"
-
-using namespace mlir;
-
-namespace {
-
-struct TestVectorToSCFPass
-    : public PassWrapper<TestVectorToSCFPass, FunctionPass> {
-  TestVectorToSCFPass() = default;
-  TestVectorToSCFPass(const TestVectorToSCFPass &pass) {}
-
-  Option<bool> fullUnroll{
-      *this, "full-unroll",
-      llvm::cl::desc(
-          "Perform full unrolling when converting vector transfers to SCF"),
-      llvm::cl::init(false)};
-
-  void runOnFunction() override {
-    OwningRewritePatternList patterns;
-    auto *context = &getContext();
-    populateVectorToSCFConversionPatterns(
-        patterns, context, VectorTransferToSCFOptions().setUnroll(fullUnroll));
-    applyPatternsAndFoldGreedily(getFunction(), patterns);
-  }
-};
-
-} // end anonymous namespace
-
-namespace mlir {
-void registerTestVectorToSCFPass() {
-  PassRegistration<TestVectorToSCFPass> pass(
-      "test-convert-vector-to-scf",
-      "Converts vector transfer ops to loops over scalars and vector casts");
-}
-} // namespace mlir
index 159a7fd..2764b23 100644 (file)
@@ -62,7 +62,6 @@ void registerTestOpaqueLoc();
 void registerTestParallelismDetection();
 void registerTestGpuParallelLoopMappingPass();
 void registerTestVectorConversions();
-void registerTestVectorToSCFPass();
 void registerVectorizerTestPass();
 } // namespace mlir
 
@@ -133,7 +132,6 @@ void registerTestPasses() {
   registerTestParallelismDetection();
   registerTestGpuParallelLoopMappingPass();
   registerTestVectorConversions();
-  registerTestVectorToSCFPass();
   registerVectorizerTestPass();
 }
 #endif