Skip to content

Commit

Permalink
Make concurrency check handle re-entrance
Browse files Browse the repository at this point in the history
Fixes #7375
  • Loading branch information
AndriySvyryd committed Jun 16, 2017
1 parent adf1630 commit bc01a55
Show file tree
Hide file tree
Showing 17 changed files with 255 additions and 170 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

using System;
using System.Data.Common;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Internal;
Expand Down Expand Up @@ -51,12 +53,15 @@ public virtual void Throws_on_concurrent_command()
{
using (var context = CreateContext())
{
((IInfrastructure<IServiceProvider>)context).Instance.GetService<IConcurrencyDetector>().EnterCriticalSection();
context.Database.EnsureCreated();

Assert.Equal(
CoreStrings.ConcurrentMethodInvocation,
Assert.Throws<InvalidOperationException>(
() => context.Database.ExecuteSqlCommand(@"SELECT * FROM ""Customers""")).Message);
Assert.Throws<AggregateException>(
() => Parallel.For(0, 10, i =>
{
context.Database.ExecuteSqlCommand(@"SELECT * FROM ""Customers""");
})).InnerExceptions.First().Message);
}
}

Expand Down Expand Up @@ -126,12 +131,17 @@ public virtual async Task Throws_on_concurrent_command_async()
{
using (var context = CreateContext())
{
((IInfrastructure<IServiceProvider>)context).Instance.GetService<IConcurrencyDetector>().EnterCriticalSection();
context.Database.EnsureCreated();

var task = Task.WhenAll(Enumerable.Range(0, 10)
.Select(i => context.Database.ExecuteSqlCommandAsync(@"SELECT * FROM ""Customers""")));

Assert.Equal(
CoreStrings.ConcurrentMethodInvocation,
(await Assert.ThrowsAsync<InvalidOperationException>(
async () => await context.Database.ExecuteSqlCommandAsync(@"SELECT * FROM ""Customers"""))).Message);
Assert.Throws<InvalidOperationException>(
() => context.Database.ExecuteSqlCommand(@"SELECT * FROM ""Customers""")).Message);

await task;
}
}

Expand Down
78 changes: 35 additions & 43 deletions src/EFCore.Relational/RelationalDatabaseFacadeExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ public static int ExecuteSqlCommand(
RawSqlString sql,
[NotNull] params object[] parameters)
=> ExecuteSqlCommand(databaseFacade, sql, (IEnumerable<object>)parameters);

// Note that this method doesn't start a transaction hence it doesn't use ExecutionStrategy
public static int ExecuteSqlCommand(
[NotNull] this DatabaseFacade databaseFacade,
Expand All @@ -120,26 +120,22 @@ public static int ExecuteSqlCommand(
[NotNull] this DatabaseFacade databaseFacade,
RawSqlString sql,
[NotNull] IEnumerable<object> parameters)
{
Check.NotNull(databaseFacade, nameof(databaseFacade));
Check.NotNull(sql, nameof(sql));
Check.NotNull(parameters, nameof(parameters));

var concurrencyDetector = databaseFacade.GetService<IConcurrencyDetector>();

using (concurrencyDetector.EnterCriticalSection())
{
var rawSqlCommand = databaseFacade
.GetRelationalService<IRawSqlCommandBuilder>()
.Build(sql.Format, parameters);

return rawSqlCommand
.RelationalCommand
.ExecuteNonQuery(
databaseFacade.GetRelationalService<IRelationalConnection>(),
rawSqlCommand.ParameterValues);
}
}
=> Check.NotNull(databaseFacade, nameof(databaseFacade)).GetService<IConcurrencyDetector>().ExecuteInCriticalSection(
(DatabaseFacade: databaseFacade,
Sql: Check.NotNull(sql, nameof(sql)),
Parameters: Check.NotNull(parameters, nameof(parameters))),
s =>
{
var rawSqlCommand = s.DatabaseFacade
.GetRelationalService<IRawSqlCommandBuilder>()
.Build(s.Sql.Format, s.Parameters);
return rawSqlCommand
.RelationalCommand
.ExecuteNonQuery(
s.DatabaseFacade.GetRelationalService<IRelationalConnection>(),
rawSqlCommand.ParameterValues);
});

// Note that this method doesn't start a transaction hence it doesn't use ExecutionStrategy
public static Task<int> ExecuteSqlCommandAsync(
Expand All @@ -163,32 +159,28 @@ public static Task<int> ExecuteSqlCommandAsync(
=> ExecuteSqlCommandAsync(databaseFacade, sql, (IEnumerable<object>)parameters);

// Note that this method doesn't start a transaction hence it doesn't use ExecutionStrategy
public static async Task<int> ExecuteSqlCommandAsync(
public static Task<int> ExecuteSqlCommandAsync(
[NotNull] this DatabaseFacade databaseFacade,
RawSqlString sql,
[NotNull] IEnumerable<object> parameters,
CancellationToken cancellationToken = default(CancellationToken))
{
Check.NotNull(databaseFacade, nameof(databaseFacade));
Check.NotNull(sql, nameof(sql));
Check.NotNull(parameters, nameof(parameters));

var concurrencyDetector = databaseFacade.GetService<IConcurrencyDetector>();

using (await concurrencyDetector.EnterCriticalSectionAsync(cancellationToken))
{
var rawSqlCommand = databaseFacade
.GetRelationalService<IRawSqlCommandBuilder>()
.Build(sql.Format, parameters);

return await rawSqlCommand
.RelationalCommand
.ExecuteNonQueryAsync(
databaseFacade.GetRelationalService<IRelationalConnection>(),
rawSqlCommand.ParameterValues,
cancellationToken);
}
}
=> Check.NotNull(databaseFacade, nameof(databaseFacade)).GetService<IConcurrencyDetector>().ExecuteInCriticalSectionAsync(
(DatabaseFacade: databaseFacade,
Sql: Check.NotNull(sql, nameof(sql)),
Parameters: Check.NotNull(parameters, nameof(parameters))),
async (s, ct) =>
{
var rawSqlCommand = s.DatabaseFacade
.GetRelationalService<IRawSqlCommandBuilder>()
.Build(s.Sql.Format, s.Parameters);
return await rawSqlCommand
.RelationalCommand
.ExecuteNonQueryAsync(
s.DatabaseFacade.GetRelationalService<IRelationalConnection>(),
rawSqlCommand.ParameterValues,
ct);
}, cancellationToken);

public static DbConnection GetDbConnection([NotNull] this DatabaseFacade databaseFacade)
=> databaseFacade.GetRelationalService<IRelationalConnection>().DbConnection;
Expand Down
53 changes: 45 additions & 8 deletions src/EFCore.Specification.Tests/Query/AsyncQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1440,6 +1440,33 @@ from c in cs
select c.City);
}

[ConditionalFact]
public virtual async Task Select_nested_projection()
{
using (var context = CreateContext())
{
var customers = await context.Customers
.Where(c => c.CustomerID.StartsWith("A"))
.Select(c => new
{
Customer = c,
CustomerAgain = Get(context, c.CustomerID)
})
.ToListAsync();

Assert.Equal(4, customers.Count);
foreach (var customer in customers)
{
// Issue #8864
//Assert.Same(customer.Customer, customer.CustomerAgain);
Assert.Equal(customer.Customer.CustomerID, customer.CustomerAgain.CustomerID);
}
}
}

private Customer Get(NorthwindContext context, string id)
=> context.Customers.Single(c => c.CustomerID == id);

[ConditionalFact]
public virtual async Task Select_nested_collection()
{
Expand Down Expand Up @@ -2255,7 +2282,7 @@ from o in os.OrderBy(o => o.OrderID)
assertOrder: true);
}

// TODO: Need to figure out how to do this
// TODO: Need to figure out how to do this
// [ConditionalFact]
// public virtual async Task GroupBy_anonymous()
// {
Expand Down Expand Up @@ -3349,26 +3376,36 @@ public virtual async Task Throws_on_concurrent_query_list()
{
using (var context = CreateContext())
{
((IInfrastructure<IServiceProvider>)context).Instance.GetService<IConcurrencyDetector>().EnterCriticalSection();
context.Database.EnsureCreated();

var task = Task.WhenAll(Enumerable.Range(0, 10)
.Select(i => context.Customers.ToListAsync()));

Assert.Equal(
CoreStrings.ConcurrentMethodInvocation,
(await Assert.ThrowsAsync<InvalidOperationException>(
async () => await context.Customers.ToListAsync())).Message);
Assert.Throws<InvalidOperationException>(
() => context.Customers.ToList()).Message);

await task;
}
}

[ConditionalFact]
[ConditionalFact(Skip = "This test is flaky see #8305")]
public virtual async Task Throws_on_concurrent_query_first()
{
using (var context = CreateContext())
{
((IInfrastructure<IServiceProvider>)context).Instance.GetService<IConcurrencyDetector>().EnterCriticalSection();
context.Database.EnsureCreated();

var task = Task.WhenAll(Enumerable.Range(0, 10)
.Select(i => context.Customers.ToListAsync()));

Assert.Equal(
CoreStrings.ConcurrentMethodInvocation,
(await Assert.ThrowsAsync<InvalidOperationException>(
async () => await context.Customers.FirstAsync())).Message);
Assert.Throws<InvalidOperationException>(
() => context.Customers.First()).Message);

await task;
}
}

Expand Down
27 changes: 27 additions & 0 deletions src/EFCore.Specification.Tests/Query/QueryTestBase.Select.cs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,33 @@ from c in cs
select c.City);
}

[ConditionalFact]
public virtual void Select_nested_projection()
{
using (var context = CreateContext())
{
var customers = context.Customers
.Where(c => c.CustomerID.StartsWith("A"))
.Select(c => new
{
Customer = c,
CustomerAgain = Get(context, c.CustomerID)
})
.ToList();

Assert.Equal(4, customers.Count);
foreach (var customer in customers)
{
// Issue #8864
//Assert.Same(customer.Customer, customer.CustomerAgain);
Assert.Equal(customer.Customer.CustomerID, customer.CustomerAgain.CustomerID);
}
}
}

private Customer Get(NorthwindContext context, string id)
=> context.Customers.Single(c => c.CustomerID == id);

[ConditionalFact]
public virtual void Select_nested_collection()
{
Expand Down
26 changes: 17 additions & 9 deletions src/EFCore.Specification.Tests/Query/QueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.Linq;
using System.Linq.Expressions;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Query.Expressions.Internal;
Expand Down Expand Up @@ -2695,26 +2697,32 @@ public virtual void Throws_on_concurrent_query_list()
{
using (var context = CreateContext())
{
((IInfrastructure<IServiceProvider>)context).Instance.GetService<IConcurrencyDetector>().EnterCriticalSection();
context.Database.EnsureCreated();

Assert.Equal(
CoreStrings.ConcurrentMethodInvocation,
Assert.Throws<InvalidOperationException>(
() => context.Customers.ToList()).Message);
Assert.Throws<AggregateException>(
() => Parallel.For(0, 10, i =>
{
context.Customers.ToList();
})).InnerExceptions.First().Message);
}
}

[ConditionalFact]
[ConditionalFact(Skip = "This test is flaky see #8305")]
public virtual void Throws_on_concurrent_query_first()
{
using (var context = CreateContext())
{
((IInfrastructure<IServiceProvider>)context).Instance.GetService<IConcurrencyDetector>().EnterCriticalSection();
context.Database.EnsureCreated();

Assert.Equal(
CoreStrings.ConcurrentMethodInvocation,
Assert.Throws<InvalidOperationException>(
() => context.Customers.First()).Message);
Assert.Throws<AggregateException>(
() => Parallel.For(0, 10, i =>
{
context.Customers.First();
})).InnerExceptions.First().Message);
}
}

Expand Down Expand Up @@ -3338,7 +3346,7 @@ public virtual void OrderBy_skip_take_distinct_orderby_take()
assertOrder: false,
entryCount: 8);
}

[ConditionalFact]
public virtual void No_orderby_added_for_fully_translated_manually_constructed_LOJ()
{
Expand Down Expand Up @@ -3944,7 +3952,7 @@ protected QueryTestBase(TFixture fixture)
Fixture = fixture;
}

protected TFixture Fixture { get; }
protected TFixture Fixture { [DebuggerStepThrough] get; }

private void AssertQuery<TItem>(
Func<IQueryable<TItem>, int> query,
Expand Down
19 changes: 6 additions & 13 deletions src/EFCore/ChangeTracking/Internal/StateManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -781,26 +781,19 @@ public virtual async Task<int> SaveChangesAsync(
/// </summary>
protected virtual int SaveChanges(
[NotNull] IReadOnlyList<InternalEntityEntry> entriesToSave)
{
using (_concurrencyDetector.EnterCriticalSection())
{
return _database.SaveChanges(entriesToSave);
}
}
=> _concurrencyDetector.ExecuteInCriticalSection((Database: _database, EntriesToSave: entriesToSave),
s => s.Database.SaveChanges(s.EntriesToSave));

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
protected virtual async Task<int> SaveChangesAsync(
protected virtual Task<int> SaveChangesAsync(
[NotNull] IReadOnlyList<InternalEntityEntry> entriesToSave,
CancellationToken cancellationToken = default(CancellationToken))
{
using (_concurrencyDetector.EnterCriticalSection())
{
return await _database.SaveChangesAsync(entriesToSave, cancellationToken);
}
}
=> _concurrencyDetector.ExecuteInCriticalSectionAsync((Database: _database, EntriesToSave: entriesToSave),
(s, ct) => s.Database.SaveChangesAsync(s.EntriesToSave, ct),
cancellationToken);

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
Expand Down
Loading

0 comments on commit bc01a55

Please sign in to comment.