From a705e0edef18f8b292dbf26c88df26c10cad5294 Mon Sep 17 00:00:00 2001 From: Peter Collingbourne Date: Fri, 2 Dec 2016 03:20:58 +0000 Subject: [PATCH] IR: Move NumElements field from {Array,Vector}Type to SequentialType. Now that PointerType is no longer a SequentialType, all SequentialTypes have an associated number of elements, so we can move that information to the base class, allowing for a number of simplifications. Differential Revision: https://reviews.llvm.org/D27122 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@288464 91177308-0d34-0410-b5e6-96231b3b80d8 --- docs/ProgrammersManual.rst | 6 +++--- include/llvm/IR/DerivedTypes.h | 20 +++++++------------- include/llvm/IR/GetElementPtrTypeIterator.h | 9 +++------ lib/IR/ConstantFold.cpp | 9 ++------- lib/IR/Constants.cpp | 6 ++---- lib/IR/Type.cpp | 8 ++------ lib/Linker/IRMover.cpp | 8 +++----- lib/Transforms/IPO/GlobalOpt.cpp | 14 ++------------ lib/Transforms/Scalar/SROA.cpp | 9 ++------- lib/Transforms/Utils/FunctionComparator.cpp | 19 +++++++------------ 10 files changed, 33 insertions(+), 75 deletions(-) diff --git a/docs/ProgrammersManual.rst b/docs/ProgrammersManual.rst index 9375b625518..5ae54ffdff4 100644 --- a/docs/ProgrammersManual.rst +++ b/docs/ProgrammersManual.rst @@ -3283,13 +3283,13 @@ Important Derived Types * ``const Type * getElementType() const``: Returns the type of each of the elements in the sequential type. + * ``uint64_t getNumElements() const``: Returns the number of elements + in the sequential type. + ``ArrayType`` This is a subclass of SequentialType and defines the interface for array types. - * ``unsigned getNumElements() const``: Returns the number of elements - in the array. - ``PointerType`` Subclass of Type for pointer types. diff --git a/include/llvm/IR/DerivedTypes.h b/include/llvm/IR/DerivedTypes.h index 79820db051a..8892d3c244d 100644 --- a/include/llvm/IR/DerivedTypes.h +++ b/include/llvm/IR/DerivedTypes.h @@ -313,18 +313,21 @@ Type *Type::getStructElementType(unsigned N) const { /// identically. class SequentialType : public CompositeType { Type *ContainedType; ///< Storage for the single contained type. + uint64_t NumElements; SequentialType(const SequentialType &) = delete; const SequentialType &operator=(const SequentialType &) = delete; protected: - SequentialType(TypeID TID, Type *ElType) - : CompositeType(ElType->getContext(), TID), ContainedType(ElType) { + SequentialType(TypeID TID, Type *ElType, uint64_t NumElements) + : CompositeType(ElType->getContext(), TID), ContainedType(ElType), + NumElements(NumElements) { ContainedTys = &ContainedType; NumContainedTys = 1; } public: - Type *getElementType() const { return getSequentialElementType(); } + uint64_t getNumElements() const { return NumElements; } + Type *getElementType() const { return ContainedType; } /// Methods for support type inquiry through isa, cast, and dyn_cast. static inline bool classof(const Type *T) { @@ -334,8 +337,6 @@ public: /// Class to represent array types. class ArrayType : public SequentialType { - uint64_t NumElements; - ArrayType(const ArrayType &) = delete; const ArrayType &operator=(const ArrayType &) = delete; ArrayType(Type *ElType, uint64_t NumEl); @@ -347,8 +348,6 @@ public: /// Return true if the specified type is valid as a element type. static bool isValidElementType(Type *ElemTy); - uint64_t getNumElements() const { return NumElements; } - /// Methods for support type inquiry through isa, cast, and dyn_cast. static inline bool classof(const Type *T) { return T->getTypeID() == ArrayTyID; @@ -361,8 +360,6 @@ uint64_t Type::getArrayNumElements() const { /// Class to represent vector types. class VectorType : public SequentialType { - unsigned NumElements; - VectorType(const VectorType &) = delete; const VectorType &operator=(const VectorType &) = delete; VectorType(Type *ElType, unsigned NumEl); @@ -418,13 +415,10 @@ public: /// Return true if the specified type is valid as a element type. static bool isValidElementType(Type *ElemTy); - /// Return the number of elements in the Vector type. - unsigned getNumElements() const { return NumElements; } - /// Return the number of bits in the Vector type. /// Returns zero when the vector is a vector of pointers. unsigned getBitWidth() const { - return NumElements * getElementType()->getPrimitiveSizeInBits(); + return getNumElements() * getElementType()->getPrimitiveSizeInBits(); } /// Methods for support type inquiry through isa, cast, and dyn_cast. diff --git a/include/llvm/IR/GetElementPtrTypeIterator.h b/include/llvm/IR/GetElementPtrTypeIterator.h index d9904c529a6..75caee05b51 100644 --- a/include/llvm/IR/GetElementPtrTypeIterator.h +++ b/include/llvm/IR/GetElementPtrTypeIterator.h @@ -74,12 +74,9 @@ namespace llvm { generic_gep_type_iterator& operator++() { // Preincrement Type *Ty = getIndexedType(); - if (auto *ATy = dyn_cast(Ty)) { - CurTy = ATy->getElementType(); - NumElements = ATy->getNumElements(); - } else if (auto *VTy = dyn_cast(Ty)) { - CurTy = VTy->getElementType(); - NumElements = VTy->getNumElements(); + if (auto *STy = dyn_cast(Ty)) { + CurTy = STy->getElementType(); + NumElements = STy->getNumElements(); } else CurTy = dyn_cast(Ty); ++OpIt; diff --git a/lib/IR/ConstantFold.cpp b/lib/IR/ConstantFold.cpp index 6360b4503a4..e14ea78c27f 100644 --- a/lib/IR/ConstantFold.cpp +++ b/lib/IR/ConstantFold.cpp @@ -891,10 +891,8 @@ Constant *llvm::ConstantFoldInsertValueInstruction(Constant *Agg, unsigned NumElts; if (StructType *ST = dyn_cast(Agg->getType())) NumElts = ST->getNumElements(); - else if (ArrayType *AT = dyn_cast(Agg->getType())) - NumElts = AT->getNumElements(); else - NumElts = Agg->getType()->getVectorNumElements(); + NumElts = cast(Agg->getType())->getNumElements(); SmallVector Result; for (unsigned i = 0; i != NumElts; ++i) { @@ -2210,10 +2208,7 @@ Constant *llvm::ConstantFoldGetElementPtr(Type *PointeeTy, Constant *C, Unknown = true; continue; } - if (isIndexInRangeOfArrayType(isa(STy) - ? cast(STy)->getNumElements() - : cast(STy)->getNumElements(), - CI)) + if (isIndexInRangeOfArrayType(STy->getNumElements(), CI)) // It's in range, skip to the next index. continue; if (isa(Prev)) { diff --git a/lib/IR/Constants.cpp b/lib/IR/Constants.cpp index b6af6ed111a..6a6820234a0 100644 --- a/lib/IR/Constants.cpp +++ b/lib/IR/Constants.cpp @@ -794,10 +794,8 @@ UndefValue *UndefValue::getElementValue(unsigned Idx) const { unsigned UndefValue::getNumElements() const { Type *Ty = getType(); - if (auto *AT = dyn_cast(Ty)) - return AT->getNumElements(); - if (auto *VT = dyn_cast(Ty)) - return VT->getNumElements(); + if (auto *ST = dyn_cast(Ty)) + return ST->getNumElements(); return Ty->getStructNumElements(); } diff --git a/lib/IR/Type.cpp b/lib/IR/Type.cpp index 291d993ee37..ca866738f88 100644 --- a/lib/IR/Type.cpp +++ b/lib/IR/Type.cpp @@ -601,9 +601,7 @@ bool CompositeType::indexValid(unsigned Idx) const { //===----------------------------------------------------------------------===// ArrayType::ArrayType(Type *ElType, uint64_t NumEl) - : SequentialType(ArrayTyID, ElType) { - NumElements = NumEl; -} + : SequentialType(ArrayTyID, ElType, NumEl) {} ArrayType *ArrayType::get(Type *ElementType, uint64_t NumElements) { assert(isValidElementType(ElementType) && "Invalid type for array element!"); @@ -628,9 +626,7 @@ bool ArrayType::isValidElementType(Type *ElemTy) { //===----------------------------------------------------------------------===// VectorType::VectorType(Type *ElType, unsigned NumEl) - : SequentialType(VectorTyID, ElType) { - NumElements = NumEl; -} + : SequentialType(VectorTyID, ElType, NumEl) {} VectorType *VectorType::get(Type *ElementType, unsigned NumElements) { assert(NumElements > 0 && "#Elements of a VectorType must be greater than 0"); diff --git a/lib/Linker/IRMover.cpp b/lib/Linker/IRMover.cpp index ca91b1e8316..8a2aac3f74b 100644 --- a/lib/Linker/IRMover.cpp +++ b/lib/Linker/IRMover.cpp @@ -169,11 +169,9 @@ bool TypeMapTy::areTypesIsomorphic(Type *DstTy, Type *SrcTy) { if (DSTy->isLiteral() != SSTy->isLiteral() || DSTy->isPacked() != SSTy->isPacked()) return false; - } else if (ArrayType *DATy = dyn_cast(DstTy)) { - if (DATy->getNumElements() != cast(SrcTy)->getNumElements()) - return false; - } else if (VectorType *DVTy = dyn_cast(DstTy)) { - if (DVTy->getNumElements() != cast(SrcTy)->getNumElements()) + } else if (auto *DSeqTy = dyn_cast(DstTy)) { + if (DSeqTy->getNumElements() != + cast(SrcTy)->getNumElements()) return false; } diff --git a/lib/Transforms/IPO/GlobalOpt.cpp b/lib/Transforms/IPO/GlobalOpt.cpp index 1df9ee7a94f..5b0d5e3bc01 100644 --- a/lib/Transforms/IPO/GlobalOpt.cpp +++ b/lib/Transforms/IPO/GlobalOpt.cpp @@ -467,12 +467,7 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) { NGV->setAlignment(NewAlign); } } else if (SequentialType *STy = dyn_cast(Ty)) { - unsigned NumElements = 0; - if (ArrayType *ATy = dyn_cast(STy)) - NumElements = ATy->getNumElements(); - else - NumElements = cast(STy)->getNumElements(); - + unsigned NumElements = STy->getNumElements(); if (NumElements > 16 && GV->hasNUsesOrMore(16)) return nullptr; // It's not worth it. NewGlobals.reserve(NumElements); @@ -2119,12 +2114,7 @@ static Constant *EvaluateStoreInto(Constant *Init, Constant *Val, ConstantInt *CI = cast(Addr->getOperand(OpNo)); SequentialType *InitTy = cast(Init->getType()); - - uint64_t NumElts; - if (ArrayType *ATy = dyn_cast(InitTy)) - NumElts = ATy->getNumElements(); - else - NumElts = InitTy->getVectorNumElements(); + uint64_t NumElts = InitTy->getNumElements(); // Break up the array into elements. for (uint64_t i = 0, e = NumElts; i != e; ++i) diff --git a/lib/Transforms/Scalar/SROA.cpp b/lib/Transforms/Scalar/SROA.cpp index 258e77e9260..1f9d08528ef 100644 --- a/lib/Transforms/Scalar/SROA.cpp +++ b/lib/Transforms/Scalar/SROA.cpp @@ -3222,13 +3222,8 @@ static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset, Type *ElementTy = SeqTy->getElementType(); uint64_t ElementSize = DL.getTypeAllocSize(ElementTy); uint64_t NumSkippedElements = Offset / ElementSize; - if (ArrayType *ArrTy = dyn_cast(SeqTy)) { - if (NumSkippedElements >= ArrTy->getNumElements()) - return nullptr; - } else if (VectorType *VecTy = dyn_cast(SeqTy)) { - if (NumSkippedElements >= VecTy->getNumElements()) - return nullptr; - } + if (NumSkippedElements >= SeqTy->getNumElements()) + return nullptr; Offset -= NumSkippedElements * ElementSize; // First check if we need to recurse. diff --git a/lib/Transforms/Utils/FunctionComparator.cpp b/lib/Transforms/Utils/FunctionComparator.cpp index 1cb75b49c01..81a7c4ceffa 100644 --- a/lib/Transforms/Utils/FunctionComparator.cpp +++ b/lib/Transforms/Utils/FunctionComparator.cpp @@ -387,12 +387,6 @@ int FunctionComparator::cmpTypes(Type *TyL, Type *TyR) const { case Type::IntegerTyID: return cmpNumbers(cast(TyL)->getBitWidth(), cast(TyR)->getBitWidth()); - case Type::VectorTyID: { - VectorType *VTyL = cast(TyL), *VTyR = cast(TyR); - if (int Res = cmpNumbers(VTyL->getNumElements(), VTyR->getNumElements())) - return Res; - return cmpTypes(VTyL->getElementType(), VTyR->getElementType()); - } // TyL == TyR would have returned true earlier, because types are uniqued. case Type::VoidTyID: case Type::FloatTyID: @@ -445,12 +439,13 @@ int FunctionComparator::cmpTypes(Type *TyL, Type *TyR) const { return 0; } - case Type::ArrayTyID: { - ArrayType *ATyL = cast(TyL); - ArrayType *ATyR = cast(TyR); - if (ATyL->getNumElements() != ATyR->getNumElements()) - return cmpNumbers(ATyL->getNumElements(), ATyR->getNumElements()); - return cmpTypes(ATyL->getElementType(), ATyR->getElementType()); + case Type::ArrayTyID: + case Type::VectorTyID: { + auto *STyL = cast(TyL); + auto *STyR = cast(TyR); + if (STyL->getNumElements() != STyR->getNumElements()) + return cmpNumbers(STyL->getNumElements(), STyR->getNumElements()); + return cmpTypes(STyL->getElementType(), STyR->getElementType()); } } } -- 2.11.0