Skip to content

Commit

Permalink
Add NullIf Translation for Ternary expressions (#3403)
Browse files Browse the repository at this point in the history
Closes #596

Co-authored-by: Shay Rojansky <roji@roji.org>
  • Loading branch information
WhatzGames and roji authored Dec 11, 2024
1 parent c6fb5e2 commit 2397dd8
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,65 @@ public NpgsqlSqlTranslatingExpressionVisitor(
_timestampTzMapping = _typeMappingSource.FindMapping("timestamp with time zone")!;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override Expression VisitConditional(ConditionalExpression conditionalExpression)
{
var test = Visit(conditionalExpression.Test);
var ifTrue = Visit(conditionalExpression.IfTrue);
var ifFalse = Visit(conditionalExpression.IfFalse);

if (TranslationFailed(conditionalExpression.Test, test, out var sqlTest)
|| TranslationFailed(conditionalExpression.IfTrue, ifTrue, out var sqlIfTrue)
|| TranslationFailed(conditionalExpression.IfFalse, ifFalse, out var sqlIfFalse))
{
return QueryCompilationContext.NotTranslatedExpression;
}

// Translate:
// a == b ? null : a -> NULLIF(a, b)
// a != b ? a : null -> NULLIF(a, b)
if (sqlTest is SqlBinaryExpression binary && sqlIfTrue is not null && sqlIfFalse is not null)
{
switch (binary.OperatorType)
{
case ExpressionType.Equal
when ifTrue is SqlConstantExpression { Value: null } && TryTranslateToNullIf(sqlIfFalse, out var nullIfTranslation):
case ExpressionType.NotEqual
when ifFalse is SqlConstantExpression { Value: null } && TryTranslateToNullIf(sqlIfTrue, out nullIfTranslation):
return nullIfTranslation;
}
}

return _sqlExpressionFactory.Case([new CaseWhenClause(sqlTest!, sqlIfTrue!)], sqlIfFalse);

bool TryTranslateToNullIf(SqlExpression conditionalResult, [NotNullWhen(true)] out Expression? nullIfTranslation)
{
var (left, right) = (binary.Left, binary.Right);

if (left.Equals(conditionalResult))
{
nullIfTranslation = _sqlExpressionFactory.Function(
"NULLIF", [left, right], true, [false, false], left.Type, left.TypeMapping);
return true;
}

if (right.Equals(conditionalResult))
{
nullIfTranslation = _sqlExpressionFactory.Function(
"NULLIF", [right, left], true, [false, false], right.Type, right.TypeMapping);
return true;
}

nullIfTranslation = null;
return false;
}
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,70 @@ GROUP BY o."ProductID"

#endregion Statistics

#region NullIf

[Theory]
[MemberData(nameof(IsAsyncData))]
public async Task NullIf_with_equality_left_sided(bool async)
{
await AssertQuery(
async,
cs => cs.Set<Order>().Select(x => x.OrderID == 1 ? (int?)null : x.OrderID));

AssertSql(
"""
SELECT NULLIF(o."OrderID", 1)
FROM "Orders" AS o
""");
}

[Theory]
[MemberData(nameof(IsAsyncData))]
public async Task NullIf_with_equality_right_sided(bool async)
{
await AssertQuery(
async,
cs => cs.Set<Order>().Select(x => 1 == x.OrderID ? (int?)null : x.OrderID));

AssertSql(
"""
SELECT NULLIF(o."OrderID", 1)
FROM "Orders" AS o
""");
}

[Theory]
[MemberData(nameof(IsAsyncData))]
public async Task NullIf_with_inequality_left_sided(bool async)
{
await AssertQuery(
async,
cs => cs.Set<Order>().Select(x => x.OrderID != 1 ? x.OrderID : (int?)null));

AssertSql(
"""
SELECT NULLIF(o."OrderID", 1)
FROM "Orders" AS o
""");
}

[Theory]
[MemberData(nameof(IsAsyncData))]
public async Task NullIf_with_inequality_right_sided(bool async)
{
await AssertQuery(
async,
cs => cs.Set<Order>().Select(x => 1 != x.OrderID ? x.OrderID : (int?)null));

AssertSql(
"""
SELECT NULLIF(o."OrderID", 1)
FROM "Orders" AS o
""");
}

#endregion

#region Unsupported

// PostgreSQL does not have strpos with starting position
Expand Down

0 comments on commit 2397dd8

Please sign in to comment.