Skip to content
This repository has been archived by the owner on Nov 6, 2020. It is now read-only.

Commit

Permalink
fix modexp bug: return 0 if base=0
Browse files Browse the repository at this point in the history
  • Loading branch information
Hawstein committed Sep 2, 2017
1 parent 2faa28c commit 4794124
Showing 1 changed file with 63 additions and 29 deletions.
92 changes: 63 additions & 29 deletions ethcore/src/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,34 @@ impl Impl for Ripemd160 {
}
}

// calculate modexp: exponentiation by squaring. the `num` crate has pow, but not modular.
fn modexp(mut base: BigUint, mut exp: BigUint, modulus: BigUint) -> BigUint {
use num::Integer;

match (base.is_zero(), exp.is_zero()) {
(_, true) => return BigUint::one(), // n^0 % m
(true, false) => return BigUint::zero(), // 0^n % m, n>0
(false, false) if modulus <= BigUint::one() => return BigUint::zero(), // a^b % 1 = 0.
_ => {}
}

let mut result = BigUint::one();
base = base % &modulus;

// fast path for base divisible by modulus.
if base.is_zero() { return BigUint::zero() }
while !exp.is_zero() {
if exp.is_odd() {
result = (result * &base) % &modulus;
}

exp = exp >> 1;
base = (base.clone() * base) % &modulus;
}

result
}

impl Impl for ModexpImpl {
fn execute(&self, input: &[u8], output: &mut BytesRef) -> Result<(), Error> {
let mut reader = input.chain(io::repeat(0));
Expand Down Expand Up @@ -295,34 +323,6 @@ impl Impl for ModexpImpl {
let exp = read_num(exp_len);
let modulus = read_num(mod_len);

// calculate modexp: exponentiation by squaring. the `num` crate has pow, but not modular.
fn modexp(mut base: BigUint, mut exp: BigUint, modulus: BigUint) -> BigUint {
use num::Integer;

match (base.is_zero(), exp.is_zero()) {
(_, true) => return BigUint::one(), // n^0 % m
(true, false) => return BigUint::zero(), // 0^n % m, n>0
(false, false) if modulus <= BigUint::one() => return BigUint::zero(), // a^b % 1 = 0.
_ => {}
}

let mut result = BigUint::one();
base = base % &modulus;

// fast path for base divisible by modulus.
if base.is_zero() { return result }
while !exp.is_zero() {
if exp.is_odd() {
result = (result * &base) % &modulus;
}

exp = exp >> 1;
base = (base.clone() * base) % &modulus;
}

result
}

// write output to given memory, left padded and same length as the modulus.
let bytes = modexp(base, exp, modulus).to_bytes_be();

Expand Down Expand Up @@ -504,10 +504,44 @@ impl Impl for Bn128PairingImpl {

#[cfg(test)]
mod tests {
use super::{Builtin, Linear, ethereum_builtin, Pricer, Modexp};
use super::{Builtin, Linear, ethereum_builtin, Pricer, Modexp, modexp as me};
use ethjson;
use util::{U256, BytesRef};
use rustc_hex::FromHex;
use num::{BigUint, Zero, One};

#[test]
fn modexp_func() {
// n^0 % m == 1
let mut base = BigUint::parse_bytes(b"12345", 10).unwrap();
let mut exp = BigUint::zero();
let mut modulus = BigUint::parse_bytes(b"789", 10).unwrap();
assert_eq!(me(base, exp, modulus), BigUint::one());

// 0^n % m == 0
base = BigUint::zero();
exp = BigUint::parse_bytes(b"12345", 10).unwrap();
modulus = BigUint::parse_bytes(b"789", 10).unwrap();
assert_eq!(me(base, exp, modulus), BigUint::zero());

// n^m % 1 == 0
base = BigUint::parse_bytes(b"12345", 10).unwrap();
exp = BigUint::parse_bytes(b"789", 10).unwrap();
modulus = BigUint::one();
assert_eq!(me(base, exp, modulus), BigUint::zero());

// if n % d == 0, then n^m % d == 0
base = BigUint::parse_bytes(b"12345", 10).unwrap();
exp = BigUint::parse_bytes(b"789", 10).unwrap();
modulus = BigUint::parse_bytes(b"15", 10).unwrap();
assert_eq!(me(base, exp, modulus), BigUint::zero());

// others
base = BigUint::parse_bytes(b"12345", 10).unwrap();
exp = BigUint::parse_bytes(b"789", 10).unwrap();
modulus = BigUint::parse_bytes(b"97", 10).unwrap();
assert_eq!(me(base, exp, modulus), BigUint::parse_bytes(b"55", 10).unwrap());
}

#[test]
fn identity() {
Expand Down

0 comments on commit 4794124

Please sign in to comment.