diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs index 2cfad91bad0..094a16610f6 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs @@ -737,7 +737,22 @@ private ShapedQueryExpression AggregateResultShaper( Expression shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), projection.Type); - if (throwOnNullResult) + if (throwOnNullResult + && resultType.IsNullableType()) + { + var resultVariable = Expression.Variable(projection.Type, "result"); + + shaper = Expression.Block( + new[] { resultVariable }, + Expression.Assign(resultVariable, shaper), + Expression.Condition( + Expression.Equal(resultVariable, Expression.Default(projection.Type)), + Expression.Constant(null, resultType), + resultType != resultVariable.Type + ? Expression.Convert(resultVariable, resultType) + : (Expression)resultVariable)); + } + else if (throwOnNullResult) { var resultVariable = Expression.Variable(projection.Type, "result"); diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index b43f9678350..f86f4251478 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -1096,7 +1096,22 @@ private ShapedQueryExpression AggregateResultShaper( Expression shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), projection.Type); - if (throwOnNullResult) + if (throwOnNullResult + && resultType.IsNullableType()) + { + var resultVariable = Expression.Variable(projection.Type, "result"); + + shaper = Expression.Block( + new[] { resultVariable }, + Expression.Assign(resultVariable, shaper), + Expression.Condition( + Expression.Equal(resultVariable, Expression.Default(projection.Type)), + Expression.Constant(null, resultType), + resultType != resultVariable.Type + ? Expression.Convert(resultVariable, resultType) + : (Expression)resultVariable)); + } + else if (throwOnNullResult) { var resultVariable = Expression.Variable(projection.Type, "result"); diff --git a/src/EFCore/Extensions/EntityFrameworkQueryableExtensions.cs b/src/EFCore/Extensions/EntityFrameworkQueryableExtensions.cs index 72fa5094a2d..d3dac52e862 100644 --- a/src/EFCore/Extensions/EntityFrameworkQueryableExtensions.cs +++ b/src/EFCore/Extensions/EntityFrameworkQueryableExtensions.cs @@ -26,8 +26,6 @@ public static class EntityFrameworkQueryableExtensions { #region Any/All - private static readonly MethodInfo _any = GetMethod(nameof(Queryable.Any)); - /// /// Asynchronously determines whether a sequence contains any elements. /// @@ -54,11 +52,9 @@ public static Task AnyAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_any, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AnyWithoutPredicateMethodInfo, source, cancellationToken); } - private static readonly MethodInfo _anyPredicate = GetMethod(nameof(Queryable.Any), parameterCount: 1); - /// /// Asynchronously determines whether any element of a sequence satisfies a condition. /// @@ -89,11 +85,9 @@ public static Task AnyAsync( Check.NotNull(source, nameof(source)); Check.NotNull(predicate, nameof(predicate)); - return ExecuteAsync>(_anyPredicate, source, predicate, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AnyWithPredicateMethodInfo, source, predicate, cancellationToken); } - private static readonly MethodInfo _allPredicate = GetMethod(nameof(Queryable.All), parameterCount: 1); - /// /// Asynchronously determines whether all the elements of a sequence satisfy a condition. /// @@ -124,15 +118,13 @@ public static Task AllAsync( Check.NotNull(source, nameof(source)); Check.NotNull(predicate, nameof(predicate)); - return ExecuteAsync>(_allPredicate, source, predicate, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AllMethodInfo, source, predicate, cancellationToken); } #endregion #region Count/LongCount - private static readonly MethodInfo _count = GetMethod(nameof(Queryable.Count)); - /// /// Asynchronously returns the number of elements in a sequence. /// @@ -159,11 +151,9 @@ public static Task CountAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_count, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.CountWithoutPredicateMethodInfo, source, cancellationToken); } - private static readonly MethodInfo _countPredicate = GetMethod(nameof(Queryable.Count), parameterCount: 1); - /// /// Asynchronously returns the number of elements in a sequence that satisfy a condition. /// @@ -194,11 +184,9 @@ public static Task CountAsync( Check.NotNull(source, nameof(source)); Check.NotNull(predicate, nameof(predicate)); - return ExecuteAsync>(_countPredicate, source, predicate, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.CountWithPredicateMethodInfo, source, predicate, cancellationToken); } - private static readonly MethodInfo _longCount = GetMethod(nameof(Queryable.LongCount)); - /// /// Asynchronously returns an that represents the total number of elements in a sequence. /// @@ -225,11 +213,9 @@ public static Task LongCountAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_longCount, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.LongCountWithoutPredicateMethodInfo, source, cancellationToken); } - private static readonly MethodInfo _longCountPredicate = GetMethod(nameof(Queryable.LongCount), parameterCount: 1); - /// /// Asynchronously returns an that represents the number of elements in a sequence /// that satisfy a condition. @@ -261,15 +247,13 @@ public static Task LongCountAsync( Check.NotNull(source, nameof(source)); Check.NotNull(predicate, nameof(predicate)); - return ExecuteAsync>(_longCountPredicate, source, predicate, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.LongCountWithPredicateMethodInfo, source, predicate, cancellationToken); } #endregion #region First/FirstOrDefault - private static readonly MethodInfo _first = GetMethod(nameof(Queryable.First)); - /// /// Asynchronously returns the first element of a sequence. /// @@ -296,11 +280,9 @@ public static Task FirstAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_first, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.FirstWithoutPredicateMethodInfo, source, cancellationToken); } - private static readonly MethodInfo _firstPredicate = GetMethod(nameof(Queryable.First), parameterCount: 1); - /// /// Asynchronously returns the first element of a sequence that satisfies a specified condition. /// @@ -331,11 +313,9 @@ public static Task FirstAsync( Check.NotNull(source, nameof(source)); Check.NotNull(predicate, nameof(predicate)); - return ExecuteAsync>(_firstPredicate, source, predicate, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.FirstWithPredicateMethodInfo, source, predicate, cancellationToken); } - private static readonly MethodInfo _firstOrDefault = GetMethod(nameof(Queryable.FirstOrDefault)); - /// /// Asynchronously returns the first element of a sequence, or a default value if the sequence contains no elements. /// @@ -363,11 +343,9 @@ public static Task FirstOrDefaultAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_firstOrDefault, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.FirstOrDefaultWithoutPredicateMethodInfo, source, cancellationToken); } - private static readonly MethodInfo _firstOrDefaultPredicate = GetMethod(nameof(Queryable.FirstOrDefault), parameterCount: 1); - /// /// Asynchronously returns the first element of a sequence that satisfies a specified condition /// or a default value if no such element is found. @@ -400,15 +378,13 @@ public static Task FirstOrDefaultAsync( Check.NotNull(source, nameof(source)); Check.NotNull(predicate, nameof(predicate)); - return ExecuteAsync>(_firstOrDefaultPredicate, source, predicate, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.FirstOrDefaultWithPredicateMethodInfo, source, predicate, cancellationToken); } #endregion #region Last/LastOrDefault - private static readonly MethodInfo _last = GetMethod(nameof(Queryable.Last)); - /// /// Asynchronously returns the last element of a sequence. /// @@ -435,11 +411,9 @@ public static Task LastAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_last, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.LastWithoutPredicateMethodInfo, source, cancellationToken); } - private static readonly MethodInfo _lastPredicate = GetMethod(nameof(Queryable.Last), parameterCount: 1); - /// /// Asynchronously returns the last element of a sequence that satisfies a specified condition. /// @@ -470,11 +444,9 @@ public static Task LastAsync( Check.NotNull(source, nameof(source)); Check.NotNull(predicate, nameof(predicate)); - return ExecuteAsync>(_lastPredicate, source, predicate, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.LastWithPredicateMethodInfo, source, predicate, cancellationToken); } - private static readonly MethodInfo _lastOrDefault = GetMethod(nameof(Queryable.LastOrDefault)); - /// /// Asynchronously returns the last element of a sequence, or a default value if the sequence contains no elements. /// @@ -502,11 +474,9 @@ public static Task LastOrDefaultAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_lastOrDefault, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.LastOrDefaultWithoutPredicateMethodInfo, source, cancellationToken); } - private static readonly MethodInfo _lastOrDefaultPredicate = GetMethod(nameof(Queryable.LastOrDefault), parameterCount: 1); - /// /// Asynchronously returns the last element of a sequence that satisfies a specified condition /// or a default value if no such element is found. @@ -539,15 +509,13 @@ public static Task LastOrDefaultAsync( Check.NotNull(source, nameof(source)); Check.NotNull(predicate, nameof(predicate)); - return ExecuteAsync>(_lastOrDefaultPredicate, source, predicate, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.LastOrDefaultWithPredicateMethodInfo, source, predicate, cancellationToken); } #endregion #region Single/SingleOrDefault - private static readonly MethodInfo _single = GetMethod(nameof(Queryable.Single)); - /// /// Asynchronously returns the only element of a sequence, and throws an exception /// if there is not exactly one element in the sequence. @@ -575,11 +543,9 @@ public static Task SingleAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_single, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SingleWithoutPredicateMethodInfo, source, cancellationToken); } - private static readonly MethodInfo _singlePredicate = GetMethod(nameof(Queryable.Single), parameterCount: 1); - /// /// Asynchronously returns the only element of a sequence that satisfies a specified condition, /// and throws an exception if more than one such element exists. @@ -611,11 +577,9 @@ public static Task SingleAsync( Check.NotNull(source, nameof(source)); Check.NotNull(predicate, nameof(predicate)); - return ExecuteAsync>(_singlePredicate, source, predicate, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SingleWithPredicateMethodInfo, source, predicate, cancellationToken); } - private static readonly MethodInfo _singleOrDefault = GetMethod(nameof(Queryable.SingleOrDefault)); - /// /// Asynchronously returns the only element of a sequence, or a default value if the sequence is empty; /// this method throws an exception if there is more than one element in the sequence. @@ -645,11 +609,9 @@ public static Task SingleOrDefaultAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_singleOrDefault, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SingleOrDefaultWithoutPredicateMethodInfo, source, cancellationToken); } - private static readonly MethodInfo _singleOrDefaultPredicate = GetMethod(nameof(Queryable.SingleOrDefault), parameterCount: 1); - /// /// Asynchronously returns the only element of a sequence that satisfies a specified condition or /// a default value if no such element exists; this method throws an exception if more than one element @@ -682,15 +644,13 @@ public static Task SingleOrDefaultAsync( Check.NotNull(source, nameof(source)); Check.NotNull(predicate, nameof(predicate)); - return ExecuteAsync>(_singleOrDefaultPredicate, source, predicate, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SingleOrDefaultWithPredicateMethodInfo, source, predicate, cancellationToken); } #endregion #region Min - private static readonly MethodInfo _min = GetMethod(nameof(Queryable.Min), predicate: mi => mi.IsGenericMethod); - /// /// Asynchronously returns the minimum value of a sequence. /// @@ -717,11 +677,9 @@ public static Task MinAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_min, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.MinWithoutSelectorMethodInfo, source, cancellationToken); } - private static readonly MethodInfo _minSelector = GetMethod(nameof(Queryable.Min), parameterCount: 1, predicate: mi => mi.IsGenericMethod); - /// /// Asynchronously invokes a projection function on each element of a sequence and returns the minimum resulting value. /// @@ -754,15 +712,13 @@ public static Task MinAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_minSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.MinWithSelectorMethodInfo, source, selector, cancellationToken); } #endregion #region Max - private static readonly MethodInfo _max = GetMethod(nameof(Queryable.Max), predicate: mi => mi.IsGenericMethod); - /// /// Asynchronously returns the maximum value of a sequence. /// @@ -789,11 +745,9 @@ public static Task MaxAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_max, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.MaxWithoutSelectorMethodInfo, source, cancellationToken); } - private static readonly MethodInfo _maxSelector = GetMethod(nameof(Queryable.Max), parameterCount: 1, predicate: mi => mi.IsGenericMethod); - /// /// Asynchronously invokes a projection function on each element of a sequence and returns the maximum resulting value. /// @@ -826,15 +780,13 @@ public static Task MaxAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_maxSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.MaxWithSelectorMethodInfo, source, selector, cancellationToken); } #endregion #region Sum - private static readonly MethodInfo _sumDecimal = GetMethod(nameof(Queryable.Sum)); - /// /// Asynchronously computes the sum of a sequence of values. /// @@ -858,11 +810,9 @@ public static Task SumAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_sumDecimal, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SumWithoutSelectorMethodInfos[typeof(decimal)], source, cancellationToken); } - private static readonly MethodInfo _sumNullableDecimal = GetMethod(nameof(Queryable.Sum)); - /// /// Asynchronously computes the sum of a sequence of values. /// @@ -886,11 +836,9 @@ public static Task SumAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_sumNullableDecimal, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SumWithoutSelectorMethodInfos[typeof(decimal?)], source, cancellationToken); } - private static readonly MethodInfo _sumDecimalSelector = GetMethod(nameof(Queryable.Sum), parameterCount: 1); - /// /// Asynchronously computes the sum of the sequence of values that is obtained by invoking a projection function on /// each element of the input sequence. @@ -918,11 +866,9 @@ public static Task SumAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_sumDecimalSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SumWithSelectorMethodInfos[typeof(decimal)], source, selector, cancellationToken); } - private static readonly MethodInfo _sumNullableDecimalSelector = GetMethod(nameof(Queryable.Sum), parameterCount: 1); - /// /// Asynchronously computes the sum of the sequence of values that is obtained by invoking a projection function on /// each element of the input sequence. @@ -950,11 +896,9 @@ public static Task SumAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_sumNullableDecimalSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SumWithSelectorMethodInfos[typeof(decimal?)], source, selector, cancellationToken); } - private static readonly MethodInfo _sumInt = GetMethod(nameof(Queryable.Sum)); - /// /// Asynchronously computes the sum of a sequence of values. /// @@ -978,11 +922,9 @@ public static Task SumAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_sumInt, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SumWithoutSelectorMethodInfos[typeof(int)], source, cancellationToken); } - private static readonly MethodInfo _sumNullableInt = GetMethod(nameof(Queryable.Sum)); - /// /// Asynchronously computes the sum of a sequence of values. /// @@ -1006,11 +948,9 @@ public static Task SumAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_sumNullableInt, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SumWithoutSelectorMethodInfos[typeof(int?)], source, cancellationToken); } - private static readonly MethodInfo _sumIntSelector = GetMethod(nameof(Queryable.Sum), parameterCount: 1); - /// /// Asynchronously computes the sum of the sequence of values that is obtained by invoking a projection function on /// each element of the input sequence. @@ -1038,11 +978,9 @@ public static Task SumAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_sumIntSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SumWithSelectorMethodInfos[typeof(int)], source, selector, cancellationToken); } - private static readonly MethodInfo _sumNullableIntSelector = GetMethod(nameof(Queryable.Sum), parameterCount: 1); - /// /// Asynchronously computes the sum of the sequence of values that is obtained by invoking a projection function on /// each element of the input sequence. @@ -1070,11 +1008,9 @@ public static Task SumAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_sumNullableIntSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SumWithSelectorMethodInfos[typeof(int?)], source, selector, cancellationToken); } - private static readonly MethodInfo _sumLong = GetMethod(nameof(Queryable.Sum)); - /// /// Asynchronously computes the sum of a sequence of values. /// @@ -1098,11 +1034,9 @@ public static Task SumAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_sumLong, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SumWithoutSelectorMethodInfos[typeof(long)], source, cancellationToken); } - private static readonly MethodInfo _sumNullableLong = GetMethod(nameof(Queryable.Sum)); - /// /// Asynchronously computes the sum of a sequence of values. /// @@ -1126,11 +1060,9 @@ public static Task SumAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_sumNullableLong, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SumWithoutSelectorMethodInfos[typeof(long?)], source, cancellationToken); } - private static readonly MethodInfo _sumLongSelector = GetMethod(nameof(Queryable.Sum), parameterCount: 1); - /// /// Asynchronously computes the sum of the sequence of values that is obtained by invoking a projection function on /// each element of the input sequence. @@ -1158,11 +1090,9 @@ public static Task SumAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_sumLongSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SumWithSelectorMethodInfos[typeof(long)], source, selector, cancellationToken); } - private static readonly MethodInfo _sumNullableLongSelector = GetMethod(nameof(Queryable.Sum), parameterCount: 1); - /// /// Asynchronously computes the sum of the sequence of values that is obtained by invoking a projection function on /// each element of the input sequence. @@ -1190,11 +1120,9 @@ public static Task SumAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_sumNullableLongSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SumWithSelectorMethodInfos[typeof(long?)], source, selector, cancellationToken); } - private static readonly MethodInfo _sumDouble = GetMethod(nameof(Queryable.Sum)); - /// /// Asynchronously computes the sum of a sequence of values. /// @@ -1218,11 +1146,9 @@ public static Task SumAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_sumDouble, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SumWithoutSelectorMethodInfos[typeof(double)], source, cancellationToken); } - private static readonly MethodInfo _sumNullableDouble = GetMethod(nameof(Queryable.Sum)); - /// /// Asynchronously computes the sum of a sequence of values. /// @@ -1246,11 +1172,9 @@ public static Task SumAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_sumNullableDouble, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SumWithoutSelectorMethodInfos[typeof(double?)], source, cancellationToken); } - private static readonly MethodInfo _sumDoubleSelector = GetMethod(nameof(Queryable.Sum), parameterCount: 1); - /// /// Asynchronously computes the sum of the sequence of values that is obtained by invoking a projection function on /// each element of the input sequence. @@ -1278,11 +1202,9 @@ public static Task SumAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_sumDoubleSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SumWithSelectorMethodInfos[typeof(double)], source, selector, cancellationToken); } - private static readonly MethodInfo _sumNullableDoubleSelector = GetMethod(nameof(Queryable.Sum), parameterCount: 1); - /// /// Asynchronously computes the sum of the sequence of values that is obtained by invoking a projection function on /// each element of the input sequence. @@ -1310,11 +1232,9 @@ public static Task SumAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_sumNullableDoubleSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SumWithSelectorMethodInfos[typeof(double?)], source, selector, cancellationToken); } - private static readonly MethodInfo _sumFloat = GetMethod(nameof(Queryable.Sum)); - /// /// Asynchronously computes the sum of a sequence of values. /// @@ -1338,11 +1258,9 @@ public static Task SumAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_sumFloat, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SumWithoutSelectorMethodInfos[typeof(float)], source, cancellationToken); } - private static readonly MethodInfo _sumNullableFloat = GetMethod(nameof(Queryable.Sum)); - /// /// Asynchronously computes the sum of a sequence of values. /// @@ -1366,11 +1284,9 @@ public static Task SumAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_sumNullableFloat, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SumWithoutSelectorMethodInfos[typeof(float?)], source, cancellationToken); } - private static readonly MethodInfo _sumFloatSelector = GetMethod(nameof(Queryable.Sum), parameterCount: 1); - /// /// Asynchronously computes the sum of the sequence of values that is obtained by invoking a projection function on /// each element of the input sequence. @@ -1398,11 +1314,9 @@ public static Task SumAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_sumFloatSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SumWithSelectorMethodInfos[typeof(float)], source, selector, cancellationToken); } - private static readonly MethodInfo _sumNullableFloatSelector = GetMethod(nameof(Queryable.Sum), parameterCount: 1); - /// /// Asynchronously computes the sum of the sequence of values that is obtained by invoking a projection function on /// each element of the input sequence. @@ -1430,26 +1344,13 @@ public static Task SumAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_sumNullableFloatSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.SumWithSelectorMethodInfos[typeof(float?)], source, selector, cancellationToken); } #endregion #region Average - private static MethodInfo GetAverageMethod(int parameterCount = 0) - => GetMethod( - nameof(Queryable.Average), - parameterCount, - mi => parameterCount == 0 - && mi.GetParameters()[0].ParameterType == typeof(IQueryable) - || mi.GetParameters().Length == 2 - && mi.GetParameters()[1] - .ParameterType.GenericTypeArguments[0] - .GenericTypeArguments[1] == typeof(TOperand)); - - private static readonly MethodInfo _averageDecimal = GetAverageMethod(); - /// /// Asynchronously computes the average of a sequence of values. /// @@ -1473,11 +1374,9 @@ public static Task AverageAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_averageDecimal, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AverageWithoutSelectorMethodInfos[typeof(decimal)], source, cancellationToken); } - private static readonly MethodInfo _averageNullableDecimal = GetAverageMethod(); - /// /// Asynchronously computes the average of a sequence of values. /// @@ -1501,11 +1400,9 @@ public static Task AverageAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_averageNullableDecimal, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AverageWithoutSelectorMethodInfos[typeof(decimal?)], source, cancellationToken); } - private static readonly MethodInfo _averageDecimalSelector = GetAverageMethod(parameterCount: 1); - /// /// Asynchronously computes the average of a sequence of values that is obtained /// by invoking a projection function on each element of the input sequence. @@ -1534,11 +1431,9 @@ public static Task AverageAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_averageDecimalSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AverageWithSelectorMethodInfos[typeof(decimal)], source, selector, cancellationToken); } - private static readonly MethodInfo _averageNullableDecimalSelector = GetAverageMethod(parameterCount: 1); - /// /// Asynchronously computes the average of a sequence of values that is obtained /// by invoking a projection function on each element of the input sequence. @@ -1567,11 +1462,9 @@ public static Task AverageAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_averageNullableDecimalSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AverageWithSelectorMethodInfos[typeof(decimal?)], source, selector, cancellationToken); } - private static readonly MethodInfo _averageInt = GetAverageMethod(); - /// /// Asynchronously computes the average of a sequence of values. /// @@ -1595,11 +1488,9 @@ public static Task AverageAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_averageInt, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AverageWithoutSelectorMethodInfos[typeof(int)], source, cancellationToken); } - private static readonly MethodInfo _averageNullableInt = GetAverageMethod(); - /// /// Asynchronously computes the average of a sequence of values. /// @@ -1623,11 +1514,9 @@ public static Task AverageAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_averageNullableInt, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AverageWithoutSelectorMethodInfos[typeof(int?)], source, cancellationToken); } - private static readonly MethodInfo _averageIntSelector = GetAverageMethod(parameterCount: 1); - /// /// Asynchronously computes the average of a sequence of values that is obtained /// by invoking a projection function on each element of the input sequence. @@ -1656,11 +1545,9 @@ public static Task AverageAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_averageIntSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AverageWithSelectorMethodInfos[typeof(int)], source, selector, cancellationToken); } - private static readonly MethodInfo _averageNullableIntSelector = GetAverageMethod(parameterCount: 1); - /// /// Asynchronously computes the average of a sequence of values that is obtained /// by invoking a projection function on each element of the input sequence. @@ -1689,11 +1576,9 @@ public static Task AverageAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_averageNullableIntSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AverageWithSelectorMethodInfos[typeof(int?)], source, selector, cancellationToken); } - private static readonly MethodInfo _averageLong = GetAverageMethod(); - /// /// Asynchronously computes the average of a sequence of values. /// @@ -1717,11 +1602,9 @@ public static Task AverageAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_averageLong, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AverageWithoutSelectorMethodInfos[typeof(long)], source, cancellationToken); } - private static readonly MethodInfo _averageNullableLong = GetAverageMethod(); - /// /// Asynchronously computes the average of a sequence of values. /// @@ -1745,11 +1628,9 @@ public static Task AverageAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_averageNullableLong, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AverageWithoutSelectorMethodInfos[typeof(long?)], source, cancellationToken); } - private static readonly MethodInfo _averageLongSelector = GetAverageMethod(parameterCount: 1); - /// /// Asynchronously computes the average of a sequence of values that is obtained /// by invoking a projection function on each element of the input sequence. @@ -1778,11 +1659,9 @@ public static Task AverageAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_averageLongSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AverageWithSelectorMethodInfos[typeof(long)], source, selector, cancellationToken); } - private static readonly MethodInfo _averageNullableLongSelector = GetAverageMethod(parameterCount: 1); - /// /// Asynchronously computes the average of a sequence of values that is obtained /// by invoking a projection function on each element of the input sequence. @@ -1811,11 +1690,9 @@ public static Task AverageAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_averageNullableLongSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AverageWithSelectorMethodInfos[typeof(long?)], source, selector, cancellationToken); } - private static readonly MethodInfo _averageDouble = GetAverageMethod(); - /// /// Asynchronously computes the average of a sequence of values. /// @@ -1839,11 +1716,9 @@ public static Task AverageAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_averageDouble, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AverageWithoutSelectorMethodInfos[typeof(double)], source, cancellationToken); } - private static readonly MethodInfo _averageNullableDouble = GetAverageMethod(); - /// /// Asynchronously computes the average of a sequence of values. /// @@ -1867,11 +1742,9 @@ public static Task AverageAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_averageNullableDouble, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AverageWithoutSelectorMethodInfos[typeof(double?)], source, cancellationToken); } - private static readonly MethodInfo _averageDoubleSelector = GetAverageMethod(parameterCount: 1); - /// /// Asynchronously computes the average of a sequence of values that is obtained /// by invoking a projection function on each element of the input sequence. @@ -1900,11 +1773,9 @@ public static Task AverageAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_averageDoubleSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AverageWithSelectorMethodInfos[typeof(double)], source, selector, cancellationToken); } - private static readonly MethodInfo _averageNullableDoubleSelector = GetAverageMethod(parameterCount: 1); - /// /// Asynchronously computes the average of a sequence of values that is obtained /// by invoking a projection function on each element of the input sequence. @@ -1933,11 +1804,9 @@ public static Task AverageAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_averageNullableDoubleSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AverageWithSelectorMethodInfos[typeof(double?)], source, selector, cancellationToken); } - private static readonly MethodInfo _averageFloat = GetAverageMethod(); - /// /// Asynchronously computes the average of a sequence of values. /// @@ -1961,11 +1830,9 @@ public static Task AverageAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_averageFloat, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AverageWithoutSelectorMethodInfos[typeof(float)], source, cancellationToken); } - private static readonly MethodInfo _averageNullableFloat = GetAverageMethod(); - /// /// Asynchronously computes the average of a sequence of values. /// @@ -1989,11 +1856,9 @@ public static Task AverageAsync( { Check.NotNull(source, nameof(source)); - return ExecuteAsync>(_averageNullableFloat, source, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AverageWithoutSelectorMethodInfos[typeof(float?)], source, cancellationToken); } - private static readonly MethodInfo _averageFloatSelector = GetAverageMethod(parameterCount: 1); - /// /// Asynchronously computes the average of a sequence of values that is obtained /// by invoking a projection function on each element of the input sequence. @@ -2022,11 +1887,9 @@ public static Task AverageAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_averageFloatSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AverageWithSelectorMethodInfos[typeof(float)], source, selector, cancellationToken); } - private static readonly MethodInfo _averageNullableFloatSelector = GetAverageMethod(parameterCount: 1); - /// /// Asynchronously computes the average of a sequence of values that is obtained /// by invoking a projection function on each element of the input sequence. @@ -2055,15 +1918,13 @@ public static Task AverageAsync( Check.NotNull(source, nameof(source)); Check.NotNull(selector, nameof(selector)); - return ExecuteAsync>(_averageNullableFloatSelector, source, selector, cancellationToken); + return ExecuteAsync>(QueryableMethodProvider.AverageWithSelectorMethodInfos[typeof(float?)], source, selector, cancellationToken); } #endregion #region Contains - private static readonly MethodInfo _contains = GetMethod(nameof(Queryable.Contains), parameterCount: 1); - /// /// Asynchronously determines whether a sequence contains a specified element by using the default equality comparer. /// @@ -2093,7 +1954,7 @@ public static Task ContainsAsync( Check.NotNull(source, nameof(source)); return ExecuteAsync>( - _contains, + QueryableMethodProvider.ContainsMethodInfo, source, Expression.Constant(item, typeof(TSource)), cancellationToken); @@ -2979,21 +2840,6 @@ private static TResult ExecuteAsync( => ExecuteAsync( operatorMethodInfo, source, (Expression)null, cancellationToken); - private static MethodInfo GetMethod( - string name, int parameterCount = 0, Func predicate = null) - => GetMethod( - name, - parameterCount, - mi => mi.ReturnType == typeof(TResult) - && (predicate == null || predicate(mi))); - - private static MethodInfo GetMethod( - string name, int parameterCount = 0, Func predicate = null) - => typeof(Queryable).GetTypeInfo().GetDeclaredMethods(name) - .Single( - mi => mi.GetParameters().Length == parameterCount + 1 - && (predicate == null || predicate(mi))); - #endregion } } diff --git a/src/EFCore/Query/QueryableMethodProvider.cs b/src/EFCore/Query/QueryableMethodProvider.cs index 96e4291c16d..6f97ef2f38f 100644 --- a/src/EFCore/Query/QueryableMethodProvider.cs +++ b/src/EFCore/Query/QueryableMethodProvider.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -80,6 +81,11 @@ private static bool IsExpressionOfFunc(Type type, int funcGenericArgs = 2) public static MethodInfo GroupByWithKeyElementResultSelectorMethodInfo { get; } public static MethodInfo GroupByWithKeyResultSelectorMethodInfo { get; } + public static IDictionary SumWithoutSelectorMethodInfos { get; } + public static IDictionary SumWithSelectorMethodInfos { get; } + public static IDictionary AverageWithoutSelectorMethodInfos { get; } + public static IDictionary AverageWithSelectorMethodInfos { get; } + static QueryableMethodProvider() { var queryableMethods = typeof(Queryable).GetTypeInfo() @@ -202,6 +208,81 @@ static QueryableMethodProvider() mi => mi.Name == nameof(Queryable.GroupBy) && mi.GetParameters().Length == 4 && IsExpressionOfFunc(mi.GetParameters()[1].ParameterType) && IsExpressionOfFunc(mi.GetParameters()[2].ParameterType) && IsExpressionOfFunc(mi.GetParameters()[3].ParameterType, 3)); GroupByWithKeyResultSelectorMethodInfo = queryableMethods.Single( mi => mi.Name == nameof(Queryable.GroupBy) && mi.GetParameters().Length == 3 && IsExpressionOfFunc(mi.GetParameters()[1].ParameterType) && IsExpressionOfFunc(mi.GetParameters()[2].ParameterType, 3)); + + MethodInfo GetSumOrAverageWithoutSelector(string methodName) + => queryableMethods.Single( + mi => mi.Name == methodName + && mi.GetParameters().Length == 1 + && mi.GetParameters()[0].ParameterType.GetGenericArguments()[0] == typeof(T)); + + bool HasSelector(Type type) + => type.IsGenericType + && type.GetGenericTypeDefinition() == typeof(Expression<>) + && type.GetGenericArguments()[0].IsGenericType + && type.GetGenericArguments()[0].GetGenericArguments().Length == 2 + && type.GetGenericArguments()[0].GetGenericArguments()[1] == typeof(T); + + MethodInfo GetSumOrAverageWithSelector(string methodName) + => queryableMethods.Single( + mi => mi.Name == methodName + && mi.GetParameters().Length == 2 + && HasSelector(mi.GetParameters()[1].ParameterType)); + + SumWithoutSelectorMethodInfos = new Dictionary + { + { typeof(decimal), GetSumOrAverageWithoutSelector(nameof(Queryable.Sum)) }, + { typeof(long), GetSumOrAverageWithoutSelector(nameof(Queryable.Sum)) }, + { typeof(int), GetSumOrAverageWithoutSelector(nameof(Queryable.Sum)) }, + { typeof(double), GetSumOrAverageWithoutSelector(nameof(Queryable.Sum)) }, + { typeof(float), GetSumOrAverageWithoutSelector(nameof(Queryable.Sum)) }, + { typeof(decimal?), GetSumOrAverageWithoutSelector(nameof(Queryable.Sum)) }, + { typeof(long?), GetSumOrAverageWithoutSelector(nameof(Queryable.Sum)) }, + { typeof(int?), GetSumOrAverageWithoutSelector(nameof(Queryable.Sum)) }, + { typeof(double?), GetSumOrAverageWithoutSelector(nameof(Queryable.Sum)) }, + { typeof(float?), GetSumOrAverageWithoutSelector(nameof(Queryable.Sum)) }, + }; + + SumWithSelectorMethodInfos = new Dictionary + { + { typeof(decimal), GetSumOrAverageWithSelector(nameof(Queryable.Sum)) }, + { typeof(long), GetSumOrAverageWithSelector(nameof(Queryable.Sum)) }, + { typeof(int), GetSumOrAverageWithSelector(nameof(Queryable.Sum)) }, + { typeof(double), GetSumOrAverageWithSelector(nameof(Queryable.Sum)) }, + { typeof(float), GetSumOrAverageWithSelector(nameof(Queryable.Sum)) }, + { typeof(decimal?), GetSumOrAverageWithSelector(nameof(Queryable.Sum)) }, + { typeof(long?), GetSumOrAverageWithSelector(nameof(Queryable.Sum)) }, + { typeof(int?), GetSumOrAverageWithSelector(nameof(Queryable.Sum)) }, + { typeof(double?), GetSumOrAverageWithSelector(nameof(Queryable.Sum)) }, + { typeof(float?), GetSumOrAverageWithSelector(nameof(Queryable.Sum)) }, + }; + + AverageWithoutSelectorMethodInfos = new Dictionary + { + { typeof(decimal), GetSumOrAverageWithoutSelector(nameof(Queryable.Average)) }, + { typeof(long), GetSumOrAverageWithoutSelector(nameof(Queryable.Average)) }, + { typeof(int), GetSumOrAverageWithoutSelector(nameof(Queryable.Average)) }, + { typeof(double), GetSumOrAverageWithoutSelector(nameof(Queryable.Average)) }, + { typeof(float), GetSumOrAverageWithoutSelector(nameof(Queryable.Average)) }, + { typeof(decimal?), GetSumOrAverageWithoutSelector(nameof(Queryable.Average)) }, + { typeof(long?), GetSumOrAverageWithoutSelector(nameof(Queryable.Average)) }, + { typeof(int?), GetSumOrAverageWithoutSelector(nameof(Queryable.Average)) }, + { typeof(double?), GetSumOrAverageWithoutSelector(nameof(Queryable.Average)) }, + { typeof(float?), GetSumOrAverageWithoutSelector(nameof(Queryable.Average)) }, + }; + + AverageWithSelectorMethodInfos = new Dictionary + { + { typeof(decimal), GetSumOrAverageWithSelector(nameof(Queryable.Average)) }, + { typeof(long), GetSumOrAverageWithSelector(nameof(Queryable.Average)) }, + { typeof(int), GetSumOrAverageWithSelector(nameof(Queryable.Average)) }, + { typeof(double), GetSumOrAverageWithSelector(nameof(Queryable.Average)) }, + { typeof(float), GetSumOrAverageWithSelector(nameof(Queryable.Average)) }, + { typeof(decimal?), GetSumOrAverageWithSelector(nameof(Queryable.Average)) }, + { typeof(long?), GetSumOrAverageWithSelector(nameof(Queryable.Average)) }, + { typeof(int?), GetSumOrAverageWithSelector(nameof(Queryable.Average)) }, + { typeof(double?), GetSumOrAverageWithSelector(nameof(Queryable.Average)) }, + { typeof(float?), GetSumOrAverageWithSelector(nameof(Queryable.Average)) }, + }; } } } diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.ResultOperators.cs b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.ResultOperators.cs index ce78d174d23..65f3acd907f 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.ResultOperators.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.ResultOperators.cs @@ -34,6 +34,12 @@ FROM root c WHERE (c[""Discriminator""] = ""Order"")"); } + [ConditionalTheory(Skip = "Issue#16146")] + public override Task Sum_with_no_data_cast_to_nullable(bool isAsync) + { + return base.Sum_with_no_data_cast_to_nullable(isAsync); + } + [ConditionalTheory(Skip = "Issue#16146")] public override async Task Sum_with_binary_expression(bool isAsync) { @@ -180,6 +186,17 @@ public override void Average_no_data() base.Average_no_data(); } + + [ConditionalFact(Skip = "Issue#16146")] + public override void Average_no_data_nullable() + { + } + + [ConditionalFact(Skip = "Issue#16146")] + public override void Average_no_data_cast_to_nullable() + { + } + [ConditionalFact(Skip = "Issue#16146")] public override void Min_no_data() { @@ -204,6 +221,16 @@ public override void Max_no_data_subquery() base.Max_no_data_subquery(); } + [ConditionalFact(Skip = "Issue#16146")] + public override void Max_no_data_nullable() + { + } + + [ConditionalFact(Skip = "Issue#16146")] + public override void Max_no_data_cast_to_nullable() + { + } + [ConditionalFact(Skip = "Issue#16146")] public override void Min_no_data_subquery() { @@ -403,6 +430,17 @@ FROM root c WHERE (c[""Discriminator""] = ""Order"")"); } + + [ConditionalFact(Skip = "Issue#16146")] + public override void Min_no_data_nullable() + { + } + + [ConditionalFact(Skip = "Issue#16146")] + public override void Min_no_data_cast_to_nullable() + { + } + [ConditionalTheory(Skip = "Issue#16146")] public override async Task Min_with_coalesce(bool isAsync) { diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs index 009f5670651..e7b49846ce0 100644 --- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs +++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs @@ -109,13 +109,23 @@ public virtual Task Sum_with_no_arg(bool isAsync) [ConditionalTheory] [MemberData(nameof(IsAsyncData))] - public virtual Task Sum_with_no_data_nullable(bool isAsync) + public virtual Task Sum_with_no_data_cast_to_nullable(bool isAsync) { return AssertSum( isAsync, os => os.Where(o => o.OrderID < 0).Select(o => (int?)o.OrderID)); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Sum_with_no_data_nullable(bool isAsync) + { + return AssertSum( + isAsync, + os => os, + selector: o => o.SupplierID); + } + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Sum_with_binary_expression(bool isAsync) @@ -411,6 +421,24 @@ public virtual void Min_no_data() } } + [ConditionalFact] + public virtual void Min_no_data_nullable() + { + using (var context = CreateContext()) + { + Assert.Null(context.Products.Where(o => o.SupplierID == -1).Min(o => o.SupplierID)); + } + } + + [ConditionalFact] + public virtual void Min_no_data_cast_to_nullable() + { + using (var context = CreateContext()) + { + Assert.Null(context.Orders.Where(o => o.OrderID == -1).Min(o => (int?)o.OrderID)); + } + } + [ConditionalFact] public virtual void Min_no_data_subquery() { @@ -430,6 +458,24 @@ public virtual void Max_no_data() } } + [ConditionalFact] + public virtual void Max_no_data_nullable() + { + using (var context = CreateContext()) + { + Assert.Null(context.Products.Where(o => o.SupplierID == -1).Max(o => o.SupplierID)); + } + } + + [ConditionalFact] + public virtual void Max_no_data_cast_to_nullable() + { + using (var context = CreateContext()) + { + Assert.Null(context.Orders.Where(o => o.OrderID == -1).Max(o => (int?)o.OrderID)); + } + } + [ConditionalFact] public virtual void Max_no_data_subquery() { @@ -449,6 +495,25 @@ public virtual void Average_no_data() } } + + [ConditionalFact] + public virtual void Average_no_data_nullable() + { + using (var context = CreateContext()) + { + Assert.Null(context.Products.Where(o => o.SupplierID == -1).Average(o => o.SupplierID)); + } + } + + [ConditionalFact] + public virtual void Average_no_data_cast_to_nullable() + { + using (var context = CreateContext()) + { + Assert.Null(context.Orders.Where(o => o.OrderID == -1).Average(o => (int?)o.OrderID)); + } + } + [ConditionalFact] public virtual void Average_no_data_subquery() {