diff --git a/py/mpz.c b/py/mpz.c index 3fb2548c4d..bb76479569 100644 --- a/py/mpz.c +++ b/py/mpz.c @@ -454,10 +454,8 @@ STATIC mp_uint_t mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mpz_d assumes num_dig has enough memory to be extended by 1 digit assumes quo_dig has enough memory (as many digits as num) assumes quo_dig is filled with zeros - modifies den_dig memory, but restors it to original state at end */ - -STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, mp_uint_t den_len, mpz_dig_t *quo_dig, mp_uint_t *quo_len) { +STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, const mpz_dig_t *den_dig, mp_uint_t den_len, mpz_dig_t *quo_dig, mp_uint_t *quo_len) { mpz_dig_t *orig_num_dig = num_dig; mpz_dig_t *orig_quo_dig = quo_dig; mpz_dig_t norm_shift = 0; @@ -478,6 +476,11 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, } } + // We need to normalise the denominator (leading bit of leading digit is 1) + // so that the division routine works. Since the denominator memory is + // read-only we do the normalisation on the fly, each time a digit of the + // denominator is needed. We need to know is how many bits to shift by. + // count number of leading zeros in leading digit of denominator { mpz_dig_t d = den_dig[den_len - 1]; @@ -487,13 +490,6 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, } } - // normalise denomenator (leading bit of leading digit is 1) - for (mpz_dig_t *den = den_dig, carry = 0; den < den_dig + den_len; ++den) { - mpz_dig_t d = *den; - *den = ((d << norm_shift) | carry) & DIG_MASK; - carry = (mpz_dbl_dig_t)d >> (DIG_SIZE - norm_shift); - } - // now need to shift numerator by same amount as denominator // first, increase length of numerator in case we need more room to shift num_dig[*num_len] = 0; @@ -505,7 +501,10 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, } // cache the leading digit of the denominator - lead_den_digit = den_dig[den_len - 1]; + lead_den_digit = (mpz_dbl_dig_t)den_dig[den_len - 1] << norm_shift; + if (den_len >= 2) { + lead_den_digit |= (mpz_dbl_dig_t)den_dig[den_len - 2] >> (DIG_SIZE - norm_shift); + } // point num_dig to last digit in numerator num_dig += *num_len - 1; @@ -540,10 +539,13 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, // round up). if (DIG_SIZE < 8 * sizeof(mpz_dbl_dig_t) / 2) { + const mpz_dig_t *d = den_dig; + mpz_dbl_dig_t d_norm = 0; mpz_dbl_dig_signed_t borrow = 0; - for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) { - borrow += (mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)*d; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2 + for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) { + d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE); + borrow += (mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK); // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2 *n = borrow & DIG_MASK; borrow >>= DIG_SIZE; } @@ -553,9 +555,12 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, // adjust quotient if it is too big for (; borrow != 0; --quo) { + d = den_dig; + d_norm = 0; mpz_dbl_dig_t carry = 0; - for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) { - carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d; + for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) { + d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE); + carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK); *n = carry & DIG_MASK; carry >>= DIG_SIZE; } @@ -566,10 +571,13 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, borrow += carry; } } else { // DIG_SIZE == 8 * sizeof(mpz_dbl_dig_t) / 2 + const mpz_dig_t *d = den_dig; + mpz_dbl_dig_t d_norm = 0; mpz_dbl_dig_t borrow = 0; - for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) { - mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)(*d); + for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) { + d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE); + mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK); if (x >= *n || *n - x <= borrow) { borrow += (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)*n; *n = (-borrow) & DIG_MASK; @@ -590,9 +598,12 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, // adjust quotient if it is too big for (; borrow != 0; --quo) { + d = den_dig; + d_norm = 0; mpz_dbl_dig_t carry = 0; - for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) { - carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d; + for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) { + d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE); + carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK); *n = carry & DIG_MASK; carry >>= DIG_SIZE; } @@ -614,13 +625,6 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, --(*num_len); } - // unnormalise denomenator - for (mpz_dig_t *den = den_dig + den_len - 1, carry = 0; den >= den_dig; --den) { - mpz_dig_t d = *den; - *den = ((d >> norm_shift) | carry) & DIG_MASK; - carry = (mpz_dbl_dig_t)d << (DIG_SIZE - norm_shift); - } - // unnormalise numerator (remainder now) for (mpz_dig_t *num = orig_num_dig + *num_len - 1, carry = 0; num >= orig_num_dig; --num) { mpz_dig_t n = *num; @@ -1506,7 +1510,6 @@ void mpz_divmod_inpl(mpz_t *dest_quo, mpz_t *dest_rem, const mpz_t *lhs, const m dest_quo->len = 0; mpz_need_dig(dest_rem, lhs->len + 1); // +1 necessary? mpz_set(dest_rem, lhs); - //rhs->dig[rhs->len] = 0; mpn_div(dest_rem->dig, &dest_rem->len, rhs->dig, rhs->len, dest_quo->dig, &dest_quo->len); // check signs and do Python style modulo