OSDN Git Service

InferAddressSpaces: Handle ptrmask intrinsic
authorMatt Arsenault <Matthew.Arsenault@amd.com>
Fri, 15 May 2020 18:54:51 +0000 (14:54 -0400)
committerMatt Arsenault <Matthew.Arsenault@amd.com>
Thu, 28 May 2020 14:04:02 +0000 (10:04 -0400)
This one is slightly odd since it counts as an address expression,
which previously could never fail. Allow the existing TTI hook to
return the value to use, and re-use it for handling how to handle
ptrmask.

Handles the no-op addrspacecasts for AMDGPU. We could probably do
something better based on analysis of the mask value based on the
address space, but leave that for now.

llvm/include/llvm/Analysis/TargetTransformInfo.h
llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
llvm/include/llvm/CodeGen/BasicTTIImpl.h
llvm/lib/Analysis/TargetTransformInfo.cpp
llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp
llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h
llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll

index c50c696..51aa1cb 100644 (file)
@@ -379,9 +379,10 @@ public:
   /// Rewrite intrinsic call \p II such that \p OldV will be replaced with \p
   /// NewV, which has a different address space. This should happen for every
   /// operand index that collectFlatAddressOperands returned for the intrinsic.
-  /// \returns true if the intrinsic /// was handled.
-  bool rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
-                                        Value *NewV) const;
+  /// \returns nullptr if the intrinsic was not handled. Otherwise, returns the
+  /// new value (which may be the original \p II with modified operands).
+  Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
+                                          Value *NewV) const;
 
   /// Test whether calls to a function lower to actual program function
   /// calls.
@@ -1236,8 +1237,9 @@ public:
   virtual unsigned getFlatAddressSpace() = 0;
   virtual bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
                                           Intrinsic::ID IID) const = 0;
-  virtual bool rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
-                                                Value *NewV) const = 0;
+  virtual Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
+                                                  Value *OldV,
+                                                  Value *NewV) const = 0;
   virtual bool isLoweredToCall(const Function *F) = 0;
   virtual void getUnrollingPreferences(Loop *L, ScalarEvolution &,
                                        UnrollingPreferences &UP) = 0;
@@ -1505,8 +1507,8 @@ public:
     return Impl.collectFlatAddressOperands(OpIndexes, IID);
   }
 
-  bool rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
-                                        Value *NewV) const override {
+  Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
+                                          Value *NewV) const override {
     return Impl.rewriteIntrinsicWithAddressSpace(II, OldV, NewV);
   }
 
index f98b8bf..0e8fc5d 100644 (file)
@@ -86,9 +86,9 @@ public:
     return false;
   }
 
-  bool rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
-                                        Value *NewV) const {
-    return false;
+  Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
+                                          Value *NewV) const {
+    return nullptr;
   }
 
   bool isLoweredToCall(const Function *F) {
index cc751a5..c751c37 100644 (file)
@@ -222,9 +222,9 @@ public:
     return false;
   }
 
-  bool rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
-                                        Value *OldV, Value *NewV) const {
-    return false;
+  Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
+                                          Value *NewV) const {
+    return nullptr;
   }
 
   bool isLegalAddImmediate(int64_t imm) {
index 9f319c4..0c34050 100644 (file)
@@ -290,9 +290,8 @@ bool TargetTransformInfo::collectFlatAddressOperands(
   return TTIImpl->collectFlatAddressOperands(OpIndexes, IID);
 }
 
-bool TargetTransformInfo::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
-                                                           Value *OldV,
-                                                           Value *NewV) const {
+Value *TargetTransformInfo::rewriteIntrinsicWithAddressSpace(
+    IntrinsicInst *II, Value *OldV, Value *NewV) const {
   return TTIImpl->rewriteIntrinsicWithAddressSpace(II, OldV, NewV);
 }
 
index 2405a24..324dcba 100644 (file)
@@ -851,8 +851,9 @@ bool GCNTTIImpl::collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
   }
 }
 
-bool GCNTTIImpl::rewriteIntrinsicWithAddressSpace(
-  IntrinsicInst *II, Value *OldV, Value *NewV) const {
+Value *GCNTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
+                                                    Value *OldV,
+                                                    Value *NewV) const {
   auto IntrID = II->getIntrinsicID();
   switch (IntrID) {
   case Intrinsic::amdgcn_atomic_inc:
@@ -862,7 +863,7 @@ bool GCNTTIImpl::rewriteIntrinsicWithAddressSpace(
   case Intrinsic::amdgcn_ds_fmax: {
     const ConstantInt *IsVolatile = cast<ConstantInt>(II->getArgOperand(4));
     if (!IsVolatile->isZero())
-      return false;
+      return nullptr;
     Module *M = II->getParent()->getParent()->getParent();
     Type *DestTy = II->getType();
     Type *SrcTy = NewV->getType();
@@ -870,7 +871,7 @@ bool GCNTTIImpl::rewriteIntrinsicWithAddressSpace(
         Intrinsic::getDeclaration(M, II->getIntrinsicID(), {DestTy, SrcTy});
     II->setArgOperand(0, NewV);
     II->setCalledFunction(NewDecl);
-    return true;
+    return II;
   }
   case Intrinsic::amdgcn_is_shared:
   case Intrinsic::amdgcn_is_private: {
@@ -880,12 +881,25 @@ bool GCNTTIImpl::rewriteIntrinsicWithAddressSpace(
     LLVMContext &Ctx = NewV->getType()->getContext();
     ConstantInt *NewVal = (TrueAS == NewAS) ?
       ConstantInt::getTrue(Ctx) : ConstantInt::getFalse(Ctx);
-    II->replaceAllUsesWith(NewVal);
-    II->eraseFromParent();
-    return true;
+    return NewVal;
+  }
+  case Intrinsic::ptrmask: {
+    unsigned OldAS = OldV->getType()->getPointerAddressSpace();
+    unsigned NewAS = NewV->getType()->getPointerAddressSpace();
+    if (!getTLI()->isNoopAddrSpaceCast(OldAS, NewAS))
+      return nullptr;
+
+    Module *M = II->getParent()->getParent()->getParent();
+    Value *MaskOp = II->getArgOperand(1);
+    Type *MaskTy = MaskOp->getType();
+    Function *NewDecl = Intrinsic::getDeclaration(M, Intrinsic::ptrmask,
+                                                  {NewV->getType(), MaskTy});
+    CallInst *NewCall = CallInst::Create(NewDecl->getFunctionType(), NewDecl,
+                                         {NewV, MaskOp}, "", II);
+    return NewCall;
   }
   default:
-    return false;
+    return nullptr;
   }
 }
 
index bc74965..72c040f 100644 (file)
@@ -211,8 +211,8 @@ public:
 
   bool collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
                                   Intrinsic::ID IID) const;
-  bool rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
-                                        Value *OldV, Value *NewV) const;
+  Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
+                                          Value *NewV) const;
 
   unsigned getVectorSplitCost() { return 0; }
 
index 6fd6b84..d407d04 100644 (file)
@@ -175,6 +175,11 @@ private:
 
   bool isSafeToCastConstAddrSpace(Constant *C, unsigned NewAS) const;
 
+  Value *cloneInstructionWithNewAddressSpace(
+      Instruction *I, unsigned NewAddrSpace,
+      const ValueToValueMapTy &ValueWithNewAddrSpace,
+      SmallVectorImpl<const Use *> *UndefUsesToFix) const;
+
   // Changes the flat address expressions in function F to point to specific
   // address spaces if InferredAddrSpace says so. Postorder is the postorder of
   // all flat expressions in the use-def graph of function F.
@@ -218,20 +223,24 @@ INITIALIZE_PASS(InferAddressSpaces, DEBUG_TYPE, "Infer address spaces",
 // TODO: Currently, we consider only phi, bitcast, addrspacecast, and
 // getelementptr operators.
 static bool isAddressExpression(const Value &V) {
-  if (!isa<Operator>(V))
+  const Operator *Op = dyn_cast<Operator>(&V);
+  if (!Op)
     return false;
 
-  const Operator &Op = cast<Operator>(V);
-  switch (Op.getOpcode()) {
+  switch (Op->getOpcode()) {
   case Instruction::PHI:
-    assert(Op.getType()->isPointerTy());
+    assert(Op->getType()->isPointerTy());
     return true;
   case Instruction::BitCast:
   case Instruction::AddrSpaceCast:
   case Instruction::GetElementPtr:
     return true;
   case Instruction::Select:
-    return Op.getType()->isPointerTy();
+    return Op->getType()->isPointerTy();
+  case Instruction::Call: {
+    const IntrinsicInst *II = dyn_cast<IntrinsicInst>(&V);
+    return II && II->getIntrinsicID() == Intrinsic::ptrmask;
+  }
   default:
     return false;
   }
@@ -254,12 +263,17 @@ static SmallVector<Value *, 2> getPointerOperands(const Value &V) {
     return {Op.getOperand(0)};
   case Instruction::Select:
     return {Op.getOperand(1), Op.getOperand(2)};
+  case Instruction::Call: {
+    const IntrinsicInst &II = cast<IntrinsicInst>(Op);
+    assert(II.getIntrinsicID() == Intrinsic::ptrmask &&
+           "unexpected intrinsic call");
+    return {II.getArgOperand(0)};
+  }
   default:
     llvm_unreachable("Unexpected instruction type.");
   }
 }
 
-// TODO: Move logic to TTI?
 bool InferAddressSpaces::rewriteIntrinsicOperands(IntrinsicInst *II,
                                                   Value *OldV,
                                                   Value *NewV) const {
@@ -275,8 +289,17 @@ bool InferAddressSpaces::rewriteIntrinsicOperands(IntrinsicInst *II,
     II->setCalledFunction(NewDecl);
     return true;
   }
-  default:
-    return TTI->rewriteIntrinsicWithAddressSpace(II, OldV, NewV);
+  case Intrinsic::ptrmask:
+    // This is handled as an address expression, not as a use memory operation.
+    return false;
+  default: {
+    Value *Rewrite = TTI->rewriteIntrinsicWithAddressSpace(II, OldV, NewV);
+    if (!Rewrite)
+      return false;
+    if (Rewrite != II)
+      II->replaceAllUsesWith(Rewrite);
+    return true;
+  }
   }
 }
 
@@ -285,6 +308,7 @@ void InferAddressSpaces::collectRewritableIntrinsicOperands(
     DenseSet<Value *> &Visited) const {
   auto IID = II->getIntrinsicID();
   switch (IID) {
+  case Intrinsic::ptrmask:
   case Intrinsic::objectsize:
     appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0),
                                                  PostorderStack, Visited);
@@ -438,10 +462,13 @@ static Value *operandWithNewAddressSpaceOrCreateUndef(
 // Note that we do not necessarily clone `I`, e.g., if it is an addrspacecast
 // from a pointer whose type already matches. Therefore, this function returns a
 // Value* instead of an Instruction*.
-static Value *cloneInstructionWithNewAddressSpace(
+//
+// This may also return nullptr in the case the instruction could not be
+// rewritten.
+Value *InferAddressSpaces::cloneInstructionWithNewAddressSpace(
     Instruction *I, unsigned NewAddrSpace,
     const ValueToValueMapTy &ValueWithNewAddrSpace,
-    SmallVectorImpl<const Use *> *UndefUsesToFix) {
+    SmallVectorImpl<const Use *> *UndefUsesToFix) const {
   Type *NewPtrType =
       I->getType()->getPointerElementType()->getPointerTo(NewAddrSpace);
 
@@ -456,6 +483,23 @@ static Value *cloneInstructionWithNewAddressSpace(
     return Src;
   }
 
+  if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
+    // Technically the intrinsic ID is a pointer typed argument, so specially
+    // handle calls early.
+    assert(II->getIntrinsicID() == Intrinsic::ptrmask);
+    Value *NewPtr = operandWithNewAddressSpaceOrCreateUndef(
+        II->getArgOperandUse(0), NewAddrSpace, ValueWithNewAddrSpace,
+        UndefUsesToFix);
+    Value *Rewrite =
+        TTI->rewriteIntrinsicWithAddressSpace(II, II->getArgOperand(0), NewPtr);
+    if (Rewrite) {
+      assert(Rewrite != II && "cannot modify this pointer operation in place");
+      return Rewrite;
+    }
+
+    return nullptr;
+  }
+
   // Computes the converted pointer operands.
   SmallVector<Value *, 4> NewPointerOperands;
   for (const Use &OperandUse : I->operands()) {
@@ -591,7 +635,7 @@ Value *InferAddressSpaces::cloneValueWithNewAddressSpace(
   if (Instruction *I = dyn_cast<Instruction>(V)) {
     Value *NewV = cloneInstructionWithNewAddressSpace(
       I, NewAddrSpace, ValueWithNewAddrSpace, UndefUsesToFix);
-    if (Instruction *NewI = dyn_cast<Instruction>(NewV)) {
+    if (Instruction *NewI = dyn_cast_or_null<Instruction>(NewV)) {
       if (NewI->getParent() == nullptr) {
         NewI->insertBefore(I);
         NewI->takeName(I);
@@ -879,8 +923,10 @@ bool InferAddressSpaces::rewriteWithNewAddressSpaces(
   for (Value* V : Postorder) {
     unsigned NewAddrSpace = InferredAddrSpace.lookup(V);
     if (V->getType()->getPointerAddressSpace() != NewAddrSpace) {
-      ValueWithNewAddrSpace[V] = cloneValueWithNewAddressSpace(
-        V, NewAddrSpace, ValueWithNewAddrSpace, &UndefUsesToFix);
+      Value *New = cloneValueWithNewAddressSpace(
+          V, NewAddrSpace, ValueWithNewAddrSpace, &UndefUsesToFix);
+      if (New)
+        ValueWithNewAddrSpace[V] = New;
     }
   }
 
@@ -890,7 +936,10 @@ bool InferAddressSpaces::rewriteWithNewAddressSpaces(
   // Fixes all the undef uses generated by cloneInstructionWithNewAddressSpace.
   for (const Use *UndefUse : UndefUsesToFix) {
     User *V = UndefUse->getUser();
-    User *NewV = cast<User>(ValueWithNewAddrSpace.lookup(V));
+    User *NewV = cast_or_null<User>(ValueWithNewAddrSpace.lookup(V));
+    if (!NewV)
+      continue;
+
     unsigned OperandNo = UndefUse->getOperandNo();
     assert(isa<UndefValue>(NewV->getOperand(OperandNo)));
     NewV->setOperand(OperandNo, ValueWithNewAddrSpace.lookup(UndefUse->get()));
index 40daa48..6f25a3e 100644 (file)
@@ -42,9 +42,8 @@ define i8 @ptrmask_cast_region_to_flat(i8 addrspace(2)* %src.ptr, i64 %mask) {
 
 define i8 @ptrmask_cast_global_to_flat(i8 addrspace(1)* %src.ptr, i64 %mask) {
 ; CHECK-LABEL: @ptrmask_cast_global_to_flat(
-; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast i8 addrspace(1)* [[SRC_PTR:%.*]] to i8*
-; CHECK-NEXT:    [[MASKED:%.*]] = call i8* @llvm.ptrmask.p0i8.i64(i8* [[CAST]], i64 [[MASK:%.*]])
-; CHECK-NEXT:    [[LOAD:%.*]] = load i8, i8* [[MASKED]], align 1
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 addrspace(1)* @llvm.ptrmask.p1i8.i64(i8 addrspace(1)* [[SRC_PTR:%.*]], i64 [[MASK:%.*]])
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, i8 addrspace(1)* [[TMP1]], align 1
 ; CHECK-NEXT:    ret i8 [[LOAD]]
 ;
   %cast = addrspacecast i8 addrspace(1)* %src.ptr to i8*
@@ -55,9 +54,8 @@ define i8 @ptrmask_cast_global_to_flat(i8 addrspace(1)* %src.ptr, i64 %mask) {
 
 define i8 @ptrmask_cast_999_to_flat(i8 addrspace(999)* %src.ptr, i64 %mask) {
 ; CHECK-LABEL: @ptrmask_cast_999_to_flat(
-; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast i8 addrspace(999)* [[SRC_PTR:%.*]] to i8*
-; CHECK-NEXT:    [[MASKED:%.*]] = call i8* @llvm.ptrmask.p0i8.i64(i8* [[CAST]], i64 [[MASK:%.*]])
-; CHECK-NEXT:    [[LOAD:%.*]] = load i8, i8* [[MASKED]], align 1
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 addrspace(999)* @llvm.ptrmask.p999i8.i64(i8 addrspace(999)* [[SRC_PTR:%.*]], i64 [[MASK:%.*]])
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, i8 addrspace(999)* [[TMP1]], align 1
 ; CHECK-NEXT:    ret i8 [[LOAD]]
 ;
   %cast = addrspacecast i8 addrspace(999)* %src.ptr to i8*
@@ -121,8 +119,8 @@ define i8 @ptrmask_cast_local_to_flat_global(i64 %mask) {
 
 define i8 @ptrmask_cast_global_to_flat_global(i64 %mask) {
 ; CHECK-LABEL: @ptrmask_cast_global_to_flat_global(
-; CHECK-NEXT:    [[MASKED:%.*]] = call i8* @llvm.ptrmask.p0i8.i64(i8* addrspacecast (i8 addrspace(1)* @gv to i8*), i64 [[MASK:%.*]])
-; CHECK-NEXT:    [[LOAD:%.*]] = load i8, i8* [[MASKED]], align 1
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 addrspace(1)* @llvm.ptrmask.p1i8.i64(i8 addrspace(1)* @gv, i64 [[MASK:%.*]])
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, i8 addrspace(1)* [[TMP1]], align 1
 ; CHECK-NEXT:    ret i8 [[LOAD]]
 ;
   %masked = call i8* @llvm.ptrmask.p0i8.i64(i8* addrspacecast (i8 addrspace(1)* @gv to i8*), i64 %mask)
@@ -132,10 +130,9 @@ define i8 @ptrmask_cast_global_to_flat_global(i64 %mask) {
 
 define i8 @multi_ptrmask_cast_global_to_flat(i8 addrspace(1)* %src.ptr, i64 %mask) {
 ; CHECK-LABEL: @multi_ptrmask_cast_global_to_flat(
-; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast i8 addrspace(1)* [[SRC_PTR:%.*]] to i8*
-; CHECK-NEXT:    [[LOAD0:%.*]] = load i8, i8 addrspace(1)* [[SRC_PTR]], align 1
-; CHECK-NEXT:    [[MASKED:%.*]] = call i8* @llvm.ptrmask.p0i8.i64(i8* [[CAST]], i64 [[MASK:%.*]])
-; CHECK-NEXT:    [[LOAD1:%.*]] = load i8, i8* [[MASKED]], align 1
+; CHECK-NEXT:    [[LOAD0:%.*]] = load i8, i8 addrspace(1)* [[SRC_PTR:%.*]], align 1
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 addrspace(1)* @llvm.ptrmask.p1i8.i64(i8 addrspace(1)* [[SRC_PTR]], i64 [[MASK:%.*]])
+; CHECK-NEXT:    [[LOAD1:%.*]] = load i8, i8 addrspace(1)* [[TMP1]], align 1
 ; CHECK-NEXT:    [[ADD:%.*]] = add i8 [[LOAD0]], [[LOAD1]]
 ; CHECK-NEXT:    ret i8 [[ADD]]
 ;