Skip to content

Commit

Permalink
Use null propagation to optimize away IS NOT NULL checks
Browse files Browse the repository at this point in the history
When a CASE expression simply replicates SQL null propagation, simplify it.
  • Loading branch information
ranma42 committed Jul 3, 2024
1 parent 43d19de commit 738b4f8
Showing 1 changed file with 87 additions and 0 deletions.
87 changes: 87 additions & 0 deletions src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,94 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
elseResult = null;
}

// optimize expressions such as expr != null ? expr : null and expr == null ? null : expr
if (testIsCondition && whenClauses is [var clause] && (elseResult == null || IsNull(clause.Result)))
{
HashSet<SqlExpression> nullPropagatedOperands = [];
SqlExpression test, expr;

if (elseResult == null)
{
expr = clause.Result;
test = clause.Test;
}
else
{
expr = elseResult;
test = _sqlExpressionFactory.Not(clause.Test);
}

NullPropagatedOperands(expr, nullPropagatedOperands);
test = DropNotNullChecks(test, nullPropagatedOperands);

if (IsTrue(test))
{
return expr;
}

if (elseResult != null)
{
test = _sqlExpressionFactory.Not(test);
}

whenClauses = [new(test, clause.Result)];
}

return caseExpression.Update(operand, whenClauses, elseResult);

SqlExpression DropNotNullChecks(SqlExpression expression, HashSet<SqlExpression> nullPropagatedOperands)
=> expression switch
{
SqlUnaryExpression { OperatorType: ExpressionType.NotEqual } isNotNull
when nullPropagatedOperands.Contains(isNotNull.Operand)
=> _sqlExpressionFactory.Constant(true, expression.Type, expression.TypeMapping),

SqlBinaryExpression { OperatorType: ExpressionType.AndAlso } binary
=> _sqlExpressionFactory.MakeBinary(
ExpressionType.AndAlso,
DropNotNullChecks(binary.Left, nullPropagatedOperands),
DropNotNullChecks(binary.Right, nullPropagatedOperands),
expression.TypeMapping,
expression)!,

_ => expression,
};

// FIXME: unify nullability computations
static void NullPropagatedOperands(SqlExpression expression, HashSet<SqlExpression> operands)
{
operands.Add(expression);

if (expression is SqlUnaryExpression unary
&& unary.OperatorType is ExpressionType.Not or ExpressionType.Negate or ExpressionType.Convert)
{
NullPropagatedOperands(unary.Operand, operands);
}
else if (expression is SqlBinaryExpression binary
&& binary.OperatorType is not (ExpressionType.AndAlso or ExpressionType.OrElse))
{
NullPropagatedOperands(binary.Left, operands);
NullPropagatedOperands(binary.Right, operands);
}
else if (expression is SqlFunctionExpression { IsNullable: true } func)
{
if (func.InstancePropagatesNullability == true)
{
NullPropagatedOperands(func.Instance!, operands);
}

if (!func.IsNiladic)
{
for (var i = 0; i < func.ArgumentsPropagateNullability.Count; i++)
{
if (func.ArgumentsPropagateNullability[i])
{
NullPropagatedOperands(func.Arguments[i], operands);
}
}
}
}
}
}

/// <summary>
Expand Down

0 comments on commit 738b4f8

Please sign in to comment.