diff --git a/py/mpz.c b/py/mpz.c index 16198730af..79d200d930 100644 --- a/py/mpz.c +++ b/py/mpz.c @@ -173,6 +173,69 @@ STATIC uint mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz return idig + 1 - oidig; } +/* computes i = j & k + returns number of digits in i + assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen + can have i, j, k pointing to same memory +*/ +STATIC uint mpn_and(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t *kdig, uint klen) { + mpz_dig_t *oidig = idig; + + jlen -= klen; + + for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) { + *idig = *jdig & *kdig; + } + + for (; jlen > 0; --jlen, ++idig) { + *idig = 0; + } + + return idig - oidig; +} + +/* computes i = j | k + returns number of digits in i + assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen + can have i, j, k pointing to same memory +*/ +STATIC uint mpn_or(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t *kdig, uint klen) { + mpz_dig_t *oidig = idig; + + jlen -= klen; + + for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) { + *idig = *jdig | *kdig; + } + + for (; jlen > 0; --jlen, ++idig, ++jdig) { + *idig = *jdig; + } + + return idig - oidig; +} + +/* computes i = j ^ k + returns number of digits in i + assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen + can have i, j, k pointing to same memory +*/ +STATIC uint mpn_xor(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t *kdig, uint klen) { + mpz_dig_t *oidig = idig; + + jlen -= klen; + + for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) { + *idig = *jdig ^ *kdig; + } + + for (; jlen > 0; --jlen, ++idig, ++jdig) { + *idig = *jdig; + } + + return idig - oidig; +} + /* computes i = i * d1 + d2 returns number of digits in i assumes enough memory in i; assumes normalised i; assumes dmul != 0 @@ -805,6 +868,75 @@ void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) { } } +/* computes dest = lhs & rhs + can have dest, lhs, rhs the same +*/ +void mpz_and_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) { + if (mpn_cmp(lhs->dig, lhs->len, rhs->dig, rhs->len) < 0) { + const mpz_t *temp = lhs; + lhs = rhs; + rhs = temp; + } + + if (lhs->neg == rhs->neg) { + mpz_need_dig(dest, lhs->len); + dest->len = mpn_and(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len); + } else { + mpz_need_dig(dest, lhs->len); + // TODO + assert(0); +// dest->len = mpn_and_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len); + } + + dest->neg = lhs->neg; +} + +/* computes dest = lhs | rhs + can have dest, lhs, rhs the same +*/ +void mpz_or_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) { + if (mpn_cmp(lhs->dig, lhs->len, rhs->dig, rhs->len) < 0) { + const mpz_t *temp = lhs; + lhs = rhs; + rhs = temp; + } + + if (lhs->neg == rhs->neg) { + mpz_need_dig(dest, lhs->len); + dest->len = mpn_or(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len); + } else { + mpz_need_dig(dest, lhs->len); + // TODO + assert(0); +// dest->len = mpn_or_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len); + } + + dest->neg = lhs->neg; +} + +/* computes dest = lhs ^ rhs + can have dest, lhs, rhs the same +*/ +void mpz_xor_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) { + if (mpn_cmp(lhs->dig, lhs->len, rhs->dig, rhs->len) < 0) { + const mpz_t *temp = lhs; + lhs = rhs; + rhs = temp; + } + + if (lhs->neg == rhs->neg) { + mpz_need_dig(dest, lhs->len); + dest->len = mpn_xor(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len); + } else { + mpz_need_dig(dest, lhs->len); + // TODO + assert(0); +// dest->len = mpn_xor_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len); + } + + dest->neg = 0; +} + /* computes dest = lhs * rhs can have dest, lhs, rhs the same */ diff --git a/py/mpz.h b/py/mpz.h index afd46cfdea..7778893229 100644 --- a/py/mpz.h +++ b/py/mpz.h @@ -59,6 +59,9 @@ void mpz_add_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs); void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs); void mpz_mul_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs); void mpz_pow_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs); +void mpz_and_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs); +void mpz_or_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs); +void mpz_xor_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs); mpz_t *mpz_gcd(const mpz_t *z1, const mpz_t *z2); mpz_t *mpz_lcm(const mpz_t *z1, const mpz_t *z2); diff --git a/py/objint_mpz.c b/py/objint_mpz.c index 05e300ddc9..25d8af70ee 100644 --- a/py/objint_mpz.c +++ b/py/objint_mpz.c @@ -78,7 +78,7 @@ mp_obj_t int_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { return mp_obj_new_float(flhs / frhs); #endif - } else if (op <= RT_BINARY_OP_POWER) { + } else if (op <= RT_BINARY_OP_INPLACE_POWER) { mp_obj_int_t *res = mp_obj_int_new_mpz(); switch (op) { @@ -119,12 +119,18 @@ mp_obj_t int_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { break; } - //case RT_BINARY_OP_AND: - //case RT_BINARY_OP_INPLACE_AND: - //case RT_BINARY_OP_OR: - //case RT_BINARY_OP_INPLACE_OR: - //case RT_BINARY_OP_XOR: - //case RT_BINARY_OP_INPLACE_XOR: + case RT_BINARY_OP_AND: + case RT_BINARY_OP_INPLACE_AND: + mpz_and_inpl(&res->mpz, zlhs, zrhs); + break; + case RT_BINARY_OP_OR: + case RT_BINARY_OP_INPLACE_OR: + mpz_or_inpl(&res->mpz, zlhs, zrhs); + break; + case RT_BINARY_OP_XOR: + case RT_BINARY_OP_INPLACE_XOR: + mpz_xor_inpl(&res->mpz, zlhs, zrhs); + break; case RT_BINARY_OP_LSHIFT: case RT_BINARY_OP_INPLACE_LSHIFT: