OSDN Git Service

[AMDGPU] Add support for TFE/LWE in image intrinsics. 2nd try
[android-x86/external-llvm.git] / lib / Target / AMDGPU / SIISelLowering.cpp
index 6374792..9e16737 100644 (file)
@@ -216,6 +216,7 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
 
   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::v2f16, Custom);
   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::v4f16, Custom);
+  setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::v8f16, Custom);
   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
 
   setOperationAction(ISD::INTRINSIC_VOID, MVT::Other, Custom);
@@ -813,6 +814,47 @@ unsigned SITargetLowering::getVectorTypeBreakdownForCallingConv(
     Context, CC, VT, IntermediateVT, NumIntermediates, RegisterVT);
 }
 
+static MVT memVTFromAggregate(Type *Ty) {
+  // Only limited forms of aggregate type currently expected.
+  assert(Ty->isStructTy() && "Expected struct type");
+
+
+  Type *ElementType = nullptr;
+  unsigned NumElts;
+  if (Ty->getContainedType(0)->isVectorTy()) {
+    VectorType *VecComponent = cast<VectorType>(Ty->getContainedType(0));
+    ElementType = VecComponent->getElementType();
+    NumElts = VecComponent->getNumElements();
+  } else {
+    ElementType = Ty->getContainedType(0);
+    NumElts = 1;
+  }
+
+  assert((Ty->getContainedType(1) && Ty->getContainedType(1)->isIntegerTy(32)) && "Expected int32 type");
+
+  // Calculate the size of the memVT type from the aggregate
+  unsigned Pow2Elts = 0;
+  unsigned ElementSize;
+  switch (ElementType->getTypeID()) {
+    default:
+      llvm_unreachable("Unknown type!");
+    case Type::IntegerTyID:
+      ElementSize = cast<IntegerType>(ElementType)->getBitWidth();
+      break;
+    case Type::HalfTyID:
+      ElementSize = 16;
+      break;
+    case Type::FloatTyID:
+      ElementSize = 32;
+      break;
+  }
+  unsigned AdditionalElts = ElementSize == 16 ? 2 : 1;
+  Pow2Elts = 1 << Log2_32_Ceil(NumElts + AdditionalElts);
+
+  return MVT::getVectorVT(MVT::getVT(ElementType, false),
+                          Pow2Elts);
+}
+
 bool SITargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
                                           const CallInst &CI,
                                           MachineFunction &MF,
@@ -840,7 +882,12 @@ bool SITargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
     Info.flags = MachineMemOperand::MODereferenceable;
     if (Attr.hasFnAttribute(Attribute::ReadOnly)) {
       Info.opc = ISD::INTRINSIC_W_CHAIN;
-      Info.memVT = MVT::getVT(CI.getType());
+      Info.memVT = MVT::getVT(CI.getType(), true);
+      if (Info.memVT == MVT::Other) {
+        // Some intrinsics return an aggregate type - special case to work out
+        // the correct memVT
+        Info.memVT = memVTFromAggregate(CI.getType());
+      }
       Info.flags |= MachineMemOperand::MOLoad;
     } else if (Attr.hasFnAttribute(Attribute::WriteOnly)) {
       Info.opc = ISD::INTRINSIC_VOID;
@@ -4613,6 +4660,109 @@ static bool parseCachePolicy(SDValue CachePolicy, SelectionDAG &DAG,
   return Value == 0;
 }
 
+// Re-construct the required return value for a image load intrinsic.
+// This is more complicated due to the optional use TexFailCtrl which means the required
+// return type is an aggregate
+static SDValue constructRetValue(SelectionDAG &DAG,
+                                 MachineSDNode *Result,
+                                 ArrayRef<EVT> ResultTypes,
+                                 bool IsTexFail, bool Unpacked, bool IsD16,
+                                 int DMaskPop, int NumVDataDwords,
+                                 const SDLoc &DL, LLVMContext &Context) {
+  // Determine the required return type. This is the same regardless of IsTexFail flag
+  EVT ReqRetVT = ResultTypes[0];
+  EVT ReqRetEltVT = ReqRetVT.isVector() ? ReqRetVT.getVectorElementType() : ReqRetVT;
+  int ReqRetNumElts = ReqRetVT.isVector() ? ReqRetVT.getVectorNumElements() : 1;
+  EVT AdjEltVT = Unpacked && IsD16 ? MVT::i32 : ReqRetEltVT;
+  EVT AdjVT = Unpacked ? ReqRetNumElts > 1 ? EVT::getVectorVT(Context, AdjEltVT, ReqRetNumElts)
+                                           : AdjEltVT
+                       : ReqRetVT;
+
+  // Extract data part of the result
+  // Bitcast the result to the same type as the required return type
+  int NumElts;
+  if (IsD16 && !Unpacked)
+    NumElts = NumVDataDwords << 1;
+  else
+    NumElts = NumVDataDwords;
+
+  EVT CastVT = NumElts > 1 ? EVT::getVectorVT(Context, AdjEltVT, NumElts)
+                           : AdjEltVT;
+
+  // Special case for v8f16. Rather than add support for this, use v4i32 to
+  // extract the data elements
+  bool V8F16Special = false;
+  if (CastVT == MVT::v8f16) {
+    CastVT = MVT::v4i32;
+    DMaskPop >>= 1;
+    ReqRetNumElts >>= 1;
+    V8F16Special = true;
+    AdjVT = MVT::v2i32;
+  }
+
+  SDValue N = SDValue(Result, 0);
+  SDValue CastRes = DAG.getNode(ISD::BITCAST, DL, CastVT, N);
+
+  // Iterate over the result
+  SmallVector<SDValue, 4> BVElts;
+
+  if (CastVT.isVector()) {
+    DAG.ExtractVectorElements(CastRes, BVElts, 0, DMaskPop);
+  } else {
+    BVElts.push_back(CastRes);
+  }
+  int ExtraElts = ReqRetNumElts - DMaskPop;
+  while(ExtraElts--)
+    BVElts.push_back(DAG.getUNDEF(AdjEltVT));
+
+  SDValue PreTFCRes;
+  if (ReqRetNumElts > 1) {
+    SDValue NewVec = DAG.getBuildVector(AdjVT, DL, BVElts);
+    if (IsD16 && Unpacked)
+      PreTFCRes = adjustLoadValueTypeImpl(NewVec, ReqRetVT, DL, DAG, Unpacked);
+    else
+      PreTFCRes = NewVec;
+  } else {
+    PreTFCRes = BVElts[0];
+  }
+
+  if (V8F16Special)
+    PreTFCRes = DAG.getNode(ISD::BITCAST, DL, MVT::v4f16, PreTFCRes);
+
+  if (!IsTexFail) {
+    if (Result->getNumValues() > 1)
+      return DAG.getMergeValues({PreTFCRes, SDValue(Result, 1)}, DL);
+    else
+      return PreTFCRes;
+  }
+
+  // Extract the TexFail result and insert into aggregate return
+  SmallVector<SDValue, 1> TFCElt;
+  DAG.ExtractVectorElements(N, TFCElt, DMaskPop, 1);
+  SDValue TFCRes = DAG.getNode(ISD::BITCAST, DL, ResultTypes[1], TFCElt[0]);
+  return DAG.getMergeValues({PreTFCRes, TFCRes, SDValue(Result, 1)}, DL);
+}
+
+static bool parseTexFail(SDValue TexFailCtrl, SelectionDAG &DAG, SDValue *TFE,
+                         SDValue *LWE, bool &IsTexFail) {
+  auto TexFailCtrlConst = dyn_cast<ConstantSDNode>(TexFailCtrl.getNode());
+  if (!TexFailCtrlConst)
+    return false;
+
+  uint64_t Value = TexFailCtrlConst->getZExtValue();
+  if (Value) {
+    IsTexFail = true;
+  }
+
+  SDLoc DL(TexFailCtrlConst);
+  *TFE = DAG.getTargetConstant((Value & 0x1) ? 1 : 0, DL, MVT::i32);
+  Value &= ~(uint64_t)0x1;
+  *LWE = DAG.getTargetConstant((Value & 0x2) ? 1 : 0, DL, MVT::i32);
+  Value &= ~(uint64_t)0x2;
+
+  return Value == 0;
+}
+
 SDValue SITargetLowering::lowerImage(SDValue Op,
                                      const AMDGPU::ImageDimIntrinsicInfo *Intr,
                                      SelectionDAG &DAG) const {
@@ -4626,13 +4776,17 @@ SDValue SITargetLowering::lowerImage(SDValue Op,
       AMDGPU::getMIMGLZMappingInfo(Intr->BaseOpcode);
   unsigned IntrOpcode = Intr->BaseOpcode;
 
-  SmallVector<EVT, 2> ResultTypes(Op->value_begin(), Op->value_end());
+  SmallVector<EVT, 3> ResultTypes(Op->value_begin(), Op->value_end());
+  SmallVector<EVT, 3> OrigResultTypes(Op->value_begin(), Op->value_end());
   bool IsD16 = false;
   bool IsA16 = false;
   SDValue VData;
   int NumVDataDwords;
+  bool AdjustRetType = false;
+
   unsigned AddrIdx; // Index of first address argument
   unsigned DMask;
+  unsigned DMaskLanes = 0;
 
   if (BaseOpcode->Atomic) {
     VData = Op.getOperand(2);
@@ -4655,7 +4809,12 @@ SDValue SITargetLowering::lowerImage(SDValue Op,
       AddrIdx = 3;
     }
   } else {
-    unsigned DMaskIdx;
+    unsigned DMaskIdx = BaseOpcode->Store ? 3 : isa<MemSDNode>(Op) ? 2 : 1;
+    auto DMaskConst = dyn_cast<ConstantSDNode>(Op.getOperand(DMaskIdx));
+    if (!DMaskConst)
+      return Op;
+    DMask = DMaskConst->getZExtValue();
+    DMaskLanes = BaseOpcode->Gather4 ? 4 : countPopulation(DMask);
 
     if (BaseOpcode->Store) {
       VData = Op.getOperand(2);
@@ -4671,37 +4830,32 @@ SDValue SITargetLowering::lowerImage(SDValue Op,
       }
 
       NumVDataDwords = (VData.getValueType().getSizeInBits() + 31) / 32;
-      DMaskIdx = 3;
     } else {
-      MVT LoadVT = Op.getSimpleValueType();
+      // Work out the num dwords based on the dmask popcount and underlying type
+      // and whether packing is supported.
+      MVT LoadVT = ResultTypes[0].getSimpleVT();
       if (LoadVT.getScalarType() == MVT::f16) {
         if (Subtarget->getGeneration() < AMDGPUSubtarget::VOLCANIC_ISLANDS ||
             !BaseOpcode->HasD16)
           return Op; // D16 is unsupported for this instruction
 
         IsD16 = true;
-        if (LoadVT.isVector() && Subtarget->hasUnpackedD16VMem())
-          ResultTypes[0] = (LoadVT == MVT::v2f16) ? MVT::v2i32 : MVT::v4i32;
       }
 
-      NumVDataDwords = (ResultTypes[0].getSizeInBits() + 31) / 32;
-      DMaskIdx = isa<MemSDNode>(Op) ? 2 : 1;
-    }
+      // Confirm that the return type is large enough for the dmask specified
+      if ((LoadVT.isVector() && LoadVT.getVectorNumElements() < DMaskLanes) ||
+          (!LoadVT.isVector() && DMaskLanes > 1))
+          return Op;
 
-    auto DMaskConst = dyn_cast<ConstantSDNode>(Op.getOperand(DMaskIdx));
-    if (!DMaskConst)
-      return Op;
+      if (IsD16 && !Subtarget->hasUnpackedD16VMem())
+        NumVDataDwords = (DMaskLanes + 1) / 2;
+      else
+        NumVDataDwords = DMaskLanes;
 
-    AddrIdx = DMaskIdx + 1;
-    DMask = DMaskConst->getZExtValue();
-    if (!DMask && !BaseOpcode->Store) {
-      // Eliminate no-op loads. Stores with dmask == 0 are *not* no-op: they
-      // store the channels' default values.
-      SDValue Undef = DAG.getUNDEF(Op.getValueType());
-      if (isa<MemSDNode>(Op))
-        return DAG.getMergeValues({Undef, Op.getOperand(0)}, DL);
-      return Undef;
+      AdjustRetType = true;
     }
+
+    AddrIdx = DMaskIdx + 1;
   }
 
   unsigned NumGradients = BaseOpcode->Gradients ? DimInfo->NumGradients : 0;
@@ -4780,11 +4934,53 @@ SDValue SITargetLowering::lowerImage(SDValue Op,
     CtrlIdx = AddrIdx + NumVAddrs + 3;
   }
 
+  SDValue TFE;
+  SDValue LWE;
   SDValue TexFail = Op.getOperand(CtrlIdx);
-  auto TexFailConst = dyn_cast<ConstantSDNode>(TexFail.getNode());
-  if (!TexFailConst || TexFailConst->getZExtValue() != 0)
+  bool IsTexFail = false;
+  if (!parseTexFail(TexFail, DAG, &TFE, &LWE, IsTexFail))
     return Op;
 
+  if (IsTexFail) {
+    if (!DMaskLanes) {
+      // Expecting to get an error flag since TFC is on - and dmask is 0
+      // Force dmask to be at least 1 otherwise the instruction will fail
+      DMask = 0x1;
+      DMaskLanes = 1;
+      NumVDataDwords = 1;
+    }
+    NumVDataDwords += 1;
+    AdjustRetType = true;
+  }
+
+  // Has something earlier tagged that the return type needs adjusting
+  // This happens if the instruction is a load or has set TexFailCtrl flags
+  if (AdjustRetType) {
+    // NumVDataDwords reflects the true number of dwords required in the return type
+    if (DMaskLanes == 0 && !BaseOpcode->Store) {
+      // This is a no-op load. This can be eliminated
+      SDValue Undef = DAG.getUNDEF(Op.getValueType());
+      if (isa<MemSDNode>(Op))
+        return DAG.getMergeValues({Undef, Op.getOperand(0)}, DL);
+      return Undef;
+    }
+
+    // Have to use a power of 2 number of dwords
+    NumVDataDwords = 1 << Log2_32_Ceil(NumVDataDwords);
+
+    EVT NewVT = NumVDataDwords > 1 ?
+                  EVT::getVectorVT(*DAG.getContext(), MVT::f32, NumVDataDwords)
+                : MVT::f32;
+
+    ResultTypes[0] = NewVT;
+    if (ResultTypes.size() == 3) {
+      // Original result was aggregate type used for TexFailCtrl results
+      // The actual instruction returns as a vector type which has now been
+      // created. Remove the aggregate result.
+      ResultTypes.erase(&ResultTypes[1]);
+    }
+  }
+
   SDValue GLC;
   SDValue SLC;
   if (BaseOpcode->Atomic) {
@@ -4809,8 +5005,8 @@ SDValue SITargetLowering::lowerImage(SDValue Op,
   Ops.push_back(SLC);
   Ops.push_back(IsA16 &&  // a16 or r128
                 ST->hasFeature(AMDGPU::FeatureR128A16) ? True : False);
-  Ops.push_back(False); // tfe
-  Ops.push_back(False); // lwe
+  Ops.push_back(TFE); // tfe
+  Ops.push_back(LWE); // lwe
   Ops.push_back(DimInfo->DA ? True : False);
   if (BaseOpcode->HasD16)
     Ops.push_back(IsD16 ? True : False);
@@ -4838,11 +5034,12 @@ SDValue SITargetLowering::lowerImage(SDValue Op,
     SmallVector<SDValue, 1> Elt;
     DAG.ExtractVectorElements(SDValue(NewNode, 0), Elt, 0, 1);
     return DAG.getMergeValues({Elt[0], SDValue(NewNode, 1)}, DL);
-  } else if (IsD16 && !BaseOpcode->Store) {
-    MVT LoadVT = Op.getSimpleValueType();
-    SDValue Adjusted = adjustLoadValueTypeImpl(
-        SDValue(NewNode, 0), LoadVT, DL, DAG, Subtarget->hasUnpackedD16VMem());
-    return DAG.getMergeValues({Adjusted, SDValue(NewNode, 1)}, DL);
+  } else if (!BaseOpcode->Store) {
+    return constructRetValue(DAG, NewNode,
+                             OrigResultTypes, IsTexFail,
+                             Subtarget->hasUnpackedD16VMem(), IsD16,
+                             DMaskLanes, NumVDataDwords, DL,
+                             *DAG.getContext());
   }
 
   return SDValue(NewNode, 0);
@@ -8753,6 +8950,7 @@ static unsigned SubIdx2Lane(unsigned Idx) {
   case AMDGPU::sub1: return 1;
   case AMDGPU::sub2: return 2;
   case AMDGPU::sub3: return 3;
+  case AMDGPU::sub4: return 4; // Possible with TFE/LWE
   }
 }
 
@@ -8766,11 +8964,16 @@ SDNode *SITargetLowering::adjustWritemask(MachineSDNode *&Node,
   if (D16Idx >= 0 && Node->getConstantOperandVal(D16Idx))
     return Node; // not implemented for D16
 
-  SDNode *Users[4] = { nullptr };
+  SDNode *Users[5] = { nullptr };
   unsigned Lane = 0;
   unsigned DmaskIdx = AMDGPU::getNamedOperandIdx(Opcode, AMDGPU::OpName::dmask) - 1;
   unsigned OldDmask = Node->getConstantOperandVal(DmaskIdx);
   unsigned NewDmask = 0;
+  unsigned TFEIdx = AMDGPU::getNamedOperandIdx(Opcode, AMDGPU::OpName::tfe) - 1;
+  unsigned LWEIdx = AMDGPU::getNamedOperandIdx(Opcode, AMDGPU::OpName::lwe) - 1;
+  bool UsesTFC = (Node->getConstantOperandVal(TFEIdx) ||
+                  Node->getConstantOperandVal(LWEIdx)) ? 1 : 0;
+  unsigned TFCLane = 0;
   bool HasChain = Node->getNumValues() > 1;
 
   if (OldDmask == 0) {
@@ -8778,6 +8981,12 @@ SDNode *SITargetLowering::adjustWritemask(MachineSDNode *&Node,
     return Node;
   }
 
+  unsigned OldBitsSet = countPopulation(OldDmask);
+  // Work out which is the TFE/LWE lane if that is enabled.
+  if (UsesTFC) {
+    TFCLane = OldBitsSet;
+  }
+
   // Try to figure out the used register components
   for (SDNode::use_iterator I = Node->use_begin(), E = Node->use_end();
        I != E; ++I) {
@@ -8797,28 +9006,49 @@ SDNode *SITargetLowering::adjustWritemask(MachineSDNode *&Node,
     // set, etc.
     Lane = SubIdx2Lane(I->getConstantOperandVal(1));
 
-    // Set which texture component corresponds to the lane.
-    unsigned Comp;
-    for (unsigned i = 0, Dmask = OldDmask; (i <= Lane) && (Dmask != 0); i++) {
-      Comp = countTrailingZeros(Dmask);
-      Dmask &= ~(1 << Comp);
-    }
+    // Check if the use is for the TFE/LWE generated result at VGPRn+1.
+    if (UsesTFC && Lane == TFCLane) {
+      Users[Lane] = *I;
+    } else {
+      // Set which texture component corresponds to the lane.
+      unsigned Comp;
+      for (unsigned i = 0, Dmask = OldDmask; (i <= Lane) && (Dmask != 0); i++) {
+        Comp = countTrailingZeros(Dmask);
+        Dmask &= ~(1 << Comp);
+      }
 
-    // Abort if we have more than one user per component
-    if (Users[Lane])
-      return Node;
+      // Abort if we have more than one user per component.
+      if (Users[Lane])
+        return Node;
 
-    Users[Lane] = *I;
-    NewDmask |= 1 << Comp;
+      Users[Lane] = *I;
+      NewDmask |= 1 << Comp;
+    }
   }
 
+  // Don't allow 0 dmask, as hardware assumes one channel enabled.
+  bool NoChannels = !NewDmask;
+  if (NoChannels) {
+    // If the original dmask has one channel - then nothing to do
+    if (OldBitsSet == 1)
+      return Node;
+    // Use an arbitrary dmask - required for the instruction to work
+    NewDmask = 1;
+  }
   // Abort if there's no change
   if (NewDmask == OldDmask)
     return Node;
 
   unsigned BitsSet = countPopulation(NewDmask);
 
-  int NewOpcode = AMDGPU::getMaskedMIMGOp(Node->getMachineOpcode(), BitsSet);
+  // Check for TFE or LWE - increase the number of channels by one to account
+  // for the extra return value
+  // This will need adjustment for D16 if this is also included in
+  // adjustWriteMask (this function) but at present D16 are excluded.
+  unsigned NewChannels = BitsSet + UsesTFC;
+
+  int NewOpcode =
+      AMDGPU::getMaskedMIMGOp(Node->getMachineOpcode(), NewChannels);
   assert(NewOpcode != -1 &&
          NewOpcode != static_cast<int>(Node->getMachineOpcode()) &&
          "failed to find equivalent MIMG op");
@@ -8831,8 +9061,9 @@ SDNode *SITargetLowering::adjustWritemask(MachineSDNode *&Node,
 
   MVT SVT = Node->getValueType(0).getVectorElementType().getSimpleVT();
 
-  MVT ResultVT = BitsSet == 1 ?
-    SVT : MVT::getVectorVT(SVT, BitsSet == 3 ? 4 : BitsSet);
+  MVT ResultVT = NewChannels == 1 ?
+    SVT : MVT::getVectorVT(SVT, NewChannels == 3 ? 4 :
+                           NewChannels == 5 ? 8 : NewChannels);
   SDVTList NewVTList = HasChain ?
     DAG.getVTList(ResultVT, MVT::Other) : DAG.getVTList(ResultVT);
 
@@ -8846,7 +9077,7 @@ SDNode *SITargetLowering::adjustWritemask(MachineSDNode *&Node,
     DAG.ReplaceAllUsesOfValueWith(SDValue(Node, 1), SDValue(NewNode, 1));
   }
 
-  if (BitsSet == 1) {
+  if (NewChannels == 1) {
     assert(Node->hasNUsesOfValue(1, 0));
     SDNode *Copy = DAG.getMachineNode(TargetOpcode::COPY,
                                       SDLoc(Node), Users[Lane]->getValueType(0),
@@ -8856,19 +9087,24 @@ SDNode *SITargetLowering::adjustWritemask(MachineSDNode *&Node,
   }
 
   // Update the users of the node with the new indices
-  for (unsigned i = 0, Idx = AMDGPU::sub0; i < 4; ++i) {
+  for (unsigned i = 0, Idx = AMDGPU::sub0; i < 5; ++i) {
     SDNode *User = Users[i];
-    if (!User)
-      continue;
-
-    SDValue Op = DAG.getTargetConstant(Idx, SDLoc(User), MVT::i32);
-    DAG.UpdateNodeOperands(User, SDValue(NewNode, 0), Op);
+    if (!User) {
+      // Handle the special case of NoChannels. We set NewDmask to 1 above, but
+      // Users[0] is still nullptr because channel 0 doesn't really have a use.
+      if (i || !NoChannels)
+        continue;
+    } else {
+      SDValue Op = DAG.getTargetConstant(Idx, SDLoc(User), MVT::i32);
+      DAG.UpdateNodeOperands(User, SDValue(NewNode, 0), Op);
+    }
 
     switch (Idx) {
     default: break;
     case AMDGPU::sub0: Idx = AMDGPU::sub1; break;
     case AMDGPU::sub1: Idx = AMDGPU::sub2; break;
     case AMDGPU::sub2: Idx = AMDGPU::sub3; break;
+    case AMDGPU::sub3: Idx = AMDGPU::sub4; break;
     }
   }