diff --git a/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs b/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs index 7e2fd3d47..0ef5b7e9b 100644 --- a/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs @@ -77,6 +77,65 @@ public NpgsqlSqlTranslatingExpressionVisitor( _timestampTzMapping = _typeMappingSource.FindMapping("timestamp with time zone")!; } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitConditional(ConditionalExpression conditionalExpression) + { + var test = Visit(conditionalExpression.Test); + var ifTrue = Visit(conditionalExpression.IfTrue); + var ifFalse = Visit(conditionalExpression.IfFalse); + + if (TranslationFailed(conditionalExpression.Test, test, out var sqlTest) + || TranslationFailed(conditionalExpression.IfTrue, ifTrue, out var sqlIfTrue) + || TranslationFailed(conditionalExpression.IfFalse, ifFalse, out var sqlIfFalse)) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + // Translate: + // a == b ? null : a -> NULLIF(a, b) + // a != b ? a : null -> NULLIF(a, b) + if (sqlTest is SqlBinaryExpression binary && sqlIfTrue is not null && sqlIfFalse is not null) + { + switch (binary.OperatorType) + { + case ExpressionType.Equal + when ifTrue is SqlConstantExpression { Value: null } && TryTranslateToNullIf(sqlIfFalse, out var nullIfTranslation): + case ExpressionType.NotEqual + when ifFalse is SqlConstantExpression { Value: null } && TryTranslateToNullIf(sqlIfTrue, out nullIfTranslation): + return nullIfTranslation; + } + } + + return _sqlExpressionFactory.Case([new CaseWhenClause(sqlTest!, sqlIfTrue!)], sqlIfFalse); + + bool TryTranslateToNullIf(SqlExpression conditionalResult, [NotNullWhen(true)] out Expression? nullIfTranslation) + { + var (left, right) = (binary.Left, binary.Right); + + if (left.Equals(conditionalResult)) + { + nullIfTranslation = _sqlExpressionFactory.Function( + "NULLIF", [left, right], true, [false, false], left.Type, left.TypeMapping); + return true; + } + + if (right.Equals(conditionalResult)) + { + nullIfTranslation = _sqlExpressionFactory.Function( + "NULLIF", [right, left], true, [false, false], right.Type, right.TypeMapping); + return true; + } + + nullIfTranslation = null; + return false; + } + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in diff --git a/test/EFCore.PG.FunctionalTests/Query/NorthwindFunctionsQueryNpgsqlTest.cs b/test/EFCore.PG.FunctionalTests/Query/NorthwindFunctionsQueryNpgsqlTest.cs index 43c78f8c2..c87a047ff 100644 --- a/test/EFCore.PG.FunctionalTests/Query/NorthwindFunctionsQueryNpgsqlTest.cs +++ b/test/EFCore.PG.FunctionalTests/Query/NorthwindFunctionsQueryNpgsqlTest.cs @@ -843,6 +843,70 @@ GROUP BY o."ProductID" #endregion Statistics + #region NullIf + + [Theory] + [MemberData(nameof(IsAsyncData))] + public async Task NullIf_with_equality_left_sided(bool async) + { + await AssertQuery( + async, + cs => cs.Set().Select(x => x.OrderID == 1 ? (int?)null : x.OrderID)); + + AssertSql( + """ +SELECT NULLIF(o."OrderID", 1) +FROM "Orders" AS o +"""); + } + + [Theory] + [MemberData(nameof(IsAsyncData))] + public async Task NullIf_with_equality_right_sided(bool async) + { + await AssertQuery( + async, + cs => cs.Set().Select(x => 1 == x.OrderID ? (int?)null : x.OrderID)); + + AssertSql( + """ +SELECT NULLIF(o."OrderID", 1) +FROM "Orders" AS o +"""); + } + + [Theory] + [MemberData(nameof(IsAsyncData))] + public async Task NullIf_with_inequality_left_sided(bool async) + { + await AssertQuery( + async, + cs => cs.Set().Select(x => x.OrderID != 1 ? x.OrderID : (int?)null)); + + AssertSql( + """ +SELECT NULLIF(o."OrderID", 1) +FROM "Orders" AS o +"""); + } + + [Theory] + [MemberData(nameof(IsAsyncData))] + public async Task NullIf_with_inequality_right_sided(bool async) + { + await AssertQuery( + async, + cs => cs.Set().Select(x => 1 != x.OrderID ? x.OrderID : (int?)null)); + + AssertSql( + """ +SELECT NULLIF(o."OrderID", 1) +FROM "Orders" AS o +"""); + } + + #endregion + #region Unsupported // PostgreSQL does not have strpos with starting position