From 3492cda48f85ca9048824baff18e24a19edaf7ed Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Sat, 7 Apr 2007 21:04:50 +0000 Subject: [PATCH] Change CastToCStr to take a pointer instead of a reference. Fix some miscompilations in fprintf optimizer. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@35753 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Transforms/IPO/SimplifyLibCalls.cpp | 158 +++++++++++++++----------------- 1 file changed, 76 insertions(+), 82 deletions(-) diff --git a/lib/Transforms/IPO/SimplifyLibCalls.cpp b/lib/Transforms/IPO/SimplifyLibCalls.cpp index 8aed1e8103b..6b55eb67597 100644 --- a/lib/Transforms/IPO/SimplifyLibCalls.cpp +++ b/lib/Transforms/IPO/SimplifyLibCalls.cpp @@ -393,7 +393,7 @@ namespace { // Forward declare utility functions. static bool GetConstantStringInfo(Value *V, ConstantArray *&Array, uint64_t &Length, uint64_t &StartIdx); -static Value *CastToCStr(Value *V, Instruction &IP); +static Value *CastToCStr(Value *V, Instruction *IP); /// This LibCallOptimization will find instances of a call to "exit" that occurs /// within the "main" function and change it to a simple "ret" instruction with @@ -1228,7 +1228,7 @@ public: return false; // printf("%s\n",str) -> puts(str) - new CallInst(SLC.get_puts(), CastToCStr(CI->getOperand(2), *CI), + new CallInst(SLC.get_puts(), CastToCStr(CI->getOperand(2), CI), CI->getName(), CI); return ReplaceCallWith(CI, 0); } @@ -1262,104 +1262,98 @@ public: "Number of 'fprintf' calls simplified") {} /// @brief Make sure that the "fprintf" function has the right prototype - virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC){ - // Just make sure this has at least 2 arguments - return (f->arg_size() >= 2); + virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){ + const FunctionType *FT = F->getFunctionType(); + return FT->getNumParams() == 2 && // two fixed arguments. + FT->getParamType(1) == PointerType::get(Type::Int8Ty) && + isa(FT->getParamType(0)) && + isa(FT->getReturnType()); } /// @brief Perform the fprintf optimization. - virtual bool OptimizeCall(CallInst* ci, SimplifyLibCalls& SLC) { + virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) { // If the call has more than 3 operands, we can't optimize it - if (ci->getNumOperands() > 4 || ci->getNumOperands() <= 2) - return false; - - // If the result of the fprintf call is used, none of these optimizations - // can be made. - if (!ci->use_empty()) + if (CI->getNumOperands() != 3 && CI->getNumOperands() != 4) return false; // All the optimizations depend on the length of the second argument and the // fact that it is a constant string array. Check that now - uint64_t len, StartIdx; - ConstantArray* CA = 0; - if (!GetConstantStringInfo(ci->getOperand(2), CA, len, StartIdx)) + uint64_t FormatLen, FormatStartIdx; + ConstantArray *CA = 0; + if (!GetConstantStringInfo(CI->getOperand(2), CA, FormatLen,FormatStartIdx)) return false; - if (ci->getNumOperands() == 3) { + // IF fthis is just a format string, turn it into fwrite. + if (CI->getNumOperands() == 3) { + if (!CA->isCString()) return false; + // Make sure there's no % in the constant array - for (unsigned i = 0; i < len; ++i) { - if (ConstantInt* CI = dyn_cast(CA->getOperand(i))) { - // Check for the null terminator - if (CI->getZExtValue() == '%') - return false; // we found end of string - } else { - return false; - } - } + std::string S = CA->getAsString(); - // fprintf(file,fmt) -> fwrite(fmt,strlen(fmt),file) - const Type* FILEptr_type = ci->getOperand(1)->getType(); + for (unsigned i = FormatStartIdx, e = S.size(); i != e; ++i) + if (S[i] == '%') + return false; // we found a format specifier - // Make sure that the fprintf() and fwrite() functions both take the - // same type of char pointer. - if (ci->getOperand(2)->getType() != PointerType::get(Type::Int8Ty)) - return false; + // fprintf(file,fmt) -> fwrite(fmt,strlen(fmt),file) + const Type *FILEty = CI->getOperand(1)->getType(); - Value* args[4] = { - ci->getOperand(2), - ConstantInt::get(SLC.getIntPtrType(),len), - ConstantInt::get(SLC.getIntPtrType(),1), - ci->getOperand(1) + Value *FWriteArgs[] = { + CI->getOperand(2), + ConstantInt::get(SLC.getIntPtrType(), FormatLen), + ConstantInt::get(SLC.getIntPtrType(), 1), + CI->getOperand(1) }; - new CallInst(SLC.get_fwrite(FILEptr_type), args, 4, ci->getName(), ci); - return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty,len)); + new CallInst(SLC.get_fwrite(FILEty), FWriteArgs, 4, CI->getName(), CI); + return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), FormatLen)); } - - // The remaining optimizations require the format string to be length 2 + + // The remaining optimizations require the format string to be length 2: // "%s" or "%c". - if (len != 2) + if (FormatLen != 2) return false; - // The first character has to be a % - if (ConstantInt* CI = dyn_cast(CA->getOperand(0))) - if (CI->getZExtValue() != '%') - return false; + // The first character has to be a % for us to handle it. + if (cast(CA->getOperand(FormatStartIdx))->getZExtValue() !='%') + return false; // Get the second character and switch on its value - ConstantInt* CI = dyn_cast(CA->getOperand(1)); - switch (CI->getZExtValue()) { - case 's': { - uint64_t len, StartIdx; - ConstantArray* CA = 0; - if (GetConstantStringInfo(ci->getOperand(3), CA, len, StartIdx)) { - // fprintf(file,"%s",str) -> fwrite(str,strlen(str),1,file) - const Type* FILEptr_type = ci->getOperand(1)->getType(); - Value* args[4] = { - CastToCStr(ci->getOperand(3), *ci), - ConstantInt::get(SLC.getIntPtrType(), len), - ConstantInt::get(SLC.getIntPtrType(), 1), - ci->getOperand(1) - }; - new CallInst(SLC.get_fwrite(FILEptr_type), args, 4,ci->getName(), ci); - return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty, len)); - } - // fprintf(file,"%s",str) -> fputs(str,file) - const Type* FILEptr_type = ci->getOperand(1)->getType(); - new CallInst(SLC.get_fputs(FILEptr_type), - CastToCStr(ci->getOperand(3), *ci), - ci->getOperand(1), ci->getName(),ci); - return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty,len)); - } - case 'c': { - // fprintf(file,"%c",c) -> fputc(c,file) - const Type* FILEptr_type = ci->getOperand(1)->getType(); - CastInst* cast = CastInst::createSExtOrBitCast( - ci->getOperand(3), Type::Int32Ty, CI->getName()+".int", ci); - new CallInst(SLC.get_fputc(FILEptr_type), cast,ci->getOperand(1),"",ci); - return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty,1)); + switch(cast(CA->getOperand(FormatStartIdx+1))->getZExtValue()){ + case 'c': { + // fprintf(file,"%c",c) -> fputc(c,file) + const Type *FILETy = CI->getOperand(1)->getType(); + Value *C = CastInst::createZExtOrBitCast(CI->getOperand(3), Type::Int32Ty, + CI->getName()+".int", CI); + new CallInst(SLC.get_fputc(FILETy), C, CI->getOperand(1), "", CI); + return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), 1)); + } + case 's': { + const Type *FILETy = CI->getOperand(1)->getType(); + uint64_t LitStrLen, LitStartIdx; + ConstantArray *CA = 0; + if (GetConstantStringInfo(CI->getOperand(3), CA, LitStrLen, LitStartIdx)){ + // fprintf(file,"%s",str) -> fwrite(str,strlen(str),1,file) + Value *FWriteArgs[] = { + CastToCStr(CI->getOperand(3), CI), + ConstantInt::get(SLC.getIntPtrType(), LitStrLen), + ConstantInt::get(SLC.getIntPtrType(), 1), + CI->getOperand(1) + }; + new CallInst(SLC.get_fwrite(FILETy), FWriteArgs, 4, CI->getName(), CI); + return ReplaceCallWith(CI, ConstantInt::get(Type::Int32Ty, LitStrLen)); } - default: + + // If the result of the fprintf call is used, we can't do this. + // TODO: we could insert a strlen call. + if (!CI->use_empty()) return false; + + // fprintf(file,"%s",str) -> fputs(str,file) + new CallInst(SLC.get_fputs(FILETy), CastToCStr(CI->getOperand(3), CI), + CI->getOperand(1), CI->getName(), CI); + return ReplaceCallWith(CI, 0); + } + default: + return false; } } } FPrintFOptimizer; @@ -1441,7 +1435,7 @@ public: case 's': { // sprintf(dest,"%s",str) -> llvm.memcpy(dest, str, strlen(str)+1, 1) Value *Len = new CallInst(SLC.get_strlen(), - CastToCStr(ci->getOperand(3), *ci), + CastToCStr(ci->getOperand(3), ci), ci->getOperand(3)->getName()+".len", ci); Value *Len1 = BinaryOperator::createAdd(Len, ConstantInt::get(Len->getType(), 1), @@ -1450,8 +1444,8 @@ public: Len1 = CastInst::createIntegerCast(Len1, SLC.getIntPtrType(), false, Len1->getName(), ci); Value *args[4] = { - CastToCStr(ci->getOperand(1), *ci), - CastToCStr(ci->getOperand(3), *ci), + CastToCStr(ci->getOperand(1), ci), + CastToCStr(ci->getOperand(3), ci), Len1, ConstantInt::get(Type::Int32Ty,1) }; @@ -1946,12 +1940,12 @@ static bool GetConstantStringInfo(Value *V, ConstantArray *&Array, /// CastToCStr - Return V if it is an sbyte*, otherwise cast it to sbyte*, /// inserting the cast before IP, and return the cast. /// @brief Cast a value to a "C" string. -static Value *CastToCStr(Value *V, Instruction &IP) { +static Value *CastToCStr(Value *V, Instruction *IP) { assert(isa(V->getType()) && "Can't cast non-pointer type to C string type"); const Type *SBPTy = PointerType::get(Type::Int8Ty); if (V->getType() != SBPTy) - return new BitCastInst(V, SBPTy, V->getName(), &IP); + return new BitCastInst(V, SBPTy, V->getName(), IP); return V; } -- 2.11.0