Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add concurrency detection for query, save changes, and ADO command execution #4051

Merged
merged 1 commit into from
Dec 15, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions src/EntityFramework.Core/ChangeTracking/Internal/StateManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ private readonly Dictionary<IKeyValue, InternalEntityEntry> _identityMap
private readonly IKeyValueFactorySource _keyValueFactorySource;
private readonly IModel _model;
private readonly IDatabase _database;
private IConcurrencyDetector _concurrencyDetector;

public StateManager(
[NotNull] IInternalEntityEntryFactory factory,
Expand All @@ -42,6 +43,7 @@ public StateManager(
[NotNull] IKeyValueFactorySource keyValueFactorySource,
[NotNull] IModel model,
[NotNull] IDatabase database,
[NotNull] IConcurrencyDetector concurrencyDetector,
[NotNull] DbContext context)
{
_factory = factory;
Expand All @@ -51,6 +53,7 @@ public StateManager(
ValueGeneration = valueGeneration;
_model = model;
_database = database;
_concurrencyDetector = concurrencyDetector;
Context = context;
}

Expand Down Expand Up @@ -411,12 +414,32 @@ public virtual async Task<int> SaveChangesAsync(

protected virtual int SaveChanges(
[NotNull] IReadOnlyList<InternalEntityEntry> entriesToSave)
=> _database.SaveChanges(entriesToSave);
{
try
{
_concurrencyDetector.EnterCriticalSection();
return _database.SaveChanges(entriesToSave);
}
finally
{
_concurrencyDetector.ExitCriticalSection();
}
}

protected virtual Task<int> SaveChangesAsync(
protected virtual async Task<int> SaveChangesAsync(
[NotNull] IReadOnlyList<InternalEntityEntry> entriesToSave,
CancellationToken cancellationToken = default(CancellationToken))
=> _database.SaveChangesAsync(entriesToSave, cancellationToken);
{
try
{
_concurrencyDetector.EnterCriticalSection();
return await _database.SaveChangesAsync(entriesToSave, cancellationToken);
}
finally
{
_concurrencyDetector.ExitCriticalSection();
}
}

public virtual void AcceptAllChanges()
{
Expand Down
2 changes: 2 additions & 0 deletions src/EntityFramework.Core/EntityFramework.Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@
<Compile Include="Internal\EnumerableExtensions.cs" />
<Compile Include="Internal\Graph.cs" />
<Compile Include="Internal\IDatabaseProviderSelector.cs" />
<Compile Include="IConcurrencyDetector.cs" />
<Compile Include="Internal\IndentedStringBuilder.cs" />
<Compile Include="Internal\ConcurrencyDetector.cs" />
<Compile Include="Internal\LazyRef.cs" />
<Compile Include="Metadata\Internal\ConventionalAnnotatable.cs" />
<Compile Include="Metadata\Internal\ConventionalAnnotation.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using JetBrains.Annotations;
using Microsoft.Data.Entity;
using Microsoft.Data.Entity.ChangeTracking.Internal;
using Microsoft.Data.Entity.Infrastructure;
using Microsoft.Data.Entity.Internal;
Expand Down Expand Up @@ -94,6 +95,7 @@ public static EntityFrameworkServicesBuilder AddEntityFramework(
.AddScoped<IKeyPropagator, KeyPropagator>()
.AddScoped<INavigationFixer, NavigationFixer>()
.AddScoped<IStateManager, StateManager>()
.AddScoped<IConcurrencyDetector, ConcurrencyDetector>()
.AddScoped<IInternalEntityEntryFactory, InternalEntityEntryFactory>()
.AddScoped<IInternalEntityEntryNotifier, InternalEntityEntryNotifier>()
.AddScoped<IInternalEntityEntrySubscriber, InternalEntityEntrySubscriber>()
Expand Down
12 changes: 12 additions & 0 deletions src/EntityFramework.Core/IConcurrencyDetector.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

namespace Microsoft.Data.Entity
{
public interface IConcurrencyDetector
{
void EnterCriticalSection();

void ExitCriticalSection();
}
}
30 changes: 30 additions & 0 deletions src/EntityFramework.Core/Internal/ConcurrencyDetector.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Diagnostics;

namespace Microsoft.Data.Entity.Internal
{
public class ConcurrencyDetector : IConcurrencyDetector
{
private bool _isInCriticalSection;

public virtual void EnterCriticalSection()
{
if(_isInCriticalSection)
{
throw new NotSupportedException(CoreStrings.ConcurrentMethodInvocation);
}

_isInCriticalSection = true;
}

public virtual void ExitCriticalSection()
{
Debug.Assert(_isInCriticalSection, "Expected to be in a critical section");

_isInCriticalSection = false;
}
}
}
8 changes: 8 additions & 0 deletions src/EntityFramework.Core/Properties/CoreStrings.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/EntityFramework.Core/Properties/CoreStrings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -522,4 +522,7 @@
<data name="KeyPropertyInForeignKey" xml:space="preserve">
<value>The property '{property}' cannot be part of a key on '{entityType}' because it is contained in a foreign key defined on a derived entity type.</value>
</data>
<data name="ConcurrentMethodInvocation" xml:space="preserve">
<value>A second operation started on this context before a previous operation completed. Any instance members are not guaranteed to be thread safe.</value>
</data>
</root>
3 changes: 2 additions & 1 deletion src/EntityFramework.Core/Query/EntityQueryModelVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ protected virtual void InterceptExceptions()
.MakeGenericMethod(_expression.Type.GetSequenceType()),
_expression,
Expression.Constant(QueryCompilationContext.ContextType),
Expression.Constant(QueryCompilationContext.Logger));
Expression.Constant(QueryCompilationContext.Logger),
QueryContextParameter);
}

protected virtual void ExtractQueryAnnotations([NotNull] QueryModel queryModel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ private static readonly MethodInfo _interceptExceptions

[UsedImplicitly]
internal static IAsyncEnumerable<T> _InterceptExceptions<T>(
IAsyncEnumerable<T> source, Type contextType, ILogger logger)
=> new ExceptionInterceptor<T>(source, contextType, logger);
IAsyncEnumerable<T> source, Type contextType, ILogger logger, QueryContext queryContext)
=> new ExceptionInterceptor<T>(source, contextType, logger, queryContext);

public virtual MethodInfo InterceptExceptions => _interceptExceptions;

Expand All @@ -96,13 +96,15 @@ private sealed class ExceptionInterceptor<T> : IAsyncEnumerable<T>
private readonly IAsyncEnumerable<T> _innerAsyncEnumerable;
private readonly Type _contextType;
private readonly ILogger _logger;
private readonly QueryContext _queryContext;

public ExceptionInterceptor(
IAsyncEnumerable<T> innerAsyncEnumerable, Type contextType, ILogger logger)
IAsyncEnumerable<T> innerAsyncEnumerable, Type contextType, ILogger logger, QueryContext queryContext)
{
_innerAsyncEnumerable = innerAsyncEnumerable;
_contextType = contextType;
_logger = logger;
_queryContext = queryContext;
}

public IAsyncEnumerator<T> GetEnumerator() => new EnumeratorExceptionInterceptor(this);
Expand All @@ -125,6 +127,7 @@ public async Task<bool> MoveNext(CancellationToken cancellationToken)
{
try
{
_exceptionInterceptor._queryContext.ConcurrencyDetector.EnterCriticalSection();
return await _innerEnumerator.MoveNext(cancellationToken);
}
catch (Exception exception)
Expand All @@ -138,6 +141,10 @@ public async Task<bool> MoveNext(CancellationToken cancellationToken)

throw;
}
finally
{
_exceptionInterceptor._queryContext.ConcurrencyDetector.ExitCriticalSection();
}
}

public void Dispose() => _innerEnumerator.Dispose();
Expand Down
13 changes: 10 additions & 3 deletions src/EntityFramework.Core/Query/Internal/LinqOperatorProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ private static readonly MethodInfo _interceptExceptions

[UsedImplicitly]
internal static IEnumerable<T> _InterceptExceptions<T>(
IEnumerable<T> source, Type contextType, ILogger logger)
=> new ExceptionInterceptor<T>(source, contextType, logger);
IEnumerable<T> source, Type contextType, ILogger logger, QueryContext queryContext)
=> new ExceptionInterceptor<T>(source, contextType, logger, queryContext);

public virtual MethodInfo InterceptExceptions => _interceptExceptions;

Expand All @@ -75,13 +75,15 @@ private sealed class ExceptionInterceptor<T> : IEnumerable<T>
private readonly IEnumerable<T> _innerEnumerable;
private readonly Type _contextType;
private readonly ILogger _logger;
private readonly QueryContext _queryContext;

public ExceptionInterceptor(
IEnumerable<T> innerEnumerable, Type contextType, ILogger logger)
IEnumerable<T> innerEnumerable, Type contextType, ILogger logger, QueryContext queryContext)
{
_innerEnumerable = innerEnumerable;
_contextType = contextType;
_logger = logger;
_queryContext = queryContext;
}

public IEnumerator<T> GetEnumerator() => new EnumeratorExceptionInterceptor(this);
Expand All @@ -108,6 +110,7 @@ public bool MoveNext()
{
try
{
_exceptionInterceptor._queryContext.ConcurrencyDetector.EnterCriticalSection();
return _innerEnumerator.MoveNext();
}
catch (Exception exception)
Expand All @@ -121,6 +124,10 @@ public bool MoveNext()

throw;
}
finally
{
_exceptionInterceptor._queryContext.ConcurrencyDetector.ExitCriticalSection();
}
}

public void Reset() => _innerEnumerator?.Reset();
Expand Down
7 changes: 6 additions & 1 deletion src/EntityFramework.Core/Query/QueryContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,26 @@ public class QueryContext

public QueryContext(
[NotNull] Func<IQueryBuffer> queryBufferFactory,
[NotNull] IStateManager stateManager)
[NotNull] IStateManager stateManager,
[NotNull] IConcurrencyDetector concurrencyDetector)
{
Check.NotNull(queryBufferFactory, nameof(queryBufferFactory));
Check.NotNull(stateManager, nameof(stateManager));
Check.NotNull(concurrencyDetector, nameof(concurrencyDetector));

_queryBufferFactory = queryBufferFactory;

StateManager = stateManager;
ConcurrencyDetector = concurrencyDetector;
}

public virtual IQueryBuffer QueryBuffer
=> _queryBuffer ?? (_queryBuffer = _queryBufferFactory());

public virtual IStateManager StateManager { get; }

public virtual IConcurrencyDetector ConcurrencyDetector { get; }

public virtual CancellationToken CancellationToken { get; set; }

public virtual IReadOnlyDictionary<string, object> ParameterValues
Expand Down
5 changes: 5 additions & 0 deletions src/EntityFramework.Core/Query/QueryContextFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@ public abstract class QueryContextFactory : IQueryContextFactory

protected QueryContextFactory(
[NotNull] IStateManager stateManager,
[NotNull] IConcurrencyDetector concurrencyDetector,
[NotNull] IKeyValueFactorySource keyValueFactorySource)
{
Check.NotNull(stateManager, nameof(stateManager));
Check.NotNull(concurrencyDetector, nameof(concurrencyDetector));
Check.NotNull(keyValueFactorySource, nameof(keyValueFactorySource));

StateManager = stateManager;
ConcurrencyDetector = concurrencyDetector;
_keyValueFactorySource = keyValueFactorySource;
}

Expand All @@ -28,6 +31,8 @@ protected virtual IQueryBuffer CreateQueryBuffer()

protected virtual IStateManager StateManager { get; }

protected virtual IConcurrencyDetector ConcurrencyDetector { get; }

public abstract QueryContext Create();
}
}
4 changes: 2 additions & 2 deletions src/EntityFramework.Core/Utilities/Imply.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,15 @@ public void ImplyMethods()

EF.Property<T>(null, null);

LinqOperatorProvider._InterceptExceptions<T>(null, null, null);
LinqOperatorProvider._InterceptExceptions<T>(null, null, null, null);
LinqOperatorProvider._ToEnumerable<T>(null);
LinqOperatorProvider._ToOrdered<T>(null);
LinqOperatorProvider._ToSequence((T)new object());
LinqOperatorProvider._ToQueryable<T>(null);
LinqOperatorProvider._TrackEntities<T, object>(null, null, null, null);
LinqOperatorProvider._Where<T>(null, null);

AsyncLinqOperatorProvider._InterceptExceptions<T>(null, null, null);
AsyncLinqOperatorProvider._InterceptExceptions<T>(null, null, null, null);
AsyncLinqOperatorProvider._ToEnumerable<T>(null);
AsyncLinqOperatorProvider._ToOrdered<T>(null);
AsyncLinqOperatorProvider._ToSequence((T)new object());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ public class InMemoryQueryContext : QueryContext
public InMemoryQueryContext(
[NotNull] Func<IQueryBuffer> queryBufferFactory,
[NotNull] IInMemoryStore store,
[NotNull] IStateManager stateManager)
[NotNull] IStateManager stateManager,
[NotNull] IConcurrencyDetector concurrencyDetector)
: base(
Check.NotNull(queryBufferFactory, nameof(queryBufferFactory)),
Check.NotNull(stateManager, nameof(stateManager)))
Check.NotNull(stateManager, nameof(stateManager)),
Check.NotNull(concurrencyDetector, nameof(concurrencyDetector)))
{
Store = store;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@ public class InMemoryQueryContextFactory : QueryContextFactory

public InMemoryQueryContextFactory(
[NotNull] IStateManager stateManager,
[NotNull] IConcurrencyDetector concurrencyDetector,
[NotNull] IKeyValueFactorySource keyValueFactorySource,
[NotNull] IInMemoryDatabase database)
: base(stateManager, keyValueFactorySource)
: base(stateManager, concurrencyDetector, keyValueFactorySource)
{
Check.NotNull(database, nameof(database));

_database = database;
}

public override QueryContext Create()
=> new InMemoryQueryContext(CreateQueryBuffer, _database.Store, StateManager);
=> new InMemoryQueryContext(CreateQueryBuffer, _database.Store, StateManager, ConcurrencyDetector);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ public class RelationalQueryContextFactory : QueryContextFactory

public RelationalQueryContextFactory(
[NotNull] IStateManager stateManager,
[NotNull] IConcurrencyDetector concurrencyDetector,
[NotNull] IKeyValueFactorySource keyValueFactorySource,
[NotNull] IRelationalConnection connection)
: base(stateManager, keyValueFactorySource)
: base(stateManager, concurrencyDetector, keyValueFactorySource)
{
_connection = connection;
}

public override QueryContext Create()
=> new RelationalQueryContext(CreateQueryBuffer, _connection, StateManager);
=> new RelationalQueryContext(CreateQueryBuffer, _connection, StateManager, ConcurrencyDetector);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ public class RelationalQueryContext : QueryContext
public RelationalQueryContext(
[NotNull] Func<IQueryBuffer> queryBufferFactory,
[NotNull] IRelationalConnection connection,
[NotNull] IStateManager stateManager)
[NotNull] IStateManager stateManager,
[NotNull] IConcurrencyDetector concurrencyDetector)
: base(
Check.NotNull(queryBufferFactory, nameof(queryBufferFactory)),
Check.NotNull(stateManager, nameof(stateManager)))
Check.NotNull(stateManager, nameof(stateManager)),
Check.NotNull(concurrencyDetector, nameof(concurrencyDetector)))
{
Check.NotNull(connection, nameof(connection));

Expand Down
Loading