Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and fixes in source input model #3445

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions src/AutoRest.CSharp/Common/Generation/Types/CSharpType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -163,18 +163,6 @@ internal static CSharpType FromSystemType(Type type, string defaultNamespace, So
internal static CSharpType FromSystemType(BuildContext context, Type type)
=> FromSystemType(type, context.DefaultNamespace, context.SourceInputModel);

public bool IsCollectionType()
ArcturusZhang marked this conversation as resolved.
Show resolved Hide resolved
{
if (!IsFrameworkType)
return false;

return FrameworkType.Equals(typeof(IList<>)) ||
FrameworkType.Equals(typeof(IEnumerable<>)) ||
FrameworkType == typeof(IReadOnlyList<>) ||
FrameworkType.Equals(typeof(IDictionary<,>)) ||
FrameworkType == typeof(IReadOnlyDictionary<,>);
}

public CSharpType GetNonNullable()
{
if (!IsNullable)
Expand Down
121 changes: 121 additions & 0 deletions src/AutoRest.CSharp/Common/Input/Source/CodeGenAttributes.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Azure.Core;
using Microsoft.CodeAnalysis;

namespace AutoRest.CSharp.Input.Source
{
public class CodeGenAttributes
{
public CodeGenAttributes(Compilation compilation)
{
CodeGenSuppressAttribute = GetSymbol(compilation, typeof(CodeGenSuppressAttribute));
CodeGenMemberAttribute = GetSymbol(compilation, typeof(CodeGenMemberAttribute));
CodeGenTypeAttribute = GetSymbol(compilation, typeof(CodeGenTypeAttribute));
CodeGenModelAttribute = GetSymbol(compilation, typeof(CodeGenModelAttribute));
CodeGenClientAttribute = GetSymbol(compilation, typeof(CodeGenClientAttribute));
CodeGenMemberSerializationAttribute = GetSymbol(compilation, typeof(CodeGenMemberSerializationAttribute));
CodeGenMemberSerializationHooksAttribute = GetSymbol(compilation, typeof(CodeGenMemberSerializationHooksAttribute));
}

public INamedTypeSymbol CodeGenSuppressAttribute { get; }

public INamedTypeSymbol CodeGenMemberAttribute { get; }

public INamedTypeSymbol CodeGenTypeAttribute { get; }

public INamedTypeSymbol CodeGenModelAttribute { get; }

public INamedTypeSymbol CodeGenClientAttribute { get; }

public INamedTypeSymbol CodeGenMemberSerializationAttribute { get; }

public INamedTypeSymbol CodeGenMemberSerializationHooksAttribute { get; }

private static INamedTypeSymbol GetSymbol(Compilation compilation, Type type) => compilation.GetTypeByMetadataName(type.FullName!) ?? throw new InvalidOperationException($"cannot load symbol of attribute {type}");

private static bool CheckAttribute(AttributeData attributeData, INamedTypeSymbol codeGenAttribute)
=> SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, codeGenAttribute);

public bool TryGetCodeGenMemberAttributeValue(AttributeData attributeData, [MaybeNullWhen(false)] out string name)
{
name = null;
if (!CheckAttribute(attributeData, CodeGenMemberAttribute))
return false;

name = attributeData.ConstructorArguments.FirstOrDefault().Value as string;
return name != null;
}

public bool TryGetCodeGenMemberSerializationAttributeValue(AttributeData attributeData, [MaybeNullWhen(false)] out string[] propertyNames)
{
propertyNames = null;
if (!CheckAttribute(attributeData, CodeGenMemberSerializationAttribute))
return false;

if (attributeData.ConstructorArguments.Length > 0)
{
propertyNames = ToStringArray(attributeData.ConstructorArguments[0].Values);
}

return propertyNames != null;
}

public bool TryGetCodeGenMemberSerializationHooksAttributeValue(AttributeData attributeData, out (string? SerializationHook, string? DeserializationHook) hooks)
{
hooks = default;
if (!CheckAttribute(attributeData, CodeGenMemberSerializationHooksAttribute))
return false;

string? serializationHook = null;
string? deserializationHook = null;

var arguments = attributeData.ConstructorArguments;
serializationHook = arguments[0].Value as string;
deserializationHook = arguments[1].Value as string;

hooks = (serializationHook, deserializationHook);
return serializationHook != null || deserializationHook != null;
}

public bool TryGetCodeGenModelAttributeValue(AttributeData attributeData, out string[]? usage, out string[]? formats)
lirenhe marked this conversation as resolved.
Show resolved Hide resolved
{
usage = null;
formats = null;
if (!CheckAttribute(attributeData, CodeGenModelAttribute))
return false;
foreach (var namedArgument in attributeData.NamedArguments)
{
switch (namedArgument.Key)
{
case nameof(Azure.Core.CodeGenModelAttribute.Usage):
usage = ToStringArray(namedArgument.Value.Values);
break;
case nameof(Azure.Core.CodeGenModelAttribute.Formats):
formats = ToStringArray(namedArgument.Value.Values);
break;
}
}

return usage != null || formats != null;
}

private static string[]? ToStringArray(ImmutableArray<TypedConstant> values)
{
if (values.IsDefaultOrEmpty)
{
return null;
lirenhe marked this conversation as resolved.
Show resolved Hide resolved
}

return values
.Select(v => (string?)v.Value)
.OfType<string>()
.ToArray();
}
}
}
88 changes: 17 additions & 71 deletions src/AutoRest.CSharp/Common/Input/Source/ModelTypeMapping.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,103 +20,49 @@ public class ModelTypeMapping
public string[]? Usage { get; }
public string[]? Formats { get; }

public ModelTypeMapping(INamedTypeSymbol modelAttribute, INamedTypeSymbol memberAttribute, INamedTypeSymbol serializationAttribute, INamedTypeSymbol serializationHooksAttribute, INamedTypeSymbol? existingType)
public ModelTypeMapping(CodeGenAttributes codeGenAttributes, INamedTypeSymbol existingType)
{
_existingType = existingType;
_propertyMappings = new();
_serializationMappings = new(SymbolEqualityComparer.Default);

foreach (ISymbol member in GetMembers(existingType))
{
string[]? serializationPath = null;
(string? SerializationHook, string? DeserializationHook)? serializationHooks = null;
foreach (var attributeData in member.GetAttributes())
{
var attributeTypeSymbol = attributeData.AttributeClass;
// handle CodeGenMember attribute
if (SymbolEqualityComparer.Default.Equals(attributeTypeSymbol, memberAttribute) && TryGetCodeGenMemberAttributeValue(member, attributeData, out var schemaMemberName))
if (codeGenAttributes.TryGetCodeGenMemberAttributeValue(attributeData, out var schemaMemberName))
{
_propertyMappings.Add(schemaMemberName, member);
}
string[]? serializationPath = null;
(string? SerializationHook, string? DeserializationHook)? serializationHooks = null;
if (SymbolEqualityComparer.Default.Equals(attributeTypeSymbol, serializationAttribute) && TryGetSerializationAttributeValue(member, attributeData, out var pathResult))
// handle CodeGenMemberSerialization attribute
if (codeGenAttributes.TryGetCodeGenMemberSerializationAttributeValue(attributeData, out var pathResult))
{
serializationPath = pathResult;
}
if (SymbolEqualityComparer.Default.Equals(attributeTypeSymbol, serializationHooksAttribute) && TryGetSerializationHooks(member, attributeData, out var hooks))
// handle CodeGenMemberSerializationHooks attribute
if (codeGenAttributes.TryGetCodeGenMemberSerializationHooksAttributeValue(attributeData, out var hooks))
{
serializationHooks = hooks;
}
if (serializationPath != null || serializationHooks != null)
{
_serializationMappings.Add(member, new SourcePropertySerializationMapping(member, serializationPath, serializationHooks?.SerializationHook, serializationHooks?.DeserializationHook));
}
}
}

if (existingType != null)
{
foreach (var attributeData in existingType.GetAttributes())
if (serializationPath != null || serializationHooks != null)
{
if (SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, modelAttribute))
{
foreach (var namedArgument in attributeData.NamedArguments)
{
switch (namedArgument.Key)
{
case nameof(CodeGenModelAttribute.Usage):
Usage = ToStringArray(namedArgument.Value.Values);
break;
case nameof(CodeGenModelAttribute.Formats):
Formats = ToStringArray(namedArgument.Value.Values);
break;
}
}
}
_serializationMappings.Add(member, new SourcePropertySerializationMapping(member, serializationPath, serializationHooks?.SerializationHook, serializationHooks?.DeserializationHook));
}
}
}

private static bool TryGetSerializationHooks(ISymbol symbol, AttributeData attributeData, out (string? SerializationHook, string? DeserializationHook) hooks)
{
string? serializationHook = null;
string? deserializationHook = null;

var arguments = attributeData.ConstructorArguments;
serializationHook = arguments[0].Value as string;
deserializationHook = arguments[1].Value as string;

hooks = (serializationHook, deserializationHook);
return serializationHook != null || deserializationHook != null;
}

private static bool TryGetCodeGenMemberAttributeValue(ISymbol symbol, AttributeData attributeData, [MaybeNullWhen(false)] out string name)
{
name = attributeData.ConstructorArguments.FirstOrDefault().Value as string;
return name != null;
}

private static bool TryGetSerializationAttributeValue(ISymbol symbol, AttributeData attributeData, [MaybeNullWhen(false)] out string[] propertyNames)
{
propertyNames = null;
if (attributeData.ConstructorArguments.Length > 0)
foreach (var attributeData in existingType.GetAttributes())
{
propertyNames = ToStringArray(attributeData.ConstructorArguments[0].Values);
}

return propertyNames != null;
}

private static string[]? ToStringArray(ImmutableArray<TypedConstant> values)
{
if (values.IsDefaultOrEmpty)
{
return null;
// handle CodeGenModel attribute
if (codeGenAttributes.TryGetCodeGenModelAttributeValue(attributeData, out var usage, out var formats))
{
Usage = usage;
Formats = formats;
}
}

return values
.Select(v => (string?)v.Value)
.OfType<string>()
.ToArray();
}

public SourceMemberMapping? GetForMember(string name)
Expand Down
25 changes: 9 additions & 16 deletions src/AutoRest.CSharp/Common/Input/Source/SourceInputModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,15 @@ public class SourceInputModel
{
private readonly Compilation _compilation;
private readonly CompilationInput? _existingCompilation;
private readonly INamedTypeSymbol _typeAttribute;
private readonly INamedTypeSymbol _modelAttribute;
private readonly INamedTypeSymbol _clientAttribute;
private readonly INamedTypeSymbol _schemaMemberNameAttribute;
private readonly INamedTypeSymbol _serializationAttribute;
private readonly INamedTypeSymbol _serializationHooksAttribute;
private readonly CodeGenAttributes _codeGenAttributes;
private readonly Dictionary<string, INamedTypeSymbol> _nameMap = new Dictionary<string, INamedTypeSymbol>(StringComparer.OrdinalIgnoreCase);

public SourceInputModel(Compilation compilation, CompilationInput? existingCompilation = null)
{
_compilation = compilation;
_existingCompilation = existingCompilation;

_schemaMemberNameAttribute = compilation.GetTypeByMetadataName(typeof(CodeGenMemberAttribute).FullName!)!;
_serializationAttribute = compilation.GetTypeByMetadataName(typeof(CodeGenMemberSerializationAttribute).FullName!)!;
_serializationHooksAttribute = compilation.GetTypeByMetadataName(typeof(CodeGenMemberSerializationHooksAttribute).FullName!)!;
_typeAttribute = compilation.GetTypeByMetadataName(typeof(CodeGenTypeAttribute).FullName!)!;
_modelAttribute = compilation.GetTypeByMetadataName(typeof(CodeGenModelAttribute).FullName!)!;
_clientAttribute = compilation.GetTypeByMetadataName(typeof(CodeGenClientAttribute).FullName!)!;
_codeGenAttributes = new CodeGenAttributes(compilation);

IAssemblySymbol assembly = _compilation.Assembly;

Expand All @@ -58,9 +48,12 @@ public SourceInputModel(Compilation compilation, CompilationInput? existingCompi
return osvAttribute?.ConstructorArguments[0].Values.Select(v => v.Value).OfType<string>().ToList();
}

public ModelTypeMapping CreateForModel(INamedTypeSymbol? symbol)
public ModelTypeMapping? CreateForModel(INamedTypeSymbol? symbol)
{
return new ModelTypeMapping(_modelAttribute, _schemaMemberNameAttribute, _serializationAttribute, _serializationHooksAttribute, symbol);
if (symbol == null)
return null;

return new ModelTypeMapping(_codeGenAttributes, symbol);
}

internal IMethodSymbol? FindMethod(string namespaceName, string typeName, string methodName, IEnumerable<CSharpType> parameters)
Expand All @@ -87,7 +80,7 @@ internal bool TryGetClientSourceInput(INamedTypeSymbol type, [NotNullWhen(true)]
var attributeType = attribute.AttributeClass;
while (attributeType != null)
{
if (SymbolEqualityComparer.Default.Equals(attributeType, _clientAttribute))
if (SymbolEqualityComparer.Default.Equals(attributeType, _codeGenAttributes.CodeGenClientAttribute))
{
INamedTypeSymbol? parentClientType = null;
foreach ((var argumentName, TypedConstant constant) in attribute.NamedArguments)
Expand Down Expand Up @@ -119,7 +112,7 @@ private bool TryGetName(ISymbol symbol, [NotNullWhen(true)] out string? name)
var type = attribute.AttributeClass;
while (type != null)
{
if (SymbolEqualityComparer.Default.Equals(type, _typeAttribute))
if (SymbolEqualityComparer.Default.Equals(type, _codeGenAttributes.CodeGenTypeAttribute))
{
if (attribute?.ConstructorArguments.Length > 0)
{
Expand Down