diff --git a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.ExpressionVisitors.cs b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.ExpressionVisitors.cs index 74f2a01b712..e58dec253d8 100644 --- a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.ExpressionVisitors.cs +++ b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.ExpressionVisitors.cs @@ -615,5 +615,33 @@ public override Expression Visit(Expression expression) return base.Visit(expression); } } + + private class SelfReferenceEntityQueryableRewritingExpressionVisitor : ExpressionVisitor + { + private readonly NavigationExpandingExpressionVisitor _navigationExpandingExpressionVisitor; + private readonly IEntityType _entityType; + + public SelfReferenceEntityQueryableRewritingExpressionVisitor( + NavigationExpandingExpressionVisitor navigationExpandingExpressionVisitor, + IEntityType entityType) + { + _navigationExpandingExpressionVisitor = navigationExpandingExpressionVisitor; + _entityType = entityType; + } + + protected override Expression VisitConstant(ConstantExpression constantExpression) + { + if (constantExpression.IsEntityQueryable()) + { + var entityType = _navigationExpandingExpressionVisitor._queryCompilationContext.Model.FindEntityType(((IQueryable)constantExpression.Value).ElementType); + if (entityType == _entityType) + { + return _navigationExpandingExpressionVisitor.CreateNavigationExpansionExpression(constantExpression, entityType); + } + } + + return base.VisitConstant(constantExpression); + } + } } } diff --git a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs index d7521ae4375..66dc1af5ebe 100644 --- a/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs @@ -126,6 +126,8 @@ protected override Expression VisitConstant(ConstantExpression constantExpressio { var processedDefiningQueryBody = _parameterExtractingExpressionVisitor.ExtractParameters(definingQuery.Body); processedDefiningQueryBody = _enumerableToQueryableMethodConvertingExpressionVisitor.Visit(processedDefiningQueryBody); + processedDefiningQueryBody = new SelfReferenceEntityQueryableRewritingExpressionVisitor(this, entityType).Visit(processedDefiningQueryBody); + navigationExpansionExpression = (NavigationExpansionExpression)Visit(processedDefiningQueryBody); var expanded = ExpandAndReduce(navigationExpansionExpression, applyInclude: false); @@ -170,6 +172,7 @@ private Expression ApplyQueryFilter(NavigationExpansionExpression navigationExpa _parameterizedQueryFilterPredicateCache[rootEntityType] = filterPredicate; } + filterPredicate = (LambdaExpression)new SelfReferenceEntityQueryableRewritingExpressionVisitor(this, entityType).Visit(filterPredicate); var sequenceType = navigationExpansionExpression.Type.GetSequenceType(); // if we are constructing EntityQueryable of a derived type, we need to re-map filter predicate to the correct derived type diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs index 91e5443edf9..f22186f9933 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs @@ -6212,6 +6212,124 @@ public PostDTO7983 From(Post7983 post) #endregion + #region Bug17253 + + [ConditionalFact] + public virtual void Self_reference_in_query_filter_works() + { + using (CreateDatabase17253()) + { + using (var context = new MyContext17253(_options)) + { + var query = context.EntitiesWithQueryFilterSelfReference.Where(e => e.Name != "Foo"); + var result = query.ToList(); + + AssertSql( + @"SELECT [e].[Id], [e].[Name] +FROM [EntitiesWithQueryFilterSelfReference] AS [e] +WHERE EXISTS ( + SELECT 1 + FROM [EntitiesWithQueryFilterSelfReference] AS [e0]) AND (([e].[Name] <> N'Foo') OR [e].[Name] IS NULL)"); + } + } + } + + [ConditionalFact] + public virtual void Self_reference_in_query_filter_works_when_nested() + { + using (CreateDatabase17253()) + { + using (var context = new MyContext17253(_options)) + { + var query = context.EntitiesReferencingEntityWithQueryFilterSelfReference.Where(e => e.Name != "Foo"); + var result = query.ToList(); + + AssertSql( + @"SELECT [e].[Id], [e].[Name] +FROM [EntitiesReferencingEntityWithQueryFilterSelfReference] AS [e] +WHERE EXISTS ( + SELECT 1 + FROM [EntitiesWithQueryFilterSelfReference] AS [e0] + WHERE EXISTS ( + SELECT 1 + FROM [EntitiesWithQueryFilterSelfReference] AS [e1])) AND (([e].[Name] <> N'Foo') OR [e].[Name] IS NULL)"); + } + } + } + + public class MyContext17253 : DbContext + { + public DbSet EntitiesWithQueryFilterSelfReference { get; set; } + public DbSet EntitiesReferencingEntityWithQueryFilterSelfReference { get; set; } + + public DbSet EntitiesWithQueryFilterCycle1 { get; set; } + public DbSet EntitiesWithQueryFilterCycle2 { get; set; } + public DbSet EntitiesWithQueryFilterCycle3 { get; set; } + + public MyContext17253(DbContextOptions options) : base(options) + { + } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity().HasQueryFilter(e => EntitiesWithQueryFilterSelfReference.Any()); + modelBuilder.Entity().HasQueryFilter(e => Set().Any()); + + modelBuilder.Entity().HasQueryFilter(e => EntitiesWithQueryFilterCycle2.Any()); + modelBuilder.Entity().HasQueryFilter(e => Set().Any()); + modelBuilder.Entity().HasQueryFilter(e => EntitiesWithQueryFilterCycle1.Any()); + } + } + + private SqlServerTestStore CreateDatabase17253() + => CreateTestStore( + () => new MyContext17253(_options), + context => + { + context.EntitiesWithQueryFilterSelfReference.Add(new EntityWithQueryFilterSelfReference { Name = "EntityWithQueryFilterSelfReference" }); + context.EntitiesReferencingEntityWithQueryFilterSelfReference.Add(new EntityReferencingEntityWithQueryFilterSelfReference { Name = "EntityReferencingEntityWithQueryFilterSelfReference" }); + + context.EntitiesWithQueryFilterCycle1.Add(new EntityWithQueryFilterCycle1 { Name = "EntityWithQueryFilterCycle1_1" }); + context.EntitiesWithQueryFilterCycle2.Add(new EntityWithQueryFilterCycle2 { Name = "EntityWithQueryFilterCycle2_1" }); + context.EntitiesWithQueryFilterCycle3.Add(new EntityWithQueryFilterCycle3 { Name = "EntityWithQueryFilterCycle3_1" }); + + context.SaveChanges(); + + ClearLog(); + }); + + public class EntityWithQueryFilterSelfReference + { + public int Id { get; set; } + public string Name { get; set; } + } + + public class EntityReferencingEntityWithQueryFilterSelfReference + { + public int Id { get; set; } + public string Name { get; set; } + } + + public class EntityWithQueryFilterCycle1 + { + public int Id { get; set; } + public string Name { get; set; } + } + + public class EntityWithQueryFilterCycle2 + { + public int Id { get; set; } + public string Name { get; set; } + } + + public class EntityWithQueryFilterCycle3 + { + public int Id { get; set; } + public string Name { get; set; } + } + + #endregion + private DbContextOptions _options; private SqlServerTestStore CreateTestStore(