diff --git a/crates/rattler_conda_types/src/match_spec/mod.rs b/crates/rattler_conda_types/src/match_spec/mod.rs index cdc8078e9..acdc3b84e 100644 --- a/crates/rattler_conda_types/src/match_spec/mod.rs +++ b/crates/rattler_conda_types/src/match_spec/mod.rs @@ -72,7 +72,7 @@ use matcher::StringMatcher; /// /// let spec = MatchSpec::from_str("foo=1.0=py27_0").unwrap(); /// assert_eq!(spec.name, Some("foo".to_string())); -/// assert_eq!(spec.version, Some(VersionSpec::from_str("1.0.*").unwrap())); +/// assert_eq!(spec.version, Some(VersionSpec::from_str("==1.0").unwrap())); /// assert_eq!(spec.build, Some(StringMatcher::from_str("py27_0").unwrap())); /// /// let spec = MatchSpec::from_str("conda-forge::foo[version=\"1.0.*\"]").unwrap(); diff --git a/crates/rattler_conda_types/src/match_spec/parse.rs b/crates/rattler_conda_types/src/match_spec/parse.rs index 64a8fcd8d..274cde275 100644 --- a/crates/rattler_conda_types/src/match_spec/parse.rs +++ b/crates/rattler_conda_types/src/match_spec/parse.rs @@ -373,6 +373,30 @@ fn parse(input: &str) -> Result { Cow::Borrowed(version_str) }; + // Special case handling for version strings that start with `=`. + let version_str = if let (Some(version_str), true) = + (version_str.strip_prefix("=="), build_str.is_none()) + { + // If the version starts with `==` and the build string is none we strip the `==` part. + Cow::Borrowed(version_str) + } else if let Some(version_str_part) = version_str.strip_prefix('=') { + let not_a_group = !version_str_part.contains(['=', ',', '|']); + if not_a_group { + // If the version starts with `=`, is not part of a group (e.g. 1|2) we append a * + // if it doesnt have one already. + if build_str.is_none() && !version_str_part.ends_with('*') { + Cow::Owned(format!("{version_str_part}*")) + } else { + Cow::Borrowed(version_str_part) + } + } else { + // Version string is part of a group, return the non-stripped version string + version_str + } + } else { + version_str + }; + // Parse the version spec match_spec.version = Some( VersionSpec::from_str(version_str.as_ref()) @@ -389,6 +413,8 @@ fn parse(input: &str) -> Result { #[cfg(test)] mod tests { + use serde::Serialize; + use std::collections::BTreeMap; use std::str::FromStr; use super::{ @@ -462,27 +488,6 @@ mod tests { assert_eq!(split_version_and_build("* *"), Ok(("*", Some("*")))); } - #[test] - fn test_match_spec() { - insta::assert_yaml_snapshot!([ - MatchSpec::from_str("python 3.8.* *_cpython").unwrap(), - MatchSpec::from_str("foo=1.0=py27_0").unwrap(), - MatchSpec::from_str("foo==1.0=py27_0").unwrap(), - ], - @r###" - --- - - name: python - version: 3.8.* - build: "*_cpython" - - name: foo - version: 1.0.* - build: py27_0 - - name: foo - version: "==1.0" - build: py27_0 - "###); - } - #[test] fn test_nameless_match_spec() { insta::assert_yaml_snapshot!([ @@ -566,4 +571,39 @@ mod tests { &[("version", "1.3,2.0")] ); } + + #[test] + fn test_from_str() { + // A list of matchspecs to parse. + // Please keep this list sorted. + let specs = [ + "blas *.* mkl", + "foo=1.0=py27_0", + "foo==1.0=py27_0", + "python 3.8.* *_cpython", + "pytorch=*=cuda*", + ]; + + #[derive(Serialize)] + #[serde(untagged)] + enum MatchSpecOrError { + Error { error: String }, + MatchSpec(MatchSpec), + } + + let evaluated: BTreeMap<_, _> = specs + .iter() + .map(|spec| { + ( + spec, + MatchSpec::from_str(spec) + .map(MatchSpecOrError::MatchSpec) + .unwrap_or_else(|err| MatchSpecOrError::Error { + error: err.to_string(), + }), + ) + }) + .collect(); + insta::assert_yaml_snapshot!("parsed matchspecs", evaluated); + } } diff --git a/crates/rattler_conda_types/src/match_spec/snapshots/rattler_conda_types__match_spec__parse__tests__parsed matchspecs.snap b/crates/rattler_conda_types/src/match_spec/snapshots/rattler_conda_types__match_spec__parse__tests__parsed matchspecs.snap new file mode 100644 index 000000000..e457cc479 --- /dev/null +++ b/crates/rattler_conda_types/src/match_spec/snapshots/rattler_conda_types__match_spec__parse__tests__parsed matchspecs.snap @@ -0,0 +1,25 @@ +--- +source: crates/rattler_conda_types/src/match_spec/parse.rs +expression: evaluated +--- +blas *.* mkl: + name: blas + version: "*" + build: mkl +foo=1.0=py27_0: + name: foo + version: "==1.0" + build: py27_0 +foo==1.0=py27_0: + name: foo + version: "==1.0" + build: py27_0 +python 3.8.* *_cpython: + name: python + version: 3.8.* + build: "*_cpython" +pytorch=*=cuda*: + name: pytorch + version: "*" + build: cuda* + diff --git a/crates/rattler_conda_types/src/version_spec/version_tree.rs b/crates/rattler_conda_types/src/version_spec/version_tree.rs index e8f370eda..502865c3c 100644 --- a/crates/rattler_conda_types/src/version_spec/version_tree.rs +++ b/crates/rattler_conda_types/src/version_spec/version_tree.rs @@ -88,8 +88,8 @@ fn recognize_constraint<'a, E: ParseError<&'a str> + ContextError<&'a str>>( input: &'a str, ) -> Result<(&'a str, &'a str), nom::Err> { alt(( - // Any - tag("*"), + // Any (* or *.*) + terminated(tag("*"), cut(opt(tag(".*")))), // Regex recognize(delimited(tag("^"), not(tag("$")), tag("$"))), // Version with optional operator followed by optional glob.