From 5430df7a0f0434b5c7e0877d6b75fca1f251c870 Mon Sep 17 00:00:00 2001 From: Kurtis Nusbaum Date: Wed, 1 Feb 2023 04:45:13 +0000 Subject: [PATCH] feat(derive): Support `#[group]` attributes This adds the ability derive additional options for the group creation. Fixes #4574 --- clap_derive/src/derives/args.rs | 8 +- clap_derive/src/item.rs | 45 +++++--- examples/tutorial_derive/04_03_relations.rs | 41 +++---- src/_derive/_tutorial.rs | 3 + src/_derive/mod.rs | 11 +- tests/derive/flatten.rs | 40 ------- tests/derive/groups.rs | 119 ++++++++++++++++++++ tests/derive_ui/group_name_attribute.rs | 23 ++++ tests/derive_ui/group_name_attribute.stderr | 5 + 9 files changed, 216 insertions(+), 79 deletions(-) create mode 100644 tests/derive_ui/group_name_attribute.rs create mode 100644 tests/derive_ui/group_name_attribute.stderr diff --git a/clap_derive/src/derives/args.rs b/clap_derive/src/derives/args.rs index ab2ce319270..e8611b6bdd6 100644 --- a/clap_derive/src/derives/args.rs +++ b/clap_derive/src/derives/args.rs @@ -14,7 +14,6 @@ use proc_macro2::{Ident, Span, TokenStream}; use quote::{format_ident, quote, quote_spanned}; -use syn::ext::IdentExt; use syn::{ punctuated::Punctuated, spanned::Spanned, token::Comma, Data, DataStruct, DeriveInput, Field, Fields, Generics, @@ -89,7 +88,7 @@ pub fn gen_for_struct( let group_id = if item.skip_group() { quote!(None) } else { - let group_id = item.ident().unraw().to_string(); + let group_id = item.group_id(); quote!(Some(clap::Id::from(#group_id))) }; @@ -368,7 +367,7 @@ pub fn gen_augment( let group_app_methods = if parent_item.skip_group() { quote!() } else { - let group_id = parent_item.ident().unraw().to_string(); + let group_id = parent_item.group_id(); let literal_group_members = fields .iter() .filter_map(|(_field, item)| { @@ -401,10 +400,13 @@ pub fn gen_augment( }}; } + let group_methods = parent_item.group_methods(); + quote!( .group( clap::ArgGroup::new(#group_id) .multiple(true) + #group_methods .args(#literal_group_members) ) ) diff --git a/clap_derive/src/item.rs b/clap_derive/src/item.rs index 3aab7ccc992..9b29ff9e8a8 100644 --- a/clap_derive/src/item.rs +++ b/clap_derive/src/item.rs @@ -32,7 +32,6 @@ pub const DEFAULT_ENV_CASING: CasingStyle = CasingStyle::ScreamingSnake; #[derive(Clone)] pub struct Item { name: Name, - ident: Ident, casing: Sp, env_casing: Sp, ty: Option, @@ -48,6 +47,8 @@ pub struct Item { is_enum: bool, is_positional: bool, skip_group: bool, + group_id: Name, + group_methods: Vec, kind: Sp, } @@ -254,9 +255,9 @@ impl Item { env_casing: Sp, kind: Sp, ) -> Self { + let group_id = Name::Derived(ident); Self { name, - ident, ty, casing, env_casing, @@ -272,6 +273,8 @@ impl Item { is_enum: false, is_positional: true, skip_group: false, + group_id, + group_methods: vec![], kind, } } @@ -294,10 +297,15 @@ impl Item { kind.as_str() ), }); + self.name = Name::Assigned(arg); + } + AttrKind::Group => { + self.group_id = Name::Assigned(arg); + } + AttrKind::Arg | AttrKind::Clap | AttrKind::StructOpt => { + self.name = Name::Assigned(arg); } - AttrKind::Group | AttrKind::Arg | AttrKind::Clap | AttrKind::StructOpt => {} } - self.name = Name::Assigned(arg); } else if name == "name" { match kind { AttrKind::Arg => { @@ -312,14 +320,13 @@ impl Item { kind.as_str() ), }); + self.name = Name::Assigned(arg); + } + AttrKind::Group => self.group_methods.push(Method::new(name, arg)), + AttrKind::Command | AttrKind::Value | AttrKind::Clap | AttrKind::StructOpt => { + self.name = Name::Assigned(arg); } - AttrKind::Group - | AttrKind::Command - | AttrKind::Value - | AttrKind::Clap - | AttrKind::StructOpt => {} } - self.name = Name::Assigned(arg); } else if name == "value_parser" { self.value_parser = Some(ValueParser::Explicit(Method::new(name, arg))); } else if name == "action" { @@ -328,7 +335,10 @@ impl Item { if name == "short" || name == "long" { self.is_positional = false; } - self.methods.push(Method::new(name, arg)); + match kind { + AttrKind::Group => self.group_methods.push(Method::new(name, arg)), + _ => self.methods.push(Method::new(name, arg)), + }; } } @@ -972,6 +982,15 @@ impl Item { quote!( #(#doc_comment)* #(#methods)* ) } + pub fn group_id(&self) -> TokenStream { + self.group_id.clone().raw() + } + + pub fn group_methods(&self) -> TokenStream { + let group_methods = &self.group_methods; + quote!( #(#group_methods)* ) + } + pub fn deprecations(&self) -> proc_macro2::TokenStream { let deprecations = &self.deprecations; quote!( #(#deprecations)* ) @@ -987,10 +1006,6 @@ impl Item { quote!( #(#next_help_heading)* ) } - pub fn ident(&self) -> &Ident { - &self.ident - } - pub fn id(&self) -> TokenStream { self.name.clone().raw() } diff --git a/examples/tutorial_derive/04_03_relations.rs b/examples/tutorial_derive/04_03_relations.rs index cbe491deb4b..8657ebe8372 100644 --- a/examples/tutorial_derive/04_03_relations.rs +++ b/examples/tutorial_derive/04_03_relations.rs @@ -1,13 +1,26 @@ -use clap::{ArgGroup, Parser}; +use clap::{Args, Parser}; #[derive(Parser)] #[command(author, version, about, long_about = None)] -#[command(group( - ArgGroup::new("vers") - .required(true) - .args(["set_ver", "major", "minor", "patch"]), - ))] struct Cli { + #[command(flatten)] + vers: Vers, + + /// some regular input + #[arg(group = "input")] + input_file: Option, + + /// some special input argument + #[arg(long, group = "input")] + spec_in: Option, + + #[arg(short, requires = "input")] + config: Option, +} + +#[derive(Args)] +#[group(required = true, multiple = false)] +struct Vers { /// set version manually #[arg(long, value_name = "VER")] set_ver: Option, @@ -23,17 +36,6 @@ struct Cli { /// auto inc patch #[arg(long)] patch: bool, - - /// some regular input - #[arg(group = "input")] - input_file: Option, - - /// some special input argument - #[arg(long, group = "input")] - spec_in: Option, - - #[arg(short, requires = "input")] - config: Option, } fn main() { @@ -45,11 +47,12 @@ fn main() { let mut patch = 3; // See if --set_ver was used to set the version manually - let version = if let Some(ver) = cli.set_ver.as_deref() { + let vers = &cli.vers; + let version = if let Some(ver) = vers.set_ver.as_deref() { ver.to_string() } else { // Increment the one requested (in a real program, we'd reset the lower numbers) - let (maj, min, pat) = (cli.major, cli.minor, cli.patch); + let (maj, min, pat) = (vers.major, vers.minor, vers.patch); match (maj, min, pat) { (true, _, _) => major += 1, (_, true, _) => minor += 1, diff --git a/src/_derive/_tutorial.rs b/src/_derive/_tutorial.rs index f3f55c3392d..8d00b03ec3f 100644 --- a/src/_derive/_tutorial.rs +++ b/src/_derive/_tutorial.rs @@ -202,6 +202,9 @@ //! want one of them to be required, but making all of them required isn't feasible because perhaps //! they conflict with each other. //! +//! [`ArgGroup`][crate::ArgGroup]s are automatically created for a `struct` with its +//! [`ArgGroup::id`][crate::ArgGroup::id] being the struct's name. +//! //! ```rust #![doc = include_str!("../../examples/tutorial_derive/04_03_relations.rs")] //! ``` diff --git a/src/_derive/mod.rs b/src/_derive/mod.rs index a92b7c87b04..6bde6033f10 100644 --- a/src/_derive/mod.rs +++ b/src/_derive/mod.rs @@ -194,7 +194,14 @@ //! These correspond to the [`ArgGroup`][crate::ArgGroup] which is implicitly created for each //! `Args` derive. //! -//! At the moment, only `#[group(skip)]` is supported +//! **Raw attributes:** Any [`ArgGroup` method][crate::ArgGroup] can also be used as an attribute, see [Terminology](#terminology) for syntax. +//! - e.g. `#[group(required = true)]` would translate to `arg_group.required(true)` +//! +//! **Magic attributes**: +//! - `id = `: [`ArgGroup::id`][crate::ArgGroup::id] +//! - When not present: struct's name is used +//! - `skip [= ]`: Ignore this field, filling in with `` +//! - Without ``: fills the field with `Default::default()` //! //! ### Arg Attributes //! @@ -205,7 +212,7 @@ //! //! **Magic attributes**: //! - `id = `: [`Arg::id`][crate::Arg::id] -//! - When not present: case-converted field name is used +//! - When not present: field's name is used //! - `value_parser [= ]`: [`Arg::value_parser`][crate::Arg::value_parser] //! - When not present: will auto-select an implementation based on the field type using //! [`value_parser!`][crate::value_parser!] diff --git a/tests/derive/flatten.rs b/tests/derive/flatten.rs index 86468eb5592..bdd0387338c 100644 --- a/tests/derive/flatten.rs +++ b/tests/derive/flatten.rs @@ -255,43 +255,3 @@ fn docstrings_ordering_with_multiple_clap_partial() { assert!(short_help.contains("This is the docstring for Flattened")); } - -#[test] -fn optional_flatten() { - #[derive(Parser, Debug, PartialEq, Eq)] - struct Opt { - #[command(flatten)] - source: Option, - } - - #[derive(clap::Args, Debug, PartialEq, Eq)] - struct Source { - crates: Vec, - #[arg(long)] - path: Option, - #[arg(long)] - git: Option, - } - - assert_eq!(Opt { source: None }, Opt::try_parse_from(["test"]).unwrap()); - assert_eq!( - Opt { - source: Some(Source { - crates: vec!["serde".to_owned()], - path: None, - git: None, - }), - }, - Opt::try_parse_from(["test", "serde"]).unwrap() - ); - assert_eq!( - Opt { - source: Some(Source { - crates: Vec::new(), - path: Some("./".into()), - git: None, - }), - }, - Opt::try_parse_from(["test", "--path=./"]).unwrap() - ); -} diff --git a/tests/derive/groups.rs b/tests/derive/groups.rs index a4750e2e9c4..43601c550de 100644 --- a/tests/derive/groups.rs +++ b/tests/derive/groups.rs @@ -92,6 +92,46 @@ fn skip_group_avoids_duplicate_ids() { assert_eq!(Opt::group_id(), None); } +#[test] +fn optional_flatten() { + #[derive(Parser, Debug, PartialEq, Eq)] + struct Opt { + #[command(flatten)] + source: Option, + } + + #[derive(clap::Args, Debug, PartialEq, Eq)] + struct Source { + crates: Vec, + #[arg(long)] + path: Option, + #[arg(long)] + git: Option, + } + + assert_eq!(Opt { source: None }, Opt::try_parse_from(["test"]).unwrap()); + assert_eq!( + Opt { + source: Some(Source { + crates: vec!["serde".to_owned()], + path: None, + git: None, + }), + }, + Opt::try_parse_from(["test", "serde"]).unwrap() + ); + assert_eq!( + Opt { + source: Some(Source { + crates: Vec::new(), + path: Some("./".into()), + git: None, + }), + }, + Opt::try_parse_from(["test", "--path=./"]).unwrap() + ); +} + #[test] #[should_panic = "\ Command clap: Argument group name must be unique @@ -120,3 +160,82 @@ fn helpful_panic_on_duplicate_groups() { use clap::CommandFactory; Opt::command().debug_assert(); } + +#[test] +fn custom_group_id() { + #[derive(Parser, Debug, PartialEq, Eq)] + struct Opt { + #[command(flatten)] + source: Option, + } + + #[derive(clap::Args, Debug, PartialEq, Eq)] + #[group(id = "source")] + struct Source { + crates: Vec, + #[arg(long)] + path: Option, + #[arg(long)] + git: Option, + } + + assert_eq!(Opt { source: None }, Opt::try_parse_from(["test"]).unwrap()); + assert_eq!( + Opt { + source: Some(Source { + crates: vec!["serde".to_owned()], + path: None, + git: None, + }), + }, + Opt::try_parse_from(["test", "serde"]).unwrap() + ); + assert_eq!( + Opt { + source: Some(Source { + crates: Vec::new(), + path: Some("./".into()), + git: None, + }), + }, + Opt::try_parse_from(["test", "--path=./"]).unwrap() + ); +} + +#[test] +fn required_group() { + #[derive(Parser, Debug, PartialEq, Eq)] + struct Opt { + #[command(flatten)] + source: Source, + } + + #[derive(clap::Args, Debug, PartialEq, Eq)] + #[group(required = true, multiple = false)] + struct Source { + #[arg(long)] + path: Option, + #[arg(long)] + git: Option, + } + + assert_eq!( + Opt { + source: Source { + path: Some("./".into()), + git: None, + }, + }, + Opt::try_parse_from(["test", "--path=./"]).unwrap() + ); + + const OUTPUT: &str = "\ +error: the following required arguments were not provided: + <--path |--git > + +Usage: test <--path |--git > + +For more information, try '--help'. +"; + assert_output::("test", OUTPUT, true); +} diff --git a/tests/derive_ui/group_name_attribute.rs b/tests/derive_ui/group_name_attribute.rs new file mode 100644 index 00000000000..059fc8921a5 --- /dev/null +++ b/tests/derive_ui/group_name_attribute.rs @@ -0,0 +1,23 @@ +use clap::Parser; + +#[derive(Parser, Debug)] +#[command(name = "basic")] +struct Opt { + #[command(flatten)] + source: Source, +} + +#[derive(clap::Args, Debug)] +#[group(required = true, name = "src")] +struct Source { + #[arg(short)] + git: String, + + #[arg(short)] + path: String, +} + +fn main() { + let opt = Opt::parse(); + println!("{:?}", opt); +} diff --git a/tests/derive_ui/group_name_attribute.stderr b/tests/derive_ui/group_name_attribute.stderr new file mode 100644 index 00000000000..67922f483b8 --- /dev/null +++ b/tests/derive_ui/group_name_attribute.stderr @@ -0,0 +1,5 @@ +error[E0599]: no method named `name` found for struct `ArgGroup` in the current scope + --> tests/derive_ui/group_name_attribute.rs:11:26 + | +11 | #[group(required = true, name = "src")] + | ^^^^ method not found in `ArgGroup`