OSDN Git Service

Minor cleanup; multiplication.
authorSimon Forman <sforman@hushmail.com>
Wed, 5 Oct 2022 05:06:49 +0000 (22:06 -0700)
committerSimon Forman <sforman@hushmail.com>
Wed, 5 Oct 2022 05:06:49 +0000 (22:06 -0700)
I forgot to commit after cleanup but before implementing multiplication
so this commit is kind of a mess.

Anyway, it works.  :D

bigjoyints/big.py

index fd6f342..11adab1 100644 (file)
@@ -1,13 +1,10 @@
 from copy import copy
+from random import randint
 from itertools import zip_longest
 from pprint import pprint as P
 import  unittest
 
 
-def is_i32(n):
-    return -2**31 <= n < 2**31
-
-
 class BigInt:
 
     def __init__(self, initial=0):
@@ -26,10 +23,6 @@ class BigInt:
     def digitize(n):
         if n < 0:
             raise ValueError(f'Non-negative only: {n}')
-        #if not n:
-        #    yield OberonInt(0)
-        #    return  # Not strictly needed as the following while
-        #            # will not do anything for n == 0.
         while n:
             n, digit = divmod(n, 2**31)
             yield OberonInt(digit)
@@ -61,19 +54,79 @@ class BigInt:
             other = BigInt(other)
         if self.sign == other.sign:
             return self.add_like_signs(other)
-
         return self.add_unlike_signs(other)
 
     def __sub__(self, other):
         if not isinstance(other, BigInt):
-            other = BigInt(other)
-        #print(23)
-        #print(self.to_int(), '-', other.to_int())
-        z = copy(other)
+            z = BigInt(other)
+        else:
+            z = copy(other)
         z.sign = not z.sign
-        #print(self.to_int(), '+', z.to_int(), 'sub')
         return self + z
 
+    def __mul__(self, other):
+        if not isinstance(other, BigInt):
+            other = BigInt(other)
+
+        if len(self.digits) < len(other.digits):
+            return other.__mul__(self)
+
+        # We now multiple the digits of self by the digits of other.
+        #
+        #     128
+        #    * 12
+        #   ------
+        #
+        #     128
+        #    *  2
+        #   ------
+        #       8
+        #       2
+        #       -
+        #      16
+        #      2|
+        #      2|
+        #      -|
+        #      46
+        # carry1
+        #      56
+        #     1||
+        #     2||
+        #     -||
+        #     256
+        #
+        # Hmm...
+
+        acc = BigInt()
+        for i, digit in enumerate(other.digits):
+            intermediate_result = self._mul_one_digit(i, digit)
+            #print(intermediate_result)
+            acc = acc + intermediate_result
+        acc.sign = not (self.sign ^ other.sign)
+        return acc
+
+    def _mul_one_digit(self, power, n):
+        # Some of this should go in a method of OberonInt?
+        out = [zero] * power
+        carry = zero
+        for digit in self.digits:
+            # In the Oberon RISC the high half of multiplication
+            # is put into the special H register.
+            H, product = digit * n
+            c, p = product + carry
+            out.append(p)
+            carry = H
+            if c:
+                z, carry = carry + one
+                assert not z, repr(z)
+        if carry.value:
+            assert carry.value > 0
+            out.append(carry)
+        result = BigInt()
+        result.digits = out
+        return result
+
+
     def add_like_signs(self, other):
         '''
         Add a BigInt of the same sign as self.
@@ -105,12 +158,7 @@ class BigInt:
         # So we have -a and +b
         #         or +a and -b
 
-        if self.sign:
-            a, b = self, other
-        else:
-            b, a = self, other
-
-        #print(a.to_int(), '+', b.to_int(), 'add_unlike_signs')
+        a, b = (self, other) if self.sign else (other, self)
 
         # So now we have:
         #     a + (-b) == a - b
@@ -123,18 +171,9 @@ class BigInt:
         #
         # I.e. 9 - 17 == -(17 - 9)
 
-
-        #if abs(a) < abs(b):
-        if not a.abs_gt_abs(b):
-            #print(f'abs({a.to_int()}) < abs({b.to_int()})')
-            x = b._subtract_smaller(a)
-            #x.sign = not x.sign
-            return x
-        #print(f'abs({a.to_int()}) > abs({b.to_int()})')
-        return a._subtract_smaller(b)
+        return a._subtract_smaller(b) if a.abs_gt_abs(b) else b._subtract_smaller(a)
 
     def _subtract_smaller(self, other):
-        assert self.abs_gt_abs(other)
         out = []
         carry = 0
         Z = zip_longest(
@@ -161,41 +200,12 @@ class BigInt:
             return False
         return self.digits[-1] > other.digits[-1]
 
+    def __eq__(self, other):
+        return self.sign == other.sign and self.digits == other.digits
 
-##        result = BigInt()
-##        result.sign = self.sign
-##        result.digits = (
-##            self.subtract_digits(other)
-##            if self.sign else
-##            other.subtract_digits(self)
-##            )
-##        return result
-
-##    def subtract_digits(self, other):
-##        return []
-
-##def _sort_key(list_of_obint):
-##    n = len(list_of_obint)
-##    last = list_of_obint[-1] if n else None
-##    return n, zero
-
-##def subtract_list_of_obints(A, B):
-##    L = [A, B]
-##    K = sorted(L, key=_sort_key)
-##    A, B = K
-##    swapped = L != K
-##    carry = 0
-##    out = []
-##    for a, b in zip_longest(A, B, fillvalue=zero):
-##        carry, digit = a.sub_with_carry(b, carry)
-##        out.append(digit)
-##    if carry:
-##        out.append(one)
-##    result = BigInt()
-##    result.sign = self.sign
-##    result.digits = out
-##    return result
 
+def is_i32(n):
+    return -2**31 <= n < 2**31
 
 
 class OberonInt:
@@ -204,6 +214,10 @@ class OberonInt:
     32-bit, two's complement.
     '''
 
+    def __init__(self, initial=0):
+        assert is_i32(initial)
+        self.value = initial
+
     def add_with_carry(self, other, carry):
         '''
         In terms of single base-10 skool arithmetic:
@@ -221,27 +235,12 @@ class OberonInt:
         return c, digit
 
     def sub_with_carry(self, other, carry):
-        '''
-        In terms of single base-10 skool arithmetic:
-
-        a, b in {0..9}
-        carry in {0..1}
-
-        0 - 9 - 1
-
-        9 + 9 + 1 = 18 + 1  = 19
-        aka       = 1,(8+1) = 1, 9
-        '''
         c, digit = self - other
         if carry:
             z, digit = digit - one
             assert not z, repr(z)
         return c, digit
 
-    def __init__(self, initial=0):
-        assert is_i32(initial)
-        self.value = initial
-
     def __add__(self, other):
         '''
         Return carry bit and new value.
@@ -252,7 +251,7 @@ class OberonInt:
         carry = not is_i32(n)
         if carry:
             n &= 2**31 - 1
-        return int(carry), OberonInt(n)
+        return carry, OberonInt(n)
 
     __radd__ = __add__
 
@@ -272,13 +271,47 @@ class OberonInt:
     __rsub__ = __sub__
 
     def __repr__(self):
-        #b = bin(self.value.value & (2**32-1))
         return f'OberonInt({self.value})'
 
     def __eq__(self, other):
         assert isinstance(other, OberonInt)
         return self.value == other.value
 
+    def __gt__(self, other):
+        assert isinstance(other, OberonInt)
+        return self.value > other.value
+
+    def __mul__(self, other):
+        assert isinstance(other, OberonInt)
+        product = self.value * other.value
+        high = OberonInt(product >> 31)
+        low = OberonInt(product & (2**31 - 1))
+        return high, low
+
+##            # I think we want to put the 32nd bit of product
+##            # into the first bit of H, left-shifting H by one first.
+##            c = (H << 1) & (product >> 31)  # What about H[32]?
+##            product &= 0x7fffffff  # Zero out that 32nd bit.
+##
+##            if carry:
+##                digit += one
+
+
+        ##    >>> n = obmax.value
+        ##    >>> n*n
+        ##    4611686014132420609
+        ##    >>> bin(n*n)
+        ##    '0b11111111111111111111111111111100000000000000000000000000000001'
+        ##    >>> bin(n)
+        ##    '0b1111111111111111111111111111111'
+        ##    >>> bin(0b1111111111111111111111111111111 * 0b1111111111111111111111111111111)
+        ##    '0b11111111111111111111111111111100000000000000000000000000000001'
+
+        ##    >>> '0b00_111111 11111111 11111111 11111111|00000000 00000000 00000000 00000001'
+        # So we can see that multiplying obmax by itself leave two empty bits in the top half
+        # If we perform the above c = (H << 1) & (product >> 31) we get:
+        # c = 0b0_1111111 11111111 11111111 11111110
+        # p = 0b_00000000 00000000 00000000 00000001'
 
 obmin, zero, one, obmax = map(OberonInt, (
     -(2**31),
@@ -322,6 +355,18 @@ class OberonIntTest(unittest.TestCase):
         self.assertTrue(carry)
         self.assertEqual(n, zero)
 
+    def test_mul(self):
+        h, l = obmax * obmax
+        B = BigInt(obmax.value * obmax.value)
+        self.assertEqual([l, h], B.digits)
+
+
+N = 100
+rand = lambda: randint(0, 10**N) - (10**N)//2
+# For some reason randint(-(10**100), 10**100) wasn't returning negative numbers.
+# Above my pay grade.  I don't even know if that's a bug,
+# there are a /lot/ of numbers up around ten-to-the-hundreth-power, eh?
+
 
 class BigIntTest(unittest.TestCase):
 
@@ -340,29 +385,17 @@ class BigIntTest(unittest.TestCase):
     def test_Addition(self):
         n = 12345678901234567898090123445678990
         m = 901234567898090
-        x = BigInt(n)
-        y = BigInt(m)
-        z = x + y
-        t = z.to_int()
-        self.assertEqual(t, n + m)
+        self._test_add(n, m)
 
     def test_Addition_of_two_negatives(self):
         n = -12345678901234567898090123445678990
         m = -901234567898090
-        x = BigInt(n)
-        y = BigInt(m)
-        z = x + y
-        t = z.to_int()
-        self.assertEqual(t, n + m)
+        self._test_add(n, m)
 
     def test_Addition_of_unlike_signs(self):
         n = 12345678901234567898090123445678990
         m = -901234567898090
-        x = BigInt(n)
-        y = BigInt(m)
-        z = x + y
-        t = z.to_int()
-        self.assertEqual(t, n + m)
+        self._test_add(n, m)
 
     def _test_invert(self):
         n = 7 * (2**16)
@@ -377,31 +410,78 @@ class BigIntTest(unittest.TestCase):
     def test_Subtraction_small_from_large(self):
         n = 12345678901234567898090123445678990
         m = 901234567898090
-        x = BigInt(n)
-        y = BigInt(m)
-        z = x - y
-        t = z.to_int()
-        self.assertEqual(t, n - m)
+        self._test_sub(n, m)
 
     def test_Subtraction_large_from_small(self):
         n = 901234567898090
         m = 12345678901234567898090123445678990
+        self._test_sub(n, m)
+
+    def test_Subtraction_neg_small_from_large(self):
+        n = 12345678901234567898090123445678990
+        m = -901234567898090
+        self._test_sub(n, m)
+
+    def test_Subtraction_neg_large_from_small(self):
+        n = 901234567898090
+        m = -12345678901234567898090123445678990
+        self._test_sub(n, m)
+
+    def test_Subtraction_small_from_neg_large(self):
+        n = -12345678901234567898090123445678990
+        m = 901234567898090
+        self._test_sub(n, m)
+
+    def test_Subtraction_large_from_neg_small(self):
+        n = -901234567898090
+        m = 12345678901234567898090123445678990
+        self._test_sub(n, m)
+
+    def test_Subtraction_neg_small_from_neg_large(self):
+        n = -12345678901234567898090123445678990
+        m = -901234567898090
+        self._test_sub(n, m)
+
+    def test_Subtraction_neg_large_from_neg_small(self):
+        n = -901234567898090
+        m = -12345678901234567898090123445678990
+        self._test_sub(n, m)
+
+    def _test_add(self, n, m):
         x = BigInt(n)
         y = BigInt(m)
-        z = x - y
+        z = x + y
         t = z.to_int()
-        self.assertEqual(t, n - m)
-
-
-if __name__ == '__main__':
-    unittest.main()
+        self.assertEqual(t, n + m, f'{x} + {y}')
 
+    def _test_sub(self, n, m):
+        x = BigInt(n)
+        y = BigInt(m)
+        z = x - y
+        t = z.to_int()
+        self.assertEqual(t, n - m, f'{x} - {y}')
 
+    def _test_mul(self, n, m):
+        x = BigInt(n)
+        y = BigInt(m)
+        z = x * y
+        t = z.to_int()
+        self.assertEqual(t, n * m, f'{x} * {y}')
 
+    def test_mul(self):
+        a = 2063400293
+        b = -1483898257
+        self._test_mul(a, b)
 
+    def test_random_add_sub(self):
+        for _ in range(100):
+            a = rand()
+            b = rand()
+            #print(a, b)
+            self._test_add(a, b)
+            self._test_sub(a, b)
+            self._test_mul(a, b)
 
 
-##        if initial >= 2**31:
-##            raise ValueError(f'too big: {initial!r}')
-##        if initial < -2**31:
-##            raise ValueError(f'too small: {initial!r}')
+if __name__ == '__main__':
+    unittest.main()