aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Include/cpython/longintrepr.h6
-rw-r--r--Lib/test/test_pow.py22
-rw-r--r--Objects/longobject.c108
3 files changed, 106 insertions, 30 deletions
diff --git a/Include/cpython/longintrepr.h b/Include/cpython/longintrepr.h
index ff4155f965..68dbf9c438 100644
--- a/Include/cpython/longintrepr.h
+++ b/Include/cpython/longintrepr.h
@@ -21,8 +21,6 @@ extern "C" {
PyLong_SHIFT. The majority of the code doesn't care about the precise
value of PyLong_SHIFT, but there are some notable exceptions:
- - long_pow() requires that PyLong_SHIFT be divisible by 5
-
- PyLong_{As,From}ByteArray require that PyLong_SHIFT be at least 8
- long_hash() requires that PyLong_SHIFT is *strictly* less than the number
@@ -63,10 +61,6 @@ typedef long stwodigits; /* signed variant of twodigits */
#define PyLong_BASE ((digit)1 << PyLong_SHIFT)
#define PyLong_MASK ((digit)(PyLong_BASE - 1))
-#if PyLong_SHIFT % 5 != 0
-#error "longobject.c requires that PyLong_SHIFT be divisible by 5"
-#endif
-
/* Long integer representation.
The absolute value of a number is equal to
SUM(for i=0 through abs(ob_size)-1) ob_digit[i] * 2**(SHIFT*i)
diff --git a/Lib/test/test_pow.py b/Lib/test/test_pow.py
index 660ff80bbf..5cea9ceb20 100644
--- a/Lib/test/test_pow.py
+++ b/Lib/test/test_pow.py
@@ -93,6 +93,28 @@ class PowTest(unittest.TestCase):
pow(int(i),j,k)
)
+ def test_big_exp(self):
+ import random
+ self.assertEqual(pow(2, 50000), 1 << 50000)
+ # Randomized modular tests, checking the identities
+ # a**(b1 + b2) == a**b1 * a**b2
+ # a**(b1 * b2) == (a**b1)**b2
+ prime = 1000000000039 # for speed, relatively small prime modulus
+ for i in range(10):
+ a = random.randrange(1000, 1000000)
+ bpower = random.randrange(1000, 50000)
+ b = random.randrange(1 << (bpower - 1), 1 << bpower)
+ b1 = random.randrange(1, b)
+ b2 = b - b1
+ got1 = pow(a, b, prime)
+ got2 = pow(a, b1, prime) * pow(a, b2, prime) % prime
+ if got1 != got2:
+ self.fail(f"{a=:x} {b1=:x} {b2=:x} {got1=:x} {got2=:x}")
+ got3 = pow(a, b1 * b2, prime)
+ got4 = pow(pow(a, b1, prime), b2, prime)
+ if got3 != got4:
+ self.fail(f"{a=:x} {b1=:x} {b2=:x} {got3=:x} {got4=:x}")
+
def test_bug643260(self):
class TestRpow:
def __rpow__(self, other):
diff --git a/Objects/longobject.c b/Objects/longobject.c
index 09ae9455c5..b5648fca7d 100644
--- a/Objects/longobject.c
+++ b/Objects/longobject.c
@@ -74,12 +74,34 @@ maybe_small_long(PyLongObject *v)
#define KARATSUBA_CUTOFF 70
#define KARATSUBA_SQUARE_CUTOFF (2 * KARATSUBA_CUTOFF)
-/* For exponentiation, use the binary left-to-right algorithm
- * unless the exponent contains more than FIVEARY_CUTOFF digits.
- * In that case, do 5 bits at a time. The potential drawback is that
- * a table of 2**5 intermediate results is computed.
+/* For exponentiation, use the binary left-to-right algorithm unless the
+ ^ exponent contains more than HUGE_EXP_CUTOFF bits. In that case, do
+ * (no more than) EXP_WINDOW_SIZE bits at a time. The potential drawback is
+ * that a table of 2**(EXP_WINDOW_SIZE - 1) intermediate results is
+ * precomputed.
*/
-#define FIVEARY_CUTOFF 8
+#define EXP_WINDOW_SIZE 5
+#define EXP_TABLE_LEN (1 << (EXP_WINDOW_SIZE - 1))
+/* Suppose the exponent has bit length e. All ways of doing this
+ * need e squarings. The binary method also needs a multiply for
+ * each bit set. In a k-ary method with window width w, a multiply
+ * for each non-zero window, so at worst (and likely!)
+ * ceiling(e/w). The k-ary sliding window method has the same
+ * worst case, but the window slides so it can sometimes skip
+ * over an all-zero window that the fixed-window method can't
+ * exploit. In addition, the windowing methods need multiplies
+ * to precompute a table of small powers.
+ *
+ * For the sliding window method with width 5, 16 precomputation
+ * multiplies are needed. Assuming about half the exponent bits
+ * are set, then, the binary method needs about e/2 extra mults
+ * and the window method about 16 + e/5.
+ *
+ * The latter is smaller for e > 53 1/3. We don't have direct
+ * access to the bit length, though, so call it 60, which is a
+ * multiple of a long digit's max bit length (15 or 30 so far).
+ */
+#define HUGE_EXP_CUTOFF 60
#define SIGCHECK(PyTryBlock) \
do { \
@@ -4172,14 +4194,15 @@ long_pow(PyObject *v, PyObject *w, PyObject *x)
int negativeOutput = 0; /* if x<0 return negative output */
PyLongObject *z = NULL; /* accumulated result */
- Py_ssize_t i, j, k; /* counters */
+ Py_ssize_t i, j; /* counters */
PyLongObject *temp = NULL;
+ PyLongObject *a2 = NULL; /* may temporarily hold a**2 % c */
- /* 5-ary values. If the exponent is large enough, table is
- * precomputed so that table[i] == a**i % c for i in range(32).
+ /* k-ary values. If the exponent is large enough, table is
+ * precomputed so that table[i] == a**(2*i+1) % c for i in
+ * range(EXP_TABLE_LEN).
*/
- PyLongObject *table[32] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
- 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
+ PyLongObject *table[EXP_TABLE_LEN] = {0};
/* a, b, c = v, w, x */
CHECK_BINOP(v, w);
@@ -4332,7 +4355,7 @@ long_pow(PyObject *v, PyObject *w, PyObject *x)
}
/* else bi is 0, and z==1 is correct */
}
- else if (i <= FIVEARY_CUTOFF) {
+ else if (i <= HUGE_EXP_CUTOFF / PyLong_SHIFT ) {
/* Left-to-right binary exponentiation (HAC Algorithm 14.79) */
/* http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf */
@@ -4366,23 +4389,59 @@ long_pow(PyObject *v, PyObject *w, PyObject *x)
}
}
else {
- /* Left-to-right 5-ary exponentiation (HAC Algorithm 14.82) */
- Py_INCREF(z); /* still holds 1L */
- table[0] = z;
- for (i = 1; i < 32; ++i)
- MULT(table[i-1], a, table[i]);
+ /* Left-to-right k-ary sliding window exponentiation
+ * (Handbook of Applied Cryptography (HAC) Algorithm 14.85)
+ */
+ Py_INCREF(a);
+ table[0] = a;
+ MULT(a, a, a2);
+ /* table[i] == a**(2*i + 1) % c */
+ for (i = 1; i < EXP_TABLE_LEN; ++i)
+ MULT(table[i-1], a2, table[i]);
+ Py_CLEAR(a2);
+
+ /* Repeatedly extract the next (no more than) EXP_WINDOW_SIZE bits
+ * into `pending`, starting with the next 1 bit. The current bit
+ * length of `pending` is `blen`.
+ */
+ int pending = 0, blen = 0;
+#define ABSORB_PENDING do { \
+ int ntz = 0; /* number of trailing zeroes in `pending` */ \
+ assert(pending && blen); \
+ assert(pending >> (blen - 1)); \
+ assert(pending >> blen == 0); \
+ while ((pending & 1) == 0) { \
+ ++ntz; \
+ pending >>= 1; \
+ } \
+ assert(ntz < blen); \
+ blen -= ntz; \
+ do { \
+ MULT(z, z, z); \
+ } while (--blen); \
+ MULT(z, table[pending >> 1], z); \
+ while (ntz-- > 0) \
+ MULT(z, z, z); \
+ assert(blen == 0); \
+ pending = 0; \
+ } while(0)
for (i = Py_SIZE(b) - 1; i >= 0; --i) {
const digit bi = b->ob_digit[i];
-
- for (j = PyLong_SHIFT - 5; j >= 0; j -= 5) {
- const int index = (bi >> j) & 0x1f;
- for (k = 0; k < 5; ++k)
+ for (j = PyLong_SHIFT - 1; j >= 0; --j) {
+ const int bit = (bi >> j) & 1;
+ pending = (pending << 1) | bit;
+ if (pending) {
+ ++blen;
+ if (blen == EXP_WINDOW_SIZE)
+ ABSORB_PENDING;
+ }
+ else /* absorb strings of 0 bits */
MULT(z, z, z);
- if (index)
- MULT(z, table[index], z);
}
}
+ if (pending)
+ ABSORB_PENDING;
}
if (negativeOutput && (Py_SIZE(z) != 0)) {
@@ -4399,13 +4458,14 @@ long_pow(PyObject *v, PyObject *w, PyObject *x)
Py_CLEAR(z);
/* fall through */
Done:
- if (Py_SIZE(b) > FIVEARY_CUTOFF) {
- for (i = 0; i < 32; ++i)
+ if (Py_SIZE(b) > HUGE_EXP_CUTOFF / PyLong_SHIFT) {
+ for (i = 0; i < EXP_TABLE_LEN; ++i)
Py_XDECREF(table[i]);
}
Py_DECREF(a);
Py_DECREF(b);
Py_XDECREF(c);
+ Py_XDECREF(a2);
Py_XDECREF(temp);
return (PyObject *)z;
}