Skip to content

Commit

Permalink
Fix nullability in entity equality
Browse files Browse the repository at this point in the history
Fixes #16564
  • Loading branch information
roji committed Aug 5, 2019
1 parent 8744174 commit aa0440e
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,9 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method
var keyListType = typeof(List<>).MakeGenericType(keyProperty.ClrType);
var lambda = Expression.Lambda(
Expression.Call(
_parameterListValueExtractor.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType),
(keyProperty.ClrType.IsValueType
? _parameterListValueTypeExtractor
: _parameterListRefTypeExtractor).MakeGenericMethod(entityType.ClrType, keyProperty.ClrType),
QueryCompilationContext.QueryContextParameter,
Expression.Constant(listParam.Name, typeof(string)),
Expression.Constant(keyProperty, typeof(IProperty))),
Expand All @@ -394,7 +396,7 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method
var param = Expression.Parameter(entityType.ClrType, "v");
var keySelector = Expression.Lambda(CreatePropertyAccessExpression(param, keyProperty), param);
rewrittenSource = Expression.Call(
QueryableMethodProvider.SelectMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType),
QueryableMethodProvider.SelectMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType.MakeNullable()),
Unwrap(newSource),
Expression.Quote(keySelector));
}
Expand All @@ -407,7 +409,7 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method
return Expression.Call(
(Unwrap(newSource).Type.IsQueryableType()
? QueryableMethodProvider.ContainsMethodInfo
: _enumerableContainsMethodInfo).MakeGenericMethod(keyProperty.ClrType),
: _enumerableContainsMethodInfo).MakeGenericMethod(keyProperty.ClrType.MakeNullable()),
rewrittenSource,
rewrittenItem
);
Expand Down Expand Up @@ -466,7 +468,7 @@ protected virtual Expression VisitOrderingMethodCall(MethodCallExpression method
var orderingMethodInfo = GetOrderingMethodInfo(firstOrdering, isAscending);

expression = Expression.Call(
orderingMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType),
orderingMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType.MakeNullable()),
expression,
Expression.Quote(rewrittenKeySelector)
);
Expand Down Expand Up @@ -753,7 +755,7 @@ private Expression RewriteNullEquality(
// (this is also why we can do it even over a subquery with a composite key)
return Expression.MakeBinary(
equality ? ExpressionType.Equal : ExpressionType.NotEqual,
CreatePropertyAccessExpression(nonNullExpression, keyProperties[0], makeNullable: true),
CreatePropertyAccessExpression(nonNullExpression, keyProperties[0]),
Expression.Constant(null));
}

Expand Down Expand Up @@ -853,12 +855,12 @@ protected virtual Expression CreateKeyAccessExpression(
.Cast<Expression>()
.ToArray()));

private Expression CreatePropertyAccessExpression(Expression target, IProperty property, bool makeNullable = false)
private Expression CreatePropertyAccessExpression(Expression target, IProperty property)
{
// The target is a constant - evaluate the property immediately and return the result
if (target is ConstantExpression constantExpression)
{
return Expression.Constant(property.GetGetter().GetClrValue(constantExpression.Value), property.ClrType);
return Expression.Constant(property.GetGetter().GetClrValue(constantExpression.Value), property.ClrType.MakeNullable());
}

// If the target is a query parameter, we can't simply add a property access over it, but must instead cause a new
Expand All @@ -877,10 +879,10 @@ private Expression CreatePropertyAccessExpression(Expression target, IProperty p
QueryCompilationContext.QueryContextParameter);

var newParameterName = $"{RuntimeParameterPrefix}{baseParameterExpression.Name.Substring(CompiledQueryCache.CompiledQueryParameterPrefix.Length)}_{property.Name}";
return _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda, property.ClrType);
return _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda, property.ClrType.MakeNullable());
}

return target.CreateEFPropertyExpression(property, makeNullable);
return target.CreateEFPropertyExpression(property, true);
}

private static object ParameterValueExtractor(QueryContext context, string baseParameterName, IProperty property)
Expand All @@ -898,7 +900,8 @@ private static readonly MethodInfo _parameterValueExtractor
/// Extracts the list parameter with name <paramref name="baseParameterName"/> from <paramref name="context"/> and returns a
/// projection to its elements' <paramref name="property"/> values.
/// </summary>
private static object ParameterListValueExtractor<TEntity, TProperty>(QueryContext context, string baseParameterName, IProperty property)
private static List<TProperty> ParameterListRefTypeExtractor<TEntity, TProperty>(QueryContext context, string baseParameterName, IProperty property)
where TProperty : class
{
Debug.Assert(property.ClrType == typeof(TProperty));

Expand All @@ -912,10 +915,34 @@ private static object ParameterListValueExtractor<TEntity, TProperty>(QueryConte
return baseListParameter.Select(e => (TProperty)getter.GetClrValue(e)).ToList();
}

private static readonly MethodInfo _parameterListValueExtractor
private static readonly MethodInfo _parameterListRefTypeExtractor
= typeof(EntityEqualityRewritingExpressionVisitor)
.GetTypeInfo()
.GetDeclaredMethod(nameof(ParameterListValueExtractor));
.GetDeclaredMethod(nameof(ParameterListRefTypeExtractor));

/// <summary>
/// Extracts the list parameter with name <paramref name="baseParameterName"/> from <paramref name="context"/> and returns a
/// projection to its elements' <paramref name="property"/> values.
/// </summary>
private static List<TProperty?> ParameterListValueTypeExtractor<TEntity, TProperty>(QueryContext context, string baseParameterName, IProperty property)
where TProperty : struct
{
Debug.Assert(property.ClrType == typeof(TProperty));

var baseListParameter = context.ParameterValues[baseParameterName] as IEnumerable<TEntity>;
if (baseListParameter == null)
{
return null;
}

var getter = property.GetGetter();
return baseListParameter.Select(e => (TProperty?)getter.GetClrValue(e)).ToList();
}

private static readonly MethodInfo _parameterListValueTypeExtractor
= typeof(EntityEqualityRewritingExpressionVisitor)
.GetTypeInfo()
.GetDeclaredMethod(nameof(ParameterListValueTypeExtractor));

protected static Expression UnwrapLastNavigation(Expression expression)
=> (expression as MemberExpression)?.Expression
Expand Down Expand Up @@ -1063,16 +1090,16 @@ public virtual void Print(ExpressionPrinter expressionPrinter)

if (IsEntityType)
{
expressionPrinter.StringBuilder.Append($".EntityType({EntityType})");
expressionPrinter.Append($".EntityType({EntityType})");
}
else if (IsDtoType)
{
expressionPrinter.StringBuilder.Append(".DTO");
expressionPrinter.Append(".DTO");
}

if (SubqueryTraversed)
{
expressionPrinter.StringBuilder.Append(".SubqueryTraversed");
expressionPrinter.Append(".SubqueryTraversed");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,6 @@ public override Task Union_over_different_projection_types(bool isAsync, string

#endregion

[ConditionalFact(Skip = "Issue#16564")]
public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
{
base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality();
}

[ConditionalTheory(Skip = "Issue#15711")]
public override Task Include_with_orderby_skip_preserves_ordering(bool isAsync)
{
Expand Down Expand Up @@ -265,24 +259,6 @@ public override void Select_nested_collection_multi_level6()
base.Select_nested_collection_multi_level6();
}

[ConditionalTheory(Skip = "Issue#16564")]
public override Task Where_subquery_FirstOrDefault_compared_to_entity(bool isAsync)
{
return base.Where_subquery_FirstOrDefault_compared_to_entity(isAsync);
}

[ConditionalTheory(Skip = "Issue#16564")]
public override Task Where_query_composition_entity_equality_one_element_FirstOrDefault(bool isAsync)
{
return base.Where_query_composition_entity_equality_one_element_FirstOrDefault(isAsync);
}

[ConditionalTheory(Skip = "Issue#16564")]
public override Task Where_query_composition_entity_equality_no_elements_FirstOrDefault(bool isAsync)
{
return base.Where_query_composition_entity_equality_no_elements_FirstOrDefault(isAsync);
}

[ConditionalTheory(Skip = "Issue#16575")]
public override Task Project_single_element_from_collection_with_OrderBy_Distinct_and_FirstOrDefault_followed_by_projecting_length(bool isAsync)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1176,7 +1176,7 @@ public override void Contains_over_entityType_should_rewrite_to_identity_equalit
FROM [Orders] AS [o]
WHERE [o].[OrderID] = 10248",
//
@"@__entity_equality_p_0_OrderID='10248'
@"@__entity_equality_p_0_OrderID='10248' (Nullable = true)
SELECT CASE
WHEN @__entity_equality_p_0_OrderID IN (
Expand All @@ -1194,7 +1194,7 @@ public override async Task List_Contains_over_entityType_should_rewrite_to_ident
await base.List_Contains_over_entityType_should_rewrite_to_identity_equality(isAsync);

AssertSql(
@"@__entity_equality_someOrder_0_OrderID='10248'
@"@__entity_equality_someOrder_0_OrderID='10248' (Nullable = true)
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
Expand Down Expand Up @@ -1230,7 +1230,7 @@ public override void Contains_over_entityType_with_null_should_rewrite_to_identi
base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality();

AssertSql(
@"@__entity_equality_p_0_OrderID='' (Nullable = false) (DbType = Int32)
@"@__entity_equality_p_0_OrderID='' (DbType = Int32)
SELECT CASE
WHEN @__entity_equality_p_0_OrderID IN (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ public override async Task Entity_equality_local_composite_key(bool isAsync)
await base.Entity_equality_local_composite_key(isAsync);

AssertSql(
@"@__entity_equality_local_0_OrderID='10248'
@__entity_equality_local_0_ProductID='11'
@"@__entity_equality_local_0_OrderID='10248' (Nullable = true)
@__entity_equality_local_0_ProductID='11' (Nullable = true)
SELECT [o].[OrderID], [o].[ProductID], [o].[Discount], [o].[Quantity], [o].[UnitPrice]
FROM [Order Details] AS [o]
Expand Down

0 comments on commit aa0440e

Please sign in to comment.