Skip to content

Commit

Permalink
More fixes to entity equality nullability handling
Browse files Browse the repository at this point in the history
Fixes #16972
  • Loading branch information
roji committed Aug 9, 2019
1 parent b40a38a commit 243a085
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method
{
// The source list is a constant, evaluate and replace with a list of the keys
var listValue = (IEnumerable)listConstant.Value;
var keyListType = typeof(List<>).MakeGenericType(keyProperty.ClrType);
var keyListType = typeof(List<>).MakeGenericType(keyProperty.ClrType.MakeNullable());
var keyList = (IList)Activator.CreateInstance(keyListType);
var getter = keyProperty.GetGetter();
foreach (var listItem in listValue)
Expand All @@ -386,7 +386,6 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method
&& listParam.Name.StartsWith(CompiledQueryCache.CompiledQueryParameterPrefix, StringComparison.Ordinal))
{
// The source list is a parameter. Add a runtime parameter that will contain a list of the extracted keys for each execution.
var keyListType = typeof(List<>).MakeGenericType(keyProperty.ClrType);
var lambda = Expression.Lambda(
Expression.Call(
_parameterListValueExtractor.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType.MakeNullable()),
Expand All @@ -397,7 +396,10 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method
);

var newParameterName = $"{RuntimeParameterPrefix}{listParam.Name.Substring(CompiledQueryCache.CompiledQueryParameterPrefix.Length)}_{keyProperty.Name}";
rewrittenSource = _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda, keyListType);
rewrittenSource = _queryCompilationContext.RegisterRuntimeParameter(
newParameterName,
lambda,
typeof(List<>).MakeGenericType(keyProperty.ClrType.MakeNullable()));
}
else
{
Expand Down Expand Up @@ -911,8 +913,6 @@ private static readonly MethodInfo _parameterValueExtractor
/// </summary>
private static List<TProperty> ParameterListValueExtractor<TEntity, TProperty>(QueryContext context, string baseParameterName, IProperty property)
{
Debug.Assert(property.ClrType == typeof(TProperty));

var baseListParameter = context.ParameterValues[baseParameterName] as IEnumerable<TEntity>;
if (baseListParameter == null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1278,6 +1278,28 @@ FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
}

[ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")]
public override async Task Contains_with_parameter_list_value_type_id(bool isAsync)
{
await base.Contains_with_parameter_list_value_type_id(isAsync);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
}

[ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")]
public override async Task Contains_with_constant_list_value_type_id(bool isAsync)
{
await base.Contains_with_constant_list_value_type_id(isAsync);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
}

[ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")]
public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,32 @@ public virtual Task List_Contains_with_parameter_list(bool isAsync)
entryCount: 2);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_with_parameter_list_value_type_id(bool isAsync)
{
var orders = new List<Order>
{
new Order { OrderID = 10248 },
new Order { OrderID = 10249 }
};

return AssertQuery<Order>(isAsync, od => od.Where(o => orders.Contains(o)),
entryCount: 2);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_with_constant_list_value_type_id(bool isAsync)
{
return AssertQuery<Order>(isAsync, od => od.Where(o => new List<Order>
{
new Order { OrderID = 10248 },
new Order { OrderID = 10249 }
}.Contains(o)),
entryCount: 2);
}

[ConditionalFact]
public virtual void Contains_over_keyless_entity_throws()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1225,6 +1225,26 @@ FROM [Customers] AS [c]
WHERE [c].[CustomerID] IN (N'ALFKI', N'ANATR')");
}

public override async Task Contains_with_parameter_list_value_type_id(bool isAsync)
{
await base.Contains_with_parameter_list_value_type_id(isAsync);

AssertSql(
@"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
WHERE [o].[OrderID] IN (10248, 10249)");
}

public override async Task Contains_with_constant_list_value_type_id(bool isAsync)
{
await base.Contains_with_constant_list_value_type_id(isAsync);

AssertSql(
@"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
WHERE [o].[OrderID] IN (10248, 10249)");
}

public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
{
base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality();
Expand Down

0 comments on commit 243a085

Please sign in to comment.