diff --git a/src/AutoRest.CSharp/Common/Generation/Types/CSharpType.cs b/src/AutoRest.CSharp/Common/Generation/Types/CSharpType.cs index e74879bfa87..0dfbcd1e0f3 100644 --- a/src/AutoRest.CSharp/Common/Generation/Types/CSharpType.cs +++ b/src/AutoRest.CSharp/Common/Generation/Types/CSharpType.cs @@ -163,18 +163,6 @@ internal static CSharpType FromSystemType(Type type, string defaultNamespace, So internal static CSharpType FromSystemType(BuildContext context, Type type) => FromSystemType(type, context.DefaultNamespace, context.SourceInputModel); - public bool IsCollectionType() - { - if (!IsFrameworkType) - return false; - - return FrameworkType.Equals(typeof(IList<>)) || - FrameworkType.Equals(typeof(IEnumerable<>)) || - FrameworkType == typeof(IReadOnlyList<>) || - FrameworkType.Equals(typeof(IDictionary<,>)) || - FrameworkType == typeof(IReadOnlyDictionary<,>); - } - public CSharpType GetNonNullable() { if (!IsNullable) diff --git a/src/AutoRest.CSharp/Common/Input/Source/CodeGenAttributes.cs b/src/AutoRest.CSharp/Common/Input/Source/CodeGenAttributes.cs new file mode 100644 index 00000000000..f7d68d4d541 --- /dev/null +++ b/src/AutoRest.CSharp/Common/Input/Source/CodeGenAttributes.cs @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using Azure.Core; +using Microsoft.CodeAnalysis; + +namespace AutoRest.CSharp.Input.Source +{ + public class CodeGenAttributes + { + public CodeGenAttributes(Compilation compilation) + { + CodeGenSuppressAttribute = GetSymbol(compilation, typeof(CodeGenSuppressAttribute)); + CodeGenMemberAttribute = GetSymbol(compilation, typeof(CodeGenMemberAttribute)); + CodeGenTypeAttribute = GetSymbol(compilation, typeof(CodeGenTypeAttribute)); + CodeGenModelAttribute = GetSymbol(compilation, typeof(CodeGenModelAttribute)); + CodeGenClientAttribute = GetSymbol(compilation, typeof(CodeGenClientAttribute)); + CodeGenMemberSerializationAttribute = GetSymbol(compilation, typeof(CodeGenMemberSerializationAttribute)); + CodeGenMemberSerializationHooksAttribute = GetSymbol(compilation, typeof(CodeGenMemberSerializationHooksAttribute)); + } + + public INamedTypeSymbol CodeGenSuppressAttribute { get; } + + public INamedTypeSymbol CodeGenMemberAttribute { get; } + + public INamedTypeSymbol CodeGenTypeAttribute { get; } + + public INamedTypeSymbol CodeGenModelAttribute { get; } + + public INamedTypeSymbol CodeGenClientAttribute { get; } + + public INamedTypeSymbol CodeGenMemberSerializationAttribute { get; } + + public INamedTypeSymbol CodeGenMemberSerializationHooksAttribute { get; } + + private static INamedTypeSymbol GetSymbol(Compilation compilation, Type type) => compilation.GetTypeByMetadataName(type.FullName!) ?? throw new InvalidOperationException($"cannot load symbol of attribute {type}"); + + private static bool CheckAttribute(AttributeData attributeData, INamedTypeSymbol codeGenAttribute) + => SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, codeGenAttribute); + + public bool TryGetCodeGenMemberAttributeValue(AttributeData attributeData, [MaybeNullWhen(false)] out string name) + { + name = null; + if (!CheckAttribute(attributeData, CodeGenMemberAttribute)) + return false; + + name = attributeData.ConstructorArguments.FirstOrDefault().Value as string; + return name != null; + } + + public bool TryGetCodeGenMemberSerializationAttributeValue(AttributeData attributeData, [MaybeNullWhen(false)] out string[] propertyNames) + { + propertyNames = null; + if (!CheckAttribute(attributeData, CodeGenMemberSerializationAttribute)) + return false; + + if (attributeData.ConstructorArguments.Length > 0) + { + propertyNames = ToStringArray(attributeData.ConstructorArguments[0].Values); + } + + return propertyNames != null; + } + + public bool TryGetCodeGenMemberSerializationHooksAttributeValue(AttributeData attributeData, out (string? SerializationHook, string? DeserializationHook) hooks) + { + hooks = default; + if (!CheckAttribute(attributeData, CodeGenMemberSerializationHooksAttribute)) + return false; + + string? serializationHook = null; + string? deserializationHook = null; + + var arguments = attributeData.ConstructorArguments; + serializationHook = arguments[0].Value as string; + deserializationHook = arguments[1].Value as string; + + hooks = (serializationHook, deserializationHook); + return serializationHook != null || deserializationHook != null; + } + + public bool TryGetCodeGenModelAttributeValue(AttributeData attributeData, out string[]? usage, out string[]? formats) + { + usage = null; + formats = null; + if (!CheckAttribute(attributeData, CodeGenModelAttribute)) + return false; + foreach (var namedArgument in attributeData.NamedArguments) + { + switch (namedArgument.Key) + { + case nameof(Azure.Core.CodeGenModelAttribute.Usage): + usage = ToStringArray(namedArgument.Value.Values); + break; + case nameof(Azure.Core.CodeGenModelAttribute.Formats): + formats = ToStringArray(namedArgument.Value.Values); + break; + } + } + + return usage != null || formats != null; + } + + private static string[]? ToStringArray(ImmutableArray values) + { + if (values.IsDefaultOrEmpty) + { + return null; + } + + return values + .Select(v => (string?)v.Value) + .OfType() + .ToArray(); + } + } +} diff --git a/src/AutoRest.CSharp/Common/Input/Source/ModelTypeMapping.cs b/src/AutoRest.CSharp/Common/Input/Source/ModelTypeMapping.cs index 29693fe10f0..1ef950415b7 100644 --- a/src/AutoRest.CSharp/Common/Input/Source/ModelTypeMapping.cs +++ b/src/AutoRest.CSharp/Common/Input/Source/ModelTypeMapping.cs @@ -20,7 +20,7 @@ public class ModelTypeMapping public string[]? Usage { get; } public string[]? Formats { get; } - public ModelTypeMapping(INamedTypeSymbol modelAttribute, INamedTypeSymbol memberAttribute, INamedTypeSymbol serializationAttribute, INamedTypeSymbol serializationHooksAttribute, INamedTypeSymbol? existingType) + public ModelTypeMapping(CodeGenAttributes codeGenAttributes, INamedTypeSymbol existingType) { _existingType = existingType; _propertyMappings = new(); @@ -28,95 +28,41 @@ public ModelTypeMapping(INamedTypeSymbol modelAttribute, INamedTypeSymbol member foreach (ISymbol member in GetMembers(existingType)) { + string[]? serializationPath = null; + (string? SerializationHook, string? DeserializationHook)? serializationHooks = null; foreach (var attributeData in member.GetAttributes()) { - var attributeTypeSymbol = attributeData.AttributeClass; // handle CodeGenMember attribute - if (SymbolEqualityComparer.Default.Equals(attributeTypeSymbol, memberAttribute) && TryGetCodeGenMemberAttributeValue(member, attributeData, out var schemaMemberName)) + if (codeGenAttributes.TryGetCodeGenMemberAttributeValue(attributeData, out var schemaMemberName)) { _propertyMappings.Add(schemaMemberName, member); } - string[]? serializationPath = null; - (string? SerializationHook, string? DeserializationHook)? serializationHooks = null; - if (SymbolEqualityComparer.Default.Equals(attributeTypeSymbol, serializationAttribute) && TryGetSerializationAttributeValue(member, attributeData, out var pathResult)) + // handle CodeGenMemberSerialization attribute + if (codeGenAttributes.TryGetCodeGenMemberSerializationAttributeValue(attributeData, out var pathResult)) { serializationPath = pathResult; } - if (SymbolEqualityComparer.Default.Equals(attributeTypeSymbol, serializationHooksAttribute) && TryGetSerializationHooks(member, attributeData, out var hooks)) + // handle CodeGenMemberSerializationHooks attribute + if (codeGenAttributes.TryGetCodeGenMemberSerializationHooksAttributeValue(attributeData, out var hooks)) { serializationHooks = hooks; } - if (serializationPath != null || serializationHooks != null) - { - _serializationMappings.Add(member, new SourcePropertySerializationMapping(member, serializationPath, serializationHooks?.SerializationHook, serializationHooks?.DeserializationHook)); - } } - } - - if (existingType != null) - { - foreach (var attributeData in existingType.GetAttributes()) + if (serializationPath != null || serializationHooks != null) { - if (SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, modelAttribute)) - { - foreach (var namedArgument in attributeData.NamedArguments) - { - switch (namedArgument.Key) - { - case nameof(CodeGenModelAttribute.Usage): - Usage = ToStringArray(namedArgument.Value.Values); - break; - case nameof(CodeGenModelAttribute.Formats): - Formats = ToStringArray(namedArgument.Value.Values); - break; - } - } - } + _serializationMappings.Add(member, new SourcePropertySerializationMapping(member, serializationPath, serializationHooks?.SerializationHook, serializationHooks?.DeserializationHook)); } } - } - - private static bool TryGetSerializationHooks(ISymbol symbol, AttributeData attributeData, out (string? SerializationHook, string? DeserializationHook) hooks) - { - string? serializationHook = null; - string? deserializationHook = null; - - var arguments = attributeData.ConstructorArguments; - serializationHook = arguments[0].Value as string; - deserializationHook = arguments[1].Value as string; - hooks = (serializationHook, deserializationHook); - return serializationHook != null || deserializationHook != null; - } - - private static bool TryGetCodeGenMemberAttributeValue(ISymbol symbol, AttributeData attributeData, [MaybeNullWhen(false)] out string name) - { - name = attributeData.ConstructorArguments.FirstOrDefault().Value as string; - return name != null; - } - - private static bool TryGetSerializationAttributeValue(ISymbol symbol, AttributeData attributeData, [MaybeNullWhen(false)] out string[] propertyNames) - { - propertyNames = null; - if (attributeData.ConstructorArguments.Length > 0) + foreach (var attributeData in existingType.GetAttributes()) { - propertyNames = ToStringArray(attributeData.ConstructorArguments[0].Values); - } - - return propertyNames != null; - } - - private static string[]? ToStringArray(ImmutableArray values) - { - if (values.IsDefaultOrEmpty) - { - return null; + // handle CodeGenModel attribute + if (codeGenAttributes.TryGetCodeGenModelAttributeValue(attributeData, out var usage, out var formats)) + { + Usage = usage; + Formats = formats; + } } - - return values - .Select(v => (string?)v.Value) - .OfType() - .ToArray(); } public SourceMemberMapping? GetForMember(string name) diff --git a/src/AutoRest.CSharp/Common/Input/Source/SourceInputModel.cs b/src/AutoRest.CSharp/Common/Input/Source/SourceInputModel.cs index 109ff97b66b..8e8c84e472f 100644 --- a/src/AutoRest.CSharp/Common/Input/Source/SourceInputModel.cs +++ b/src/AutoRest.CSharp/Common/Input/Source/SourceInputModel.cs @@ -15,12 +15,7 @@ public class SourceInputModel { private readonly Compilation _compilation; private readonly CompilationInput? _existingCompilation; - private readonly INamedTypeSymbol _typeAttribute; - private readonly INamedTypeSymbol _modelAttribute; - private readonly INamedTypeSymbol _clientAttribute; - private readonly INamedTypeSymbol _schemaMemberNameAttribute; - private readonly INamedTypeSymbol _serializationAttribute; - private readonly INamedTypeSymbol _serializationHooksAttribute; + private readonly CodeGenAttributes _codeGenAttributes; private readonly Dictionary _nameMap = new Dictionary(StringComparer.OrdinalIgnoreCase); public SourceInputModel(Compilation compilation, CompilationInput? existingCompilation = null) @@ -28,12 +23,7 @@ public SourceInputModel(Compilation compilation, CompilationInput? existingCompi _compilation = compilation; _existingCompilation = existingCompilation; - _schemaMemberNameAttribute = compilation.GetTypeByMetadataName(typeof(CodeGenMemberAttribute).FullName!)!; - _serializationAttribute = compilation.GetTypeByMetadataName(typeof(CodeGenMemberSerializationAttribute).FullName!)!; - _serializationHooksAttribute = compilation.GetTypeByMetadataName(typeof(CodeGenMemberSerializationHooksAttribute).FullName!)!; - _typeAttribute = compilation.GetTypeByMetadataName(typeof(CodeGenTypeAttribute).FullName!)!; - _modelAttribute = compilation.GetTypeByMetadataName(typeof(CodeGenModelAttribute).FullName!)!; - _clientAttribute = compilation.GetTypeByMetadataName(typeof(CodeGenClientAttribute).FullName!)!; + _codeGenAttributes = new CodeGenAttributes(compilation); IAssemblySymbol assembly = _compilation.Assembly; @@ -58,9 +48,12 @@ public SourceInputModel(Compilation compilation, CompilationInput? existingCompi return osvAttribute?.ConstructorArguments[0].Values.Select(v => v.Value).OfType().ToList(); } - public ModelTypeMapping CreateForModel(INamedTypeSymbol? symbol) + public ModelTypeMapping? CreateForModel(INamedTypeSymbol? symbol) { - return new ModelTypeMapping(_modelAttribute, _schemaMemberNameAttribute, _serializationAttribute, _serializationHooksAttribute, symbol); + if (symbol == null) + return null; + + return new ModelTypeMapping(_codeGenAttributes, symbol); } internal IMethodSymbol? FindMethod(string namespaceName, string typeName, string methodName, IEnumerable parameters) @@ -87,7 +80,7 @@ internal bool TryGetClientSourceInput(INamedTypeSymbol type, [NotNullWhen(true)] var attributeType = attribute.AttributeClass; while (attributeType != null) { - if (SymbolEqualityComparer.Default.Equals(attributeType, _clientAttribute)) + if (SymbolEqualityComparer.Default.Equals(attributeType, _codeGenAttributes.CodeGenClientAttribute)) { INamedTypeSymbol? parentClientType = null; foreach ((var argumentName, TypedConstant constant) in attribute.NamedArguments) @@ -119,7 +112,7 @@ private bool TryGetName(ISymbol symbol, [NotNullWhen(true)] out string? name) var type = attribute.AttributeClass; while (type != null) { - if (SymbolEqualityComparer.Default.Equals(type, _typeAttribute)) + if (SymbolEqualityComparer.Default.Equals(type, _codeGenAttributes.CodeGenTypeAttribute)) { if (attribute?.ConstructorArguments.Length > 0) {