From cd3f1dda516cd80aa6ba9285e797b9218f8d2f17 Mon Sep 17 00:00:00 2001 From: Husqvik Date: Tue, 21 Nov 2023 06:08:59 +0100 Subject: [PATCH] #151 proper deserialization of union sub types --- .../GenerationContext.cs | 33 ++----- .../GraphQlGenerator.cs | 91 +++++++++++++------ src/GraphQlClientGenerator/NamingHelper.cs | 10 +- .../Unions | 21 ++--- 4 files changed, 88 insertions(+), 67 deletions(-) diff --git a/src/GraphQlClientGenerator/GenerationContext.cs b/src/GraphQlClientGenerator/GenerationContext.cs index a6f97aa..709d7bf 100644 --- a/src/GraphQlClientGenerator/GenerationContext.cs +++ b/src/GraphQlClientGenerator/GenerationContext.cs @@ -48,8 +48,12 @@ private bool FilterIfAllFieldsDeprecated(GraphQlField field) return true; var graphQlType = _complexTypes[fieldType.Name]; + var nestedTypeFields = + graphQlType.Kind is GraphQlTypeKind.Union + ? graphQlType.PossibleTypes.Select(t => _complexTypes[t.Name]).SelectMany(t => t.Fields) + : graphQlType.Fields; - return GetFields(graphQlType).Any(FilterIfDeprecated); + return nestedTypeFields.Any(FilterIfDeprecated); } public void Initialize(GraphQlGeneratorConfiguration configuration) @@ -111,21 +115,7 @@ public void Initialize(GraphQlGeneratorConfiguration configuration) public abstract void AfterGeneration(); protected internal List GetFieldsToGenerate(GraphQlType type) => - GetFields(type)?.Where(FilterIfDeprecated).Where(FilterIfAllFieldsDeprecated).ToList(); - - private IEnumerable GetFields(GraphQlType type) - { - if (type.Kind != GraphQlTypeKind.Union) - return type.Fields; - - var unionFields = new List(); - var unionFieldNames = new HashSet(); - foreach (var possibleType in type.PossibleTypes) - if (_complexTypes.TryGetValue(possibleType.Name, out var consistOfType) && consistOfType.Fields is not null) - unionFields.AddRange(consistOfType.Fields.Where(f => unionFieldNames.Add(f.Name))); - - return unionFields; - } + type.Fields?.Where(FilterIfDeprecated).Where(FilterIfAllFieldsDeprecated).ToList(); protected internal IEnumerable GetFragments(GraphQlType type) { @@ -162,10 +152,7 @@ protected internal ScalarFieldTypeDescription GetDataPropertyType(GraphQlType ba case GraphQlTypeKind.Union: case GraphQlTypeKind.InputObject: var fieldTypeName = GetCSharpClassName(fieldType.Name); - var propertyType = $"{Configuration.ClassPrefix}{fieldTypeName}{Configuration.ClassSuffix}"; - if (fieldType.Kind == GraphQlTypeKind.Interface) - propertyType = $"I{propertyType}"; - + var propertyType = GetFullyQualifiedNetTypeName(fieldTypeName, fieldType.Kind); return ScalarFieldTypeDescription.FromNetTypeName(AddQuestionMarkIfNullableReferencesEnabled(propertyType)); case GraphQlTypeKind.Enum: @@ -178,7 +165,7 @@ protected internal ScalarFieldTypeDescription GetDataPropertyType(GraphQlType ba var netItemType = IsUnknownObjectScalar(baseType, member.Name, itemType) ? "object" - : $"{(unwrappedItemType.Kind == GraphQlTypeKind.Interface ? "I" : null)}{Configuration.ClassPrefix}{itemTypeName}{Configuration.ClassSuffix}"; + : GetFullyQualifiedNetTypeName(itemTypeName, unwrappedItemType.Kind); var suggestedScalarNetType = ResolveScalarNetType(baseType, member.Name, itemType, true).NetTypeName.TrimEnd('?'); if (!String.Equals(suggestedScalarNetType, "object") && !suggestedScalarNetType.TrimEnd().EndsWith("System.Object")) @@ -246,7 +233,7 @@ internal ScalarFieldTypeDescription ResolveScalarNetType(GraphQlType baseType, s }; internal string GetFullyQualifiedNetTypeName(string baseTypeName, GraphQlTypeKind kind) => - $"{(kind is GraphQlTypeKind.Interface ? "I" : null)}{Configuration.ClassPrefix}{baseTypeName}{Configuration.ClassSuffix}"; + $"{(kind is GraphQlTypeKind.Interface or GraphQlTypeKind.Union ? "I" : null)}{Configuration.ClassPrefix}{baseTypeName}{Configuration.ClassSuffix}"; private ScalarFieldTypeDescription GetBooleanNetType(GraphQlType baseType, GraphQlTypeBase valueType, string valueName, bool alwaysNullable) => Configuration.BooleanTypeMapping switch @@ -327,7 +314,7 @@ private void ResolveNameCollisions() var propertyNamesToGenerate = new List(); if (isInputObject) propertyNamesToGenerate.AddRange(graphQlType.InputFields.Select(f => NamingHelper.ToPascalCase(f.Name))); - else if (graphQlType.Kind.IsComplex()) + else if (graphQlType.Kind is GraphQlTypeKind.Object or GraphQlTypeKind.Interface) propertyNamesToGenerate.AddRange(GetFieldsToGenerate(graphQlType).Select(f => NamingHelper.ToPascalCase(f.Name))); else continue; diff --git a/src/GraphQlClientGenerator/GraphQlGenerator.cs b/src/GraphQlClientGenerator/GraphQlGenerator.cs index 22ed3c4..bebcc14 100644 --- a/src/GraphQlClientGenerator/GraphQlGenerator.cs +++ b/src/GraphQlClientGenerator/GraphQlGenerator.cs @@ -50,6 +50,17 @@ private static HttpClient CreateHttpClient(HttpMessageHandler messageHandler = n Converters = { new StringEnumConverter() } }; + private static readonly GraphQlField TypeNameField = + new() + { + Name = NamingHelper.MetadataFieldTypeName, + Type = new GraphQlFieldType + { + Kind = GraphQlTypeKind.Scalar, + Name = "String" + } + }; + private readonly GraphQlGeneratorConfiguration _configuration; public GraphQlGenerator(GraphQlGeneratorConfiguration configuration = null) => @@ -356,38 +367,62 @@ private void GenerateDataClasses(GenerationContext context) context.BeforeDataClassesGeneration(); + var unionLookup = + complexTypes.Values + .Where(t => t.Kind is GraphQlTypeKind.Union) + .SelectMany(u => u.PossibleTypes.Select(t => (UnionName: u.Name, PossibleTypeName: t.Name))) + .ToLookup(x => x.PossibleTypeName, x => x.UnionName); + foreach (var complexType in complexTypes.Values) { + var csharpUnionTypeName = context.GetCSharpClassName(complexType.Name); + if (complexType.Kind is GraphQlTypeKind.Union) + { + GenerateFileMember( + context, + csharpUnionTypeName, + complexType, + null, + () => GenerateDataClassBody(complexType, Array.Empty(), context, true)); + + continue; + } + var hasInputReference = complexType.Kind is GraphQlTypeKind.InputObject && context.ReferencedObjectTypes.Contains(complexType.Name); var fieldsToGenerate = context.GetFieldsToGenerate(complexType); - var isInterface = complexType.Kind is GraphQlTypeKind.Interface; var csharpTypeName = context.GetCSharpClassName(complexType.Name); - var interfacesToImplement = new List(); - if (isInterface) - { - interfacesToImplement.Add(GenerateFileMember(context, csharpTypeName, complexType, null, () => GenerateBody(true))); - } - else if (complexType.Interfaces?.Count > 0) + var interfacesToImplement = new HashSet(unionLookup[complexType.Name].Select(n => context.GetFullyQualifiedNetTypeName(n, GraphQlTypeKind.Interface))); + //var implementsUnion = interfacesToImplement.Any(); + if (complexType.Interfaces?.Count > 0) { var fieldNames = new HashSet(fieldsToGenerate.Select(f => f.Name)); foreach (var @interface in complexType.Interfaces) { var csharpInterfaceName = context.GetCSharpClassName(@interface.Name, false); - var interfaceName = $"I{_configuration.ClassPrefix}{csharpInterfaceName}{_configuration.ClassSuffix}"; + var interfaceName = context.GetFullyQualifiedNetTypeName(csharpInterfaceName, GraphQlTypeKind.Interface); interfacesToImplement.Add(interfaceName); foreach (var interfaceField in complexTypes[@interface.Name].Fields.Where(context.FilterIfDeprecated)) if (fieldNames.Add(interfaceField.Name)) fieldsToGenerate.Add(interfaceField); } + + //fieldsToGenerate.Insert(0, TypeNameField); } + /*else if (implementsUnion) + fieldsToGenerate.Insert(0, TypeNameField);*/ if (hasInputReference) interfacesToImplement.Add("IGraphQlInputObject"); - if (!isInterface && fieldsToGenerate.Any()) - GenerateFileMember(context, csharpTypeName, complexType, String.Join(", ", interfacesToImplement), () => GenerateBody(false)); + if (fieldsToGenerate.Any()) + GenerateFileMember( + context, + csharpTypeName, + complexType, + String.Join(", ", interfacesToImplement), + () => GenerateBody(complexType.Kind is GraphQlTypeKind.Interface)); continue; @@ -395,7 +430,7 @@ void GenerateBody(bool isInterfaceMember) { if (hasInputReference) GenerateInputDataClassBody(complexType, fieldsToGenerate, context); - else if (fieldsToGenerate is not null) + else GenerateDataClassBody(complexType, fieldsToGenerate, context, isInterfaceMember); } } @@ -567,7 +602,7 @@ private void GenerateInputDataClassBody(GraphQlType type, IEnumerable @@ -637,7 +670,7 @@ internal static string AddQuestionMarkIfNullableReferencesEnabled(GraphQlGenerat configuration.CSharpVersion == CSharpVersion.NewestWithNullableReferences ? $"{dataTypeIdentifier}?" : dataTypeIdentifier; private string GetMemberAccessibility() => - _configuration.MemberAccessibility == MemberAccessibility.Internal ? "internal" : "public"; + _configuration.MemberAccessibility is MemberAccessibility.Internal ? "internal" : "public"; private void GenerateDataProperty( GraphQlType baseType, @@ -662,21 +695,21 @@ private void GenerateDataProperty( if (decorateWithJsonPropertyAttribute) { decorateWithJsonPropertyAttribute = - _configuration.JsonPropertyGeneration == JsonPropertyGenerationOption.Always || + _configuration.JsonPropertyGeneration is JsonPropertyGenerationOption.Always || !String.Equals( member.Name, propertyContext.PropertyName.TrimStart('@'), - _configuration.JsonPropertyGeneration == JsonPropertyGenerationOption.CaseInsensitive ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal); + _configuration.JsonPropertyGeneration is JsonPropertyGenerationOption.CaseInsensitive ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal); if (_configuration.JsonPropertyGeneration is JsonPropertyGenerationOption.Never or JsonPropertyGenerationOption.UseDefaultAlias) decorateWithJsonPropertyAttribute = false; } - var isInterfaceMember = baseType.Kind == GraphQlTypeKind.Interface; + var isInterfaceMember = baseType.Kind is GraphQlTypeKind.Interface; var fieldType = member.Type.UnwrapIfNonNull(); var isGraphQlInterfaceJsonConverterRequired = - fieldType.Kind == GraphQlTypeKind.Interface || - fieldType.Kind == GraphQlTypeKind.List && UnwrapListItemType(fieldType, false, out _).UnwrapIfNonNull().Kind == GraphQlTypeKind.Interface; + fieldType.Kind is GraphQlTypeKind.Interface or GraphQlTypeKind.Union || + fieldType.Kind is GraphQlTypeKind.List && UnwrapListItemType(fieldType, false, out _).UnwrapIfNonNull().Kind is GraphQlTypeKind.Interface or GraphQlTypeKind.Union; var isBaseTypeInputObject = baseType.Kind == GraphQlTypeKind.InputObject; var isPreprocessorDirectiveDisableNewtonsoftJsonRequired = !isInterfaceMember && decorateWithJsonPropertyAttribute || isGraphQlInterfaceJsonConverterRequired || isBaseTypeInputObject; @@ -808,7 +841,8 @@ private void GenerateQueryBuilder(GenerationContext context, GraphQlType graphQl writer.Write('"'); var csharpPropertyName = NamingHelper.ToPascalCase(field.Name); - if (_configuration.JsonPropertyGeneration == JsonPropertyGenerationOption.UseDefaultAlias && !String.Equals(field.Name, csharpPropertyName, StringComparison.OrdinalIgnoreCase)) + if (_configuration.JsonPropertyGeneration is JsonPropertyGenerationOption.UseDefaultAlias && + !String.Equals(field.Name, csharpPropertyName, StringComparison.OrdinalIgnoreCase)) { writer.Write(", DefaultAlias = \""); writer.Write(NamingHelper.LowerFirst(csharpPropertyName)); @@ -867,13 +901,13 @@ private void GenerateQueryBuilder(GenerationContext context, GraphQlType graphQl writer.WriteLine(); writer.Write(indentation); writer.Write(memberIndentation); - writer.WriteLine("{"); + writer.WriteLine('{'); writer.Write(indentation); writer.Write(memberIndentation); writer.WriteLine(" WithTypeName();"); writer.Write(indentation); writer.Write(memberIndentation); - writer.WriteLine("}"); + writer.WriteLine('}'); } else writer.WriteLine(" => WithTypeName();"); @@ -919,7 +953,7 @@ private void GenerateQueryBuilder(GenerationContext context, GraphQlType graphQl } var fragments = context.GetFragments(graphQlType); - fields ??= new List(); + fields ??= []; var firstFragmentIndex = fields.Count; fields.AddRange(fragments); var csharpNameLookup = fields.ToLookup(f => NamingHelper.ToPascalCase(f.Name)); @@ -1110,7 +1144,8 @@ void WriteAliasParameter() writer.Write(stringDataType); writer.Write(" alias = "); - if (_configuration.JsonPropertyGeneration == JsonPropertyGenerationOption.UseDefaultAlias && !String.Equals(field.Name, csharpPropertyName, StringComparison.OrdinalIgnoreCase)) + if (_configuration.JsonPropertyGeneration is JsonPropertyGenerationOption.UseDefaultAlias && + !String.Equals(field.Name, csharpPropertyName, StringComparison.OrdinalIgnoreCase)) { writer.Write('"'); writer.Write(NamingHelper.LowerFirst(csharpPropertyName)); diff --git a/src/GraphQlClientGenerator/NamingHelper.cs b/src/GraphQlClientGenerator/NamingHelper.cs index 2c9f73d..632dcd8 100644 --- a/src/GraphQlClientGenerator/NamingHelper.cs +++ b/src/GraphQlClientGenerator/NamingHelper.cs @@ -5,9 +5,10 @@ namespace GraphQlClientGenerator; internal static class NamingHelper { + internal const string MetadataFieldTypeName = "__typename"; + private static readonly HashSet CSharpKeywords = - new() - { + [ "abstract", "as", "base", @@ -86,7 +87,7 @@ internal static class NamingHelper "void", "volatile", "while", - }; + ]; public static string LowerFirst(string value) => Char.ToLowerInvariant(value[0]) + value.Substring(1); @@ -103,6 +104,9 @@ internal static class NamingHelper /// https://stackoverflow.com/questions/18627112/how-can-i-convert-text-to-pascal-case> public static string ToPascalCase(string text) { + if (text is MetadataFieldTypeName) + return "TypeName"; + var textWithoutWhiteSpace = RegexInvalidCharacters.Replace(RegexWhiteSpace.Replace(text, String.Empty), String.Empty); if (textWithoutWhiteSpace.All(c => c == '_')) return textWithoutWhiteSpace; diff --git a/test/GraphQlClientGenerator.Test/ExpectedSingleFileGenerationContext/Unions b/test/GraphQlClientGenerator.Test/ExpectedSingleFileGenerationContext/Unions index 0f0949e..ae6cfde 100644 --- a/test/GraphQlClientGenerator.Test/ExpectedSingleFileGenerationContext/Unions +++ b/test/GraphQlClientGenerator.Test/ExpectedSingleFileGenerationContext/Unions @@ -659,16 +659,8 @@ public partial class SimpleObjectType public ICollection? StringArrayValue { get; set; } } -public partial class UnionType +public partial interface IUnionType { - public string? Name { get; set; } - public string? ConcreteType1Field { get; set; } - public string? Value { get; set; } - public string? ConcreteType2Field { get; set; } - public string? value { get; set; } - public string? ConcreteType3Field { get; set; } - public string? VALUE { get; set; } - public string? Function { get; set; } } public partial interface INamedType @@ -677,7 +669,7 @@ public partial interface INamedType } [GraphQlObjectType("ConcreteType1")] -public partial class ConcreteType1 : INamedType +public partial class ConcreteType1 : IUnionType, INamedType { public string? Name { get; set; } public string? ConcreteType1Field { get; set; } @@ -685,7 +677,7 @@ public partial class ConcreteType1 : INamedType } [GraphQlObjectType("ConcreteType2")] -public partial class ConcreteType2 : INamedType +public partial class ConcreteType2 : IUnionType, INamedType { public string? Name { get; set; } public string? ConcreteType2Field { get; set; } @@ -693,7 +685,7 @@ public partial class ConcreteType2 : INamedType } [GraphQlObjectType("ConcreteType3")] -public partial class ConcreteType3 : INamedType +public partial class ConcreteType3 : IUnionType, INamedType { public string? Name { get; set; } public string? ConcreteType3Field { get; set; } @@ -705,7 +697,10 @@ public partial class Query { public DateTimeOffset? ScalarValue { get; set; } public SimpleObjectType? SimpleObject { get; set; } - public ICollection? Union { get; set; } + #if !GRAPHQL_GENERATOR_DISABLE_NEWTONSOFT_JSON + [JsonConverter(typeof(GraphQlInterfaceJsonConverter))] + #endif + public ICollection? Union { get; set; } #if !GRAPHQL_GENERATOR_DISABLE_NEWTONSOFT_JSON [JsonConverter(typeof(GraphQlInterfaceJsonConverter))] #endif