diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs index a4858168717..da5f81b9c14 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs @@ -527,7 +527,7 @@ public virtual void AddLeftJoin( _projectionMapping = projectionMapping; } - public virtual void AddCrossJoin(InMemoryQueryExpression innerQueryExpression, Type transparentIdentifierType) + public virtual void AddSelectMany(InMemoryQueryExpression innerQueryExpression, Type transparentIdentifierType, bool innerNullable) { var outerParameter = Parameter(typeof(ValueBuffer), "outer"); var innerParameter = Parameter(typeof(ValueBuffer), "inner"); @@ -564,6 +564,7 @@ public virtual void AddCrossJoin(InMemoryQueryExpression innerQueryExpression, T } var innerMemberInfo = transparentIdentifierType.GetTypeInfo().GetDeclaredField("Inner"); + var nullableReadValueExpressionVisitor = new NullableReadValueExpressionVisitor(); foreach (var projection in innerQueryExpression._projectionMapping) { if (projection.Value is EntityProjectionExpression entityProjection) @@ -571,17 +572,27 @@ public virtual void AddCrossJoin(InMemoryQueryExpression innerQueryExpression, T var readExpressionMap = new Dictionary(); foreach (var property in GetAllPropertiesInHierarchy(entityProjection.EntityType)) { - resultValueBufferExpressions.Add(replacingVisitor.Visit(entityProjection.BindProperty(property))); - readExpressionMap[property] = CreateReadValueExpression(property.ClrType, index++, property); + var replacedExpression = replacingVisitor.Visit(entityProjection.BindProperty(property)); + if (innerNullable) + { + replacedExpression = nullableReadValueExpressionVisitor.Visit(replacedExpression); + } + resultValueBufferExpressions.Add(replacedExpression); + readExpressionMap[property] = CreateReadValueExpression(replacedExpression.Type, index++, property); } projectionMapping[projection.Key.Prepend(innerMemberInfo)] = new EntityProjectionExpression(entityProjection.EntityType, readExpressionMap); } else { - resultValueBufferExpressions.Add(replacingVisitor.Visit(projection.Value)); + var replacedExpression = replacingVisitor.Visit(projection.Value); + if (innerNullable) + { + replacedExpression = nullableReadValueExpressionVisitor.Visit(replacedExpression); + } + resultValueBufferExpressions.Add(replacedExpression); projectionMapping[projection.Key.Prepend(innerMemberInfo)] - = CreateReadValueExpression(projection.Value.Type, index++, InferPropertyFromInner(projection.Value)); + = CreateReadValueExpression(replacedExpression.Type, index++, InferPropertyFromInner(projection.Value)); } } diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs index bcd471504ca..be86e1e5367 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs @@ -438,81 +438,73 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s return source; } - private static readonly MethodInfo _defaultIfEmptyWithoutArgMethodInfo = typeof(Enumerable).GetTypeInfo() - .GetDeclaredMethods(nameof(Enumerable.DefaultIfEmpty)).Single(mi => mi.GetParameters().Length == 1); - protected override ShapedQueryExpression TranslateSelectMany( ShapedQueryExpression source, LambdaExpression collectionSelector, LambdaExpression resultSelector) { - var collectionSelectorBody = collectionSelector.Body; - //var defaultIfEmpty = false; + var defaultIfEmpty = new DefaultIfEmptyFindingExpressionVisitor().IsOptional(collectionSelector); + var collectionSelectorBody = RemapLambdaBody(source, collectionSelector); - if (collectionSelectorBody is MethodCallExpression collectionEndingMethod - && collectionEndingMethod.Method.IsGenericMethod - && collectionEndingMethod.Method.GetGenericMethodDefinition() == _defaultIfEmptyWithoutArgMethodInfo) + if (Visit(collectionSelectorBody) is ShapedQueryExpression inner) { - //defaultIfEmpty = true; - collectionSelectorBody = collectionEndingMethod.Arguments[0]; - } - - var correlated = new CorrelationFindingExpressionVisitor().IsCorrelated(collectionSelectorBody, collectionSelector.Parameters[0]); - if (correlated) - { - // TODO visit inner with outer parameter; - // See #17236 - throw new InvalidOperationException(CoreStrings.TranslationFailed( - collectionSelector.Print() + "; " + resultSelector.Print())); - - } - else - { - if (Visit(collectionSelectorBody) is ShapedQueryExpression inner) - { - var transparentIdentifierType = TransparentIdentifierFactory.Create( - resultSelector.Parameters[0].Type, - resultSelector.Parameters[1].Type); - - ((InMemoryQueryExpression)source.QueryExpression).AddCrossJoin( - (InMemoryQueryExpression)inner.QueryExpression, transparentIdentifierType); - - return TranslateResultSelectorForJoin( - source, - resultSelector, - inner.ShaperExpression, - transparentIdentifierType); - } + var transparentIdentifierType = TransparentIdentifierFactory.Create( + resultSelector.Parameters[0].Type, + resultSelector.Parameters[1].Type); + + var innerShaperExpression = defaultIfEmpty + ? MarkShaperNullable(inner.ShaperExpression) + : inner.ShaperExpression; + + ((InMemoryQueryExpression)source.QueryExpression).AddSelectMany( + (InMemoryQueryExpression)inner.QueryExpression, transparentIdentifierType, defaultIfEmpty); + + return TranslateResultSelectorForJoin( + source, + resultSelector, + innerShaperExpression, + transparentIdentifierType); } return null; } - private class CorrelationFindingExpressionVisitor : ExpressionVisitor + private class DefaultIfEmptyFindingExpressionVisitor : ExpressionVisitor { - private ParameterExpression _outerParameter; - private bool _isCorrelated; - public bool IsCorrelated(Expression tree, ParameterExpression outerParameter) + private bool _defaultIfEmpty; + + public bool IsOptional(LambdaExpression lambdaExpression) { - _isCorrelated = false; - _outerParameter = outerParameter; + _defaultIfEmpty = false; - Visit(tree); + Visit(lambdaExpression.Body); - return _isCorrelated; + return _defaultIfEmpty; } - protected override Expression VisitParameter(ParameterExpression parameterExpression) + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) { - if (parameterExpression == _outerParameter) + if (methodCallExpression.Method.IsGenericMethod + && methodCallExpression.Method.GetGenericMethodDefinition() == QueryableMethods.DefaultIfEmptyWithoutArgument) { - _isCorrelated = true; + _defaultIfEmpty = true; } - return base.VisitParameter(parameterExpression); + return base.VisitMethodCall(methodCallExpression); } } protected override ShapedQueryExpression TranslateSelectMany(ShapedQueryExpression source, LambdaExpression selector) - => null; + { + var innerParameter = Expression.Parameter(selector.ReturnType.TryGetSequenceType(), "i"); + var resultSelector = Expression.Lambda( + innerParameter, + new[] + { + Expression.Parameter(source.Type.TryGetSequenceType()), + innerParameter + }); + + return TranslateSelectMany(source, selector, resultSelector); + } protected override ShapedQueryExpression TranslateSingleOrDefault(ShapedQueryExpression source, LambdaExpression predicate, Type returnType, bool returnDefault) { @@ -617,9 +609,7 @@ private Expression TranslateExpression(Expression expression) private LambdaExpression TranslateLambdaExpression( ShapedQueryExpression shapedQueryExpression, LambdaExpression lambdaExpression) { - var lambdaBody = ReplacingExpressionVisitor.Replace( - lambdaExpression.Parameters.Single(), shapedQueryExpression.ShaperExpression, lambdaExpression.Body); - lambdaBody = TranslateExpression(lambdaBody); + var lambdaBody = TranslateExpression(RemapLambdaBody(shapedQueryExpression, lambdaExpression)); return lambdaBody != null ? Expression.Lambda(lambdaBody, @@ -627,6 +617,12 @@ private LambdaExpression TranslateLambdaExpression( : null; } + private static Expression RemapLambdaBody(ShapedQueryExpression shapedQueryExpression, LambdaExpression lambdaExpression) + { + return ReplacingExpressionVisitor.Replace( + lambdaExpression.Parameters.Single(), shapedQueryExpression.ShaperExpression, lambdaExpression.Body); + } + private ShapedQueryExpression TranslateScalarAggregate( ShapedQueryExpression source, LambdaExpression selector, string methodName) { diff --git a/test/EFCore.InMemory.FunctionalTests/Query/SimpleQueryInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/SimpleQueryInMemoryTest.cs index 04e9a2e07b1..4c951aac29d 100644 --- a/test/EFCore.InMemory.FunctionalTests/Query/SimpleQueryInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/Query/SimpleQueryInMemoryTest.cs @@ -373,16 +373,6 @@ public override Task Join_with_default_if_empty_on_both_sources(bool isAsync) #region SelectMany - public override Task Multiple_select_many_with_predicate(bool isAsync) - { - return Task.CompletedTask; - } - - public override Task SelectMany_Joined(bool isAsync) - { - return Task.CompletedTask; - } - public override Task SelectMany_Joined_DefaultIfEmpty(bool isAsync) { return Task.CompletedTask; @@ -393,26 +383,6 @@ public override Task SelectMany_Joined_DefaultIfEmpty2(bool isAsync) return Task.CompletedTask; } - public override Task SelectMany_Joined_Take(bool isAsync) - { - return Task.CompletedTask; - } - - public override Task SelectMany_correlated_subquery_simple(bool isAsync) - { - return Task.CompletedTask; - } - - public override Task SelectMany_correlated_with_outer_1(bool isAsync) - { - return Task.CompletedTask; - } - - public override Task SelectMany_correlated_with_outer_2(bool isAsync) - { - return Task.CompletedTask; - } - public override Task SelectMany_correlated_with_outer_3(bool isAsync) { return Task.CompletedTask; @@ -423,26 +393,6 @@ public override Task SelectMany_correlated_with_outer_4(bool isAsync) return Task.CompletedTask; } - public override Task SelectMany_without_result_selector_collection_navigation_composed(bool isAsync) - { - return Task.CompletedTask; - } - - public override Task SelectMany_without_result_selector_naked_collection_navigation(bool isAsync) - { - return Task.CompletedTask; - } - - public override Task Select_DTO_with_member_init_distinct_in_subquery_translated_to_server(bool isAsync) - { - return Task.CompletedTask; - } - - public override Task Select_DTO_with_member_init_distinct_in_subquery_translated_to_server_2(bool isAsync) - { - return Task.CompletedTask; - } - #endregion #region NullableError