OSDN Git Service

MemCmpExpansion::getCompareLoadPairs - assert we find a comparison diff. NFCI.
[android-x86/external-llvm.git] / lib / CodeGen / ExpandMemCmp.cpp
1 //===--- ExpandMemCmp.cpp - Expand memcmp() to load/stores ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass tries to expand memcmp() calls into optimally-sized loads and
10 // compares for the target.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/ADT/Statistic.h"
15 #include "llvm/Analysis/ConstantFolding.h"
16 #include "llvm/Analysis/TargetLibraryInfo.h"
17 #include "llvm/Analysis/TargetTransformInfo.h"
18 #include "llvm/Analysis/ValueTracking.h"
19 #include "llvm/CodeGen/TargetLowering.h"
20 #include "llvm/CodeGen/TargetPassConfig.h"
21 #include "llvm/CodeGen/TargetSubtargetInfo.h"
22 #include "llvm/IR/IRBuilder.h"
23
24 using namespace llvm;
25
26 #define DEBUG_TYPE "expandmemcmp"
27
28 STATISTIC(NumMemCmpCalls, "Number of memcmp calls");
29 STATISTIC(NumMemCmpNotConstant, "Number of memcmp calls without constant size");
30 STATISTIC(NumMemCmpGreaterThanMax,
31           "Number of memcmp calls with size greater than max size");
32 STATISTIC(NumMemCmpInlined, "Number of inlined memcmp calls");
33
34 static cl::opt<unsigned> MemCmpEqZeroNumLoadsPerBlock(
35     "memcmp-num-loads-per-block", cl::Hidden, cl::init(1),
36     cl::desc("The number of loads per basic block for inline expansion of "
37              "memcmp that is only being compared against zero."));
38
39 static cl::opt<unsigned> MaxLoadsPerMemcmp(
40     "max-loads-per-memcmp", cl::Hidden,
41     cl::desc("Set maximum number of loads used in expanded memcmp"));
42
43 static cl::opt<unsigned> MaxLoadsPerMemcmpOptSize(
44     "max-loads-per-memcmp-opt-size", cl::Hidden,
45     cl::desc("Set maximum number of loads used in expanded memcmp for -Os/Oz"));
46
47 namespace {
48
49
50 // This class provides helper functions to expand a memcmp library call into an
51 // inline expansion.
52 class MemCmpExpansion {
53   struct ResultBlock {
54     BasicBlock *BB = nullptr;
55     PHINode *PhiSrc1 = nullptr;
56     PHINode *PhiSrc2 = nullptr;
57
58     ResultBlock() = default;
59   };
60
61   CallInst *const CI;
62   ResultBlock ResBlock;
63   const uint64_t Size;
64   unsigned MaxLoadSize;
65   uint64_t NumLoadsNonOneByte;
66   const uint64_t NumLoadsPerBlockForZeroCmp;
67   std::vector<BasicBlock *> LoadCmpBlocks;
68   BasicBlock *EndBlock;
69   PHINode *PhiRes;
70   const bool IsUsedForZeroCmp;
71   const DataLayout &DL;
72   IRBuilder<> Builder;
73   // Represents the decomposition in blocks of the expansion. For example,
74   // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and
75   // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}.
76   struct LoadEntry {
77     LoadEntry(unsigned LoadSize, uint64_t Offset)
78         : LoadSize(LoadSize), Offset(Offset) {
79     }
80
81     // The size of the load for this block, in bytes.
82     unsigned LoadSize;
83     // The offset of this load from the base pointer, in bytes.
84     uint64_t Offset;
85   };
86   using LoadEntryVector = SmallVector<LoadEntry, 8>;
87   LoadEntryVector LoadSequence;
88
89   void createLoadCmpBlocks();
90   void createResultBlock();
91   void setupResultBlockPHINodes();
92   void setupEndBlockPHINodes();
93   Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex);
94   void emitLoadCompareBlock(unsigned BlockIndex);
95   void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
96                                          unsigned &LoadIndex);
97   void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned OffsetBytes);
98   void emitMemCmpResultBlock();
99   Value *getMemCmpExpansionZeroCase();
100   Value *getMemCmpEqZeroOneBlock();
101   Value *getMemCmpOneBlock();
102   Value *getPtrToElementAtOffset(Value *Source, Type *LoadSizeType,
103                                  uint64_t OffsetBytes);
104
105   static LoadEntryVector
106   computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes,
107                             unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte);
108   static LoadEntryVector
109   computeOverlappingLoadSequence(uint64_t Size, unsigned MaxLoadSize,
110                                  unsigned MaxNumLoads,
111                                  unsigned &NumLoadsNonOneByte);
112
113 public:
114   MemCmpExpansion(CallInst *CI, uint64_t Size,
115                   const TargetTransformInfo::MemCmpExpansionOptions &Options,
116                   unsigned MaxNumLoads, const bool IsUsedForZeroCmp,
117                   unsigned MaxLoadsPerBlockForZeroCmp, const DataLayout &TheDataLayout);
118
119   unsigned getNumBlocks();
120   uint64_t getNumLoads() const { return LoadSequence.size(); }
121
122   Value *getMemCmpExpansion();
123 };
124
125 MemCmpExpansion::LoadEntryVector MemCmpExpansion::computeGreedyLoadSequence(
126     uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes,
127     const unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte) {
128   NumLoadsNonOneByte = 0;
129   LoadEntryVector LoadSequence;
130   uint64_t Offset = 0;
131   while (Size && !LoadSizes.empty()) {
132     const unsigned LoadSize = LoadSizes.front();
133     const uint64_t NumLoadsForThisSize = Size / LoadSize;
134     if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) {
135       // Do not expand if the total number of loads is larger than what the
136       // target allows. Note that it's important that we exit before completing
137       // the expansion to avoid using a ton of memory to store the expansion for
138       // large sizes.
139       return {};
140     }
141     if (NumLoadsForThisSize > 0) {
142       for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) {
143         LoadSequence.push_back({LoadSize, Offset});
144         Offset += LoadSize;
145       }
146       if (LoadSize > 1)
147         ++NumLoadsNonOneByte;
148       Size = Size % LoadSize;
149     }
150     LoadSizes = LoadSizes.drop_front();
151   }
152   return LoadSequence;
153 }
154
155 MemCmpExpansion::LoadEntryVector
156 MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size,
157                                                 const unsigned MaxLoadSize,
158                                                 const unsigned MaxNumLoads,
159                                                 unsigned &NumLoadsNonOneByte) {
160   // These are already handled by the greedy approach.
161   if (Size < 2 || MaxLoadSize < 2)
162     return {};
163
164   // We try to do as many non-overlapping loads as possible starting from the
165   // beginning.
166   const uint64_t NumNonOverlappingLoads = Size / MaxLoadSize;
167   assert(NumNonOverlappingLoads && "there must be at least one load");
168   // There remain 0 to (MaxLoadSize - 1) bytes to load, this will be done with
169   // an overlapping load.
170   Size = Size - NumNonOverlappingLoads * MaxLoadSize;
171   // Bail if we do not need an overloapping store, this is already handled by
172   // the greedy approach.
173   if (Size == 0)
174     return {};
175   // Bail if the number of loads (non-overlapping + potential overlapping one)
176   // is larger than the max allowed.
177   if ((NumNonOverlappingLoads + 1) > MaxNumLoads)
178     return {};
179
180   // Add non-overlapping loads.
181   LoadEntryVector LoadSequence;
182   uint64_t Offset = 0;
183   for (uint64_t I = 0; I < NumNonOverlappingLoads; ++I) {
184     LoadSequence.push_back({MaxLoadSize, Offset});
185     Offset += MaxLoadSize;
186   }
187
188   // Add the last overlapping load.
189   assert(Size > 0 && Size < MaxLoadSize && "broken invariant");
190   LoadSequence.push_back({MaxLoadSize, Offset - (MaxLoadSize - Size)});
191   NumLoadsNonOneByte = 1;
192   return LoadSequence;
193 }
194
195 // Initialize the basic block structure required for expansion of memcmp call
196 // with given maximum load size and memcmp size parameter.
197 // This structure includes:
198 // 1. A list of load compare blocks - LoadCmpBlocks.
199 // 2. An EndBlock, split from original instruction point, which is the block to
200 // return from.
201 // 3. ResultBlock, block to branch to for early exit when a
202 // LoadCmpBlock finds a difference.
203 MemCmpExpansion::MemCmpExpansion(
204     CallInst *const CI, uint64_t Size,
205     const TargetTransformInfo::MemCmpExpansionOptions &Options,
206     const unsigned MaxNumLoads, const bool IsUsedForZeroCmp,
207     const unsigned MaxLoadsPerBlockForZeroCmp, const DataLayout &TheDataLayout)
208     : CI(CI),
209       Size(Size),
210       MaxLoadSize(0),
211       NumLoadsNonOneByte(0),
212       NumLoadsPerBlockForZeroCmp(MaxLoadsPerBlockForZeroCmp),
213       IsUsedForZeroCmp(IsUsedForZeroCmp),
214       DL(TheDataLayout),
215       Builder(CI) {
216   assert(Size > 0 && "zero blocks");
217   // Scale the max size down if the target can load more bytes than we need.
218   llvm::ArrayRef<unsigned> LoadSizes(Options.LoadSizes);
219   while (!LoadSizes.empty() && LoadSizes.front() > Size) {
220     LoadSizes = LoadSizes.drop_front();
221   }
222   assert(!LoadSizes.empty() && "cannot load Size bytes");
223   MaxLoadSize = LoadSizes.front();
224   // Compute the decomposition.
225   unsigned GreedyNumLoadsNonOneByte = 0;
226   LoadSequence = computeGreedyLoadSequence(Size, LoadSizes, MaxNumLoads,
227                                            GreedyNumLoadsNonOneByte);
228   NumLoadsNonOneByte = GreedyNumLoadsNonOneByte;
229   assert(LoadSequence.size() <= MaxNumLoads && "broken invariant");
230   // If we allow overlapping loads and the load sequence is not already optimal,
231   // use overlapping loads.
232   if (Options.AllowOverlappingLoads &&
233       (LoadSequence.empty() || LoadSequence.size() > 2)) {
234     unsigned OverlappingNumLoadsNonOneByte = 0;
235     auto OverlappingLoads = computeOverlappingLoadSequence(
236         Size, MaxLoadSize, MaxNumLoads, OverlappingNumLoadsNonOneByte);
237     if (!OverlappingLoads.empty() &&
238         (LoadSequence.empty() ||
239          OverlappingLoads.size() < LoadSequence.size())) {
240       LoadSequence = OverlappingLoads;
241       NumLoadsNonOneByte = OverlappingNumLoadsNonOneByte;
242     }
243   }
244   assert(LoadSequence.size() <= MaxNumLoads && "broken invariant");
245 }
246
247 unsigned MemCmpExpansion::getNumBlocks() {
248   if (IsUsedForZeroCmp)
249     return getNumLoads() / NumLoadsPerBlockForZeroCmp +
250            (getNumLoads() % NumLoadsPerBlockForZeroCmp != 0 ? 1 : 0);
251   return getNumLoads();
252 }
253
254 void MemCmpExpansion::createLoadCmpBlocks() {
255   for (unsigned i = 0; i < getNumBlocks(); i++) {
256     BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb",
257                                         EndBlock->getParent(), EndBlock);
258     LoadCmpBlocks.push_back(BB);
259   }
260 }
261
262 void MemCmpExpansion::createResultBlock() {
263   ResBlock.BB = BasicBlock::Create(CI->getContext(), "res_block",
264                                    EndBlock->getParent(), EndBlock);
265 }
266
267 /// Return a pointer to an element of type `LoadSizeType` at offset
268 /// `OffsetBytes`.
269 Value *MemCmpExpansion::getPtrToElementAtOffset(Value *Source,
270                                                 Type *LoadSizeType,
271                                                 uint64_t OffsetBytes) {
272   if (OffsetBytes > 0) {
273     auto *ByteType = Type::getInt8Ty(CI->getContext());
274     Source = Builder.CreateGEP(
275         ByteType, Builder.CreateBitCast(Source, ByteType->getPointerTo()),
276         ConstantInt::get(ByteType, OffsetBytes));
277   }
278   return Builder.CreateBitCast(Source, LoadSizeType->getPointerTo());
279 }
280
281 // This function creates the IR instructions for loading and comparing 1 byte.
282 // It loads 1 byte from each source of the memcmp parameters with the given
283 // GEPIndex. It then subtracts the two loaded values and adds this result to the
284 // final phi node for selecting the memcmp result.
285 void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
286                                                unsigned OffsetBytes) {
287   Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
288   Type *LoadSizeType = Type::getInt8Ty(CI->getContext());
289   Value *Source1 =
290       getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, OffsetBytes);
291   Value *Source2 =
292       getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, OffsetBytes);
293
294   Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
295   Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
296
297   LoadSrc1 = Builder.CreateZExt(LoadSrc1, Type::getInt32Ty(CI->getContext()));
298   LoadSrc2 = Builder.CreateZExt(LoadSrc2, Type::getInt32Ty(CI->getContext()));
299   Value *Diff = Builder.CreateSub(LoadSrc1, LoadSrc2);
300
301   PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]);
302
303   if (BlockIndex < (LoadCmpBlocks.size() - 1)) {
304     // Early exit branch if difference found to EndBlock. Otherwise, continue to
305     // next LoadCmpBlock,
306     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff,
307                                     ConstantInt::get(Diff->getType(), 0));
308     BranchInst *CmpBr =
309         BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp);
310     Builder.Insert(CmpBr);
311   } else {
312     // The last block has an unconditional branch to EndBlock.
313     BranchInst *CmpBr = BranchInst::Create(EndBlock);
314     Builder.Insert(CmpBr);
315   }
316 }
317
318 /// Generate an equality comparison for one or more pairs of loaded values.
319 /// This is used in the case where the memcmp() call is compared equal or not
320 /// equal to zero.
321 Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
322                                             unsigned &LoadIndex) {
323   assert(LoadIndex < getNumLoads() &&
324          "getCompareLoadPairs() called with no remaining loads");
325   std::vector<Value *> XorList, OrList;
326   Value *Diff = nullptr;
327
328   const unsigned NumLoads =
329       std::min(getNumLoads() - LoadIndex, NumLoadsPerBlockForZeroCmp);
330
331   // For a single-block expansion, start inserting before the memcmp call.
332   if (LoadCmpBlocks.empty())
333     Builder.SetInsertPoint(CI);
334   else
335     Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
336
337   Value *Cmp = nullptr;
338   // If we have multiple loads per block, we need to generate a composite
339   // comparison using xor+or. The type for the combinations is the largest load
340   // type.
341   IntegerType *const MaxLoadType =
342       NumLoads == 1 ? nullptr
343                     : IntegerType::get(CI->getContext(), MaxLoadSize * 8);
344   for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
345     const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
346
347     IntegerType *LoadSizeType =
348         IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
349
350     Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType,
351                                              CurLoadEntry.Offset);
352     Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType,
353                                              CurLoadEntry.Offset);
354
355     // Get a constant or load a value for each source address.
356     Value *LoadSrc1 = nullptr;
357     if (auto *Source1C = dyn_cast<Constant>(Source1))
358       LoadSrc1 = ConstantFoldLoadFromConstPtr(Source1C, LoadSizeType, DL);
359     if (!LoadSrc1)
360       LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
361
362     Value *LoadSrc2 = nullptr;
363     if (auto *Source2C = dyn_cast<Constant>(Source2))
364       LoadSrc2 = ConstantFoldLoadFromConstPtr(Source2C, LoadSizeType, DL);
365     if (!LoadSrc2)
366       LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
367
368     if (NumLoads != 1) {
369       if (LoadSizeType != MaxLoadType) {
370         LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
371         LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
372       }
373       // If we have multiple loads per block, we need to generate a composite
374       // comparison using xor+or.
375       Diff = Builder.CreateXor(LoadSrc1, LoadSrc2);
376       Diff = Builder.CreateZExt(Diff, MaxLoadType);
377       XorList.push_back(Diff);
378     } else {
379       // If there's only one load per block, we just compare the loaded values.
380       Cmp = Builder.CreateICmpNE(LoadSrc1, LoadSrc2);
381     }
382   }
383
384   auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> {
385     std::vector<Value *> OutList;
386     for (unsigned i = 0; i < InList.size() - 1; i = i + 2) {
387       Value *Or = Builder.CreateOr(InList[i], InList[i + 1]);
388       OutList.push_back(Or);
389     }
390     if (InList.size() % 2 != 0)
391       OutList.push_back(InList.back());
392     return OutList;
393   };
394
395   if (!Cmp) {
396     // Pairwise OR the XOR results.
397     OrList = pairWiseOr(XorList);
398
399     // Pairwise OR the OR results until one result left.
400     while (OrList.size() != 1) {
401       OrList = pairWiseOr(OrList);
402     }
403
404     assert(Diff && "Failed to find comparison diff");
405     Cmp = Builder.CreateICmpNE(OrList[0], ConstantInt::get(Diff->getType(), 0));
406   }
407
408   return Cmp;
409 }
410
411 void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
412                                                         unsigned &LoadIndex) {
413   Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex);
414
415   BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
416                            ? EndBlock
417                            : LoadCmpBlocks[BlockIndex + 1];
418   // Early exit branch if difference found to ResultBlock. Otherwise,
419   // continue to next LoadCmpBlock or EndBlock.
420   BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp);
421   Builder.Insert(CmpBr);
422
423   // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
424   // since early exit to ResultBlock was not taken (no difference was found in
425   // any of the bytes).
426   if (BlockIndex == LoadCmpBlocks.size() - 1) {
427     Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
428     PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
429   }
430 }
431
432 // This function creates the IR intructions for loading and comparing using the
433 // given LoadSize. It loads the number of bytes specified by LoadSize from each
434 // source of the memcmp parameters. It then does a subtract to see if there was
435 // a difference in the loaded values. If a difference is found, it branches
436 // with an early exit to the ResultBlock for calculating which source was
437 // larger. Otherwise, it falls through to the either the next LoadCmpBlock or
438 // the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with
439 // a special case through emitLoadCompareByteBlock. The special handling can
440 // simply subtract the loaded values and add it to the result phi node.
441 void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {
442   // There is one load per block in this case, BlockIndex == LoadIndex.
443   const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex];
444
445   if (CurLoadEntry.LoadSize == 1) {
446     MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, CurLoadEntry.Offset);
447     return;
448   }
449
450   Type *LoadSizeType =
451       IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
452   Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
453   assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type");
454
455   Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
456
457   Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType,
458                                            CurLoadEntry.Offset);
459   Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType,
460                                            CurLoadEntry.Offset);
461
462   // Load LoadSizeType from the base address.
463   Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
464   Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
465
466   if (DL.isLittleEndian()) {
467     Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
468                                                 Intrinsic::bswap, LoadSizeType);
469     LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
470     LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
471   }
472
473   if (LoadSizeType != MaxLoadType) {
474     LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
475     LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
476   }
477
478   // Add the loaded values to the phi nodes for calculating memcmp result only
479   // if result is not used in a zero equality.
480   if (!IsUsedForZeroCmp) {
481     ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]);
482     ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]);
483   }
484
485   Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2);
486   BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
487                            ? EndBlock
488                            : LoadCmpBlocks[BlockIndex + 1];
489   // Early exit branch if difference found to ResultBlock. Otherwise, continue
490   // to next LoadCmpBlock or EndBlock.
491   BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp);
492   Builder.Insert(CmpBr);
493
494   // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
495   // since early exit to ResultBlock was not taken (no difference was found in
496   // any of the bytes).
497   if (BlockIndex == LoadCmpBlocks.size() - 1) {
498     Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
499     PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
500   }
501 }
502
503 // This function populates the ResultBlock with a sequence to calculate the
504 // memcmp result. It compares the two loaded source values and returns -1 if
505 // src1 < src2 and 1 if src1 > src2.
506 void MemCmpExpansion::emitMemCmpResultBlock() {
507   // Special case: if memcmp result is used in a zero equality, result does not
508   // need to be calculated and can simply return 1.
509   if (IsUsedForZeroCmp) {
510     BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt();
511     Builder.SetInsertPoint(ResBlock.BB, InsertPt);
512     Value *Res = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 1);
513     PhiRes->addIncoming(Res, ResBlock.BB);
514     BranchInst *NewBr = BranchInst::Create(EndBlock);
515     Builder.Insert(NewBr);
516     return;
517   }
518   BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt();
519   Builder.SetInsertPoint(ResBlock.BB, InsertPt);
520
521   Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_ULT, ResBlock.PhiSrc1,
522                                   ResBlock.PhiSrc2);
523
524   Value *Res =
525       Builder.CreateSelect(Cmp, ConstantInt::get(Builder.getInt32Ty(), -1),
526                            ConstantInt::get(Builder.getInt32Ty(), 1));
527
528   BranchInst *NewBr = BranchInst::Create(EndBlock);
529   Builder.Insert(NewBr);
530   PhiRes->addIncoming(Res, ResBlock.BB);
531 }
532
533 void MemCmpExpansion::setupResultBlockPHINodes() {
534   Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
535   Builder.SetInsertPoint(ResBlock.BB);
536   // Note: this assumes one load per block.
537   ResBlock.PhiSrc1 =
538       Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1");
539   ResBlock.PhiSrc2 =
540       Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2");
541 }
542
543 void MemCmpExpansion::setupEndBlockPHINodes() {
544   Builder.SetInsertPoint(&EndBlock->front());
545   PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res");
546 }
547
548 Value *MemCmpExpansion::getMemCmpExpansionZeroCase() {
549   unsigned LoadIndex = 0;
550   // This loop populates each of the LoadCmpBlocks with the IR sequence to
551   // handle multiple loads per block.
552   for (unsigned I = 0; I < getNumBlocks(); ++I) {
553     emitLoadCompareBlockMultipleLoads(I, LoadIndex);
554   }
555
556   emitMemCmpResultBlock();
557   return PhiRes;
558 }
559
560 /// A memcmp expansion that compares equality with 0 and only has one block of
561 /// load and compare can bypass the compare, branch, and phi IR that is required
562 /// in the general case.
563 Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
564   unsigned LoadIndex = 0;
565   Value *Cmp = getCompareLoadPairs(0, LoadIndex);
566   assert(LoadIndex == getNumLoads() && "some entries were not consumed");
567   return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext()));
568 }
569
570 /// A memcmp expansion that only has one block of load and compare can bypass
571 /// the compare, branch, and phi IR that is required in the general case.
572 Value *MemCmpExpansion::getMemCmpOneBlock() {
573   Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
574   Value *Source1 = CI->getArgOperand(0);
575   Value *Source2 = CI->getArgOperand(1);
576
577   // Cast source to LoadSizeType*.
578   if (Source1->getType() != LoadSizeType)
579     Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
580   if (Source2->getType() != LoadSizeType)
581     Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
582
583   // Load LoadSizeType from the base address.
584   Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
585   Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
586
587   if (DL.isLittleEndian() && Size != 1) {
588     Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
589                                                 Intrinsic::bswap, LoadSizeType);
590     LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
591     LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
592   }
593
594   if (Size < 4) {
595     // The i8 and i16 cases don't need compares. We zext the loaded values and
596     // subtract them to get the suitable negative, zero, or positive i32 result.
597     LoadSrc1 = Builder.CreateZExt(LoadSrc1, Builder.getInt32Ty());
598     LoadSrc2 = Builder.CreateZExt(LoadSrc2, Builder.getInt32Ty());
599     return Builder.CreateSub(LoadSrc1, LoadSrc2);
600   }
601
602   // The result of memcmp is negative, zero, or positive, so produce that by
603   // subtracting 2 extended compare bits: sub (ugt, ult).
604   // If a target prefers to use selects to get -1/0/1, they should be able
605   // to transform this later. The inverse transform (going from selects to math)
606   // may not be possible in the DAG because the selects got converted into
607   // branches before we got there.
608   Value *CmpUGT = Builder.CreateICmpUGT(LoadSrc1, LoadSrc2);
609   Value *CmpULT = Builder.CreateICmpULT(LoadSrc1, LoadSrc2);
610   Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty());
611   Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty());
612   return Builder.CreateSub(ZextUGT, ZextULT);
613 }
614
615 // This function expands the memcmp call into an inline expansion and returns
616 // the memcmp result.
617 Value *MemCmpExpansion::getMemCmpExpansion() {
618   // Create the basic block framework for a multi-block expansion.
619   if (getNumBlocks() != 1) {
620     BasicBlock *StartBlock = CI->getParent();
621     EndBlock = StartBlock->splitBasicBlock(CI, "endblock");
622     setupEndBlockPHINodes();
623     createResultBlock();
624
625     // If return value of memcmp is not used in a zero equality, we need to
626     // calculate which source was larger. The calculation requires the
627     // two loaded source values of each load compare block.
628     // These will be saved in the phi nodes created by setupResultBlockPHINodes.
629     if (!IsUsedForZeroCmp) setupResultBlockPHINodes();
630
631     // Create the number of required load compare basic blocks.
632     createLoadCmpBlocks();
633
634     // Update the terminator added by splitBasicBlock to branch to the first
635     // LoadCmpBlock.
636     StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]);
637   }
638
639   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
640
641   if (IsUsedForZeroCmp)
642     return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock()
643                                : getMemCmpExpansionZeroCase();
644
645   if (getNumBlocks() == 1)
646     return getMemCmpOneBlock();
647
648   for (unsigned I = 0; I < getNumBlocks(); ++I) {
649     emitLoadCompareBlock(I);
650   }
651
652   emitMemCmpResultBlock();
653   return PhiRes;
654 }
655
656 // This function checks to see if an expansion of memcmp can be generated.
657 // It checks for constant compare size that is less than the max inline size.
658 // If an expansion cannot occur, returns false to leave as a library call.
659 // Otherwise, the library call is replaced with a new IR instruction sequence.
660 /// We want to transform:
661 /// %call = call signext i32 @memcmp(i8* %0, i8* %1, i64 15)
662 /// To:
663 /// loadbb:
664 ///  %0 = bitcast i32* %buffer2 to i8*
665 ///  %1 = bitcast i32* %buffer1 to i8*
666 ///  %2 = bitcast i8* %1 to i64*
667 ///  %3 = bitcast i8* %0 to i64*
668 ///  %4 = load i64, i64* %2
669 ///  %5 = load i64, i64* %3
670 ///  %6 = call i64 @llvm.bswap.i64(i64 %4)
671 ///  %7 = call i64 @llvm.bswap.i64(i64 %5)
672 ///  %8 = sub i64 %6, %7
673 ///  %9 = icmp ne i64 %8, 0
674 ///  br i1 %9, label %res_block, label %loadbb1
675 /// res_block:                                        ; preds = %loadbb2,
676 /// %loadbb1, %loadbb
677 ///  %phi.src1 = phi i64 [ %6, %loadbb ], [ %22, %loadbb1 ], [ %36, %loadbb2 ]
678 ///  %phi.src2 = phi i64 [ %7, %loadbb ], [ %23, %loadbb1 ], [ %37, %loadbb2 ]
679 ///  %10 = icmp ult i64 %phi.src1, %phi.src2
680 ///  %11 = select i1 %10, i32 -1, i32 1
681 ///  br label %endblock
682 /// loadbb1:                                          ; preds = %loadbb
683 ///  %12 = bitcast i32* %buffer2 to i8*
684 ///  %13 = bitcast i32* %buffer1 to i8*
685 ///  %14 = bitcast i8* %13 to i32*
686 ///  %15 = bitcast i8* %12 to i32*
687 ///  %16 = getelementptr i32, i32* %14, i32 2
688 ///  %17 = getelementptr i32, i32* %15, i32 2
689 ///  %18 = load i32, i32* %16
690 ///  %19 = load i32, i32* %17
691 ///  %20 = call i32 @llvm.bswap.i32(i32 %18)
692 ///  %21 = call i32 @llvm.bswap.i32(i32 %19)
693 ///  %22 = zext i32 %20 to i64
694 ///  %23 = zext i32 %21 to i64
695 ///  %24 = sub i64 %22, %23
696 ///  %25 = icmp ne i64 %24, 0
697 ///  br i1 %25, label %res_block, label %loadbb2
698 /// loadbb2:                                          ; preds = %loadbb1
699 ///  %26 = bitcast i32* %buffer2 to i8*
700 ///  %27 = bitcast i32* %buffer1 to i8*
701 ///  %28 = bitcast i8* %27 to i16*
702 ///  %29 = bitcast i8* %26 to i16*
703 ///  %30 = getelementptr i16, i16* %28, i16 6
704 ///  %31 = getelementptr i16, i16* %29, i16 6
705 ///  %32 = load i16, i16* %30
706 ///  %33 = load i16, i16* %31
707 ///  %34 = call i16 @llvm.bswap.i16(i16 %32)
708 ///  %35 = call i16 @llvm.bswap.i16(i16 %33)
709 ///  %36 = zext i16 %34 to i64
710 ///  %37 = zext i16 %35 to i64
711 ///  %38 = sub i64 %36, %37
712 ///  %39 = icmp ne i64 %38, 0
713 ///  br i1 %39, label %res_block, label %loadbb3
714 /// loadbb3:                                          ; preds = %loadbb2
715 ///  %40 = bitcast i32* %buffer2 to i8*
716 ///  %41 = bitcast i32* %buffer1 to i8*
717 ///  %42 = getelementptr i8, i8* %41, i8 14
718 ///  %43 = getelementptr i8, i8* %40, i8 14
719 ///  %44 = load i8, i8* %42
720 ///  %45 = load i8, i8* %43
721 ///  %46 = zext i8 %44 to i32
722 ///  %47 = zext i8 %45 to i32
723 ///  %48 = sub i32 %46, %47
724 ///  br label %endblock
725 /// endblock:                                         ; preds = %res_block,
726 /// %loadbb3
727 ///  %phi.res = phi i32 [ %48, %loadbb3 ], [ %11, %res_block ]
728 ///  ret i32 %phi.res
729 static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI,
730                          const TargetLowering *TLI, const DataLayout *DL) {
731   NumMemCmpCalls++;
732
733   // Early exit from expansion if -Oz.
734   if (CI->getFunction()->hasMinSize())
735     return false;
736
737   // Early exit from expansion if size is not a constant.
738   ConstantInt *SizeCast = dyn_cast<ConstantInt>(CI->getArgOperand(2));
739   if (!SizeCast) {
740     NumMemCmpNotConstant++;
741     return false;
742   }
743   const uint64_t SizeVal = SizeCast->getZExtValue();
744
745   if (SizeVal == 0) {
746     return false;
747   }
748   // TTI call to check if target would like to expand memcmp. Also, get the
749   // available load sizes.
750   const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI);
751   const auto *const Options = TTI->enableMemCmpExpansion(IsUsedForZeroCmp);
752   if (!Options) return false;
753
754   const unsigned MaxNumLoads = CI->getFunction()->hasOptSize()
755       ? (MaxLoadsPerMemcmpOptSize.getNumOccurrences()
756          ? MaxLoadsPerMemcmpOptSize
757          : TLI->getMaxExpandSizeMemcmp(true))
758       : (MaxLoadsPerMemcmp.getNumOccurrences()
759          ? MaxLoadsPerMemcmp
760          : TLI->getMaxExpandSizeMemcmp(false));
761
762   unsigned NumLoadsPerBlock = MemCmpEqZeroNumLoadsPerBlock.getNumOccurrences()
763                                   ? MemCmpEqZeroNumLoadsPerBlock
764                                   : TLI->getMemcmpEqZeroLoadsPerBlock();
765
766   MemCmpExpansion Expansion(CI, SizeVal, *Options, MaxNumLoads,
767                             IsUsedForZeroCmp, NumLoadsPerBlock, *DL);
768
769   // Don't expand if this will require more loads than desired by the target.
770   if (Expansion.getNumLoads() == 0) {
771     NumMemCmpGreaterThanMax++;
772     return false;
773   }
774
775   NumMemCmpInlined++;
776
777   Value *Res = Expansion.getMemCmpExpansion();
778
779   // Replace call with result of expansion and erase call.
780   CI->replaceAllUsesWith(Res);
781   CI->eraseFromParent();
782
783   return true;
784 }
785
786
787
788 class ExpandMemCmpPass : public FunctionPass {
789 public:
790   static char ID;
791
792   ExpandMemCmpPass() : FunctionPass(ID) {
793     initializeExpandMemCmpPassPass(*PassRegistry::getPassRegistry());
794   }
795
796   bool runOnFunction(Function &F) override {
797     if (skipFunction(F)) return false;
798
799     auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
800     if (!TPC) {
801       return false;
802     }
803     const TargetLowering* TL =
804         TPC->getTM<TargetMachine>().getSubtargetImpl(F)->getTargetLowering();
805
806     const TargetLibraryInfo *TLI =
807         &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
808     const TargetTransformInfo *TTI =
809         &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
810     auto PA = runImpl(F, TLI, TTI, TL);
811     return !PA.areAllPreserved();
812   }
813
814 private:
815   void getAnalysisUsage(AnalysisUsage &AU) const override {
816     AU.addRequired<TargetLibraryInfoWrapperPass>();
817     AU.addRequired<TargetTransformInfoWrapperPass>();
818     FunctionPass::getAnalysisUsage(AU);
819   }
820
821   PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI,
822                             const TargetTransformInfo *TTI,
823                             const TargetLowering* TL);
824   // Returns true if a change was made.
825   bool runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI,
826                   const TargetTransformInfo *TTI, const TargetLowering* TL,
827                   const DataLayout& DL);
828 };
829
830 bool ExpandMemCmpPass::runOnBlock(
831     BasicBlock &BB, const TargetLibraryInfo *TLI,
832     const TargetTransformInfo *TTI, const TargetLowering* TL,
833     const DataLayout& DL) {
834   for (Instruction& I : BB) {
835     CallInst *CI = dyn_cast<CallInst>(&I);
836     if (!CI) {
837       continue;
838     }
839     LibFunc Func;
840     if (TLI->getLibFunc(ImmutableCallSite(CI), Func) &&
841         (Func == LibFunc_memcmp || Func == LibFunc_bcmp) &&
842         expandMemCmp(CI, TTI, TL, &DL)) {
843       return true;
844     }
845   }
846   return false;
847 }
848
849
850 PreservedAnalyses ExpandMemCmpPass::runImpl(
851     Function &F, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI,
852     const TargetLowering* TL) {
853   const DataLayout& DL = F.getParent()->getDataLayout();
854   bool MadeChanges = false;
855   for (auto BBIt = F.begin(); BBIt != F.end();) {
856     if (runOnBlock(*BBIt, TLI, TTI, TL, DL)) {
857       MadeChanges = true;
858       // If changes were made, restart the function from the beginning, since
859       // the structure of the function was changed.
860       BBIt = F.begin();
861     } else {
862       ++BBIt;
863     }
864   }
865   return MadeChanges ? PreservedAnalyses::none() : PreservedAnalyses::all();
866 }
867
868 } // namespace
869
870 char ExpandMemCmpPass::ID = 0;
871 INITIALIZE_PASS_BEGIN(ExpandMemCmpPass, "expandmemcmp",
872                       "Expand memcmp() to load/stores", false, false)
873 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
874 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
875 INITIALIZE_PASS_END(ExpandMemCmpPass, "expandmemcmp",
876                     "Expand memcmp() to load/stores", false, false)
877
878 FunctionPass *llvm::createExpandMemCmpPass() {
879   return new ExpandMemCmpPass();
880 }