Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Array translation improvements #2028

Merged
merged 1 commit into from
Oct 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions EFCore.PG.sln.DotSettings
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@
<s:Boolean x:Key="/Default/UserDictionary/Words/=datetimeoffset/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=doesnt/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=fallbacks/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=ilike/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=initializers/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=keyless/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=materializer/@EntryIndexedValue">True</s:Boolean>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,57 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Inte
/// </remarks>
public class NpgsqlArrayTranslator : IMethodCallTranslator, IMemberTranslator
{
private static readonly MethodInfo SequenceEqual =
#region Methods

private static readonly MethodInfo Array_IndexOf1 =
Copy link

@hez2010 hez2010 Oct 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend using below code to get definite method overload.

typeof(Array).GetMethod(nameof(Array.IndexOf), 1, BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly, null, CallingConventions.Any, new[] { Type.MakeGenericMethodParameter(0).MakeArrayType(), Type.MakeGenericMethodParameter(0) }, null)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @hez2010 you're right. This provider (as well as EF Core itself) has lots of places which aren't friendly to AOT/trimming, so I'll be doing a concentrated pass for this and other problematic practices. I've opened #2031 to track this specific suggestion.

typeof(Array).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(Array.IndexOf) && m.IsGenericMethod && m.GetParameters().Length == 2);

private static readonly MethodInfo Array_IndexOf2 =
typeof(Array).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(Array.IndexOf) && m.IsGenericMethod && m.GetParameters().Length == 3);

private static readonly MethodInfo Enumerable_Append =
typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(Enumerable.SequenceEqual) && m.GetParameters().Length == 2);
.Single(m => m.Name == nameof(Enumerable.Append) && m.GetParameters().Length == 2);

private static readonly MethodInfo Enumerable_AnyWithoutPredicate =
typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(mi => mi.Name == nameof(Enumerable.Any) && mi.GetParameters().Length == 1);

private static readonly MethodInfo EnumerableContains =
private static readonly MethodInfo Enumerable_Concat =
typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(Enumerable.Concat) && m.GetParameters().Length == 2);

private static readonly MethodInfo Enumerable_Contains =
typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(Enumerable.Contains) && m.GetParameters().Length == 2);

private static readonly MethodInfo EnumerableAnyWithoutPredicate =
private static readonly MethodInfo Enumerable_SequenceEqual =
typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(mi => mi.Name == nameof(Enumerable.Any) && mi.GetParameters().Length == 1);
.Single(m => m.Name == nameof(Enumerable.SequenceEqual) && m.GetParameters().Length == 2);

private static readonly MethodInfo String_Join1 =
typeof(string).GetMethod(nameof(string.Join), new[] { typeof(string), typeof(object[]) })!;

private static readonly MethodInfo String_Join2 =
typeof(string).GetMethod(nameof(string.Join), new[] { typeof(string), typeof(string[]) })!;

private static readonly MethodInfo String_Join3 =
typeof(string).GetMethod(nameof(string.Join), new[] { typeof(char), typeof(object[]) })!;

private static readonly MethodInfo String_Join4 =
typeof(string).GetMethod(nameof(string.Join), new[] { typeof(char), typeof(string[]) })!;

private static readonly MethodInfo String_Join_generic1 =
typeof(string).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(string.Join) && m.IsGenericMethod && m.GetParameters().Length == 2 && m.GetParameters()[0].ParameterType == typeof(string));

private static readonly MethodInfo String_Join_generic2 =
typeof(string).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(string.Join) && m.IsGenericMethod && m.GetParameters().Length == 2 && m.GetParameters()[0].ParameterType == typeof(char));

#endregion Methods

private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory;
private readonly NpgsqlJsonPocoTranslator _jsonPocoTranslator;
Expand Down Expand Up @@ -72,14 +112,31 @@ public NpgsqlArrayTranslator(
if (instance is null && arguments.Count > 0 && arguments[0].Type.IsArrayOrGenericList() && !IsMappedToNonArray(arguments[0]))
{
// Extension method over an array or list
if (method.IsClosedFormOf(SequenceEqual) && arguments[1].Type.IsArray)
if (method.IsClosedFormOf(Enumerable_SequenceEqual) && arguments[1].Type.IsArray)
{
return _sqlExpressionFactory.Equal(arguments[0], arguments[1]);
}

return TranslateCommon(arguments[0], arguments.Slice(1));
}

if (method.DeclaringType == typeof(string)
&& (method == String_Join1
|| method == String_Join2
|| method == String_Join3
|| method == String_Join4
|| method.IsClosedFormOf(String_Join_generic1)
|| method.IsClosedFormOf(String_Join_generic2))
&& !IsMappedToNonArray(arguments[0]))
{
return _sqlExpressionFactory.Function(
"array_to_string",
new[] { arguments[1], arguments[0], _sqlExpressionFactory.Constant("") },
nullable: true,
argumentsPropagateNullability: TrueArrays[3],
typeof(string));
}

// Not an array/list
return null;

Expand All @@ -92,7 +149,7 @@ static bool IsMappedToNonArray(SqlExpression arrayOrList)
SqlExpression? TranslateCommon(SqlExpression arrayOrList, IReadOnlyList<SqlExpression> arguments)
{
// Predicate-less Any - translate to a simple length check.
if (method.IsClosedFormOf(EnumerableAnyWithoutPredicate))
if (method.IsClosedFormOf(Enumerable_AnyWithoutPredicate))
{
return _sqlExpressionFactory.GreaterThan(
_jsonPocoTranslator.TranslateArrayLength(arrayOrList)
Expand All @@ -109,7 +166,7 @@ static bool IsMappedToNonArray(SqlExpression arrayOrList)
// is pattern-matched in AllAnyToContainsRewritingExpressionVisitor, which transforms it to
// new[] { "a", "b", "c" }.Contains(e.Some Text).

if ((method.IsClosedFormOf(EnumerableContains)
if ((method.IsClosedFormOf(Enumerable_Contains)
||
method.Name == nameof(List<int>.Contains)
&& method.DeclaringType.IsGenericList()
Expand Down Expand Up @@ -176,6 +233,77 @@ arrayOrList.TypeMapping is NpgsqlArrayTypeMapping or null
// Note: we also translate .Where(e => new[] { "a", "b", "c" }.Any(p => EF.Functions.Like(e.SomeText, p)))
// to LIKE ANY (...). See NpgsqlSqlTranslatingExpressionVisitor.VisitArrayMethodCall.

if (method.IsClosedFormOf(Enumerable_Append))
{
var (item, array) = _sqlExpressionFactory.ApplyTypeMappingsOnItemAndArray(arguments[0], arrayOrList);

return _sqlExpressionFactory.Function(
"array_append",
new[] { array, item },
nullable: true,
TrueArrays[2],
arrayOrList.Type,
arrayOrList.TypeMapping);
}

if (method.IsClosedFormOf(Enumerable_Concat))
{
var inferredMapping = ExpressionExtensions.InferTypeMapping(arrayOrList, arguments[0]);

return _sqlExpressionFactory.Function(
"array_cat",
new[]
{
_sqlExpressionFactory.ApplyTypeMapping(arrayOrList, inferredMapping),
_sqlExpressionFactory.ApplyTypeMapping(arguments[0], inferredMapping)
},
nullable: true,
TrueArrays[2],
arrayOrList.Type,
inferredMapping);
}

if (method.IsClosedFormOf(Array_IndexOf1)
||
method.Name == nameof(List<int>.IndexOf)
&& method.DeclaringType.IsGenericList()
&& method.GetParameters().Length == 1)
{
var (item, array) = _sqlExpressionFactory.ApplyTypeMappingsOnItemAndArray(arguments[0], arrayOrList);

return _sqlExpressionFactory.Coalesce(
_sqlExpressionFactory.Subtract(
_sqlExpressionFactory.Function(
"array_position",
new[] { array, item },
nullable: true,
TrueArrays[2],
arrayOrList.Type),
_sqlExpressionFactory.Constant(1)),
_sqlExpressionFactory.Constant(-1));
}

if (method.IsClosedFormOf(Array_IndexOf2)
||
method.Name == nameof(List<int>.IndexOf)
&& method.DeclaringType.IsGenericList()
&& method.GetParameters().Length == 2)
{
var (item, array) = _sqlExpressionFactory.ApplyTypeMappingsOnItemAndArray(arguments[0], arrayOrList);
var startIndex = _sqlExpressionFactory.GenerateOneBasedIndexExpression(arguments[1]);

return _sqlExpressionFactory.Coalesce(
_sqlExpressionFactory.Subtract(
_sqlExpressionFactory.Function(
"array_position",
new[] { array, item, startIndex },
nullable: true,
TrueArrays[3],
arrayOrList.Type),
_sqlExpressionFactory.Constant(1)),
_sqlExpressionFactory.Constant(-1));
}

return null;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
// Try translating ArrayIndex inside json column
_jsonPocoTranslator.TranslateMemberAccess(sqlLeft!, sqlRight!, binaryExpression.Type) ??
// Other types should be subscriptable - but PostgreSQL arrays are 1-based, so adjust the index.
_sqlExpressionFactory.ArrayIndex(sqlLeft!, GenerateOneBasedIndexExpression(sqlRight!));
_sqlExpressionFactory.ArrayIndex(sqlLeft!, _sqlExpressionFactory.GenerateOneBasedIndexExpression(sqlRight!));
}

return base.VisitBinary(binaryExpression);
Expand Down Expand Up @@ -509,15 +509,6 @@ bool TryTranslateArguments(out SqlExpression[] sqlArguments)
}
}

/// <summary>
/// PostgreSQL array indexing is 1-based. If the index happens to be a constant,
/// just increment it. Otherwise, append a +1 in the SQL.
/// </summary>
private SqlExpression GenerateOneBasedIndexExpression(SqlExpression expression)
=> expression is SqlConstantExpression constant
? _sqlExpressionFactory.Constant(Convert.ToInt32(constant.Value) + 1, constant.TypeMapping)
: _sqlExpressionFactory.Add(expression, _sqlExpressionFactory.Constant(1));

#region Copied from RelationalSqlTranslatingExpressionVisitor

private static Expression TryRemoveImplicitConvert(Expression expression)
Expand Down
26 changes: 22 additions & 4 deletions src/EFCore.PG/Query/NpgsqlSqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,9 @@ private SqlExpression ApplyTypeMappingOnAll(PostgresAllExpression postgresAllExp
return new PostgresAllExpression(item, array, postgresAllExpression.OperatorType, _boolTypeMapping);
}

private (SqlExpression, SqlExpression) ApplyTypeMappingsOnItemAndArray(SqlExpression itemExpression, SqlExpression arrayExpression)
public virtual (SqlExpression, SqlExpression) ApplyTypeMappingsOnItemAndArray(
SqlExpression itemExpression,
SqlExpression arrayExpression)
{
// Attempt type inference either from the operand to the array or the other way around
var arrayMapping = (NpgsqlArrayTypeMapping?)arrayExpression.TypeMapping;
Expand Down Expand Up @@ -464,16 +466,23 @@ private SqlExpression ApplyTypeMappingOnAll(PostgresAllExpression postgresAllExp
private SqlExpression ApplyTypeMappingOnArrayIndex(
PostgresArrayIndexExpression postgresArrayIndexExpression,
RelationalTypeMapping? typeMapping)
=> new PostgresArrayIndexExpression(
// TODO: Infer the array's mapping from the element
ApplyDefaultTypeMapping(postgresArrayIndexExpression.Array),
{
// If a (non-null) type mapping is being applied, it's to the element being indexed.
// Infer the array's mapping from that.
var (_, array) = typeMapping is not null
? ApplyTypeMappingsOnItemAndArray(Constant(null, typeMapping), postgresArrayIndexExpression.Array)
: (null, ApplyDefaultTypeMapping(postgresArrayIndexExpression.Array));

return new PostgresArrayIndexExpression(
array,
ApplyDefaultTypeMapping(postgresArrayIndexExpression.Index),
postgresArrayIndexExpression.Type,
// If the array has a type mapping (i.e. column), prefer that just like we prefer column mappings in general
postgresArrayIndexExpression.Array.TypeMapping is NpgsqlArrayTypeMapping arrayMapping
? arrayMapping.ElementMapping
: typeMapping
?? (RelationalTypeMapping?)_typeMappingSource.FindMapping(postgresArrayIndexExpression.Type, Dependencies.Model));
}

private SqlExpression ApplyTypeMappingOnILike(PostgresILikeExpression ilikeExpression)
{
Expand Down Expand Up @@ -749,5 +758,14 @@ private SqlExpression ApplyTypeMappingOnPostgresNewArray(
newExpressions ?? postgresNewArrayExpression.Expressions,
postgresNewArrayExpression.Type, arrayTypeMapping);
}

/// <summary>
/// PostgreSQL array indexing is 1-based. If the index happens to be a constant,
/// just increment it. Otherwise, append a +1 in the SQL.
/// </summary>
public virtual SqlExpression GenerateOneBasedIndexExpression(SqlExpression expression)
=> expression is SqlConstantExpression constant
? Constant(System.Convert.ToInt32(constant.Value) + 1, constant.TypeMapping)
: Add(expression, Constant(1));
}
}
Loading