OSDN Git Service

Implement comparison operators for BranchProbability in a way that can't overflow...
authorBenjamin Kramer <benny.kra@googlemail.com>
Mon, 24 Oct 2011 13:50:56 +0000 (13:50 +0000)
committerBenjamin Kramer <benny.kra@googlemail.com>
Mon, 24 Oct 2011 13:50:56 +0000 (13:50 +0000)
Add a test case for the edge case that triggers this. Thanks to Chandler for bringing this to my attention.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@142794 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Support/BranchProbability.h
unittests/Support/BlockFrequencyTest.cpp

index 6b83159..4b5d904 100644 (file)
@@ -29,10 +29,6 @@ class BranchProbability {
   // Denominator
   uint32_t D;
 
-  int64_t compare(BranchProbability RHS) const {
-    return (uint64_t)N * RHS.D - (uint64_t)D * RHS.N;
-  }
-
 public:
   BranchProbability(uint32_t n, uint32_t d) : N(n), D(d) {
     assert(d > 0 && "Denomiator cannot be 0!");
@@ -54,12 +50,24 @@ public:
 
   void dump() const;
 
-  bool operator==(BranchProbability RHS) const { return compare(RHS) == 0; }
-  bool operator!=(BranchProbability RHS) const { return compare(RHS) != 0; }
-  bool operator< (BranchProbability RHS) const { return compare(RHS) <  0; }
-  bool operator> (BranchProbability RHS) const { return compare(RHS) >  0; }
-  bool operator<=(BranchProbability RHS) const { return compare(RHS) <= 0; }
-  bool operator>=(BranchProbability RHS) const { return compare(RHS) >= 0; }
+  bool operator==(BranchProbability RHS) const {
+    return (uint64_t)N * RHS.D == (uint64_t)D * RHS.N;
+  }
+  bool operator!=(BranchProbability RHS) const {
+    return !(*this == RHS);
+  }
+  bool operator<(BranchProbability RHS) const {
+    return (uint64_t)N * RHS.D < (uint64_t)D * RHS.N;
+  }
+  bool operator>(BranchProbability RHS) const {
+    return RHS < *this;
+  }
+  bool operator<=(BranchProbability RHS) const {
+    return (uint64_t)N * RHS.D <= (uint64_t)D * RHS.N;
+  }
+  bool operator>=(BranchProbability RHS) const {
+    return RHS <= *this;
+  }
 };
 
 raw_ostream &operator<<(raw_ostream &OS, const BranchProbability &Prob);
index ac3cedf..df25642 100644 (file)
@@ -71,6 +71,15 @@ TEST(BlockFrequencyTest, ProbabilityCompare) {
   EXPECT_TRUE(B > C);
   EXPECT_FALSE(B <= C);
   EXPECT_TRUE(B >= C);
+
+  BranchProbability BigZero(0, UINT32_MAX);
+  BranchProbability BigOne(UINT32_MAX, UINT32_MAX);
+  EXPECT_FALSE(BigZero == BigOne);
+  EXPECT_TRUE(BigZero != BigOne);
+  EXPECT_TRUE(BigZero < BigOne);
+  EXPECT_FALSE(BigZero > BigOne);
+  EXPECT_TRUE(BigZero <= BigOne);
+  EXPECT_FALSE(BigZero >= BigOne);
 }
 
 }