Skip to content

Commit

Permalink
Allow cancellation token to be passed as part of params array in Find…
Browse files Browse the repository at this point in the history
…Async (#28389)
  • Loading branch information
ajcvickers committed Jul 7, 2022
1 parent b4ad698 commit 32987fa
Show file tree
Hide file tree
Showing 6 changed files with 392 additions and 221 deletions.
56 changes: 43 additions & 13 deletions src/EFCore/Internal/EntityFinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,18 @@ public EntityFinder(
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual TEntity? Find(object?[]? keyValues)
=> keyValues == null || keyValues.Any(v => v == null)
? null
: (FindTracked(keyValues!, out var keyProperties)
?? _queryRoot.FirstOrDefault(BuildLambda(keyProperties, new ValueBuffer(keyValues))));
{
if (keyValues == null
|| keyValues.Any(v => v == null))
{
return default;
}

var (key, processedKeyValues, _) = ValidateKeyPropertiesAndExtractCancellationToken(keyValues!, async: false, default);

return FindTracked(key, processedKeyValues)
?? _queryRoot.FirstOrDefault(BuildLambda(key.Properties, new ValueBuffer(processedKeyValues)));
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand All @@ -74,11 +82,13 @@ public EntityFinder(
return default;
}

var tracked = FindTracked(keyValues!, out var keyProperties);
var (key, processedKeyValues, ct) = ValidateKeyPropertiesAndExtractCancellationToken(keyValues!, async: true, cancellationToken);

var tracked = FindTracked(key, processedKeyValues);
return tracked != null
? new ValueTask<TEntity?>(tracked)
: new ValueTask<TEntity?>(
_queryRoot.FirstOrDefaultAsync(BuildLambda(keyProperties, new ValueBuffer(keyValues)), cancellationToken));
_queryRoot.FirstOrDefaultAsync(BuildLambda(key.Properties, new ValueBuffer(processedKeyValues)), ct));
}

/// <summary>
Expand All @@ -95,12 +105,14 @@ public EntityFinder(
return default;
}

var tracked = FindTracked(keyValues!, out var keyProperties);
var (key, processedKeyValues, ct) = ValidateKeyPropertiesAndExtractCancellationToken(keyValues!, async: true, cancellationToken);

var tracked = FindTracked(key, processedKeyValues);
return tracked != null
? new ValueTask<object?>(tracked)
: new ValueTask<object?>(
_queryRoot.FirstOrDefaultAsync(
BuildObjectLambda(keyProperties, new ValueBuffer(keyValues)), cancellationToken));
BuildObjectLambda(key.Properties, new ValueBuffer(processedKeyValues)), ct));
}

/// <summary>
Expand Down Expand Up @@ -259,23 +271,41 @@ private static IReadOnlyList<IProperty> GetLoadProperties(INavigation navigation
? navigation.ForeignKey.PrincipalKey.Properties
: navigation.ForeignKey.Properties;

private TEntity? FindTracked(object[] keyValues, out IReadOnlyList<IProperty> keyProperties)
private (IKey Key, object[] KeyValues,CancellationToken CancellationToken) ValidateKeyPropertiesAndExtractCancellationToken(
object[] keyValues,
bool async,
CancellationToken cancellationToken)
{
var key = _entityType.FindPrimaryKey()!;
keyProperties = key.Properties;
var keyPropertiesCount = key.Properties.Count;

if (keyProperties.Count != keyValues.Length)
if (keyPropertiesCount != keyValues.Length)
{
if (keyProperties.Count == 1)
if (async
&& keyPropertiesCount == keyValues.Length - 1
&& keyValues[keyPropertiesCount] is CancellationToken ct)
{
var newValues = new object[keyPropertiesCount];
Array.Copy(keyValues, newValues, keyPropertiesCount);
return (key, newValues, ct);
}

if (keyPropertiesCount == 1)
{
throw new ArgumentException(
CoreStrings.FindNotCompositeKey(typeof(TEntity).ShortDisplayName(), keyValues.Length));
}

throw new ArgumentException(
CoreStrings.FindValueCountMismatch(typeof(TEntity).ShortDisplayName(), keyProperties.Count, keyValues.Length));
CoreStrings.FindValueCountMismatch(typeof(TEntity).ShortDisplayName(), keyPropertiesCount, keyValues.Length));
}

return (key, keyValues, cancellationToken);
}

private TEntity? FindTracked(IKey key, object[] keyValues)
{
var keyProperties = key.Properties;
for (var i = 0; i < keyValues.Length; i++)
{
var valueType = keyValues[i].GetType();
Expand Down
26 changes: 7 additions & 19 deletions test/EFCore.Cosmos.FunctionalTests/FindCosmosTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ protected FindCosmosTest(FindCosmosFixture fixture)
[ConditionalFact(Skip = "#25886")]
public override void Find_base_type_using_derived_set_tracked() { }

[ConditionalFact(Skip = "#25886")]
public override Task Find_base_type_using_derived_set_tracked_async()
[ConditionalTheory(Skip = "#25886")]
public override Task Find_base_type_using_derived_set_tracked_async(CancellationType cancellationType)
=> Task.CompletedTask;

[ConditionalFact(Skip = "#25886")]
public override void Find_derived_using_base_set_type_from_store() { }

[ConditionalFact(Skip = "#25886")]
public override Task Find_derived_using_base_set_type_from_store_async()
[ConditionalTheory(Skip = "#25886")]
public override Task Find_derived_using_base_set_type_from_store_async(CancellationType cancellationType)
=> Task.CompletedTask;

public class FindCosmosTestSet : FindCosmosTest
Expand All @@ -32,11 +32,7 @@ public FindCosmosTestSet(FindCosmosFixture fixture)
{
}

protected override TEntity Find<TEntity>(DbContext context, params object[] keyValues)
=> context.Set<TEntity>().Find(keyValues);

protected override ValueTask<TEntity> FindAsync<TEntity>(DbContext context, params object[] keyValues)
=> context.Set<TEntity>().FindAsync(keyValues);
protected override TestFinder Finder { get; } = new FindViaSetFinder();
}

public class FindCosmosTestContext : FindCosmosTest
Expand All @@ -46,11 +42,7 @@ public FindCosmosTestContext(FindCosmosFixture fixture)
{
}

protected override TEntity Find<TEntity>(DbContext context, params object[] keyValues)
=> context.Find<TEntity>(keyValues);

protected override ValueTask<TEntity> FindAsync<TEntity>(DbContext context, params object[] keyValues)
=> context.FindAsync<TEntity>(keyValues);
protected override TestFinder Finder { get; } = new FindViaContextFinder();
}

public class FindCosmosTestNonGeneric : FindCosmosTest
Expand All @@ -60,11 +52,7 @@ public FindCosmosTestNonGeneric(FindCosmosFixture fixture)
{
}

protected override TEntity Find<TEntity>(DbContext context, params object[] keyValues)
=> (TEntity)context.Find(typeof(TEntity), keyValues);

protected override async ValueTask<TEntity> FindAsync<TEntity>(DbContext context, params object[] keyValues)
=> (TEntity)await context.FindAsync(typeof(TEntity), keyValues);
protected override TestFinder Finder { get; } = new FindViaNonGenericContextFinder();
}

public class FindCosmosFixture : FindFixtureBase
Expand Down
18 changes: 3 additions & 15 deletions test/EFCore.InMemory.FunctionalTests/FindInMemoryTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@ public FindInMemoryTestSet(FindInMemoryFixture fixture)
{
}

protected override TEntity Find<TEntity>(DbContext context, params object[] keyValues)
=> context.Set<TEntity>().Find(keyValues);

protected override ValueTask<TEntity> FindAsync<TEntity>(DbContext context, params object[] keyValues)
=> context.Set<TEntity>().FindAsync(keyValues);
protected override TestFinder Finder { get; } = new FindViaSetFinder();
}

public class FindInMemoryTestContext : FindInMemoryTest
Expand All @@ -31,11 +27,7 @@ public FindInMemoryTestContext(FindInMemoryFixture fixture)
{
}

protected override TEntity Find<TEntity>(DbContext context, params object[] keyValues)
=> context.Find<TEntity>(keyValues);

protected override ValueTask<TEntity> FindAsync<TEntity>(DbContext context, params object[] keyValues)
=> context.FindAsync<TEntity>(keyValues);
protected override TestFinder Finder { get; } = new FindViaContextFinder();
}

public class FindInMemoryTestNonGeneric : FindInMemoryTest
Expand All @@ -45,11 +37,7 @@ public FindInMemoryTestNonGeneric(FindInMemoryFixture fixture)
{
}

protected override TEntity Find<TEntity>(DbContext context, params object[] keyValues)
=> (TEntity)context.Find(typeof(TEntity), keyValues);

protected override async ValueTask<TEntity> FindAsync<TEntity>(DbContext context, params object[] keyValues)
=> (TEntity)await context.FindAsync(typeof(TEntity), keyValues);
protected override TestFinder Finder { get; } = new FindViaNonGenericContextFinder();
}

public class FindInMemoryFixture : FindFixtureBase
Expand Down
Loading

0 comments on commit 32987fa

Please sign in to comment.