From 463b49e00bfed094d2fa9741b8d5fa300328707b Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Thu, 17 Dec 2020 16:45:53 +0200 Subject: [PATCH] Fully support generic Lists in value conversion (#1610) Fixes #1606 --- src/EFCore.PG/Extensions/TypeExtensions.cs | 12 ++++- .../Internal/NpgsqlArrayTranslator.cs | 4 +- .../Query/NpgsqlSqlExpressionFactory.cs | 17 ++----- .../ValueConversion/NpgsqlArrayConverter.cs | 51 ++++++++++++++----- .../NpgsqlValueConverterSelector.cs | 51 ++++++++++--------- .../Query/ArrayQueryTest.cs | 48 +++++++++++++++++ 6 files changed, 129 insertions(+), 54 deletions(-) diff --git a/src/EFCore.PG/Extensions/TypeExtensions.cs b/src/EFCore.PG/Extensions/TypeExtensions.cs index 510020065..14aa9143d 100644 --- a/src/EFCore.PG/Extensions/TypeExtensions.cs +++ b/src/EFCore.PG/Extensions/TypeExtensions.cs @@ -1,6 +1,6 @@ -using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; +using System.Linq; #nullable enable @@ -24,5 +24,15 @@ internal static bool TryGetElementType(this Type type, [NotNullWhen(true)] out T : null; return elementType != null; } + + public static PropertyInfo? FindIndexerProperty([NotNull] this Type type) + { + var defaultPropertyAttribute = type.GetCustomAttributes().FirstOrDefault(); + + return defaultPropertyAttribute is null + ? null + : type.GetRuntimeProperties() + .FirstOrDefault(pi => pi.Name == defaultPropertyAttribute.MemberName && pi.GetIndexParameters().Length == 1); + } } } diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlArrayTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlArrayTranslator.cs index bf9a32a81..d1cf7c571 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlArrayTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlArrayTranslator.cs @@ -127,7 +127,7 @@ arrayOrList.TypeMapping is NpgsqlArrayTypeMapping { // When the array is a column, we translate to array @> ARRAY[item]. GIN indexes // on array are used, but null semantics is impossible without preventing index use. - case ColumnExpression _: + case ColumnExpression: if (item is SqlConstantExpression constant && constant.Value is null) { // We special-case null constant item and use array_position instead, since it does @@ -150,7 +150,7 @@ arrayOrList.TypeMapping is NpgsqlArrayTypeMapping // for that case: item IN (1, 2, 3). // After https://github.com/aspnet/EntityFrameworkCore/issues/16375 is done we may not need the // check any more. - case SqlConstantExpression _: + case SqlConstantExpression: return null; // For ParameterExpression, and for all other cases - e.g. array returned from some function - diff --git a/src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs b/src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs index dd11f4366..5a173190d 100644 --- a/src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs +++ b/src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs @@ -347,22 +347,17 @@ SqlExpression ApplyTypeMappingOnRegexMatch(PostgresRegexMatchExpression postgres SqlExpression ApplyTypeMappingOnAny(PostgresAnyExpression postgresAnyExpression) { - var (item, array) = ApplyTypeMappingsOnItemAndArray( - postgresAnyExpression.Item, - postgresAnyExpression.Array); + var (item, array) = ApplyTypeMappingsOnItemAndArray(postgresAnyExpression.Item, postgresAnyExpression.Array); return new PostgresAnyExpression(item, array, postgresAnyExpression.OperatorType, _boolTypeMapping); } SqlExpression ApplyTypeMappingOnAll(PostgresAllExpression postgresAllExpression) { - var (item, array) = ApplyTypeMappingsOnItemAndArray( - postgresAllExpression.Item, - postgresAllExpression.Array); + var (item, array) = ApplyTypeMappingsOnItemAndArray(postgresAllExpression.Item, postgresAllExpression.Array); return new PostgresAllExpression(item, array, postgresAllExpression.OperatorType, _boolTypeMapping); } - (SqlExpression, SqlExpression) ApplyTypeMappingsOnItemAndArray( - SqlExpression itemExpression, SqlExpression arrayExpression) + (SqlExpression, SqlExpression) ApplyTypeMappingsOnItemAndArray(SqlExpression itemExpression, SqlExpression arrayExpression) { // Attempt type inference either from the operand to the array or the other way around var arrayMapping = (NpgsqlArrayTypeMapping)arrayExpression.TypeMapping; @@ -373,12 +368,10 @@ SqlExpression ApplyTypeMappingOnAll(PostgresAllExpression postgresAllExpression) // Note that we provide both the array CLR type *and* an array store type constructed from the element's // store type. If we use only the array CLR type, byte[] will yield bytea which we don't want. - arrayMapping ??= (NpgsqlArrayTypeMapping)_typeMappingSource.FindMapping( - arrayExpression.Type, itemMapping.StoreType + "[]"); + arrayMapping ??= (NpgsqlArrayTypeMapping)_typeMappingSource.FindMapping(arrayExpression.Type, itemMapping.StoreType + "[]"); if (itemMapping == null || arrayMapping == null) - throw new InvalidOperationException( - "Couldn't find array or element type mapping in ArrayAnyAllExpression"); + throw new InvalidOperationException("Couldn't find array or element type mapping in ArrayAnyAllExpression"); return ( ApplyTypeMapping(itemExpression, itemMapping), diff --git a/src/EFCore.PG/Storage/ValueConversion/NpgsqlArrayConverter.cs b/src/EFCore.PG/Storage/ValueConversion/NpgsqlArrayConverter.cs index cdf89223c..4f5a8d7df 100644 --- a/src/EFCore.PG/Storage/ValueConversion/NpgsqlArrayConverter.cs +++ b/src/EFCore.PG/Storage/ValueConversion/NpgsqlArrayConverter.cs @@ -1,6 +1,8 @@ using System; +using System.Collections.Generic; using System.Diagnostics; using System.Linq.Expressions; +using System.Reflection; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Storage.ValueConversion; @@ -15,13 +17,16 @@ public class NpgsqlArrayConverter : ValueConverter< public NpgsqlArrayConverter([NotNull] ValueConverter elementConverter) : base(ToProviderExpression(elementConverter), FromProviderExpression(elementConverter)) { - // TODO: List support - if (!typeof(TModelArray).IsArray || !typeof(TProviderArray).IsArray) + if (!typeof(TModelArray).TryGetElementType(out var modelElementType) || + !typeof(TProviderArray).TryGetElementType(out var providerElementType)) + { throw new ArgumentException("Can only convert between arrays"); - if (typeof(TModelArray).GetElementType() != elementConverter.ModelClrType) - throw new ArgumentException($"The element's value converter model type ({elementConverter.ModelClrType}), doesn't match the array's ({typeof(TModelArray).GetElementType()})"); - if (typeof(TProviderArray).GetElementType() != elementConverter.ProviderClrType) - throw new ArgumentException($"The element's value converter provider type ({elementConverter.ProviderClrType}), doesn't match the array's ({typeof(TProviderArray).GetElementType()})"); + } + + if (modelElementType != elementConverter.ModelClrType) + throw new ArgumentException($"The element's value converter model type ({elementConverter.ModelClrType}), doesn't match the array's ({modelElementType})"); + if (providerElementType != elementConverter.ProviderClrType) + throw new ArgumentException($"The element's value converter provider type ({elementConverter.ProviderClrType}), doesn't match the array's ({providerElementType})"); } static Expression> ToProviderExpression(ValueConverter elementConverter) @@ -36,8 +41,11 @@ static Expression> FromProviderExpression(Valu /// static Expression> ArrayConversionExpression(LambdaExpression elementConversionExpression) { - var outputElementType = typeof(TOutput).GetElementType(); - Debug.Assert(outputElementType != null); + Debug.Assert(typeof(TInput).IsArrayOrGenericList()); + Debug.Assert(typeof(TOutput).IsArrayOrGenericList()); + + var result = typeof(TOutput).TryGetElementType(out var outputElementType); + Debug.Assert(result); var inputArray = Expression.Parameter(typeof(TInput), "value"); var outputArray = Expression.Parameter(typeof(TOutput), "result"); @@ -53,22 +61,37 @@ static Expression> ArrayConversionExpression.Count))!)), + + // Allocate an output array or list + Expression.Assign(outputArray, typeof(TOutput).IsArray + ? Expression.NewArrayBounds(outputElementType, arrayLengthVariable) + : Expression.New(typeof(TOutput))), + + // Loop over the elements, applying the element converter on them one by one ForLoop( loopVar: loopVariable, initValue: Expression.Constant(0), - condition: Expression.LessThan(loopVariable, arrayLengthVariable), Expression.AddAssign(loopVariable, Expression.Constant(1)), + condition: Expression.LessThan(loopVariable, arrayLengthVariable), + increment: Expression.AddAssign(loopVariable, Expression.Constant(1)), loopContent: Expression.Assign( - Expression.ArrayAccess(outputArray, loopVariable), + AccessArrayOrList(outputArray, loopVariable), Expression.Invoke( elementConversionExpression, - Expression.ArrayAccess(inputArray, loopVariable)))), + AccessArrayOrList(inputArray, loopVariable)))), outputArray )), inputArray); + + static Expression AccessArrayOrList(Expression arrayOrList, Expression index) + => arrayOrList.Type.IsArray + ? Expression.ArrayAccess(arrayOrList, index) + : Expression.Property(arrayOrList, arrayOrList.Type.FindIndexerProperty()!, index); } static Expression ForLoop(ParameterExpression loopVar, Expression initValue, Expression condition, Expression increment, Expression loopContent) diff --git a/src/EFCore.PG/Storage/ValueConversion/NpgsqlValueConverterSelector.cs b/src/EFCore.PG/Storage/ValueConversion/NpgsqlValueConverterSelector.cs index 5df623333..a68d6e035 100644 --- a/src/EFCore.PG/Storage/ValueConversion/NpgsqlValueConverterSelector.cs +++ b/src/EFCore.PG/Storage/ValueConversion/NpgsqlValueConverterSelector.cs @@ -1,8 +1,8 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; -using System.Diagnostics; using System.Linq; +using System.Reflection; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Storage.ValueConversion; @@ -17,36 +17,37 @@ public NpgsqlValueConverterSelector([NotNull] ValueConverterSelectorDependencies : base(dependencies) {} /// - public override IEnumerable Select(Type modelArrayClrType, Type providerArrayClrType = null) + public override IEnumerable Select(Type modelClrType, Type providerClrType = null) { - var converters = default(IEnumerable); - if (modelArrayClrType.IsArray && (providerArrayClrType == null || providerArrayClrType.IsArray)) - { - var modelElementType = modelArrayClrType.GetElementType(); - Debug.Assert(modelElementType != null); - - var providerElementType = default(Type); - if (providerArrayClrType != null) - { - providerElementType = providerArrayClrType.GetElementType(); - Debug.Assert(providerElementType != null); - } + var providerElementType = default(Type); - // For each ValueConverterInfo selected by the superclass for the element type, return a ValueConverterInfo for its array type - converters = base + if (modelClrType.TryGetElementType(out var modelElementType) && + (providerClrType == null || providerClrType.TryGetElementType(out providerElementType))) + { + // For each ValueConverterInfo selected by the superclass for the element type, + // return a ValueConverterInfo for its array type + return base .Select(modelElementType, providerElementType) - .Select(elementConverterInfo => _arrayConverters.GetOrAdd( - (elementConverterInfo.ModelClrType, elementConverterInfo.ProviderClrType), + .Select(elementConverterInfo => new + { + ModelArrayType = modelClrType, + ProviderArrayType = providerClrType ?? elementConverterInfo.ProviderClrType.MakeArrayType(), + ElementConverterInfo = elementConverterInfo + }) + .Select(x => _arrayConverters.GetOrAdd( + (x.ModelArrayType, x.ProviderArrayType), new ValueConverterInfo( - modelArrayClrType, - elementConverterInfo.ProviderClrType.MakeArrayType(), - ci => (ValueConverter)Activator.CreateInstance( - typeof(NpgsqlArrayConverter<,>).MakeGenericType(modelArrayClrType, elementConverterInfo.ProviderClrType.MakeArrayType()), - elementConverterInfo.Create())))); + x.ModelArrayType, + x.ProviderArrayType, + _ => (ValueConverter)Activator.CreateInstance( + typeof(NpgsqlArrayConverter<,>).MakeGenericType( + modelClrType, + x.ProviderArrayType), + x.ElementConverterInfo.Create())))) + .Concat(base.Select(modelClrType, providerClrType)); } - var baseConverters = base.Select(modelArrayClrType, providerArrayClrType); - return converters == null ? baseConverters : converters.Concat(baseConverters); + return base.Select(modelClrType, providerClrType); } } } diff --git a/test/EFCore.PG.FunctionalTests/Query/ArrayQueryTest.cs b/test/EFCore.PG.FunctionalTests/Query/ArrayQueryTest.cs index c67c22757..dbec7405d 100644 --- a/test/EFCore.PG.FunctionalTests/Query/ArrayQueryTest.cs +++ b/test/EFCore.PG.FunctionalTests/Query/ArrayQueryTest.cs @@ -531,6 +531,46 @@ SELECT s.""Id"" LIMIT 2"); } + [Fact] + public void Array_param_Contains_value_converted_column() + { + using var ctx = CreateContext(); + var list = new[] { Guid.Empty, Guid.NewGuid() }; + var id = ctx.SomeEntities + .Where(e => list.Contains(e.ValueConvertedGuid)) + .Select(e => e.Id) + .Single(); + + Assert.Equal(2, id); + AssertSql( + @"@__list_0='System.String[]' (DbType = Object) + +SELECT s.""Id"" +FROM ""SomeEntities"" AS s +WHERE s.""ValueConvertedGuid"" = ANY (@__list_0) +LIMIT 2"); + } + + [Fact] + public void List_param_Contains_value_converted_column() + { + using var ctx = CreateContext(); + var list = new List { Guid.Empty, Guid.NewGuid() }; + var id = ctx.SomeEntities + .Where(e => list.Contains(e.ValueConvertedGuid)) + .Select(e => e.Id) + .Single(); + + Assert.Equal(2, id); + AssertSql( + @"@__list_0='System.String[]' (DbType = Object) + +SELECT s.""Id"" +FROM ""SomeEntities"" AS s +WHERE s.""ValueConvertedGuid"" = ANY (@__list_0) +LIMIT 2"); + } + [Fact] public void Byte_array_parameter_contains_column() { @@ -888,6 +928,11 @@ public class ArrayArrayQueryContext : PoolableDbContext public ArrayArrayQueryContext(DbContextOptions options) : base(options) {} + protected override void OnModelCreating(ModelBuilder modelBuilder) + => modelBuilder.Entity() + .Property(e => e.ValueConvertedGuid) + .HasColumnType("text"); + public static void Seed(ArrayArrayQueryContext context) { context.SomeEntities.AddRange( @@ -905,6 +950,7 @@ public static void Seed(ArrayArrayQueryContext context) IntMatrix = new[,] { { 5, 6 }, { 7, 8 } }, NullableText = "foo", NonNullableText = "foo", + ValueConvertedGuid = Guid.Parse("54b46885-a17c-49f0-a12e-08ae7d7da5ca"), Byte = 10 }, new SomeArrayEntity @@ -921,6 +967,7 @@ public static void Seed(ArrayArrayQueryContext context) IntMatrix = new[,] { { 10, 11 }, { 12, 13 } }, NullableText = "bar", NonNullableText = "bar", + ValueConvertedGuid = Guid.Empty, Byte = 20 }); context.SaveChanges(); @@ -942,6 +989,7 @@ public class SomeArrayEntity public string NullableText { get; set; } [Required] public string NonNullableText { get; set; } + public Guid ValueConvertedGuid { get; set; } public byte Byte { get; set; } }