diff --git a/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs b/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs index d4d43fbfc23..a30b3485fb7 100644 --- a/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs +++ b/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs @@ -66,6 +66,7 @@ public static readonly IDictionary RelationalServi { typeof(IModificationCommandBatchFactory), new ServiceCharacteristics(ServiceLifetime.Scoped) }, { typeof(IRelationalSqlTranslatingExpressionVisitorFactory), new ServiceCharacteristics(ServiceLifetime.Scoped) }, { typeof(IMethodCallTranslatorProvider), new ServiceCharacteristics(ServiceLifetime.Scoped) }, + { typeof(IAggregateMethodCallTranslatorProvider), new ServiceCharacteristics(ServiceLifetime.Scoped) }, { typeof(IMemberTranslatorProvider), new ServiceCharacteristics(ServiceLifetime.Scoped) }, { typeof(ISqlExpressionFactory), new ServiceCharacteristics(ServiceLifetime.Scoped) }, { typeof(IRelationalQueryStringFactory), new ServiceCharacteristics(ServiceLifetime.Scoped) }, @@ -171,6 +172,7 @@ public override EntityFrameworkServicesBuilder TryAddCoreServices() TryAdd(); TryAdd(); TryAdd(); + TryAdd(); TryAdd(); TryAdd(); TryAdd(); @@ -202,6 +204,7 @@ public override EntityFrameworkServicesBuilder TryAddCoreServices() .AddDependencyScoped() .AddDependencyScoped() .AddDependencyScoped() + .AddDependencyScoped() .AddDependencyScoped() .AddDependencyScoped() .AddDependencyScoped() diff --git a/src/EFCore.Relational/Query/IAggregateMethodCallTranslator.cs b/src/EFCore.Relational/Query/IAggregateMethodCallTranslator.cs new file mode 100644 index 00000000000..58d7f8b8334 --- /dev/null +++ b/src/EFCore.Relational/Query/IAggregateMethodCallTranslator.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; + +namespace Microsoft.EntityFrameworkCore.Query; + +/// +/// +/// A SQL translator for LINQ expression represending an aggregate function. +/// +/// +/// This interface is typically used by database providers (and other extensions). It is generally +/// not used in application code. +/// +/// +public interface IAggregateMethodCallTranslator +{ + /// + /// Translates a LINQ to a SQL equivalent. + /// + /// The method info from . + /// The source on which the aggregate method is applied. + /// SQL representations of scalar . + /// The query logger to use. + /// A SQL translation of the . + SqlExpression? Translate( + MethodInfo method, + EnumerableExpression source, + IReadOnlyList arguments, + IDiagnosticsLogger logger); +} diff --git a/src/EFCore.Relational/Query/IAggregateMethodCallTranslatorPlugin.cs b/src/EFCore.Relational/Query/IAggregateMethodCallTranslatorPlugin.cs new file mode 100644 index 00000000000..7fa452b7491 --- /dev/null +++ b/src/EFCore.Relational/Query/IAggregateMethodCallTranslatorPlugin.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Query; + +/// +/// Represents plugin for . +/// +/// +/// The service lifetime is and multiple registrations +/// are allowed. This means that each instance will use its own +/// set of instances of this service. +/// The implementations may depend on other services registered with any lifetime. +/// The implementations do not need to be thread-safe. +/// +public interface IAggregateMethodCallTranslatorPlugin +{ + /// + /// Gets the method call translators. + /// + IEnumerable Translators { get; } +} diff --git a/src/EFCore.Relational/Query/IAggregateMethodCallTranslatorProvider.cs b/src/EFCore.Relational/Query/IAggregateMethodCallTranslatorProvider.cs new file mode 100644 index 00000000000..217558ed506 --- /dev/null +++ b/src/EFCore.Relational/Query/IAggregateMethodCallTranslatorProvider.cs @@ -0,0 +1,35 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; + +namespace Microsoft.EntityFrameworkCore.Query; + +/// +/// Provides translations for LINQ expressions which represents aggregate methods. +/// +/// +/// The service lifetime is and multiple registrations +/// are allowed. This means that each instance will use its own +/// set of instances of this service. +/// The implementations may depend on other services registered with any lifetime. +/// The implementations do not need to be thread-safe. +/// +public interface IAggregateMethodCallTranslatorProvider +{ + /// + /// Translates a LINQ aggregate to a SQL equivalent. + /// + /// A model to use for translation. + /// The method info from . + /// The source on which the aggregate method is applied. + /// SQL representations of scalar . + /// The query logger to use. + /// A SQL translation of the . + SqlExpression? Translate( + IModel model, + MethodInfo method, + EnumerableExpression source, + IReadOnlyList arguments, + IDiagnosticsLogger logger); +} diff --git a/src/EFCore.Relational/Query/IMethodCallTranslator.cs b/src/EFCore.Relational/Query/IMethodCallTranslator.cs index 13432daa93f..a96a9e2ff69 100644 --- a/src/EFCore.Relational/Query/IMethodCallTranslator.cs +++ b/src/EFCore.Relational/Query/IMethodCallTranslator.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Diagnostics.CodeAnalysis; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; namespace Microsoft.EntityFrameworkCore.Query; diff --git a/src/EFCore.Relational/Query/IMethodCallTranslatorProvider.cs b/src/EFCore.Relational/Query/IMethodCallTranslatorProvider.cs index e0fda2ee45e..9a729cc2749 100644 --- a/src/EFCore.Relational/Query/IMethodCallTranslatorProvider.cs +++ b/src/EFCore.Relational/Query/IMethodCallTranslatorProvider.cs @@ -6,12 +6,14 @@ namespace Microsoft.EntityFrameworkCore.Query; /// -/// Provides translations for LINQ expressions. +/// Provides translations for LINQ expressions which represents scalar methods. /// /// -/// The service lifetime is . This means a single instance -/// is used by many instances. The implementation must be thread-safe. -/// This service cannot depend on services registered as . +/// The service lifetime is and multiple registrations +/// are allowed. This means that each instance will use its own +/// set of instances of this service. +/// The implementations may depend on other services registered with any lifetime. +/// The implementations do not need to be thread-safe. /// public interface IMethodCallTranslatorProvider { diff --git a/src/EFCore.Relational/Query/Internal/QueryableAggregateMethodTranslator.cs b/src/EFCore.Relational/Query/Internal/QueryableAggregateMethodTranslator.cs new file mode 100644 index 00000000000..b0592ca7323 --- /dev/null +++ b/src/EFCore.Relational/Query/Internal/QueryableAggregateMethodTranslator.cs @@ -0,0 +1,180 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; + +namespace Microsoft.EntityFrameworkCore.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class QueryableAggregateMethodTranslator : IAggregateMethodCallTranslator +{ + private readonly ISqlExpressionFactory _sqlExpressionFactory; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public QueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) + { + _sqlExpressionFactory = sqlExpressionFactory; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual SqlExpression? Translate( + MethodInfo method, EnumerableExpression source, IReadOnlyList arguments, + IDiagnosticsLogger logger) + { + if (method.DeclaringType == typeof(Queryable)) + { + var methodInfo = method.IsGenericMethod + ? method.GetGenericMethodDefinition() + : method; + switch (methodInfo.Name) + { + case nameof(Queryable.Average) + when (QueryableMethods.IsAverageWithoutSelector(methodInfo) + || QueryableMethods.IsAverageWithSelector(methodInfo)) + && source.Selector is SqlExpression averageSqlExpression: + var averageInputType = averageSqlExpression.Type; + if (averageInputType == typeof(int) + || averageInputType == typeof(long)) + { + averageSqlExpression = _sqlExpressionFactory.ApplyDefaultTypeMapping( + _sqlExpressionFactory.Convert(averageSqlExpression, typeof(double))); + } + + averageSqlExpression = CombineTerms(source, averageSqlExpression); + return averageInputType == typeof(float) + ? _sqlExpressionFactory.Convert( + _sqlExpressionFactory.Function( + "AVG", + new[] { averageSqlExpression }, + nullable: true, + argumentsPropagateNullability: new[] { false }, + typeof(double)), + averageSqlExpression.Type, + averageSqlExpression.TypeMapping) + : _sqlExpressionFactory.Function( + "AVG", + new[] { averageSqlExpression }, + nullable: true, + argumentsPropagateNullability: new[] { false }, + averageSqlExpression.Type, + averageSqlExpression.TypeMapping); + + case nameof(Queryable.Count) + when methodInfo == QueryableMethods.CountWithoutPredicate + || methodInfo == QueryableMethods.CountWithPredicate: + var countSqlExpression = (source.Selector as SqlExpression) ?? _sqlExpressionFactory.Fragment("*"); + countSqlExpression = CombineTerms(source, countSqlExpression); + return _sqlExpressionFactory.ApplyDefaultTypeMapping( + _sqlExpressionFactory.Function( + "COUNT", + new[] { countSqlExpression }, + nullable: false, + argumentsPropagateNullability: new[] { false }, + typeof(int))); + + case nameof(Queryable.LongCount) + when methodInfo == QueryableMethods.LongCountWithoutPredicate + || methodInfo == QueryableMethods.LongCountWithPredicate: + var longCountSqlExpression = (source.Selector as SqlExpression) ?? _sqlExpressionFactory.Fragment("*"); + longCountSqlExpression = CombineTerms(source, longCountSqlExpression); + + return _sqlExpressionFactory.ApplyDefaultTypeMapping( + _sqlExpressionFactory.Function( + "COUNT", + new[] { longCountSqlExpression }, + nullable: false, + argumentsPropagateNullability: new[] { false }, + typeof(long))); + + case nameof(Queryable.Max) + when (methodInfo == QueryableMethods.MaxWithoutSelector + || methodInfo == QueryableMethods.MaxWithSelector) + && source.Selector is SqlExpression maxSqlExpression: + maxSqlExpression = CombineTerms(source, maxSqlExpression); + return _sqlExpressionFactory.Function( + "MAX", + new[] { maxSqlExpression }, + nullable: true, + argumentsPropagateNullability: new[] { false }, + maxSqlExpression.Type, + maxSqlExpression.TypeMapping); + + case nameof(Queryable.Min) + when (methodInfo == QueryableMethods.MinWithoutSelector + || methodInfo == QueryableMethods.MinWithSelector) + && source.Selector is SqlExpression minSqlExpression: + minSqlExpression = CombineTerms(source, minSqlExpression); + return _sqlExpressionFactory.Function( + "MIN", + new[] { minSqlExpression }, + nullable: true, + argumentsPropagateNullability: new[] { false }, + minSqlExpression.Type, + minSqlExpression.TypeMapping); + + case nameof(Queryable.Sum) + when (QueryableMethods.IsSumWithoutSelector(methodInfo) + || QueryableMethods.IsSumWithSelector(methodInfo)) + && source.Selector is SqlExpression sumSqlExpression: + sumSqlExpression = CombineTerms(source, sumSqlExpression); + var sumInputType = sumSqlExpression.Type; + return sumInputType == typeof(float) + ? _sqlExpressionFactory.Convert( + _sqlExpressionFactory.Function( + "SUM", + new[] { sumSqlExpression }, + nullable: true, + argumentsPropagateNullability: new[] { false }, + typeof(double)), + sumInputType, + sumSqlExpression.TypeMapping) + : _sqlExpressionFactory.Function( + "SUM", + new[] { sumSqlExpression }, + nullable: true, + argumentsPropagateNullability: new[] { false }, + sumInputType, + sumSqlExpression.TypeMapping); + } + } + + return null; + } + + private SqlExpression CombineTerms(EnumerableExpression enumerableExpression, SqlExpression sqlExpression) + { + if (enumerableExpression.Predicate != null) + { + if (sqlExpression is SqlFragmentExpression) + { + sqlExpression = _sqlExpressionFactory.Constant(1); + } + + sqlExpression = _sqlExpressionFactory.Case( + new List { new(enumerableExpression.Predicate, sqlExpression) }, + elseResult: null); + } + + if (enumerableExpression.IsDistinct) + { + sqlExpression = new DistinctExpression(sqlExpression); + } + + return sqlExpression; + } +} diff --git a/src/EFCore.Relational/Query/RelationalAggregateMethodCallTranslatorProvider.cs b/src/EFCore.Relational/Query/RelationalAggregateMethodCallTranslatorProvider.cs new file mode 100644 index 00000000000..b2cb7b5b450 --- /dev/null +++ b/src/EFCore.Relational/Query/RelationalAggregateMethodCallTranslatorProvider.cs @@ -0,0 +1,89 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.EntityFrameworkCore.Query.Internal; +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; + +namespace Microsoft.EntityFrameworkCore.Query; + +/// +public class RelationalAggregateMethodCallTranslatorProvider : IAggregateMethodCallTranslatorProvider +{ + private readonly List _plugins = new(); + private readonly List _translators = new(); + private readonly ISqlExpressionFactory _sqlExpressionFactory; + + /// + /// Creates a new instance of the class. + /// + /// Parameter object containing dependencies for this class. + public RelationalAggregateMethodCallTranslatorProvider(RelationalAggregateMethodCallTranslatorProviderDependencies dependencies) + { + Dependencies = dependencies; + + _plugins.AddRange(dependencies.Plugins.SelectMany(p => p.Translators)); + + _sqlExpressionFactory = dependencies.SqlExpressionFactory; + + _translators.AddRange( + new IAggregateMethodCallTranslator[] + { + new QueryableAggregateMethodTranslator(_sqlExpressionFactory) + }); ; + } + + /// + /// Dependencies for this service. + /// + protected virtual RelationalAggregateMethodCallTranslatorProviderDependencies Dependencies { get; } + + /// + public virtual SqlExpression? Translate( + IModel model, + MethodInfo method, + EnumerableExpression source, + IReadOnlyList arguments, + IDiagnosticsLogger logger) + { + // TODO: Add support for user defined aggregate functions + //var dbFunction = model.FindDbFunction(method); + //if (dbFunction != null) + //{ + // if (dbFunction.Translation != null) + // { + // return dbFunction.Translation.Invoke( + // arguments.Select(e => _sqlExpressionFactory.ApplyDefaultTypeMapping(e)).ToList()); + // } + + // var argumentsPropagateNullability = dbFunction.Parameters.Select(p => p.PropagatesNullability); + + // return dbFunction.IsBuiltIn + // ? _sqlExpressionFactory.Function( + // dbFunction.Name, + // arguments, + // dbFunction.IsNullable, + // argumentsPropagateNullability, + // method.ReturnType.UnwrapNullableType(), + // dbFunction.TypeMapping) + // : _sqlExpressionFactory.Function( + // dbFunction.Schema, + // dbFunction.Name, + // arguments, + // dbFunction.IsNullable, + // argumentsPropagateNullability, + // method.ReturnType.UnwrapNullableType(), + // dbFunction.TypeMapping); + //} + + return _plugins.Concat(_translators) + .Select(t => t.Translate(method, source, arguments, logger)) + .FirstOrDefault(t => t != null); + } + + /// + /// Adds additional translators which will take priority over existing registered translators. + /// + /// Translators to add. + protected virtual void AddTranslators(IEnumerable translators) + => _translators.InsertRange(0, translators); +} diff --git a/src/EFCore.Relational/Query/RelationalAggregateMethodCallTranslatorProviderDependencies.cs b/src/EFCore.Relational/Query/RelationalAggregateMethodCallTranslatorProviderDependencies.cs new file mode 100644 index 00000000000..14dee9fde57 --- /dev/null +++ b/src/EFCore.Relational/Query/RelationalAggregateMethodCallTranslatorProviderDependencies.cs @@ -0,0 +1,72 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Query; + +/// +/// +/// Service dependencies parameter class for +/// +/// +/// This type is typically used by database providers (and other extensions). It is generally +/// not used in application code. +/// +/// +/// +/// +/// Do not construct instances of this class directly from either provider or application code as the +/// constructor signature may change as new dependencies are added. Instead, use this type in +/// your constructor so that an instance will be created and injected automatically by the +/// dependency injection container. To create an instance with some dependent services replaced, +/// first resolve the object from the dependency injection container, then replace selected +/// services using the 'With...' methods. Do not call the constructor at any point in this process. +/// +/// +/// The service lifetime is . This means that each +/// instance will use its own instance of this service. +/// The implementation may depend on other services registered with any lifetime. +/// The implementation does not need to be thread-safe. +/// +/// +public sealed record RelationalAggregateMethodCallTranslatorProviderDependencies +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + /// + /// Do not call this constructor directly from either provider or application code as it may change + /// as new dependencies are added. Instead, use this type in your constructor so that an instance + /// will be created and injected automatically by the dependency injection container. To create + /// an instance with some dependent services replaced, first resolve the object from the dependency + /// injection container, then replace selected services using the 'With...' methods. Do not call + /// the constructor at any point in this process. + /// + [EntityFrameworkInternal] + public RelationalAggregateMethodCallTranslatorProviderDependencies( + ISqlExpressionFactory sqlExpressionFactory, + IEnumerable plugins, + IRelationalTypeMappingSource typeMappingSource) + { + SqlExpressionFactory = sqlExpressionFactory; + Plugins = plugins; + RelationalTypeMappingSource = typeMappingSource; + } + + /// + /// The expression factory.. + /// + public ISqlExpressionFactory SqlExpressionFactory { get; init; } + + /// + /// Registered plugins. + /// + public IEnumerable Plugins { get; init; } + + /// + /// Relational Type Mapping Source. + /// + public IRelationalTypeMappingSource RelationalTypeMappingSource { get; init; } +} diff --git a/src/EFCore.Relational/Query/RelationalCompiledQueryCacheKeyGenerator.cs b/src/EFCore.Relational/Query/RelationalCompiledQueryCacheKeyGenerator.cs index 5f8893e4f75..b193a826cd4 100644 --- a/src/EFCore.Relational/Query/RelationalCompiledQueryCacheKeyGenerator.cs +++ b/src/EFCore.Relational/Query/RelationalCompiledQueryCacheKeyGenerator.cs @@ -3,22 +3,7 @@ namespace Microsoft.EntityFrameworkCore.Query; -/// -/// -/// Creates keys that uniquely identifies a query. This is used to store and lookup -/// compiled versions of a query in a cache. -/// -/// -/// This type is typically used by database providers (and other extensions). It is generally -/// not used in application code. -/// -/// -/// -/// The service lifetime is . This means that each -/// instance will use its own instance of this service. -/// The implementation may depend on other services registered with any lifetime. -/// The implementation does not need to be thread-safe. -/// +/// public class RelationalCompiledQueryCacheKeyGenerator : CompiledQueryCacheKeyGenerator { /// @@ -39,12 +24,7 @@ public RelationalCompiledQueryCacheKeyGenerator( /// protected virtual RelationalCompiledQueryCacheKeyGeneratorDependencies RelationalDependencies { get; } - /// - /// Generates the cache key for the given query. - /// - /// The query to get the cache key for. - /// A value indicating whether the query will be executed asynchronously. - /// The cache key. + /// public override object GenerateCacheKey(Expression query, bool async) => GenerateCacheKeyCore(query, async); @@ -102,41 +82,19 @@ public RelationalCompiledQueryCacheKey( _shouldBuffer = shouldBuffer; } - /// - /// Determines if this key is equivalent to a given object (i.e. if they are keys for the same query). - /// - /// - /// The object to compare this key to. - /// - /// - /// if the object is a and is for the same query, - /// otherwise . - /// + /// public override bool Equals(object? obj) => obj is RelationalCompiledQueryCacheKey key && Equals(key); - /// - /// Indicates whether the current object is equal to another object of the same type. - /// - /// - /// An object to compare with this object. - /// - /// - /// if the current object is equal to the parameter; otherwise, . - /// + /// public bool Equals(RelationalCompiledQueryCacheKey other) => _compiledQueryCacheKey.Equals(other._compiledQueryCacheKey) && _useRelationalNulls == other._useRelationalNulls && _querySplittingBehavior == other._querySplittingBehavior && _shouldBuffer == other._shouldBuffer; - /// - /// Gets the hash code for the key. - /// - /// - /// The hash code for the key. - /// + /// public override int GetHashCode() => HashCode.Combine( _compiledQueryCacheKey, _useRelationalNulls, _querySplittingBehavior, _shouldBuffer); diff --git a/src/EFCore.Relational/Query/RelationalEvaluatableExpressionFilter.cs b/src/EFCore.Relational/Query/RelationalEvaluatableExpressionFilter.cs index c1a2e6e7286..29df9899207 100644 --- a/src/EFCore.Relational/Query/RelationalEvaluatableExpressionFilter.cs +++ b/src/EFCore.Relational/Query/RelationalEvaluatableExpressionFilter.cs @@ -3,14 +3,7 @@ namespace Microsoft.EntityFrameworkCore.Query; -/// -/// Represents a filter for evaluatable expressions. -/// -/// -/// The service lifetime is . This means a single instance -/// is used by many instances. The implementation must be thread-safe. -/// This service cannot depend on services registered as . -/// +/// public class RelationalEvaluatableExpressionFilter : EvaluatableExpressionFilter { /// @@ -37,12 +30,7 @@ public RelationalEvaluatableExpressionFilter( /// protected virtual RelationalEvaluatableExpressionFilterDependencies RelationalDependencies { get; } - /// - /// Checks whether the given expression can be evaluated. - /// - /// The expression. - /// The model. - /// if the expression can be evaluated; otherwise. + /// public override bool IsEvaluatableExpression(Expression expression, IModel model) { if (expression is MethodCallExpression methodCallExpression) diff --git a/src/EFCore.Relational/Query/RelationalMemberTranslatorProvider.cs b/src/EFCore.Relational/Query/RelationalMemberTranslatorProvider.cs index 358c75583c6..7ad3d3f30e6 100644 --- a/src/EFCore.Relational/Query/RelationalMemberTranslatorProvider.cs +++ b/src/EFCore.Relational/Query/RelationalMemberTranslatorProvider.cs @@ -6,16 +6,7 @@ namespace Microsoft.EntityFrameworkCore.Query; -/// -/// Provides translations for LINQ expressions by dispatching to multiple specialized member -/// translators. -/// -/// -/// The service lifetime is . This means that each -/// instance will use its own instance of this service. -/// The implementation may depend on other services registered with any lifetime. -/// The implementation does not need to be thread-safe. -/// +/// public class RelationalMemberTranslatorProvider : IMemberTranslatorProvider { private readonly List _plugins = new(); diff --git a/src/EFCore.Relational/Query/RelationalMethodCallTranslatorProvider.cs b/src/EFCore.Relational/Query/RelationalMethodCallTranslatorProvider.cs index ca8ba28bd61..566f3d12bf8 100644 --- a/src/EFCore.Relational/Query/RelationalMethodCallTranslatorProvider.cs +++ b/src/EFCore.Relational/Query/RelationalMethodCallTranslatorProvider.cs @@ -6,16 +6,7 @@ namespace Microsoft.EntityFrameworkCore.Query; -/// -/// Provides translations for LINQ expressions by dispatching to multiple specialized -/// method call translators. -/// -/// -/// The service lifetime is . This means that each -/// instance will use its own instance of this service. -/// The implementation may depend on other services registered with any lifetime. -/// The implementation does not need to be thread-safe. -/// +/// public class RelationalMethodCallTranslatorProvider : IMethodCallTranslatorProvider { private readonly List _plugins = new(); diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index 31a745a013b..c1f94c9a06f 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -237,7 +237,7 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent LambdaExpression? selector, Type resultType) => TranslateAggregateWithSelector( - source, selector, e => _sqlTranslator.TranslateAverage(e), throwWhenEmpty: true, resultType); + source, selector, e => TranslateAverage(e), throwWhenEmpty: true, resultType); /// protected override ShapedQueryExpression? TranslateCast(ShapedQueryExpression source, Type resultType) @@ -301,7 +301,7 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent /// protected override ShapedQueryExpression? TranslateCount(ShapedQueryExpression source, LambdaExpression? predicate) - => TranslateAggregateWithPredicate(source, predicate, e => _sqlTranslator.TranslateCount(e), typeof(int)); + => TranslateAggregateWithPredicate(source, predicate, e => TranslateCount(e), typeof(int)); /// protected override ShapedQueryExpression? TranslateDefaultIfEmpty(ShapedQueryExpression source, Expression? defaultValue) @@ -623,15 +623,15 @@ private SqlExpression CreateJoinPredicate(Expression outerKey, Expression innerK /// protected override ShapedQueryExpression? TranslateLongCount(ShapedQueryExpression source, LambdaExpression? predicate) - => TranslateAggregateWithPredicate(source, predicate, e => _sqlTranslator.TranslateLongCount(e), typeof(long)); + => TranslateAggregateWithPredicate(source, predicate, e => TranslateLongCount(e), typeof(long)); /// protected override ShapedQueryExpression? TranslateMax(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - => TranslateAggregateWithSelector(source, selector, e => _sqlTranslator.TranslateMax(e), throwWhenEmpty: true, resultType); + => TranslateAggregateWithSelector(source, selector, e => TranslateMax(e), throwWhenEmpty: true, resultType); /// protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - => TranslateAggregateWithSelector(source, selector, e => _sqlTranslator.TranslateMin(e), throwWhenEmpty: true, resultType); + => TranslateAggregateWithSelector(source, selector, e => TranslateMin(e), throwWhenEmpty: true, resultType); /// protected override ShapedQueryExpression? TranslateOfType(ShapedQueryExpression source, Type resultType) @@ -888,7 +888,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp /// protected override ShapedQueryExpression? TranslateSum(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - => TranslateAggregateWithSelector(source, selector, e => _sqlTranslator.TranslateSum(e), throwWhenEmpty: false, resultType); + => TranslateAggregateWithSelector(source, selector, e => TranslateSum(e), throwWhenEmpty: false, resultType); /// protected override ShapedQueryExpression? TranslateTake(ShapedQueryExpression source, Expression count) @@ -1465,6 +1465,36 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape } } + private SqlExpression? TranslateAverage(SqlExpression sqlExpression) + => RelationalDependencies.AggregateMethodCallTranslatorProvider.Translate( + RelationalDependencies.Model, QueryableMethods.GetAverageWithoutSelector(sqlExpression.Type), + new EnumerableExpression(sqlExpression), Array.Empty(), _queryCompilationContext.Logger); + + private SqlExpression? TranslateCount(SqlExpression sqlExpression) + => RelationalDependencies.AggregateMethodCallTranslatorProvider.Translate( + RelationalDependencies.Model, QueryableMethods.CountWithoutPredicate, + new EnumerableExpression(sqlExpression), Array.Empty(), _queryCompilationContext.Logger); + + private SqlExpression? TranslateLongCount(SqlExpression sqlExpression) + => RelationalDependencies.AggregateMethodCallTranslatorProvider.Translate( + RelationalDependencies.Model, QueryableMethods.LongCountWithoutPredicate, + new EnumerableExpression(sqlExpression), Array.Empty(), _queryCompilationContext.Logger); + + private SqlExpression? TranslateMax(SqlExpression sqlExpression) + => RelationalDependencies.AggregateMethodCallTranslatorProvider.Translate( + RelationalDependencies.Model, QueryableMethods.MaxWithoutSelector, + new EnumerableExpression(sqlExpression), Array.Empty(), _queryCompilationContext.Logger); + + private SqlExpression? TranslateMin(SqlExpression sqlExpression) + => RelationalDependencies.AggregateMethodCallTranslatorProvider.Translate( + RelationalDependencies.Model, QueryableMethods.MinWithoutSelector, + new EnumerableExpression(sqlExpression), Array.Empty(), _queryCompilationContext.Logger); + + private SqlExpression? TranslateSum(SqlExpression sqlExpression) + => RelationalDependencies.AggregateMethodCallTranslatorProvider.Translate( + RelationalDependencies.Model, QueryableMethods.GetSumWithoutSelector(sqlExpression.Type), + new EnumerableExpression(sqlExpression), Array.Empty(), _queryCompilationContext.Logger); + private ShapedQueryExpression? TranslateAggregateWithPredicate( ShapedQueryExpression source, LambdaExpression? predicate, diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitorDependencies.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitorDependencies.cs index 9fb23b65d9b..ad3fd864baf 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitorDependencies.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitorDependencies.cs @@ -47,10 +47,14 @@ public sealed record RelationalQueryableMethodTranslatingExpressionVisitorDepend [EntityFrameworkInternal] public RelationalQueryableMethodTranslatingExpressionVisitorDependencies( IRelationalSqlTranslatingExpressionVisitorFactory relationalSqlTranslatingExpressionVisitorFactory, - ISqlExpressionFactory sqlExpressionFactory) + ISqlExpressionFactory sqlExpressionFactory, + IModel model, + IAggregateMethodCallTranslatorProvider aggregateMethodCallTranslatorProvider) { RelationalSqlTranslatingExpressionVisitorFactory = relationalSqlTranslatingExpressionVisitorFactory; SqlExpressionFactory = sqlExpressionFactory; + Model = model; + AggregateMethodCallTranslatorProvider = aggregateMethodCallTranslatorProvider; } /// @@ -62,4 +66,14 @@ public RelationalQueryableMethodTranslatingExpressionVisitorDependencies( /// The SQL expression factory. /// public ISqlExpressionFactory SqlExpressionFactory { get; init; } + + /// + /// The model. + /// + public IModel Model { get; init; } + + /// + /// The aggregate method-call translation provider. + /// + public IAggregateMethodCallTranslatorProvider AggregateMethodCallTranslatorProvider { get; } } diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index dd6922d2749..f8ccab890ad 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -3,6 +3,7 @@ using System.Collections; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using Microsoft.EntityFrameworkCore.Metadata.Internal; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; @@ -161,6 +162,7 @@ protected virtual void AddTranslationErrorDetails(string details) /// /// An expression to translate Average over. /// A SQL translation of Average over the given expression. + [Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")] public virtual SqlExpression? TranslateAverage(SqlExpression sqlExpression) { var inputType = sqlExpression.Type; @@ -199,6 +201,7 @@ protected virtual void AddTranslationErrorDetails(string details) /// /// An expression to translate Count over. /// A SQL translation of Count over the given expression. + [Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")] public virtual SqlExpression? TranslateCount(SqlExpression sqlExpression) => _sqlExpressionFactory.ApplyDefaultTypeMapping( _sqlExpressionFactory.Function( @@ -213,6 +216,7 @@ protected virtual void AddTranslationErrorDetails(string details) /// /// An expression to translate LongCount over. /// A SQL translation of LongCount over the given expression. + [Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")] public virtual SqlExpression? TranslateLongCount(SqlExpression sqlExpression) => _sqlExpressionFactory.ApplyDefaultTypeMapping( _sqlExpressionFactory.Function( @@ -227,6 +231,7 @@ protected virtual void AddTranslationErrorDetails(string details) /// /// An expression to translate Max over. /// A SQL translation of Max over the given expression. + [Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")] public virtual SqlExpression? TranslateMax(SqlExpression sqlExpression) => sqlExpression != null ? _sqlExpressionFactory.Function( @@ -243,6 +248,7 @@ protected virtual void AddTranslationErrorDetails(string details) /// /// An expression to translate Min over. /// A SQL translation of Min over the given expression. + [Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")] public virtual SqlExpression? TranslateMin(SqlExpression sqlExpression) => sqlExpression != null ? _sqlExpressionFactory.Function( @@ -259,6 +265,7 @@ protected virtual void AddTranslationErrorDetails(string details) /// /// An expression to translate Sum over. /// A SQL translation of Sum over the given expression. + [Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")] public virtual SqlExpression? TranslateSum(SqlExpression sqlExpression) { var inputType = sqlExpression.Type; @@ -647,243 +654,30 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp // EF Indexer property if (methodCallExpression.TryGetIndexerArguments(_model, out source, out propertyName)) { - var result = TryBindMember(Visit(source), MemberIdentity.Create(propertyName)); - if (result != null) + if (TryBindMember(Visit(source), MemberIdentity.Create(propertyName)) is SqlExpression result) { return result; } } - // Subquery case - // TODO: Refactor in future to avoid repeated visitation. - // Specifically ordering of visiting aggregate chain, subquery, method arguments. - if (methodCallExpression.Method.IsStatic - && methodCallExpression.Arguments.Count > 0 - && methodCallExpression.Method.DeclaringType == typeof(Queryable)) - { - if (methodCallExpression.Method.IsGenericMethod - && methodCallExpression.Method.GetGenericMethodDefinition() == QueryableMethods.AsQueryable - && methodCallExpression.Arguments[0] is RelationalGroupByShaperExpression groupByShaperExpression) - { - return new EnumerableExpression(groupByShaperExpression.ElementSelector); - } - - var enumerableSource = Visit(methodCallExpression.Arguments[0]); - if (enumerableSource is EnumerableExpression enumerableExpression) - { - Expression? result = null; - switch (methodCallExpression.Method.Name) - { - case nameof(Queryable.Average): - if (methodCallExpression.Arguments.Count == 2) - { - enumerableExpression = ProcessSelector( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - } - - result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); - break; - - case nameof(Queryable.Count): - if (methodCallExpression.Arguments.Count == 2) - { - var newEnumerableExpression = ProcessPredicate( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - if (newEnumerableExpression == null) - { - break; - } - - enumerableExpression = newEnumerableExpression; - } - - result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); - break; - - - case nameof(Queryable.Distinct): - result = enumerableExpression.Selector is EntityShaperExpression entityShaperExpression - && entityShaperExpression.EntityType.FindPrimaryKey() != null - ? enumerableExpression - : !enumerableExpression.IsDistinct - ? enumerableExpression.ApplyDistinct() - : (Expression?)null; - break; - - case nameof(Queryable.LongCount): - if (methodCallExpression.Arguments.Count == 2) - { - var newEnumerableExpression = ProcessPredicate( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - if (newEnumerableExpression == null) - { - break; - } - - enumerableExpression = newEnumerableExpression; - } - - result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); - break; - - case nameof(Queryable.Max): - if (methodCallExpression.Arguments.Count == 2) - { - enumerableExpression = ProcessSelector( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - } - - result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); - break; - - case nameof(Queryable.Min): - if (methodCallExpression.Arguments.Count == 2) - { - enumerableExpression = ProcessSelector( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - } - - result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); - break; - - case nameof(Queryable.OrderBy): - result = ProcessOrderByThenBy( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: true); - break; - - case nameof(Queryable.OrderByDescending): - result = ProcessOrderByThenBy( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: false); - break; - - case nameof(Queryable.ThenBy): - result = ProcessOrderByThenBy( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: true); - break; - - case nameof(Queryable.ThenByDescending): - result = ProcessOrderByThenBy( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: false); - break; - - case nameof(Queryable.Select): - result = ProcessSelector(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - break; - - case nameof(Queryable.Sum): - if (methodCallExpression.Arguments.Count == 2) - { - enumerableExpression = ProcessSelector( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - } - - result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); - break; - - case nameof(Queryable.Where): - result = ProcessPredicate(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - break; - } - - if (result != null) - { - return result; - } - } - } - - var subqueryTranslation = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression); - if (subqueryTranslation != null) - { - if (subqueryTranslation.ResultCardinality == ResultCardinality.Enumerable) - { - return QueryCompilationContext.NotTranslatedExpression; - } - - var shaperExpression = subqueryTranslation.ShaperExpression; - ProjectionBindingExpression? mappedProjectionBindingExpression = null; - - var innerExpression = shaperExpression; - Type? convertedType = null; - if (shaperExpression is UnaryExpression unaryExpression - && unaryExpression.NodeType == ExpressionType.Convert) - { - convertedType = unaryExpression.Type; - innerExpression = unaryExpression.Operand; - } - - if (innerExpression is EntityShaperExpression ese - && (convertedType == null - || convertedType.IsAssignableFrom(ese.Type))) - { - return new EntityReferenceExpression(subqueryTranslation.UpdateShaperExpression(innerExpression)); - } - - if (innerExpression is ProjectionBindingExpression pbe - && (convertedType == null - || convertedType.MakeNullable() == innerExpression.Type)) - { - mappedProjectionBindingExpression = pbe; - } - - if (mappedProjectionBindingExpression == null - && shaperExpression is BlockExpression blockExpression - && blockExpression.Expressions.Count == 2 - && blockExpression.Expressions[0] is BinaryExpression binaryExpression - && binaryExpression.NodeType == ExpressionType.Assign - && binaryExpression.Right is ProjectionBindingExpression pbe2) - { - mappedProjectionBindingExpression = pbe2; - } - - if (mappedProjectionBindingExpression == null) - { - return QueryCompilationContext.NotTranslatedExpression; - } - - var subquery = (SelectExpression)subqueryTranslation.QueryExpression; - var projection = subquery.GetProjection(mappedProjectionBindingExpression); - if (projection is not SqlExpression sqlExpression) - { - return QueryCompilationContext.NotTranslatedExpression; - } - - if (subquery.Tables.Count == 0) - { - return sqlExpression; - } - - subquery.ReplaceProjection(new List { sqlExpression }); - subquery.ApplyProjection(); - - SqlExpression scalarSubqueryExpression = new ScalarSubqueryExpression(subquery); - - if (subqueryTranslation.ResultCardinality == ResultCardinality.SingleOrDefault - && !shaperExpression.Type.IsNullableType()) - { - scalarSubqueryExpression = _sqlExpressionFactory.Coalesce( - scalarSubqueryExpression, - (SqlExpression)Visit(shaperExpression.Type.GetDefaultValueConstant())); - } - - return scalarSubqueryExpression; - } - - SqlExpression? sqlObject = null; - SqlExpression[] arguments; var method = methodCallExpression.Method; + var arguments = methodCallExpression.Arguments; + EnumerableExpression? enumerableExpression = null; + var abortTranslation = false; + SqlExpression? sqlObject = null; + List scalarArguments; if (method.Name == nameof(object.Equals) && methodCallExpression.Object != null - && methodCallExpression.Arguments.Count == 1) + && arguments.Count == 1) { var left = Visit(methodCallExpression.Object); - var right = Visit(RemoveObjectConvert(methodCallExpression.Arguments[0])); + var right = Visit(RemoveObjectConvert(arguments[0])); if (TryRewriteEntityEquality( ExpressionType.Equal, left == QueryCompilationContext.NotTranslatedExpression ? methodCallExpression.Object : left, - right == QueryCompilationContext.NotTranslatedExpression ? methodCallExpression.Arguments[0] : right, + right == QueryCompilationContext.NotTranslatedExpression ? arguments[0] : right, equalsMethod: true, out var result)) { @@ -894,7 +688,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp && right is SqlExpression rightSql) { sqlObject = leftSql; - arguments = new[] { rightSql }; + scalarArguments = new List { rightSql }; } else { @@ -903,23 +697,23 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp } else if (method.Name == nameof(object.Equals) && methodCallExpression.Object == null - && methodCallExpression.Arguments.Count == 2) + && arguments.Count == 2) { - if (methodCallExpression.Arguments[0].Type == typeof(object[]) - && methodCallExpression.Arguments[0] is NewArrayExpression) + if (arguments[0].Type == typeof(object[]) + && arguments[0] is NewArrayExpression) { return Visit( ConvertObjectArrayEqualityComparison( - methodCallExpression.Arguments[0], methodCallExpression.Arguments[1])); + arguments[0], arguments[1])); } - var left = Visit(RemoveObjectConvert(methodCallExpression.Arguments[0])); - var right = Visit(RemoveObjectConvert(methodCallExpression.Arguments[1])); + var left = Visit(RemoveObjectConvert(arguments[0])); + var right = Visit(RemoveObjectConvert(arguments[1])); if (TryRewriteEntityEquality( ExpressionType.Equal, - left == QueryCompilationContext.NotTranslatedExpression ? methodCallExpression.Arguments[0] : left, - right == QueryCompilationContext.NotTranslatedExpression ? methodCallExpression.Arguments[1] : right, + left == QueryCompilationContext.NotTranslatedExpression ? arguments[0] : left, + right == QueryCompilationContext.NotTranslatedExpression ? arguments[1] : right, equalsMethod: true, out var result)) { @@ -929,7 +723,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp if (left is SqlExpression leftSql && right is SqlExpression rightSql) { - arguments = new[] { leftSql, rightSql }; + scalarArguments = new List { leftSql, rightSql }; } else { @@ -939,12 +733,12 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp else if (method.IsGenericMethod && method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains)) { - var enumerable = Visit(methodCallExpression.Arguments[0]); - var item = Visit(methodCallExpression.Arguments[1]); + var enumerable = Visit(arguments[0]); + var item = Visit(arguments[1]); if (TryRewriteContainsEntity( enumerable, - item == QueryCompilationContext.NotTranslatedExpression ? methodCallExpression.Arguments[1] : item, out var result)) + item == QueryCompilationContext.NotTranslatedExpression ? arguments[1] : item, out var result)) { return result; } @@ -952,22 +746,22 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp if (enumerable is SqlExpression sqlEnumerable && item is SqlExpression sqlItem) { - arguments = new[] { sqlEnumerable, sqlItem }; + scalarArguments = new List { sqlEnumerable, sqlItem }; } else { return QueryCompilationContext.NotTranslatedExpression; } } - else if (methodCallExpression.Arguments.Count == 1 + else if (arguments.Count == 1 && method.IsContainsMethod()) { var enumerable = Visit(methodCallExpression.Object); - var item = Visit(methodCallExpression.Arguments[0]); + var item = Visit(arguments[0]); if (TryRewriteContainsEntity( enumerable!, - item == QueryCompilationContext.NotTranslatedExpression ? methodCallExpression.Arguments[0] : item, out var result)) + item == QueryCompilationContext.NotTranslatedExpression ? arguments[0] : item, out var result)) { return result; } @@ -976,7 +770,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp && item is SqlExpression sqlItem) { sqlObject = sqlEnumerable; - arguments = new[] { sqlItem }; + scalarArguments = new List { sqlItem }; } else { @@ -985,31 +779,222 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp } else { - if (TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out sqlObject)) + var skipVisitChildren = false; + scalarArguments = new List(); + if (method.IsStatic + && arguments.Count > 0 + && method.DeclaringType == typeof(Queryable)) { - return QueryCompilationContext.NotTranslatedExpression; + var genericMethod = method.IsGenericMethod + ? method.GetGenericMethodDefinition() + : method; + if (genericMethod == QueryableMethods.AsQueryable + && arguments[0] is RelationalGroupByShaperExpression groupByShaperExpression) + { + return new EnumerableExpression(groupByShaperExpression.ElementSelector); + } + + var enumerableSource = Visit(arguments[0]); + if (enumerableSource is EnumerableExpression) + { + enumerableExpression = (EnumerableExpression)enumerableSource; + switch (method.Name) + { + case nameof(Queryable.Average) + when QueryableMethods.IsAverageWithoutSelector(genericMethod): + case nameof(Queryable.Max) + when genericMethod == QueryableMethods.MaxWithoutSelector: + case nameof(Queryable.Min) + when genericMethod == QueryableMethods.MinWithoutSelector: + case nameof(Queryable.Sum) + when QueryableMethods.IsSumWithoutSelector(genericMethod): + case nameof(Queryable.Count) + when genericMethod == QueryableMethods.CountWithoutPredicate: + case nameof(Queryable.LongCount) + when genericMethod == QueryableMethods.LongCountWithoutPredicate: + skipVisitChildren = true; + break; + + case nameof(Queryable.Average) + when QueryableMethods.IsAverageWithSelector(genericMethod): + case nameof(Queryable.Max) + when genericMethod == QueryableMethods.MaxWithSelector: + case nameof(Queryable.Min) + when genericMethod == QueryableMethods.MinWithSelector: + case nameof(Queryable.Sum) + when QueryableMethods.IsSumWithSelector(genericMethod): + enumerableExpression = ProcessSelector(enumerableExpression, arguments[1].UnwrapLambdaFromQuote()); + skipVisitChildren = true; + break; + + case nameof(Queryable.Count) + when genericMethod == QueryableMethods.CountWithPredicate: + case nameof(Queryable.LongCount) + when genericMethod == QueryableMethods.LongCountWithPredicate: + enumerableExpression = ProcessPredicate(enumerableExpression, arguments[1].UnwrapLambdaFromQuote()); + if (enumerableExpression == null) + { + abortTranslation = true; + } + else + { + skipVisitChildren = true; + } + break; + + case nameof(Queryable.Distinct) + when genericMethod == QueryableMethods.Distinct: + if (enumerableExpression.Selector is EntityShaperExpression entityShaperExpression + && entityShaperExpression.EntityType.FindPrimaryKey() != null) + { + return enumerableExpression; + } + + if (!enumerableExpression.IsDistinct) + { + return enumerableExpression.ApplyDistinct(); + } + + abortTranslation = true; + break; + + case nameof(Queryable.OrderBy) + when genericMethod == QueryableMethods.OrderBy: + enumerableExpression = ProcessOrderByThenBy( + enumerableExpression, arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: true); + if (enumerableExpression != null) + { + return enumerableExpression; + } + abortTranslation = true; + break; + + case nameof(Queryable.OrderByDescending) + when genericMethod == QueryableMethods.OrderByDescending: + enumerableExpression = ProcessOrderByThenBy( + enumerableExpression, arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: false); + if (enumerableExpression != null) + { + return enumerableExpression; + } + abortTranslation = true; + break; + + case nameof(Queryable.ThenBy) + when genericMethod == QueryableMethods.ThenBy: + enumerableExpression = ProcessOrderByThenBy( + enumerableExpression, arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: true); + if (enumerableExpression != null) + { + return enumerableExpression; + } + abortTranslation = true; + break; + + case nameof(Queryable.ThenByDescending) + when genericMethod == QueryableMethods.ThenByDescending: + enumerableExpression = ProcessOrderByThenBy( + enumerableExpression, arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: false); + if (enumerableExpression != null) + { + return enumerableExpression; + } + abortTranslation = true; + break; + + case nameof(Queryable.Select) + when genericMethod == QueryableMethods.Select: + return ProcessSelector(enumerableExpression, arguments[1].UnwrapLambdaFromQuote()); + + case nameof(Queryable.Where) + when genericMethod == QueryableMethods.Where: + enumerableExpression = ProcessPredicate(enumerableExpression, arguments[1].UnwrapLambdaFromQuote()); + if (enumerableExpression != null) + { + return enumerableExpression; + } + abortTranslation = true; + break; + + default: + abortTranslation = true; + break; + } + } } - arguments = new SqlExpression[methodCallExpression.Arguments.Count]; - for (var i = 0; i < arguments.Length; i++) + if (!abortTranslation + && !skipVisitChildren) { - var argument = methodCallExpression.Arguments[i]; - if (TranslationFailed(argument, Visit(argument), out var sqlArgument)) + var @object = Visit(methodCallExpression.Object); + if (@object is EnumerableExpression) + { + // This is safe since if enumerableExpression is non-null then it was static method + enumerableExpression = (EnumerableExpression)@object; + } + else if (TranslationFailed(methodCallExpression.Object, @object, out sqlObject)) { - return QueryCompilationContext.NotTranslatedExpression; + abortTranslation = true; } - arguments[i] = sqlArgument!; + if (!abortTranslation) + { + for (var i = 0; i < arguments.Count; i++) + { + var argument = arguments[i]; + var visitedArgument = Visit(argument); + if (i == 0 + && visitedArgument is EnumerableExpression) + { + if (enumerableExpression != null) + { + abortTranslation = true; + break; + } + + enumerableExpression = (EnumerableExpression)visitedArgument; + continue; + } + + if (TranslationFailed(argument, visitedArgument, out var sqlArgument)) + { + abortTranslation = true; + break; + } + + scalarArguments.Add(sqlArgument!); + } + } } } - var translation = Dependencies.MethodCallTranslatorProvider.Translate( - _model, sqlObject, methodCallExpression.Method, arguments, _queryCompilationContext.Logger); - - if (translation == null) + if (!abortTranslation) { - if (methodCallExpression.Method == StringEqualsWithStringComparison - || methodCallExpression.Method == StringEqualsWithStringComparisonStatic) + SqlExpression? translation; + if (enumerableExpression != null) + { + var selector = TranslateInternal(enumerableExpression.Selector); + if (selector != null) + { + enumerableExpression = enumerableExpression.ApplySelector(selector); + } + + translation = Dependencies.AggregateMethodCallTranslatorProvider.Translate( + _model, method, enumerableExpression, scalarArguments, _queryCompilationContext.Logger); + } + else + { + translation = Dependencies.MethodCallTranslatorProvider.Translate( + _model, sqlObject, method, scalarArguments, _queryCompilationContext.Logger); + } + + if (translation != null) + { + return translation; + } + + if (method == StringEqualsWithStringComparison + || method == StringEqualsWithStringComparisonStatic) { AddTranslationErrorDetails(CoreStrings.QueryUnableToTranslateStringEqualsWithStringComparison); } @@ -1017,12 +1002,90 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp { AddTranslationErrorDetails( CoreStrings.QueryUnableToTranslateMethod( - methodCallExpression.Method.DeclaringType?.DisplayName(), - methodCallExpression.Method.Name)); + method.DeclaringType?.DisplayName(), + method.Name)); + } + } + + // Subquery case + var subqueryTranslation = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression); + if (subqueryTranslation != null) + { + if (subqueryTranslation.ResultCardinality == ResultCardinality.Enumerable) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + var shaperExpression = subqueryTranslation.ShaperExpression; + ProjectionBindingExpression? mappedProjectionBindingExpression = null; + + var innerExpression = shaperExpression; + Type? convertedType = null; + if (shaperExpression is UnaryExpression unaryExpression + && unaryExpression.NodeType == ExpressionType.Convert) + { + convertedType = unaryExpression.Type; + innerExpression = unaryExpression.Operand; + } + + if (innerExpression is EntityShaperExpression ese + && (convertedType == null + || convertedType.IsAssignableFrom(ese.Type))) + { + return new EntityReferenceExpression(subqueryTranslation.UpdateShaperExpression(innerExpression)); } + + if (innerExpression is ProjectionBindingExpression pbe + && (convertedType == null + || convertedType.MakeNullable() == innerExpression.Type)) + { + mappedProjectionBindingExpression = pbe; + } + + if (mappedProjectionBindingExpression == null + && shaperExpression is BlockExpression blockExpression + && blockExpression.Expressions.Count == 2 + && blockExpression.Expressions[0] is BinaryExpression binaryExpression + && binaryExpression.NodeType == ExpressionType.Assign + && binaryExpression.Right is ProjectionBindingExpression pbe2) + { + mappedProjectionBindingExpression = pbe2; + } + + if (mappedProjectionBindingExpression == null) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + var subquery = (SelectExpression)subqueryTranslation.QueryExpression; + var projection = subquery.GetProjection(mappedProjectionBindingExpression); + if (projection is not SqlExpression sqlExpression) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + if (subquery.Tables.Count == 0) + { + return sqlExpression; + } + + subquery.ReplaceProjection(new List { sqlExpression }); + subquery.ApplyProjection(); + + SqlExpression scalarSubqueryExpression = new ScalarSubqueryExpression(subquery); + + if (subqueryTranslation.ResultCardinality == ResultCardinality.SingleOrDefault + && !shaperExpression.Type.IsNullableType()) + { + scalarSubqueryExpression = _sqlExpressionFactory.Coalesce( + scalarSubqueryExpression, + (SqlExpression)Visit(shaperExpression.Type.GetDefaultValueConstant())); + } + + return scalarSubqueryExpression; } - return translation ?? QueryCompilationContext.NotTranslatedExpression; + return QueryCompilationContext.NotTranslatedExpression; } /// @@ -1211,7 +1274,7 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) return QueryCompilationContext.NotTranslatedExpression; } - private Expression? TryBindMember(Expression? source, MemberIdentity member) + private SqlExpression? TryBindMember(Expression? source, MemberIdentity member) { if (source is not EntityReferenceExpression entityReferenceExpression) { @@ -1334,53 +1397,6 @@ private static EnumerableExpression ProcessSelector(EnumerableExpression enumera return enumerableExpression.ApplyPredicate(predicate); } - private SqlExpression? TranslateAggregate(MethodInfo methodInfo, EnumerableExpression enumerableExpression) - { - var selector = TranslateInternal(enumerableExpression.Selector); - if (selector == null) - { - if (methodInfo.IsGenericMethod - && PredicateAggregateMethodInfos.Contains(methodInfo.GetGenericMethodDefinition())) - { - selector = _sqlExpressionFactory.Fragment("*"); - } - else - { - return null; - } - } - enumerableExpression.ApplySelector(selector); - - if (enumerableExpression.Predicate != null) - { - if (selector is SqlFragmentExpression) - { - selector = _sqlExpressionFactory.Constant(1); - } - - selector = _sqlExpressionFactory.Case( - new List { new(enumerableExpression.Predicate, selector) }, - elseResult: null); - } - - if (enumerableExpression.IsDistinct) - { - selector = new DistinctExpression(selector); - } - - // TODO: Issue#22957 - return methodInfo.Name switch - { - nameof(Queryable.Average) => TranslateAverage(selector), - nameof(Queryable.Count) => TranslateCount(selector), - nameof(Queryable.LongCount) => TranslateLongCount(selector), - nameof(Queryable.Max) => TranslateMax(selector), - nameof(Queryable.Min) => TranslateMin(selector), - nameof(Queryable.Sum) => TranslateSum(selector), - _ => null, - }; - } - private static Expression TryRemoveImplicitConvert(Expression expression) { if (expression is UnaryExpression unaryExpression diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitorDependencies.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitorDependencies.cs index 1df5e1a1118..e52a85c9cde 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitorDependencies.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitorDependencies.cs @@ -50,13 +50,15 @@ public RelationalSqlTranslatingExpressionVisitorDependencies( IModel model, IRelationalTypeMappingSource typeMappingSource, IMemberTranslatorProvider memberTranslatorProvider, - IMethodCallTranslatorProvider methodCallTranslatorProvider) + IMethodCallTranslatorProvider methodCallTranslatorProvider, + IAggregateMethodCallTranslatorProvider aggregateMethodCallTranslatorProvider) { SqlExpressionFactory = sqlExpressionFactory; Model = model; TypeMappingSource = typeMappingSource; MemberTranslatorProvider = memberTranslatorProvider; MethodCallTranslatorProvider = methodCallTranslatorProvider; + AggregateMethodCallTranslatorProvider = aggregateMethodCallTranslatorProvider; } /// @@ -65,7 +67,7 @@ public RelationalSqlTranslatingExpressionVisitorDependencies( public ISqlExpressionFactory SqlExpressionFactory { get; init; } /// - /// The expression factory. + /// The model. /// public IModel Model { get; init; } @@ -80,7 +82,12 @@ public RelationalSqlTranslatingExpressionVisitorDependencies( public IMemberTranslatorProvider MemberTranslatorProvider { get; init; } /// - /// The method-call translation provider. + /// The scalar method-call translation provider. /// public IMethodCallTranslatorProvider MethodCallTranslatorProvider { get; init; } + + /// + /// The aggregate method-call translation provider. + /// + public IAggregateMethodCallTranslatorProvider AggregateMethodCallTranslatorProvider { get; } } diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitorFactory.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitorFactory.cs index bf8ddbeaa80..ad7d4727000 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitorFactory.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitorFactory.cs @@ -3,15 +3,7 @@ namespace Microsoft.EntityFrameworkCore.Query; -/// -/// A factory for creating instances. -/// -/// -/// The service lifetime is . This means that each -/// instance will use its own instance of this service. -/// The implementation may depend on other services registered with any lifetime. -/// The implementation does not need to be thread-safe. -/// +/// public class RelationalSqlTranslatingExpressionVisitorFactory : IRelationalSqlTranslatingExpressionVisitorFactory { /// @@ -29,12 +21,7 @@ public RelationalSqlTranslatingExpressionVisitorFactory( /// protected virtual RelationalSqlTranslatingExpressionVisitorDependencies Dependencies { get; } - /// - /// Creates a new . - /// - /// The query compilation context to use. - /// The visitor to use to translate subqueries. - /// A relational sql translating expression visitor. + /// public virtual RelationalSqlTranslatingExpressionVisitor Create( QueryCompilationContext queryCompilationContext, QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor) diff --git a/src/EFCore.SqlServer/Extensions/SqlServerServiceCollectionExtensions.cs b/src/EFCore.SqlServer/Extensions/SqlServerServiceCollectionExtensions.cs index 6d79e7872fe..7588e14adec 100644 --- a/src/EFCore.SqlServer/Extensions/SqlServerServiceCollectionExtensions.cs +++ b/src/EFCore.SqlServer/Extensions/SqlServerServiceCollectionExtensions.cs @@ -120,6 +120,7 @@ public static IServiceCollection AddEntityFrameworkSqlServer(this IServiceCollec .TryAdd() .TryAdd() .TryAdd() + .TryAdd() .TryAdd() .TryAdd() .TryAdd() diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerAggregateMethodCallTranslatorProvider.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerAggregateMethodCallTranslatorProvider.cs new file mode 100644 index 00000000000..ba39f384e0f --- /dev/null +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerAggregateMethodCallTranslatorProvider.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class SqlServerAggregateMethodCallTranslatorProvider : RelationalAggregateMethodCallTranslatorProvider +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public SqlServerAggregateMethodCallTranslatorProvider(RelationalAggregateMethodCallTranslatorProviderDependencies dependencies) + : base(dependencies) + { + var sqlExpressionFactory = dependencies.SqlExpressionFactory; + AddTranslators( + new IAggregateMethodCallTranslator[] + { + new SqlServerLongCountMethodTranslator(sqlExpressionFactory) + }); + } +} diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerLongCountMethodTranslator.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerLongCountMethodTranslator.cs new file mode 100644 index 00000000000..5ef4b1b45de --- /dev/null +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerLongCountMethodTranslator.cs @@ -0,0 +1,74 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; + +namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class SqlServerLongCountMethodTranslator : IAggregateMethodCallTranslator +{ + private readonly ISqlExpressionFactory _sqlExpressionFactory; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public SqlServerLongCountMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) + { + _sqlExpressionFactory = sqlExpressionFactory; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual SqlExpression? Translate( + MethodInfo method, EnumerableExpression source, IReadOnlyList arguments, + IDiagnosticsLogger logger) + { + if (method.DeclaringType == typeof(Queryable) + && method.IsGenericMethod + && method.GetGenericMethodDefinition() is MethodInfo genericMethod + && (genericMethod == QueryableMethods.LongCountWithoutPredicate + || genericMethod == QueryableMethods.LongCountWithPredicate)) + { + var sqlExpression = (source.Selector as SqlExpression) ?? _sqlExpressionFactory.Fragment("*"); + if (source.Predicate != null) + { + if (sqlExpression is SqlFragmentExpression) + { + sqlExpression = _sqlExpressionFactory.Constant(1); + } + + sqlExpression = _sqlExpressionFactory.Case( + new List { new(source.Predicate, sqlExpression) }, + elseResult: null); + } + + if (source.IsDistinct) + { + sqlExpression = new DistinctExpression(sqlExpression); + } + + return _sqlExpressionFactory.ApplyDefaultTypeMapping( + _sqlExpressionFactory.Function( + "COUNT_BIG", + new[] { sqlExpression }, + nullable: false, + argumentsPropagateNullability: new[] { false }, + typeof(long))); + } + + return null; + } +} diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs index 5aca4d68afe..1e27cda4b69 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs @@ -152,21 +152,6 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) return base.VisitUnary(unaryExpression); } - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public override SqlExpression? TranslateLongCount(SqlExpression sqlExpression) - => Dependencies.SqlExpressionFactory.ApplyDefaultTypeMapping( - Dependencies.SqlExpressionFactory.Function( - "COUNT_BIG", - new[] { sqlExpression }, - nullable: false, - argumentsPropagateNullability: new[] { false }, - typeof(long))); - private static string? GetProviderType(SqlExpression expression) => expression.TypeMapping?.StoreType; } diff --git a/src/EFCore.Sqlite.Core/Extensions/SqliteServiceCollectionExtensions.cs b/src/EFCore.Sqlite.Core/Extensions/SqliteServiceCollectionExtensions.cs index ffda8bb98dc..4f3a76f7460 100644 --- a/src/EFCore.Sqlite.Core/Extensions/SqliteServiceCollectionExtensions.cs +++ b/src/EFCore.Sqlite.Core/Extensions/SqliteServiceCollectionExtensions.cs @@ -112,6 +112,7 @@ public static IServiceCollection AddEntityFrameworkSqlite(this IServiceCollectio // New Query Pipeline .TryAdd() + .TryAdd() .TryAdd() .TryAdd() .TryAdd() diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteAggregateMethodCallTranslatorProvider.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteAggregateMethodCallTranslatorProvider.cs new file mode 100644 index 00000000000..a836f543a5b --- /dev/null +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteAggregateMethodCallTranslatorProvider.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class SqliteAggregateMethodCallTranslatorProvider : RelationalAggregateMethodCallTranslatorProvider +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public SqliteAggregateMethodCallTranslatorProvider(RelationalAggregateMethodCallTranslatorProviderDependencies dependencies) + : base(dependencies) + { + var sqlExpressionFactory = dependencies.SqlExpressionFactory; + + AddTranslators( + new IAggregateMethodCallTranslator[] + { + new SqliteQueryableAggregateMethodTranslator(sqlExpressionFactory) + }); + } +} diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryableAggregateMethodTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryableAggregateMethodTranslator.cs new file mode 100644 index 00000000000..5ee3e9bacdd --- /dev/null +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryableAggregateMethodTranslator.cs @@ -0,0 +1,110 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; +using Microsoft.EntityFrameworkCore.Sqlite.Internal; + +namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class SqliteQueryableAggregateMethodTranslator : IAggregateMethodCallTranslator +{ + private readonly ISqlExpressionFactory _sqlExpressionFactory; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) + { + _sqlExpressionFactory = sqlExpressionFactory; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual SqlExpression? Translate( + MethodInfo method, EnumerableExpression source, IReadOnlyList arguments, + IDiagnosticsLogger logger) + { + if (method.DeclaringType == typeof(Queryable)) + { + var methodInfo = method.IsGenericMethod + ? method.GetGenericMethodDefinition() + : method; + switch (methodInfo.Name) + { + case nameof(Queryable.Average) + when (QueryableMethods.IsAverageWithoutSelector(methodInfo) + || QueryableMethods.IsAverageWithSelector(methodInfo)) + && source.Selector is SqlExpression averageSqlExpression: + var averageArgumentType = GetProviderType(averageSqlExpression); + if (averageArgumentType == typeof(decimal)) + { + throw new NotSupportedException( + SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Average), averageArgumentType.ShortDisplayName())); + } + break; + + case nameof(Queryable.Max) + when (methodInfo == QueryableMethods.MaxWithoutSelector + || methodInfo == QueryableMethods.MaxWithSelector) + && source.Selector is SqlExpression maxSqlExpression: + var maxArgumentType = GetProviderType(maxSqlExpression); + if (maxArgumentType == typeof(DateTimeOffset) + || maxArgumentType == typeof(decimal) + || maxArgumentType == typeof(TimeSpan) + || maxArgumentType == typeof(ulong)) + { + throw new NotSupportedException( + SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Max), maxArgumentType.ShortDisplayName())); + } + break; + + case nameof(Queryable.Min) + when (methodInfo == QueryableMethods.MinWithoutSelector + || methodInfo == QueryableMethods.MinWithSelector) + && source.Selector is SqlExpression minSqlExpression: + var minArgumentType = GetProviderType(minSqlExpression); + if (minArgumentType == typeof(DateTimeOffset) + || minArgumentType == typeof(decimal) + || minArgumentType == typeof(TimeSpan) + || minArgumentType == typeof(ulong)) + { + throw new NotSupportedException( + SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Min), minArgumentType.ShortDisplayName())); + } + break; + + case nameof(Queryable.Sum) + when (QueryableMethods.IsSumWithoutSelector(methodInfo) + || QueryableMethods.IsSumWithSelector(methodInfo)) + && source.Selector is SqlExpression sumSqlExpression: + var sumArgumentType = GetProviderType(sumSqlExpression); + if (sumArgumentType == typeof(decimal)) + { + throw new NotSupportedException( + SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Sum), sumArgumentType.ShortDisplayName())); + } + break; + } + } + + return null; + } + + private static Type? GetProviderType(SqlExpression expression) + => expression.TypeMapping?.Converter?.ProviderClrType + ?? expression.TypeMapping?.ClrType + ?? expression.Type; +} diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs index dde75d7154d..665b25159d5 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using Microsoft.EntityFrameworkCore.Query.SqlExpressions; -using Microsoft.EntityFrameworkCore.Sqlite.Internal; namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; @@ -212,88 +211,6 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) return visitedExpression; } - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public override SqlExpression? TranslateAverage(SqlExpression sqlExpression) - { - var visitedExpression = base.TranslateAverage(sqlExpression); - var argumentType = GetProviderType(visitedExpression); - if (argumentType == typeof(decimal)) - { - throw new NotSupportedException( - SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Average), argumentType.ShortDisplayName())); - } - - return visitedExpression; - } - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public override SqlExpression? TranslateMax(SqlExpression sqlExpression) - { - var visitedExpression = base.TranslateMax(sqlExpression); - var argumentType = GetProviderType(visitedExpression); - if (argumentType == typeof(DateTimeOffset) - || argumentType == typeof(decimal) - || argumentType == typeof(TimeSpan) - || argumentType == typeof(ulong)) - { - throw new NotSupportedException( - SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Max), argumentType.ShortDisplayName())); - } - - return visitedExpression; - } - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public override SqlExpression? TranslateMin(SqlExpression sqlExpression) - { - var visitedExpression = base.TranslateMin(sqlExpression); - var argumentType = GetProviderType(visitedExpression); - if (argumentType == typeof(DateTimeOffset) - || argumentType == typeof(decimal) - || argumentType == typeof(TimeSpan) - || argumentType == typeof(ulong)) - { - throw new NotSupportedException( - SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Min), argumentType.ShortDisplayName())); - } - - return visitedExpression; - } - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public override SqlExpression? TranslateSum(SqlExpression sqlExpression) - { - var visitedExpression = base.TranslateSum(sqlExpression); - var argumentType = GetProviderType(visitedExpression); - if (argumentType == typeof(decimal)) - { - throw new NotSupportedException( - SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Sum), argumentType.ShortDisplayName())); - } - - return visitedExpression; - } - private static Type? GetProviderType(SqlExpression? expression) => expression == null ? null diff --git a/src/EFCore/Query/CompiledQueryCacheKeyGenerator.cs b/src/EFCore/Query/CompiledQueryCacheKeyGenerator.cs index 1f7f1e84ff8..afbcf6ae9f9 100644 --- a/src/EFCore/Query/CompiledQueryCacheKeyGenerator.cs +++ b/src/EFCore/Query/CompiledQueryCacheKeyGenerator.cs @@ -3,28 +3,7 @@ namespace Microsoft.EntityFrameworkCore.Query; -/// -/// -/// Creates keys that uniquely identifies a query. This is used to store and lookup -/// compiled versions of a query in a cache. -/// -/// -/// This type is typically used by database providers (and other extensions). It is generally -/// not used in application code. -/// -/// -/// -/// -/// The service lifetime is . This means that each -/// instance will use its own instance of this service. -/// The implementation may depend on other services registered with any lifetime. -/// The implementation does not need to be thread-safe. -/// -/// -/// See Implementation of database providers and extensions -/// and How EF Core queries work for more information and examples. -/// -/// +/// public class CompiledQueryCacheKeyGenerator : ICompiledQueryCacheKeyGenerator { /// @@ -41,12 +20,7 @@ public CompiledQueryCacheKeyGenerator(CompiledQueryCacheKeyGeneratorDependencies /// protected virtual CompiledQueryCacheKeyGeneratorDependencies Dependencies { get; } - /// - /// Generates the cache key for the given query. - /// - /// The query to get the cache key for. - /// A value indicating whether the query will be executed asynchronously. - /// The cache key. + /// public virtual object GenerateCacheKey(Expression query, bool async) => GenerateCacheKeyCore(query, async); @@ -99,40 +73,18 @@ public CompiledQueryCacheKey( _async = async; } - /// - /// Determines if this key is equivalent to a given object (i.e. if they are keys for the same query). - /// - /// - /// The object to compare this key to. - /// - /// - /// if the object is a and is for the same query, otherwise - /// . - /// + /// public override bool Equals(object? obj) => obj is CompiledQueryCacheKey other && Equals(other); - /// - /// Indicates whether the current object is equal to another object of the same type. - /// - /// - /// An object to compare with this object. - /// - /// - /// if the current object is equal to the parameter; otherwise, . - /// + /// public bool Equals(CompiledQueryCacheKey other) => ReferenceEquals(_model, other._model) && _queryTrackingBehavior == other._queryTrackingBehavior && _async == other._async && ExpressionEqualityComparer.Instance.Equals(_query, other._query); - /// - /// Gets the hash code for the key. - /// - /// - /// The hash code for the key. - /// + /// public override int GetHashCode() { var hash = new HashCode(); diff --git a/src/EFCore/Query/EvaluatableExpressionFilter.cs b/src/EFCore/Query/EvaluatableExpressionFilter.cs index 0f2a3a5f32e..d3a2d0d9bf8 100644 --- a/src/EFCore/Query/EvaluatableExpressionFilter.cs +++ b/src/EFCore/Query/EvaluatableExpressionFilter.cs @@ -3,20 +3,7 @@ namespace Microsoft.EntityFrameworkCore.Query; -/// -/// Represents a filter for evaluatable expressions. -/// -/// -/// -/// The service lifetime is . This means a single instance -/// is used by many instances. The implementation must be thread-safe. -/// This service cannot depend on services registered as . -/// -/// -/// See Implementation of database providers and extensions -/// and How EF Core queries work for more information and examples. -/// -/// +/// public class EvaluatableExpressionFilter : IEvaluatableExpressionFilter { // This methods are non-deterministic and result varies based on time of running the query. @@ -69,12 +56,7 @@ public EvaluatableExpressionFilter( /// protected virtual EvaluatableExpressionFilterDependencies Dependencies { get; } - /// - /// Checks whether the given expression can be evaluated. - /// - /// The expression. - /// The model. - /// if the expression can be evaluated; otherwise. + /// public virtual bool IsEvaluatableExpression(Expression expression, IModel model) { switch (expression) diff --git a/src/EFCore/Query/ExpressionEqualityComparer.cs b/src/EFCore/Query/ExpressionEqualityComparer.cs index d8bdc13eccf..bd8934fee4d 100644 --- a/src/EFCore/Query/ExpressionEqualityComparer.cs +++ b/src/EFCore/Query/ExpressionEqualityComparer.cs @@ -28,11 +28,7 @@ private ExpressionEqualityComparer() /// public static ExpressionEqualityComparer Instance { get; } = new(); - /// - /// Returns the hash code for given expression. - /// - /// The obj to compute hash code for. - /// The hash code value for . + /// public int GetHashCode(Expression obj) { if (obj == null) @@ -276,12 +272,7 @@ void AddMemberBindingsToHash(IReadOnlyList memberBindings) } } - /// - /// Returns a value indicating whether the given expressions are equal. - /// - /// The left expression. - /// The right expression. - /// if the expressions are equal, otherwise. + /// public bool Equals(Expression? x, Expression? y) => new ExpressionComparer().Compare(x, y); diff --git a/src/EFCore/Query/ICompiledQueryCacheKeyGenerator.cs b/src/EFCore/Query/ICompiledQueryCacheKeyGenerator.cs index 81f8f30b8b4..43bbea6d271 100644 --- a/src/EFCore/Query/ICompiledQueryCacheKeyGenerator.cs +++ b/src/EFCore/Query/ICompiledQueryCacheKeyGenerator.cs @@ -4,7 +4,14 @@ namespace Microsoft.EntityFrameworkCore.Query; /// -/// A cache key generator for the compiled query cache. +/// +/// Creates keys that uniquely identifies a query. This is used to store and lookup +/// compiled versions of a query in a cache. +/// +/// +/// This type is typically used by database providers (and other extensions). It is generally +/// not used in application code. +/// /// /// /// diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs index 4244e4018bf..fef1f7dc887 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs @@ -626,9 +626,9 @@ public override async Task GroupBy_constant_with_where_on_grouping_with_aggregat WHEN 1 = [t].[Key] THEN [t].[OrderDate] END) AS [Max], COALESCE(SUM(CASE WHEN 1 = [t].[Key] THEN [t].[OrderID] -END), 0) AS [Sum], AVG(CAST(CASE - WHEN 1 = [t].[Key] THEN [t].[OrderID] -END AS float)) AS [Average] +END), 0) AS [Sum], AVG(CASE + WHEN 1 = [t].[Key] THEN CAST([t].[OrderID] AS float) +END) AS [Average] FROM ( SELECT [o].[OrderID], [o].[OrderDate], 1 AS [Key] FROM [Orders] AS [o] @@ -1782,9 +1782,9 @@ public override async Task GroupBy_Where_Average(bool async) await base.GroupBy_Where_Average(async); AssertSql( - @"SELECT AVG(CAST(CASE - WHEN [o].[OrderID] < 10300 THEN [o].[OrderID] -END AS float)) + @"SELECT AVG(CASE + WHEN [o].[OrderID] < 10300 THEN CAST([o].[OrderID] AS float) +END) FROM [Orders] AS [o] GROUP BY [o].[CustomerID]"); }