diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index 94f6707fecb..982c8b18dfb 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -430,98 +430,100 @@ Expression ProcessGetType(EntityReferenceExpression entityReferenceExpression, T var derivedType = entityType.GetDerivedTypesInclusive().SingleOrDefault(et => et.ClrType == comparisonType); // If no derived type matches then fail the translation - if (derivedType != null) + if (derivedType == null) { - // If the derived type is abstract type then predicate will always be false - if (derivedType.IsAbstract()) - { - return _sqlExpressionFactory.Constant(!match); - } + return QueryCompilationContext.NotTranslatedExpression; + } - // Or add predicate for matching that particular type discriminator value - var discriminatorProperty = entityType.FindDiscriminatorProperty(); - if (discriminatorProperty == null) - { - // TPT or TPC - var discriminatorValue = derivedType.ShortName(); - if (entityReferenceExpression.SubqueryEntity != null) - { - var entityShaper = (EntityShaperExpression)entityReferenceExpression.SubqueryEntity.ShaperExpression; - var entityProjection = (EntityProjectionExpression)Visit(entityShaper.ValueBufferExpression); - var subSelectExpression = (SelectExpression)entityReferenceExpression.SubqueryEntity.QueryExpression; + // If the derived type is abstract type then predicate will always be false + if (derivedType.IsAbstract()) + { + return _sqlExpressionFactory.Constant(!match); + } - var predicate = GeneratePredicateTpt(entityProjection); + // Or add predicate for matching that particular type discriminator value + var discriminatorProperty = entityType.FindDiscriminatorProperty(); + if (discriminatorProperty == null) + { + // TPT or TPC + var discriminatorValue = derivedType.ShortName(); + if (entityReferenceExpression.SubqueryEntity != null) + { + var entityShaper = (EntityShaperExpression)entityReferenceExpression.SubqueryEntity.ShaperExpression; + var entityProjection = (EntityProjectionExpression)Visit(entityShaper.ValueBufferExpression); + var subSelectExpression = (SelectExpression)entityReferenceExpression.SubqueryEntity.QueryExpression; - subSelectExpression.ApplyPredicate(predicate); - subSelectExpression.ReplaceProjection(new List()); - subSelectExpression.ApplyProjection(); - if (subSelectExpression.Limit == null - && subSelectExpression.Offset == null) - { - subSelectExpression.ClearOrdering(); - } + var predicate = GeneratePredicateTpt(entityProjection); - return _sqlExpressionFactory.Exists(subSelectExpression, false); + subSelectExpression.ApplyPredicate(predicate); + subSelectExpression.ReplaceProjection(new List()); + subSelectExpression.ApplyProjection(); + if (subSelectExpression.Limit == null + && subSelectExpression.Offset == null) + { + subSelectExpression.ClearOrdering(); } - if (entityReferenceExpression.ParameterEntity != null) - { - var entityProjection = (EntityProjectionExpression)Visit( - entityReferenceExpression.ParameterEntity.ValueBufferExpression); + return _sqlExpressionFactory.Exists(subSelectExpression, false); + } - return GeneratePredicateTpt(entityProjection); - } + if (entityReferenceExpression.ParameterEntity != null) + { + var entityProjection = (EntityProjectionExpression)Visit( + entityReferenceExpression.ParameterEntity.ValueBufferExpression); + + return GeneratePredicateTpt(entityProjection); + } - SqlExpression GeneratePredicateTpt(EntityProjectionExpression entityProjectionExpression) + SqlExpression GeneratePredicateTpt(EntityProjectionExpression entityProjectionExpression) + { + if (entityProjectionExpression.DiscriminatorExpression is CaseExpression caseExpression) { - if (entityProjectionExpression.DiscriminatorExpression is CaseExpression caseExpression) + // TPT case + // Most root type doesn't have matching case + // All derived types needs to be excluded + var derivedTypeValues = derivedType.GetDerivedTypes().Where(e => !e.IsAbstract()).Select(e => e.ShortName()) + .ToList(); + var predicates = new List(); + foreach (var caseWhenClause in caseExpression.WhenClauses) { - // TPT case - // Most root type doesn't have matching case - // All derived types needs to be excluded - var derivedTypeValues = derivedType.GetDerivedTypes().Where(e => !e.IsAbstract()).Select(e => e.ShortName()) - .ToList(); - var predicates = new List(); - foreach (var caseWhenClause in caseExpression.WhenClauses) + var value = (string)((SqlConstantExpression)caseWhenClause.Result).Value!; + if (value == discriminatorValue) { - var value = (string)((SqlConstantExpression)caseWhenClause.Result).Value!; - if (value == discriminatorValue) - { - predicates.Add(caseWhenClause.Test); - } - else if (derivedTypeValues.Contains(value)) - { - predicates.Add(_sqlExpressionFactory.Not(caseWhenClause.Test)); - } + predicates.Add(caseWhenClause.Test); + } + else if (derivedTypeValues.Contains(value)) + { + predicates.Add(_sqlExpressionFactory.Not(caseWhenClause.Test)); } - - var result = predicates.Aggregate((a, b) => _sqlExpressionFactory.AndAlso(a, b)); - - return match ? result : _sqlExpressionFactory.Not(result); } - return match - ? _sqlExpressionFactory.Equal( - entityProjectionExpression.DiscriminatorExpression!, - _sqlExpressionFactory.Constant(discriminatorValue)) - : _sqlExpressionFactory.NotEqual( - entityProjectionExpression.DiscriminatorExpression!, - _sqlExpressionFactory.Constant(discriminatorValue)); + var result = predicates.Aggregate((a, b) => _sqlExpressionFactory.AndAlso(a, b)); + + return match ? result : _sqlExpressionFactory.Not(result); } + + return match + ? _sqlExpressionFactory.Equal( + entityProjectionExpression.DiscriminatorExpression!, + _sqlExpressionFactory.Constant(discriminatorValue)) + : _sqlExpressionFactory.NotEqual( + entityProjectionExpression.DiscriminatorExpression!, + _sqlExpressionFactory.Constant(discriminatorValue)); } - else + } + else + { + var discriminatorColumn = BindProperty(entityReferenceExpression, discriminatorProperty); + if (discriminatorColumn != null) { - var discriminatorColumn = BindProperty(entityReferenceExpression, discriminatorProperty); - if (discriminatorColumn != null) - { - return match - ? _sqlExpressionFactory.Equal( - discriminatorColumn, - _sqlExpressionFactory.Constant(derivedType.GetDiscriminatorValue())) - : _sqlExpressionFactory.NotEqual( - discriminatorColumn, - _sqlExpressionFactory.Constant(derivedType.GetDiscriminatorValue())); - } + return match + ? _sqlExpressionFactory.Equal( + discriminatorColumn, + _sqlExpressionFactory.Constant(derivedType.GetDiscriminatorValue())) + : _sqlExpressionFactory.NotEqual( + discriminatorColumn, + _sqlExpressionFactory.Constant(derivedType.GetDiscriminatorValue())); } } @@ -727,19 +729,16 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp } // EF Indexer property - if (methodCallExpression.TryGetIndexerArguments(_model, out source, out propertyName)) + if (methodCallExpression.TryGetIndexerArguments(_model, out source, out propertyName) + && TryBindMember(Visit(source), MemberIdentity.Create(propertyName)) is SqlExpression indexerResult) { - if (TryBindMember(Visit(source), MemberIdentity.Create(propertyName)) is SqlExpression result) - { - return result; - } + return indexerResult; } var method = methodCallExpression.Method; var arguments = methodCallExpression.Arguments; EnumerableExpression? enumerableExpression = null; - var abortTranslation = false; SqlExpression? sqlObject = null; List scalarArguments; @@ -866,75 +865,65 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp return translatedAggregate; } - abortTranslation = true; + goto SubqueryTranslation; } scalarArguments = new List(); - if (!abortTranslation) + if (!TryTranslateAsEnumerableExpression(methodCallExpression.Object, out enumerableExpression) + && TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out sqlObject)) { - if (!TryTranslateAsEnumerableExpression(methodCallExpression.Object, out enumerableExpression) - && TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out sqlObject)) - { - abortTranslation = true; - } + goto SubqueryTranslation; + } - if (!abortTranslation) + for (var i = 0; i < arguments.Count; i++) + { + var argument = arguments[i]; + if (TryTranslateAsEnumerableExpression(argument, out var eea)) { - for (var i = 0; i < arguments.Count; i++) + if (enumerableExpression != null) { - var argument = arguments[i]; - if (TryTranslateAsEnumerableExpression(argument, out var eea)) - { - if (enumerableExpression != null) - { - abortTranslation = true; - break; - } - - enumerableExpression = eea; - continue; - } + goto SubqueryTranslation; + } - var visitedArgument = Visit(argument); - if (TranslationFailed(argument, visitedArgument, out var sqlArgument)) - { - abortTranslation = true; - break; - } + enumerableExpression = eea; + continue; + } - scalarArguments.Add(sqlArgument!); - } + var visitedArgument = Visit(argument); + if (TranslationFailed(argument, visitedArgument, out var sqlArgument)) + { + goto SubqueryTranslation; } + + scalarArguments.Add(sqlArgument!); } } - if (!abortTranslation) - { - var translation = enumerableExpression != null - ? TranslateAggregateMethod(enumerableExpression, method, scalarArguments) - : Dependencies.MethodCallTranslatorProvider.Translate( - _model, sqlObject, method, scalarArguments, _queryCompilationContext.Logger); + var translation = enumerableExpression != null + ? TranslateAggregateMethod(enumerableExpression, method, scalarArguments) + : Dependencies.MethodCallTranslatorProvider.Translate( + _model, sqlObject, method, scalarArguments, _queryCompilationContext.Logger); - if (translation != null) - { - return translation; - } + if (translation != null) + { + return translation; + } - if (method == StringEqualsWithStringComparison - || method == StringEqualsWithStringComparisonStatic) - { - AddTranslationErrorDetails(CoreStrings.QueryUnableToTranslateStringEqualsWithStringComparison); - } - else - { - AddTranslationErrorDetails( - CoreStrings.QueryUnableToTranslateMethod( - method.DeclaringType?.DisplayName(), - method.Name)); - } + if (method == StringEqualsWithStringComparison + || method == StringEqualsWithStringComparisonStatic) + { + AddTranslationErrorDetails(CoreStrings.QueryUnableToTranslateStringEqualsWithStringComparison); + } + else + { + AddTranslationErrorDetails( + CoreStrings.QueryUnableToTranslateMethod( + method.DeclaringType?.DisplayName(), + method.Name)); } // Subquery case + SubqueryTranslation: var subqueryTranslation = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression); return subqueryTranslation == null @@ -961,109 +950,111 @@ protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExp { var innerExpression = Visit(typeBinaryExpression.Expression); - if (typeBinaryExpression.NodeType == ExpressionType.TypeIs - && innerExpression is EntityReferenceExpression entityReferenceExpression) + if (typeBinaryExpression.NodeType != ExpressionType.TypeIs + || innerExpression is not EntityReferenceExpression entityReferenceExpression) { - var entityType = entityReferenceExpression.EntityType; - if (entityType.GetAllBaseTypesInclusive().Any(et => et.ClrType == typeBinaryExpression.TypeOperand)) + return QueryCompilationContext.NotTranslatedExpression; + } + var entityType = entityReferenceExpression.EntityType; + if (entityType.GetAllBaseTypesInclusive().Any(et => et.ClrType == typeBinaryExpression.TypeOperand)) + { + return _sqlExpressionFactory.Constant(true); + } + + var derivedType = entityType.GetDerivedTypes().SingleOrDefault(et => et.ClrType == typeBinaryExpression.TypeOperand); + if (derivedType == null) + { + return QueryCompilationContext.NotTranslatedExpression; + } + var discriminatorProperty = entityType.FindDiscriminatorProperty(); + if (discriminatorProperty == null) + { + if (entityType.GetMappingStrategy() == RelationalAnnotationNames.TpcMappingStrategy + && entityType.GetDerivedTypesInclusive().Count(e => !e.IsAbstract()) == 1) { return _sqlExpressionFactory.Constant(true); } - var derivedType = entityType.GetDerivedTypes().SingleOrDefault(et => et.ClrType == typeBinaryExpression.TypeOperand); - if (derivedType != null) + // TPT or TPC + var discriminatorValues = derivedType.GetConcreteDerivedTypesInclusive() + .Select(e => (string)e.GetDiscriminatorValue()!).ToList(); + if (entityReferenceExpression.SubqueryEntity != null) { - var discriminatorProperty = entityType.FindDiscriminatorProperty(); - if (discriminatorProperty == null) - { - if (entityType.GetMappingStrategy() == RelationalAnnotationNames.TpcMappingStrategy - && entityType.GetDerivedTypesInclusive().Count(e => !e.IsAbstract()) == 1) - { - return _sqlExpressionFactory.Constant(true); - } - - // TPT or TPC - var discriminatorValues = derivedType.GetConcreteDerivedTypesInclusive() - .Select(e => (string)e.GetDiscriminatorValue()!).ToList(); - if (entityReferenceExpression.SubqueryEntity != null) - { - var entityShaper = (EntityShaperExpression)entityReferenceExpression.SubqueryEntity.ShaperExpression; - var entityProjection = (EntityProjectionExpression)Visit(entityShaper.ValueBufferExpression); - var subSelectExpression = (SelectExpression)entityReferenceExpression.SubqueryEntity.QueryExpression; - - var predicate = GeneratePredicateTpt(entityProjection); + var entityShaper = (EntityShaperExpression)entityReferenceExpression.SubqueryEntity.ShaperExpression; + var entityProjection = (EntityProjectionExpression)Visit(entityShaper.ValueBufferExpression); + var subSelectExpression = (SelectExpression)entityReferenceExpression.SubqueryEntity.QueryExpression; - subSelectExpression.ApplyPredicate(predicate); - subSelectExpression.ReplaceProjection(new List()); - subSelectExpression.ApplyProjection(); - if (subSelectExpression.Limit == null - && subSelectExpression.Offset == null) - { - subSelectExpression.ClearOrdering(); - } + var predicate = GeneratePredicateTpt(entityProjection); - return _sqlExpressionFactory.Exists(subSelectExpression, false); - } + subSelectExpression.ApplyPredicate(predicate); + subSelectExpression.ReplaceProjection(new List()); + subSelectExpression.ApplyProjection(); + if (subSelectExpression.Limit == null + && subSelectExpression.Offset == null) + { + subSelectExpression.ClearOrdering(); + } - if (entityReferenceExpression.ParameterEntity != null) - { - var entityProjection = (EntityProjectionExpression)Visit( - entityReferenceExpression.ParameterEntity.ValueBufferExpression); + return _sqlExpressionFactory.Exists(subSelectExpression, false); + } - return GeneratePredicateTpt(entityProjection); - } + if (entityReferenceExpression.ParameterEntity != null) + { + var entityProjection = (EntityProjectionExpression)Visit( + entityReferenceExpression.ParameterEntity.ValueBufferExpression); - SqlExpression GeneratePredicateTpt(EntityProjectionExpression entityProjectionExpression) - { - if (entityProjectionExpression.DiscriminatorExpression is CaseExpression caseExpression) - { - var matchingCaseWhenClauses = caseExpression.WhenClauses - .Where(wc => discriminatorValues.Contains((string)((SqlConstantExpression)wc.Result).Value!)) - .ToList(); - - return matchingCaseWhenClauses.Count == 1 - ? matchingCaseWhenClauses[0].Test - : matchingCaseWhenClauses.Select(e => e.Test) - .Aggregate((l, r) => _sqlExpressionFactory.OrElse(l, r)); - } + return GeneratePredicateTpt(entityProjection); + } - return discriminatorValues.Count == 1 - ? _sqlExpressionFactory.Equal( - entityProjectionExpression.DiscriminatorExpression!, - _sqlExpressionFactory.Constant(discriminatorValues[0])) - : _sqlExpressionFactory.In( - entityProjectionExpression.DiscriminatorExpression!, - _sqlExpressionFactory.Constant(discriminatorValues), - negated: false); - } + SqlExpression GeneratePredicateTpt(EntityProjectionExpression entityProjectionExpression) + { + if (entityProjectionExpression.DiscriminatorExpression is CaseExpression caseExpression) + { + var matchingCaseWhenClauses = caseExpression.WhenClauses + .Where(wc => discriminatorValues.Contains((string)((SqlConstantExpression)wc.Result).Value!)) + .ToList(); + + return matchingCaseWhenClauses.Count == 1 + ? matchingCaseWhenClauses[0].Test + : matchingCaseWhenClauses.Select(e => e.Test) + .Aggregate((l, r) => _sqlExpressionFactory.OrElse(l, r)); } - else + + return discriminatorValues.Count == 1 + ? _sqlExpressionFactory.Equal( + entityProjectionExpression.DiscriminatorExpression!, + _sqlExpressionFactory.Constant(discriminatorValues[0])) + : _sqlExpressionFactory.In( + entityProjectionExpression.DiscriminatorExpression!, + _sqlExpressionFactory.Constant(discriminatorValues), + negated: false); + } + } + else + { + if (!derivedType.GetRootType().GetIsDiscriminatorMappingComplete() + || !derivedType.GetAllBaseTypesInclusiveAscending() + .All(e => (e == derivedType || e.IsAbstract()) && !HasSiblings(e))) + { + var concreteEntityTypes = derivedType.GetConcreteDerivedTypesInclusive().ToList(); + var discriminatorColumn = BindProperty(entityReferenceExpression, discriminatorProperty); + if (discriminatorColumn != null) { - if (!derivedType.GetRootType().GetIsDiscriminatorMappingComplete() - || !derivedType.GetAllBaseTypesInclusiveAscending() - .All(e => (e == derivedType || e.IsAbstract()) && !HasSiblings(e))) - { - var concreteEntityTypes = derivedType.GetConcreteDerivedTypesInclusive().ToList(); - var discriminatorColumn = BindProperty(entityReferenceExpression, discriminatorProperty); - if (discriminatorColumn != null) - { - return concreteEntityTypes.Count == 1 - ? _sqlExpressionFactory.Equal( - discriminatorColumn, - _sqlExpressionFactory.Constant(concreteEntityTypes[0].GetDiscriminatorValue())) - : _sqlExpressionFactory.In( - discriminatorColumn, - _sqlExpressionFactory.Constant( - concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList()), - negated: false); - } - } - else - { - return _sqlExpressionFactory.Constant(true); - } + return concreteEntityTypes.Count == 1 + ? _sqlExpressionFactory.Equal( + discriminatorColumn, + _sqlExpressionFactory.Constant(concreteEntityTypes[0].GetDiscriminatorValue())) + : _sqlExpressionFactory.In( + discriminatorColumn, + _sqlExpressionFactory.Constant( + concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList()), + negated: false); } } + else + { + return _sqlExpressionFactory.Constant(true); + } } return QueryCompilationContext.NotTranslatedExpression; @@ -1545,8 +1536,7 @@ when sqlParameterExpression.Name.StartsWith(QueryCompilationContext.QueryParamet QueryCompilationContext.QueryContextParameter, Expression.Constant(sqlParameterExpression.Name, typeof(string)), Expression.Constant(property, typeof(IProperty))), - QueryCompilationContext.QueryContextParameter - ); + QueryCompilationContext.QueryContextParameter); var newParameterName = $"{RuntimeParameterPrefix}" @@ -1836,17 +1826,9 @@ public Expression Convert(Type type) private sealed class SqlTypeMappingVerifyingExpressionVisitor : ExpressionVisitor { protected override Expression VisitExtension(Expression extensionExpression) - { - if (extensionExpression is SqlExpression sqlExpression - && extensionExpression is not SqlFragmentExpression) - { - if (sqlExpression.TypeMapping == null) - { - throw new InvalidOperationException(RelationalStrings.NullTypeMappingInSqlTree(sqlExpression.Print())); - } - } - - return base.VisitExtension(extensionExpression); - } + => extensionExpression is SqlExpression { TypeMapping: null } sqlExpression + && extensionExpression is not SqlFragmentExpression + ? throw new InvalidOperationException(RelationalStrings.NullTypeMappingInSqlTree(sqlExpression.Print())) + : base.VisitExtension(extensionExpression); } }