Skip to content

Commit

Permalink
Fully support generic Lists in value conversion (npgsql#1610)
Browse files Browse the repository at this point in the history
  • Loading branch information
roji authored Dec 17, 2020
1 parent 8c5e77d commit 463b49e
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 54 deletions.
12 changes: 11 additions & 1 deletion src/EFCore.PG/Extensions/TypeExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;

#nullable enable

Expand All @@ -24,5 +24,15 @@ internal static bool TryGetElementType(this Type type, [NotNullWhen(true)] out T
: null;
return elementType != null;
}

public static PropertyInfo? FindIndexerProperty([NotNull] this Type type)
{
var defaultPropertyAttribute = type.GetCustomAttributes<DefaultMemberAttribute>().FirstOrDefault();

return defaultPropertyAttribute is null
? null
: type.GetRuntimeProperties()
.FirstOrDefault(pi => pi.Name == defaultPropertyAttribute.MemberName && pi.GetIndexParameters().Length == 1);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ arrayOrList.TypeMapping is NpgsqlArrayTypeMapping
{
// When the array is a column, we translate to array @> ARRAY[item]. GIN indexes
// on array are used, but null semantics is impossible without preventing index use.
case ColumnExpression _:
case ColumnExpression:
if (item is SqlConstantExpression constant && constant.Value is null)
{
// We special-case null constant item and use array_position instead, since it does
Expand All @@ -150,7 +150,7 @@ arrayOrList.TypeMapping is NpgsqlArrayTypeMapping
// for that case: item IN (1, 2, 3).
// After https://github.com/aspnet/EntityFrameworkCore/issues/16375 is done we may not need the
// check any more.
case SqlConstantExpression _:
case SqlConstantExpression:
return null;

// For ParameterExpression, and for all other cases - e.g. array returned from some function -
Expand Down
17 changes: 5 additions & 12 deletions src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -347,22 +347,17 @@ SqlExpression ApplyTypeMappingOnRegexMatch(PostgresRegexMatchExpression postgres

SqlExpression ApplyTypeMappingOnAny(PostgresAnyExpression postgresAnyExpression)
{
var (item, array) = ApplyTypeMappingsOnItemAndArray(
postgresAnyExpression.Item,
postgresAnyExpression.Array);
var (item, array) = ApplyTypeMappingsOnItemAndArray(postgresAnyExpression.Item, postgresAnyExpression.Array);
return new PostgresAnyExpression(item, array, postgresAnyExpression.OperatorType, _boolTypeMapping);
}

SqlExpression ApplyTypeMappingOnAll(PostgresAllExpression postgresAllExpression)
{
var (item, array) = ApplyTypeMappingsOnItemAndArray(
postgresAllExpression.Item,
postgresAllExpression.Array);
var (item, array) = ApplyTypeMappingsOnItemAndArray(postgresAllExpression.Item, postgresAllExpression.Array);
return new PostgresAllExpression(item, array, postgresAllExpression.OperatorType, _boolTypeMapping);
}

(SqlExpression, SqlExpression) ApplyTypeMappingsOnItemAndArray(
SqlExpression itemExpression, SqlExpression arrayExpression)
(SqlExpression, SqlExpression) ApplyTypeMappingsOnItemAndArray(SqlExpression itemExpression, SqlExpression arrayExpression)
{
// Attempt type inference either from the operand to the array or the other way around
var arrayMapping = (NpgsqlArrayTypeMapping)arrayExpression.TypeMapping;
Expand All @@ -373,12 +368,10 @@ SqlExpression ApplyTypeMappingOnAll(PostgresAllExpression postgresAllExpression)

// Note that we provide both the array CLR type *and* an array store type constructed from the element's
// store type. If we use only the array CLR type, byte[] will yield bytea which we don't want.
arrayMapping ??= (NpgsqlArrayTypeMapping)_typeMappingSource.FindMapping(
arrayExpression.Type, itemMapping.StoreType + "[]");
arrayMapping ??= (NpgsqlArrayTypeMapping)_typeMappingSource.FindMapping(arrayExpression.Type, itemMapping.StoreType + "[]");

if (itemMapping == null || arrayMapping == null)
throw new InvalidOperationException(
"Couldn't find array or element type mapping in ArrayAnyAllExpression");
throw new InvalidOperationException("Couldn't find array or element type mapping in ArrayAnyAllExpression");

return (
ApplyTypeMapping(itemExpression, itemMapping),
Expand Down
51 changes: 37 additions & 14 deletions src/EFCore.PG/Storage/ValueConversion/NpgsqlArrayConverter.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq.Expressions;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Storage.ValueConversion;

Expand All @@ -15,13 +17,16 @@ public class NpgsqlArrayConverter<TModelArray, TProviderArray> : ValueConverter<
public NpgsqlArrayConverter([NotNull] ValueConverter elementConverter)
: base(ToProviderExpression(elementConverter), FromProviderExpression(elementConverter))
{
// TODO: List support
if (!typeof(TModelArray).IsArray || !typeof(TProviderArray).IsArray)
if (!typeof(TModelArray).TryGetElementType(out var modelElementType) ||
!typeof(TProviderArray).TryGetElementType(out var providerElementType))
{
throw new ArgumentException("Can only convert between arrays");
if (typeof(TModelArray).GetElementType() != elementConverter.ModelClrType)
throw new ArgumentException($"The element's value converter model type ({elementConverter.ModelClrType}), doesn't match the array's ({typeof(TModelArray).GetElementType()})");
if (typeof(TProviderArray).GetElementType() != elementConverter.ProviderClrType)
throw new ArgumentException($"The element's value converter provider type ({elementConverter.ProviderClrType}), doesn't match the array's ({typeof(TProviderArray).GetElementType()})");
}

if (modelElementType != elementConverter.ModelClrType)
throw new ArgumentException($"The element's value converter model type ({elementConverter.ModelClrType}), doesn't match the array's ({modelElementType})");
if (providerElementType != elementConverter.ProviderClrType)
throw new ArgumentException($"The element's value converter provider type ({elementConverter.ProviderClrType}), doesn't match the array's ({providerElementType})");
}

static Expression<Func<TModelArray, TProviderArray>> ToProviderExpression(ValueConverter elementConverter)
Expand All @@ -36,8 +41,11 @@ static Expression<Func<TProviderArray, TModelArray>> FromProviderExpression(Valu
/// </summary>
static Expression<Func<TInput, TOutput>> ArrayConversionExpression<TInput, TOutput>(LambdaExpression elementConversionExpression)
{
var outputElementType = typeof(TOutput).GetElementType();
Debug.Assert(outputElementType != null);
Debug.Assert(typeof(TInput).IsArrayOrGenericList());
Debug.Assert(typeof(TOutput).IsArrayOrGenericList());

var result = typeof(TOutput).TryGetElementType(out var outputElementType);
Debug.Assert(result);

var inputArray = Expression.Parameter(typeof(TInput), "value");
var outputArray = Expression.Parameter(typeof(TOutput), "result");
Expand All @@ -53,22 +61,37 @@ static Expression<Func<TInput, TOutput>> ArrayConversionExpression<TInput, TOutp
typeof(TOutput),
new[] { outputArray, arrayLengthVariable, loopVariable },

// Get the length of the input array, allocate an output array and loop over the elements, converting them
Expression.Assign(arrayLengthVariable, Expression.ArrayLength(inputArray)),
Expression.Assign(outputArray, Expression.NewArrayBounds(outputElementType, arrayLengthVariable)),
// Get the length of the input array or list
Expression.Assign(arrayLengthVariable, typeof(TInput).IsArray
? Expression.ArrayLength(inputArray)
: Expression.Property(inputArray,
typeof(TInput).GetProperty(nameof(List<TModelArray>.Count))!)),

// Allocate an output array or list
Expression.Assign(outputArray, typeof(TOutput).IsArray
? Expression.NewArrayBounds(outputElementType, arrayLengthVariable)
: Expression.New(typeof(TOutput))),

// Loop over the elements, applying the element converter on them one by one
ForLoop(
loopVar: loopVariable,
initValue: Expression.Constant(0),
condition: Expression.LessThan(loopVariable, arrayLengthVariable), Expression.AddAssign(loopVariable, Expression.Constant(1)),
condition: Expression.LessThan(loopVariable, arrayLengthVariable),
increment: Expression.AddAssign(loopVariable, Expression.Constant(1)),
loopContent:
Expression.Assign(
Expression.ArrayAccess(outputArray, loopVariable),
AccessArrayOrList(outputArray, loopVariable),
Expression.Invoke(
elementConversionExpression,
Expression.ArrayAccess(inputArray, loopVariable)))),
AccessArrayOrList(inputArray, loopVariable)))),
outputArray
)),
inputArray);

static Expression AccessArrayOrList(Expression arrayOrList, Expression index)
=> arrayOrList.Type.IsArray
? Expression.ArrayAccess(arrayOrList, index)
: Expression.Property(arrayOrList, arrayOrList.Type.FindIndexerProperty()!, index);
}

static Expression ForLoop(ParameterExpression loopVar, Expression initValue, Expression condition, Expression increment, Expression loopContent)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Storage.ValueConversion;

Expand All @@ -17,36 +17,37 @@ public NpgsqlValueConverterSelector([NotNull] ValueConverterSelectorDependencies
: base(dependencies) {}

/// <inheritdoc />
public override IEnumerable<ValueConverterInfo> Select(Type modelArrayClrType, Type providerArrayClrType = null)
public override IEnumerable<ValueConverterInfo> Select(Type modelClrType, Type providerClrType = null)
{
var converters = default(IEnumerable<ValueConverterInfo>);
if (modelArrayClrType.IsArray && (providerArrayClrType == null || providerArrayClrType.IsArray))
{
var modelElementType = modelArrayClrType.GetElementType();
Debug.Assert(modelElementType != null);

var providerElementType = default(Type);
if (providerArrayClrType != null)
{
providerElementType = providerArrayClrType.GetElementType();
Debug.Assert(providerElementType != null);
}
var providerElementType = default(Type);

// For each ValueConverterInfo selected by the superclass for the element type, return a ValueConverterInfo for its array type
converters = base
if (modelClrType.TryGetElementType(out var modelElementType) &&
(providerClrType == null || providerClrType.TryGetElementType(out providerElementType)))
{
// For each ValueConverterInfo selected by the superclass for the element type,
// return a ValueConverterInfo for its array type
return base
.Select(modelElementType, providerElementType)
.Select(elementConverterInfo => _arrayConverters.GetOrAdd(
(elementConverterInfo.ModelClrType, elementConverterInfo.ProviderClrType),
.Select(elementConverterInfo => new
{
ModelArrayType = modelClrType,
ProviderArrayType = providerClrType ?? elementConverterInfo.ProviderClrType.MakeArrayType(),
ElementConverterInfo = elementConverterInfo
})
.Select(x => _arrayConverters.GetOrAdd(
(x.ModelArrayType, x.ProviderArrayType),
new ValueConverterInfo(
modelArrayClrType,
elementConverterInfo.ProviderClrType.MakeArrayType(),
ci => (ValueConverter)Activator.CreateInstance(
typeof(NpgsqlArrayConverter<,>).MakeGenericType(modelArrayClrType, elementConverterInfo.ProviderClrType.MakeArrayType()),
elementConverterInfo.Create()))));
x.ModelArrayType,
x.ProviderArrayType,
_ => (ValueConverter)Activator.CreateInstance(
typeof(NpgsqlArrayConverter<,>).MakeGenericType(
modelClrType,
x.ProviderArrayType),
x.ElementConverterInfo.Create()))))
.Concat(base.Select(modelClrType, providerClrType));
}

var baseConverters = base.Select(modelArrayClrType, providerArrayClrType);
return converters == null ? baseConverters : converters.Concat(baseConverters);
return base.Select(modelClrType, providerClrType);
}
}
}
48 changes: 48 additions & 0 deletions test/EFCore.PG.FunctionalTests/Query/ArrayQueryTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,46 @@ SELECT s.""Id""
LIMIT 2");
}

[Fact]
public void Array_param_Contains_value_converted_column()
{
using var ctx = CreateContext();
var list = new[] { Guid.Empty, Guid.NewGuid() };
var id = ctx.SomeEntities
.Where(e => list.Contains(e.ValueConvertedGuid))
.Select(e => e.Id)
.Single();

Assert.Equal(2, id);
AssertSql(
@"@__list_0='System.String[]' (DbType = Object)
SELECT s.""Id""
FROM ""SomeEntities"" AS s
WHERE s.""ValueConvertedGuid"" = ANY (@__list_0)
LIMIT 2");
}

[Fact]
public void List_param_Contains_value_converted_column()
{
using var ctx = CreateContext();
var list = new List<Guid> { Guid.Empty, Guid.NewGuid() };
var id = ctx.SomeEntities
.Where(e => list.Contains(e.ValueConvertedGuid))
.Select(e => e.Id)
.Single();

Assert.Equal(2, id);
AssertSql(
@"@__list_0='System.String[]' (DbType = Object)
SELECT s.""Id""
FROM ""SomeEntities"" AS s
WHERE s.""ValueConvertedGuid"" = ANY (@__list_0)
LIMIT 2");
}

[Fact]
public void Byte_array_parameter_contains_column()
{
Expand Down Expand Up @@ -888,6 +928,11 @@ public class ArrayArrayQueryContext : PoolableDbContext

public ArrayArrayQueryContext(DbContextOptions options) : base(options) {}

protected override void OnModelCreating(ModelBuilder modelBuilder)
=> modelBuilder.Entity<SomeArrayEntity>()
.Property(e => e.ValueConvertedGuid)
.HasColumnType("text");

public static void Seed(ArrayArrayQueryContext context)
{
context.SomeEntities.AddRange(
Expand All @@ -905,6 +950,7 @@ public static void Seed(ArrayArrayQueryContext context)
IntMatrix = new[,] { { 5, 6 }, { 7, 8 } },
NullableText = "foo",
NonNullableText = "foo",
ValueConvertedGuid = Guid.Parse("54b46885-a17c-49f0-a12e-08ae7d7da5ca"),
Byte = 10
},
new SomeArrayEntity
Expand All @@ -921,6 +967,7 @@ public static void Seed(ArrayArrayQueryContext context)
IntMatrix = new[,] { { 10, 11 }, { 12, 13 } },
NullableText = "bar",
NonNullableText = "bar",
ValueConvertedGuid = Guid.Empty,
Byte = 20
});
context.SaveChanges();
Expand All @@ -942,6 +989,7 @@ public class SomeArrayEntity
public string NullableText { get; set; }
[Required]
public string NonNullableText { get; set; }
public Guid ValueConvertedGuid { get; set; }
public byte Byte { get; set; }
}

Expand Down

0 comments on commit 463b49e

Please sign in to comment.