From 6624bc3d7dd2c27e82966748f335c90f402d0c56 Mon Sep 17 00:00:00 2001 From: Sebastien Ros Date: Thu, 2 May 2024 12:52:23 -0700 Subject: [PATCH] Add thread-safety checks option --- src/YesSql.Abstractions/IConfiguration.cs | 13 ++++ src/YesSql.Core/Configuration.cs | 2 + src/YesSql.Core/Services/DefaultQuery.cs | 19 ++++++ src/YesSql.Core/Session.cs | 72 +++++++++++++++++++++-- test/YesSql.Tests/CoreTests.cs | 46 +++++++++++++++ 5 files changed, 146 insertions(+), 6 deletions(-) diff --git a/src/YesSql.Abstractions/IConfiguration.cs b/src/YesSql.Abstractions/IConfiguration.cs index 423be7a4..52b782bb 100644 --- a/src/YesSql.Abstractions/IConfiguration.cs +++ b/src/YesSql.Abstractions/IConfiguration.cs @@ -66,6 +66,19 @@ public interface IConfiguration /// bool QueryGatingEnabled { get; set; } + /// + /// Gets or sets whether the thread-safety checks are enabled. + /// + /// + /// When enabled, YesSql will throw an if two threads are trying to execute read or write + /// operations on the database concurrently. This can help investigating thread-safety issue where an + /// instance is shared which is not supported. + /// + /// + /// The default value is . + /// + public bool EnableThreadSafetyChecks { get; set; } + /// /// Gets the collection of types that must be checked for concurrency. /// diff --git a/src/YesSql.Core/Configuration.cs b/src/YesSql.Core/Configuration.cs index 4bac7823..4fed7c32 100644 --- a/src/YesSql.Core/Configuration.cs +++ b/src/YesSql.Core/Configuration.cs @@ -21,6 +21,7 @@ public Configuration() TablePrefix = ""; CommandsPageSize = 500; QueryGatingEnabled = true; + EnableThreadSafetyChecks = false; Logger = NullLogger.Instance; ConcurrentTypes = new HashSet(); TableNameConvention = new DefaultTableNameConvention(); @@ -35,6 +36,7 @@ public Configuration() public string Schema { get; set; } public int CommandsPageSize { get; set; } public bool QueryGatingEnabled { get; set; } + public bool EnableThreadSafetyChecks { get; set; } public IIdGenerator IdGenerator { get; set; } public ILogger Logger { get; set; } public HashSet ConcurrentTypes { get; } diff --git a/src/YesSql.Core/Services/DefaultQuery.cs b/src/YesSql.Core/Services/DefaultQuery.cs index 66de7368..4cbc08d2 100644 --- a/src/YesSql.Core/Services/DefaultQuery.cs +++ b/src/YesSql.Core/Services/DefaultQuery.cs @@ -1106,6 +1106,8 @@ public async Task CountAsync() var parameters = localBuilder.Parameters; var key = new WorkerQueryKey(sql, localBuilder.Parameters); + _session.EnterAsyncExecution(); + try { return await _session._store.ProduceAsync(key, static (state) => @@ -1127,6 +1129,10 @@ public async Task CountAsync() throw; } + finally + { + _session.ExitAsyncExecution(); + } } IQuery IQuery.For(bool filterType) @@ -1206,6 +1212,8 @@ protected async Task FirstOrDefaultImpl() _query.Page(1, 0); + _query._session.EnterAsyncExecution(); + try { if (typeof(IIndex).IsAssignableFrom(typeof(T))) @@ -1259,6 +1267,10 @@ protected async Task FirstOrDefaultImpl() await _query._session.CancelAsync(); throw; } + finally + { + _query._session.ExitAsyncExecution(); + } } Task> IQuery.ListAsync() @@ -1297,6 +1309,8 @@ internal async Task> ListImpl() } } + _query._session.EnterAsyncExecution(); + try { if (typeof(IIndex).IsAssignableFrom(typeof(T))) @@ -1311,6 +1325,7 @@ internal async Task> ListImpl() var sql = sqlBuilder.ToSqlString(); var key = new WorkerQueryKey(sql, _query._queryState._sqlBuilder.Parameters); + return await _query._session._store.ProduceAsync(key, static (state) => { var logger = state.Query._session._store.Configuration.Logger; @@ -1365,6 +1380,10 @@ internal async Task> ListImpl() throw; } + finally + { + _query._session.ExitAsyncExecution(); + } } private string GetDeduplicatedQuery() diff --git a/src/YesSql.Core/Session.cs b/src/YesSql.Core/Session.cs index 49b0faa5..005140b2 100644 --- a/src/YesSql.Core/Session.cs +++ b/src/YesSql.Core/Session.cs @@ -6,6 +6,7 @@ using System.Data.Common; using System.Linq; using System.Linq.Expressions; +using System.Threading; using System.Threading.Tasks; using YesSql.Commands; using YesSql.Data; @@ -35,6 +36,10 @@ public class Session : ISession private readonly ILogger _logger; private readonly bool _withTracking; + private readonly bool _enableThreadSafetyChecks; + private int _asyncOperations = 0; + private string _previousStackTrace = null; + public Session(Store store, bool withTracking = true) { _store = store; @@ -43,6 +48,7 @@ public Session(Store store, bool withTracking = true) _logger = store.Configuration.Logger; _withTracking = withTracking; _defaultState = new SessionState(); + _enableThreadSafetyChecks = _store.Configuration.EnableThreadSafetyChecks; _collectionStates = new Dictionary() { [string.Empty] = _defaultState @@ -500,7 +506,7 @@ public async Task> GetAsync(long[] ids, string collection = nu var key = new WorkerQueryKey(nameof(GetAsync), ids); try { - var documents = await _store.ProduceAsync(key, (state) => + var documents = await _store.ProduceAsync(key, static (state) => { var logger = state.Store.Configuration.Logger; @@ -661,7 +667,12 @@ public void Dispose() GC.SuppressFinalize(this); } - public async Task FlushAsync() + public Task FlushAsync() + { + return FlushInternalAsync(false); + } + + private async Task FlushInternalAsync(bool saving) { if (!HasWork()) { @@ -684,6 +695,12 @@ public async Task FlushAsync() CheckDisposed(); + // Only check thread-safety if not called from SaveChangesAsync + if (!saving) + { + EnterAsyncExecution(); + } + try { // saving all tracked entities @@ -767,6 +784,12 @@ public async Task FlushAsync() _commands?.Clear(); _flushing = false; + + // Only check thread-safety if not called from SaveChangesAsync + if (!saving) + { + ExitAsyncExecution(); + } } } @@ -856,13 +879,40 @@ private void BatchCommands() _commands.AddRange(batches); } + public void EnterAsyncExecution() + { + if (!_enableThreadSafetyChecks) + { + return; + } + + if (Interlocked.Increment(ref _asyncOperations) > 1) + { + throw new InvalidOperationException($"Two concurrent threads have been detected accessing the same ISession instance from: \n{Environment.StackTrace}\nand:\n{_previousStackTrace}\n---"); + } + + _previousStackTrace = Environment.StackTrace; + } + + public void ExitAsyncExecution() + { + if (!_enableThreadSafetyChecks) + { + return; + } + + Interlocked.Decrement(ref _asyncOperations); + } + public async Task SaveChangesAsync() { + EnterAsyncExecution(); + try { if (!_cancel) { - await FlushAsync(); + await FlushInternalAsync(true); _save = true; } @@ -870,6 +920,7 @@ public async Task SaveChangesAsync() finally { await CommitOrRollbackTransactionAsync(); + ExitAsyncExecution(); } } @@ -1362,11 +1413,20 @@ public async Task BeginTransactionAsync(IsolationLevel isolationL public Task CancelAsync() { - CheckDisposed(); + EnterAsyncExecution(); + + try + { + CheckDisposed(); - _cancel = true; + _cancel = true; - return ReleaseTransactionAsync(); + return ReleaseTransactionAsync(); + } + finally + { + ExitAsyncExecution(); + } } public IStore Store => _store; diff --git a/test/YesSql.Tests/CoreTests.cs b/test/YesSql.Tests/CoreTests.cs index 9da1eaa2..7c5156d4 100644 --- a/test/YesSql.Tests/CoreTests.cs +++ b/test/YesSql.Tests/CoreTests.cs @@ -5,6 +5,7 @@ using System.Data.Common; using System.Diagnostics; using System.Linq; +using System.Security.Cryptography; using System.Threading; using System.Threading.Tasks; using Xunit; @@ -6401,6 +6402,51 @@ await session.SaveAsync(new Article Assert.Equal(10, result); } + [Fact] + public virtual async Task ShouldDetectThreadSafetyIssues() + { + try + { + _store.Configuration.EnableThreadSafetyChecks = true; + + await using var session = _store.CreateSession(); + + _store.Configuration.EnableThreadSafetyChecks = false; + + var person = new Person { Firstname = "Bill" }; + await session.SaveAsync(person); + await session.SaveChangesAsync(); + + Task[] tasks = null; + + var throws = Assert.ThrowsAsync(async () => + { + tasks = Enumerable.Range(0, 10).Select(x => Task.Run(DoWork)).ToArray(); + await Task.WhenAll(tasks); + }); + + var result = Task.WaitAny(throws, Task.Delay(5000)); + + Assert.Equal(0, result); + + async Task DoWork() + { + while (true) + { + var p = await session.Query().FirstOrDefaultAsync(); + Assert.NotNull(p); + + person.Firstname = "Bill" + RandomNumberGenerator.GetInt32(100); + await session.FlushAsync(); + } + } + } + finally + { + _store.Configuration.EnableThreadSafetyChecks = false; + } + } + #region FilterTests [Fact]