Skip to content

Commit

Permalink
#151 proper deserialization of union sub types
Browse files Browse the repository at this point in the history
Husqvik committed Nov 21, 2023
1 parent a55762d commit cd3f1dd
Showing 4 changed files with 88 additions and 67 deletions.
33 changes: 10 additions & 23 deletions src/GraphQlClientGenerator/GenerationContext.cs
Original file line number Diff line number Diff line change
@@ -48,8 +48,12 @@ private bool FilterIfAllFieldsDeprecated(GraphQlField field)
return true;

var graphQlType = _complexTypes[fieldType.Name];
var nestedTypeFields =
graphQlType.Kind is GraphQlTypeKind.Union
? graphQlType.PossibleTypes.Select(t => _complexTypes[t.Name]).SelectMany(t => t.Fields)
: graphQlType.Fields;

return GetFields(graphQlType).Any(FilterIfDeprecated);
return nestedTypeFields.Any(FilterIfDeprecated);
}

public void Initialize(GraphQlGeneratorConfiguration configuration)
@@ -111,21 +115,7 @@ public void Initialize(GraphQlGeneratorConfiguration configuration)
public abstract void AfterGeneration();

protected internal List<GraphQlField> GetFieldsToGenerate(GraphQlType type) =>
GetFields(type)?.Where(FilterIfDeprecated).Where(FilterIfAllFieldsDeprecated).ToList();

private IEnumerable<GraphQlField> GetFields(GraphQlType type)
{
if (type.Kind != GraphQlTypeKind.Union)
return type.Fields;

var unionFields = new List<GraphQlField>();
var unionFieldNames = new HashSet<string>();
foreach (var possibleType in type.PossibleTypes)
if (_complexTypes.TryGetValue(possibleType.Name, out var consistOfType) && consistOfType.Fields is not null)
unionFields.AddRange(consistOfType.Fields.Where(f => unionFieldNames.Add(f.Name)));

return unionFields;
}
type.Fields?.Where(FilterIfDeprecated).Where(FilterIfAllFieldsDeprecated).ToList();

protected internal IEnumerable<GraphQlField> GetFragments(GraphQlType type)
{
@@ -162,10 +152,7 @@ protected internal ScalarFieldTypeDescription GetDataPropertyType(GraphQlType ba
case GraphQlTypeKind.Union:
case GraphQlTypeKind.InputObject:
var fieldTypeName = GetCSharpClassName(fieldType.Name);
var propertyType = $"{Configuration.ClassPrefix}{fieldTypeName}{Configuration.ClassSuffix}";
if (fieldType.Kind == GraphQlTypeKind.Interface)
propertyType = $"I{propertyType}";

var propertyType = GetFullyQualifiedNetTypeName(fieldTypeName, fieldType.Kind);
return ScalarFieldTypeDescription.FromNetTypeName(AddQuestionMarkIfNullableReferencesEnabled(propertyType));

case GraphQlTypeKind.Enum:
@@ -178,7 +165,7 @@ protected internal ScalarFieldTypeDescription GetDataPropertyType(GraphQlType ba
var netItemType =
IsUnknownObjectScalar(baseType, member.Name, itemType)
? "object"
: $"{(unwrappedItemType.Kind == GraphQlTypeKind.Interface ? "I" : null)}{Configuration.ClassPrefix}{itemTypeName}{Configuration.ClassSuffix}";
: GetFullyQualifiedNetTypeName(itemTypeName, unwrappedItemType.Kind);

var suggestedScalarNetType = ResolveScalarNetType(baseType, member.Name, itemType, true).NetTypeName.TrimEnd('?');
if (!String.Equals(suggestedScalarNetType, "object") && !suggestedScalarNetType.TrimEnd().EndsWith("System.Object"))
@@ -246,7 +233,7 @@ internal ScalarFieldTypeDescription ResolveScalarNetType(GraphQlType baseType, s
};

internal string GetFullyQualifiedNetTypeName(string baseTypeName, GraphQlTypeKind kind) =>
$"{(kind is GraphQlTypeKind.Interface ? "I" : null)}{Configuration.ClassPrefix}{baseTypeName}{Configuration.ClassSuffix}";
$"{(kind is GraphQlTypeKind.Interface or GraphQlTypeKind.Union ? "I" : null)}{Configuration.ClassPrefix}{baseTypeName}{Configuration.ClassSuffix}";

private ScalarFieldTypeDescription GetBooleanNetType(GraphQlType baseType, GraphQlTypeBase valueType, string valueName, bool alwaysNullable) =>
Configuration.BooleanTypeMapping switch
@@ -327,7 +314,7 @@ private void ResolveNameCollisions()
var propertyNamesToGenerate = new List<string>();
if (isInputObject)
propertyNamesToGenerate.AddRange(graphQlType.InputFields.Select(f => NamingHelper.ToPascalCase(f.Name)));
else if (graphQlType.Kind.IsComplex())
else if (graphQlType.Kind is GraphQlTypeKind.Object or GraphQlTypeKind.Interface)
propertyNamesToGenerate.AddRange(GetFieldsToGenerate(graphQlType).Select(f => NamingHelper.ToPascalCase(f.Name)));
else
continue;
91 changes: 63 additions & 28 deletions src/GraphQlClientGenerator/GraphQlGenerator.cs
Original file line number Diff line number Diff line change
@@ -50,6 +50,17 @@ private static HttpClient CreateHttpClient(HttpMessageHandler messageHandler = n
Converters = { new StringEnumConverter() }
};

private static readonly GraphQlField TypeNameField =
new()
{
Name = NamingHelper.MetadataFieldTypeName,
Type = new GraphQlFieldType
{
Kind = GraphQlTypeKind.Scalar,
Name = "String"
}
};

private readonly GraphQlGeneratorConfiguration _configuration;

public GraphQlGenerator(GraphQlGeneratorConfiguration configuration = null) =>
@@ -356,46 +367,70 @@ private void GenerateDataClasses(GenerationContext context)

context.BeforeDataClassesGeneration();

var unionLookup =
complexTypes.Values
.Where(t => t.Kind is GraphQlTypeKind.Union)
.SelectMany(u => u.PossibleTypes.Select(t => (UnionName: u.Name, PossibleTypeName: t.Name)))
.ToLookup(x => x.PossibleTypeName, x => x.UnionName);

foreach (var complexType in complexTypes.Values)
{
var csharpUnionTypeName = context.GetCSharpClassName(complexType.Name);
if (complexType.Kind is GraphQlTypeKind.Union)
{
GenerateFileMember(
context,
csharpUnionTypeName,
complexType,
null,
() => GenerateDataClassBody(complexType, Array.Empty<GraphQlField>(), context, true));

continue;
}

var hasInputReference = complexType.Kind is GraphQlTypeKind.InputObject && context.ReferencedObjectTypes.Contains(complexType.Name);
var fieldsToGenerate = context.GetFieldsToGenerate(complexType);
var isInterface = complexType.Kind is GraphQlTypeKind.Interface;
var csharpTypeName = context.GetCSharpClassName(complexType.Name);
var interfacesToImplement = new List<string>();
if (isInterface)
{
interfacesToImplement.Add(GenerateFileMember(context, csharpTypeName, complexType, null, () => GenerateBody(true)));
}
else if (complexType.Interfaces?.Count > 0)
var interfacesToImplement = new HashSet<string>(unionLookup[complexType.Name].Select(n => context.GetFullyQualifiedNetTypeName(n, GraphQlTypeKind.Interface)));
//var implementsUnion = interfacesToImplement.Any();
if (complexType.Interfaces?.Count > 0)
{
var fieldNames = new HashSet<string>(fieldsToGenerate.Select(f => f.Name));

foreach (var @interface in complexType.Interfaces)
{
var csharpInterfaceName = context.GetCSharpClassName(@interface.Name, false);
var interfaceName = $"I{_configuration.ClassPrefix}{csharpInterfaceName}{_configuration.ClassSuffix}";
var interfaceName = context.GetFullyQualifiedNetTypeName(csharpInterfaceName, GraphQlTypeKind.Interface);
interfacesToImplement.Add(interfaceName);

foreach (var interfaceField in complexTypes[@interface.Name].Fields.Where(context.FilterIfDeprecated))
if (fieldNames.Add(interfaceField.Name))
fieldsToGenerate.Add(interfaceField);
}

//fieldsToGenerate.Insert(0, TypeNameField);
}
/*else if (implementsUnion)
fieldsToGenerate.Insert(0, TypeNameField);*/

if (hasInputReference)
interfacesToImplement.Add("IGraphQlInputObject");

if (!isInterface && fieldsToGenerate.Any())
GenerateFileMember(context, csharpTypeName, complexType, String.Join(", ", interfacesToImplement), () => GenerateBody(false));
if (fieldsToGenerate.Any())
GenerateFileMember(
context,
csharpTypeName,
complexType,
String.Join(", ", interfacesToImplement),
() => GenerateBody(complexType.Kind is GraphQlTypeKind.Interface));

continue;

void GenerateBody(bool isInterfaceMember)
{
if (hasInputReference)
GenerateInputDataClassBody(complexType, fieldsToGenerate, context);
else if (fieldsToGenerate is not null)
else
GenerateDataClassBody(complexType, fieldsToGenerate, context, isInterfaceMember);
}
}
@@ -567,7 +602,7 @@ private void GenerateInputDataClassBody(GraphQlType type, IEnumerable<IGraphQlMe
writer.WriteLine(" }");
}

private string GenerateFileMember(GenerationContext context, string typeName, GraphQlType graphQlType, string baseTypeName, Action generateFileMemberBody)
private void GenerateFileMember(GenerationContext context, string typeName, GraphQlType graphQlType, string baseTypeName, Action generateFileMemberBody)
{
typeName = context.GetFullyQualifiedNetTypeName(typeName, graphQlType.Kind);

@@ -601,7 +636,7 @@ private string GenerateFileMember(GenerationContext context, string typeName, Gr
if (_configuration.GeneratePartialClasses)
writer.Write("partial ");

writer.Write(graphQlType.Kind is GraphQlTypeKind.Interface ? "interface" : "class");
writer.Write(graphQlType.Kind is GraphQlTypeKind.Interface or GraphQlTypeKind.Union ? "interface" : "class");
writer.Write(' ');
writer.Write(typeName);

@@ -613,21 +648,19 @@ private string GenerateFileMember(GenerationContext context, string typeName, Gr

writer.WriteLine();
writer.Write(indentation);
writer.WriteLine("{");
writer.WriteLine('{');

generateFileMemberBody();

writer.Write(indentation);
writer.WriteLine("}");
writer.WriteLine('}');

context.AfterDataClassGeneration(
new ObjectGenerationContext
{
GraphQlType = graphQlType,
CSharpTypeName = typeName
});

return typeName;
}

private string AddQuestionMarkIfNullableReferencesEnabled(string dataTypeIdentifier) =>
@@ -637,7 +670,7 @@ internal static string AddQuestionMarkIfNullableReferencesEnabled(GraphQlGenerat
configuration.CSharpVersion == CSharpVersion.NewestWithNullableReferences ? $"{dataTypeIdentifier}?" : dataTypeIdentifier;

private string GetMemberAccessibility() =>
_configuration.MemberAccessibility == MemberAccessibility.Internal ? "internal" : "public";
_configuration.MemberAccessibility is MemberAccessibility.Internal ? "internal" : "public";

private void GenerateDataProperty(
GraphQlType baseType,
@@ -662,21 +695,21 @@ private void GenerateDataProperty(
if (decorateWithJsonPropertyAttribute)
{
decorateWithJsonPropertyAttribute =
_configuration.JsonPropertyGeneration == JsonPropertyGenerationOption.Always ||
_configuration.JsonPropertyGeneration is JsonPropertyGenerationOption.Always ||
!String.Equals(
member.Name,
propertyContext.PropertyName.TrimStart('@'),
_configuration.JsonPropertyGeneration == JsonPropertyGenerationOption.CaseInsensitive ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal);
_configuration.JsonPropertyGeneration is JsonPropertyGenerationOption.CaseInsensitive ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal);

if (_configuration.JsonPropertyGeneration is JsonPropertyGenerationOption.Never or JsonPropertyGenerationOption.UseDefaultAlias)
decorateWithJsonPropertyAttribute = false;
}

var isInterfaceMember = baseType.Kind == GraphQlTypeKind.Interface;
var isInterfaceMember = baseType.Kind is GraphQlTypeKind.Interface;
var fieldType = member.Type.UnwrapIfNonNull();
var isGraphQlInterfaceJsonConverterRequired =
fieldType.Kind == GraphQlTypeKind.Interface ||
fieldType.Kind == GraphQlTypeKind.List && UnwrapListItemType(fieldType, false, out _).UnwrapIfNonNull().Kind == GraphQlTypeKind.Interface;
fieldType.Kind is GraphQlTypeKind.Interface or GraphQlTypeKind.Union ||
fieldType.Kind is GraphQlTypeKind.List && UnwrapListItemType(fieldType, false, out _).UnwrapIfNonNull().Kind is GraphQlTypeKind.Interface or GraphQlTypeKind.Union;

var isBaseTypeInputObject = baseType.Kind == GraphQlTypeKind.InputObject;
var isPreprocessorDirectiveDisableNewtonsoftJsonRequired = !isInterfaceMember && decorateWithJsonPropertyAttribute || isGraphQlInterfaceJsonConverterRequired || isBaseTypeInputObject;
@@ -808,7 +841,8 @@ private void GenerateQueryBuilder(GenerationContext context, GraphQlType graphQl
writer.Write('"');

var csharpPropertyName = NamingHelper.ToPascalCase(field.Name);
if (_configuration.JsonPropertyGeneration == JsonPropertyGenerationOption.UseDefaultAlias && !String.Equals(field.Name, csharpPropertyName, StringComparison.OrdinalIgnoreCase))
if (_configuration.JsonPropertyGeneration is JsonPropertyGenerationOption.UseDefaultAlias &&
!String.Equals(field.Name, csharpPropertyName, StringComparison.OrdinalIgnoreCase))
{
writer.Write(", DefaultAlias = \"");
writer.Write(NamingHelper.LowerFirst(csharpPropertyName));
@@ -867,13 +901,13 @@ private void GenerateQueryBuilder(GenerationContext context, GraphQlType graphQl
writer.WriteLine();
writer.Write(indentation);
writer.Write(memberIndentation);
writer.WriteLine("{");
writer.WriteLine('{');
writer.Write(indentation);
writer.Write(memberIndentation);
writer.WriteLine(" WithTypeName();");
writer.Write(indentation);
writer.Write(memberIndentation);
writer.WriteLine("}");
writer.WriteLine('}');
}
else
writer.WriteLine(" => WithTypeName();");
@@ -919,7 +953,7 @@ private void GenerateQueryBuilder(GenerationContext context, GraphQlType graphQl
}

var fragments = context.GetFragments(graphQlType);
fields ??= new List<GraphQlField>();
fields ??= [];
var firstFragmentIndex = fields.Count;
fields.AddRange(fragments);
var csharpNameLookup = fields.ToLookup(f => NamingHelper.ToPascalCase(f.Name));
@@ -1110,7 +1144,8 @@ void WriteAliasParameter()
writer.Write(stringDataType);
writer.Write(" alias = ");

if (_configuration.JsonPropertyGeneration == JsonPropertyGenerationOption.UseDefaultAlias && !String.Equals(field.Name, csharpPropertyName, StringComparison.OrdinalIgnoreCase))
if (_configuration.JsonPropertyGeneration is JsonPropertyGenerationOption.UseDefaultAlias &&
!String.Equals(field.Name, csharpPropertyName, StringComparison.OrdinalIgnoreCase))
{
writer.Write('"');
writer.Write(NamingHelper.LowerFirst(csharpPropertyName));
10 changes: 7 additions & 3 deletions src/GraphQlClientGenerator/NamingHelper.cs
Original file line number Diff line number Diff line change
@@ -5,9 +5,10 @@ namespace GraphQlClientGenerator;

internal static class NamingHelper
{
internal const string MetadataFieldTypeName = "__typename";

private static readonly HashSet<string> CSharpKeywords =
new()
{
[
"abstract",
"as",
"base",
@@ -86,7 +87,7 @@ internal static class NamingHelper
"void",
"volatile",
"while",
};
];

public static string LowerFirst(string value) => Char.ToLowerInvariant(value[0]) + value.Substring(1);

@@ -103,6 +104,9 @@ internal static class NamingHelper
/// <remarks>https://stackoverflow.com/questions/18627112/how-can-i-convert-text-to-pascal-case</remarks>>
public static string ToPascalCase(string text)
{
if (text is MetadataFieldTypeName)
return "TypeName";

var textWithoutWhiteSpace = RegexInvalidCharacters.Replace(RegexWhiteSpace.Replace(text, String.Empty), String.Empty);
if (textWithoutWhiteSpace.All(c => c == '_'))
return textWithoutWhiteSpace;
Original file line number Diff line number Diff line change
@@ -659,16 +659,8 @@ public partial class SimpleObjectType
public ICollection<string>? StringArrayValue { get; set; }
}

public partial class UnionType
public partial interface IUnionType
{
public string? Name { get; set; }
public string? ConcreteType1Field { get; set; }
public string? Value { get; set; }
public string? ConcreteType2Field { get; set; }
public string? value { get; set; }
public string? ConcreteType3Field { get; set; }
public string? VALUE { get; set; }
public string? Function { get; set; }
}

public partial interface INamedType
@@ -677,23 +669,23 @@ public partial interface INamedType
}

[GraphQlObjectType("ConcreteType1")]
public partial class ConcreteType1 : INamedType
public partial class ConcreteType1 : IUnionType, INamedType
{
public string? Name { get; set; }
public string? ConcreteType1Field { get; set; }
public string? Value { get; set; }
}

[GraphQlObjectType("ConcreteType2")]
public partial class ConcreteType2 : INamedType
public partial class ConcreteType2 : IUnionType, INamedType
{
public string? Name { get; set; }
public string? ConcreteType2Field { get; set; }
public string? Value { get; set; }
}

[GraphQlObjectType("ConcreteType3")]
public partial class ConcreteType3 : INamedType
public partial class ConcreteType3 : IUnionType, INamedType
{
public string? Name { get; set; }
public string? ConcreteType3Field { get; set; }
@@ -705,7 +697,10 @@ public partial class Query
{
public DateTimeOffset? ScalarValue { get; set; }
public SimpleObjectType? SimpleObject { get; set; }
public ICollection<UnionType>? Union { get; set; }
#if !GRAPHQL_GENERATOR_DISABLE_NEWTONSOFT_JSON
[JsonConverter(typeof(GraphQlInterfaceJsonConverter))]
#endif
public ICollection<IUnionType>? Union { get; set; }
#if !GRAPHQL_GENERATOR_DISABLE_NEWTONSOFT_JSON
[JsonConverter(typeof(GraphQlInterfaceJsonConverter))]
#endif

0 comments on commit cd3f1dd

Please sign in to comment.