Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: reduce mem cost on modexp #1300

Merged
merged 1 commit into from
Aug 4, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 88 additions & 55 deletions core/executor/src/precompiles/modexp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,29 @@ impl PrecompileContract for ModExp {
_context: &Context,
_is_static: bool,
) -> Result<(PrecompileOutput, u64), PrecompileFailure> {
let large_number = LargeNumber::parse(input)?;

let gas = Self::gas_cost(input);
if let Some(limit) = gas_limit {
if limit < gas {
return err!();
}
}

let base_size = get_data(input, 0, 32).saturating_as::<usize>();
let modulo_size = get_data(input, 64, 32).saturating_as::<usize>();

// Handle a special case when both the base and mod length is zero
if base_size == 0 && modulo_size == 0 {
return Ok((
PrecompileOutput {
exit_status: ExitSucceed::Returned,
output: Vec::new(),
},
gas,
));
}

let large_number = LargeNumber::parse(input, base_size, modulo_size)?;

let m_size = large_number.m_size;
let mut res = large_number.calc()?.to_digits::<u8>(Order::MsfBe);
let res_len = res.len();
Expand Down Expand Up @@ -64,21 +78,44 @@ impl PrecompileContract for ModExp {
}

fn gas_cost(input: &[u8]) -> u64 {
match LargeNumber::parse(input) {
Ok(large_number) => {
let dynamic_gas =
large_number.multiplication_complexity() * large_number.iterator_count() / 3u64;

dynamic_gas
.max(Integer::from(Self::MIN_GAS))
.saturating_as()
}
Err(_) => u64::MAX,
let base_size = get_data(input, 0, 32);
let modulo_size = get_data(input, 64, 32);

// multiplication_complexity always zero
if base_size == 0 && modulo_size == 0 {
return Self::MIN_GAS;
}

let exponent_size = get_data(input, 32, 32);

let data = if input.len() > 96 {
&input[96..]
} else {
&input[0..0]
};

let exponent = if exponent_size > 32 {
get_data(data, base_size.clone().saturating_as::<usize>(), 32)
} else {
get_data(
data,
base_size.clone().saturating_as::<usize>(),
exponent_size.clone().saturating_as::<usize>(),
)
};

let multiplication_complexity = multiplication_complexity(base_size, modulo_size);

let iterator_count = iterator_count(exponent_size, exponent);

let dynamic_gas = multiplication_complexity * iterator_count / 3u64;
dynamic_gas
.max(Integer::from(Self::MIN_GAS))
.saturating_as::<u64>()
}
}

fn get_data(data: &[u8], mut start: usize, size: usize) -> Result<Integer, PrecompileFailure> {
fn get_data(data: &[u8], mut start: usize, size: usize) -> Integer {
let len = data.len();

if start > len {
Expand All @@ -96,73 +133,44 @@ fn get_data(data: &[u8], mut start: usize, size: usize) -> Result<Integer, Preco
Vec::new()
};

padded
.try_reserve_exact(size)
.map_err(|_| PrecompileFailure::Error {
exit_status: ExitError::StackOverflow,
})?;
// may panic here when memory doesn't enough
padded.reserve_exact(size);

padded.extend(std::iter::repeat(0).take(size - (end.saturating_sub(start))));

Ok(Integer::from_digits(&padded, Order::MsfBe))
Integer::from_digits(&padded, Order::MsfBe)
}

struct LargeNumber {
b_size: usize,
e_size: usize,
m_size: usize,
base: Integer,
exponent: Integer,
modulo: Integer,
}

impl LargeNumber {
fn parse(input: &[u8]) -> Result<Self, PrecompileFailure> {
let base_size = get_data(input, 0, 32)?.saturating_as::<usize>();
let exponent_size = get_data(input, 32, 32)?.saturating_as::<usize>();
let modulo_size = get_data(input, 64, 32)?.saturating_as::<usize>();
fn parse(
input: &[u8],
base_size: usize,
modulo_size: usize,
) -> Result<Self, PrecompileFailure> {
let exponent_size = get_data(input, 32, 32).saturating_as::<usize>();

let data = if input.len() > 96 {
&input[96..]
} else {
&input[0..0]
};

Ok(LargeNumber {
b_size: base_size,
e_size: exponent_size,
m_size: modulo_size,
base: get_data(data, 0, base_size)?,
exponent: get_data(data, base_size, exponent_size)?,
modulo: get_data(data, base_size.wrapping_add(exponent_size), modulo_size)?,
base: get_data(data, 0, base_size),
exponent: get_data(data, base_size, exponent_size),
modulo: get_data(data, base_size.wrapping_add(exponent_size), modulo_size),
})
}

fn multiplication_complexity(&self) -> Integer {
Integer::from((self.b_size.max(self.m_size) + 7) / 8).pow(2)
}

fn iterator_count(&self) -> u64 {
let iter_count = if self.e_size <= 32 && self.exponent == Integer::ZERO {
0
} else if self.e_size <= 32 {
(self.exponent.significant_bits() - 1) as usize
} else {
let bytes: [u8; 32] = [0xFF; 32];
let max_256_bit_uint = Integer::from_digits(&bytes, Order::MsfBe);
(8 * (self.e_size - 32))
+ ((self.exponent.clone().bitand(max_256_bit_uint))
.significant_bits()
.saturating_sub(1)) as usize
};

iter_count.max(1) as u64
}

fn calc(self) -> Result<Integer, PrecompileFailure> {
if self.b_size == 0 && self.m_size == 0 {
return Ok(Integer::ZERO);
}

// https://github.com/ethereum/go-ethereum/blob/a03490c6b2ff0e1d9a1274afdbe087a695d533eb/core/vm/contracts.go#L385
if self.modulo == Integer::ZERO {
return Ok(Integer::ZERO);
Expand All @@ -175,3 +183,28 @@ impl LargeNumber {
.map_err(|_| err!(_, "Overflow"))
}
}

fn multiplication_complexity(b_size: Integer, m_size: Integer) -> Integer {
let a = b_size.max(m_size);
let a: Integer = a + 7;
let a: Integer = a / 8;
a.pow(2)
}

fn iterator_count(e_size: Integer, exponent: Integer) -> u64 {
let iter_count = if e_size <= 32 && exponent == Integer::ZERO {
0
} else if e_size <= 32 {
(exponent.significant_bits() - 1) as usize
} else {
let bytes: [u8; 32] = [0xFF; 32];
let max_256_bit_uint = Integer::from_digits(&bytes, Order::MsfBe);
let a: Integer = 8 * (e_size - 32);
a.saturating_as::<usize>()
+ ((exponent.bitand(max_256_bit_uint))
.significant_bits()
.saturating_sub(1)) as usize
};

iter_count.max(1) as u64
}