Skip to content

Commit

Permalink
Generate condition for required property in optional dependents (#29206)
Browse files Browse the repository at this point in the history
  • Loading branch information
smitpatel authored Sep 28, 2022
1 parent 7dcfd35 commit a27f04f
Show file tree
Hide file tree
Showing 9 changed files with 368 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ public static IEnumerable<ITableMappingBase> GetViewOrTableMappings(this IEntity
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public static IEnumerable<IProperty> GetNonPrincipalSharedNonPkProperties(this IEntityType entityType, ITableBase table)
public static List<IProperty> GetNonPrincipalSharedNonPkProperties(this IEntityType entityType, ITableBase table)
{
var principalEntityTypes = new HashSet<IEntityType>();
PopulatePrincipalEntityTypes(table, entityType, principalEntityTypes);
var properties = new List<IProperty>();
foreach (var property in entityType.GetProperties())
{
if (property.IsPrimaryKey())
Expand All @@ -104,9 +105,11 @@ public static IEnumerable<IProperty> GetNonPrincipalSharedNonPkProperties(this I
continue;
}

yield return property;
properties.Add(property);
}

return properties;

static void PopulatePrincipalEntityTypes(ITableBase table, IEntityType entityType, HashSet<IEntityType> entityTypes)
{
foreach (var linkingFk in table.GetRowInternalForeignKeys(entityType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ protected override LambdaExpression GenerateMaterializationCondition(IEntityType

var allNonPrincipalSharedNonPkProperties = entityType.GetNonPrincipalSharedNonPkProperties(table);
// We don't need condition for nullable property if there exist at least one required property which is non shared.
if (allNonPrincipalSharedNonPkProperties.Any()
if (allNonPrincipalSharedNonPkProperties.Count != 0
&& allNonPrincipalSharedNonPkProperties.All(p => p.IsNullable))
{
var atLeastOneNonNullValueInNullablePropertyCondition = allNonPrincipalSharedNonPkProperties
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Collections;
using System.Diagnostics.CodeAnalysis;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;

Expand Down Expand Up @@ -1184,18 +1185,43 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
}

// this is optional dependent sharing table
var nonPrincipalSharedNonPkProperties = entityType.GetNonPrincipalSharedNonPkProperties(table).ToList();
var nonPrincipalSharedNonPkProperties = entityType.GetNonPrincipalSharedNonPkProperties(table);
if (nonPrincipalSharedNonPkProperties.Contains(property))
{
// The column is not being shared with principal side so we can always use directly
return propertyAccess;
}

var condition = nonPrincipalSharedNonPkProperties
.Where(e => !e.IsNullable)
.Select(p => entityProjectionExpression.BindProperty(p))
.Select(c => (SqlExpression)_sqlExpressionFactory.NotEqual(c, _sqlExpressionFactory.Constant(null)))
.Aggregate((a, b) => _sqlExpressionFactory.AndAlso(a, b));
SqlExpression? condition = null;
// Property is being shared with principal side, so we need to make it conditional access
var allRequiredNonPkProperties =
entityType.GetProperties().Where(p => !p.IsNullable && !p.IsPrimaryKey()).ToList();
if (allRequiredNonPkProperties.Count > 0)
{
condition = allRequiredNonPkProperties.Select(p => entityProjectionExpression.BindProperty(p))
.Select(c => (SqlExpression)_sqlExpressionFactory.NotEqual(c, _sqlExpressionFactory.Constant(null)))
.Aggregate((a, b) => _sqlExpressionFactory.AndAlso(a, b));
}

if (nonPrincipalSharedNonPkProperties.Count != 0
&& nonPrincipalSharedNonPkProperties.All(p => p.IsNullable))
{
// If all non principal shared properties are nullable then we need additional condition
var atLeastOneNonNullValueInNullableColumnsCondition = nonPrincipalSharedNonPkProperties
.Select(p => entityProjectionExpression.BindProperty(p))
.Select(c => (SqlExpression)_sqlExpressionFactory.NotEqual(c, _sqlExpressionFactory.Constant(null)))
.Aggregate((a, b) => _sqlExpressionFactory.OrElse(a, b));

condition = condition == null
? atLeastOneNonNullValueInNullableColumnsCondition
: _sqlExpressionFactory.AndAlso(condition, atLeastOneNonNullValueInNullableColumnsCondition);
}

if (condition == null)
{
// if we cannot compute condition then we just return property access (and hope for the best)
return propertyAccess;
}

return _sqlExpressionFactory.Case(
new List<CaseWhenClause> { new(condition, propertyAccess) },
Expand Down Expand Up @@ -1610,17 +1636,47 @@ private bool TryRewriteEntityEquality(
?? nullComparedEntityType.GetDefaultMappings().Single().Table;
if (table.IsOptional(nullComparedEntityType))
{
var condition = nullComparedEntityType.GetNonPrincipalSharedNonPkProperties(table)
.Where(e => !e.IsNullable)
.Select(
p => Infrastructure.ExpressionExtensions.CreateEqualsExpression(
CreatePropertyAccessExpression(nonNullEntityReference, p),
Expression.Constant(null, p.ClrType.MakeNullable()),
nodeType != ExpressionType.Equal))
.Aggregate((l, r) => nodeType == ExpressionType.Equal ? Expression.OrElse(l, r) : Expression.AndAlso(l, r));

result = Visit(condition);
return true;
Expression? condition = null;
// Optional dependent sharing table
var requiredNonPkProperties = nullComparedEntityType.GetProperties().Where(p => !p.IsNullable && !p.IsPrimaryKey()).ToList();
if (requiredNonPkProperties.Count > 0)
{
condition = requiredNonPkProperties.Select(
p => Infrastructure.ExpressionExtensions.CreateEqualsExpression(
CreatePropertyAccessExpression(nonNullEntityReference, p),
Expression.Constant(null, p.ClrType.MakeNullable()),
nodeType != ExpressionType.Equal))
.Aggregate((l, r) => nodeType == ExpressionType.Equal ? Expression.OrElse(l, r) : Expression.AndAlso(l, r));
}

var allNonPrincipalSharedNonPkProperties = nullComparedEntityType.GetNonPrincipalSharedNonPkProperties(table);
// We don't need condition for nullable property if there exist at least one required property which is non shared.
if (allNonPrincipalSharedNonPkProperties.Count != 0
&& allNonPrincipalSharedNonPkProperties.All(p => p.IsNullable))
{
var atLeastOneNonNullValueInNullablePropertyCondition = allNonPrincipalSharedNonPkProperties
.Select(
p => Infrastructure.ExpressionExtensions.CreateEqualsExpression(
CreatePropertyAccessExpression(nonNullEntityReference, p),
Expression.Constant(null, p.ClrType.MakeNullable()),
nodeType != ExpressionType.Equal))
.Aggregate((l, r) => nodeType == ExpressionType.Equal ? Expression.OrElse(l, r) : Expression.AndAlso(l, r));

condition = condition == null
? atLeastOneNonNullValueInNullablePropertyCondition
: nodeType == ExpressionType.Equal
? Expression.OrElse(condition, atLeastOneNonNullValueInNullablePropertyCondition)
: Expression.AndAlso(condition, atLeastOneNonNullValueInNullablePropertyCondition);
}

if (condition != null)
{
result = Visit(condition);
return true;
}

result = null;
return false;
}
}

Expand Down
31 changes: 26 additions & 5 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -645,12 +645,33 @@ private void AddConditions(SelectExpression selectExpression, IEntityType entity
var table = (firstTable as FromSqlExpression)?.Table ?? ((ITableBasedExpression)firstTable).Table;
if (table.IsOptional(entityType))
{
SqlExpression? predicate = null;
var entityProjectionExpression = GetMappedEntityProjectionExpression(selectExpression);
var predicate = entityType.GetNonPrincipalSharedNonPkProperties(table)
.Where(e => !e.IsNullable)
.Select(e => IsNotNull(e, entityProjectionExpression))
.Aggregate((l, r) => AndAlso(l, r));
selectExpression.ApplyPredicate(predicate);
var requiredNonPkProperties = entityType.GetProperties().Where(p => !p.IsNullable && !p.IsPrimaryKey()).ToList();
if (requiredNonPkProperties.Count > 0)
{
predicate = requiredNonPkProperties.Select(e => IsNotNull(e, entityProjectionExpression))
.Aggregate((l, r) => AndAlso(l, r));
}

var allNonSharedNonPkProperties = entityType.GetNonPrincipalSharedNonPkProperties(table);
// We don't need condition for nullable property if there exist at least one required property which is non shared.
if (allNonSharedNonPkProperties.Count != 0
&& allNonSharedNonPkProperties.All(p => p.IsNullable))
{
var atLeastOneNonNullValueInNullablePropertyCondition = allNonSharedNonPkProperties
.Select(e => IsNotNull(e, entityProjectionExpression))
.Aggregate((a, b) => OrElse(a, b));

predicate = predicate == null
? atLeastOneNonNullValueInNullablePropertyCondition
: AndAlso(predicate, atLeastOneNonNullValueInNullablePropertyCondition);
}

if (predicate != null)
{
selectExpression.ApplyPredicate(predicate);
}
}

bool HasSiblings(IEntityType entityType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,55 @@ public virtual async Task Owned_entity_with_all_null_properties_materializes_whe
t =>
{
Assert.Equal("Buyer2", t.Buyer);
// Cannot verify owned entities here since they differ between relational/in-memory
Assert.Null(t.Rot);
Assert.Null(t.Rut);
});
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Owned_entity_with_all_null_properties_entity_equality_when_not_containing_another_owned_entity(bool async)
{
var contextFactory = await InitializeAsync<MyContext28247>(seed: c => c.Seed());

using var context = contextFactory.CreateContext();
var query = context.RotRutCases.AsNoTracking().Select(e => e.Rot).Where(e => e != null);

var result = async
? await query.ToListAsync()
: query.ToList();

Assert.Collection(
result,
t =>
{
Assert.Equal(1, t.ServiceType);
Assert.Equal("1", t.ApartmentNo);
});
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Owned_entity_with_all_null_properties_property_access_when_not_containing_another_owned_entity(bool async)
{
var contextFactory = await InitializeAsync<MyContext28247>(seed: c => c.Seed());

using var context = contextFactory.CreateContext();
var query = context.RotRutCases.AsNoTracking().Select(e => e.Rot.ApartmentNo);

var result = async
? await query.ToListAsync()
: query.ToList();

Assert.Collection(
result,
t =>
{
Assert.Equal("1", t);
},
t =>
{
Assert.Null(t);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,96 @@ await TestHelpers.ExecuteWithStrategyInTransactionAsync(
}
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Optional_dependent_without_required_property(bool async)
{
var contextFactory = await InitializeAsync<Context29196>(
onConfiguring: e => e.ConfigureWarnings(w => w.Log(RelationalEventId.OptionalDependentWithoutIdentifyingPropertyWarning)));

using (var context = contextFactory.CreateContext())
{
var query = context.DetailedOrders.Where(o => o.Status == OrderStatus.Pending);

var result = async
? await query.ToListAsync()
: query.ToList();
}
}

protected class Context29196 : DbContext
{
public Context29196(DbContextOptions options)
: base(options)
{
}

public DbSet<Order> Orders => Set<Order>();

public DbSet<DetailedOrder> DetailedOrders => Set<DetailedOrder>();

protected override void OnModelCreating(ModelBuilder modelBuilder)
{
base.OnModelCreating(modelBuilder);

modelBuilder.Entity<DetailedOrder>(
dob =>
{
dob.ToTable("Orders");
dob.Property(o => o.Status).HasColumnName("Status");
dob.Property(o => o.Version).IsRowVersion().HasColumnName("Version");
});

modelBuilder.Entity<Order>(
ob =>
{
ob.ToTable("Orders");
ob.Property(o => o.Status).HasColumnName("Status");
ob.HasOne(o => o.DetailedOrder).WithOne().HasForeignKey<DetailedOrder>(o => o.Id);
ob.Property<byte[]>("Version").IsRowVersion().HasColumnName("Version");
});
}

public void Seed()
{
Add(
new Order
{
Status = OrderStatus.Pending,
DetailedOrder = new DetailedOrder
{
Status = OrderStatus.Pending,
ShippingAddress = "221 B Baker St, London",
BillingAddress = "11 Wall Street, New York"
}
});

SaveChanges();
}
}

public class DetailedOrder
{
public int Id { get; set; }
public OrderStatus? Status { get; set; }
public string BillingAddress { get; set; }
public string ShippingAddress { get; set; }
public byte[] Version { get; set; }
}

public class Order
{
public int Id { get; set; }
public OrderStatus? Status { get; set; }
public DetailedOrder DetailedOrder { get; set; }
}

public enum OrderStatus
{
Pending,
Shipped
}

public void UseTransaction(DatabaseFacade facade, IDbContextTransaction transaction)
=> facade.UseTransaction(transaction.GetDbTransaction());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,13 @@ public override async Task Normal_entity_owning_a_split_reference_with_main_frag

AssertSql(
@"SELECT [e].[Id], CASE
WHEN [e].[OwnedReference_Id] IS NOT NULL AND [e].[OwnedReference_OwnedIntValue1] IS NOT NULL AND [e].[OwnedReference_OwnedIntValue2] IS NOT NULL THEN [o].[OwnedIntValue4]
WHEN [e].[OwnedReference_Id] IS NOT NULL AND [e].[OwnedReference_OwnedIntValue1] IS NOT NULL AND [e].[OwnedReference_OwnedIntValue2] IS NOT NULL AND [o0].[OwnedIntValue3] IS NOT NULL AND [o].[OwnedIntValue4] IS NOT NULL THEN [o].[OwnedIntValue4]
END AS [OwnedIntValue4], CASE
WHEN [e].[OwnedReference_Id] IS NOT NULL AND [e].[OwnedReference_OwnedIntValue1] IS NOT NULL AND [e].[OwnedReference_OwnedIntValue2] IS NOT NULL THEN [o].[OwnedStringValue4]
WHEN [e].[OwnedReference_Id] IS NOT NULL AND [e].[OwnedReference_OwnedIntValue1] IS NOT NULL AND [e].[OwnedReference_OwnedIntValue2] IS NOT NULL AND [o0].[OwnedIntValue3] IS NOT NULL AND [o].[OwnedIntValue4] IS NOT NULL THEN [o].[OwnedStringValue4]
END AS [OwnedStringValue4]
FROM [EntityOnes] AS [e]
LEFT JOIN [OwnedReferenceExtras2] AS [o] ON [e].[Id] = [o].[EntityOneId]");
LEFT JOIN [OwnedReferenceExtras2] AS [o] ON [e].[Id] = [o].[EntityOneId]
LEFT JOIN [OwnedReferenceExtras1] AS [o0] ON [e].[Id] = [o0].[EntityOneId]");
}

[ConditionalTheory(Skip = "Issue29075")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,23 @@ public override async Task Owned_entity_with_all_null_properties_materializes_wh
FROM [RotRutCases] AS [r]
ORDER BY [r].[Buyer]");
}

public override async Task Owned_entity_with_all_null_properties_entity_equality_when_not_containing_another_owned_entity(bool async)
{
await base.Owned_entity_with_all_null_properties_entity_equality_when_not_containing_another_owned_entity(async);

AssertSql(
@"SELECT [r].[Id], [r].[Rot_ApartmentNo], [r].[Rot_ServiceType]
FROM [RotRutCases] AS [r]
WHERE [r].[Rot_ApartmentNo] IS NOT NULL AND [r].[Rot_ServiceType] IS NOT NULL");
}

public override async Task Owned_entity_with_all_null_properties_property_access_when_not_containing_another_owned_entity(bool async)
{
await base.Owned_entity_with_all_null_properties_property_access_when_not_containing_another_owned_entity(async);

AssertSql(
@"SELECT [r].[Rot_ApartmentNo]
FROM [RotRutCases] AS [r]");
}
}
Loading

0 comments on commit a27f04f

Please sign in to comment.