Function &Func;
const DataLayout &DL;
const TargetTransformInfo &TTI;
- AliasAnalysis &AA;
- DominatorTree &DT;
- LoopInfo &LI;
- OptimizationRemarkEmitter &ORE;
+ AliasAnalysis *AA;
+ DominatorTree *DT;
+ LoopInfo *LI;
+ OptimizationRemarkEmitter *ORE;
/// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
struct OpInfoTy {
public:
LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
- AliasAnalysis &AA, DominatorTree &DT, LoopInfo &LI,
- OptimizationRemarkEmitter &ORE)
+ AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI,
+ OptimizationRemarkEmitter *ORE)
: Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
LI(LI), ORE(ORE) {}
Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
}
- RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, Func);
- RemarkGen.emitRemarks();
+ if (ORE) {
+ RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
+ RemarkGen.emitRemarks();
+ }
for (Instruction *Inst : reverse(ToRemove))
Inst->eraseFromParent();
MemoryLocation StoreLoc = MemoryLocation::get(Store);
MemoryLocation LoadLoc = MemoryLocation::get(Load);
- AliasResult LdAliased = AA.alias(LoadLoc, StoreLoc);
+ AliasResult LdAliased = AA->alias(LoadLoc, StoreLoc);
// If we can statically determine noalias we're good.
if (!LdAliased)
// as we adjust Check0 and Check1's branches.
SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
for (BasicBlock *Succ : successors(Check0))
- DTUpdates.push_back({DT.Delete, Check0, Succ});
+ DTUpdates.push_back({DT->Delete, Check0, Succ});
- BasicBlock *Check1 = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI,
+ BasicBlock *Check1 = SplitBlock(MatMul->getParent(), MatMul, nullptr, LI,
nullptr, "alias_cont");
BasicBlock *Copy =
- SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, nullptr, "copy");
- BasicBlock *Fusion = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI,
+ SplitBlock(MatMul->getParent(), MatMul, nullptr, LI, nullptr, "copy");
+ BasicBlock *Fusion = SplitBlock(MatMul->getParent(), MatMul, nullptr, LI,
nullptr, "no_alias");
// Check if the loaded memory location begins before the end of the store
PHI->addIncoming(NewLd, Copy);
// Adjust DT.
- DTUpdates.push_back({DT.Insert, Check0, Check1});
- DTUpdates.push_back({DT.Insert, Check0, Fusion});
- DTUpdates.push_back({DT.Insert, Check1, Copy});
- DTUpdates.push_back({DT.Insert, Check1, Fusion});
- DT.applyUpdates(DTUpdates);
+ DTUpdates.push_back({DT->Insert, Check0, Check1});
+ DTUpdates.push_back({DT->Insert, Check0, Fusion});
+ DTUpdates.push_back({DT->Insert, Check1, Copy});
+ DTUpdates.push_back({DT->Insert, Check1, Fusion});
+ DT->applyUpdates(DTUpdates);
return PHI;
}
void LowerMatrixMultiplyFused(CallInst *MatMul,
SmallPtrSetImpl<Instruction *> &FusedInsts) {
if (!FuseMatrix || !MatMul->hasOneUse() ||
- MatrixLayout != MatrixLayoutTy::ColumnMajor)
+ MatrixLayout != MatrixLayoutTy::ColumnMajor || !DT)
return;
auto *LoadOp0 = dyn_cast<LoadInst>(MatMul->getOperand(0));
// we create invalid IR.
// FIXME: See if we can hoist the store address computation.
auto *AddrI = dyn_cast<Instruction>(Store->getOperand(1));
- if (AddrI && (!DT.dominates(AddrI, MatMul)))
+ if (AddrI && (!DT->dominates(AddrI, MatMul)))
return;
emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
auto &LI = AM.getResult<LoopAnalysis>(F);
- LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
+ LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE);
if (LMT.Visit()) {
PreservedAnalyses PA;
PA.preserveSet<CFGAnalyses>();
auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
+ LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE);
bool C = LMT.Visit();
return C;
}
Pass *llvm::createLowerMatrixIntrinsicsPass() {
return new LowerMatrixIntrinsicsLegacyPass();
}
+
+namespace {
+
+/// A lightweight version of the matrix lowering pass that only requires TTI.
+/// Advanced features that require DT, AA or ORE like tiling are disabled. This
+/// is used to lower matrix intrinsics if the main lowering pass is not run, for
+/// example with -O0.
+class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass {
+public:
+ static char ID;
+
+ LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) {
+ initializeLowerMatrixIntrinsicsMinimalLegacyPassPass(
+ *PassRegistry::getPassRegistry());
+ }
+
+ bool runOnFunction(Function &F) override {
+ auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+ LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr);
+ bool C = LMT.Visit();
+ return C;
+ }
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<TargetTransformInfoWrapperPass>();
+ AU.setPreservesCFG();
+ }
+};
+} // namespace
+
+static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)";
+char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0;
+INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass,
+ "lower-matrix-intrinsics-minimal", pass_name_minimal,
+ false, false)
+INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass,
+ "lower-matrix-intrinsics-minimal", pass_name_minimal, false,
+ false)
+
+Pass *llvm::createLowerMatrixIntrinsicsMinimalPass() {
+ return new LowerMatrixIntrinsicsMinimalLegacyPass();
+}
--- /dev/null
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -lower-matrix-intrinsics-minimal -fuse-matrix-tile-size=2 -matrix-allow-contract -force-fuse-matrix -instcombine -verify-dom-info %s -S | FileCheck %s
+
+; Test for the minimal version of the matrix lowering pass, which does not
+; require DT or AA. Make sure no tiling is happening, even though it was
+; requested.
+
+; REQUIRES: aarch64-registered-target
+
+target datalayout = "e-m:o-i64:64-f80:128-n8:8:32:64-S128"
+target triple = "aarch64-apple-ios"
+
+define void @multiply(<8 x double> * %A, <8 x double> * %B, <4 x double>* %C) {
+; CHECK-LABEL: @multiply(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[VEC_CAST:%.*]] = bitcast <8 x double>* [[A:%.*]] to <2 x double>*
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST]], align 8
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr <8 x double>, <8 x double>* [[A]], i64 0, i64 2
+; CHECK-NEXT: [[VEC_CAST1:%.*]] = bitcast double* [[VEC_GEP]] to <2 x double>*
+; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST1]], align 8
+; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr <8 x double>, <8 x double>* [[A]], i64 0, i64 4
+; CHECK-NEXT: [[VEC_CAST4:%.*]] = bitcast double* [[VEC_GEP3]] to <2 x double>*
+; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST4]], align 8
+; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr <8 x double>, <8 x double>* [[A]], i64 0, i64 6
+; CHECK-NEXT: [[VEC_CAST7:%.*]] = bitcast double* [[VEC_GEP6]] to <2 x double>*
+; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <2 x double>, <2 x double>* [[VEC_CAST7]], align 8
+; CHECK-NEXT: [[VEC_CAST9:%.*]] = bitcast <8 x double>* [[B:%.*]] to <4 x double>*
+; CHECK-NEXT: [[COL_LOAD10:%.*]] = load <4 x double>, <4 x double>* [[VEC_CAST9]], align 8
+; CHECK-NEXT: [[VEC_GEP11:%.*]] = getelementptr <8 x double>, <8 x double>* [[B]], i64 0, i64 4
+; CHECK-NEXT: [[VEC_CAST12:%.*]] = bitcast double* [[VEC_GEP11]] to <4 x double>*
+; CHECK-NEXT: [[COL_LOAD13:%.*]] = load <4 x double>, <4 x double>* [[VEC_CAST12]], align 8
+; CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <4 x double> [[COL_LOAD10]], <4 x double> undef, <2 x i32> zeroinitializer
+; CHECK-NEXT: [[TMP0:%.*]] = fmul <2 x double> [[COL_LOAD]], [[SPLAT_SPLAT]]
+; CHECK-NEXT: [[SPLAT_SPLAT16:%.*]] = shufflevector <4 x double> [[COL_LOAD10]], <4 x double> undef, <2 x i32> <i32 1, i32 1>
+; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[COL_LOAD2]], <2 x double> [[SPLAT_SPLAT16]], <2 x double> [[TMP0]])
+; CHECK-NEXT: [[SPLAT_SPLAT19:%.*]] = shufflevector <4 x double> [[COL_LOAD10]], <4 x double> undef, <2 x i32> <i32 2, i32 2>
+; CHECK-NEXT: [[TMP2:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[COL_LOAD5]], <2 x double> [[SPLAT_SPLAT19]], <2 x double> [[TMP1]])
+; CHECK-NEXT: [[SPLAT_SPLAT22:%.*]] = shufflevector <4 x double> [[COL_LOAD10]], <4 x double> undef, <2 x i32> <i32 3, i32 3>
+; CHECK-NEXT: [[TMP3:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[COL_LOAD8]], <2 x double> [[SPLAT_SPLAT22]], <2 x double> [[TMP2]])
+; CHECK-NEXT: [[SPLAT_SPLAT25:%.*]] = shufflevector <4 x double> [[COL_LOAD13]], <4 x double> undef, <2 x i32> zeroinitializer
+; CHECK-NEXT: [[TMP4:%.*]] = fmul <2 x double> [[COL_LOAD]], [[SPLAT_SPLAT25]]
+; CHECK-NEXT: [[SPLAT_SPLAT28:%.*]] = shufflevector <4 x double> [[COL_LOAD13]], <4 x double> undef, <2 x i32> <i32 1, i32 1>
+; CHECK-NEXT: [[TMP5:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[COL_LOAD2]], <2 x double> [[SPLAT_SPLAT28]], <2 x double> [[TMP4]])
+; CHECK-NEXT: [[SPLAT_SPLAT31:%.*]] = shufflevector <4 x double> [[COL_LOAD13]], <4 x double> undef, <2 x i32> <i32 2, i32 2>
+; CHECK-NEXT: [[TMP6:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[COL_LOAD5]], <2 x double> [[SPLAT_SPLAT31]], <2 x double> [[TMP5]])
+; CHECK-NEXT: [[SPLAT_SPLAT34:%.*]] = shufflevector <4 x double> [[COL_LOAD13]], <4 x double> undef, <2 x i32> <i32 3, i32 3>
+; CHECK-NEXT: [[TMP7:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[COL_LOAD8]], <2 x double> [[SPLAT_SPLAT34]], <2 x double> [[TMP6]])
+; CHECK-NEXT: [[VEC_CAST35:%.*]] = bitcast <4 x double>* [[C:%.*]] to <2 x double>*
+; CHECK-NEXT: store <2 x double> [[TMP3]], <2 x double>* [[VEC_CAST35]], align 8
+; CHECK-NEXT: [[VEC_GEP36:%.*]] = getelementptr <4 x double>, <4 x double>* [[C]], i64 0, i64 2
+; CHECK-NEXT: [[VEC_CAST37:%.*]] = bitcast double* [[VEC_GEP36]] to <2 x double>*
+; CHECK-NEXT: store <2 x double> [[TMP7]], <2 x double>* [[VEC_CAST37]], align 8
+; CHECK-NEXT: ret void
+;
+entry:
+ %a = load <8 x double>, <8 x double>* %A, align 8
+ %b = load <8 x double>, <8 x double>* %B, align 8
+
+ %c = call <4 x double> @llvm.matrix.multiply(<8 x double> %a, <8 x double> %b, i32 2, i32 4, i32 2)
+
+ store <4 x double> %c, <4 x double>* %C, align 8
+ ret void
+}
+
+declare <4 x double> @llvm.matrix.multiply(<8 x double>, <8 x double>, i32, i32, i32)