From 406c0440c5fc7bca370641f3f0a2d36940c88861 Mon Sep 17 00:00:00 2001 From: Roman Tereshin Date: Wed, 25 Jul 2018 21:33:00 +0000 Subject: [PATCH] [LSV] Look through selects for consecutive addresses In some cases LSV sees (load/store _ (select _ )) patterns in input IR, often due to sinking and other forms of CFG simplification, sometimes interspersed with bitcasts and all-constant-indices GEPs. With this patch`areConsecutivePointers` method would attempt to handle select instructions. This leads to an increased number of successful vectorizations. Technically, select instructions could appear in index arithmetic as well, however, we don't see those in our test suites / benchmarks. Also, there is a lot more freedom in IR shapes computing integral indices in general than in what's common in pointer computations, and it appears that it's quite unreliable to do anything short of making select instructions first class citizens of Scalar Evolution, which for the purposes of this patch is most definitely an overkill. Reviewed By: rampitec Differential Revision: https://reviews.llvm.org/D49428 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@337965 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Transforms/Vectorize/LoadStoreVectorizer.cpp | 77 ++++++++++++++---- .../LoadStoreVectorizer/AMDGPU/selects.ll | 95 ++++++++++++++++++++++ 2 files changed, 157 insertions(+), 15 deletions(-) create mode 100644 test/Transforms/LoadStoreVectorizer/AMDGPU/selects.ll diff --git a/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp index b56e731c991..719df55347a 100644 --- a/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -97,8 +97,16 @@ static const unsigned StackAdjustedAlignment = 4; namespace { +/// ChainID is an arbitrary token that is allowed to be different only for the +/// accesses that are guaranteed to be considered non-consecutive by +/// Vectorizer::isConsecutiveAccess. It's used for grouping instructions +/// together and reducing the number of instructions the main search operates on +/// at a time, i.e. this is to reduce compile time and nothing else as the main +/// search has O(n^2) time complexity. The underlying type of ChainID should not +/// be relied upon. +using ChainID = const Value *; using InstrList = SmallVector; -using InstrListMap = MapVector; +using InstrListMap = MapVector; class Vectorizer { Function &F; @@ -136,9 +144,15 @@ private: return DL.getABITypeAlignment(SI->getValueOperand()->getType()); } + static const unsigned MaxDepth = 3; + bool isConsecutiveAccess(Value *A, Value *B); - bool areConsecutivePointers(Value *PtrA, Value *PtrB, APInt Size); - bool lookThroughComplexAddresses(Value *PtrA, Value *PtrB, APInt PtrDelta); + bool areConsecutivePointers(Value *PtrA, Value *PtrB, APInt PtrDelta, + unsigned Depth = 0) const; + bool lookThroughComplexAddresses(Value *PtrA, Value *PtrB, APInt PtrDelta, + unsigned Depth) const; + bool lookThroughSelects(Value *PtrA, Value *PtrB, APInt PtrDelta, + unsigned Depth) const; /// After vectorization, reorder the instructions that I depends on /// (the instructions defining its operands), to ensure they dominate I. @@ -304,7 +318,8 @@ bool Vectorizer::isConsecutiveAccess(Value *A, Value *B) { return areConsecutivePointers(PtrA, PtrB, Size); } -bool Vectorizer::areConsecutivePointers(Value *PtrA, Value *PtrB, APInt Size) { +bool Vectorizer::areConsecutivePointers(Value *PtrA, Value *PtrB, + APInt PtrDelta, unsigned Depth) const { unsigned PtrBitWidth = DL.getPointerTypeSizeInBits(PtrA->getType()); APInt OffsetA(PtrBitWidth, 0); APInt OffsetB(PtrBitWidth, 0); @@ -316,11 +331,11 @@ bool Vectorizer::areConsecutivePointers(Value *PtrA, Value *PtrB, APInt Size) { // Check if they are based on the same pointer. That makes the offsets // sufficient. if (PtrA == PtrB) - return OffsetDelta == Size; + return OffsetDelta == PtrDelta; // Compute the necessary base pointer delta to have the necessary final delta - // equal to the size. - APInt BaseDelta = Size - OffsetDelta; + // equal to the pointer delta requested. + APInt BaseDelta = PtrDelta - OffsetDelta; // Compute the distance with SCEV between the base pointers. const SCEV *PtrSCEVA = SE.getSCEV(PtrA); @@ -341,15 +356,16 @@ bool Vectorizer::areConsecutivePointers(Value *PtrA, Value *PtrB, APInt Size) { // Sometimes even this doesn't work, because SCEV can't always see through // patterns that look like (gep (ext (add (shl X, C1), C2))). Try checking // things the hard way. - return lookThroughComplexAddresses(PtrA, PtrB, BaseDelta); + return lookThroughComplexAddresses(PtrA, PtrB, BaseDelta, Depth); } bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB, - APInt PtrDelta) { + APInt PtrDelta, + unsigned Depth) const { auto *GEPA = dyn_cast(PtrA); auto *GEPB = dyn_cast(PtrB); if (!GEPA || !GEPB) - return false; + return lookThroughSelects(PtrA, PtrB, PtrDelta, Depth); // Look through GEPs after checking they're the same except for the last // index. @@ -434,6 +450,23 @@ bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB, return X == OffsetSCEVB; } +bool Vectorizer::lookThroughSelects(Value *PtrA, Value *PtrB, APInt PtrDelta, + unsigned Depth) const { + if (Depth++ == MaxDepth) + return false; + + if (auto *SelectA = dyn_cast(PtrA)) { + if (auto *SelectB = dyn_cast(PtrB)) { + return SelectA->getCondition() == SelectB->getCondition() && + areConsecutivePointers(SelectA->getTrueValue(), + SelectB->getTrueValue(), PtrDelta, Depth) && + areConsecutivePointers(SelectA->getFalseValue(), + SelectB->getFalseValue(), PtrDelta, Depth); + } + } + return false; +} + void Vectorizer::reorder(Instruction *I) { OrderedBasicBlock OBB(I->getParent()); SmallPtrSet InstructionsToMove; @@ -656,6 +689,20 @@ Vectorizer::getVectorizablePrefix(ArrayRef Chain) { return Chain.slice(0, ChainIdx); } +static ChainID getChainID(const Value *Ptr, const DataLayout &DL) { + const Value *ObjPtr = GetUnderlyingObject(Ptr, DL); + if (const auto *Sel = dyn_cast(ObjPtr)) { + // The select's themselves are distinct instructions even if they share the + // same condition and evaluate to consecutive pointers for true and false + // values of the condition. Therefore using the select's themselves for + // grouping instructions would put consecutive accesses into different lists + // and they won't be even checked for being consecutive, and won't be + // vectorized. + return Sel->getCondition(); + } + return ObjPtr; +} + std::pair Vectorizer::collectInstructions(BasicBlock *BB) { InstrListMap LoadRefs; @@ -710,8 +757,8 @@ Vectorizer::collectInstructions(BasicBlock *BB) { continue; // Save the load locations. - Value *ObjPtr = GetUnderlyingObject(Ptr, DL); - LoadRefs[ObjPtr].push_back(LI); + const ChainID ID = getChainID(Ptr, DL); + LoadRefs[ID].push_back(LI); } else if (StoreInst *SI = dyn_cast(&I)) { if (!SI->isSimple()) continue; @@ -756,8 +803,8 @@ Vectorizer::collectInstructions(BasicBlock *BB) { continue; // Save store location. - Value *ObjPtr = GetUnderlyingObject(Ptr, DL); - StoreRefs[ObjPtr].push_back(SI); + const ChainID ID = getChainID(Ptr, DL); + StoreRefs[ID].push_back(SI); } } @@ -767,7 +814,7 @@ Vectorizer::collectInstructions(BasicBlock *BB) { bool Vectorizer::vectorizeChains(InstrListMap &Map) { bool Changed = false; - for (const std::pair &Chain : Map) { + for (const std::pair &Chain : Map) { unsigned Size = Chain.second.size(); if (Size < 2) continue; diff --git a/test/Transforms/LoadStoreVectorizer/AMDGPU/selects.ll b/test/Transforms/LoadStoreVectorizer/AMDGPU/selects.ll new file mode 100644 index 00000000000..32fe5eb9ce2 --- /dev/null +++ b/test/Transforms/LoadStoreVectorizer/AMDGPU/selects.ll @@ -0,0 +1,95 @@ +; RUN: opt -mtriple=amdgcn-amd-amdhsa -load-store-vectorizer -dce -S -o - %s | FileCheck %s + +target datalayout = "e-p:32:32-p1:64:64-p2:64:64-p3:32:32-p4:64:64-p5:32:32-p24:64:64-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64" + +define void @base_case(i1 %cnd, i32 addrspace(1)* %a, i32 addrspace(1)* %b, <3 x i32> addrspace(1)* %out) { +; CHECK-LABEL: @base_case +; CHECK: load <3 x i32> +entry: + %gep1 = getelementptr inbounds i32, i32 addrspace(1)* %a, i64 1 + %gep2 = getelementptr inbounds i32, i32 addrspace(1)* %a, i64 2 + %gep4 = getelementptr inbounds i32, i32 addrspace(1)* %b, i64 1 + %gep5 = getelementptr inbounds i32, i32 addrspace(1)* %b, i64 2 + %selected = select i1 %cnd, i32 addrspace(1)* %a, i32 addrspace(1)* %b + %selected14 = select i1 %cnd, i32 addrspace(1)* %gep1, i32 addrspace(1)* %gep4 + %selected25 = select i1 %cnd, i32 addrspace(1)* %gep2, i32 addrspace(1)* %gep5 + %val0 = load i32, i32 addrspace(1)* %selected, align 4 + %val1 = load i32, i32 addrspace(1)* %selected14, align 4 + %val2 = load i32, i32 addrspace(1)* %selected25, align 4 + %t0 = insertelement <3 x i32> undef, i32 %val0, i32 0 + %t1 = insertelement <3 x i32> %t0, i32 %val1, i32 1 + %t2 = insertelement <3 x i32> %t1, i32 %val2, i32 2 + store <3 x i32> %t2, <3 x i32> addrspace(1)* %out + ret void +} + +define void @scev_targeting_complex_case(i1 %cnd, i32 addrspace(1)* %a, i32 addrspace(1)* %b, i32 %base, <2 x i32> addrspace(1)* %out) { +; CHECK-LABEL: @scev_targeting_complex_case +; CHECK: load <2 x i32> +entry: + %base.x4 = shl i32 %base, 2 + %base.x4.p1 = add i32 %base.x4, 1 + %base.x4.p2 = add i32 %base.x4, 2 + %base.x4.p3 = add i32 %base.x4, 3 + %zext.x4 = zext i32 %base.x4 to i64 + %zext.x4.p1 = zext i32 %base.x4.p1 to i64 + %zext.x4.p2 = zext i32 %base.x4.p2 to i64 + %zext.x4.p3 = zext i32 %base.x4.p3 to i64 + %base.x16 = mul i64 %zext.x4, 4 + %base.x16.p4 = shl i64 %zext.x4.p1, 2 + %base.x16.p8 = shl i64 %zext.x4.p2, 2 + %base.x16.p12 = mul i64 %zext.x4.p3, 4 + %a.pi8 = bitcast i32 addrspace(1)* %a to i8 addrspace(1)* + %b.pi8 = bitcast i32 addrspace(1)* %b to i8 addrspace(1)* + %gep.a.base.x16 = getelementptr inbounds i8, i8 addrspace(1)* %a.pi8, i64 %base.x16 + %gep.b.base.x16.p4 = getelementptr inbounds i8, i8 addrspace(1)* %b.pi8, i64 %base.x16.p4 + %gep.a.base.x16.p8 = getelementptr inbounds i8, i8 addrspace(1)* %a.pi8, i64 %base.x16.p8 + %gep.b.base.x16.p12 = getelementptr inbounds i8, i8 addrspace(1)* %b.pi8, i64 %base.x16.p12 + %a.base.x16 = bitcast i8 addrspace(1)* %gep.a.base.x16 to i32 addrspace(1)* + %b.base.x16.p4 = bitcast i8 addrspace(1)* %gep.b.base.x16.p4 to i32 addrspace(1)* + %selected.base.x16.p0.or.4 = select i1 %cnd, i32 addrspace(1)* %a.base.x16, i32 addrspace(1)* %b.base.x16.p4 + %gep.selected.base.x16.p8.or.12 = select i1 %cnd, i8 addrspace(1)* %gep.a.base.x16.p8, i8 addrspace(1)* %gep.b.base.x16.p12 + %selected.base.x16.p8.or.12 = bitcast i8 addrspace(1)* %gep.selected.base.x16.p8.or.12 to i32 addrspace(1)* + %selected.base.x16.p40.or.44 = getelementptr inbounds i32, i32 addrspace(1)* %selected.base.x16.p0.or.4, i64 10 + %selected.base.x16.p44.or.48 = getelementptr inbounds i32, i32 addrspace(1)* %selected.base.x16.p8.or.12, i64 9 + %val0 = load i32, i32 addrspace(1)* %selected.base.x16.p40.or.44, align 4 + %val1 = load i32, i32 addrspace(1)* %selected.base.x16.p44.or.48, align 4 + %t0 = insertelement <2 x i32> undef, i32 %val0, i32 0 + %t1 = insertelement <2 x i32> %t0, i32 %val1, i32 1 + store <2 x i32> %t1, <2 x i32> addrspace(1)* %out + ret void +} + +define void @nested_selects(i1 %cnd0, i1 %cnd1, i32 addrspace(1)* %a, i32 addrspace(1)* %b, i32 %base, <2 x i32> addrspace(1)* %out) { +; CHECK-LABEL: @nested_selects +; CHECK: load <2 x i32> +entry: + %base.p1 = add nsw i32 %base, 1 + %base.p2 = add i32 %base, 2 + %base.p3 = add nsw i32 %base, 3 + %base.x4 = mul i32 %base, 4 + %base.x4.p5 = add i32 %base.x4, 5 + %base.x4.p6 = add i32 %base.x4, 6 + %sext = sext i32 %base to i64 + %sext.p1 = sext i32 %base.p1 to i64 + %sext.p2 = sext i32 %base.p2 to i64 + %sext.p3 = sext i32 %base.p3 to i64 + %sext.x4.p5 = sext i32 %base.x4.p5 to i64 + %sext.x4.p6 = sext i32 %base.x4.p6 to i64 + %gep.a.base = getelementptr inbounds i32, i32 addrspace(1)* %a, i64 %sext + %gep.a.base.p1 = getelementptr inbounds i32, i32 addrspace(1)* %a, i64 %sext.p1 + %gep.a.base.p2 = getelementptr inbounds i32, i32 addrspace(1)* %a, i64 %sext.p2 + %gep.a.base.p3 = getelementptr inbounds i32, i32 addrspace(1)* %a, i64 %sext.p3 + %gep.b.base.x4.p5 = getelementptr inbounds i32, i32 addrspace(1)* %a, i64 %sext.x4.p5 + %gep.b.base.x4.p6 = getelementptr inbounds i32, i32 addrspace(1)* %a, i64 %sext.x4.p6 + %selected.1.L = select i1 %cnd1, i32 addrspace(1)* %gep.a.base.p2, i32 addrspace(1)* %gep.b.base.x4.p5 + %selected.1.R = select i1 %cnd1, i32 addrspace(1)* %gep.a.base.p3, i32 addrspace(1)* %gep.b.base.x4.p6 + %selected.0.L = select i1 %cnd0, i32 addrspace(1)* %gep.a.base, i32 addrspace(1)* %selected.1.L + %selected.0.R = select i1 %cnd0, i32 addrspace(1)* %gep.a.base.p1, i32 addrspace(1)* %selected.1.R + %val0 = load i32, i32 addrspace(1)* %selected.0.L, align 4 + %val1 = load i32, i32 addrspace(1)* %selected.0.R, align 4 + %t0 = insertelement <2 x i32> undef, i32 %val0, i32 0 + %t1 = insertelement <2 x i32> %t0, i32 %val1, i32 1 + store <2 x i32> %t1, <2 x i32> addrspace(1)* %out + ret void +} -- 2.11.0