Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add thread-safety checks option #540

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/YesSql.Abstractions/IConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,19 @@ public interface IConfiguration
/// </summary>
bool QueryGatingEnabled { get; set; }

/// <summary>
/// Gets or sets whether the thread-safety checks are enabled.
/// </summary>
/// <remarks>
/// When enabled, YesSql will throw an <see cref="InvalidOperationException" /> if two threads are trying to execute read or write
/// operations on the database concurrently. This can help investigating thread-safety issue where an <see cref="ISession"/>
/// instance is shared which is not supported.
/// </remarks>
/// <value>
/// The default value is <see langword="false"/>.
/// </value>
public bool EnableThreadSafetyChecks { get; set; }

/// <summary>
/// Gets the collection of types that must be checked for concurrency.
/// </summary>
Expand Down
2 changes: 2 additions & 0 deletions src/YesSql.Core/Configuration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public Configuration()
TablePrefix = "";
CommandsPageSize = 500;
QueryGatingEnabled = true;
EnableThreadSafetyChecks = false;
Logger = NullLogger.Instance;
ConcurrentTypes = new HashSet<Type>();
TableNameConvention = new DefaultTableNameConvention();
Expand All @@ -35,6 +36,7 @@ public Configuration()
public string Schema { get; set; }
public int CommandsPageSize { get; set; }
public bool QueryGatingEnabled { get; set; }
public bool EnableThreadSafetyChecks { get; set; }
public IIdGenerator IdGenerator { get; set; }
public ILogger Logger { get; set; }
public HashSet<Type> ConcurrentTypes { get; }
Expand Down
19 changes: 19 additions & 0 deletions src/YesSql.Core/Services/DefaultQuery.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,8 @@ public async Task<int> CountAsync()
var parameters = localBuilder.Parameters;
var key = new WorkerQueryKey(sql, localBuilder.Parameters);

_session.EnterAsyncExecution();

try
{
return await _session._store.ProduceAsync(key, static (state) =>
Expand All @@ -1127,6 +1129,10 @@ public async Task<int> CountAsync()

throw;
}
finally
{
_session.ExitAsyncExecution();
}
}

IQuery<T> IQuery.For<T>(bool filterType)
Expand Down Expand Up @@ -1206,6 +1212,8 @@ protected async Task<T> FirstOrDefaultImpl()

_query.Page(1, 0);

_query._session.EnterAsyncExecution();

try
{
if (typeof(IIndex).IsAssignableFrom(typeof(T)))
Expand Down Expand Up @@ -1259,6 +1267,10 @@ protected async Task<T> FirstOrDefaultImpl()
await _query._session.CancelAsync();
throw;
}
finally
{
_query._session.ExitAsyncExecution();
}
}

Task<IEnumerable<T>> IQuery<T>.ListAsync()
Expand Down Expand Up @@ -1297,6 +1309,8 @@ internal async Task<IEnumerable<T>> ListImpl()
}
}

_query._session.EnterAsyncExecution();

try
{
if (typeof(IIndex).IsAssignableFrom(typeof(T)))
Expand All @@ -1311,6 +1325,7 @@ internal async Task<IEnumerable<T>> ListImpl()

var sql = sqlBuilder.ToSqlString();
var key = new WorkerQueryKey(sql, _query._queryState._sqlBuilder.Parameters);

return await _query._session._store.ProduceAsync(key, static (state) =>
{
var logger = state.Query._session._store.Configuration.Logger;
Expand Down Expand Up @@ -1365,6 +1380,10 @@ internal async Task<IEnumerable<T>> ListImpl()

throw;
}
finally
{
_query._session.ExitAsyncExecution();
}
}

private string GetDeduplicatedQuery()
Expand Down
72 changes: 66 additions & 6 deletions src/YesSql.Core/Session.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Data.Common;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;
using YesSql.Commands;
using YesSql.Data;
Expand Down Expand Up @@ -35,6 +36,10 @@ public class Session : ISession
private readonly ILogger _logger;
private readonly bool _withTracking;

private readonly bool _enableThreadSafetyChecks;
private int _asyncOperations = 0;
private string _previousStackTrace = null;

public Session(Store store, bool withTracking = true)
{
_store = store;
Expand All @@ -43,6 +48,7 @@ public Session(Store store, bool withTracking = true)
_logger = store.Configuration.Logger;
_withTracking = withTracking;
_defaultState = new SessionState();
_enableThreadSafetyChecks = _store.Configuration.EnableThreadSafetyChecks;
_collectionStates = new Dictionary<string, SessionState>()
{
[string.Empty] = _defaultState
Expand Down Expand Up @@ -500,7 +506,7 @@ public async Task<IEnumerable<T>> GetAsync<T>(long[] ids, string collection = nu
var key = new WorkerQueryKey(nameof(GetAsync), ids);
try
{
var documents = await _store.ProduceAsync(key, (state) =>
var documents = await _store.ProduceAsync(key, static (state) =>
{
var logger = state.Store.Configuration.Logger;

Expand Down Expand Up @@ -661,7 +667,12 @@ public void Dispose()
GC.SuppressFinalize(this);
}

public async Task FlushAsync()
public Task FlushAsync()
{
return FlushInternalAsync(false);
}

private async Task FlushInternalAsync(bool saving)
{
if (!HasWork())
{
Expand All @@ -684,6 +695,12 @@ public async Task FlushAsync()

CheckDisposed();

// Only check thread-safety if not called from SaveChangesAsync
if (!saving)
{
EnterAsyncExecution();
}

try
{
// saving all tracked entities
Expand Down Expand Up @@ -767,6 +784,12 @@ public async Task FlushAsync()

_commands?.Clear();
_flushing = false;

// Only check thread-safety if not called from SaveChangesAsync
if (!saving)
{
ExitAsyncExecution();
}
}
}

Expand Down Expand Up @@ -856,20 +879,48 @@ private void BatchCommands()
_commands.AddRange(batches);
}

public void EnterAsyncExecution()
{
if (!_enableThreadSafetyChecks)
{
return;
}

if (Interlocked.Increment(ref _asyncOperations) > 1)
{
throw new InvalidOperationException($"Two concurrent threads have been detected accessing the same ISession instance from: \n{Environment.StackTrace}\nand:\n{_previousStackTrace}\n---");
}

_previousStackTrace = Environment.StackTrace;
}

public void ExitAsyncExecution()
{
if (!_enableThreadSafetyChecks)
{
return;
}

Interlocked.Decrement(ref _asyncOperations);
}

public async Task SaveChangesAsync()
{
EnterAsyncExecution();

try
{
if (!_cancel)
{
await FlushAsync();
await FlushInternalAsync(true);

_save = true;
}
}
finally
{
await CommitOrRollbackTransactionAsync();
ExitAsyncExecution();
}
}

Expand Down Expand Up @@ -1362,11 +1413,20 @@ public async Task<DbTransaction> BeginTransactionAsync(IsolationLevel isolationL

public Task CancelAsync()
{
CheckDisposed();
EnterAsyncExecution();

try
{
CheckDisposed();

_cancel = true;
_cancel = true;

return ReleaseTransactionAsync();
return ReleaseTransactionAsync();
}
finally
{
ExitAsyncExecution();
}
}

public IStore Store => _store;
Expand Down
46 changes: 46 additions & 0 deletions test/YesSql.Tests/CoreTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Data.Common;
using System.Diagnostics;
using System.Linq;
using System.Security.Cryptography;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
Expand Down Expand Up @@ -6401,6 +6402,51 @@
Assert.Equal(10, result);
}

[Fact]
public virtual async Task ShouldDetectThreadSafetyIssues()
{
try
{
_store.Configuration.EnableThreadSafetyChecks = true;

await using var session = _store.CreateSession();

_store.Configuration.EnableThreadSafetyChecks = false;

var person = new Person { Firstname = "Bill" };
await session.SaveAsync(person);
await session.SaveChangesAsync();

Task[] tasks = null;

var throws = Assert.ThrowsAsync<InvalidOperationException>(async () =>
{
tasks = Enumerable.Range(0, 10).Select(x => Task.Run(DoWork)).ToArray();
await Task.WhenAll(tasks);
});

var result = Task.WaitAny(throws, Task.Delay(5000));

Check warning on line 6428 in test/YesSql.Tests/CoreTests.cs

View workflow job for this annotation

GitHub Actions / build

Test methods should not use blocking task operations, as they can cause deadlocks. Use an async test method and await instead. (https://xunit.net/xunit.analyzers/rules/xUnit1031)

Check warning on line 6428 in test/YesSql.Tests/CoreTests.cs

View workflow job for this annotation

GitHub Actions / build

Test methods should not use blocking task operations, as they can cause deadlocks. Use an async test method and await instead. (https://xunit.net/xunit.analyzers/rules/xUnit1031)

Check warning on line 6428 in test/YesSql.Tests/CoreTests.cs

View workflow job for this annotation

GitHub Actions / build

Test methods should not use blocking task operations, as they can cause deadlocks. Use an async test method and await instead. (https://xunit.net/xunit.analyzers/rules/xUnit1031)

Check warning on line 6428 in test/YesSql.Tests/CoreTests.cs

View workflow job for this annotation

GitHub Actions / build

Test methods should not use blocking task operations, as they can cause deadlocks. Use an async test method and await instead. (https://xunit.net/xunit.analyzers/rules/xUnit1031)

Check warning on line 6428 in test/YesSql.Tests/CoreTests.cs

View workflow job for this annotation

GitHub Actions / build

Test methods should not use blocking task operations, as they can cause deadlocks. Use an async test method and await instead. (https://xunit.net/xunit.analyzers/rules/xUnit1031)

Check warning on line 6428 in test/YesSql.Tests/CoreTests.cs

View workflow job for this annotation

GitHub Actions / build

Test methods should not use blocking task operations, as they can cause deadlocks. Use an async test method and await instead. (https://xunit.net/xunit.analyzers/rules/xUnit1031)

Assert.Equal(0, result);

async Task DoWork()
{
while (true)
{
var p = await session.Query<Person>().FirstOrDefaultAsync();
Assert.NotNull(p);

person.Firstname = "Bill" + RandomNumberGenerator.GetInt32(100);
await session.FlushAsync();
}
}
}
finally
{
_store.Configuration.EnableThreadSafetyChecks = false;
}
}

#region FilterTests

[Fact]
Expand Down
Loading