diff --git a/core/executor/src/precompiles/modexp.rs b/core/executor/src/precompiles/modexp.rs index 6ef0ff310..814e51b49 100644 --- a/core/executor/src/precompiles/modexp.rs +++ b/core/executor/src/precompiles/modexp.rs @@ -24,8 +24,6 @@ impl PrecompileContract for ModExp { _context: &Context, _is_static: bool, ) -> Result<(PrecompileOutput, u64), PrecompileFailure> { - let large_number = LargeNumber::parse(input)?; - let gas = Self::gas_cost(input); if let Some(limit) = gas_limit { if limit < gas { @@ -33,6 +31,22 @@ impl PrecompileContract for ModExp { } } + let base_size = get_data(input, 0, 32).saturating_as::(); + let modulo_size = get_data(input, 64, 32).saturating_as::(); + + // Handle a special case when both the base and mod length is zero + if base_size == 0 && modulo_size == 0 { + return Ok(( + PrecompileOutput { + exit_status: ExitSucceed::Returned, + output: Vec::new(), + }, + gas, + )); + } + + let large_number = LargeNumber::parse(input, base_size, modulo_size)?; + let m_size = large_number.m_size; let mut res = large_number.calc()?.to_digits::(Order::MsfBe); let res_len = res.len(); @@ -64,21 +78,44 @@ impl PrecompileContract for ModExp { } fn gas_cost(input: &[u8]) -> u64 { - match LargeNumber::parse(input) { - Ok(large_number) => { - let dynamic_gas = - large_number.multiplication_complexity() * large_number.iterator_count() / 3u64; - - dynamic_gas - .max(Integer::from(Self::MIN_GAS)) - .saturating_as() - } - Err(_) => u64::MAX, + let base_size = get_data(input, 0, 32); + let modulo_size = get_data(input, 64, 32); + + // multiplication_complexity always zero + if base_size == 0 && modulo_size == 0 { + return Self::MIN_GAS; } + + let exponent_size = get_data(input, 32, 32); + + let data = if input.len() > 96 { + &input[96..] + } else { + &input[0..0] + }; + + let exponent = if exponent_size > 32 { + get_data(data, base_size.clone().saturating_as::(), 32) + } else { + get_data( + data, + base_size.clone().saturating_as::(), + exponent_size.clone().saturating_as::(), + ) + }; + + let multiplication_complexity = multiplication_complexity(base_size, modulo_size); + + let iterator_count = iterator_count(exponent_size, exponent); + + let dynamic_gas = multiplication_complexity * iterator_count / 3u64; + dynamic_gas + .max(Integer::from(Self::MIN_GAS)) + .saturating_as::() } } -fn get_data(data: &[u8], mut start: usize, size: usize) -> Result { +fn get_data(data: &[u8], mut start: usize, size: usize) -> Integer { let len = data.len(); if start > len { @@ -96,20 +133,15 @@ fn get_data(data: &[u8], mut start: usize, size: usize) -> Result Result { - let base_size = get_data(input, 0, 32)?.saturating_as::(); - let exponent_size = get_data(input, 32, 32)?.saturating_as::(); - let modulo_size = get_data(input, 64, 32)?.saturating_as::(); + fn parse( + input: &[u8], + base_size: usize, + modulo_size: usize, + ) -> Result { + let exponent_size = get_data(input, 32, 32).saturating_as::(); + let data = if input.len() > 96 { &input[96..] } else { @@ -128,41 +163,14 @@ impl LargeNumber { }; Ok(LargeNumber { - b_size: base_size, - e_size: exponent_size, m_size: modulo_size, - base: get_data(data, 0, base_size)?, - exponent: get_data(data, base_size, exponent_size)?, - modulo: get_data(data, base_size.wrapping_add(exponent_size), modulo_size)?, + base: get_data(data, 0, base_size), + exponent: get_data(data, base_size, exponent_size), + modulo: get_data(data, base_size.wrapping_add(exponent_size), modulo_size), }) } - fn multiplication_complexity(&self) -> Integer { - Integer::from((self.b_size.max(self.m_size) + 7) / 8).pow(2) - } - - fn iterator_count(&self) -> u64 { - let iter_count = if self.e_size <= 32 && self.exponent == Integer::ZERO { - 0 - } else if self.e_size <= 32 { - (self.exponent.significant_bits() - 1) as usize - } else { - let bytes: [u8; 32] = [0xFF; 32]; - let max_256_bit_uint = Integer::from_digits(&bytes, Order::MsfBe); - (8 * (self.e_size - 32)) - + ((self.exponent.clone().bitand(max_256_bit_uint)) - .significant_bits() - .saturating_sub(1)) as usize - }; - - iter_count.max(1) as u64 - } - fn calc(self) -> Result { - if self.b_size == 0 && self.m_size == 0 { - return Ok(Integer::ZERO); - } - // https://github.com/ethereum/go-ethereum/blob/a03490c6b2ff0e1d9a1274afdbe087a695d533eb/core/vm/contracts.go#L385 if self.modulo == Integer::ZERO { return Ok(Integer::ZERO); @@ -175,3 +183,28 @@ impl LargeNumber { .map_err(|_| err!(_, "Overflow")) } } + +fn multiplication_complexity(b_size: Integer, m_size: Integer) -> Integer { + let a = b_size.max(m_size); + let a: Integer = a + 7; + let a: Integer = a / 8; + a.pow(2) +} + +fn iterator_count(e_size: Integer, exponent: Integer) -> u64 { + let iter_count = if e_size <= 32 && exponent == Integer::ZERO { + 0 + } else if e_size <= 32 { + (exponent.significant_bits() - 1) as usize + } else { + let bytes: [u8; 32] = [0xFF; 32]; + let max_256_bit_uint = Integer::from_digits(&bytes, Order::MsfBe); + let a: Integer = 8 * (e_size - 32); + a.saturating_as::() + + ((exponent.bitand(max_256_bit_uint)) + .significant_bits() + .saturating_sub(1)) as usize + }; + + iter_count.max(1) as u64 +}