Skip to content

Commit

Permalink
Implement improvements from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
ranma42 committed Jul 12, 2024
1 parent bbe125a commit c2b1eff
Showing 1 changed file with 31 additions and 39 deletions.
70 changes: 31 additions & 39 deletions src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -579,23 +579,15 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
}

// optimize expressions such as expr != null ? expr : null and expr == null ? null : expr
if (testIsCondition && whenClauses is [var clause] && (elseResult == null || IsNull(clause.Result)))
if (testIsCondition && whenClauses is [var clause] && (elseResult is null || IsNull(clause.Result)))
{
HashSet<SqlExpression> nullPropagatedOperands = [];
SqlExpression test, expr;

if (elseResult == null)
{
expr = clause.Result;
test = clause.Test;
}
else
{
expr = elseResult;
test = _sqlExpressionFactory.Not(clause.Test);
}
var (test, expr) = elseResult is null
? (clause.Result, clause.Test)
: (_sqlExpressionFactory.Not(clause.Test), elseResult);

NullPropagatedOperands(expr, nullPropagatedOperands);
DetectNullPropagatingNodes(expr, nullPropagatedOperands);
test = DropNotNullChecks(test, nullPropagatedOperands);

if (IsTrue(test))
Expand Down Expand Up @@ -632,38 +624,38 @@ when nullPropagatedOperands.Contains(isNotNull.Operand)
};

// FIXME: unify nullability computations
static void NullPropagatedOperands(SqlExpression expression, HashSet<SqlExpression> operands)
static void DetectNullPropagatingNodes(SqlExpression expression, HashSet<SqlExpression> operands)
{
operands.Add(expression);

if (expression is SqlUnaryExpression unary
&& unary.OperatorType is ExpressionType.Not or ExpressionType.Negate or ExpressionType.Convert)
switch (expression)
{
NullPropagatedOperands(unary.Operand, operands);
}
else if (expression is SqlBinaryExpression binary
&& binary.OperatorType is not (ExpressionType.AndAlso or ExpressionType.OrElse))
{
NullPropagatedOperands(binary.Left, operands);
NullPropagatedOperands(binary.Right, operands);
}
else if (expression is SqlFunctionExpression { IsNullable: true } func)
{
if (func.InstancePropagatesNullability == true)
{
NullPropagatedOperands(func.Instance!, operands);
}
case SqlUnaryExpression { OperatorType: not (ExpressionType.Equal or ExpressionType.NotEqual) } unary:
DetectNullPropagatingNodes(unary.Operand, operands);
break;

if (!func.IsNiladic)
{
for (var i = 0; i < func.ArgumentsPropagateNullability.Count; i++)
case SqlBinaryExpression { OperatorType: not (ExpressionType.AndAlso or ExpressionType.OrElse) } binary:
DetectNullPropagatingNodes(binary.Left, operands);
DetectNullPropagatingNodes(binary.Right, operands);
break;

case SqlFunctionExpression { IsNullable: true } func:
if (func.InstancePropagatesNullability is true)
{
if (func.ArgumentsPropagateNullability[i])
DetectNullPropagatingNodes(func.Instance!, operands);
}

if (!func.IsNiladic)
{
for (var i = 0; i < func.ArgumentsPropagateNullability.Count; i++)
{
NullPropagatedOperands(func.Arguments[i], operands);
if (func.ArgumentsPropagateNullability[i])
{
DetectNullPropagatingNodes(func.Arguments[i], operands);
}
}
}
}
break;
}
}
}
Expand Down Expand Up @@ -913,9 +905,9 @@ protected virtual SqlExpression VisitIn(InExpression inExpression, bool allowOpt

return inExpression.Values! switch
{
[] => _sqlExpressionFactory.Constant(false, inExpression.TypeMapping),
[var v] => _sqlExpressionFactory.Equal(inExpression.Item, v),
[..] => inExpression
[] => _sqlExpressionFactory.Constant(false, inExpression.TypeMapping),
[var v] => _sqlExpressionFactory.Equal(inExpression.Item, v),
[..] => inExpression
};
}

Expand Down

0 comments on commit c2b1eff

Please sign in to comment.