diff --git a/src/EFCore/Internal/EntityFinder.cs b/src/EFCore/Internal/EntityFinder.cs index 5bb926101d2..9279c388b3d 100644 --- a/src/EFCore/Internal/EntityFinder.cs +++ b/src/EFCore/Internal/EntityFinder.cs @@ -46,10 +46,18 @@ public EntityFinder( /// doing so can result in application failures when updating to a new Entity Framework Core release. /// public virtual TEntity? Find(object?[]? keyValues) - => keyValues == null || keyValues.Any(v => v == null) - ? null - : (FindTracked(keyValues!, out var keyProperties) - ?? _queryRoot.FirstOrDefault(BuildLambda(keyProperties, new ValueBuffer(keyValues)))); + { + if (keyValues == null + || keyValues.Any(v => v == null)) + { + return default; + } + + var (key, processedKeyValues, _) = ValidateKeyPropertiesAndExtractCancellationToken(keyValues!, async: false, default); + + return FindTracked(key, processedKeyValues) + ?? _queryRoot.FirstOrDefault(BuildLambda(key.Properties, new ValueBuffer(processedKeyValues))); + } /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -74,11 +82,13 @@ public EntityFinder( return default; } - var tracked = FindTracked(keyValues!, out var keyProperties); + var (key, processedKeyValues, ct) = ValidateKeyPropertiesAndExtractCancellationToken(keyValues!, async: true, cancellationToken); + + var tracked = FindTracked(key, processedKeyValues); return tracked != null ? new ValueTask(tracked) : new ValueTask( - _queryRoot.FirstOrDefaultAsync(BuildLambda(keyProperties, new ValueBuffer(keyValues)), cancellationToken)); + _queryRoot.FirstOrDefaultAsync(BuildLambda(key.Properties, new ValueBuffer(processedKeyValues)), ct)); } /// @@ -95,12 +105,14 @@ public EntityFinder( return default; } - var tracked = FindTracked(keyValues!, out var keyProperties); + var (key, processedKeyValues, ct) = ValidateKeyPropertiesAndExtractCancellationToken(keyValues!, async: true, cancellationToken); + + var tracked = FindTracked(key, processedKeyValues); return tracked != null ? new ValueTask(tracked) : new ValueTask( _queryRoot.FirstOrDefaultAsync( - BuildObjectLambda(keyProperties, new ValueBuffer(keyValues)), cancellationToken)); + BuildObjectLambda(key.Properties, new ValueBuffer(processedKeyValues)), ct)); } /// @@ -259,23 +271,41 @@ private static IReadOnlyList GetLoadProperties(INavigation navigation ? navigation.ForeignKey.PrincipalKey.Properties : navigation.ForeignKey.Properties; - private TEntity? FindTracked(object[] keyValues, out IReadOnlyList keyProperties) + private (IKey Key, object[] KeyValues,CancellationToken CancellationToken) ValidateKeyPropertiesAndExtractCancellationToken( + object[] keyValues, + bool async, + CancellationToken cancellationToken) { var key = _entityType.FindPrimaryKey()!; - keyProperties = key.Properties; + var keyPropertiesCount = key.Properties.Count; - if (keyProperties.Count != keyValues.Length) + if (keyPropertiesCount != keyValues.Length) { - if (keyProperties.Count == 1) + if (async + && keyPropertiesCount == keyValues.Length - 1 + && keyValues[keyPropertiesCount] is CancellationToken ct) + { + var newValues = new object[keyPropertiesCount]; + Array.Copy(keyValues, newValues, keyPropertiesCount); + return (key, newValues, ct); + } + + if (keyPropertiesCount == 1) { throw new ArgumentException( CoreStrings.FindNotCompositeKey(typeof(TEntity).ShortDisplayName(), keyValues.Length)); } throw new ArgumentException( - CoreStrings.FindValueCountMismatch(typeof(TEntity).ShortDisplayName(), keyProperties.Count, keyValues.Length)); + CoreStrings.FindValueCountMismatch(typeof(TEntity).ShortDisplayName(), keyPropertiesCount, keyValues.Length)); } + return (key, keyValues, cancellationToken); + } + + private TEntity? FindTracked(IKey key, object[] keyValues) + { + var keyProperties = key.Properties; for (var i = 0; i < keyValues.Length; i++) { var valueType = keyValues[i].GetType(); diff --git a/test/EFCore.Cosmos.FunctionalTests/FindCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/FindCosmosTest.cs index 8b479c44d49..ed891f7c328 100644 --- a/test/EFCore.Cosmos.FunctionalTests/FindCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/FindCosmosTest.cs @@ -14,15 +14,15 @@ protected FindCosmosTest(FindCosmosFixture fixture) [ConditionalFact(Skip = "#25886")] public override void Find_base_type_using_derived_set_tracked() { } - [ConditionalFact(Skip = "#25886")] - public override Task Find_base_type_using_derived_set_tracked_async() + [ConditionalTheory(Skip = "#25886")] + public override Task Find_base_type_using_derived_set_tracked_async(CancellationType cancellationType) => Task.CompletedTask; [ConditionalFact(Skip = "#25886")] public override void Find_derived_using_base_set_type_from_store() { } - [ConditionalFact(Skip = "#25886")] - public override Task Find_derived_using_base_set_type_from_store_async() + [ConditionalTheory(Skip = "#25886")] + public override Task Find_derived_using_base_set_type_from_store_async(CancellationType cancellationType) => Task.CompletedTask; public class FindCosmosTestSet : FindCosmosTest @@ -32,11 +32,7 @@ public FindCosmosTestSet(FindCosmosFixture fixture) { } - protected override TEntity Find(DbContext context, params object[] keyValues) - => context.Set().Find(keyValues); - - protected override ValueTask FindAsync(DbContext context, params object[] keyValues) - => context.Set().FindAsync(keyValues); + protected override TestFinder Finder { get; } = new FindViaSetFinder(); } public class FindCosmosTestContext : FindCosmosTest @@ -46,11 +42,7 @@ public FindCosmosTestContext(FindCosmosFixture fixture) { } - protected override TEntity Find(DbContext context, params object[] keyValues) - => context.Find(keyValues); - - protected override ValueTask FindAsync(DbContext context, params object[] keyValues) - => context.FindAsync(keyValues); + protected override TestFinder Finder { get; } = new FindViaContextFinder(); } public class FindCosmosTestNonGeneric : FindCosmosTest @@ -60,11 +52,7 @@ public FindCosmosTestNonGeneric(FindCosmosFixture fixture) { } - protected override TEntity Find(DbContext context, params object[] keyValues) - => (TEntity)context.Find(typeof(TEntity), keyValues); - - protected override async ValueTask FindAsync(DbContext context, params object[] keyValues) - => (TEntity)await context.FindAsync(typeof(TEntity), keyValues); + protected override TestFinder Finder { get; } = new FindViaNonGenericContextFinder(); } public class FindCosmosFixture : FindFixtureBase diff --git a/test/EFCore.InMemory.FunctionalTests/FindInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/FindInMemoryTest.cs index e1a2d61219b..8af4437aca6 100644 --- a/test/EFCore.InMemory.FunctionalTests/FindInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/FindInMemoryTest.cs @@ -17,11 +17,7 @@ public FindInMemoryTestSet(FindInMemoryFixture fixture) { } - protected override TEntity Find(DbContext context, params object[] keyValues) - => context.Set().Find(keyValues); - - protected override ValueTask FindAsync(DbContext context, params object[] keyValues) - => context.Set().FindAsync(keyValues); + protected override TestFinder Finder { get; } = new FindViaSetFinder(); } public class FindInMemoryTestContext : FindInMemoryTest @@ -31,11 +27,7 @@ public FindInMemoryTestContext(FindInMemoryFixture fixture) { } - protected override TEntity Find(DbContext context, params object[] keyValues) - => context.Find(keyValues); - - protected override ValueTask FindAsync(DbContext context, params object[] keyValues) - => context.FindAsync(keyValues); + protected override TestFinder Finder { get; } = new FindViaContextFinder(); } public class FindInMemoryTestNonGeneric : FindInMemoryTest @@ -45,11 +37,7 @@ public FindInMemoryTestNonGeneric(FindInMemoryFixture fixture) { } - protected override TEntity Find(DbContext context, params object[] keyValues) - => (TEntity)context.Find(typeof(TEntity), keyValues); - - protected override async ValueTask FindAsync(DbContext context, params object[] keyValues) - => (TEntity)await context.FindAsync(typeof(TEntity), keyValues); + protected override TestFinder Finder { get; } = new FindViaNonGenericContextFinder(); } public class FindInMemoryFixture : FindFixtureBase diff --git a/test/EFCore.Specification.Tests/FindTestBase.cs b/test/EFCore.Specification.Tests/FindTestBase.cs index 5c95e71a1f5..b7ff1feb958 100644 --- a/test/EFCore.Specification.Tests/FindTestBase.cs +++ b/test/EFCore.Specification.Tests/FindTestBase.cs @@ -16,11 +16,7 @@ protected FindTestBase(TFixture fixture) protected TFixture Fixture { get; } - protected abstract TEntity Find(DbContext context, params object[] keyValues) - where TEntity : class; - - protected abstract ValueTask FindAsync(DbContext context, params object[] keyValues) - where TEntity : class; + protected abstract TestFinder Finder { get; } [ConditionalFact] public virtual void Find_int_key_tracked() @@ -29,21 +25,21 @@ public virtual void Find_int_key_tracked() var entity = context.Attach( new IntKey { Id = 88 }).Entity; - Assert.Same(entity, Find(context, 88)); + Assert.Same(entity, Finder.Find(context, 88)); } [ConditionalFact] public virtual void Find_int_key_from_store() { using var context = CreateContext(); - Assert.Equal("Smokey", Find(context, 77).Foo); + Assert.Equal("Smokey", Finder.Find(context, 77).Foo); } [ConditionalFact] public virtual void Returns_null_for_int_key_not_in_store() { using var context = CreateContext(); - Assert.Null(Find(context, 99)); + Assert.Null(Finder.Find(context, 99)); } [ConditionalFact] @@ -53,21 +49,21 @@ public virtual void Find_nullable_int_key_tracked() var entity = context.Attach( new NullableIntKey { Id = 88 }).Entity; - Assert.Same(entity, Find(context, 88)); + Assert.Same(entity, Finder.Find(context, 88)); } [ConditionalFact] public virtual void Find_nullable_int_key_from_store() { using var context = CreateContext(); - Assert.Equal("Smokey", Find(context, 77).Foo); + Assert.Equal("Smokey", Finder.Find(context, 77).Foo); } [ConditionalFact] public virtual void Returns_null_for_nullable_int_key_not_in_store() { using var context = CreateContext(); - Assert.Null(Find(context, 99)); + Assert.Null(Finder.Find(context, 99)); } [ConditionalFact] @@ -77,21 +73,21 @@ public virtual void Find_string_key_tracked() var entity = context.Attach( new StringKey { Id = "Rabbit" }).Entity; - Assert.Same(entity, Find(context, "Rabbit")); + Assert.Same(entity, Finder.Find(context, "Rabbit")); } [ConditionalFact] public virtual void Find_string_key_from_store() { using var context = CreateContext(); - Assert.Equal("Alice", Find(context, "Cat").Foo); + Assert.Equal("Alice", Finder.Find(context, "Cat").Foo); } [ConditionalFact] public virtual void Returns_null_for_string_key_not_in_store() { using var context = CreateContext(); - Assert.Null(Find(context, "Fox")); + Assert.Null(Finder.Find(context, "Fox")); } [ConditionalFact] @@ -101,21 +97,21 @@ public virtual void Find_composite_key_tracked() var entity = context.Attach( new CompositeKey { Id1 = 88, Id2 = "Rabbit" }).Entity; - Assert.Same(entity, Find(context, 88, "Rabbit")); + Assert.Same(entity, Finder.Find(context, 88, "Rabbit")); } [ConditionalFact] public virtual void Find_composite_key_from_store() { using var context = CreateContext(); - Assert.Equal("Olive", Find(context, 77, "Dog").Foo); + Assert.Equal("Olive", Finder.Find(context, 77, "Dog").Foo); } [ConditionalFact] public virtual void Returns_null_for_composite_key_not_in_store() { using var context = CreateContext(); - Assert.Null(Find(context, 77, "Fox")); + Assert.Null(Finder.Find(context, 77, "Fox")); } [ConditionalFact] @@ -125,21 +121,21 @@ public virtual void Find_base_type_tracked() var entity = context.Attach( new BaseType { Id = 88 }).Entity; - Assert.Same(entity, Find(context, 88)); + Assert.Same(entity, Finder.Find(context, 88)); } [ConditionalFact] public virtual void Find_base_type_from_store() { using var context = CreateContext(); - Assert.Equal("Baxter", Find(context, 77).Foo); + Assert.Equal("Baxter", Finder.Find(context, 77).Foo); } [ConditionalFact] public virtual void Returns_null_for_base_type_not_in_store() { using var context = CreateContext(); - Assert.Null(Find(context, 99)); + Assert.Null(Finder.Find(context, 99)); } [ConditionalFact] @@ -149,14 +145,14 @@ public virtual void Find_derived_type_tracked() var entity = context.Attach( new DerivedType { Id = 88 }).Entity; - Assert.Same(entity, Find(context, 88)); + Assert.Same(entity, Finder.Find(context, 88)); } [ConditionalFact] public virtual void Find_derived_type_from_store() { using var context = CreateContext(); - var derivedType = Find(context, 78); + var derivedType = Finder.Find(context, 78); Assert.Equal("Strawberry", derivedType.Foo); Assert.Equal("Cheesecake", derivedType.Boo); } @@ -165,7 +161,7 @@ public virtual void Find_derived_type_from_store() public virtual void Returns_null_for_derived_type_not_in_store() { using var context = CreateContext(); - Assert.Null(Find(context, 99)); + Assert.Null(Finder.Find(context, 99)); } [ConditionalFact] @@ -175,14 +171,14 @@ public virtual void Find_base_type_using_derived_set_tracked() context.Attach( new BaseType { Id = 88 }); - Assert.Null(Find(context, 88)); + Assert.Null(Finder.Find(context, 88)); } [ConditionalFact] public virtual void Find_base_type_using_derived_set_from_store() { using var context = CreateContext(); - Assert.Null(Find(context, 77)); + Assert.Null(Finder.Find(context, 77)); } [ConditionalFact] @@ -192,14 +188,14 @@ public virtual void Find_derived_type_using_base_set_tracked() var entity = context.Attach( new DerivedType { Id = 88 }).Entity; - Assert.Same(entity, Find(context, 88)); + Assert.Same(entity, Finder.Find(context, 88)); } [ConditionalFact] public virtual void Find_derived_using_base_set_type_from_store() { using var context = CreateContext(); - var derivedType = Find(context, 78); + var derivedType = Finder.Find(context, 78); Assert.Equal("Strawberry", derivedType.Foo); Assert.Equal("Cheesecake", ((DerivedType)derivedType).Boo); } @@ -212,49 +208,49 @@ public virtual void Find_shadow_key_tracked() entry.Property("Id").CurrentValue = 88; entry.State = EntityState.Unchanged; - Assert.Same(entry.Entity, Find(context, 88)); + Assert.Same(entry.Entity, Finder.Find(context, 88)); } [ConditionalFact] public virtual void Find_shadow_key_from_store() { using var context = CreateContext(); - Assert.Equal("Clippy", Find(context, 77).Foo); + Assert.Equal("Clippy", Finder.Find(context, 77).Foo); } [ConditionalFact] public virtual void Returns_null_for_shadow_key_not_in_store() { using var context = CreateContext(); - Assert.Null(Find(context, 99)); + Assert.Null(Finder.Find(context, 99)); } [ConditionalFact] public virtual void Returns_null_for_null_key_values_array() { using var context = CreateContext(); - Assert.Null(Find(context, null)); + Assert.Null(Finder.Find(context, null)); } [ConditionalFact] public virtual void Returns_null_for_null_key() { using var context = CreateContext(); - Assert.Null(Find(context, new object[] { null })); + Assert.Null(Finder.Find(context, new object[] { null })); } [ConditionalFact] public virtual void Returns_null_for_null_nullable_key() { using var context = CreateContext(); - Assert.Null(Find(context, new object[] { null })); + Assert.Null(Finder.Find(context, new object[] { null })); } [ConditionalFact] public virtual void Returns_null_for_null_in_composite_key() { using var context = CreateContext(); - Assert.Null(Find(context, 77, null)); + Assert.Null(Finder.Find(context, 77, null)); } [ConditionalFact] @@ -263,7 +259,7 @@ public virtual void Throws_for_multiple_values_passed_for_simple_key() using var context = CreateContext(); Assert.Equal( CoreStrings.FindNotCompositeKey("IntKey", 2), - Assert.Throws(() => Find(context, 77, 88)).Message); + Assert.Throws(() => Finder.Find(context, 77, 88)).Message); } [ConditionalFact] @@ -272,7 +268,7 @@ public virtual void Throws_for_wrong_number_of_values_for_composite_key() using var context = CreateContext(); Assert.Equal( CoreStrings.FindValueCountMismatch("CompositeKey", 2, 1), - Assert.Throws(() => Find(context, 77)).Message); + Assert.Throws(() => Finder.Find(context, 77)).Message); } [ConditionalFact] @@ -281,7 +277,7 @@ public virtual void Throws_for_bad_type_for_simple_key() using var context = CreateContext(); Assert.Equal( CoreStrings.FindValueTypeMismatch(0, "IntKey", "string", "int"), - Assert.Throws(() => Find(context, "77")).Message); + Assert.Throws(() => Finder.Find(context, "77")).Message); } [ConditionalFact] @@ -290,7 +286,7 @@ public virtual void Throws_for_bad_type_for_composite_key() using var context = CreateContext(); Assert.Equal( CoreStrings.FindValueTypeMismatch(1, "CompositeKey", "int", "string"), - Assert.Throws(() => Find(context, 77, 88)).Message); + Assert.Throws(() => Finder.Find(context, 77, 88)).Message); } [ConditionalFact] @@ -300,7 +296,7 @@ public virtual void Throws_for_bad_entity_type() Assert.Equal( CoreStrings.InvalidSetType(nameof(Random)), - Assert.Throws(() => Find(context, 77)).Message); + Assert.Throws(() => Finder.Find(context, 77)).Message); } [ConditionalFact] @@ -311,296 +307,414 @@ public virtual void Throws_for_bad_entity_type_with_different_namespace() Assert.Equal( CoreStrings.InvalidSetSameTypeWithDifferentNamespace( typeof(DifferentNamespace.ShadowKey).DisplayName(), typeof(ShadowKey).DisplayName()), - Assert.Throws(() => Find(context, 77)).Message); + Assert.Throws(() => Finder.Find(context, 77)).Message); } - [ConditionalFact] - public virtual async Task Find_int_key_tracked_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Find_int_key_tracked_async(CancellationType cancellationType) { using var context = CreateContext(); var entity = context.Attach( new IntKey { Id = 88 }).Entity; - var valueTask = FindAsync(context, 88); + var valueTask = Finder.FindAsync(cancellationType, context, new object[] { 88 }); + Assert.True(valueTask.IsCompleted); Assert.Same(entity, await valueTask); } - [ConditionalFact] - public virtual async Task Find_int_key_from_store_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Find_int_key_from_store_async(CancellationType cancellationType) { using var context = CreateContext(); - Assert.Equal("Smokey", (await FindAsync(context, 77)).Foo); + Assert.Equal("Smokey", (await Finder.FindAsync(cancellationType, context, new object[] { 77 })).Foo); } - [ConditionalFact] - public virtual async Task Returns_null_for_int_key_not_in_store_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Returns_null_for_int_key_not_in_store_async(CancellationType cancellationType) { using var context = CreateContext(); - Assert.Null(await FindAsync(context, 99)); + Assert.Null(await Finder.FindAsync(cancellationType, context, new object[] { 99 })); } - [ConditionalFact] - public virtual async Task Find_nullable_int_key_tracked_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Find_nullable_int_key_tracked_async(CancellationType cancellationType) { using var context = CreateContext(); var entity = context.Attach( new NullableIntKey { Id = 88 }).Entity; - Assert.Same(entity, await FindAsync(context, 88)); + Assert.Same(entity, await Finder.FindAsync(cancellationType, context, new object[] { 88 })); } - [ConditionalFact] - public virtual async Task Find_nullable_int_key_from_store_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Find_nullable_int_key_from_store_async(CancellationType cancellationType) { using var context = CreateContext(); - Assert.Equal("Smokey", (await FindAsync(context, 77)).Foo); + Assert.Equal("Smokey", (await Finder.FindAsync(cancellationType, context, new object[] { 77 })).Foo); } - [ConditionalFact] - public virtual async Task Returns_null_for_nullable_int_key_not_in_store_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Returns_null_for_nullable_int_key_not_in_store_async(CancellationType cancellationType) { using var context = CreateContext(); - Assert.Null(await FindAsync(context, 99)); + Assert.Null(await Finder.FindAsync(cancellationType, context, new object[] { 99 })); } - [ConditionalFact] - public virtual async Task Find_string_key_tracked_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Find_string_key_tracked_async(CancellationType cancellationType) { using var context = CreateContext(); var entity = context.Attach( new StringKey { Id = "Rabbit" }).Entity; - Assert.Same(entity, await FindAsync(context, "Rabbit")); + Assert.Same(entity, await Finder.FindAsync(cancellationType, context, new object[] { "Rabbit" })); } - [ConditionalFact] - public virtual async Task Find_string_key_from_store_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Find_string_key_from_store_async(CancellationType cancellationType) { using var context = CreateContext(); - Assert.Equal("Alice", (await FindAsync(context, "Cat")).Foo); + Assert.Equal("Alice", (await Finder.FindAsync(cancellationType, context, new object[] { "Cat" })).Foo); } - [ConditionalFact] - public virtual async Task Returns_null_for_string_key_not_in_store_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Returns_null_for_string_key_not_in_store_async(CancellationType cancellationType) { using var context = CreateContext(); - Assert.Null(await FindAsync(context, "Fox")); + Assert.Null(await Finder.FindAsync(cancellationType, context, new object[] { "Fox" })); } - [ConditionalFact] - public virtual async Task Find_composite_key_tracked_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Find_composite_key_tracked_async(CancellationType cancellationType) { using var context = CreateContext(); var entity = context.Attach( new CompositeKey { Id1 = 88, Id2 = "Rabbit" }).Entity; - Assert.Same(entity, await FindAsync(context, 88, "Rabbit")); + Assert.Same(entity, await Finder.FindAsync(cancellationType, context, new object[] { 88, "Rabbit" })); } - [ConditionalFact] - public virtual async Task Find_composite_key_from_store_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Find_composite_key_from_store_async(CancellationType cancellationType) { using var context = CreateContext(); - Assert.Equal("Olive", (await FindAsync(context, 77, "Dog")).Foo); + Assert.Equal("Olive", (await Finder.FindAsync(cancellationType, context, new object[] { 77, "Dog" })).Foo); } - [ConditionalFact] - public virtual async Task Returns_null_for_composite_key_not_in_store_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Returns_null_for_composite_key_not_in_store_async(CancellationType cancellationType) { using var context = CreateContext(); - Assert.Null(await FindAsync(context, 77, "Fox")); + Assert.Null(await Finder.FindAsync(cancellationType, context, new object[] { 77, "Fox" })); } - [ConditionalFact] - public virtual async Task Find_base_type_tracked_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Find_base_type_tracked_async(CancellationType cancellationType) { using var context = CreateContext(); var entity = context.Attach( new BaseType { Id = 88 }).Entity; - Assert.Same(entity, await FindAsync(context, 88)); + Assert.Same(entity, await Finder.FindAsync(cancellationType, context, new object[] { 88 })); } - [ConditionalFact] - public virtual async Task Find_base_type_from_store_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Find_base_type_from_store_async(CancellationType cancellationType) { using var context = CreateContext(); - Assert.Equal("Baxter", (await FindAsync(context, 77)).Foo); + Assert.Equal("Baxter", (await Finder.FindAsync(cancellationType, context, new object[] { 77 })).Foo); } - [ConditionalFact] - public virtual async Task Returns_null_for_base_type_not_in_store_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Returns_null_for_base_type_not_in_store_async(CancellationType cancellationType) { using var context = CreateContext(); - Assert.Null(await FindAsync(context, 99)); + Assert.Null(await Finder.FindAsync(cancellationType, context, new object[] { 99 })); } - [ConditionalFact] - public virtual async Task Find_derived_type_tracked_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Find_derived_type_tracked_async(CancellationType cancellationType) { using var context = CreateContext(); var entity = context.Attach( new DerivedType { Id = 88 }).Entity; - Assert.Same(entity, await FindAsync(context, 88)); + Assert.Same(entity, await Finder.FindAsync(cancellationType, context, new object[] { 88 })); } - [ConditionalFact] - public virtual async Task Find_derived_type_from_store_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Find_derived_type_from_store_async(CancellationType cancellationType) { using var context = CreateContext(); - var derivedType = await FindAsync(context, 78); + var derivedType = await Finder.FindAsync(cancellationType, context, new object[] { 78 }); Assert.Equal("Strawberry", derivedType.Foo); Assert.Equal("Cheesecake", derivedType.Boo); } - [ConditionalFact] - public virtual async Task Returns_null_for_derived_type_not_in_store_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Returns_null_for_derived_type_not_in_store_async(CancellationType cancellationType) { using var context = CreateContext(); - Assert.Null(await FindAsync(context, 99)); + Assert.Null(await Finder.FindAsync(cancellationType, context, new object[] { 99 })); } - [ConditionalFact] - public virtual async Task Find_base_type_using_derived_set_tracked_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Find_base_type_using_derived_set_tracked_async(CancellationType cancellationType) { using var context = CreateContext(); context.Attach( new BaseType { Id = 88 }); - Assert.Null(await FindAsync(context, 88)); + Assert.Null(await Finder.FindAsync(cancellationType, context, new object[] { 88 })); } - [ConditionalFact] - public virtual async Task Find_base_type_using_derived_set_from_store_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Find_base_type_using_derived_set_from_store_async(CancellationType cancellationType) { using var context = CreateContext(); - Assert.Null(await FindAsync(context, 77)); + Assert.Null(await Finder.FindAsync(cancellationType, context, new object[] { 77 })); } - [ConditionalFact] - public virtual async Task Find_derived_type_using_base_set_tracked_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Find_derived_type_using_base_set_tracked_async(CancellationType cancellationType) { using var context = CreateContext(); var entity = context.Attach( new DerivedType { Id = 88 }).Entity; - Assert.Same(entity, await FindAsync(context, 88)); + Assert.Same(entity, await Finder.FindAsync(cancellationType, context, new object[] { 88 })); } - [ConditionalFact] - public virtual async Task Find_derived_using_base_set_type_from_store_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Find_derived_using_base_set_type_from_store_async(CancellationType cancellationType) { using var context = CreateContext(); - var derivedType = await FindAsync(context, 78); + var derivedType = await Finder.FindAsync(cancellationType, context, new object[] { 78 }); Assert.Equal("Strawberry", derivedType.Foo); Assert.Equal("Cheesecake", ((DerivedType)derivedType).Boo); } - [ConditionalFact] - public virtual async Task Find_shadow_key_tracked_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Find_shadow_key_tracked_async(CancellationType cancellationType) { using var context = CreateContext(); var entry = context.Entry(new ShadowKey()); entry.Property("Id").CurrentValue = 88; entry.State = EntityState.Unchanged; - Assert.Same(entry.Entity, await FindAsync(context, 88)); + Assert.Same(entry.Entity, await Finder.FindAsync(cancellationType, context, new object[] { 88 })); } - [ConditionalFact] - public virtual async Task Find_shadow_key_from_store_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Find_shadow_key_from_store_async(CancellationType cancellationType) { using var context = CreateContext(); - Assert.Equal("Clippy", (await FindAsync(context, 77)).Foo); + Assert.Equal("Clippy", (await Finder.FindAsync(cancellationType, context, new object[] { 77 })).Foo); } - [ConditionalFact] - public virtual async Task Returns_null_for_shadow_key_not_in_store_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Returns_null_for_shadow_key_not_in_store_async(CancellationType cancellationType) { using var context = CreateContext(); - Assert.Null(await FindAsync(context, 99)); + Assert.Null(await Finder.FindAsync(cancellationType, context, new object[] { 99 })); } - [ConditionalFact] - public virtual async Task Returns_null_for_null_key_values_array_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Returns_null_for_null_key_values_array_async(CancellationType cancellationType) { using var context = CreateContext(); - Assert.Null(await FindAsync(context, null)); + Assert.Null(await Finder.FindAsync(cancellationType, context, null)); } - [ConditionalFact] - public virtual async Task Returns_null_for_null_key_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Returns_null_for_null_key_async(CancellationType cancellationType) { using var context = CreateContext(); - Assert.Null(await FindAsync(context, new object[] { null })); + Assert.Null(await Finder.FindAsync(cancellationType, context, new object[] { null })); } - [ConditionalFact] - public virtual async Task Returns_null_for_null_in_composite_key_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Returns_null_for_null_in_composite_key_async(CancellationType cancellationType) { using var context = CreateContext(); - Assert.Null(await FindAsync(context, 77, null)); + Assert.Null(await Finder.FindAsync(cancellationType, context, new object[] { 77, null })); } - [ConditionalFact] - public virtual async Task Throws_for_multiple_values_passed_for_simple_key_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Throws_for_multiple_values_passed_for_simple_key_async(CancellationType cancellationType) { using var context = CreateContext(); Assert.Equal( - CoreStrings.FindNotCompositeKey("IntKey", 2), - (await Assert.ThrowsAsync(() => FindAsync(context, 77, 88).AsTask())).Message); + CoreStrings.FindNotCompositeKey("IntKey", cancellationType == CancellationType.Wrong ? 3 : 2), + (await Assert.ThrowsAsync( + () => Finder.FindAsync(cancellationType, context, new object[] { 77, 88 }).AsTask())).Message); } - [ConditionalFact] - public virtual async Task Throws_for_wrong_number_of_values_for_composite_key_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Throws_for_wrong_number_of_values_for_composite_key_async(CancellationType cancellationType) { using var context = CreateContext(); Assert.Equal( - CoreStrings.FindValueCountMismatch("CompositeKey", 2, 1), - (await Assert.ThrowsAsync(() => FindAsync(context, 77).AsTask())).Message); + cancellationType == CancellationType.Wrong + ? CoreStrings.FindValueTypeMismatch(1, "CompositeKey", "CancellationToken", "string") + : CoreStrings.FindValueCountMismatch("CompositeKey", 2, 1), + (await Assert.ThrowsAsync( + () => Finder.FindAsync(cancellationType, context, new object[] { 77 }).AsTask())).Message); } - [ConditionalFact] - public virtual async Task Throws_for_bad_type_for_simple_key_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Throws_for_bad_type_for_simple_key_async(CancellationType cancellationType) { using var context = CreateContext(); Assert.Equal( CoreStrings.FindValueTypeMismatch(0, "IntKey", "string", "int"), - (await Assert.ThrowsAsync(() => FindAsync(context, "77").AsTask())).Message); + (await Assert.ThrowsAsync( + () => Finder.FindAsync(cancellationType, context, new object[] { "77" }).AsTask())).Message); } - [ConditionalFact] - public virtual async Task Throws_for_bad_type_for_composite_key_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Throws_for_bad_type_for_composite_key_async(CancellationType cancellationType) { using var context = CreateContext(); Assert.Equal( CoreStrings.FindValueTypeMismatch(1, "CompositeKey", "int", "string"), - (await Assert.ThrowsAsync(() => FindAsync(context, 77, 88).AsTask())).Message); + (await Assert.ThrowsAsync( + () => Finder.FindAsync(cancellationType, context, new object[] { 77, 78 }).AsTask())).Message); } - [ConditionalFact] - public virtual async Task Throws_for_bad_entity_type_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Throws_for_bad_entity_type_async(CancellationType cancellationType) { using var context = CreateContext(); Assert.Equal( CoreStrings.InvalidSetType(nameof(Random)), - (await Assert.ThrowsAsync(() => FindAsync(context, 77).AsTask())).Message); + (await Assert.ThrowsAsync( + () => Finder.FindAsync(cancellationType, context, new object[] { 77 }).AsTask())).Message); } - [ConditionalFact] - public virtual async Task Throws_for_bad_entity_type_with_different_namespace_async() + [ConditionalTheory] + [InlineData((int)CancellationType.Right)] + [InlineData((int)CancellationType.Wrong)] + [InlineData((int)CancellationType.None)] + public virtual async Task Throws_for_bad_entity_type_with_different_namespace_async(CancellationType cancellationType) { using var context = CreateContext(); Assert.Equal( CoreStrings.InvalidSetSameTypeWithDifferentNamespace( typeof(DifferentNamespace.ShadowKey).DisplayName(), typeof(ShadowKey).DisplayName()), - (await Assert.ThrowsAsync(() => FindAsync(context, 77).AsTask())) + (await Assert.ThrowsAsync( + () => Finder.FindAsync(cancellationType, context, new object[] { 77 }).AsTask())) .Message); } + public enum CancellationType + { + Right, + Wrong, + None + } + protected class BaseType { [DatabaseGenerated(DatabaseGeneratedOption.None)] @@ -696,6 +810,81 @@ protected override void Seed(PoolableDbContext context) context.SaveChanges(); } } + + public abstract class TestFinder + { + public abstract TEntity Find(DbContext context, params object[] keyValues) + where TEntity : class; + + public abstract ValueTask FindAsync( + CancellationType cancellationType, + DbContext context, + object[] keyValues, + CancellationToken cancellationToken = default) + where TEntity : class; + } + + public class FindViaSetFinder : TestFinder + { + public override TEntity Find(DbContext context, params object[] keyValues) + => context.Set().Find(keyValues); + + public override ValueTask FindAsync( + CancellationType cancellationType, + DbContext context, + object[] keyValues, + CancellationToken cancellationToken = default) + => cancellationType switch + { + CancellationType.Right => context.Set().FindAsync(keyValues, cancellationToken: cancellationToken), + CancellationType.Wrong => context.Set() + .FindAsync(keyValues?.Concat(new object[] { cancellationToken }).ToArray()), + CancellationType.None => context.Set().FindAsync(keyValues), + _ => throw new ArgumentOutOfRangeException(nameof(cancellationType), cancellationType, null) + }; + } + + public class FindViaContextFinder : TestFinder + { + public override TEntity Find(DbContext context, params object[] keyValues) + => (TEntity)context.Find(typeof(TEntity), keyValues); + + public override async ValueTask FindAsync( + CancellationType cancellationType, + DbContext context, + object[] keyValues, + CancellationToken cancellationToken = default) + => cancellationType switch + { + CancellationType.Right => (TEntity)await context.FindAsync( + typeof(TEntity), keyValues, cancellationToken: cancellationToken), + CancellationType.Wrong => (TEntity)await context.FindAsync( + typeof(TEntity), keyValues?.Concat(new object[] { cancellationToken }).ToArray()), + CancellationType.None => (TEntity)await context.FindAsync(typeof(TEntity), keyValues), + _ => throw new ArgumentOutOfRangeException(nameof(cancellationType), cancellationType, null) + }; + } + + public class FindViaNonGenericContextFinder : TestFinder + { + public override TEntity Find(DbContext context, params object[] keyValues) + => context.Find(keyValues); + + public override ValueTask FindAsync( + CancellationType cancellationType, + DbContext context, + object[] keyValues, + CancellationToken cancellationToken = default) + => cancellationType switch + { + CancellationType.Right => context.FindAsync(keyValues, cancellationToken: cancellationToken), + CancellationType.Wrong => context.FindAsync( + keyValues?.Concat(new object[] { cancellationToken }).ToArray()), + CancellationType.None => context.FindAsync(keyValues), + _ => throw new ArgumentOutOfRangeException(nameof(cancellationType), cancellationType, null) + }; + + } } } diff --git a/test/EFCore.SqlServer.FunctionalTests/FindSqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/FindSqlServerTest.cs index 023cd41f258..10dcc1320d1 100644 --- a/test/EFCore.SqlServer.FunctionalTests/FindSqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/FindSqlServerTest.cs @@ -18,11 +18,7 @@ public FindSqlServerTestSet(FindSqlServerFixture fixture) { } - protected override TEntity Find(DbContext context, params object[] keyValues) - => context.Set().Find(keyValues); - - protected override ValueTask FindAsync(DbContext context, params object[] keyValues) - => context.Set().FindAsync(keyValues); + protected override TestFinder Finder { get; } = new FindViaSetFinder(); } public class FindSqlServerTestContext : FindSqlServerTest @@ -32,11 +28,7 @@ public FindSqlServerTestContext(FindSqlServerFixture fixture) { } - protected override TEntity Find(DbContext context, params object[] keyValues) - => context.Find(keyValues); - - protected override ValueTask FindAsync(DbContext context, params object[] keyValues) - => context.FindAsync(keyValues); + protected override TestFinder Finder { get; } = new FindViaContextFinder(); } public class FindSqlServerTestNonGeneric : FindSqlServerTest @@ -46,11 +38,7 @@ public FindSqlServerTestNonGeneric(FindSqlServerFixture fixture) { } - protected override TEntity Find(DbContext context, params object[] keyValues) - => (TEntity)context.Find(typeof(TEntity), keyValues); - - protected override async ValueTask FindAsync(DbContext context, params object[] keyValues) - => (TEntity)await context.FindAsync(typeof(TEntity), keyValues); + protected override TestFinder Finder { get; } = new FindViaNonGenericContextFinder(); } public override void Find_int_key_tracked() diff --git a/test/EFCore.Sqlite.FunctionalTests/FindSqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/FindSqliteTest.cs index a2c15e6fc72..e2571ac9f35 100644 --- a/test/EFCore.Sqlite.FunctionalTests/FindSqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/FindSqliteTest.cs @@ -17,11 +17,7 @@ public FindSqliteTestSet(FindSqliteFixture fixture) { } - protected override TEntity Find(DbContext context, params object[] keyValues) - => context.Set().Find(keyValues); - - protected override ValueTask FindAsync(DbContext context, params object[] keyValues) - => context.Set().FindAsync(keyValues); + protected override TestFinder Finder { get; } = new FindViaSetFinder(); } public class FindSqliteTestContext : FindSqliteTest @@ -31,11 +27,7 @@ public FindSqliteTestContext(FindSqliteFixture fixture) { } - protected override TEntity Find(DbContext context, params object[] keyValues) - => context.Find(keyValues); - - protected override ValueTask FindAsync(DbContext context, params object[] keyValues) - => context.FindAsync(keyValues); + protected override TestFinder Finder { get; } = new FindViaContextFinder(); } public class FindSqliteTestNonGeneric : FindSqliteTest @@ -45,11 +37,7 @@ public FindSqliteTestNonGeneric(FindSqliteFixture fixture) { } - protected override TEntity Find(DbContext context, params object[] keyValues) - => (TEntity)context.Find(typeof(TEntity), keyValues); - - protected override async ValueTask FindAsync(DbContext context, params object[] keyValues) - => (TEntity)await context.FindAsync(typeof(TEntity), keyValues); + protected override TestFinder Finder { get; } = new FindViaNonGenericContextFinder(); } public class FindSqliteFixture : FindFixtureBase