Skip to content

Commit

Permalink
Support non-literal discriminant values (#96)
Browse files Browse the repository at this point in the history
While preserving favouring literal values where possible.

Fixes #94 

The story here is:
#89 tried to simplify the generated code (which, among other things, made rustc not stack overflow when codegening for large enums).
Unfortunately, it assumed that all explicit discriminants were literals (and non-negative ones, as `-123` parses as a `UnaryExpr` not a `Literal`).
And further unfortunately, because of a typo in a `#[cfg(feature)]` attached to the only tests we had for non-literal enums, we weren't running those tests (this PR re-enables them and adds some which aren't feature gated, and #95 will ensure we don't regress in this way again).

This PR attempts to preserve the "prefer just using literals rather than large chains of wrapping adds" property of #89, while also supporting non-literal cases.
  • Loading branch information
illicitonion committed Jan 15, 2023
1 parent 7acc582 commit 6604e2b
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 60 deletions.
13 changes: 13 additions & 0 deletions num_enum/tests/try_build/compile_fail/alternative_exprs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
const THREE: u8 = 3;

#[derive(num_enum::TryFromPrimitive)]
#[repr(i8)]
enum Numbers {
Zero = 0,
#[num_enum(alternatives = [-1, 2, THREE])]
One = 1,
}

fn main() {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
error: Only literals are allowed as num_enum alternate values
--> tests/try_build/compile_fail/alternative_exprs.rs:7:39
|
7 | #[num_enum(alternatives = [-1, 2, THREE])]
| ^^^^^
98 changes: 86 additions & 12 deletions num_enum/tests/try_from_primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,79 @@ fn wrong_order() {
assert_eq!(four, Ok(Enum::Four));
}

#[cfg(feature = "complex-expression")]
#[test]
fn negative_values() {
#[derive(Debug, Eq, PartialEq, TryFromPrimitive)]
#[repr(i8)]
enum Enum {
MinusTwo = -2,
MinusOne = -1,
Zero = 0,
One = 1,
Two = 2,
}

let minus_two: Result<Enum, _> = (-2i8).try_into();
assert_eq!(minus_two, Ok(Enum::MinusTwo));

let minus_one: Result<Enum, _> = (-1i8).try_into();
assert_eq!(minus_one, Ok(Enum::MinusOne));

let zero: Result<Enum, _> = 0i8.try_into();
assert_eq!(zero, Ok(Enum::Zero));

let one: Result<Enum, _> = 1i8.try_into();
assert_eq!(one, Ok(Enum::One));

let two: Result<Enum, _> = 2i8.try_into();
assert_eq!(two, Ok(Enum::Two));
}

#[test]
fn discriminant_expressions() {
const ONE: u8 = 1;

#[derive(Debug, Eq, PartialEq, TryFromPrimitive)]
#[repr(u8)]
enum Enum {
Zero,
One = ONE,
Two,
Four = 4u8,
Five,
Six = ONE + ONE + 2u8 + 2,
}

let zero: Result<Enum, _> = 0u8.try_into();
assert_eq!(zero, Ok(Enum::Zero));

let one: Result<Enum, _> = 1u8.try_into();
assert_eq!(one, Ok(Enum::One));

let two: Result<Enum, _> = 2u8.try_into();
assert_eq!(two, Ok(Enum::Two));

let three: Result<Enum, _> = 3u8.try_into();
assert_eq!(
three.unwrap_err().to_string(),
"No discriminant in enum `Enum` matches the value `3`",
);

let four: Result<Enum, _> = 4u8.try_into();
assert_eq!(four, Ok(Enum::Four));

let five: Result<Enum, _> = 5u8.try_into();
assert_eq!(five, Ok(Enum::Five));

let six: Result<Enum, _> = 6u8.try_into();
assert_eq!(six, Ok(Enum::Six));
}

#[cfg(feature = "complex-expressions")]
mod complex {
use num_enum::TryFromPrimitive;
use std::convert::TryInto;

const ONE: u8 = 1;

#[derive(Debug, Eq, PartialEq, TryFromPrimitive)]
Expand Down Expand Up @@ -261,26 +332,29 @@ fn error_variant_is_allowed() {
#[test]
fn alternative_values() {
#[derive(Debug, Eq, PartialEq, TryFromPrimitive)]
#[repr(u8)]
#[repr(i8)]
enum Enum {
Zero = 0,
#[num_enum(alternatives = [2, 3])]
OneTwoOrThree = 1,
#[num_enum(alternatives = [-1, 2, 3])]
OneTwoThreeOrMinusOne = 1,
}

let zero: Result<Enum, _> = 0u8.try_into();
let minus_one: Result<Enum, _> = (-1i8).try_into();
assert_eq!(minus_one, Ok(Enum::OneTwoThreeOrMinusOne));

let zero: Result<Enum, _> = 0i8.try_into();
assert_eq!(zero, Ok(Enum::Zero));

let one: Result<Enum, _> = 1u8.try_into();
assert_eq!(one, Ok(Enum::OneTwoOrThree));
let one: Result<Enum, _> = 1i8.try_into();
assert_eq!(one, Ok(Enum::OneTwoThreeOrMinusOne));

let two: Result<Enum, _> = 2u8.try_into();
assert_eq!(two, Ok(Enum::OneTwoOrThree));
let two: Result<Enum, _> = 2i8.try_into();
assert_eq!(two, Ok(Enum::OneTwoThreeOrMinusOne));

let three: Result<Enum, _> = 3u8.try_into();
assert_eq!(three, Ok(Enum::OneTwoOrThree));
let three: Result<Enum, _> = 3i8.try_into();
assert_eq!(three, Ok(Enum::OneTwoThreeOrMinusOne));

let four: Result<Enum, _> = 4u8.try_into();
let four: Result<Enum, _> = 4i8.try_into();
assert_eq!(
four.unwrap_err().to_string(),
"No discriminant in enum `Enum` matches the value `4`"
Expand Down
149 changes: 101 additions & 48 deletions num_enum_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use syn::{
parse::{Parse, ParseStream},
parse_macro_input, parse_quote,
spanned::Spanned,
Attribute, Data, DeriveInput, Error, Expr, Fields, Ident, Lit, LitInt, LitStr, Meta, Result,
Attribute, Data, DeriveInput, Error, Expr, ExprLit, ExprUnary, Fields, Ident, Lit, LitInt,
LitStr, Meta, Result, UnOp,
};

macro_rules! die {
Expand All @@ -26,20 +27,40 @@ macro_rules! die {
}

fn literal(i: i128) -> Expr {
let literal = LitInt::new(&i.to_string(), Span::call_site());
parse_quote! {
#literal
}
Expr::Lit(ExprLit {
lit: Lit::Int(LitInt::new(&i.to_string(), Span::call_site())),
attrs: vec![],
})
}

fn expr_to_int(val_exp: &Expr) -> Result<i128> {
Ok(match val_exp {
Expr::Lit(ref val_exp_lit) => match val_exp_lit.lit {
Lit::Int(ref lit_int) => lit_int.base10_parse()?,
_ => die!(val_exp => "Expected integer"),
},
_ => die!(val_exp => "Expected literal"),
})
enum DiscriminantValue {
Literal(i128),
Expr(Expr),
}

fn parse_discriminant(val_exp: &Expr) -> Result<DiscriminantValue> {
let mut sign = 1;
let mut unsigned_expr = val_exp;
if let Expr::Unary(ExprUnary {
op: UnOp::Neg(..),
expr,
..
}) = val_exp
{
unsigned_expr = expr;
sign = -1;
}
if let Expr::Lit(ExprLit {
lit: Lit::Int(ref lit_int),
..
}) = unsigned_expr
{
Ok(DiscriminantValue::Literal(
sign * lit_int.base10_parse::<i128>()?,
))
} else {
Ok(DiscriminantValue::Expr(val_exp.clone()))
}
}

mod kw {
Expand Down Expand Up @@ -307,7 +328,7 @@ impl Parse for EnumInfo {
let mut has_catch_all_variant: bool = false;

// Vec to keep track of the used discriminants and alt values.
let mut val_set: BTreeSet<i128> = BTreeSet::new();
let mut discriminant_int_val_set = BTreeSet::new();

let mut next_discriminant = literal(0);
for variant in data.variants.into_iter() {
Expand All @@ -319,7 +340,7 @@ impl Parse for EnumInfo {
};

let mut attr_spans: AttributeSpans = Default::default();
let mut alternative_values: Vec<Expr> = vec![];
let mut raw_alternative_values: Vec<Expr> = vec![];
// Keep the attribute around for better error reporting.
let mut alt_attr_ref: Vec<&Attribute> = vec![];

Expand Down Expand Up @@ -398,7 +419,7 @@ impl Parse for EnumInfo {
}
NumEnumVariantAttributeItem::Alternatives(alternatives) => {
attr_spans.alternatives.push(alternatives.span());
alternative_values.extend(alternatives.expressions);
raw_alternative_values.extend(alternatives.expressions);
alt_attr_ref.push(attribute);
}
}
Expand All @@ -422,75 +443,107 @@ impl Parse for EnumInfo {
}
}

let canonical_value = discriminant;
let canonical_value_int = expr_to_int(&canonical_value)?;
let discriminant_value = parse_discriminant(&discriminant)?;

// Check for collision.
if val_set.contains(&canonical_value_int) {
die!(ident => format!("The discriminant '{}' collides with a value attributed to a previous variant", canonical_value_int))
// We can't do const evaluation, or even compare arbitrary Exprs,
// so unfortunately we can't check for duplicates.
// That's not the end of the world, just we'll end up with compile errors for
// matches with duplicate branches in generated code instead of nice friendly error messages.
if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
if discriminant_int_val_set.contains(&canonical_value_int) {
die!(ident => format!("The discriminant '{}' collides with a value attributed to a previous variant", canonical_value_int))
}
}

// Deal with the alternative values.
let alt_val = alternative_values
let alternate_values = raw_alternative_values
.iter()
.map(expr_to_int)
.map(parse_discriminant)
.collect::<Result<Vec<_>>>()?;

debug_assert_eq!(alt_val.len(), alternative_values.len());

if !alt_val.is_empty() {
let mut alt_val_sorted = alt_val.clone();
alt_val_sorted.sort_unstable();
let alt_val_sorted = alt_val_sorted;

// check if the current discriminant is not in the alternative values.
if let Some(i) = alt_val.iter().position(|&x| x == canonical_value_int) {
die!(&alternative_values[i] => format!("'{}' in the alternative values is already attributed as the discriminant of this variant", canonical_value_int));
debug_assert_eq!(alternate_values.len(), raw_alternative_values.len());

if !alternate_values.is_empty() {
let mut sorted_alternate_int_values = alternate_values
.into_iter()
.map(|v| {
match v {
DiscriminantValue::Literal(value) => Ok(value),
DiscriminantValue::Expr(expr) => {
// We can't do uniqueness checking on non-literals, so we don't allow them as alternate values.
// We could probably allow them, but there doesn't seem to be much of a use-case,
// and it's easier to give good error messages about duplicate values this way,
// rather than rustc errors on conflicting match branches.
die!(expr => format!("Only literals are allowed as num_enum alternate values"))
},
}
})
.collect::<Result<Vec<i128>>>()?;
sorted_alternate_int_values.sort_unstable();
let sorted_alternate_int_values = sorted_alternate_int_values;

// Check if the current discriminant is not in the alternative values.
if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
if let Ok(index) =
sorted_alternate_int_values.binary_search(&canonical_value_int)
{
die!(&raw_alternative_values[index] => format!("'{}' in the alternative values is already attributed as the discriminant of this variant", canonical_value_int));

This comment has been minimized.

Copy link
@GilShoshan94

GilShoshan94 Jan 15, 2023

Contributor

Here I think there is a minor error, I think that in the case of a die!, we won't have necesseraly the correct alternative value underlined. It's because index in sorted_alternate_int_values is not in sync with raw_alternative_values.

As an example try with an out of order alternative values such as [9,1,5,6,0].

That's why in my code before I keep around what you call now alternate_values and cloned it into a sorted version, that I used for other stuff, but here I do the if let Some(i) = alt_val.iter().position(|&x| x == canonical_value_int) where i is the index and is in sync with the alternative_values (that you now call raw_alternative_values)

}
}

// Search for duplicates, the vec is sorted. Warn about them.
if (1..alt_val_sorted.len()).any(|i| alt_val_sorted[i] == alt_val_sorted[i - 1])
{
if (1..sorted_alternate_int_values.len()).any(|i| {
sorted_alternate_int_values[i] == sorted_alternate_int_values[i - 1]
}) {
let attr = *alt_attr_ref.last().unwrap();
die!(attr => "There is duplication in the alternative values");
}
// Search if those alt_val where already attributed.
// (The val_set is BTreeSet, and iter().next_back() is the is the maximum in the set.)
if let Some(last_upper_val) = val_set.iter().next_back() {
if alt_val_sorted.first().unwrap() <= last_upper_val {
for (i, val) in alt_val_sorted.iter().enumerate() {
if val_set.contains(val) {
die!(&alternative_values[i] => format!("'{}' in the alternative values is already attributed to a previous variant", val));
// Search if those discriminant_int_val_set where already attributed.
// (discriminant_int_val_set is BTreeSet, and iter().next_back() is the is the maximum in the set.)
if let Some(last_upper_val) = discriminant_int_val_set.iter().next_back() {
if sorted_alternate_int_values.first().unwrap() <= last_upper_val {
for (i, val) in sorted_alternate_int_values.iter().enumerate() {
if discriminant_int_val_set.contains(val) {
die!(&raw_alternative_values[i] => format!("'{}' in the alternative values is already attributed to a previous variant", val));

This comment has been minimized.

Copy link
@GilShoshan94

GilShoshan94 Jan 15, 2023

Contributor

Same here about the non sync of the index, but I did the same mistake before you too...

The for loop should be on the unsorted alternate_int_values so the index stay in sync with raw_alternative_values

}
}
}
}

// Reconstruct the alternative_values vec of Expr but sorted.
alternative_values = alt_val_sorted
raw_alternative_values = sorted_alternate_int_values
.iter()
.map(|val| literal(val.to_owned()))
.collect();

// Add the alternative values to the the set to keep track.
val_set.extend(alt_val_sorted);
discriminant_int_val_set.extend(sorted_alternate_int_values);
}

// Add the current discriminant to the the set to keep track.
let newly_inserted = val_set.insert(canonical_value_int);
debug_assert!(newly_inserted);
if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
discriminant_int_val_set.insert(canonical_value_int);
}

variants.push(VariantInfo {
ident,
attr_spans,
is_default,
is_catch_all,
canonical_value,
alternative_values,
canonical_value: discriminant,
alternative_values: raw_alternative_values,
});

// Get the next value for the discriminant.
next_discriminant = literal(canonical_value_int + 1);
next_discriminant = match discriminant_value {
DiscriminantValue::Literal(int_value) => literal(int_value.wrapping_add(1)),
DiscriminantValue::Expr(expr) => {
parse_quote! {
#repr::wrapping_add(#expr, 1)
}
}
}
}

EnumInfo {
Expand Down

3 comments on commit 6604e2b

@GilShoshan94
Copy link
Contributor

@GilShoshan94 GilShoshan94 commented on 6604e2b Jan 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

Good stuff, I didn't know about the usecase non-literal discriminant values.
Sorry about the regression.

I reviewed your code change and it seems correct so far beside 2 minors issues.

@GilShoshan94
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@illicitonion
I just realize I commented on the commit and not on the PR.... sorry.

@GilShoshan94
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I submitted a PR to fix the minors issues.
Have a nice day.

Please sign in to comment.