Skip to content

Commit

Permalink
Query: Throw exception for null key value in non-tracking (#26311)
Browse files Browse the repository at this point in the history
Resolves #26310
  • Loading branch information
smitpatel authored Oct 28, 2021
1 parent 2d2fcdc commit 5fb9f37
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 3 deletions.
61 changes: 58 additions & 3 deletions src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.ChangeTracking.Internal;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
Expand Down Expand Up @@ -324,6 +325,9 @@ private static readonly MethodInfo _startTrackingMethodInfo
= typeof(QueryContext).GetRequiredMethod(
nameof(QueryContext.StartTracking), typeof(IEntityType), typeof(object), typeof(ValueBuffer));

private static readonly MethodInfo _createNullKeyValueInNoTrackingQuery
= typeof(EntityMaterializerInjectingExpressionVisitor).GetRequiredDeclaredMethod(nameof(CreateNullKeyValueInNoTrackingQuery));

private readonly IEntityMaterializerSource _entityMaterializerSource;
private readonly QueryTrackingBehavior _queryTrackingBehavior;
private readonly bool _queryStateMananger;
Expand Down Expand Up @@ -458,16 +462,47 @@ private Expression ProcessEntityShaper(EntityShaperExpression entityShaperExpres
{
if (primaryKey != null)
{
expressions.Add(
Expression.IfThen(
primaryKey.Properties.Select(
if (entityShaperExpression.IsNullable)
{
expressions.Add(
Expression.IfThen(
primaryKey.Properties.Select(
p => Expression.NotEqual(
valueBufferExpression.CreateValueBufferReadValueExpression(typeof(object), p.GetIndex(), p),
Expression.Constant(null)))
.Aggregate((a, b) => Expression.AndAlso(a, b)),
MaterializeEntity(
entityShaperExpression, materializationContextVariable, concreteEntityTypeVariable, instanceVariable,
null)));
}
else
{
var keyValuesVariable = Expression.Variable(typeof(object[]), "keyValues" + _currentEntityIndex);
expressions.Add(
Expression.IfThenElse(
primaryKey.Properties.Select(
p => Expression.NotEqual(
valueBufferExpression.CreateValueBufferReadValueExpression(typeof(object), p.GetIndex(), p),
Expression.Constant(null)))
.Aggregate((a, b) => Expression.AndAlso(a, b)),
MaterializeEntity(
entityShaperExpression, materializationContextVariable, concreteEntityTypeVariable, instanceVariable,
null),
Expression.Block(
new[] { keyValuesVariable },
Expression.Assign(
keyValuesVariable,
Expression.NewArrayInit(
typeof(object),
primaryKey.Properties.Select(
p => valueBufferExpression.CreateValueBufferReadValueExpression(typeof(object), p.GetIndex(), p)))),
Expression.Call(
_createNullKeyValueInNoTrackingQuery,
Expression.Constant(entityType),
Expression.Constant(primaryKey.Properties),
keyValuesVariable))));

}
}
else
{
Expand Down Expand Up @@ -605,6 +640,26 @@ private BlockExpression CreateFullMaterializeExpression(

return Expression.Block(blockExpressions);
}

[UsedImplicitly]
private static Exception CreateNullKeyValueInNoTrackingQuery(
IEntityType entityType, IReadOnlyList<IProperty> properties, object?[] keyValues)
{
var index = -1;
for (var i = 0; i < keyValues.Length; i++)
{
if (keyValues[i] == null)
{
index = i;
break;
}
}

var property = properties[index];

throw new InvalidOperationException(
CoreStrings.InvalidKeyValue(entityType.DisplayName(), property.Name));
}
}
}
}
18 changes: 18 additions & 0 deletions test/EFCore.Cosmos.FunctionalTests/Query/FromSqlQueryCosmosTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,24 @@ public virtual async Task FromSqlRaw_queryable_simple_projection_not_composed(bo
) c");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public async Task FromSqlRaw_queryable_simple_with_missing_key_and_non_tracking_throws(bool async)
{
using var context = CreateContext();
var query = context.Set<Customer>()
.FromSqlRaw(@"SELECT * FROM root c WHERE c[""Discriminator""] = ""Category""")
.AsNoTracking();
var exception = async
? await Assert.ThrowsAsync<InvalidOperationException>(() => query.ToArrayAsync())
: Assert.Throws<InvalidOperationException>(() => query.ToArray());

Assert.Equal(CoreStrings.InvalidKeyValue(
context.Model.FindEntityType(typeof(Customer))!.DisplayName(),
"CustomerID"),
exception.Message);
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down

0 comments on commit 5fb9f37

Please sign in to comment.