Skip to content

Commit

Permalink
Correct nullability for set operations
Browse files Browse the repository at this point in the history
Fixes #18135
  • Loading branch information
roji committed Oct 17, 2019
1 parent 59fbd88 commit d3738d3
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ internal ColumnExpression(ProjectionExpression subqueryProjection, TableExpressi
{
}

// TODO: This returns true for all but the simplest cases, it should drill down instead.
// TODO: Duplication with SelectExpression.ApplySetOperation.isNullableProjection
private static bool IsNullableProjection(ProjectionExpression projectionExpression)
=> projectionExpression.Expression switch
{
Expand Down
42 changes: 34 additions & 8 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -527,10 +527,19 @@ private void ApplySetOperation(SetOperationType setOperationType, SelectExpressi
?? (innerColumn1 as ColumnExpression)?.Name
?? "c");

var innerProjection = new ProjectionExpression(innerColumn1, alias);
select1._projection.Add(innerProjection);
select2._projection.Add(new ProjectionExpression(innerColumn2, alias));
_projectionMapping[joinedMapping.Key] = new ColumnExpression(innerProjection, setExpression);
var innerProjection1 = new ProjectionExpression(innerColumn1, alias);
var innerProjection2 = new ProjectionExpression(innerColumn2, alias);
select1._projection.Add(innerProjection1);
select2._projection.Add(innerProjection2);
var outerProjection = new ColumnExpression(innerProjection1, setExpression);

if (isNullableProjection(innerProjection1)
|| isNullableProjection(innerProjection2))
{
outerProjection = outerProjection.MakeNullable();
}

_projectionMapping[joinedMapping.Key] = outerProjection;
continue;
}

Expand Down Expand Up @@ -575,10 +584,17 @@ ColumnExpression addSetOperationColumnProjections(
SelectExpression select2, ColumnExpression column2)
{
var alias = generateUniqueAlias(column1.Name);
var innerProjection = new ProjectionExpression(column1, alias);
select1._projection.Add(innerProjection);
select2._projection.Add(new ProjectionExpression(column2, alias));
var outerProjection = new ColumnExpression(innerProjection, setExpression);
var innerProjection1 = new ProjectionExpression(column1, alias);
var innerProjection2 = new ProjectionExpression(column2, alias);
select1._projection.Add(innerProjection1);
select2._projection.Add(innerProjection2);
var outerProjection = new ColumnExpression(innerProjection1, setExpression);
if (isNullableProjection(innerProjection1)
|| isNullableProjection(innerProjection2))
{
outerProjection = outerProjection.MakeNullable();
}

if (select1._identifier.Contains(column1))
{
_identifier.Add(outerProjection);
Expand All @@ -598,6 +614,16 @@ string generateUniqueAlias(string baseAlias)

return currentAlias;
}

// TODO: This returns true for all but the simplest cases, it should drill down instead.
// TODO: Duplication with IsNullableProjection.IsNullableProjection
static bool isNullableProjection(ProjectionExpression projectionExpression)
=> projectionExpression.Expression switch
{
ColumnExpression columnExpression => columnExpression.IsNullable,
SqlConstantExpression sqlConstantExpression => sqlConstantExpression.Value == null,
_ => true,
};
}

private ColumnExpression GenerateOuterColumn(SqlExpression projection, string alias = null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5709,5 +5709,21 @@ public virtual Task Null_check_different_structure_does_not_remove_null_checks(b
? null
: l1.OneToOne_Optional_FK1.OneToOne_Optional_FK2.OneToOne_Optional_FK3.Name) == "L4 01"));
}

[ConditionalFact]
public virtual void Union_over_entities_with_different_nullability()
{
using var ctx = CreateContext();

var query = ctx.Set<Level1>()
.GroupJoin(ctx.Set<Level2>(), l1 => l1.Id, l2 => l2.Level1_Optional_Id, (l1, l2s) => new { l1, l2s })
.SelectMany(g => g.l2s.DefaultIfEmpty(), (g, l2) => new { g.l1, l2 })
.Concat(ctx.Set<Level2>().GroupJoin(ctx.Set<Level1>(), l2 => l2.Level1_Optional_Id, l1 => l1.Id, (l2, l1s) => new { l2, l1s })
.SelectMany(g => g.l1s.DefaultIfEmpty(), (g, l1) => new { l1, g.l2 })
.Where(e => e.l1.Equals(null)))
.Select(e => e.l1.Id);

var result = query.ToList();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,17 @@ public virtual Task GroupBy_Select_Union(bool isAsync)
.Select(g => new { CustomerID = g.Key, Count = g.Count() })));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Union_over_columns_with_different_nullability(bool isAsync)
{
return AssertQuery(
isAsync, ss => ss.Set<Customer>()
.Select(c => "NonNullableConstant")
.Concat(ss.Set<Customer>()
.Select(c => (string)null)));
}

[ConditionalTheory]
#pragma warning disable xUnit1016 // MemberData must reference a public member
[MemberData(nameof(GetSetOperandTestCases))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4357,6 +4357,24 @@ ELSE [l2].[Name]
END IS NOT NULL");
}

public override void Union_over_entities_with_different_nullability()
{
base.Union_over_entities_with_different_nullability();

AssertSql(
@"SELECT [t].[Id]
FROM (
SELECT [l].[Id], [l].[Date], [l].[Name], [l].[OneToMany_Optional_Self_Inverse1Id], [l].[OneToMany_Required_Self_Inverse1Id], [l].[OneToOne_Optional_Self1Id], [l0].[Id] AS [Id0], [l0].[Date] AS [Date0], [l0].[Level1_Optional_Id], [l0].[Level1_Required_Id], [l0].[Name] AS [Name0], [l0].[OneToMany_Optional_Inverse2Id], [l0].[OneToMany_Optional_Self_Inverse2Id], [l0].[OneToMany_Required_Inverse2Id], [l0].[OneToMany_Required_Self_Inverse2Id], [l0].[OneToOne_Optional_PK_Inverse2Id], [l0].[OneToOne_Optional_Self2Id]
FROM [LevelOne] AS [l]
LEFT JOIN [LevelTwo] AS [l0] ON [l].[Id] = [l0].[Level1_Optional_Id]
UNION ALL
SELECT [l2].[Id], [l2].[Date], [l2].[Name], [l2].[OneToMany_Optional_Self_Inverse1Id], [l2].[OneToMany_Required_Self_Inverse1Id], [l2].[OneToOne_Optional_Self1Id], [l1].[Id] AS [Id0], [l1].[Date] AS [Date0], [l1].[Level1_Optional_Id], [l1].[Level1_Required_Id], [l1].[Name] AS [Name0], [l1].[OneToMany_Optional_Inverse2Id], [l1].[OneToMany_Optional_Self_Inverse2Id], [l1].[OneToMany_Required_Inverse2Id], [l1].[OneToMany_Required_Self_Inverse2Id], [l1].[OneToOne_Optional_PK_Inverse2Id], [l1].[OneToOne_Optional_Self2Id]
FROM [LevelTwo] AS [l1]
LEFT JOIN [LevelOne] AS [l2] ON [l1].[Level1_Optional_Id] = [l2].[Id]
WHERE [l2].[Id] IS NULL
) AS [t]");
}

private void AssertSql(params string[] expected)
{
Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,18 @@ FROM [Customers] AS [c0]
GROUP BY [c0].[CustomerID]");
}

public override async Task Union_over_columns_with_different_nullability(bool isAsync)
{
await base.Union_over_columns_with_different_nullability(isAsync);

AssertSql(
@"SELECT N'NonNullableConstant' AS [c]
FROM [Customers] AS [c]
UNION ALL
SELECT NULL AS [c]
FROM [Customers] AS [c0]");
}

public override async Task Union_over_different_projection_types(bool isAsync, string leftType, string rightType)
{
await base.Union_over_different_projection_types(isAsync, leftType, rightType);
Expand Down

0 comments on commit d3738d3

Please sign in to comment.