diff --git a/graphql_client/Cargo.toml b/graphql_client/Cargo.toml index 680d7068..6a8a74b1 100644 --- a/graphql_client/Cargo.toml +++ b/graphql_client/Cargo.toml @@ -24,7 +24,10 @@ graphql_query_derive = { path = "../graphql_query_derive", version = "0.14.0", o reqwest-crate = { package = "reqwest", version = ">=0.11, <=0.12", features = ["json"], default-features = false, optional = true } [features] -default = ["graphql_query_derive"] +default = ["derive-integer-id"] reqwest = ["reqwest-crate", "reqwest-crate/default-tls"] reqwest-rustls = ["reqwest-crate", "reqwest-crate/rustls-tls"] reqwest-blocking = ["reqwest-crate/blocking"] + +derive = ["graphql_query_derive"] +derive-integer-id = ["derive","graphql_query_derive/integer-id"] \ No newline at end of file diff --git a/graphql_client/src/serde_with.rs b/graphql_client/src/serde_with.rs index dcd5cd98..e1a58008 100644 --- a/graphql_client/src/serde_with.rs +++ b/graphql_client/src/serde_with.rs @@ -1,41 +1,440 @@ //! Helpers for overriding default serde implementations. -use serde::{Deserialize, Deserializer}; +use serde::de::{self, Deserializer, SeqAccess, Visitor}; +use serde::Deserialize; +use std::fmt; +use std::marker::PhantomData; -#[derive(Deserialize)] -#[serde(untagged)] -enum IntOrString { - Int(i64), - Str(String), +/// Our own visitor trait that allows us to deserialize GraphQL IDs. +/// +/// This is used by the codegen to enable String IDs to be deserialized from +/// either Strings or Integers even if they are nested in a list or optional. +/// +/// We can't use the Visitor since we want to override the default deserialization +/// behavior for base types and automatic nesting support. +pub trait GraphQLVisitor<'de>: Sized { + /// The name of the type that we are deserializing. + fn type_name() -> &'static str; + + /// Visit an integer + fn visit_i64(v: i64) -> Result + where + E: de::Error, + { + Err(de::Error::invalid_type( + de::Unexpected::Signed(v), + &Self::type_name(), + )) + } + + /// Visit an integer + fn visit_u64(v: u64) -> Result + where + E: de::Error, + { + Err(de::Error::invalid_type( + de::Unexpected::Unsigned(v), + &Self::type_name(), + )) + } + + /// Visit a borrowed string + fn visit_str(v: &str) -> Result + where + E: de::Error, + { + Err(de::Error::invalid_type( + de::Unexpected::Str(v), + &Self::type_name(), + )) + } + + /// Visit a string + fn visit_string(v: String) -> Result + where + E: de::Error, + { + Err(de::Error::invalid_type( + de::Unexpected::Str(&v), + &Self::type_name(), + )) + } + + /// Visit a missing optional value + fn visit_none() -> Result + where + E: de::Error, + { + Err(de::Error::invalid_type( + de::Unexpected::Option, + &Self::type_name(), + )) + } + + /// Visit a null value + fn visit_unit() -> Result + where + E: de::Error, + { + Err(de::Error::invalid_type( + de::Unexpected::Unit, + &Self::type_name(), + )) + } + + /// Visit a sequence + fn visit_seq(seq: A) -> Result + where + A: SeqAccess<'de>, + { + let _ = seq; + Err(de::Error::invalid_type( + de::Unexpected::Seq, + &Self::type_name(), + )) + } +} + +impl GraphQLVisitor<'_> for String { + fn type_name() -> &'static str { + "an ID" + } + + fn visit_i64(v: i64) -> Result + where + E: de::Error, + { + Ok(v.to_string()) + } + + fn visit_u64(v: u64) -> Result + where + E: de::Error, + { + Ok(v.to_string()) + } + + fn visit_str(v: &str) -> Result + where + E: de::Error, + { + Ok(v.to_string()) + } + + fn visit_string(v: String) -> Result + where + E: de::Error, + { + Ok(v) + } +} + +impl<'de, T: GraphQLVisitor<'de>> GraphQLVisitor<'de> for Option { + fn type_name() -> &'static str { + "an optional ID or sequence of IDs" + } + + fn visit_i64(v: i64) -> Result + where + E: de::Error, + { + T::visit_i64(v).map(Some) + } + + fn visit_u64(v: u64) -> Result + where + E: de::Error, + { + T::visit_u64(v).map(Some) + } + + fn visit_str(v: &str) -> Result + where + E: de::Error, + { + T::visit_str(v).map(Some) + } + + fn visit_string(v: String) -> Result + where + E: de::Error, + { + T::visit_string(v).map(Some) + } + + fn visit_none() -> Result + where + E: de::Error, + { + Ok(None) + } + + fn visit_unit() -> Result + where + E: de::Error, + { + Ok(None) + } + + fn visit_seq(seq: A) -> Result + where + A: SeqAccess<'de>, + { + T::visit_seq(seq).map(Some) + } } -impl From for String { - fn from(value: IntOrString) -> Self { - match value { - IntOrString::Int(n) => n.to_string(), - IntOrString::Str(s) => s, +impl<'de, T: GraphQLVisitor<'de>> GraphQLVisitor<'de> for Vec { + fn type_name() -> &'static str { + "a sequence of IDs" + } + + fn visit_seq(mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + struct Id(T); + + impl<'de, T> Deserialize<'de> for Id + where + T: GraphQLVisitor<'de>, + { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserialize_id(deserializer).map(Id) + } } + + let mut vec = Vec::with_capacity(seq.size_hint().unwrap_or(0)); + while let Some(Id(elem)) = seq.next_element()? { + vec.push(elem); + } + Ok(vec) } } -/// Deserialize an optional ID type from either a String or an Integer representation. -/// -/// This is used by the codegen to enable String IDs to be deserialized from -/// either Strings or Integers. -pub fn deserialize_option_id<'de, D>(deserializer: D) -> Result, D::Error> +struct IdVisitor { + phantom: PhantomData, +} + +impl<'de, T> Visitor<'de> for IdVisitor where - D: Deserializer<'de>, + T: GraphQLVisitor<'de>, { - Option::::deserialize(deserializer).map(|opt| opt.map(String::from)) + type Value = T; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("a string, integer, null, or a sequence of IDs") + } + + fn visit_i64(self, value: i64) -> Result + where + E: de::Error, + { + T::visit_i64(value) + } + + fn visit_u64(self, value: u64) -> Result + where + E: de::Error, + { + T::visit_u64(value) + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + T::visit_str(value) + } + + fn visit_string(self, value: String) -> Result + where + E: de::Error, + { + T::visit_str(&value) + } + + fn visit_none(self) -> Result + where + E: de::Error, + { + T::visit_none() + } + + fn visit_unit(self) -> Result + where + E: de::Error, + { + T::visit_unit() + } + + fn visit_some(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_any(self) + } + + fn visit_seq(self, seq: A) -> Result + where + A: SeqAccess<'de>, + { + T::visit_seq(seq) + } } -/// Deserialize an ID type from either a String or an Integer representation. +/// Generic deserializer for GraphQL ID types. /// -/// This is used by the codegen to enable String IDs to be deserialized from -/// either Strings or Integers. -pub fn deserialize_id<'de, D>(deserializer: D) -> Result +/// It can deserialize IDs from a string or an integer. +/// It supports optional values and lists of IDs. +pub fn deserialize_id<'de, D, T>(deserializer: D) -> Result where D: Deserializer<'de>, + T: GraphQLVisitor<'de>, { - IntOrString::deserialize(deserializer).map(String::from) + deserializer.deserialize_any(IdVisitor { + phantom: PhantomData, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug, PartialEq, Deserialize)] + struct Test { + #[serde(deserialize_with = "deserialize_id")] + pub id: String, + #[serde(deserialize_with = "deserialize_id")] + pub id_opt: Option, + #[serde(deserialize_with = "deserialize_id")] + pub id_seq: Vec, + #[serde(deserialize_with = "deserialize_id")] + pub id_opt_seq: Option>, + #[serde(deserialize_with = "deserialize_id")] + pub id_opt_seq_opt: Option>>, + } + + #[derive(Debug, PartialEq, Deserialize)] + struct NestedTest { + #[serde(deserialize_with = "deserialize_id")] + pub nested: Vec>, + #[serde(deserialize_with = "deserialize_id")] + pub opt_nested: Option>>>>, + } + + #[derive(Debug, PartialEq, Deserialize)] + struct Test2 { + #[serde(deserialize_with = "deserialize_id")] + pub id_opt_seq_opt: Option>>, + } + + #[test] + fn test_deserialize_string() { + let test = serde_json::from_str::( + r#"{"id": "123", "id_opt": "123", "id_seq": ["123", "456"], "id_opt_seq": ["123"], "id_opt_seq_opt": ["123", "456"]}"#, + ).unwrap(); + assert_eq!(test.id, "123".to_string()); + assert_eq!(test.id_opt, Some("123".to_string())); + assert_eq!(test.id_seq, vec!["123".to_string(), "456".to_string()]); + assert_eq!(test.id_opt_seq, Some(vec!["123".to_string()])); + assert_eq!( + test.id_opt_seq_opt, + Some(vec![Some("123".to_string()), Some("456".to_string())]) + ); + } + + #[test] + fn test_deserialize_integer() { + let test = serde_json::from_str::( + r#"{"id": 123, "id_opt": 123, "id_seq": [123, 456], "id_opt_seq": [123], "id_opt_seq_opt": [123, 456]}"#, + ).unwrap(); + assert_eq!(test.id, "123".to_string()); + assert_eq!(test.id_opt, Some("123".to_string())); + assert_eq!(test.id_seq, vec!["123".to_string(), "456".to_string()]); + assert_eq!(test.id_opt_seq, Some(vec!["123".to_string()])); + assert_eq!( + test.id_opt_seq_opt, + Some(vec![Some("123".to_string()), Some("456".to_string())]) + ); + } + + #[test] + fn test_deserialize_mixed() { + let test = serde_json::from_str::( + r#"{"id": 123, "id_opt": null, "id_seq": [123, "456"], "id_opt_seq": null, "id_opt_seq_opt": [123, null, "456"]}"#, + ) + .unwrap(); + assert_eq!(test.id, "123".to_string()); + assert_eq!(test.id_opt, None); + assert_eq!(test.id_seq, vec!["123".to_string(), "456".to_string()]); + assert_eq!(test.id_opt_seq, None); + assert_eq!( + test.id_opt_seq_opt, + Some(vec![Some("123".to_string()), None, Some("456".to_string())]) + ); + } + + #[test] + fn test_deserialize_unexpected_list_id_null() { + let test = serde_json::from_str::( + r#"{"id": "123", "id_opt": "123", "id_seq": ["123", null, "456"], "id_opt_seq": null, "id_opt_seq_opt": ["123", null, "456"]}"#, + ) + .unwrap_err(); + assert_eq!( + test.to_string(), + "invalid type: null, expected an ID at line 1 column 53" + ); + } + + #[test] + fn test_deserialize_unexpected_list_null() { + let test = serde_json::from_str::( + r#"{"id": "123", "id_opt": "123", "id_seq": null, "id_opt_seq": null, "id_opt_seq_opt": ["123", null, "456"]}"#, + ) + .unwrap_err(); + assert_eq!( + test.to_string(), + "invalid type: null, expected a sequence of IDs at line 1 column 45" + ); + } + + #[test] + fn test_deserialize_unexpected_id_null() { + let test = serde_json::from_str::( + r#"{"id": null, "id_opt": "123", "id_seq": null, "id_opt_seq": null, "id_opt_seq_opt": ["123", null, "456"]}"#, + ) + .unwrap_err(); + assert_eq!( + test.to_string(), + "invalid type: null, expected an ID at line 1 column 11" + ); + } + + #[test] + fn test_deserialize_nested() { + let test = serde_json::from_str::( + r#"{"nested": [["123", 789, "456"]], "opt_nested": [["123", null, "456"]]}"#, + ) + .unwrap(); + assert_eq!( + test.nested, + vec![vec![ + "123".to_string(), + "789".to_string(), + "456".to_string() + ]] + ); + assert_eq!( + test.opt_nested, + Some(vec![Some(vec![ + Some("123".to_string()), + None, + Some("456".to_string()) + ])]) + ); + } } diff --git a/graphql_client_codegen/Cargo.toml b/graphql_client_codegen/Cargo.toml index 8f30cbb4..1dd76f40 100644 --- a/graphql_client_codegen/Cargo.toml +++ b/graphql_client_codegen/Cargo.toml @@ -7,6 +7,9 @@ license = "Apache-2.0 OR MIT" repository = "https://github.com/graphql-rust/graphql-client" edition = "2018" +[features] +integer-id = [] + [dependencies] graphql-introspection-query = { version = "0.2.0", path = "../graphql-introspection-query" } graphql-parser = "0.4" @@ -16,4 +19,4 @@ proc-macro2 = { version = "^1.0", features = [] } quote = "^1.0" serde_json = "1.0" serde = { version = "^1.0", features = ["derive"] } -syn = { version = "^2.0", features = [ "full" ] } +syn = { version = "^2.0", features = ["full"] } diff --git a/graphql_client_codegen/src/codegen/selection.rs b/graphql_client_codegen/src/codegen/selection.rs index ec1703b8..760e820b 100644 --- a/graphql_client_codegen/src/codegen/selection.rs +++ b/graphql_client_codegen/src/codegen/selection.rs @@ -12,8 +12,7 @@ use crate::{ }, schema::{Schema, TypeId}, type_qualifiers::GraphqlTypeQualifier, - GraphQLClientCodegenOptions, - GeneralError, + GeneralError, GraphQLClientCodegenOptions, }; use heck::*; use proc_macro2::{Ident, Span, TokenStream}; @@ -43,12 +42,27 @@ pub(crate) fn render_response_data_fields<'a>( if let Some(custom_response_type) = options.custom_response_type() { if operation.selection_set.len() == 1 { let selection_id = operation.selection_set[0]; - let selection_field = query.query.get_selection(selection_id).as_selected_field() - .ok_or_else(|| GeneralError(format!("Custom response type {custom_response_type} will only work on fields")))?; - calculate_custom_response_type_selection(&mut expanded_selection, response_data_type_id, custom_response_type, selection_id, selection_field); + let selection_field = query + .query + .get_selection(selection_id) + .as_selected_field() + .ok_or_else(|| { + GeneralError(format!( + "Custom response type {custom_response_type} will only work on fields" + )) + })?; + calculate_custom_response_type_selection( + &mut expanded_selection, + response_data_type_id, + custom_response_type, + selection_id, + selection_field, + ); return Ok(expanded_selection); } else { - return Err(GeneralError(format!("Custom response type {custom_response_type} requires single selection field"))); + return Err(GeneralError(format!( + "Custom response type {custom_response_type} requires single selection field" + ))); } } @@ -68,8 +82,8 @@ fn calculate_custom_response_type_selection<'a>( struct_id: ResponseTypeId, custom_response_type: &'a String, selection_id: SelectionId, - field: &'a SelectedField) -{ + field: &'a SelectedField, +) { let (graphql_name, rust_name) = context.field_name(field); let struct_name_string = full_path_prefix(selection_id, context.query); let field = context.query.schema.get_field(field.field_id); @@ -451,15 +465,8 @@ impl ExpandedField<'_> { }; let is_id = self.field_type == "ID"; - let is_required = self - .field_type_qualifiers - .contains(&GraphqlTypeQualifier::Required); - let id_deserialize_with = if is_id && is_required { + let id_deserialize_with = if is_id && cfg!(feature = "integer-id") { Some(quote!(#[serde(deserialize_with = "graphql_client::serde_with::deserialize_id")])) - } else if is_id { - Some( - quote!(#[serde(deserialize_with = "graphql_client::serde_with::deserialize_option_id")]), - ) } else { None }; diff --git a/graphql_query_derive/Cargo.toml b/graphql_query_derive/Cargo.toml index b48f5845..beec1c87 100644 --- a/graphql_query_derive/Cargo.toml +++ b/graphql_query_derive/Cargo.toml @@ -10,6 +10,9 @@ edition = "2018" [lib] proc-macro = true +[features] +integer-id = ["graphql_client_codegen/integer-id"] + [dependencies] syn = { version = "^2.0", features = ["extra-traits"] } proc-macro2 = { version = "^1.0", features = [] }