Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ public virtual Expression Translate(SelectExpression selectExpression, Expressio
}

return new ProjectionBindingExpression(
_selectExpression, _selectExpression.AddToProjection(translation), expression.Type.MakeNullable());
_selectExpression, _selectExpression.AddToProjection(translation), expression.Type);
}
else
{
Expand All @@ -149,7 +149,7 @@ public virtual Expression Translate(SelectExpression selectExpression, Expressio

_projectionMapping[_projectionMembers.Peek()] = translation;

return new ProjectionBindingExpression(_selectExpression, _projectionMembers.Peek(), expression.Type.MakeNullable());
return new ProjectionBindingExpression(_selectExpression, _projectionMembers.Peek(), expression.Type);
}
}

Expand Down Expand Up @@ -793,12 +793,17 @@ private void VerifySelectExpression(ProjectionBindingExpression projectionBindin

private static Expression MatchTypes(Expression expression, Type targetType)
{
if (targetType != expression.Type
&& targetType.TryGetSequenceType() == null)
if (targetType != expression.Type
&& targetType.TryGetSequenceType() == null)
{
Check.DebugAssert(targetType.MakeNullable() == expression.Type, "expression.Type must be nullable of targetType");

expression = Expression.Convert(expression, targetType);
if (expression is ProjectionBindingExpression projectionBindingExpression)
{
return projectionBindingExpression.UpdateType(targetType);
}
if (targetType.MakeNullable() == expression.Type)
{
expression = Expression.Convert(expression, targetType);
}
}

return expression;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,23 @@ public virtual Expression Translate(SelectExpression selectExpression, Expressio
case ParameterExpression parameterExpression:
throw new InvalidOperationException(CoreStrings.TranslationFailed(parameterExpression.Print()));

case ProjectionBindingExpression projectionBindingExpression:
return _selectExpression.GetProjection(projectionBindingExpression) switch
{
StructuralTypeProjectionExpression projection => AddClientProjection(projection, typeof(ValueBuffer)),
SqlExpression mappedSqlExpression => AddClientProjection(mappedSqlExpression, expression.Type.MakeNullable()),
_ => throw new InvalidOperationException(CoreStrings.TranslationFailed(projectionBindingExpression.Print()))
};
case SqlExpression mappedSqlExpression:
var isNullable = mappedSqlExpression switch
{
ColumnExpression c => c.IsNullable,
SqlFunctionExpression f => f.IsNullable,
AtTimeZoneExpression a => a.IsNullable,
_ => true
};

// Only mark as nullable if the target type can actually be nullable
var shouldBeNullable = isNullable
&& expression.Type.IsValueType
&& !expression.Type.IsNullableType();

return AddClientProjection(
mappedSqlExpression,
shouldBeNullable ? expression.Type.MakeNullable() : expression.Type);

case MaterializeCollectionNavigationExpression materializeCollectionNavigationExpression:
if (materializeCollectionNavigationExpression.Navigation.TargetEntityType.IsMappedToJson())
Expand Down Expand Up @@ -240,11 +250,23 @@ public virtual Expression Translate(SelectExpression selectExpression, Expressio
{
switch (_sqlTranslator.TranslateProjection(expression))
{
case SqlExpression sqlExpression:
_projectionMapping[_projectionMembers.Peek()] = sqlExpression;
return new ProjectionBindingExpression(
_selectExpression, _projectionMembers.Peek(), expression.Type.MakeNullable());

case SqlExpression mappedSqlExpression:
var isNullable = mappedSqlExpression switch
{
ColumnExpression c => c.IsNullable,
SqlFunctionExpression f => f.IsNullable,
AtTimeZoneExpression a => a.IsNullable,
_ => true
};


var shouldBeNullable = isNullable
&& expression.Type.IsValueType
&& !expression.Type.IsNullableType();

return AddClientProjection(
mappedSqlExpression,
shouldBeNullable ? expression.Type.MakeNullable() : expression.Type);
// This handles the case of a complex type being projected out of a Select.
// Note that an entity type being projected is (currently) handled differently
case RelationalStructuralTypeShaperExpression { StructuralType: IComplexType } shaper:
Expand Down Expand Up @@ -687,18 +709,16 @@ private static Expression MatchTypes(Expression expression, Type targetType)
if (targetType != expression.Type
&& targetType.TryGetElementType(typeof(IQueryable<>)) == null)
{
Check.DebugAssert(
targetType.MakeNullable() == expression.Type,
$"expression has type {expression.Type.Name}, but must be nullable over {targetType.Name}");

return expression switch
if (expression is ProjectionBindingExpression projectionBindingExpression)
{
#pragma warning disable EF1001
RelationalStructuralTypeShaperExpression structuralShaper => structuralShaper.MakeClrTypeNonNullable(),
#pragma warning restore EF1001
return projectionBindingExpression.UpdateType(targetType);
}

_ => Expression.Convert(expression, targetType),
};
if (expression is RelationalStructuralTypeShaperExpression structuralShaper)
{
return structuralShaper.MakeClrTypeNonNullable();
}
return Expression.Convert(expression, targetType);
}

return expression;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2908,7 +2908,14 @@ private object GetProjectionIndex(ProjectionBindingExpression projectionBindingE
=> _selectExpression.GetProjection(projectionBindingExpression).GetConstantValue<object>();

private static bool IsNullableProjection(ProjectionExpression projection)
=> projection.Expression is not ColumnExpression column || column.IsNullable;
=> projection.Expression switch
{
ColumnExpression column => column.IsNullable,
SqlFunctionExpression function => function.IsNullable,
AtTimeZoneExpression atTimeZone => atTimeZone.IsNullable,
JsonScalarExpression jsonScalar => jsonScalar.IsNullable,
_ => true
};

private Expression CreateGetValueExpression(
ParameterExpression dbDataReader,
Expand All @@ -2918,6 +2925,10 @@ private Expression CreateGetValueExpression(
Type type,
IPropertyBase? property = null)
{
Check.DebugAssert(
property != null || !nullable || type.IsNullableType(),
"Must read nullable value from database if property is not specified and nullable is true.");

var getMethod = typeMapping.GetDataReaderMethod();

Expression indexExpression = Constant(index);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ public AtTimeZoneExpression(
/// </summary>
public virtual SqlExpression TimeZone { get; }

/// <summary>
/// A bool value indicating if this SQL expression is nullable.
/// </summary>
public virtual bool IsNullable
=> Operand switch
{
ColumnExpression c => c.IsNullable,
SqlFunctionExpression f => f.IsNullable,
JsonScalarExpression j => j.IsNullable,
AtTimeZoneExpression a => a.IsNullable,
_ => true
};

/// <inheritdoc />
protected override Expression VisitChildren(ExpressionVisitor visitor)
{
Expand Down
16 changes: 16 additions & 0 deletions src/EFCore/Query/ProjectionBindingExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,22 @@ void IPrintableExpression.Print(ExpressionPrinter expressionPrinter)
expressionPrinter.Append(Index.ToString()!);
}
}
/// <summary>
/// Creates a new instance of the <see cref="ProjectionBindingExpression" /> class with a new type.
/// </summary>
/// <param name="type">The new clr type of value being read.</param>
/// <returns>A new projection binding expression with the updated type.</returns>
public virtual ProjectionBindingExpression UpdateType(Type type)
{
if (type == Type)
{
return this;
}

return Index != null
? new ProjectionBindingExpression(QueryExpression, Index.Value, type)
: new ProjectionBindingExpression(QueryExpression, ProjectionMember!, type);
}

/// <inheritdoc />
public override bool Equals(object? obj)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.Data.SqlClient;

using Microsoft.EntityFrameworkCore.TestModels.Northwind;
namespace Microsoft.EntityFrameworkCore.Query;

#nullable disable
Expand Down Expand Up @@ -70,6 +70,26 @@ public override async Task SqlQuery_over_int_with_parameter(bool async)
""");
}

[ConditionalFact]
public virtual void Projection_binding_clean_up_non_nullable_value_type()
{
using var context = Fixture.CreateContext();

var query = context.Set<Customer>()
.Select(c => c.CustomerID)
.ToQueryString();
var result = context.Set<Customer>().Select(c => c.CustomerID).ToList();
Assert.NotEmpty(result);
}

[ConditionalFact]
public virtual void Projection_binding_stays_nullable_for_nullable_types()
{
using var context = Fixture.CreateContext();
var result = context.Set<Employee>().Select(e => e.ReportsTo).ToList();

Assert.Contains(null, result);
}
protected override DbParameter CreateDbParameter(string name, object value)
=> new SqlParameter { ParameterName = name, Value = value };

Expand Down
Loading