From e89465f74870e294f6929d6eb6f5d324691660b2 Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Mon, 23 May 2022 14:48:23 -0700 Subject: [PATCH] Query: Add support for custom aggregate operators Resolves #22957 - Introduce IAggregateMethodCallTranslator and family to translate aggregate methods - Currently we assume that either method instance or first argument (for static method) will be the one which can be of type enumerable, rest of the arguments if any are scalar. - Also refactor code in SqlTranslator.VisitMethodCall to avoid visiting grouping element source multiple times --- ...ntityFrameworkRelationalServicesBuilder.cs | 3 + .../Query/IAggregateMethodCallTranslator.cs | 32 + .../IAggregateMethodCallTranslatorPlugin.cs | 22 + .../IAggregateMethodCallTranslatorProvider.cs | 35 + .../Query/IMethodCallTranslator.cs | 1 - .../Query/IMethodCallTranslatorProvider.cs | 10 +- .../QueryableAggregateMethodTranslator.cs | 180 +++++ ...alAggregateMethodCallTranslatorProvider.cs | 89 +++ ...ethodCallTranslatorProviderDependencies.cs | 72 ++ ...elationalCompiledQueryCacheKeyGenerator.cs | 52 +- .../RelationalEvaluatableExpressionFilter.cs | 16 +- .../RelationalMemberTranslatorProvider.cs | 11 +- .../RelationalMethodCallTranslatorProvider.cs | 11 +- ...yableMethodTranslatingExpressionVisitor.cs | 42 +- ...ranslatingExpressionVisitorDependencies.cs | 16 +- ...lationalSqlTranslatingExpressionVisitor.cs | 626 +++++++++--------- ...ranslatingExpressionVisitorDependencies.cs | 13 +- ...lSqlTranslatingExpressionVisitorFactory.cs | 17 +- .../SqlServerServiceCollectionExtensions.cs | 1 + ...erAggregateMethodCallTranslatorProvider.cs | 30 + .../SqlServerLongCountMethodTranslator.cs | 74 +++ ...qlServerSqlTranslatingExpressionVisitor.cs | 15 - .../SqliteServiceCollectionExtensions.cs | 1 + ...teAggregateMethodCallTranslatorProvider.cs | 31 + ...qliteQueryableAggregateMethodTranslator.cs | 110 +++ .../SqliteSqlTranslatingExpressionVisitor.cs | 83 --- .../Query/CompiledQueryCacheKeyGenerator.cs | 58 +- .../Query/EvaluatableExpressionFilter.cs | 22 +- .../Query/ExpressionEqualityComparer.cs | 13 +- .../Query/ICompiledQueryCacheKeyGenerator.cs | 9 +- .../NorthwindGroupByQuerySqlServerTest.cs | 12 +- 31 files changed, 1102 insertions(+), 605 deletions(-) create mode 100644 src/EFCore.Relational/Query/IAggregateMethodCallTranslator.cs create mode 100644 src/EFCore.Relational/Query/IAggregateMethodCallTranslatorPlugin.cs create mode 100644 src/EFCore.Relational/Query/IAggregateMethodCallTranslatorProvider.cs create mode 100644 src/EFCore.Relational/Query/Internal/QueryableAggregateMethodTranslator.cs create mode 100644 src/EFCore.Relational/Query/RelationalAggregateMethodCallTranslatorProvider.cs create mode 100644 src/EFCore.Relational/Query/RelationalAggregateMethodCallTranslatorProviderDependencies.cs create mode 100644 src/EFCore.SqlServer/Query/Internal/SqlServerAggregateMethodCallTranslatorProvider.cs create mode 100644 src/EFCore.SqlServer/Query/Internal/SqlServerLongCountMethodTranslator.cs create mode 100644 src/EFCore.Sqlite.Core/Query/Internal/SqliteAggregateMethodCallTranslatorProvider.cs create mode 100644 src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryableAggregateMethodTranslator.cs diff --git a/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs b/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs index d4d43fbfc23..a30b3485fb7 100644 --- a/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs +++ b/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs @@ -66,6 +66,7 @@ public static readonly IDictionary RelationalServi { typeof(IModificationCommandBatchFactory), new ServiceCharacteristics(ServiceLifetime.Scoped) }, { typeof(IRelationalSqlTranslatingExpressionVisitorFactory), new ServiceCharacteristics(ServiceLifetime.Scoped) }, { typeof(IMethodCallTranslatorProvider), new ServiceCharacteristics(ServiceLifetime.Scoped) }, + { typeof(IAggregateMethodCallTranslatorProvider), new ServiceCharacteristics(ServiceLifetime.Scoped) }, { typeof(IMemberTranslatorProvider), new ServiceCharacteristics(ServiceLifetime.Scoped) }, { typeof(ISqlExpressionFactory), new ServiceCharacteristics(ServiceLifetime.Scoped) }, { typeof(IRelationalQueryStringFactory), new ServiceCharacteristics(ServiceLifetime.Scoped) }, @@ -171,6 +172,7 @@ public override EntityFrameworkServicesBuilder TryAddCoreServices() TryAdd(); TryAdd(); TryAdd(); + TryAdd(); TryAdd(); TryAdd(); TryAdd(); @@ -202,6 +204,7 @@ public override EntityFrameworkServicesBuilder TryAddCoreServices() .AddDependencyScoped() .AddDependencyScoped() .AddDependencyScoped() + .AddDependencyScoped() .AddDependencyScoped() .AddDependencyScoped() .AddDependencyScoped() diff --git a/src/EFCore.Relational/Query/IAggregateMethodCallTranslator.cs b/src/EFCore.Relational/Query/IAggregateMethodCallTranslator.cs new file mode 100644 index 00000000000..58d7f8b8334 --- /dev/null +++ b/src/EFCore.Relational/Query/IAggregateMethodCallTranslator.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; + +namespace Microsoft.EntityFrameworkCore.Query; + +/// +/// +/// A SQL translator for LINQ expression represending an aggregate function. +/// +/// +/// This interface is typically used by database providers (and other extensions). It is generally +/// not used in application code. +/// +/// +public interface IAggregateMethodCallTranslator +{ + /// + /// Translates a LINQ to a SQL equivalent. + /// + /// The method info from . + /// The source on which the aggregate method is applied. + /// SQL representations of scalar . + /// The query logger to use. + /// A SQL translation of the . + SqlExpression? Translate( + MethodInfo method, + EnumerableExpression source, + IReadOnlyList arguments, + IDiagnosticsLogger logger); +} diff --git a/src/EFCore.Relational/Query/IAggregateMethodCallTranslatorPlugin.cs b/src/EFCore.Relational/Query/IAggregateMethodCallTranslatorPlugin.cs new file mode 100644 index 00000000000..7fa452b7491 --- /dev/null +++ b/src/EFCore.Relational/Query/IAggregateMethodCallTranslatorPlugin.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Query; + +/// +/// Represents plugin for . +/// +/// +/// The service lifetime is and multiple registrations +/// are allowed. This means that each instance will use its own +/// set of instances of this service. +/// The implementations may depend on other services registered with any lifetime. +/// The implementations do not need to be thread-safe. +/// +public interface IAggregateMethodCallTranslatorPlugin +{ + /// + /// Gets the method call translators. + /// + IEnumerable Translators { get; } +} diff --git a/src/EFCore.Relational/Query/IAggregateMethodCallTranslatorProvider.cs b/src/EFCore.Relational/Query/IAggregateMethodCallTranslatorProvider.cs new file mode 100644 index 00000000000..217558ed506 --- /dev/null +++ b/src/EFCore.Relational/Query/IAggregateMethodCallTranslatorProvider.cs @@ -0,0 +1,35 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; + +namespace Microsoft.EntityFrameworkCore.Query; + +/// +/// Provides translations for LINQ expressions which represents aggregate methods. +/// +/// +/// The service lifetime is and multiple registrations +/// are allowed. This means that each instance will use its own +/// set of instances of this service. +/// The implementations may depend on other services registered with any lifetime. +/// The implementations do not need to be thread-safe. +/// +public interface IAggregateMethodCallTranslatorProvider +{ + /// + /// Translates a LINQ aggregate to a SQL equivalent. + /// + /// A model to use for translation. + /// The method info from . + /// The source on which the aggregate method is applied. + /// SQL representations of scalar . + /// The query logger to use. + /// A SQL translation of the . + SqlExpression? Translate( + IModel model, + MethodInfo method, + EnumerableExpression source, + IReadOnlyList arguments, + IDiagnosticsLogger logger); +} diff --git a/src/EFCore.Relational/Query/IMethodCallTranslator.cs b/src/EFCore.Relational/Query/IMethodCallTranslator.cs index 13432daa93f..a96a9e2ff69 100644 --- a/src/EFCore.Relational/Query/IMethodCallTranslator.cs +++ b/src/EFCore.Relational/Query/IMethodCallTranslator.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Diagnostics.CodeAnalysis; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; namespace Microsoft.EntityFrameworkCore.Query; diff --git a/src/EFCore.Relational/Query/IMethodCallTranslatorProvider.cs b/src/EFCore.Relational/Query/IMethodCallTranslatorProvider.cs index e0fda2ee45e..9a729cc2749 100644 --- a/src/EFCore.Relational/Query/IMethodCallTranslatorProvider.cs +++ b/src/EFCore.Relational/Query/IMethodCallTranslatorProvider.cs @@ -6,12 +6,14 @@ namespace Microsoft.EntityFrameworkCore.Query; /// -/// Provides translations for LINQ expressions. +/// Provides translations for LINQ expressions which represents scalar methods. /// /// -/// The service lifetime is . This means a single instance -/// is used by many instances. The implementation must be thread-safe. -/// This service cannot depend on services registered as . +/// The service lifetime is and multiple registrations +/// are allowed. This means that each instance will use its own +/// set of instances of this service. +/// The implementations may depend on other services registered with any lifetime. +/// The implementations do not need to be thread-safe. /// public interface IMethodCallTranslatorProvider { diff --git a/src/EFCore.Relational/Query/Internal/QueryableAggregateMethodTranslator.cs b/src/EFCore.Relational/Query/Internal/QueryableAggregateMethodTranslator.cs new file mode 100644 index 00000000000..b0592ca7323 --- /dev/null +++ b/src/EFCore.Relational/Query/Internal/QueryableAggregateMethodTranslator.cs @@ -0,0 +1,180 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; + +namespace Microsoft.EntityFrameworkCore.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class QueryableAggregateMethodTranslator : IAggregateMethodCallTranslator +{ + private readonly ISqlExpressionFactory _sqlExpressionFactory; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public QueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) + { + _sqlExpressionFactory = sqlExpressionFactory; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual SqlExpression? Translate( + MethodInfo method, EnumerableExpression source, IReadOnlyList arguments, + IDiagnosticsLogger logger) + { + if (method.DeclaringType == typeof(Queryable)) + { + var methodInfo = method.IsGenericMethod + ? method.GetGenericMethodDefinition() + : method; + switch (methodInfo.Name) + { + case nameof(Queryable.Average) + when (QueryableMethods.IsAverageWithoutSelector(methodInfo) + || QueryableMethods.IsAverageWithSelector(methodInfo)) + && source.Selector is SqlExpression averageSqlExpression: + var averageInputType = averageSqlExpression.Type; + if (averageInputType == typeof(int) + || averageInputType == typeof(long)) + { + averageSqlExpression = _sqlExpressionFactory.ApplyDefaultTypeMapping( + _sqlExpressionFactory.Convert(averageSqlExpression, typeof(double))); + } + + averageSqlExpression = CombineTerms(source, averageSqlExpression); + return averageInputType == typeof(float) + ? _sqlExpressionFactory.Convert( + _sqlExpressionFactory.Function( + "AVG", + new[] { averageSqlExpression }, + nullable: true, + argumentsPropagateNullability: new[] { false }, + typeof(double)), + averageSqlExpression.Type, + averageSqlExpression.TypeMapping) + : _sqlExpressionFactory.Function( + "AVG", + new[] { averageSqlExpression }, + nullable: true, + argumentsPropagateNullability: new[] { false }, + averageSqlExpression.Type, + averageSqlExpression.TypeMapping); + + case nameof(Queryable.Count) + when methodInfo == QueryableMethods.CountWithoutPredicate + || methodInfo == QueryableMethods.CountWithPredicate: + var countSqlExpression = (source.Selector as SqlExpression) ?? _sqlExpressionFactory.Fragment("*"); + countSqlExpression = CombineTerms(source, countSqlExpression); + return _sqlExpressionFactory.ApplyDefaultTypeMapping( + _sqlExpressionFactory.Function( + "COUNT", + new[] { countSqlExpression }, + nullable: false, + argumentsPropagateNullability: new[] { false }, + typeof(int))); + + case nameof(Queryable.LongCount) + when methodInfo == QueryableMethods.LongCountWithoutPredicate + || methodInfo == QueryableMethods.LongCountWithPredicate: + var longCountSqlExpression = (source.Selector as SqlExpression) ?? _sqlExpressionFactory.Fragment("*"); + longCountSqlExpression = CombineTerms(source, longCountSqlExpression); + + return _sqlExpressionFactory.ApplyDefaultTypeMapping( + _sqlExpressionFactory.Function( + "COUNT", + new[] { longCountSqlExpression }, + nullable: false, + argumentsPropagateNullability: new[] { false }, + typeof(long))); + + case nameof(Queryable.Max) + when (methodInfo == QueryableMethods.MaxWithoutSelector + || methodInfo == QueryableMethods.MaxWithSelector) + && source.Selector is SqlExpression maxSqlExpression: + maxSqlExpression = CombineTerms(source, maxSqlExpression); + return _sqlExpressionFactory.Function( + "MAX", + new[] { maxSqlExpression }, + nullable: true, + argumentsPropagateNullability: new[] { false }, + maxSqlExpression.Type, + maxSqlExpression.TypeMapping); + + case nameof(Queryable.Min) + when (methodInfo == QueryableMethods.MinWithoutSelector + || methodInfo == QueryableMethods.MinWithSelector) + && source.Selector is SqlExpression minSqlExpression: + minSqlExpression = CombineTerms(source, minSqlExpression); + return _sqlExpressionFactory.Function( + "MIN", + new[] { minSqlExpression }, + nullable: true, + argumentsPropagateNullability: new[] { false }, + minSqlExpression.Type, + minSqlExpression.TypeMapping); + + case nameof(Queryable.Sum) + when (QueryableMethods.IsSumWithoutSelector(methodInfo) + || QueryableMethods.IsSumWithSelector(methodInfo)) + && source.Selector is SqlExpression sumSqlExpression: + sumSqlExpression = CombineTerms(source, sumSqlExpression); + var sumInputType = sumSqlExpression.Type; + return sumInputType == typeof(float) + ? _sqlExpressionFactory.Convert( + _sqlExpressionFactory.Function( + "SUM", + new[] { sumSqlExpression }, + nullable: true, + argumentsPropagateNullability: new[] { false }, + typeof(double)), + sumInputType, + sumSqlExpression.TypeMapping) + : _sqlExpressionFactory.Function( + "SUM", + new[] { sumSqlExpression }, + nullable: true, + argumentsPropagateNullability: new[] { false }, + sumInputType, + sumSqlExpression.TypeMapping); + } + } + + return null; + } + + private SqlExpression CombineTerms(EnumerableExpression enumerableExpression, SqlExpression sqlExpression) + { + if (enumerableExpression.Predicate != null) + { + if (sqlExpression is SqlFragmentExpression) + { + sqlExpression = _sqlExpressionFactory.Constant(1); + } + + sqlExpression = _sqlExpressionFactory.Case( + new List { new(enumerableExpression.Predicate, sqlExpression) }, + elseResult: null); + } + + if (enumerableExpression.IsDistinct) + { + sqlExpression = new DistinctExpression(sqlExpression); + } + + return sqlExpression; + } +} diff --git a/src/EFCore.Relational/Query/RelationalAggregateMethodCallTranslatorProvider.cs b/src/EFCore.Relational/Query/RelationalAggregateMethodCallTranslatorProvider.cs new file mode 100644 index 00000000000..b2cb7b5b450 --- /dev/null +++ b/src/EFCore.Relational/Query/RelationalAggregateMethodCallTranslatorProvider.cs @@ -0,0 +1,89 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.EntityFrameworkCore.Query.Internal; +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; + +namespace Microsoft.EntityFrameworkCore.Query; + +/// +public class RelationalAggregateMethodCallTranslatorProvider : IAggregateMethodCallTranslatorProvider +{ + private readonly List _plugins = new(); + private readonly List _translators = new(); + private readonly ISqlExpressionFactory _sqlExpressionFactory; + + /// + /// Creates a new instance of the class. + /// + /// Parameter object containing dependencies for this class. + public RelationalAggregateMethodCallTranslatorProvider(RelationalAggregateMethodCallTranslatorProviderDependencies dependencies) + { + Dependencies = dependencies; + + _plugins.AddRange(dependencies.Plugins.SelectMany(p => p.Translators)); + + _sqlExpressionFactory = dependencies.SqlExpressionFactory; + + _translators.AddRange( + new IAggregateMethodCallTranslator[] + { + new QueryableAggregateMethodTranslator(_sqlExpressionFactory) + }); ; + } + + /// + /// Dependencies for this service. + /// + protected virtual RelationalAggregateMethodCallTranslatorProviderDependencies Dependencies { get; } + + /// + public virtual SqlExpression? Translate( + IModel model, + MethodInfo method, + EnumerableExpression source, + IReadOnlyList arguments, + IDiagnosticsLogger logger) + { + // TODO: Add support for user defined aggregate functions + //var dbFunction = model.FindDbFunction(method); + //if (dbFunction != null) + //{ + // if (dbFunction.Translation != null) + // { + // return dbFunction.Translation.Invoke( + // arguments.Select(e => _sqlExpressionFactory.ApplyDefaultTypeMapping(e)).ToList()); + // } + + // var argumentsPropagateNullability = dbFunction.Parameters.Select(p => p.PropagatesNullability); + + // return dbFunction.IsBuiltIn + // ? _sqlExpressionFactory.Function( + // dbFunction.Name, + // arguments, + // dbFunction.IsNullable, + // argumentsPropagateNullability, + // method.ReturnType.UnwrapNullableType(), + // dbFunction.TypeMapping) + // : _sqlExpressionFactory.Function( + // dbFunction.Schema, + // dbFunction.Name, + // arguments, + // dbFunction.IsNullable, + // argumentsPropagateNullability, + // method.ReturnType.UnwrapNullableType(), + // dbFunction.TypeMapping); + //} + + return _plugins.Concat(_translators) + .Select(t => t.Translate(method, source, arguments, logger)) + .FirstOrDefault(t => t != null); + } + + /// + /// Adds additional translators which will take priority over existing registered translators. + /// + /// Translators to add. + protected virtual void AddTranslators(IEnumerable translators) + => _translators.InsertRange(0, translators); +} diff --git a/src/EFCore.Relational/Query/RelationalAggregateMethodCallTranslatorProviderDependencies.cs b/src/EFCore.Relational/Query/RelationalAggregateMethodCallTranslatorProviderDependencies.cs new file mode 100644 index 00000000000..14dee9fde57 --- /dev/null +++ b/src/EFCore.Relational/Query/RelationalAggregateMethodCallTranslatorProviderDependencies.cs @@ -0,0 +1,72 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Query; + +/// +/// +/// Service dependencies parameter class for +/// +/// +/// This type is typically used by database providers (and other extensions). It is generally +/// not used in application code. +/// +/// +/// +/// +/// Do not construct instances of this class directly from either provider or application code as the +/// constructor signature may change as new dependencies are added. Instead, use this type in +/// your constructor so that an instance will be created and injected automatically by the +/// dependency injection container. To create an instance with some dependent services replaced, +/// first resolve the object from the dependency injection container, then replace selected +/// services using the 'With...' methods. Do not call the constructor at any point in this process. +/// +/// +/// The service lifetime is . This means that each +/// instance will use its own instance of this service. +/// The implementation may depend on other services registered with any lifetime. +/// The implementation does not need to be thread-safe. +/// +/// +public sealed record RelationalAggregateMethodCallTranslatorProviderDependencies +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + /// + /// Do not call this constructor directly from either provider or application code as it may change + /// as new dependencies are added. Instead, use this type in your constructor so that an instance + /// will be created and injected automatically by the dependency injection container. To create + /// an instance with some dependent services replaced, first resolve the object from the dependency + /// injection container, then replace selected services using the 'With...' methods. Do not call + /// the constructor at any point in this process. + /// + [EntityFrameworkInternal] + public RelationalAggregateMethodCallTranslatorProviderDependencies( + ISqlExpressionFactory sqlExpressionFactory, + IEnumerable plugins, + IRelationalTypeMappingSource typeMappingSource) + { + SqlExpressionFactory = sqlExpressionFactory; + Plugins = plugins; + RelationalTypeMappingSource = typeMappingSource; + } + + /// + /// The expression factory.. + /// + public ISqlExpressionFactory SqlExpressionFactory { get; init; } + + /// + /// Registered plugins. + /// + public IEnumerable Plugins { get; init; } + + /// + /// Relational Type Mapping Source. + /// + public IRelationalTypeMappingSource RelationalTypeMappingSource { get; init; } +} diff --git a/src/EFCore.Relational/Query/RelationalCompiledQueryCacheKeyGenerator.cs b/src/EFCore.Relational/Query/RelationalCompiledQueryCacheKeyGenerator.cs index 5f8893e4f75..b193a826cd4 100644 --- a/src/EFCore.Relational/Query/RelationalCompiledQueryCacheKeyGenerator.cs +++ b/src/EFCore.Relational/Query/RelationalCompiledQueryCacheKeyGenerator.cs @@ -3,22 +3,7 @@ namespace Microsoft.EntityFrameworkCore.Query; -/// -/// -/// Creates keys that uniquely identifies a query. This is used to store and lookup -/// compiled versions of a query in a cache. -/// -/// -/// This type is typically used by database providers (and other extensions). It is generally -/// not used in application code. -/// -/// -/// -/// The service lifetime is . This means that each -/// instance will use its own instance of this service. -/// The implementation may depend on other services registered with any lifetime. -/// The implementation does not need to be thread-safe. -/// +/// public class RelationalCompiledQueryCacheKeyGenerator : CompiledQueryCacheKeyGenerator { /// @@ -39,12 +24,7 @@ public RelationalCompiledQueryCacheKeyGenerator( /// protected virtual RelationalCompiledQueryCacheKeyGeneratorDependencies RelationalDependencies { get; } - /// - /// Generates the cache key for the given query. - /// - /// The query to get the cache key for. - /// A value indicating whether the query will be executed asynchronously. - /// The cache key. + /// public override object GenerateCacheKey(Expression query, bool async) => GenerateCacheKeyCore(query, async); @@ -102,41 +82,19 @@ public RelationalCompiledQueryCacheKey( _shouldBuffer = shouldBuffer; } - /// - /// Determines if this key is equivalent to a given object (i.e. if they are keys for the same query). - /// - /// - /// The object to compare this key to. - /// - /// - /// if the object is a and is for the same query, - /// otherwise . - /// + /// public override bool Equals(object? obj) => obj is RelationalCompiledQueryCacheKey key && Equals(key); - /// - /// Indicates whether the current object is equal to another object of the same type. - /// - /// - /// An object to compare with this object. - /// - /// - /// if the current object is equal to the parameter; otherwise, . - /// + /// public bool Equals(RelationalCompiledQueryCacheKey other) => _compiledQueryCacheKey.Equals(other._compiledQueryCacheKey) && _useRelationalNulls == other._useRelationalNulls && _querySplittingBehavior == other._querySplittingBehavior && _shouldBuffer == other._shouldBuffer; - /// - /// Gets the hash code for the key. - /// - /// - /// The hash code for the key. - /// + /// public override int GetHashCode() => HashCode.Combine( _compiledQueryCacheKey, _useRelationalNulls, _querySplittingBehavior, _shouldBuffer); diff --git a/src/EFCore.Relational/Query/RelationalEvaluatableExpressionFilter.cs b/src/EFCore.Relational/Query/RelationalEvaluatableExpressionFilter.cs index c1a2e6e7286..29df9899207 100644 --- a/src/EFCore.Relational/Query/RelationalEvaluatableExpressionFilter.cs +++ b/src/EFCore.Relational/Query/RelationalEvaluatableExpressionFilter.cs @@ -3,14 +3,7 @@ namespace Microsoft.EntityFrameworkCore.Query; -/// -/// Represents a filter for evaluatable expressions. -/// -/// -/// The service lifetime is . This means a single instance -/// is used by many instances. The implementation must be thread-safe. -/// This service cannot depend on services registered as . -/// +/// public class RelationalEvaluatableExpressionFilter : EvaluatableExpressionFilter { /// @@ -37,12 +30,7 @@ public RelationalEvaluatableExpressionFilter( /// protected virtual RelationalEvaluatableExpressionFilterDependencies RelationalDependencies { get; } - /// - /// Checks whether the given expression can be evaluated. - /// - /// The expression. - /// The model. - /// if the expression can be evaluated; otherwise. + /// public override bool IsEvaluatableExpression(Expression expression, IModel model) { if (expression is MethodCallExpression methodCallExpression) diff --git a/src/EFCore.Relational/Query/RelationalMemberTranslatorProvider.cs b/src/EFCore.Relational/Query/RelationalMemberTranslatorProvider.cs index 358c75583c6..7ad3d3f30e6 100644 --- a/src/EFCore.Relational/Query/RelationalMemberTranslatorProvider.cs +++ b/src/EFCore.Relational/Query/RelationalMemberTranslatorProvider.cs @@ -6,16 +6,7 @@ namespace Microsoft.EntityFrameworkCore.Query; -/// -/// Provides translations for LINQ expressions by dispatching to multiple specialized member -/// translators. -/// -/// -/// The service lifetime is . This means that each -/// instance will use its own instance of this service. -/// The implementation may depend on other services registered with any lifetime. -/// The implementation does not need to be thread-safe. -/// +/// public class RelationalMemberTranslatorProvider : IMemberTranslatorProvider { private readonly List _plugins = new(); diff --git a/src/EFCore.Relational/Query/RelationalMethodCallTranslatorProvider.cs b/src/EFCore.Relational/Query/RelationalMethodCallTranslatorProvider.cs index ca8ba28bd61..566f3d12bf8 100644 --- a/src/EFCore.Relational/Query/RelationalMethodCallTranslatorProvider.cs +++ b/src/EFCore.Relational/Query/RelationalMethodCallTranslatorProvider.cs @@ -6,16 +6,7 @@ namespace Microsoft.EntityFrameworkCore.Query; -/// -/// Provides translations for LINQ expressions by dispatching to multiple specialized -/// method call translators. -/// -/// -/// The service lifetime is . This means that each -/// instance will use its own instance of this service. -/// The implementation may depend on other services registered with any lifetime. -/// The implementation does not need to be thread-safe. -/// +/// public class RelationalMethodCallTranslatorProvider : IMethodCallTranslatorProvider { private readonly List _plugins = new(); diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index 31a745a013b..c1f94c9a06f 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -237,7 +237,7 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent LambdaExpression? selector, Type resultType) => TranslateAggregateWithSelector( - source, selector, e => _sqlTranslator.TranslateAverage(e), throwWhenEmpty: true, resultType); + source, selector, e => TranslateAverage(e), throwWhenEmpty: true, resultType); /// protected override ShapedQueryExpression? TranslateCast(ShapedQueryExpression source, Type resultType) @@ -301,7 +301,7 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent /// protected override ShapedQueryExpression? TranslateCount(ShapedQueryExpression source, LambdaExpression? predicate) - => TranslateAggregateWithPredicate(source, predicate, e => _sqlTranslator.TranslateCount(e), typeof(int)); + => TranslateAggregateWithPredicate(source, predicate, e => TranslateCount(e), typeof(int)); /// protected override ShapedQueryExpression? TranslateDefaultIfEmpty(ShapedQueryExpression source, Expression? defaultValue) @@ -623,15 +623,15 @@ private SqlExpression CreateJoinPredicate(Expression outerKey, Expression innerK /// protected override ShapedQueryExpression? TranslateLongCount(ShapedQueryExpression source, LambdaExpression? predicate) - => TranslateAggregateWithPredicate(source, predicate, e => _sqlTranslator.TranslateLongCount(e), typeof(long)); + => TranslateAggregateWithPredicate(source, predicate, e => TranslateLongCount(e), typeof(long)); /// protected override ShapedQueryExpression? TranslateMax(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - => TranslateAggregateWithSelector(source, selector, e => _sqlTranslator.TranslateMax(e), throwWhenEmpty: true, resultType); + => TranslateAggregateWithSelector(source, selector, e => TranslateMax(e), throwWhenEmpty: true, resultType); /// protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - => TranslateAggregateWithSelector(source, selector, e => _sqlTranslator.TranslateMin(e), throwWhenEmpty: true, resultType); + => TranslateAggregateWithSelector(source, selector, e => TranslateMin(e), throwWhenEmpty: true, resultType); /// protected override ShapedQueryExpression? TranslateOfType(ShapedQueryExpression source, Type resultType) @@ -888,7 +888,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp /// protected override ShapedQueryExpression? TranslateSum(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - => TranslateAggregateWithSelector(source, selector, e => _sqlTranslator.TranslateSum(e), throwWhenEmpty: false, resultType); + => TranslateAggregateWithSelector(source, selector, e => TranslateSum(e), throwWhenEmpty: false, resultType); /// protected override ShapedQueryExpression? TranslateTake(ShapedQueryExpression source, Expression count) @@ -1465,6 +1465,36 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape } } + private SqlExpression? TranslateAverage(SqlExpression sqlExpression) + => RelationalDependencies.AggregateMethodCallTranslatorProvider.Translate( + RelationalDependencies.Model, QueryableMethods.GetAverageWithoutSelector(sqlExpression.Type), + new EnumerableExpression(sqlExpression), Array.Empty(), _queryCompilationContext.Logger); + + private SqlExpression? TranslateCount(SqlExpression sqlExpression) + => RelationalDependencies.AggregateMethodCallTranslatorProvider.Translate( + RelationalDependencies.Model, QueryableMethods.CountWithoutPredicate, + new EnumerableExpression(sqlExpression), Array.Empty(), _queryCompilationContext.Logger); + + private SqlExpression? TranslateLongCount(SqlExpression sqlExpression) + => RelationalDependencies.AggregateMethodCallTranslatorProvider.Translate( + RelationalDependencies.Model, QueryableMethods.LongCountWithoutPredicate, + new EnumerableExpression(sqlExpression), Array.Empty(), _queryCompilationContext.Logger); + + private SqlExpression? TranslateMax(SqlExpression sqlExpression) + => RelationalDependencies.AggregateMethodCallTranslatorProvider.Translate( + RelationalDependencies.Model, QueryableMethods.MaxWithoutSelector, + new EnumerableExpression(sqlExpression), Array.Empty(), _queryCompilationContext.Logger); + + private SqlExpression? TranslateMin(SqlExpression sqlExpression) + => RelationalDependencies.AggregateMethodCallTranslatorProvider.Translate( + RelationalDependencies.Model, QueryableMethods.MinWithoutSelector, + new EnumerableExpression(sqlExpression), Array.Empty(), _queryCompilationContext.Logger); + + private SqlExpression? TranslateSum(SqlExpression sqlExpression) + => RelationalDependencies.AggregateMethodCallTranslatorProvider.Translate( + RelationalDependencies.Model, QueryableMethods.GetSumWithoutSelector(sqlExpression.Type), + new EnumerableExpression(sqlExpression), Array.Empty(), _queryCompilationContext.Logger); + private ShapedQueryExpression? TranslateAggregateWithPredicate( ShapedQueryExpression source, LambdaExpression? predicate, diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitorDependencies.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitorDependencies.cs index 9fb23b65d9b..ad3fd864baf 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitorDependencies.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitorDependencies.cs @@ -47,10 +47,14 @@ public sealed record RelationalQueryableMethodTranslatingExpressionVisitorDepend [EntityFrameworkInternal] public RelationalQueryableMethodTranslatingExpressionVisitorDependencies( IRelationalSqlTranslatingExpressionVisitorFactory relationalSqlTranslatingExpressionVisitorFactory, - ISqlExpressionFactory sqlExpressionFactory) + ISqlExpressionFactory sqlExpressionFactory, + IModel model, + IAggregateMethodCallTranslatorProvider aggregateMethodCallTranslatorProvider) { RelationalSqlTranslatingExpressionVisitorFactory = relationalSqlTranslatingExpressionVisitorFactory; SqlExpressionFactory = sqlExpressionFactory; + Model = model; + AggregateMethodCallTranslatorProvider = aggregateMethodCallTranslatorProvider; } /// @@ -62,4 +66,14 @@ public RelationalQueryableMethodTranslatingExpressionVisitorDependencies( /// The SQL expression factory. /// public ISqlExpressionFactory SqlExpressionFactory { get; init; } + + /// + /// The model. + /// + public IModel Model { get; init; } + + /// + /// The aggregate method-call translation provider. + /// + public IAggregateMethodCallTranslatorProvider AggregateMethodCallTranslatorProvider { get; } } diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index dd6922d2749..f8ccab890ad 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -3,6 +3,7 @@ using System.Collections; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using Microsoft.EntityFrameworkCore.Metadata.Internal; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; @@ -161,6 +162,7 @@ protected virtual void AddTranslationErrorDetails(string details) /// /// An expression to translate Average over. /// A SQL translation of Average over the given expression. + [Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")] public virtual SqlExpression? TranslateAverage(SqlExpression sqlExpression) { var inputType = sqlExpression.Type; @@ -199,6 +201,7 @@ protected virtual void AddTranslationErrorDetails(string details) /// /// An expression to translate Count over. /// A SQL translation of Count over the given expression. + [Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")] public virtual SqlExpression? TranslateCount(SqlExpression sqlExpression) => _sqlExpressionFactory.ApplyDefaultTypeMapping( _sqlExpressionFactory.Function( @@ -213,6 +216,7 @@ protected virtual void AddTranslationErrorDetails(string details) /// /// An expression to translate LongCount over. /// A SQL translation of LongCount over the given expression. + [Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")] public virtual SqlExpression? TranslateLongCount(SqlExpression sqlExpression) => _sqlExpressionFactory.ApplyDefaultTypeMapping( _sqlExpressionFactory.Function( @@ -227,6 +231,7 @@ protected virtual void AddTranslationErrorDetails(string details) /// /// An expression to translate Max over. /// A SQL translation of Max over the given expression. + [Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")] public virtual SqlExpression? TranslateMax(SqlExpression sqlExpression) => sqlExpression != null ? _sqlExpressionFactory.Function( @@ -243,6 +248,7 @@ protected virtual void AddTranslationErrorDetails(string details) /// /// An expression to translate Min over. /// A SQL translation of Min over the given expression. + [Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")] public virtual SqlExpression? TranslateMin(SqlExpression sqlExpression) => sqlExpression != null ? _sqlExpressionFactory.Function( @@ -259,6 +265,7 @@ protected virtual void AddTranslationErrorDetails(string details) /// /// An expression to translate Sum over. /// A SQL translation of Sum over the given expression. + [Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")] public virtual SqlExpression? TranslateSum(SqlExpression sqlExpression) { var inputType = sqlExpression.Type; @@ -647,243 +654,30 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp // EF Indexer property if (methodCallExpression.TryGetIndexerArguments(_model, out source, out propertyName)) { - var result = TryBindMember(Visit(source), MemberIdentity.Create(propertyName)); - if (result != null) + if (TryBindMember(Visit(source), MemberIdentity.Create(propertyName)) is SqlExpression result) { return result; } } - // Subquery case - // TODO: Refactor in future to avoid repeated visitation. - // Specifically ordering of visiting aggregate chain, subquery, method arguments. - if (methodCallExpression.Method.IsStatic - && methodCallExpression.Arguments.Count > 0 - && methodCallExpression.Method.DeclaringType == typeof(Queryable)) - { - if (methodCallExpression.Method.IsGenericMethod - && methodCallExpression.Method.GetGenericMethodDefinition() == QueryableMethods.AsQueryable - && methodCallExpression.Arguments[0] is RelationalGroupByShaperExpression groupByShaperExpression) - { - return new EnumerableExpression(groupByShaperExpression.ElementSelector); - } - - var enumerableSource = Visit(methodCallExpression.Arguments[0]); - if (enumerableSource is EnumerableExpression enumerableExpression) - { - Expression? result = null; - switch (methodCallExpression.Method.Name) - { - case nameof(Queryable.Average): - if (methodCallExpression.Arguments.Count == 2) - { - enumerableExpression = ProcessSelector( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - } - - result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); - break; - - case nameof(Queryable.Count): - if (methodCallExpression.Arguments.Count == 2) - { - var newEnumerableExpression = ProcessPredicate( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - if (newEnumerableExpression == null) - { - break; - } - - enumerableExpression = newEnumerableExpression; - } - - result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); - break; - - - case nameof(Queryable.Distinct): - result = enumerableExpression.Selector is EntityShaperExpression entityShaperExpression - && entityShaperExpression.EntityType.FindPrimaryKey() != null - ? enumerableExpression - : !enumerableExpression.IsDistinct - ? enumerableExpression.ApplyDistinct() - : (Expression?)null; - break; - - case nameof(Queryable.LongCount): - if (methodCallExpression.Arguments.Count == 2) - { - var newEnumerableExpression = ProcessPredicate( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - if (newEnumerableExpression == null) - { - break; - } - - enumerableExpression = newEnumerableExpression; - } - - result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); - break; - - case nameof(Queryable.Max): - if (methodCallExpression.Arguments.Count == 2) - { - enumerableExpression = ProcessSelector( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - } - - result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); - break; - - case nameof(Queryable.Min): - if (methodCallExpression.Arguments.Count == 2) - { - enumerableExpression = ProcessSelector( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - } - - result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); - break; - - case nameof(Queryable.OrderBy): - result = ProcessOrderByThenBy( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: true); - break; - - case nameof(Queryable.OrderByDescending): - result = ProcessOrderByThenBy( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: false); - break; - - case nameof(Queryable.ThenBy): - result = ProcessOrderByThenBy( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: true); - break; - - case nameof(Queryable.ThenByDescending): - result = ProcessOrderByThenBy( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: false); - break; - - case nameof(Queryable.Select): - result = ProcessSelector(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - break; - - case nameof(Queryable.Sum): - if (methodCallExpression.Arguments.Count == 2) - { - enumerableExpression = ProcessSelector( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - } - - result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); - break; - - case nameof(Queryable.Where): - result = ProcessPredicate(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - break; - } - - if (result != null) - { - return result; - } - } - } - - var subqueryTranslation = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression); - if (subqueryTranslation != null) - { - if (subqueryTranslation.ResultCardinality == ResultCardinality.Enumerable) - { - return QueryCompilationContext.NotTranslatedExpression; - } - - var shaperExpression = subqueryTranslation.ShaperExpression; - ProjectionBindingExpression? mappedProjectionBindingExpression = null; - - var innerExpression = shaperExpression; - Type? convertedType = null; - if (shaperExpression is UnaryExpression unaryExpression - && unaryExpression.NodeType == ExpressionType.Convert) - { - convertedType = unaryExpression.Type; - innerExpression = unaryExpression.Operand; - } - - if (innerExpression is EntityShaperExpression ese - && (convertedType == null - || convertedType.IsAssignableFrom(ese.Type))) - { - return new EntityReferenceExpression(subqueryTranslation.UpdateShaperExpression(innerExpression)); - } - - if (innerExpression is ProjectionBindingExpression pbe - && (convertedType == null - || convertedType.MakeNullable() == innerExpression.Type)) - { - mappedProjectionBindingExpression = pbe; - } - - if (mappedProjectionBindingExpression == null - && shaperExpression is BlockExpression blockExpression - && blockExpression.Expressions.Count == 2 - && blockExpression.Expressions[0] is BinaryExpression binaryExpression - && binaryExpression.NodeType == ExpressionType.Assign - && binaryExpression.Right is ProjectionBindingExpression pbe2) - { - mappedProjectionBindingExpression = pbe2; - } - - if (mappedProjectionBindingExpression == null) - { - return QueryCompilationContext.NotTranslatedExpression; - } - - var subquery = (SelectExpression)subqueryTranslation.QueryExpression; - var projection = subquery.GetProjection(mappedProjectionBindingExpression); - if (projection is not SqlExpression sqlExpression) - { - return QueryCompilationContext.NotTranslatedExpression; - } - - if (subquery.Tables.Count == 0) - { - return sqlExpression; - } - - subquery.ReplaceProjection(new List { sqlExpression }); - subquery.ApplyProjection(); - - SqlExpression scalarSubqueryExpression = new ScalarSubqueryExpression(subquery); - - if (subqueryTranslation.ResultCardinality == ResultCardinality.SingleOrDefault - && !shaperExpression.Type.IsNullableType()) - { - scalarSubqueryExpression = _sqlExpressionFactory.Coalesce( - scalarSubqueryExpression, - (SqlExpression)Visit(shaperExpression.Type.GetDefaultValueConstant())); - } - - return scalarSubqueryExpression; - } - - SqlExpression? sqlObject = null; - SqlExpression[] arguments; var method = methodCallExpression.Method; + var arguments = methodCallExpression.Arguments; + EnumerableExpression? enumerableExpression = null; + var abortTranslation = false; + SqlExpression? sqlObject = null; + List scalarArguments; if (method.Name == nameof(object.Equals) && methodCallExpression.Object != null - && methodCallExpression.Arguments.Count == 1) + && arguments.Count == 1) { var left = Visit(methodCallExpression.Object); - var right = Visit(RemoveObjectConvert(methodCallExpression.Arguments[0])); + var right = Visit(RemoveObjectConvert(arguments[0])); if (TryRewriteEntityEquality( ExpressionType.Equal, left == QueryCompilationContext.NotTranslatedExpression ? methodCallExpression.Object : left, - right == QueryCompilationContext.NotTranslatedExpression ? methodCallExpression.Arguments[0] : right, + right == QueryCompilationContext.NotTranslatedExpression ? arguments[0] : right, equalsMethod: true, out var result)) { @@ -894,7 +688,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp && right is SqlExpression rightSql) { sqlObject = leftSql; - arguments = new[] { rightSql }; + scalarArguments = new List { rightSql }; } else { @@ -903,23 +697,23 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp } else if (method.Name == nameof(object.Equals) && methodCallExpression.Object == null - && methodCallExpression.Arguments.Count == 2) + && arguments.Count == 2) { - if (methodCallExpression.Arguments[0].Type == typeof(object[]) - && methodCallExpression.Arguments[0] is NewArrayExpression) + if (arguments[0].Type == typeof(object[]) + && arguments[0] is NewArrayExpression) { return Visit( ConvertObjectArrayEqualityComparison( - methodCallExpression.Arguments[0], methodCallExpression.Arguments[1])); + arguments[0], arguments[1])); } - var left = Visit(RemoveObjectConvert(methodCallExpression.Arguments[0])); - var right = Visit(RemoveObjectConvert(methodCallExpression.Arguments[1])); + var left = Visit(RemoveObjectConvert(arguments[0])); + var right = Visit(RemoveObjectConvert(arguments[1])); if (TryRewriteEntityEquality( ExpressionType.Equal, - left == QueryCompilationContext.NotTranslatedExpression ? methodCallExpression.Arguments[0] : left, - right == QueryCompilationContext.NotTranslatedExpression ? methodCallExpression.Arguments[1] : right, + left == QueryCompilationContext.NotTranslatedExpression ? arguments[0] : left, + right == QueryCompilationContext.NotTranslatedExpression ? arguments[1] : right, equalsMethod: true, out var result)) { @@ -929,7 +723,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp if (left is SqlExpression leftSql && right is SqlExpression rightSql) { - arguments = new[] { leftSql, rightSql }; + scalarArguments = new List { leftSql, rightSql }; } else { @@ -939,12 +733,12 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp else if (method.IsGenericMethod && method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains)) { - var enumerable = Visit(methodCallExpression.Arguments[0]); - var item = Visit(methodCallExpression.Arguments[1]); + var enumerable = Visit(arguments[0]); + var item = Visit(arguments[1]); if (TryRewriteContainsEntity( enumerable, - item == QueryCompilationContext.NotTranslatedExpression ? methodCallExpression.Arguments[1] : item, out var result)) + item == QueryCompilationContext.NotTranslatedExpression ? arguments[1] : item, out var result)) { return result; } @@ -952,22 +746,22 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp if (enumerable is SqlExpression sqlEnumerable && item is SqlExpression sqlItem) { - arguments = new[] { sqlEnumerable, sqlItem }; + scalarArguments = new List { sqlEnumerable, sqlItem }; } else { return QueryCompilationContext.NotTranslatedExpression; } } - else if (methodCallExpression.Arguments.Count == 1 + else if (arguments.Count == 1 && method.IsContainsMethod()) { var enumerable = Visit(methodCallExpression.Object); - var item = Visit(methodCallExpression.Arguments[0]); + var item = Visit(arguments[0]); if (TryRewriteContainsEntity( enumerable!, - item == QueryCompilationContext.NotTranslatedExpression ? methodCallExpression.Arguments[0] : item, out var result)) + item == QueryCompilationContext.NotTranslatedExpression ? arguments[0] : item, out var result)) { return result; } @@ -976,7 +770,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp && item is SqlExpression sqlItem) { sqlObject = sqlEnumerable; - arguments = new[] { sqlItem }; + scalarArguments = new List { sqlItem }; } else { @@ -985,31 +779,222 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp } else { - if (TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out sqlObject)) + var skipVisitChildren = false; + scalarArguments = new List(); + if (method.IsStatic + && arguments.Count > 0 + && method.DeclaringType == typeof(Queryable)) { - return QueryCompilationContext.NotTranslatedExpression; + var genericMethod = method.IsGenericMethod + ? method.GetGenericMethodDefinition() + : method; + if (genericMethod == QueryableMethods.AsQueryable + && arguments[0] is RelationalGroupByShaperExpression groupByShaperExpression) + { + return new EnumerableExpression(groupByShaperExpression.ElementSelector); + } + + var enumerableSource = Visit(arguments[0]); + if (enumerableSource is EnumerableExpression) + { + enumerableExpression = (EnumerableExpression)enumerableSource; + switch (method.Name) + { + case nameof(Queryable.Average) + when QueryableMethods.IsAverageWithoutSelector(genericMethod): + case nameof(Queryable.Max) + when genericMethod == QueryableMethods.MaxWithoutSelector: + case nameof(Queryable.Min) + when genericMethod == QueryableMethods.MinWithoutSelector: + case nameof(Queryable.Sum) + when QueryableMethods.IsSumWithoutSelector(genericMethod): + case nameof(Queryable.Count) + when genericMethod == QueryableMethods.CountWithoutPredicate: + case nameof(Queryable.LongCount) + when genericMethod == QueryableMethods.LongCountWithoutPredicate: + skipVisitChildren = true; + break; + + case nameof(Queryable.Average) + when QueryableMethods.IsAverageWithSelector(genericMethod): + case nameof(Queryable.Max) + when genericMethod == QueryableMethods.MaxWithSelector: + case nameof(Queryable.Min) + when genericMethod == QueryableMethods.MinWithSelector: + case nameof(Queryable.Sum) + when QueryableMethods.IsSumWithSelector(genericMethod): + enumerableExpression = ProcessSelector(enumerableExpression, arguments[1].UnwrapLambdaFromQuote()); + skipVisitChildren = true; + break; + + case nameof(Queryable.Count) + when genericMethod == QueryableMethods.CountWithPredicate: + case nameof(Queryable.LongCount) + when genericMethod == QueryableMethods.LongCountWithPredicate: + enumerableExpression = ProcessPredicate(enumerableExpression, arguments[1].UnwrapLambdaFromQuote()); + if (enumerableExpression == null) + { + abortTranslation = true; + } + else + { + skipVisitChildren = true; + } + break; + + case nameof(Queryable.Distinct) + when genericMethod == QueryableMethods.Distinct: + if (enumerableExpression.Selector is EntityShaperExpression entityShaperExpression + && entityShaperExpression.EntityType.FindPrimaryKey() != null) + { + return enumerableExpression; + } + + if (!enumerableExpression.IsDistinct) + { + return enumerableExpression.ApplyDistinct(); + } + + abortTranslation = true; + break; + + case nameof(Queryable.OrderBy) + when genericMethod == QueryableMethods.OrderBy: + enumerableExpression = ProcessOrderByThenBy( + enumerableExpression, arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: true); + if (enumerableExpression != null) + { + return enumerableExpression; + } + abortTranslation = true; + break; + + case nameof(Queryable.OrderByDescending) + when genericMethod == QueryableMethods.OrderByDescending: + enumerableExpression = ProcessOrderByThenBy( + enumerableExpression, arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: false); + if (enumerableExpression != null) + { + return enumerableExpression; + } + abortTranslation = true; + break; + + case nameof(Queryable.ThenBy) + when genericMethod == QueryableMethods.ThenBy: + enumerableExpression = ProcessOrderByThenBy( + enumerableExpression, arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: true); + if (enumerableExpression != null) + { + return enumerableExpression; + } + abortTranslation = true; + break; + + case nameof(Queryable.ThenByDescending) + when genericMethod == QueryableMethods.ThenByDescending: + enumerableExpression = ProcessOrderByThenBy( + enumerableExpression, arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: false); + if (enumerableExpression != null) + { + return enumerableExpression; + } + abortTranslation = true; + break; + + case nameof(Queryable.Select) + when genericMethod == QueryableMethods.Select: + return ProcessSelector(enumerableExpression, arguments[1].UnwrapLambdaFromQuote()); + + case nameof(Queryable.Where) + when genericMethod == QueryableMethods.Where: + enumerableExpression = ProcessPredicate(enumerableExpression, arguments[1].UnwrapLambdaFromQuote()); + if (enumerableExpression != null) + { + return enumerableExpression; + } + abortTranslation = true; + break; + + default: + abortTranslation = true; + break; + } + } } - arguments = new SqlExpression[methodCallExpression.Arguments.Count]; - for (var i = 0; i < arguments.Length; i++) + if (!abortTranslation + && !skipVisitChildren) { - var argument = methodCallExpression.Arguments[i]; - if (TranslationFailed(argument, Visit(argument), out var sqlArgument)) + var @object = Visit(methodCallExpression.Object); + if (@object is EnumerableExpression) + { + // This is safe since if enumerableExpression is non-null then it was static method + enumerableExpression = (EnumerableExpression)@object; + } + else if (TranslationFailed(methodCallExpression.Object, @object, out sqlObject)) { - return QueryCompilationContext.NotTranslatedExpression; + abortTranslation = true; } - arguments[i] = sqlArgument!; + if (!abortTranslation) + { + for (var i = 0; i < arguments.Count; i++) + { + var argument = arguments[i]; + var visitedArgument = Visit(argument); + if (i == 0 + && visitedArgument is EnumerableExpression) + { + if (enumerableExpression != null) + { + abortTranslation = true; + break; + } + + enumerableExpression = (EnumerableExpression)visitedArgument; + continue; + } + + if (TranslationFailed(argument, visitedArgument, out var sqlArgument)) + { + abortTranslation = true; + break; + } + + scalarArguments.Add(sqlArgument!); + } + } } } - var translation = Dependencies.MethodCallTranslatorProvider.Translate( - _model, sqlObject, methodCallExpression.Method, arguments, _queryCompilationContext.Logger); - - if (translation == null) + if (!abortTranslation) { - if (methodCallExpression.Method == StringEqualsWithStringComparison - || methodCallExpression.Method == StringEqualsWithStringComparisonStatic) + SqlExpression? translation; + if (enumerableExpression != null) + { + var selector = TranslateInternal(enumerableExpression.Selector); + if (selector != null) + { + enumerableExpression = enumerableExpression.ApplySelector(selector); + } + + translation = Dependencies.AggregateMethodCallTranslatorProvider.Translate( + _model, method, enumerableExpression, scalarArguments, _queryCompilationContext.Logger); + } + else + { + translation = Dependencies.MethodCallTranslatorProvider.Translate( + _model, sqlObject, method, scalarArguments, _queryCompilationContext.Logger); + } + + if (translation != null) + { + return translation; + } + + if (method == StringEqualsWithStringComparison + || method == StringEqualsWithStringComparisonStatic) { AddTranslationErrorDetails(CoreStrings.QueryUnableToTranslateStringEqualsWithStringComparison); } @@ -1017,12 +1002,90 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp { AddTranslationErrorDetails( CoreStrings.QueryUnableToTranslateMethod( - methodCallExpression.Method.DeclaringType?.DisplayName(), - methodCallExpression.Method.Name)); + method.DeclaringType?.DisplayName(), + method.Name)); + } + } + + // Subquery case + var subqueryTranslation = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression); + if (subqueryTranslation != null) + { + if (subqueryTranslation.ResultCardinality == ResultCardinality.Enumerable) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + var shaperExpression = subqueryTranslation.ShaperExpression; + ProjectionBindingExpression? mappedProjectionBindingExpression = null; + + var innerExpression = shaperExpression; + Type? convertedType = null; + if (shaperExpression is UnaryExpression unaryExpression + && unaryExpression.NodeType == ExpressionType.Convert) + { + convertedType = unaryExpression.Type; + innerExpression = unaryExpression.Operand; + } + + if (innerExpression is EntityShaperExpression ese + && (convertedType == null + || convertedType.IsAssignableFrom(ese.Type))) + { + return new EntityReferenceExpression(subqueryTranslation.UpdateShaperExpression(innerExpression)); } + + if (innerExpression is ProjectionBindingExpression pbe + && (convertedType == null + || convertedType.MakeNullable() == innerExpression.Type)) + { + mappedProjectionBindingExpression = pbe; + } + + if (mappedProjectionBindingExpression == null + && shaperExpression is BlockExpression blockExpression + && blockExpression.Expressions.Count == 2 + && blockExpression.Expressions[0] is BinaryExpression binaryExpression + && binaryExpression.NodeType == ExpressionType.Assign + && binaryExpression.Right is ProjectionBindingExpression pbe2) + { + mappedProjectionBindingExpression = pbe2; + } + + if (mappedProjectionBindingExpression == null) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + var subquery = (SelectExpression)subqueryTranslation.QueryExpression; + var projection = subquery.GetProjection(mappedProjectionBindingExpression); + if (projection is not SqlExpression sqlExpression) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + if (subquery.Tables.Count == 0) + { + return sqlExpression; + } + + subquery.ReplaceProjection(new List { sqlExpression }); + subquery.ApplyProjection(); + + SqlExpression scalarSubqueryExpression = new ScalarSubqueryExpression(subquery); + + if (subqueryTranslation.ResultCardinality == ResultCardinality.SingleOrDefault + && !shaperExpression.Type.IsNullableType()) + { + scalarSubqueryExpression = _sqlExpressionFactory.Coalesce( + scalarSubqueryExpression, + (SqlExpression)Visit(shaperExpression.Type.GetDefaultValueConstant())); + } + + return scalarSubqueryExpression; } - return translation ?? QueryCompilationContext.NotTranslatedExpression; + return QueryCompilationContext.NotTranslatedExpression; } /// @@ -1211,7 +1274,7 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) return QueryCompilationContext.NotTranslatedExpression; } - private Expression? TryBindMember(Expression? source, MemberIdentity member) + private SqlExpression? TryBindMember(Expression? source, MemberIdentity member) { if (source is not EntityReferenceExpression entityReferenceExpression) { @@ -1334,53 +1397,6 @@ private static EnumerableExpression ProcessSelector(EnumerableExpression enumera return enumerableExpression.ApplyPredicate(predicate); } - private SqlExpression? TranslateAggregate(MethodInfo methodInfo, EnumerableExpression enumerableExpression) - { - var selector = TranslateInternal(enumerableExpression.Selector); - if (selector == null) - { - if (methodInfo.IsGenericMethod - && PredicateAggregateMethodInfos.Contains(methodInfo.GetGenericMethodDefinition())) - { - selector = _sqlExpressionFactory.Fragment("*"); - } - else - { - return null; - } - } - enumerableExpression.ApplySelector(selector); - - if (enumerableExpression.Predicate != null) - { - if (selector is SqlFragmentExpression) - { - selector = _sqlExpressionFactory.Constant(1); - } - - selector = _sqlExpressionFactory.Case( - new List { new(enumerableExpression.Predicate, selector) }, - elseResult: null); - } - - if (enumerableExpression.IsDistinct) - { - selector = new DistinctExpression(selector); - } - - // TODO: Issue#22957 - return methodInfo.Name switch - { - nameof(Queryable.Average) => TranslateAverage(selector), - nameof(Queryable.Count) => TranslateCount(selector), - nameof(Queryable.LongCount) => TranslateLongCount(selector), - nameof(Queryable.Max) => TranslateMax(selector), - nameof(Queryable.Min) => TranslateMin(selector), - nameof(Queryable.Sum) => TranslateSum(selector), - _ => null, - }; - } - private static Expression TryRemoveImplicitConvert(Expression expression) { if (expression is UnaryExpression unaryExpression diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitorDependencies.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitorDependencies.cs index 1df5e1a1118..e52a85c9cde 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitorDependencies.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitorDependencies.cs @@ -50,13 +50,15 @@ public RelationalSqlTranslatingExpressionVisitorDependencies( IModel model, IRelationalTypeMappingSource typeMappingSource, IMemberTranslatorProvider memberTranslatorProvider, - IMethodCallTranslatorProvider methodCallTranslatorProvider) + IMethodCallTranslatorProvider methodCallTranslatorProvider, + IAggregateMethodCallTranslatorProvider aggregateMethodCallTranslatorProvider) { SqlExpressionFactory = sqlExpressionFactory; Model = model; TypeMappingSource = typeMappingSource; MemberTranslatorProvider = memberTranslatorProvider; MethodCallTranslatorProvider = methodCallTranslatorProvider; + AggregateMethodCallTranslatorProvider = aggregateMethodCallTranslatorProvider; } /// @@ -65,7 +67,7 @@ public RelationalSqlTranslatingExpressionVisitorDependencies( public ISqlExpressionFactory SqlExpressionFactory { get; init; } /// - /// The expression factory. + /// The model. /// public IModel Model { get; init; } @@ -80,7 +82,12 @@ public RelationalSqlTranslatingExpressionVisitorDependencies( public IMemberTranslatorProvider MemberTranslatorProvider { get; init; } /// - /// The method-call translation provider. + /// The scalar method-call translation provider. /// public IMethodCallTranslatorProvider MethodCallTranslatorProvider { get; init; } + + /// + /// The aggregate method-call translation provider. + /// + public IAggregateMethodCallTranslatorProvider AggregateMethodCallTranslatorProvider { get; } } diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitorFactory.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitorFactory.cs index bf8ddbeaa80..ad7d4727000 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitorFactory.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitorFactory.cs @@ -3,15 +3,7 @@ namespace Microsoft.EntityFrameworkCore.Query; -/// -/// A factory for creating instances. -/// -/// -/// The service lifetime is . This means that each -/// instance will use its own instance of this service. -/// The implementation may depend on other services registered with any lifetime. -/// The implementation does not need to be thread-safe. -/// +/// public class RelationalSqlTranslatingExpressionVisitorFactory : IRelationalSqlTranslatingExpressionVisitorFactory { /// @@ -29,12 +21,7 @@ public RelationalSqlTranslatingExpressionVisitorFactory( /// protected virtual RelationalSqlTranslatingExpressionVisitorDependencies Dependencies { get; } - /// - /// Creates a new . - /// - /// The query compilation context to use. - /// The visitor to use to translate subqueries. - /// A relational sql translating expression visitor. + /// public virtual RelationalSqlTranslatingExpressionVisitor Create( QueryCompilationContext queryCompilationContext, QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor) diff --git a/src/EFCore.SqlServer/Extensions/SqlServerServiceCollectionExtensions.cs b/src/EFCore.SqlServer/Extensions/SqlServerServiceCollectionExtensions.cs index 6d79e7872fe..7588e14adec 100644 --- a/src/EFCore.SqlServer/Extensions/SqlServerServiceCollectionExtensions.cs +++ b/src/EFCore.SqlServer/Extensions/SqlServerServiceCollectionExtensions.cs @@ -120,6 +120,7 @@ public static IServiceCollection AddEntityFrameworkSqlServer(this IServiceCollec .TryAdd() .TryAdd() .TryAdd() + .TryAdd() .TryAdd() .TryAdd() .TryAdd() diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerAggregateMethodCallTranslatorProvider.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerAggregateMethodCallTranslatorProvider.cs new file mode 100644 index 00000000000..ba39f384e0f --- /dev/null +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerAggregateMethodCallTranslatorProvider.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class SqlServerAggregateMethodCallTranslatorProvider : RelationalAggregateMethodCallTranslatorProvider +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public SqlServerAggregateMethodCallTranslatorProvider(RelationalAggregateMethodCallTranslatorProviderDependencies dependencies) + : base(dependencies) + { + var sqlExpressionFactory = dependencies.SqlExpressionFactory; + AddTranslators( + new IAggregateMethodCallTranslator[] + { + new SqlServerLongCountMethodTranslator(sqlExpressionFactory) + }); + } +} diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerLongCountMethodTranslator.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerLongCountMethodTranslator.cs new file mode 100644 index 00000000000..5ef4b1b45de --- /dev/null +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerLongCountMethodTranslator.cs @@ -0,0 +1,74 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; + +namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class SqlServerLongCountMethodTranslator : IAggregateMethodCallTranslator +{ + private readonly ISqlExpressionFactory _sqlExpressionFactory; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public SqlServerLongCountMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) + { + _sqlExpressionFactory = sqlExpressionFactory; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual SqlExpression? Translate( + MethodInfo method, EnumerableExpression source, IReadOnlyList arguments, + IDiagnosticsLogger logger) + { + if (method.DeclaringType == typeof(Queryable) + && method.IsGenericMethod + && method.GetGenericMethodDefinition() is MethodInfo genericMethod + && (genericMethod == QueryableMethods.LongCountWithoutPredicate + || genericMethod == QueryableMethods.LongCountWithPredicate)) + { + var sqlExpression = (source.Selector as SqlExpression) ?? _sqlExpressionFactory.Fragment("*"); + if (source.Predicate != null) + { + if (sqlExpression is SqlFragmentExpression) + { + sqlExpression = _sqlExpressionFactory.Constant(1); + } + + sqlExpression = _sqlExpressionFactory.Case( + new List { new(source.Predicate, sqlExpression) }, + elseResult: null); + } + + if (source.IsDistinct) + { + sqlExpression = new DistinctExpression(sqlExpression); + } + + return _sqlExpressionFactory.ApplyDefaultTypeMapping( + _sqlExpressionFactory.Function( + "COUNT_BIG", + new[] { sqlExpression }, + nullable: false, + argumentsPropagateNullability: new[] { false }, + typeof(long))); + } + + return null; + } +} diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs index 5aca4d68afe..1e27cda4b69 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs @@ -152,21 +152,6 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) return base.VisitUnary(unaryExpression); } - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public override SqlExpression? TranslateLongCount(SqlExpression sqlExpression) - => Dependencies.SqlExpressionFactory.ApplyDefaultTypeMapping( - Dependencies.SqlExpressionFactory.Function( - "COUNT_BIG", - new[] { sqlExpression }, - nullable: false, - argumentsPropagateNullability: new[] { false }, - typeof(long))); - private static string? GetProviderType(SqlExpression expression) => expression.TypeMapping?.StoreType; } diff --git a/src/EFCore.Sqlite.Core/Extensions/SqliteServiceCollectionExtensions.cs b/src/EFCore.Sqlite.Core/Extensions/SqliteServiceCollectionExtensions.cs index ffda8bb98dc..4f3a76f7460 100644 --- a/src/EFCore.Sqlite.Core/Extensions/SqliteServiceCollectionExtensions.cs +++ b/src/EFCore.Sqlite.Core/Extensions/SqliteServiceCollectionExtensions.cs @@ -112,6 +112,7 @@ public static IServiceCollection AddEntityFrameworkSqlite(this IServiceCollectio // New Query Pipeline .TryAdd() + .TryAdd() .TryAdd() .TryAdd() .TryAdd() diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteAggregateMethodCallTranslatorProvider.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteAggregateMethodCallTranslatorProvider.cs new file mode 100644 index 00000000000..a836f543a5b --- /dev/null +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteAggregateMethodCallTranslatorProvider.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class SqliteAggregateMethodCallTranslatorProvider : RelationalAggregateMethodCallTranslatorProvider +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public SqliteAggregateMethodCallTranslatorProvider(RelationalAggregateMethodCallTranslatorProviderDependencies dependencies) + : base(dependencies) + { + var sqlExpressionFactory = dependencies.SqlExpressionFactory; + + AddTranslators( + new IAggregateMethodCallTranslator[] + { + new SqliteQueryableAggregateMethodTranslator(sqlExpressionFactory) + }); + } +} diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryableAggregateMethodTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryableAggregateMethodTranslator.cs new file mode 100644 index 00000000000..5ee3e9bacdd --- /dev/null +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryableAggregateMethodTranslator.cs @@ -0,0 +1,110 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; +using Microsoft.EntityFrameworkCore.Sqlite.Internal; + +namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class SqliteQueryableAggregateMethodTranslator : IAggregateMethodCallTranslator +{ + private readonly ISqlExpressionFactory _sqlExpressionFactory; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) + { + _sqlExpressionFactory = sqlExpressionFactory; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual SqlExpression? Translate( + MethodInfo method, EnumerableExpression source, IReadOnlyList arguments, + IDiagnosticsLogger logger) + { + if (method.DeclaringType == typeof(Queryable)) + { + var methodInfo = method.IsGenericMethod + ? method.GetGenericMethodDefinition() + : method; + switch (methodInfo.Name) + { + case nameof(Queryable.Average) + when (QueryableMethods.IsAverageWithoutSelector(methodInfo) + || QueryableMethods.IsAverageWithSelector(methodInfo)) + && source.Selector is SqlExpression averageSqlExpression: + var averageArgumentType = GetProviderType(averageSqlExpression); + if (averageArgumentType == typeof(decimal)) + { + throw new NotSupportedException( + SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Average), averageArgumentType.ShortDisplayName())); + } + break; + + case nameof(Queryable.Max) + when (methodInfo == QueryableMethods.MaxWithoutSelector + || methodInfo == QueryableMethods.MaxWithSelector) + && source.Selector is SqlExpression maxSqlExpression: + var maxArgumentType = GetProviderType(maxSqlExpression); + if (maxArgumentType == typeof(DateTimeOffset) + || maxArgumentType == typeof(decimal) + || maxArgumentType == typeof(TimeSpan) + || maxArgumentType == typeof(ulong)) + { + throw new NotSupportedException( + SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Max), maxArgumentType.ShortDisplayName())); + } + break; + + case nameof(Queryable.Min) + when (methodInfo == QueryableMethods.MinWithoutSelector + || methodInfo == QueryableMethods.MinWithSelector) + && source.Selector is SqlExpression minSqlExpression: + var minArgumentType = GetProviderType(minSqlExpression); + if (minArgumentType == typeof(DateTimeOffset) + || minArgumentType == typeof(decimal) + || minArgumentType == typeof(TimeSpan) + || minArgumentType == typeof(ulong)) + { + throw new NotSupportedException( + SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Min), minArgumentType.ShortDisplayName())); + } + break; + + case nameof(Queryable.Sum) + when (QueryableMethods.IsSumWithoutSelector(methodInfo) + || QueryableMethods.IsSumWithSelector(methodInfo)) + && source.Selector is SqlExpression sumSqlExpression: + var sumArgumentType = GetProviderType(sumSqlExpression); + if (sumArgumentType == typeof(decimal)) + { + throw new NotSupportedException( + SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Sum), sumArgumentType.ShortDisplayName())); + } + break; + } + } + + return null; + } + + private static Type? GetProviderType(SqlExpression expression) + => expression.TypeMapping?.Converter?.ProviderClrType + ?? expression.TypeMapping?.ClrType + ?? expression.Type; +} diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs index dde75d7154d..665b25159d5 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using Microsoft.EntityFrameworkCore.Query.SqlExpressions; -using Microsoft.EntityFrameworkCore.Sqlite.Internal; namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; @@ -212,88 +211,6 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) return visitedExpression; } - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public override SqlExpression? TranslateAverage(SqlExpression sqlExpression) - { - var visitedExpression = base.TranslateAverage(sqlExpression); - var argumentType = GetProviderType(visitedExpression); - if (argumentType == typeof(decimal)) - { - throw new NotSupportedException( - SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Average), argumentType.ShortDisplayName())); - } - - return visitedExpression; - } - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public override SqlExpression? TranslateMax(SqlExpression sqlExpression) - { - var visitedExpression = base.TranslateMax(sqlExpression); - var argumentType = GetProviderType(visitedExpression); - if (argumentType == typeof(DateTimeOffset) - || argumentType == typeof(decimal) - || argumentType == typeof(TimeSpan) - || argumentType == typeof(ulong)) - { - throw new NotSupportedException( - SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Max), argumentType.ShortDisplayName())); - } - - return visitedExpression; - } - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public override SqlExpression? TranslateMin(SqlExpression sqlExpression) - { - var visitedExpression = base.TranslateMin(sqlExpression); - var argumentType = GetProviderType(visitedExpression); - if (argumentType == typeof(DateTimeOffset) - || argumentType == typeof(decimal) - || argumentType == typeof(TimeSpan) - || argumentType == typeof(ulong)) - { - throw new NotSupportedException( - SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Min), argumentType.ShortDisplayName())); - } - - return visitedExpression; - } - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public override SqlExpression? TranslateSum(SqlExpression sqlExpression) - { - var visitedExpression = base.TranslateSum(sqlExpression); - var argumentType = GetProviderType(visitedExpression); - if (argumentType == typeof(decimal)) - { - throw new NotSupportedException( - SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Sum), argumentType.ShortDisplayName())); - } - - return visitedExpression; - } - private static Type? GetProviderType(SqlExpression? expression) => expression == null ? null diff --git a/src/EFCore/Query/CompiledQueryCacheKeyGenerator.cs b/src/EFCore/Query/CompiledQueryCacheKeyGenerator.cs index 1f7f1e84ff8..afbcf6ae9f9 100644 --- a/src/EFCore/Query/CompiledQueryCacheKeyGenerator.cs +++ b/src/EFCore/Query/CompiledQueryCacheKeyGenerator.cs @@ -3,28 +3,7 @@ namespace Microsoft.EntityFrameworkCore.Query; -/// -/// -/// Creates keys that uniquely identifies a query. This is used to store and lookup -/// compiled versions of a query in a cache. -/// -/// -/// This type is typically used by database providers (and other extensions). It is generally -/// not used in application code. -/// -/// -/// -/// -/// The service lifetime is . This means that each -/// instance will use its own instance of this service. -/// The implementation may depend on other services registered with any lifetime. -/// The implementation does not need to be thread-safe. -/// -/// -/// See Implementation of database providers and extensions -/// and How EF Core queries work for more information and examples. -/// -/// +/// public class CompiledQueryCacheKeyGenerator : ICompiledQueryCacheKeyGenerator { /// @@ -41,12 +20,7 @@ public CompiledQueryCacheKeyGenerator(CompiledQueryCacheKeyGeneratorDependencies /// protected virtual CompiledQueryCacheKeyGeneratorDependencies Dependencies { get; } - /// - /// Generates the cache key for the given query. - /// - /// The query to get the cache key for. - /// A value indicating whether the query will be executed asynchronously. - /// The cache key. + /// public virtual object GenerateCacheKey(Expression query, bool async) => GenerateCacheKeyCore(query, async); @@ -99,40 +73,18 @@ public CompiledQueryCacheKey( _async = async; } - /// - /// Determines if this key is equivalent to a given object (i.e. if they are keys for the same query). - /// - /// - /// The object to compare this key to. - /// - /// - /// if the object is a and is for the same query, otherwise - /// . - /// + /// public override bool Equals(object? obj) => obj is CompiledQueryCacheKey other && Equals(other); - /// - /// Indicates whether the current object is equal to another object of the same type. - /// - /// - /// An object to compare with this object. - /// - /// - /// if the current object is equal to the parameter; otherwise, . - /// + /// public bool Equals(CompiledQueryCacheKey other) => ReferenceEquals(_model, other._model) && _queryTrackingBehavior == other._queryTrackingBehavior && _async == other._async && ExpressionEqualityComparer.Instance.Equals(_query, other._query); - /// - /// Gets the hash code for the key. - /// - /// - /// The hash code for the key. - /// + /// public override int GetHashCode() { var hash = new HashCode(); diff --git a/src/EFCore/Query/EvaluatableExpressionFilter.cs b/src/EFCore/Query/EvaluatableExpressionFilter.cs index 0f2a3a5f32e..d3a2d0d9bf8 100644 --- a/src/EFCore/Query/EvaluatableExpressionFilter.cs +++ b/src/EFCore/Query/EvaluatableExpressionFilter.cs @@ -3,20 +3,7 @@ namespace Microsoft.EntityFrameworkCore.Query; -/// -/// Represents a filter for evaluatable expressions. -/// -/// -/// -/// The service lifetime is . This means a single instance -/// is used by many instances. The implementation must be thread-safe. -/// This service cannot depend on services registered as . -/// -/// -/// See Implementation of database providers and extensions -/// and How EF Core queries work for more information and examples. -/// -/// +/// public class EvaluatableExpressionFilter : IEvaluatableExpressionFilter { // This methods are non-deterministic and result varies based on time of running the query. @@ -69,12 +56,7 @@ public EvaluatableExpressionFilter( /// protected virtual EvaluatableExpressionFilterDependencies Dependencies { get; } - /// - /// Checks whether the given expression can be evaluated. - /// - /// The expression. - /// The model. - /// if the expression can be evaluated; otherwise. + /// public virtual bool IsEvaluatableExpression(Expression expression, IModel model) { switch (expression) diff --git a/src/EFCore/Query/ExpressionEqualityComparer.cs b/src/EFCore/Query/ExpressionEqualityComparer.cs index d8bdc13eccf..bd8934fee4d 100644 --- a/src/EFCore/Query/ExpressionEqualityComparer.cs +++ b/src/EFCore/Query/ExpressionEqualityComparer.cs @@ -28,11 +28,7 @@ private ExpressionEqualityComparer() /// public static ExpressionEqualityComparer Instance { get; } = new(); - /// - /// Returns the hash code for given expression. - /// - /// The obj to compute hash code for. - /// The hash code value for . + /// public int GetHashCode(Expression obj) { if (obj == null) @@ -276,12 +272,7 @@ void AddMemberBindingsToHash(IReadOnlyList memberBindings) } } - /// - /// Returns a value indicating whether the given expressions are equal. - /// - /// The left expression. - /// The right expression. - /// if the expressions are equal, otherwise. + /// public bool Equals(Expression? x, Expression? y) => new ExpressionComparer().Compare(x, y); diff --git a/src/EFCore/Query/ICompiledQueryCacheKeyGenerator.cs b/src/EFCore/Query/ICompiledQueryCacheKeyGenerator.cs index 81f8f30b8b4..43bbea6d271 100644 --- a/src/EFCore/Query/ICompiledQueryCacheKeyGenerator.cs +++ b/src/EFCore/Query/ICompiledQueryCacheKeyGenerator.cs @@ -4,7 +4,14 @@ namespace Microsoft.EntityFrameworkCore.Query; /// -/// A cache key generator for the compiled query cache. +/// +/// Creates keys that uniquely identifies a query. This is used to store and lookup +/// compiled versions of a query in a cache. +/// +/// +/// This type is typically used by database providers (and other extensions). It is generally +/// not used in application code. +/// /// /// /// diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs index 4244e4018bf..fef1f7dc887 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs @@ -626,9 +626,9 @@ public override async Task GroupBy_constant_with_where_on_grouping_with_aggregat WHEN 1 = [t].[Key] THEN [t].[OrderDate] END) AS [Max], COALESCE(SUM(CASE WHEN 1 = [t].[Key] THEN [t].[OrderID] -END), 0) AS [Sum], AVG(CAST(CASE - WHEN 1 = [t].[Key] THEN [t].[OrderID] -END AS float)) AS [Average] +END), 0) AS [Sum], AVG(CASE + WHEN 1 = [t].[Key] THEN CAST([t].[OrderID] AS float) +END) AS [Average] FROM ( SELECT [o].[OrderID], [o].[OrderDate], 1 AS [Key] FROM [Orders] AS [o] @@ -1782,9 +1782,9 @@ public override async Task GroupBy_Where_Average(bool async) await base.GroupBy_Where_Average(async); AssertSql( - @"SELECT AVG(CAST(CASE - WHEN [o].[OrderID] < 10300 THEN [o].[OrderID] -END AS float)) + @"SELECT AVG(CASE + WHEN [o].[OrderID] < 10300 THEN CAST([o].[OrderID] AS float) +END) FROM [Orders] AS [o] GROUP BY [o].[CustomerID]"); }