Skip to content

Commit

Permalink
Merge pull request #200 from julesdesmit/fix_overflow
Browse files Browse the repository at this point in the history
Correct carry propagation when adding digits in macro (fixes #199)
  • Loading branch information
recmo authored Nov 1, 2022
2 parents 46937ac + b6c70c1 commit 68d4833
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 25 deletions.
4 changes: 4 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Fixed

- Edge case in which an overflow occurs when parsing a `Uint` with `uint!`.

## [1.6.0] — 2022-10-28

### Added
Expand Down
52 changes: 27 additions & 25 deletions ruint-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ fn error(span: Span, message: &str) -> TokenStream {
}

/// Parse a value literal and bits suffix into a Uint literal.
fn parse(value: &str, bits: &str) -> Result<TokenStream, String> {
fn parse(value: &str, bits: &str) -> Result<(usize, Vec<u64>), String> {
// Parse bit length
let bits = bits
.parse::<usize>()
Expand Down Expand Up @@ -91,13 +91,12 @@ fn parse(value: &str, bits: &str) -> Result<TokenStream, String> {
#[allow(clippy::cast_lossless)]
if digit > base as u64 {
return Err(format!(
"Invalid digit {} in base {} (did you forget the `0x` prefix?)",
c, base
"Invalid digit {c} in base {base} (did you forget the `0x` prefix?)"
));
}

// Multiply result by base
let mut carry = 0_u64;
// Multiply result by base and add digit
let mut carry = digit;
#[allow(clippy::cast_lossless)]
#[allow(clippy::cast_possible_truncation)]
for limb in &mut limbs {
Expand All @@ -108,9 +107,6 @@ fn parse(value: &str, bits: &str) -> Result<TokenStream, String> {
if carry > 0 {
limbs.push(carry);
}

// Add digit to result
limbs[0] += digit; // Never carries
}

// Remove trailing zeros, pad with zeros
Expand All @@ -127,16 +123,19 @@ fn parse(value: &str, bits: &str) -> Result<TokenStream, String> {
return Err(format!("Value too large for Uint<{bits}>: {value}"));
}

Ok(construct(bits, &limbs))
Ok((bits, limbs))
}

/// Transforms a [`Literal`] and returns the substitute [`TokenTree`]
fn transform_literal(literal: Literal) -> TokenTree {
let source = literal.to_string();
if let Some((value, bits)) = source.split_once('U') {
let stream = parse(value, bits).unwrap_or_else(|e| error(literal.span(), &e));
let tokens = parse(value, bits).map_or_else(
|e| error(literal.span(), &e),
|(bits, limbs)| construct(bits, &limbs),
);

return TokenTree::Group(Group::new(Delimiter::None, stream));
return TokenTree::Group(Group::new(Delimiter::None, tokens));
}
TokenTree::Literal(literal)
}
Expand Down Expand Up @@ -176,26 +175,29 @@ pub fn uint(stream: TokenStream) -> TokenStream {

#[cfg(test)]
mod tests {
use ruint::{uint, Uint};
use super::*;

#[test]
fn test_zero_size() {
uint! {
assert_eq!(0_U0, Uint::ZERO);
assert_eq!(0000_U0, Uint::ZERO);
assert_eq!(0x00_U0, Uint::ZERO);
assert_eq!(0b0000_U0, Uint::ZERO);
assert_eq!(0b0000000_U0, Uint::ZERO);
}
assert_eq!(parse("0", "0"), Ok((0, vec![])));
assert_eq!(parse("00000", "0"), Ok((0, vec![])));
assert_eq!(parse("0x00", "0"), Ok((0, vec![])));
assert_eq!(parse("0b0000", "0"), Ok((0, vec![])));
assert_eq!(parse("0b0000000", "0"), Ok((0, vec![])));
}

#[test]
fn test_bases() {
uint! {
assert_eq!(10_U8, Uint::from(10));
assert_eq!(0x10_U8, 16_u64.try_into().unwrap());
assert_eq!(0b10_U8, 2_u64.try_into().unwrap());
assert_eq!(0o10_U8, 8_u64.try_into().unwrap());
}
assert_eq!(parse("10", "8"), Ok((8, vec![10])));
assert_eq!(parse("0x10", "8"), Ok((8, vec![16])));
assert_eq!(parse("0b10", "8"), Ok((8, vec![2])));
assert_eq!(parse("0o10", "8"), Ok((8, vec![8])));
}

#[test]
#[allow(clippy::unreadable_literal)]
fn test_overflow_during_parsing() {
assert_eq!(parse("258664426012969093929703085429980814127835149614277183275038967946009968870203535512256352201271898244626862047232", "384"), Ok((384, vec![0, 15125697203588300800, 6414901478162127871, 13296924585243691235, 13584922160258634318, 121098312706494698])));
assert_eq!(parse("2135987035920910082395021706169552114602704522356652769947041607822219725780640550022962086936576", "384"), Ok((384, vec![0, 0, 0, 0, 0, 1])));
}
}

0 comments on commit 68d4833

Please sign in to comment.