From ad91a293a2db0018381b37b62de8ef9798f68fb3 Mon Sep 17 00:00:00 2001 From: Ed Page Date: Mon, 28 Aug 2023 14:20:44 -0500 Subject: [PATCH] refactor: Use more serde_untagged I felt this does a good job of cleaning up the code and by using it more, new uses are more likely to use it. Due to an error reporting limitation in `serde_untagged`, I'm not using this for some `MaybeWorkspace` types because it makes the errors worse. I also held off on some config visitors because they seemed more complicated and I didn't want to risk that code. --- src/cargo/util/interning.rs | 23 +--- src/cargo/util/semver_ext.rs | 24 +--- src/cargo/util/toml/mod.rs | 233 ++++++++++------------------------- 3 files changed, 72 insertions(+), 208 deletions(-) diff --git a/src/cargo/util/interning.rs b/src/cargo/util/interning.rs index 6d62b167f0e0..2e3848eaaf1f 100644 --- a/src/cargo/util/interning.rs +++ b/src/cargo/util/interning.rs @@ -1,4 +1,5 @@ use serde::{Serialize, Serializer}; +use serde_untagged::UntaggedEnumVisitor; use std::borrow::Borrow; use std::cmp::Ordering; use std::collections::HashSet; @@ -150,28 +151,14 @@ impl Serialize for InternedString { } } -struct InternedStringVisitor; - impl<'de> serde::Deserialize<'de> for InternedString { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { - deserializer.deserialize_str(InternedStringVisitor) - } -} - -impl<'de> serde::de::Visitor<'de> for InternedStringVisitor { - type Value = InternedString; - - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - formatter.write_str("an String like thing") - } - - fn visit_str(self, v: &str) -> Result - where - E: serde::de::Error, - { - Ok(InternedString::new(v)) + UntaggedEnumVisitor::new() + .expecting("an String like thing") + .string(|value| Ok(InternedString::new(value))) + .deserialize(deserializer) } } diff --git a/src/cargo/util/semver_ext.rs b/src/cargo/util/semver_ext.rs index 9e0d3e50efe0..5efa1d2f173e 100644 --- a/src/cargo/util/semver_ext.rs +++ b/src/cargo/util/semver_ext.rs @@ -1,4 +1,5 @@ use semver::{Comparator, Op, Version, VersionReq}; +use serde_untagged::UntaggedEnumVisitor; use std::fmt::{self, Display}; #[derive(PartialEq, Eq, Hash, Clone, Debug)] @@ -198,25 +199,10 @@ impl<'de> serde::Deserialize<'de> for PartialVersion { where D: serde::Deserializer<'de>, { - struct VersionVisitor; - - impl<'de> serde::de::Visitor<'de> for VersionVisitor { - type Value = PartialVersion; - - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - formatter.write_str("SemVer version") - } - - fn visit_str(self, string: &str) -> Result - where - E: serde::de::Error, - { - string.parse().map_err(serde::de::Error::custom) - } - } - - let s = String::deserialize(deserializer)?; - s.parse().map_err(serde::de::Error::custom) + UntaggedEnumVisitor::new() + .expecting("SemVer version") + .string(|value| value.parse().map_err(serde::de::Error::custom)) + .deserialize(deserializer) } } diff --git a/src/cargo/util/toml/mod.rs b/src/cargo/util/toml/mod.rs index 1c73341f9873..af58206cb3ac 100644 --- a/src/cargo/util/toml/mod.rs +++ b/src/cargo/util/toml/mod.rs @@ -1,7 +1,6 @@ use std::collections::{BTreeMap, BTreeSet, HashMap}; use std::ffi::OsStr; use std::fmt::{self, Display, Write}; -use std::marker::PhantomData; use std::path::{Path, PathBuf}; use std::rc::Rc; use std::str::{self, FromStr}; @@ -213,34 +212,14 @@ impl<'de, P: Deserialize<'de> + Clone> de::Deserialize<'de> for TomlDependency

, { - struct TomlDependencyVisitor

(PhantomData

); - - impl<'de, P: Deserialize<'de> + Clone> de::Visitor<'de> for TomlDependencyVisitor

{ - type Value = TomlDependency

; - - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - formatter.write_str( - "a version string like \"0.9.8\" or a \ + UntaggedEnumVisitor::new() + .expecting( + "a version string like \"0.9.8\" or a \ detailed dependency like { version = \"0.9.8\" }", - ) - } - - fn visit_str(self, s: &str) -> Result - where - E: de::Error, - { - Ok(TomlDependency::Simple(s.to_owned())) - } - - fn visit_map(self, map: V) -> Result - where - V: de::MapAccess<'de>, - { - let mvd = de::value::MapAccessDeserializer::new(map); - DetailedTomlDependency::deserialize(mvd).map(TomlDependency::Detailed) - } - } - deserializer.deserialize_any(TomlDependencyVisitor(PhantomData)) + ) + .string(|value| Ok(TomlDependency::Simple(value.to_owned()))) + .map(|value| value.deserialize().map(TomlDependency::Detailed)) + .deserialize(deserializer) } } @@ -400,39 +379,22 @@ impl<'de> de::Deserialize<'de> for TomlOptLevel { where D: de::Deserializer<'de>, { - struct Visitor; - - impl<'de> de::Visitor<'de> for Visitor { - type Value = TomlOptLevel; - - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - formatter.write_str("an optimization level") - } - - fn visit_i64(self, value: i64) -> Result - where - E: de::Error, - { - Ok(TomlOptLevel(value.to_string())) - } - - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { + use serde::de::Error as _; + UntaggedEnumVisitor::new() + .expecting("an optimization level") + .i64(|value| Ok(TomlOptLevel(value.to_string()))) + .string(|value| { if value == "s" || value == "z" { Ok(TomlOptLevel(value.to_string())) } else { - Err(E::custom(format!( + Err(serde_untagged::de::Error::custom(format!( "must be `0`, `1`, `2`, `3`, `s` or `z`, \ but found the string: \"{}\"", value ))) } - } - } - - d.deserialize_any(Visitor) + }) + .deserialize(d) } } @@ -477,58 +439,48 @@ impl<'de> de::Deserialize<'de> for TomlDebugInfo { where D: de::Deserializer<'de>, { - struct Visitor; - - impl<'de> de::Visitor<'de> for Visitor { - type Value = TomlDebugInfo; - - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - formatter.write_str( - "a boolean, 0, 1, 2, \"line-tables-only\", or \"line-directives-only\"", - ) - } - - fn visit_i64(self, value: i64) -> Result - where - E: de::Error, - { + use serde::de::Error as _; + let expecting = "a boolean, 0, 1, 2, \"line-tables-only\", or \"line-directives-only\""; + UntaggedEnumVisitor::new() + .expecting(expecting) + .bool(|value| { + Ok(if value { + TomlDebugInfo::Full + } else { + TomlDebugInfo::None + }) + }) + .i64(|value| { let debuginfo = match value { 0 => TomlDebugInfo::None, 1 => TomlDebugInfo::Limited, 2 => TomlDebugInfo::Full, - _ => return Err(de::Error::invalid_value(Unexpected::Signed(value), &self)), + _ => { + return Err(serde_untagged::de::Error::invalid_value( + Unexpected::Signed(value), + &expecting, + )) + } }; Ok(debuginfo) - } - - fn visit_bool(self, v: bool) -> Result - where - E: de::Error, - { - Ok(if v { - TomlDebugInfo::Full - } else { - TomlDebugInfo::None - }) - } - - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { + }) + .string(|value| { let debuginfo = match value { "none" => TomlDebugInfo::None, "limited" => TomlDebugInfo::Limited, "full" => TomlDebugInfo::Full, "line-directives-only" => TomlDebugInfo::LineDirectivesOnly, "line-tables-only" => TomlDebugInfo::LineTablesOnly, - _ => return Err(de::Error::invalid_value(Unexpected::Str(value), &self)), + _ => { + return Err(serde_untagged::de::Error::invalid_value( + Unexpected::Str(value), + &expecting, + )) + } }; Ok(debuginfo) - } - } - - d.deserialize_any(Visitor) + }) + .deserialize(d) } } @@ -927,32 +879,11 @@ impl<'de> de::Deserialize<'de> for StringOrVec { where D: de::Deserializer<'de>, { - struct Visitor; - - impl<'de> de::Visitor<'de> for Visitor { - type Value = StringOrVec; - - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - formatter.write_str("string or list of strings") - } - - fn visit_str(self, s: &str) -> Result - where - E: de::Error, - { - Ok(StringOrVec(vec![s.to_string()])) - } - - fn visit_seq(self, v: V) -> Result - where - V: de::SeqAccess<'de>, - { - let seq = de::value::SeqAccessDeserializer::new(v); - Vec::deserialize(seq).map(StringOrVec) - } - } - - deserializer.deserialize_any(Visitor) + UntaggedEnumVisitor::new() + .expecting("string or list of strings") + .string(|value| Ok(StringOrVec(vec![value.to_owned()]))) + .seq(|value| value.deserialize().map(StringOrVec)) + .deserialize(deserializer) } } @@ -975,8 +906,8 @@ impl<'de> Deserialize<'de> for StringOrBool { D: de::Deserializer<'de>, { UntaggedEnumVisitor::new() - .string(|s| Ok(StringOrBool::String(s.to_owned()))) .bool(|b| Ok(StringOrBool::Bool(b))) + .string(|s| Ok(StringOrBool::String(s.to_owned()))) .deserialize(deserializer) } } @@ -993,32 +924,11 @@ impl<'de> de::Deserialize<'de> for VecStringOrBool { where D: de::Deserializer<'de>, { - struct Visitor; - - impl<'de> de::Visitor<'de> for Visitor { - type Value = VecStringOrBool; - - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - formatter.write_str("a boolean or vector of strings") - } - - fn visit_seq(self, v: V) -> Result - where - V: de::SeqAccess<'de>, - { - let seq = de::value::SeqAccessDeserializer::new(v); - Vec::deserialize(seq).map(VecStringOrBool::VecString) - } - - fn visit_bool(self, b: bool) -> Result - where - E: de::Error, - { - Ok(VecStringOrBool::Bool(b)) - } - } - - deserializer.deserialize_any(Visitor) + UntaggedEnumVisitor::new() + .expecting("a boolean or vector of strings") + .bool(|value| Ok(VecStringOrBool::Bool(value))) + .seq(|value| value.deserialize().map(VecStringOrBool::VecString)) + .deserialize(deserializer) } } @@ -1026,35 +936,16 @@ fn version_trim_whitespace<'de, D>(deserializer: D) -> Result, { - struct Visitor; - - impl<'de> de::Visitor<'de> for Visitor { - type Value = MaybeWorkspaceSemverVersion; - - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - formatter.write_str("SemVer version") - } - - fn visit_str(self, string: &str) -> Result - where - E: de::Error, - { - match string.trim().parse().map_err(de::Error::custom) { + UntaggedEnumVisitor::new() + .expecting("SemVer version") + .string( + |value| match value.trim().parse().map_err(de::Error::custom) { Ok(parsed) => Ok(MaybeWorkspace::Defined(parsed)), Err(e) => Err(e), - } - } - - fn visit_map(self, map: V) -> Result - where - V: de::MapAccess<'de>, - { - let mvd = de::value::MapAccessDeserializer::new(map); - TomlWorkspaceField::deserialize(mvd).map(MaybeWorkspace::Workspace) - } - } - - deserializer.deserialize_any(Visitor) + }, + ) + .map(|value| value.deserialize().map(MaybeWorkspace::Workspace)) + .deserialize(deserializer) } /// This Trait exists to make [`MaybeWorkspace::Workspace`] generic. It makes deserialization of