diff --git a/src/EntityFrameworkCore.Generator.Core/Extensions/GenerationExtensions.cs b/src/EntityFrameworkCore.Generator.Core/Extensions/GenerationExtensions.cs index ea58269..079b251 100644 --- a/src/EntityFrameworkCore.Generator.Core/Extensions/GenerationExtensions.cs +++ b/src/EntityFrameworkCore.Generator.Core/Extensions/GenerationExtensions.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; @@ -46,6 +46,12 @@ public static class GenerationExtensions "class_initialize" }; + private static readonly List<string> _defaultUsings = new List<string>() + { + "System.Collections.Generic", + "System" + }; + private static readonly Dictionary<string, string> _csharpTypeAlias = new(16) { {"System.Int16", "short"}, @@ -115,6 +121,19 @@ public static string ToType(this Type type, CodeLanguage language = CodeLanguage { ArgumentNullException.ThrowIfNull(type); + if (type.IsGenericType) + { + var genericType = type.GetGenericTypeDefinition().FullName! + .Split('`')[0]; // trim the `1 bit + + genericType = ToType(genericType, language); + + var elementType = ToType(type.GetGenericArguments()[0].FullName!, language); + return language == CodeLanguage.VisualBasic + ? $"{genericType}(Of {elementType})" + : $"{genericType}<{elementType}>"; + } + return ToType(type.FullName ?? type.Name, language); } @@ -128,34 +147,26 @@ public static string ToType(this string type, CodeLanguage language = CodeLangua if (language == CodeLanguage.CSharp && _csharpTypeAlias.TryGetValue(type, out var t)) return t; - // drop system from namespace - var parts = type.Split('.'); - if (parts.Length == 2 && parts[0] == "System") - return parts[1]; + // drop common namespaces + foreach (var defaultUsing in _defaultUsings) + if (type.StartsWith(defaultUsing)) + return type.Remove(0, defaultUsing.Length + 1); return type; } public static string? ToNullableType(this Type type, bool isNullable = false, CodeLanguage language = CodeLanguage.CSharp) { - return ToNullableType(type.FullName, isNullable, language); - } - - public static string? ToNullableType(this string? type, bool isNullable = false, CodeLanguage language = CodeLanguage.CSharp) - { - if (string.IsNullOrEmpty(type)) - return null; - - bool isValueType = type.IsValueType(); + bool isValueType = type.IsValueType; - type = type.ToType(language); + var typeString = type.ToType(language); if (!isValueType || !isNullable) - return type; + return typeString; return language == CodeLanguage.VisualBasic ? $"Nullable(Of {type})" - : type + "?"; + : typeString + "?"; } public static bool IsValueType(this string? type)