|  | 
|  | 1 | +//! The definition of the ResultLabels derive macro, see | 
|  | 2 | +//! autometrics::ResultLabels for more information. | 
|  | 3 | +
 | 
|  | 4 | +use proc_macro2::TokenStream; | 
|  | 5 | +use quote::quote; | 
|  | 6 | +use syn::{ | 
|  | 7 | +    punctuated::Punctuated, token::Comma, Attribute, Data, DataEnum, DeriveInput, Error, Ident, | 
|  | 8 | +    Lit, LitStr, Result, Variant, | 
|  | 9 | +}; | 
|  | 10 | + | 
|  | 11 | +// These labels must match autometrics::ERROR_KEY and autometrics::OK_KEY, | 
|  | 12 | +// to avoid a dependency loop just for 2 constants we recreate these here. | 
|  | 13 | +const OK_KEY: &str = "ok"; | 
|  | 14 | +const ERROR_KEY: &str = "error"; | 
|  | 15 | +const RESULT_KEY: &str = "result"; | 
|  | 16 | +const ATTR_LABEL: &str = "label"; | 
|  | 17 | +const ACCEPTED_LABELS: [&str; 2] = [ERROR_KEY, OK_KEY]; | 
|  | 18 | + | 
|  | 19 | +/// Entry point of the ResultLabels macro | 
|  | 20 | +pub(crate) fn expand(input: DeriveInput) -> Result<TokenStream> { | 
|  | 21 | +    let variants = match &input.data { | 
|  | 22 | +        Data::Enum(DataEnum { variants, .. }) => variants, | 
|  | 23 | +        _ => { | 
|  | 24 | +            return Err(Error::new_spanned( | 
|  | 25 | +                input, | 
|  | 26 | +                "ResultLabels only works with 'Enum's.", | 
|  | 27 | +            )) | 
|  | 28 | +        } | 
|  | 29 | +    }; | 
|  | 30 | +    let enum_name = &input.ident; | 
|  | 31 | +    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); | 
|  | 32 | +    let conditional_clauses_for_labels = conditional_label_clauses(variants, enum_name)?; | 
|  | 33 | + | 
|  | 34 | +    Ok(quote! { | 
|  | 35 | +        #[automatically_derived] | 
|  | 36 | +        impl #impl_generics ::autometrics::__private::GetLabels for #enum_name #ty_generics #where_clause { | 
|  | 37 | +            fn __autometrics_get_labels(&self) -> Option<&'static str> { | 
|  | 38 | +                #conditional_clauses_for_labels | 
|  | 39 | +            } | 
|  | 40 | +        } | 
|  | 41 | +    }) | 
|  | 42 | +} | 
|  | 43 | + | 
|  | 44 | +/// Build the list of match clauses for the generated code. | 
|  | 45 | +fn conditional_label_clauses( | 
|  | 46 | +    variants: &Punctuated<Variant, Comma>, | 
|  | 47 | +    enum_name: &Ident, | 
|  | 48 | +) -> Result<TokenStream> { | 
|  | 49 | +    let clauses: Vec<TokenStream> = variants | 
|  | 50 | +        .iter() | 
|  | 51 | +        .map(|variant| { | 
|  | 52 | +            let variant_name = &variant.ident; | 
|  | 53 | +            let variant_matcher: TokenStream = match variant.fields { | 
|  | 54 | +                syn::Fields::Named(_) => quote! { #variant_name {..} }, | 
|  | 55 | +                syn::Fields::Unnamed(_) => quote! { #variant_name (_) }, | 
|  | 56 | +                syn::Fields::Unit => quote! { #variant_name }, | 
|  | 57 | +            }; | 
|  | 58 | +            if let Some(key) = extract_label_attribute(&variant.attrs)? { | 
|  | 59 | +                Ok(quote! [ | 
|  | 60 | +                    else if ::std::matches!(self, & #enum_name :: #variant_matcher) { | 
|  | 61 | +                       Some(#key) | 
|  | 62 | +                    } | 
|  | 63 | +                ]) | 
|  | 64 | +            } else { | 
|  | 65 | +                // Let the code flow through the last value | 
|  | 66 | +                Ok(quote! {}) | 
|  | 67 | +            } | 
|  | 68 | +        }) | 
|  | 69 | +        .collect::<Result<Vec<_>>>()?; | 
|  | 70 | + | 
|  | 71 | +    Ok(quote! [ | 
|  | 72 | +        if false { | 
|  | 73 | +            None | 
|  | 74 | +        } | 
|  | 75 | +        #(#clauses)* | 
|  | 76 | +        else { | 
|  | 77 | +            None | 
|  | 78 | +        } | 
|  | 79 | +    ]) | 
|  | 80 | +} | 
|  | 81 | + | 
|  | 82 | +/// Extract the wanted label from the annotation in the variant, if present. | 
|  | 83 | +/// The function looks for `#[label(result = "ok")]` kind of labels. | 
|  | 84 | +/// | 
|  | 85 | +/// ## Error cases | 
|  | 86 | +/// | 
|  | 87 | +/// The function will error out with the smallest possible span when: | 
|  | 88 | +/// | 
|  | 89 | +/// - The attribute on a variant is not a "list" type (so `#[label]` is not allowed), | 
|  | 90 | +/// - The key in the key value pair is not "result", as it's the only supported keyword | 
|  | 91 | +///   for now (so `#[label(non_existing_label = "ok")]` is not allowed), | 
|  | 92 | +/// - The value for the "result" label is not in the autometrics supported set (so | 
|  | 93 | +///   `#[label(result = "random label that will break queries")]` is not allowed) | 
|  | 94 | +fn extract_label_attribute(attrs: &[Attribute]) -> Result<Option<LitStr>> { | 
|  | 95 | +    attrs | 
|  | 96 | +            .iter() | 
|  | 97 | +            .find_map(|att| match att.parse_meta() { | 
|  | 98 | +                Ok(meta) => match &meta { | 
|  | 99 | +                    syn::Meta::List(list) => { | 
|  | 100 | +                        // Ignore attribute if it's not `label(...)` | 
|  | 101 | +                        if list.path.segments.len() != 1 || list.path.segments[0].ident != ATTR_LABEL { | 
|  | 102 | +                            return None; | 
|  | 103 | +                        } | 
|  | 104 | + | 
|  | 105 | +                        // Only lists are allowed | 
|  | 106 | +                        let pair = match list.nested.first() { | 
|  | 107 | +                            Some(syn::NestedMeta::Meta(syn::Meta::NameValue(pair))) => pair, | 
|  | 108 | +                            _ => return Some(Err(Error::new_spanned( | 
|  | 109 | +                            meta, | 
|  | 110 | +                            format!("Only `{ATTR_LABEL}({RESULT_KEY} = \"RES\")` (RES can be {OK_KEY:?} or {ERROR_KEY:?}) is supported"), | 
|  | 111 | +                            ))), | 
|  | 112 | +                        }; | 
|  | 113 | + | 
|  | 114 | +                        // Inside list, only 'result = ...' are allowed | 
|  | 115 | +                        if pair.path.segments.len() != 1 || pair.path.segments[0].ident != RESULT_KEY { | 
|  | 116 | +                            return Some(Err(Error::new_spanned( | 
|  | 117 | +                                pair.path.clone(), | 
|  | 118 | +                            format!("Only `{RESULT_KEY} = \"RES\"` (RES can be {OK_KEY:?} or {ERROR_KEY:?}) is supported"), | 
|  | 119 | +                            ))); | 
|  | 120 | +                        } | 
|  | 121 | + | 
|  | 122 | +                        // Inside 'result = val', 'val' must be a string literal | 
|  | 123 | +                        let lit_str = match pair.lit { | 
|  | 124 | +                            Lit::Str(ref lit_str) => lit_str, | 
|  | 125 | +                            _ => { | 
|  | 126 | +                            return Some(Err(Error::new_spanned( | 
|  | 127 | +                                &pair.lit, | 
|  | 128 | +                            format!("Only {OK_KEY:?} or {ERROR_KEY:?}, as string literals, are accepted as result values"), | 
|  | 129 | +                            ))); | 
|  | 130 | +                        } | 
|  | 131 | +                        }; | 
|  | 132 | + | 
|  | 133 | +                        // Inside 'result = val', 'val' must be one of the allowed string literals | 
|  | 134 | +                        if !ACCEPTED_LABELS.contains(&lit_str.value().as_str()) { | 
|  | 135 | +                            return Some(Err(Error::new_spanned( | 
|  | 136 | +                                    lit_str, | 
|  | 137 | +                            format!("Only {OK_KEY:?} or {ERROR_KEY:?} are accepted as result values"), | 
|  | 138 | +                            ))); | 
|  | 139 | +                        } | 
|  | 140 | + | 
|  | 141 | +                        Some(Ok(lit_str.clone())) | 
|  | 142 | +                    }, | 
|  | 143 | +                    syn::Meta::NameValue(nv) if nv.path.segments.len() == 1 && nv.path.segments[0].ident == ATTR_LABEL => { | 
|  | 144 | +                        Some(Err(Error::new_spanned( | 
|  | 145 | +                            nv, | 
|  | 146 | +                            format!("Only `{ATTR_LABEL}({RESULT_KEY} = \"RES\")` (RES can be {OK_KEY:?} or {ERROR_KEY:?}) is supported"), | 
|  | 147 | +                        ))) | 
|  | 148 | +                    }, | 
|  | 149 | +                    syn::Meta::Path(p) if p.segments.len() == 1 && p.segments[0].ident == ATTR_LABEL => { | 
|  | 150 | +                        Some(Err(Error::new_spanned( | 
|  | 151 | +                            p, | 
|  | 152 | +                            format!("Only `{ATTR_LABEL}({RESULT_KEY} = \"RES\")` (RES can be {OK_KEY:?} or {ERROR_KEY:?}) is supported"), | 
|  | 153 | +                        ))) | 
|  | 154 | +                    }, | 
|  | 155 | +                    _ => None, | 
|  | 156 | +                }, | 
|  | 157 | +                Err(e) => Some(Err(Error::new_spanned( | 
|  | 158 | +                    att, | 
|  | 159 | +                    format!("could not parse the meta attribute: {e}"), | 
|  | 160 | +                ))), | 
|  | 161 | +            }) | 
|  | 162 | +            .transpose() | 
|  | 163 | +} | 
0 commit comments