diff --git a/num_enum/tests/try_build/compile_fail/alternative_exprs.rs b/num_enum/tests/try_build/compile_fail/alternative_exprs.rs new file mode 100644 index 0000000..046d9c1 --- /dev/null +++ b/num_enum/tests/try_build/compile_fail/alternative_exprs.rs @@ -0,0 +1,13 @@ +const THREE: u8 = 3; + +#[derive(num_enum::TryFromPrimitive)] +#[repr(i8)] +enum Numbers { + Zero = 0, + #[num_enum(alternatives = [-1, 2, THREE])] + One = 1, +} + +fn main() { + +} diff --git a/num_enum/tests/try_build/compile_fail/alternative_exprs.stderr b/num_enum/tests/try_build/compile_fail/alternative_exprs.stderr new file mode 100644 index 0000000..33540c5 --- /dev/null +++ b/num_enum/tests/try_build/compile_fail/alternative_exprs.stderr @@ -0,0 +1,5 @@ +error: Only literals are allowed as num_enum alternate values + --> tests/try_build/compile_fail/alternative_exprs.rs:7:39 + | +7 | #[num_enum(alternatives = [-1, 2, THREE])] + | ^^^^^ diff --git a/num_enum/tests/try_from_primitive.rs b/num_enum/tests/try_from_primitive.rs index ee76e0e..ffccf3a 100644 --- a/num_enum/tests/try_from_primitive.rs +++ b/num_enum/tests/try_from_primitive.rs @@ -120,8 +120,79 @@ fn wrong_order() { assert_eq!(four, Ok(Enum::Four)); } -#[cfg(feature = "complex-expression")] +#[test] +fn negative_values() { + #[derive(Debug, Eq, PartialEq, TryFromPrimitive)] + #[repr(i8)] + enum Enum { + MinusTwo = -2, + MinusOne = -1, + Zero = 0, + One = 1, + Two = 2, + } + + let minus_two: Result = (-2i8).try_into(); + assert_eq!(minus_two, Ok(Enum::MinusTwo)); + + let minus_one: Result = (-1i8).try_into(); + assert_eq!(minus_one, Ok(Enum::MinusOne)); + + let zero: Result = 0i8.try_into(); + assert_eq!(zero, Ok(Enum::Zero)); + + let one: Result = 1i8.try_into(); + assert_eq!(one, Ok(Enum::One)); + + let two: Result = 2i8.try_into(); + assert_eq!(two, Ok(Enum::Two)); +} + +#[test] +fn discriminant_expressions() { + const ONE: u8 = 1; + + #[derive(Debug, Eq, PartialEq, TryFromPrimitive)] + #[repr(u8)] + enum Enum { + Zero, + One = ONE, + Two, + Four = 4u8, + Five, + Six = ONE + ONE + 2u8 + 2, + } + + let zero: Result = 0u8.try_into(); + assert_eq!(zero, Ok(Enum::Zero)); + + let one: Result = 1u8.try_into(); + assert_eq!(one, Ok(Enum::One)); + + let two: Result = 2u8.try_into(); + assert_eq!(two, Ok(Enum::Two)); + + let three: Result = 3u8.try_into(); + assert_eq!( + three.unwrap_err().to_string(), + "No discriminant in enum `Enum` matches the value `3`", + ); + + let four: Result = 4u8.try_into(); + assert_eq!(four, Ok(Enum::Four)); + + let five: Result = 5u8.try_into(); + assert_eq!(five, Ok(Enum::Five)); + + let six: Result = 6u8.try_into(); + assert_eq!(six, Ok(Enum::Six)); +} + +#[cfg(feature = "complex-expressions")] mod complex { + use num_enum::TryFromPrimitive; + use std::convert::TryInto; + const ONE: u8 = 1; #[derive(Debug, Eq, PartialEq, TryFromPrimitive)] @@ -261,26 +332,29 @@ fn error_variant_is_allowed() { #[test] fn alternative_values() { #[derive(Debug, Eq, PartialEq, TryFromPrimitive)] - #[repr(u8)] + #[repr(i8)] enum Enum { Zero = 0, - #[num_enum(alternatives = [2, 3])] - OneTwoOrThree = 1, + #[num_enum(alternatives = [-1, 2, 3])] + OneTwoThreeOrMinusOne = 1, } - let zero: Result = 0u8.try_into(); + let minus_one: Result = (-1i8).try_into(); + assert_eq!(minus_one, Ok(Enum::OneTwoThreeOrMinusOne)); + + let zero: Result = 0i8.try_into(); assert_eq!(zero, Ok(Enum::Zero)); - let one: Result = 1u8.try_into(); - assert_eq!(one, Ok(Enum::OneTwoOrThree)); + let one: Result = 1i8.try_into(); + assert_eq!(one, Ok(Enum::OneTwoThreeOrMinusOne)); - let two: Result = 2u8.try_into(); - assert_eq!(two, Ok(Enum::OneTwoOrThree)); + let two: Result = 2i8.try_into(); + assert_eq!(two, Ok(Enum::OneTwoThreeOrMinusOne)); - let three: Result = 3u8.try_into(); - assert_eq!(three, Ok(Enum::OneTwoOrThree)); + let three: Result = 3i8.try_into(); + assert_eq!(three, Ok(Enum::OneTwoThreeOrMinusOne)); - let four: Result = 4u8.try_into(); + let four: Result = 4i8.try_into(); assert_eq!( four.unwrap_err().to_string(), "No discriminant in enum `Enum` matches the value `4`" diff --git a/num_enum_derive/src/lib.rs b/num_enum_derive/src/lib.rs index 157e19e..5de3402 100644 --- a/num_enum_derive/src/lib.rs +++ b/num_enum_derive/src/lib.rs @@ -8,7 +8,8 @@ use syn::{ parse::{Parse, ParseStream}, parse_macro_input, parse_quote, spanned::Spanned, - Attribute, Data, DeriveInput, Error, Expr, Fields, Ident, Lit, LitInt, LitStr, Meta, Result, + Attribute, Data, DeriveInput, Error, Expr, ExprLit, ExprUnary, Fields, Ident, Lit, LitInt, + LitStr, Meta, Result, UnOp, }; macro_rules! die { @@ -26,20 +27,40 @@ macro_rules! die { } fn literal(i: i128) -> Expr { - let literal = LitInt::new(&i.to_string(), Span::call_site()); - parse_quote! { - #literal - } + Expr::Lit(ExprLit { + lit: Lit::Int(LitInt::new(&i.to_string(), Span::call_site())), + attrs: vec![], + }) } -fn expr_to_int(val_exp: &Expr) -> Result { - Ok(match val_exp { - Expr::Lit(ref val_exp_lit) => match val_exp_lit.lit { - Lit::Int(ref lit_int) => lit_int.base10_parse()?, - _ => die!(val_exp => "Expected integer"), - }, - _ => die!(val_exp => "Expected literal"), - }) +enum DiscriminantValue { + Literal(i128), + Expr(Expr), +} + +fn parse_discriminant(val_exp: &Expr) -> Result { + let mut sign = 1; + let mut unsigned_expr = val_exp; + if let Expr::Unary(ExprUnary { + op: UnOp::Neg(..), + expr, + .. + }) = val_exp + { + unsigned_expr = expr; + sign = -1; + } + if let Expr::Lit(ExprLit { + lit: Lit::Int(ref lit_int), + .. + }) = unsigned_expr + { + Ok(DiscriminantValue::Literal( + sign * lit_int.base10_parse::()?, + )) + } else { + Ok(DiscriminantValue::Expr(val_exp.clone())) + } } mod kw { @@ -307,7 +328,7 @@ impl Parse for EnumInfo { let mut has_catch_all_variant: bool = false; // Vec to keep track of the used discriminants and alt values. - let mut val_set: BTreeSet = BTreeSet::new(); + let mut discriminant_int_val_set = BTreeSet::new(); let mut next_discriminant = literal(0); for variant in data.variants.into_iter() { @@ -319,7 +340,7 @@ impl Parse for EnumInfo { }; let mut attr_spans: AttributeSpans = Default::default(); - let mut alternative_values: Vec = vec![]; + let mut raw_alternative_values: Vec = vec![]; // Keep the attribute around for better error reporting. let mut alt_attr_ref: Vec<&Attribute> = vec![]; @@ -398,7 +419,7 @@ impl Parse for EnumInfo { } NumEnumVariantAttributeItem::Alternatives(alternatives) => { attr_spans.alternatives.push(alternatives.span()); - alternative_values.extend(alternatives.expressions); + raw_alternative_values.extend(alternatives.expressions); alt_attr_ref.push(attribute); } } @@ -422,75 +443,107 @@ impl Parse for EnumInfo { } } - let canonical_value = discriminant; - let canonical_value_int = expr_to_int(&canonical_value)?; + let discriminant_value = parse_discriminant(&discriminant)?; // Check for collision. - if val_set.contains(&canonical_value_int) { - die!(ident => format!("The discriminant '{}' collides with a value attributed to a previous variant", canonical_value_int)) + // We can't do const evaluation, or even compare arbitrary Exprs, + // so unfortunately we can't check for duplicates. + // That's not the end of the world, just we'll end up with compile errors for + // matches with duplicate branches in generated code instead of nice friendly error messages. + if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value { + if discriminant_int_val_set.contains(&canonical_value_int) { + die!(ident => format!("The discriminant '{}' collides with a value attributed to a previous variant", canonical_value_int)) + } } // Deal with the alternative values. - let alt_val = alternative_values + let alternate_values = raw_alternative_values .iter() - .map(expr_to_int) + .map(parse_discriminant) .collect::>>()?; - debug_assert_eq!(alt_val.len(), alternative_values.len()); - - if !alt_val.is_empty() { - let mut alt_val_sorted = alt_val.clone(); - alt_val_sorted.sort_unstable(); - let alt_val_sorted = alt_val_sorted; - - // check if the current discriminant is not in the alternative values. - if let Some(i) = alt_val.iter().position(|&x| x == canonical_value_int) { - die!(&alternative_values[i] => format!("'{}' in the alternative values is already attributed as the discriminant of this variant", canonical_value_int)); + debug_assert_eq!(alternate_values.len(), raw_alternative_values.len()); + + if !alternate_values.is_empty() { + let mut sorted_alternate_int_values = alternate_values + .into_iter() + .map(|v| { + match v { + DiscriminantValue::Literal(value) => Ok(value), + DiscriminantValue::Expr(expr) => { + // We can't do uniqueness checking on non-literals, so we don't allow them as alternate values. + // We could probably allow them, but there doesn't seem to be much of a use-case, + // and it's easier to give good error messages about duplicate values this way, + // rather than rustc errors on conflicting match branches. + die!(expr => format!("Only literals are allowed as num_enum alternate values")) + }, + } + }) + .collect::>>()?; + sorted_alternate_int_values.sort_unstable(); + let sorted_alternate_int_values = sorted_alternate_int_values; + + // Check if the current discriminant is not in the alternative values. + if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value { + if let Ok(index) = + sorted_alternate_int_values.binary_search(&canonical_value_int) + { + die!(&raw_alternative_values[index] => format!("'{}' in the alternative values is already attributed as the discriminant of this variant", canonical_value_int)); + } } // Search for duplicates, the vec is sorted. Warn about them. - if (1..alt_val_sorted.len()).any(|i| alt_val_sorted[i] == alt_val_sorted[i - 1]) - { + if (1..sorted_alternate_int_values.len()).any(|i| { + sorted_alternate_int_values[i] == sorted_alternate_int_values[i - 1] + }) { let attr = *alt_attr_ref.last().unwrap(); die!(attr => "There is duplication in the alternative values"); } - // Search if those alt_val where already attributed. - // (The val_set is BTreeSet, and iter().next_back() is the is the maximum in the set.) - if let Some(last_upper_val) = val_set.iter().next_back() { - if alt_val_sorted.first().unwrap() <= last_upper_val { - for (i, val) in alt_val_sorted.iter().enumerate() { - if val_set.contains(val) { - die!(&alternative_values[i] => format!("'{}' in the alternative values is already attributed to a previous variant", val)); + // Search if those discriminant_int_val_set where already attributed. + // (discriminant_int_val_set is BTreeSet, and iter().next_back() is the is the maximum in the set.) + if let Some(last_upper_val) = discriminant_int_val_set.iter().next_back() { + if sorted_alternate_int_values.first().unwrap() <= last_upper_val { + for (i, val) in sorted_alternate_int_values.iter().enumerate() { + if discriminant_int_val_set.contains(val) { + die!(&raw_alternative_values[i] => format!("'{}' in the alternative values is already attributed to a previous variant", val)); } } } } // Reconstruct the alternative_values vec of Expr but sorted. - alternative_values = alt_val_sorted + raw_alternative_values = sorted_alternate_int_values .iter() .map(|val| literal(val.to_owned())) .collect(); // Add the alternative values to the the set to keep track. - val_set.extend(alt_val_sorted); + discriminant_int_val_set.extend(sorted_alternate_int_values); } // Add the current discriminant to the the set to keep track. - let newly_inserted = val_set.insert(canonical_value_int); - debug_assert!(newly_inserted); + if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value { + discriminant_int_val_set.insert(canonical_value_int); + } variants.push(VariantInfo { ident, attr_spans, is_default, is_catch_all, - canonical_value, - alternative_values, + canonical_value: discriminant, + alternative_values: raw_alternative_values, }); // Get the next value for the discriminant. - next_discriminant = literal(canonical_value_int + 1); + next_discriminant = match discriminant_value { + DiscriminantValue::Literal(int_value) => literal(int_value.wrapping_add(1)), + DiscriminantValue::Expr(expr) => { + parse_quote! { + #repr::wrapping_add(#expr, 1) + } + } + } } EnumInfo {