OSDN Git Service

[X86][SSE] Pulled out splat detection helper from LowerScalarVariableShift (NFCI)
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Wed, 30 May 2018 19:16:59 +0000 (19:16 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Wed, 30 May 2018 19:16:59 +0000 (19:16 +0000)
Created the IsSplatValue helper from the splat detection code in LowerScalarVariableShift as a first NFC step towards improving support for splat rotations, which is an extension of PR37426.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@333580 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelLowering.cpp

index b50376b..7f1432c 100644 (file)
@@ -23122,6 +23122,40 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
   return SDValue();
 }
 
+// Determine if V is a splat value, and return the scalar.
+// TODO: Add support for SUB(SPLAT_CST, SPLAT) cases to support rotate patterns.
+static SDValue IsSplatValue(SDValue V, const SDLoc &dl, SelectionDAG &DAG) {
+  // Check if this is a splat build_vector node.
+  if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(V)) {
+    SDValue SplatAmt = BV->getSplatValue();
+    if (SplatAmt && SplatAmt.isUndef())
+      return SDValue();
+    return SplatAmt;
+  }
+
+  // Check if this is a shuffle node doing a splat.
+  ShuffleVectorSDNode *SVN = dyn_cast<ShuffleVectorSDNode>(V);
+  if (!SVN || !SVN->isSplat())
+    return SDValue();
+
+  unsigned SplatIdx = (unsigned)SVN->getSplatIndex();
+  SDValue InVec = V.getOperand(0);
+  if (InVec.getOpcode() == ISD::BUILD_VECTOR) {
+    assert((SplatIdx < InVec.getSimpleValueType().getVectorNumElements()) &&
+           "Unexpected shuffle index found!");
+    return InVec.getOperand(SplatIdx);
+  } else if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT) {
+    if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(InVec.getOperand(2)))
+      if (C->getZExtValue() == SplatIdx)
+        return InVec.getOperand(1);
+  }
+
+  // Avoid introducing an extract element from a shuffle.
+  return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
+                     V.getValueType().getVectorElementType(), InVec,
+                     DAG.getIntPtrConstant(SplatIdx, dl));
+}
+
 static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG,
                                         const X86Subtarget &Subtarget) {
   MVT VT = Op.getSimpleValueType();
@@ -23135,44 +23169,13 @@ static SDValue LowerScalarVariableShift(SDValue Op, SelectionDAG &DAG,
   unsigned X86OpcV = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHL :
     (Op.getOpcode() == ISD::SRL) ? X86ISD::VSRL : X86ISD::VSRA;
 
-  if (SupportedVectorShiftWithBaseAmnt(VT, Subtarget, Op.getOpcode())) {
-    SDValue BaseShAmt;
-    MVT EltVT = VT.getVectorElementType();
-
-    if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(Amt)) {
-      // Check if this build_vector node is doing a splat.
-      // If so, then set BaseShAmt equal to the splat value.
-      BaseShAmt = BV->getSplatValue();
-      if (BaseShAmt && BaseShAmt.isUndef())
-        BaseShAmt = SDValue();
-    } else {
-      if (Amt.getOpcode() == ISD::EXTRACT_SUBVECTOR)
-        Amt = Amt.getOperand(0);
-
-      ShuffleVectorSDNode *SVN = dyn_cast<ShuffleVectorSDNode>(Amt);
-      if (SVN && SVN->isSplat()) {
-        unsigned SplatIdx = (unsigned)SVN->getSplatIndex();
-        SDValue InVec = Amt.getOperand(0);
-        if (InVec.getOpcode() == ISD::BUILD_VECTOR) {
-          assert((SplatIdx < InVec.getSimpleValueType().getVectorNumElements()) &&
-                 "Unexpected shuffle index found!");
-          BaseShAmt = InVec.getOperand(SplatIdx);
-        } else if (InVec.getOpcode() == ISD::INSERT_VECTOR_ELT) {
-           if (ConstantSDNode *C =
-               dyn_cast<ConstantSDNode>(InVec.getOperand(2))) {
-             if (C->getZExtValue() == SplatIdx)
-               BaseShAmt = InVec.getOperand(1);
-           }
-        }
-
-        if (!BaseShAmt)
-          // Avoid introducing an extract element from a shuffle.
-          BaseShAmt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, InVec,
-                                  DAG.getIntPtrConstant(SplatIdx, dl));
-      }
-    }
+  // Peek through any EXTRACT_SUBVECTORs.
+  while (Amt.getOpcode() == ISD::EXTRACT_SUBVECTOR)
+    Amt = Amt.getOperand(0);
 
-    if (BaseShAmt.getNode()) {
+  if (SupportedVectorShiftWithBaseAmnt(VT, Subtarget, Op.getOpcode())) {
+    if (SDValue BaseShAmt = IsSplatValue(Amt, dl, DAG)) {
+      MVT EltVT = VT.getVectorElementType();
       assert(EltVT.bitsLE(MVT::i64) && "Unexpected element type!");
       if (EltVT != MVT::i64 && EltVT.bitsGT(MVT::i32))
         BaseShAmt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i64, BaseShAmt);