Skip to content

Commit

Permalink
EnumDiscriminant inherits the repr and discriminant values (#288)
Browse files Browse the repository at this point in the history
* moved repr extraction to helpers

* repr pass-through added

* add discriminant pass through

* remove dev artifact

---------

Co-authored-by: Jason Scatena <jscatena@amazon.com>
  • Loading branch information
jscatena88 and Jason Scatena authored Oct 29, 2023
1 parent d32af44 commit e8b2ff1
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 51 deletions.
12 changes: 12 additions & 0 deletions strum_macros/src/helpers/type_props.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub struct StrumTypeProperties {
pub discriminant_others: Vec<TokenStream>,
pub discriminant_vis: Option<Visibility>,
pub use_phf: bool,
pub enum_repr: Option<TokenStream>,
}

impl HasTypeProperties for DeriveInput {
Expand Down Expand Up @@ -103,6 +104,17 @@ impl HasTypeProperties for DeriveInput {
}
}

let attrs = &self.attrs;
for attr in attrs {
if let Ok(list) = attr.meta.require_list() {
if let Some(ident) = list.path.get_ident() {
if ident == "repr" {
output.enum_repr = Some(list.tokens.clone())
}
}
}
}

Ok(output)
}
}
Expand Down
9 changes: 8 additions & 1 deletion strum_macros/src/macros/enum_discriminants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,16 @@ pub fn enum_discriminants_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
// Pass through all other attributes
let pass_though_attributes = type_properties.discriminant_others;

let repr = type_properties.enum_repr.map(|repr| quote!(#[repr(#repr)]));

// Add the variants without fields, but exclude the `strum` meta item
let mut discriminants = Vec::new();
for variant in variants {
let ident = &variant.ident;
let discriminant = variant
.discriminant
.as_ref()
.map(|(_, expr)| quote!( = #expr));

// Don't copy across the "strum" meta attribute. Only passthrough the whitelisted
// attributes and proxy `#[strum_discriminants(...)]` attributes
Expand Down Expand Up @@ -81,7 +87,7 @@ pub fn enum_discriminants_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
})
.collect::<Result<Vec<_>, _>>()?;

discriminants.push(quote! { #(#attrs)* #ident });
discriminants.push(quote! { #(#attrs)* #ident #discriminant});
}

// Ideally:
Expand Down Expand Up @@ -153,6 +159,7 @@ pub fn enum_discriminants_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
Ok(quote! {
/// Auto-generated discriminant enum variants
#derives
#repr
#(#[ #pass_though_attributes ])*
#discriminants_vis enum #discriminants_name {
#(#discriminants),*
Expand Down
66 changes: 18 additions & 48 deletions strum_macros/src/macros/from_repr.rs
Original file line number Diff line number Diff line change
@@ -1,62 +1,32 @@
use heck::ToShoutySnakeCase;
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, ToTokens};
use syn::{Data, DeriveInput, Fields, PathArguments, Type, TypeParen};
use quote::{format_ident, quote};
use syn::{Data, DeriveInput, Fields, Type};

use crate::helpers::{non_enum_error, HasStrumVariantProperties};
use crate::helpers::{non_enum_error, HasStrumVariantProperties, HasTypeProperties};

pub fn from_repr_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
let name = &ast.ident;
let gen = &ast.generics;
let (impl_generics, ty_generics, where_clause) = gen.split_for_impl();
let vis = &ast.vis;
let attrs = &ast.attrs;

let mut discriminant_type: Type = syn::parse("usize".parse().unwrap()).unwrap();
for attr in attrs {
let path = attr.path();

let mut ts = if let Ok(ts) = attr
.meta
.require_list()
.map(|metas| metas.to_token_stream().into_iter())
{
ts
} else {
continue;
};
// Discard the path
let _ = ts.next();
let tokens: TokenStream = ts.collect();

if path.leading_colon.is_some() {
continue;
}
if path.segments.len() != 1 {
continue;
}
let segment = path.segments.first().unwrap();
if segment.ident != "repr" {
continue;
}
if segment.arguments != PathArguments::None {
continue;
}
let typ_paren = match syn::parse2::<Type>(tokens.clone()) {
Ok(Type::Paren(TypeParen { elem, .. })) => *elem,
_ => continue,
};
let inner_path = match &typ_paren {
Type::Path(t) => t,
_ => continue,
};
if let Some(seg) = inner_path.path.segments.last() {
for t in &[
"u8", "u16", "u32", "u64", "usize", "i8", "i16", "i32", "i64", "isize",
] {
if seg.ident == t {
discriminant_type = typ_paren;
break;
if let Some(type_path) = ast
.get_type_properties()
.ok()
.and_then(|tp| tp.enum_repr)
.and_then(|repr_ts| syn::parse2::<Type>(repr_ts).ok())
{
if let Type::Path(path) = type_path.clone() {
if let Some(seg) = path.path.segments.last() {
for t in &[
"u8", "u16", "u32", "u64", "usize", "i8", "i16", "i32", "i64", "isize",
] {
if seg.ident == t {
discriminant_type = type_path;
break;
}
}
}
}
Expand Down
62 changes: 60 additions & 2 deletions strum_tests/tests/enum_discriminants.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use enum_variant_type::EnumVariantType;
use strum::{Display, EnumDiscriminants, EnumIter, EnumMessage, EnumString, IntoEnumIterator};
use std::mem::{align_of, size_of};

use enum_variant_type::EnumVariantType;
use strum::{
Display, EnumDiscriminants, EnumIter, EnumMessage, EnumString, FromRepr, IntoEnumIterator,
};

mod core {} // ensure macros call `::core`

Expand Down Expand Up @@ -305,3 +308,58 @@ fn crate_module_path_test() {

assert_eq!(expected, discriminants);
}

#[allow(dead_code)]
#[derive(EnumDiscriminants)]
#[repr(u16)]
enum WithReprUInt {
Variant0,
Variant1,
}

#[test]
fn with_repr_uint() {
// These tests would not be proof of proper functioning on a 16 bit system
assert_eq!(size_of::<u16>(), size_of::<WithReprUIntDiscriminants>());
assert_eq!(
size_of::<WithReprUInt>(),
size_of::<WithReprUIntDiscriminants>()
)
}

#[allow(dead_code)]
#[derive(EnumDiscriminants)]
#[repr(align(16), u8)]
enum WithReprAlign {
Variant0,
Variant1,
}

#[test]
fn with_repr_align() {
assert_eq!(
align_of::<WithReprAlign>(),
align_of::<WithReprAlignDiscriminants>()
);
assert_eq!(16, align_of::<WithReprAlignDiscriminants>());
}

#[allow(dead_code)]
#[derive(EnumDiscriminants)]
#[strum_discriminants(derive(FromRepr))]
enum WithExplicitDicriminantValue {
Variant0 = 42 + 100,
Variant1 = 11,
}

#[test]
fn with_explicit_discriminant_value() {
assert_eq!(
WithExplicitDicriminantValueDiscriminants::from_repr(11),
Some(WithExplicitDicriminantValueDiscriminants::Variant1)
);
assert_eq!(
142,
WithExplicitDicriminantValueDiscriminants::Variant0 as u8
);
}

0 comments on commit e8b2ff1

Please sign in to comment.