OSDN Git Service

[SelectionDAG][X86] Explicitly store the scale in the gather/scatter ISD nodes
authorCraig Topper <craig.topper@intel.com>
Wed, 10 Jan 2018 19:16:05 +0000 (19:16 +0000)
committerCraig Topper <craig.topper@intel.com>
Wed, 10 Jan 2018 19:16:05 +0000 (19:16 +0000)
Currently we infer the scale at isel time by analyzing whether the base is a constant 0 or not. If it is we assume scale is 1, else we take it from the element size of the pass thru or stored value. This seems a little weird and I think it makes more sense to make it explicit in the DAG rather than doing tricky things in the backend.

Most of this patch is just making sure we copy the scale around everywhere.

Differential Revision: https://reviews.llvm.org/D40055

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@322210 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/CodeGen/SelectionDAGNodes.h
lib/CodeGen/SelectionDAG/DAGCombiner.cpp
lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
lib/CodeGen/SelectionDAG/SelectionDAG.cpp
lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
lib/Target/X86/X86ISelDAGToDAG.cpp
lib/Target/X86/X86ISelLowering.cpp
lib/Target/X86/X86ISelLowering.h
test/CodeGen/X86/masked_gather_scatter.ll

index 522c2f1..7eb4dbb 100644 (file)
@@ -2120,13 +2120,14 @@ public:
       : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {}
 
   // In the both nodes address is Op1, mask is Op2:
-  // MaskedGatherSDNode  (Chain, src0, mask, base, index), src0 is a passthru value
-  // MaskedScatterSDNode (Chain, value, mask, base, index)
+  // MaskedGatherSDNode  (Chain, passthru, mask, base, index, scale)
+  // MaskedScatterSDNode (Chain, value, mask, base, index, scale)
   // Mask is a vector of i1 elements
   const SDValue &getBasePtr() const { return getOperand(3); }
   const SDValue &getIndex()   const { return getOperand(4); }
   const SDValue &getMask()    const { return getOperand(2); }
   const SDValue &getValue()   const { return getOperand(1); }
+  const SDValue &getScale()   const { return getOperand(5); }
 
   static bool classof(const SDNode *N) {
     return N->getOpcode() == ISD::MGATHER ||
index f229be1..1820d5f 100644 (file)
@@ -6726,6 +6726,7 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
   SDValue DataLo, DataHi;
   std::tie(DataLo, DataHi) = DAG.SplitVector(Data, DL);
 
+  SDValue Scale = MSC->getScale();
   SDValue BasePtr = MSC->getBasePtr();
   SDValue IndexLo, IndexHi;
   std::tie(IndexLo, IndexHi) = DAG.SplitVector(MSC->getIndex(), DL);
@@ -6735,11 +6736,11 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
                           MachineMemOperand::MOStore,  LoMemVT.getStoreSize(),
                           Alignment, MSC->getAAInfo(), MSC->getRanges());
 
-  SDValue OpsLo[] = { Chain, DataLo, MaskLo, BasePtr, IndexLo };
+  SDValue OpsLo[] = { Chain, DataLo, MaskLo, BasePtr, IndexLo, Scale };
   Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataLo.getValueType(),
                             DL, OpsLo, MMO);
 
-  SDValue OpsHi[] = {Chain, DataHi, MaskHi, BasePtr, IndexHi};
+  SDValue OpsHi[] = { Chain, DataHi, MaskHi, BasePtr, IndexHi, Scale };
   Hi = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(),
                             DL, OpsHi, MMO);
 
@@ -6859,6 +6860,7 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
   EVT LoMemVT, HiMemVT;
   std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT);
 
+  SDValue Scale = MGT->getScale();
   SDValue BasePtr = MGT->getBasePtr();
   SDValue Index = MGT->getIndex();
   SDValue IndexLo, IndexHi;
@@ -6869,13 +6871,13 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
                           MachineMemOperand::MOLoad,  LoMemVT.getStoreSize(),
                           Alignment, MGT->getAAInfo(), MGT->getRanges());
 
-  SDValue OpsLo[] = { Chain, Src0Lo, MaskLo, BasePtr, IndexLo };
+  SDValue OpsLo[] = { Chain, Src0Lo, MaskLo, BasePtr, IndexLo, Scale };
   Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, DL, OpsLo,
-                            MMO);
+                           MMO);
 
-  SDValue OpsHi[] = {Chain, Src0Hi, MaskHi, BasePtr, IndexHi};
+  SDValue OpsHi[] = { Chain, Src0Hi, MaskHi, BasePtr, IndexHi, Scale };
   Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, DL, OpsHi,
-                            MMO);
+                           MMO);
 
   AddToWorklist(Lo.getNode());
   AddToWorklist(Hi.getNode());
index b603523..eaa8273 100644 (file)
@@ -501,7 +501,7 @@ SDValue DAGTypeLegalizer::PromoteIntRes_MGATHER(MaskedGatherSDNode *N) {
 
   SDLoc dl(N);
   SDValue Ops[] = {N->getChain(), ExtSrc0, N->getMask(), N->getBasePtr(),
-                   N->getIndex()};
+                   N->getIndex(), N->getScale() };
   SDValue Res = DAG.getMaskedGather(DAG.getVTList(NVT, MVT::Other),
                                     N->getMemoryVT(), dl, Ops,
                                     N->getMemOperand());
index ee9b3fd..69d9fe9 100644 (file)
@@ -1238,6 +1238,7 @@ void DAGTypeLegalizer::SplitVecRes_MGATHER(MaskedGatherSDNode *MGT,
   SDValue Mask = MGT->getMask();
   SDValue Src0 = MGT->getValue();
   SDValue Index = MGT->getIndex();
+  SDValue Scale = MGT->getScale();
   unsigned Alignment = MGT->getOriginalAlignment();
 
   // Split Mask operand
@@ -1269,11 +1270,11 @@ void DAGTypeLegalizer::SplitVecRes_MGATHER(MaskedGatherSDNode *MGT,
                          MachineMemOperand::MOLoad,  LoMemVT.getStoreSize(),
                          Alignment, MGT->getAAInfo(), MGT->getRanges());
 
-  SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo};
+  SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo, Scale};
   Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl, OpsLo,
                            MMO);
 
-  SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi};
+  SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi, Scale};
   Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl, OpsHi,
                            MMO);
 
@@ -1816,6 +1817,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_MGATHER(MaskedGatherSDNode *MGT,
   SDValue Ch = MGT->getChain();
   SDValue Ptr = MGT->getBasePtr();
   SDValue Index = MGT->getIndex();
+  SDValue Scale = MGT->getScale();
   SDValue Mask = MGT->getMask();
   SDValue Src0 = MGT->getValue();
   unsigned Alignment = MGT->getOriginalAlignment();
@@ -1848,7 +1850,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_MGATHER(MaskedGatherSDNode *MGT,
                          MachineMemOperand::MOLoad,  LoMemVT.getStoreSize(),
                          Alignment, MGT->getAAInfo(), MGT->getRanges());
 
-  SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo};
+  SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo, Scale};
   SDValue Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl,
                                    OpsLo, MMO);
 
@@ -1858,7 +1860,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_MGATHER(MaskedGatherSDNode *MGT,
                          Alignment, MGT->getAAInfo(),
                          MGT->getRanges());
 
-  SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi};
+  SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi, Scale};
   SDValue Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl,
                                    OpsHi, MMO);
 
@@ -1941,6 +1943,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_MSCATTER(MaskedScatterSDNode *N,
   SDValue Ptr = N->getBasePtr();
   SDValue Mask = N->getMask();
   SDValue Index = N->getIndex();
+  SDValue Scale = N->getScale();
   SDValue Data = N->getValue();
   EVT MemoryVT = N->getMemoryVT();
   unsigned Alignment = N->getOriginalAlignment();
@@ -1976,7 +1979,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_MSCATTER(MaskedScatterSDNode *N,
                          MachineMemOperand::MOStore, LoMemVT.getStoreSize(),
                          Alignment, N->getAAInfo(), N->getRanges());
 
-  SDValue OpsLo[] = {Ch, DataLo, MaskLo, Ptr, IndexLo};
+  SDValue OpsLo[] = {Ch, DataLo, MaskLo, Ptr, IndexLo, Scale};
   Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataLo.getValueType(),
                             DL, OpsLo, MMO);
 
@@ -1988,7 +1991,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_MSCATTER(MaskedScatterSDNode *N,
   // The order of the Scatter operation after split is well defined. The "Hi"
   // part comes after the "Lo". So these two operations should be chained one
   // after another.
-  SDValue OpsHi[] = {Lo, DataHi, MaskHi, Ptr, IndexHi};
+  SDValue OpsHi[] = {Lo, DataHi, MaskHi, Ptr, IndexHi, Scale};
   return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(),
                               DL, OpsHi, MMO);
 }
@@ -2954,6 +2957,7 @@ SDValue DAGTypeLegalizer::WidenVecRes_MGATHER(MaskedGatherSDNode *N) {
   SDValue Mask = N->getMask();
   EVT MaskVT = Mask.getValueType();
   SDValue Src0 = GetWidenedVector(N->getValue());
+  SDValue Scale = N->getScale();
   unsigned NumElts = WideVT.getVectorNumElements();
   SDLoc dl(N);
 
@@ -2969,7 +2973,7 @@ SDValue DAGTypeLegalizer::WidenVecRes_MGATHER(MaskedGatherSDNode *N) {
                                      Index.getValueType().getScalarType(),
                                      NumElts);
   Index = ModifyToType(Index, WideIndexVT);
-  SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index };
+  SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, Scale };
   SDValue Res = DAG.getMaskedGather(DAG.getVTList(WideVT, MVT::Other),
                                     N->getMemoryVT(), dl, Ops,
                                     N->getMemOperand());
@@ -3593,6 +3597,7 @@ SDValue DAGTypeLegalizer::WidenVecOp_MSCATTER(SDNode *N, unsigned OpNo) {
   SDValue DataOp = MSC->getValue();
   SDValue Mask = MSC->getMask();
   EVT MaskVT = Mask.getValueType();
+  SDValue Scale = MSC->getScale();
 
   // Widen the value.
   SDValue WideVal = GetWidenedVector(DataOp);
@@ -3612,7 +3617,8 @@ SDValue DAGTypeLegalizer::WidenVecOp_MSCATTER(SDNode *N, unsigned OpNo) {
                                      NumElts);
   Index = ModifyToType(Index, WideIndexVT);
 
-  SDValue Ops[] = {MSC->getChain(), WideVal, Mask, MSC->getBasePtr(), Index};
+  SDValue Ops[] = {MSC->getChain(), WideVal, Mask, MSC->getBasePtr(), Index,
+                   Scale};
   return DAG.getMaskedScatter(DAG.getVTList(MVT::Other),
                               MSC->getMemoryVT(), dl, Ops,
                               MSC->getMemOperand());
index c012205..837173e 100644 (file)
@@ -6208,7 +6208,7 @@ SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl,
 SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
                                       ArrayRef<SDValue> Ops,
                                       MachineMemOperand *MMO) {
-  assert(Ops.size() == 5 && "Incompatible number of operands");
+  assert(Ops.size() == 6 && "Incompatible number of operands");
 
   FoldingSetNodeID ID;
   AddNodeIDNode(ID, ISD::MGATHER, VTs, Ops);
@@ -6234,6 +6234,9 @@ SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
   assert(N->getIndex().getValueType().getVectorNumElements() ==
              N->getValueType(0).getVectorNumElements() &&
          "Vector width mismatch between index and data");
+  assert(isa<ConstantSDNode>(N->getScale()) &&
+         cast<ConstantSDNode>(N->getScale())->getAPIntValue().isPowerOf2() &&
+         "Scale should be a constant power of 2");
 
   CSEMap.InsertNode(N, IP);
   InsertNode(N);
@@ -6245,7 +6248,7 @@ SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
 SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl,
                                        ArrayRef<SDValue> Ops,
                                        MachineMemOperand *MMO) {
-  assert(Ops.size() == 5 && "Incompatible number of operands");
+  assert(Ops.size() == 6 && "Incompatible number of operands");
 
   FoldingSetNodeID ID;
   AddNodeIDNode(ID, ISD::MSCATTER, VTs, Ops);
@@ -6268,6 +6271,9 @@ SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl,
   assert(N->getIndex().getValueType().getVectorNumElements() ==
              N->getValue().getValueType().getVectorNumElements() &&
          "Vector width mismatch between index and data");
+  assert(isa<ConstantSDNode>(N->getScale()) &&
+         cast<ConstantSDNode>(N->getScale())->getAPIntValue().isPowerOf2() &&
+         "Scale should be a constant power of 2");
 
   CSEMap.InsertNode(N, IP);
   InsertNode(N);
index bd8d767..1c15a3c 100644 (file)
@@ -3867,7 +3867,7 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I,
 // extract the splat value and use it as a uniform base.
 // In all other cases the function returns 'false'.
 static bool getUniformBase(const Value* &Ptr, SDValue& Base, SDValue& Index,
-                           SelectionDAGBuilder* SDB) {
+                           SDValue &Scale, SelectionDAGBuilder* SDB) {
   SelectionDAG& DAG = SDB->DAG;
   LLVMContext &Context = *DAG.getContext();
 
@@ -3897,6 +3897,10 @@ static bool getUniformBase(const Value* &Ptr, SDValue& Base, SDValue& Index,
   if (!SDB->findValue(Ptr) || !SDB->findValue(IndexVal))
     return false;
 
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  const DataLayout &DL = DAG.getDataLayout();
+  Scale = DAG.getTargetConstant(DL.getTypeAllocSize(GEP->getResultElementType()),
+                                SDB->getCurSDLoc(), TLI.getPointerTy(DL));
   Base = SDB->getValue(Ptr);
   Index = SDB->getValue(IndexVal);
 
@@ -3926,8 +3930,9 @@ void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
 
   SDValue Base;
   SDValue Index;
+  SDValue Scale;
   const Value *BasePtr = Ptr;
-  bool UniformBase = getUniformBase(BasePtr, Base, Index, this);
+  bool UniformBase = getUniformBase(BasePtr, Base, Index, Scale, this);
 
   const Value *MemOpBasePtr = UniformBase ? BasePtr : nullptr;
   MachineMemOperand *MMO = DAG.getMachineFunction().
@@ -3935,10 +3940,11 @@ void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
                          MachineMemOperand::MOStore,  VT.getStoreSize(),
                          Alignment, AAInfo);
   if (!UniformBase) {
-    Base = DAG.getTargetConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout()));
+    Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout()));
     Index = getValue(Ptr);
+    Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
   }
-  SDValue Ops[] = { getRoot(), Src0, Mask, Base, Index };
+  SDValue Ops[] = { getRoot(), Src0, Mask, Base, Index, Scale };
   SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), VT, sdl,
                                          Ops, MMO);
   DAG.setRoot(Scatter);
@@ -4025,8 +4031,9 @@ void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
   SDValue Root = DAG.getRoot();
   SDValue Base;
   SDValue Index;
+  SDValue Scale;
   const Value *BasePtr = Ptr;
-  bool UniformBase = getUniformBase(BasePtr, Base, Index, this);
+  bool UniformBase = getUniformBase(BasePtr, Base, Index, Scale, this);
   bool ConstantMemory = false;
   if (UniformBase &&
       AA && AA->pointsToConstantMemory(MemoryLocation(
@@ -4044,10 +4051,11 @@ void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
                          Alignment, AAInfo, Ranges);
 
   if (!UniformBase) {
-    Base = DAG.getTargetConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout()));
+    Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout()));
     Index = getValue(Ptr);
+    Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
   }
-  SDValue Ops[] = { Root, Src0, Mask, Base, Index };
+  SDValue Ops[] = { Root, Src0, Mask, Base, Index, Scale };
   SDValue Gather = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), VT, sdl,
                                        Ops, MMO);
 
index 775cc79..94363be 100644 (file)
@@ -1508,6 +1508,12 @@ bool X86DAGToDAGISel::matchAddressBase(SDValue N, X86ISelAddressMode &AM) {
 bool X86DAGToDAGISel::matchVectorAddress(SDValue N, X86ISelAddressMode &AM) {
   // TODO: Support other operations.
   switch (N.getOpcode()) {
+  case ISD::Constant: {
+    uint64_t Val = cast<ConstantSDNode>(N)->getSExtValue();
+    if (!foldOffsetIntoAddress(Val, AM))
+      return false;
+    break;
+  }
   case X86ISD::Wrapper:
     if (!matchWrapper(N, AM))
       return false;
@@ -1523,7 +1529,7 @@ bool X86DAGToDAGISel::selectVectorAddr(SDNode *Parent, SDValue N, SDValue &Base,
   X86ISelAddressMode AM;
   auto *Mgs = cast<X86MaskedGatherScatterSDNode>(Parent);
   AM.IndexReg = Mgs->getIndex();
-  AM.Scale = Mgs->getValue().getScalarValueSizeInBits() / 8;
+  AM.Scale = cast<ConstantSDNode>(Mgs->getScale())->getZExtValue();
 
   unsigned AddrSpace = cast<MemSDNode>(Parent)->getPointerInfo().getAddrSpace();
   // AddrSpace 256 -> GS, 257 -> FS, 258 -> SS.
@@ -1534,14 +1540,8 @@ bool X86DAGToDAGISel::selectVectorAddr(SDNode *Parent, SDValue N, SDValue &Base,
   if (AddrSpace == 258)
     AM.Segment = CurDAG->getRegister(X86::SS, MVT::i16);
 
-  // If Base is 0, the whole address is in index and the Scale is 1
-  if (isa<ConstantSDNode>(N)) {
-    assert(cast<ConstantSDNode>(N)->isNullValue() &&
-           "Unexpected base in gather/scatter");
-    AM.Scale = 1;
-  }
-  // Otherwise, try to match into the base and displacement fields.
-  else if (matchVectorAddress(N, AM))
+  // Try to match into the base and displacement fields.
+  if (matchVectorAddress(N, AM))
     return false;
 
   MVT VT = N.getSimpleValueType();
index 2623da5..16394f0 100644 (file)
@@ -24317,6 +24317,7 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget &Subtarget,
   assert(VT.getScalarSizeInBits() >= 32 && "Unsupported scatter op");
   SDLoc dl(Op);
 
+  SDValue Scale = N->getScale();
   SDValue Index = N->getIndex();
   SDValue Mask = N->getMask();
   SDValue Chain = N->getChain();
@@ -24383,7 +24384,7 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget &Subtarget,
 
   // The mask is killed by scatter, add it to the values
   SDVTList VTs = DAG.getVTList(Mask.getValueType(), MVT::Other);
-  SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index};
+  SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale};
   SDValue NewScatter = DAG.getTargetMemSDNode<X86MaskedScatterSDNode>(
       VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand());
   DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1));
@@ -24489,6 +24490,7 @@ static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget,
   MaskedGatherSDNode *N = cast<MaskedGatherSDNode>(Op.getNode());
   SDLoc dl(Op);
   MVT VT = Op.getSimpleValueType();
+  SDValue Scale = N->getScale();
   SDValue Index = N->getIndex();
   SDValue Mask = N->getMask();
   SDValue Src0 = N->getValue();
@@ -24509,7 +24511,8 @@ static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget,
     // the vector contains 8 elements, we just sign-extend the index
     if (NumElts == 8) {
       Index = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v8i64, Index);
-      SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index };
+      SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index,
+                        Scale };
       SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>(
           DAG.getVTList(VT, MaskVT, MVT::Other), Ops, dl, N->getMemoryVT(),
           N->getMemOperand());
@@ -24533,7 +24536,7 @@ static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget,
     MVT NewVT = MVT::getVectorVT(VT.getScalarType(), NumElts);
     Src0 = ExtendToType(Src0, NewVT, DAG);
 
-    SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index };
+    SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, Scale };
     SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>(
         DAG.getVTList(NewVT, MaskVT, MVT::Other), Ops, dl, N->getMemoryVT(),
         N->getMemOperand());
@@ -24544,7 +24547,7 @@ static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget,
     return DAG.getMergeValues(RetOps, dl);
   }
 
-  SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index };
+  SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, Scale };
   SDValue NewGather = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>(
       DAG.getVTList(VT, MaskVT, MVT::Other), Ops, dl, N->getMemoryVT(),
       N->getMemOperand());
@@ -25080,7 +25083,7 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
         Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i32, Mask);
       }
       SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(),
-                        Index };
+                        Index, Gather->getScale() };
       SDValue Res = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>(
         DAG.getVTList(MVT::v4f32, Mask.getValueType(), MVT::Other), Ops, dl,
         Gather->getMemoryVT(), Gather->getMemOperand());
@@ -25107,7 +25110,7 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
           Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i32, Mask);
         }
         SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(),
-                          Index };
+                          Index, Gather->getScale() };
         SDValue Res = DAG.getTargetMemSDNode<X86MaskedGatherSDNode>(
           DAG.getVTList(MVT::v4i32, Mask.getValueType(), MVT::Other), Ops, dl,
           Gather->getMemoryVT(), Gather->getMemOperand());
@@ -25128,7 +25131,7 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
       Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, Mask,
                          DAG.getConstant(0, dl, MVT::v2i1));
       SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(),
-                        Index };
+                        Index, Gather->getScale() };
       SDValue Res = DAG.getMaskedGather(DAG.getVTList(MVT::v4i32, MVT::Other),
                                         Gather->getMemoryVT(), dl, Ops,
                                         Gather->getMemOperand());
index 7a981a7..462c59a 100644 (file)
@@ -1442,6 +1442,7 @@ namespace llvm {
     const SDValue &getIndex()   const { return getOperand(4); }
     const SDValue &getMask()    const { return getOperand(2); }
     const SDValue &getValue()   const { return getOperand(1); }
+    const SDValue &getScale()   const { return getOperand(5); }
 
     static bool classof(const SDNode *N) {
       return N->getOpcode() == X86ISD::MGATHER ||
index e63517d..2558efe 100644 (file)
@@ -2782,3 +2782,163 @@ define <16 x float> @zext_index(float* %base, <16 x i32> %ind) {
   %res = call <16 x float> @llvm.masked.gather.v16f32.v16p0f32(<16 x float*> %gep.random, i32 4, <16 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <16 x float> undef)
   ret <16 x float>%res
 }
+
+define <16 x double> @test_gather_setcc_split(double* %base, <16 x i32> %ind, <16 x i32> %cmp, <16 x double> %passthru) {
+; KNL_64-LABEL: test_gather_setcc_split:
+; KNL_64:       # %bb.0:
+; KNL_64-NEXT:    vextractf64x4 $1, %zmm0, %ymm4
+; KNL_64-NEXT:    vpxor %xmm5, %xmm5, %xmm5
+; KNL_64-NEXT:    vextracti64x4 $1, %zmm1, %ymm6
+; KNL_64-NEXT:    vpcmpeqd %zmm5, %zmm6, %k1
+; KNL_64-NEXT:    vpcmpeqd %zmm5, %zmm1, %k2
+; KNL_64-NEXT:    vgatherdpd (%rdi,%ymm0,8), %zmm2 {%k2}
+; KNL_64-NEXT:    vgatherdpd (%rdi,%ymm4,8), %zmm3 {%k1}
+; KNL_64-NEXT:    vmovapd %zmm2, %zmm0
+; KNL_64-NEXT:    vmovapd %zmm3, %zmm1
+; KNL_64-NEXT:    retq
+;
+; KNL_32-LABEL: test_gather_setcc_split:
+; KNL_32:       # %bb.0:
+; KNL_32-NEXT:    pushl %ebp
+; KNL_32-NEXT:    .cfi_def_cfa_offset 8
+; KNL_32-NEXT:    .cfi_offset %ebp, -8
+; KNL_32-NEXT:    movl %esp, %ebp
+; KNL_32-NEXT:    .cfi_def_cfa_register %ebp
+; KNL_32-NEXT:    andl $-64, %esp
+; KNL_32-NEXT:    subl $64, %esp
+; KNL_32-NEXT:    vmovapd 72(%ebp), %zmm3
+; KNL_32-NEXT:    movl 8(%ebp), %eax
+; KNL_32-NEXT:    vextractf64x4 $1, %zmm0, %ymm4
+; KNL_32-NEXT:    vpxor %xmm5, %xmm5, %xmm5
+; KNL_32-NEXT:    vextracti64x4 $1, %zmm1, %ymm6
+; KNL_32-NEXT:    vpcmpeqd %zmm5, %zmm6, %k1
+; KNL_32-NEXT:    vpcmpeqd %zmm5, %zmm1, %k2
+; KNL_32-NEXT:    vgatherdpd (%eax,%ymm0,8), %zmm2 {%k2}
+; KNL_32-NEXT:    vgatherdpd (%eax,%ymm4,8), %zmm3 {%k1}
+; KNL_32-NEXT:    vmovapd %zmm2, %zmm0
+; KNL_32-NEXT:    vmovapd %zmm3, %zmm1
+; KNL_32-NEXT:    movl %ebp, %esp
+; KNL_32-NEXT:    popl %ebp
+; KNL_32-NEXT:    retl
+;
+; SKX-LABEL: test_gather_setcc_split:
+; SKX:       # %bb.0:
+; SKX-NEXT:    vextractf64x4 $1, %zmm0, %ymm4
+; SKX-NEXT:    vextracti64x4 $1, %zmm1, %ymm5
+; SKX-NEXT:    vpxor %xmm6, %xmm6, %xmm6
+; SKX-NEXT:    vpcmpeqd %ymm6, %ymm5, %k1
+; SKX-NEXT:    vpcmpeqd %ymm6, %ymm1, %k2
+; SKX-NEXT:    vgatherdpd (%rdi,%ymm0,8), %zmm2 {%k2}
+; SKX-NEXT:    vgatherdpd (%rdi,%ymm4,8), %zmm3 {%k1}
+; SKX-NEXT:    vmovapd %zmm2, %zmm0
+; SKX-NEXT:    vmovapd %zmm3, %zmm1
+; SKX-NEXT:    retq
+;
+; SKX_32-LABEL: test_gather_setcc_split:
+; SKX_32:       # %bb.0:
+; SKX_32-NEXT:    pushl %ebp
+; SKX_32-NEXT:    .cfi_def_cfa_offset 8
+; SKX_32-NEXT:    .cfi_offset %ebp, -8
+; SKX_32-NEXT:    movl %esp, %ebp
+; SKX_32-NEXT:    .cfi_def_cfa_register %ebp
+; SKX_32-NEXT:    andl $-64, %esp
+; SKX_32-NEXT:    subl $64, %esp
+; SKX_32-NEXT:    vmovapd 72(%ebp), %zmm3
+; SKX_32-NEXT:    movl 8(%ebp), %eax
+; SKX_32-NEXT:    vextractf64x4 $1, %zmm0, %ymm4
+; SKX_32-NEXT:    vextracti64x4 $1, %zmm1, %ymm5
+; SKX_32-NEXT:    vpxor %xmm6, %xmm6, %xmm6
+; SKX_32-NEXT:    vpcmpeqd %ymm6, %ymm5, %k1
+; SKX_32-NEXT:    vpcmpeqd %ymm6, %ymm1, %k2
+; SKX_32-NEXT:    vgatherdpd (%eax,%ymm0,8), %zmm2 {%k2}
+; SKX_32-NEXT:    vgatherdpd (%eax,%ymm4,8), %zmm3 {%k1}
+; SKX_32-NEXT:    vmovapd %zmm2, %zmm0
+; SKX_32-NEXT:    vmovapd %zmm3, %zmm1
+; SKX_32-NEXT:    movl %ebp, %esp
+; SKX_32-NEXT:    popl %ebp
+; SKX_32-NEXT:    retl
+  %sext_ind = sext <16 x i32> %ind to <16 x i64>
+  %gep.random = getelementptr double, double *%base, <16 x i64> %sext_ind
+
+  %mask = icmp eq <16 x i32> %cmp, zeroinitializer
+  %res = call <16 x double> @llvm.masked.gather.v16f64.v16p0f64(<16 x double*> %gep.random, i32 4, <16 x i1> %mask, <16 x double> %passthru)
+  ret <16 x double>%res
+}
+
+define void @test_scatter_setcc_split(double* %base, <16 x i32> %ind, <16 x i32> %cmp, <16 x double> %src0)  {
+; KNL_64-LABEL: test_scatter_setcc_split:
+; KNL_64:       # %bb.0:
+; KNL_64-NEXT:    vextractf64x4 $1, %zmm0, %ymm4
+; KNL_64-NEXT:    vpxor %xmm5, %xmm5, %xmm5
+; KNL_64-NEXT:    vpcmpeqd %zmm5, %zmm1, %k1
+; KNL_64-NEXT:    vextracti64x4 $1, %zmm1, %ymm1
+; KNL_64-NEXT:    vpcmpeqd %zmm5, %zmm1, %k2
+; KNL_64-NEXT:    vscatterdpd %zmm3, (%rdi,%ymm4,8) {%k2}
+; KNL_64-NEXT:    vscatterdpd %zmm2, (%rdi,%ymm0,8) {%k1}
+; KNL_64-NEXT:    vzeroupper
+; KNL_64-NEXT:    retq
+;
+; KNL_32-LABEL: test_scatter_setcc_split:
+; KNL_32:       # %bb.0:
+; KNL_32-NEXT:    pushl %ebp
+; KNL_32-NEXT:    .cfi_def_cfa_offset 8
+; KNL_32-NEXT:    .cfi_offset %ebp, -8
+; KNL_32-NEXT:    movl %esp, %ebp
+; KNL_32-NEXT:    .cfi_def_cfa_register %ebp
+; KNL_32-NEXT:    andl $-64, %esp
+; KNL_32-NEXT:    subl $64, %esp
+; KNL_32-NEXT:    vmovapd 72(%ebp), %zmm3
+; KNL_32-NEXT:    movl 8(%ebp), %eax
+; KNL_32-NEXT:    vextractf64x4 $1, %zmm0, %ymm4
+; KNL_32-NEXT:    vpxor %xmm5, %xmm5, %xmm5
+; KNL_32-NEXT:    vpcmpeqd %zmm5, %zmm1, %k1
+; KNL_32-NEXT:    vextracti64x4 $1, %zmm1, %ymm1
+; KNL_32-NEXT:    vpcmpeqd %zmm5, %zmm1, %k2
+; KNL_32-NEXT:    vscatterdpd %zmm3, (%eax,%ymm4,8) {%k2}
+; KNL_32-NEXT:    vscatterdpd %zmm2, (%eax,%ymm0,8) {%k1}
+; KNL_32-NEXT:    movl %ebp, %esp
+; KNL_32-NEXT:    popl %ebp
+; KNL_32-NEXT:    vzeroupper
+; KNL_32-NEXT:    retl
+;
+; SKX-LABEL: test_scatter_setcc_split:
+; SKX:       # %bb.0:
+; SKX-NEXT:    vextractf64x4 $1, %zmm0, %ymm4
+; SKX-NEXT:    vpxor %xmm5, %xmm5, %xmm5
+; SKX-NEXT:    vpcmpeqd %ymm5, %ymm1, %k1
+; SKX-NEXT:    vextracti64x4 $1, %zmm1, %ymm1
+; SKX-NEXT:    vpcmpeqd %ymm5, %ymm1, %k2
+; SKX-NEXT:    vscatterdpd %zmm3, (%rdi,%ymm4,8) {%k2}
+; SKX-NEXT:    vscatterdpd %zmm2, (%rdi,%ymm0,8) {%k1}
+; SKX-NEXT:    vzeroupper
+; SKX-NEXT:    retq
+;
+; SKX_32-LABEL: test_scatter_setcc_split:
+; SKX_32:       # %bb.0:
+; SKX_32-NEXT:    pushl %ebp
+; SKX_32-NEXT:    .cfi_def_cfa_offset 8
+; SKX_32-NEXT:    .cfi_offset %ebp, -8
+; SKX_32-NEXT:    movl %esp, %ebp
+; SKX_32-NEXT:    .cfi_def_cfa_register %ebp
+; SKX_32-NEXT:    andl $-64, %esp
+; SKX_32-NEXT:    subl $64, %esp
+; SKX_32-NEXT:    vmovapd 72(%ebp), %zmm3
+; SKX_32-NEXT:    movl 8(%ebp), %eax
+; SKX_32-NEXT:    vextractf64x4 $1, %zmm0, %ymm4
+; SKX_32-NEXT:    vpxor %xmm5, %xmm5, %xmm5
+; SKX_32-NEXT:    vpcmpeqd %ymm5, %ymm1, %k1
+; SKX_32-NEXT:    vextracti64x4 $1, %zmm1, %ymm1
+; SKX_32-NEXT:    vpcmpeqd %ymm5, %ymm1, %k2
+; SKX_32-NEXT:    vscatterdpd %zmm3, (%eax,%ymm4,8) {%k2}
+; SKX_32-NEXT:    vscatterdpd %zmm2, (%eax,%ymm0,8) {%k1}
+; SKX_32-NEXT:    movl %ebp, %esp
+; SKX_32-NEXT:    popl %ebp
+; SKX_32-NEXT:    vzeroupper
+; SKX_32-NEXT:    retl
+  %sext_ind = sext <16 x i32> %ind to <16 x i64>
+  %gep.random = getelementptr double, double *%base, <16 x i64> %sext_ind
+
+  %mask = icmp eq <16 x i32> %cmp, zeroinitializer
+  call void @llvm.masked.scatter.v16f64.v16p0f64(<16 x double> %src0, <16 x double*> %gep.random, i32 4, <16 x i1> %mask)
+  ret void
+}