Skip to content

Commit

Permalink
Nested database context scope.
Browse files Browse the repository at this point in the history
  • Loading branch information
eben-roux committed Aug 5, 2024
1 parent 79af1a9 commit 85cac96
Show file tree
Hide file tree
Showing 15 changed files with 82 additions and 75 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ The default JSON settings structure is as follows:

# DatabaseContextScope

The `DatabaseContextService` contains a collection of the `DatabaseContext` instances created by `IDatabaseContextFactory`. However, these will not be flowed across async contexts.
The `DatabaseContextService` contains a collection of the `DatabaseContext` instances created by `IDatabaseContextFactory`. However, since the `DatabaseContextService` is a singleton the same collection will be used in all thread contexts. This includes not only the same execution context, but also "peered" execution context running in parallel.

To enable async context flow wrap the initial database context creation in a `using` statement:
To enable an individual execution/thread context-specific collection which also enables async context flow wrap the initial database context creation in a new `DatabaseContextScope()`:

``` c#
using (new DatabaseContextScope())
Expand Down
62 changes: 49 additions & 13 deletions Shuttle.Core.Data.Tests/AsyncFixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using System.Transactions;
using NUnit.Framework;

namespace Shuttle.Core.Data.Tests;

public class AsyncFixture : MappingFixture
{
private readonly Query _rowsQuery = new(@"
private readonly Query _rowsQuery = new Query(@"
select
Id,
Name,
Expand Down Expand Up @@ -53,26 +54,61 @@ public void Should_be_able_to_use_the_different_database_context_for_separate_th
}

[Test]
public void Should_be_able_to_use_the_same_connection_name_for_separate_tasks_and_have_them_block_async()
public void Should_not_be_able_to_create_duplicate_database_contexts_async()
{
var tasks = new List<Task>();
using (new DatabaseContextScope())
using (DatabaseContextFactory.Create())
{
Assert.That(()=> DatabaseContextFactory.Create(), Throws.InvalidOperationException);
}
}

[Test]
public async Task Should_be_able_to_create_nested_database_context_scopes_async()
{
await using (DatabaseContextFactory.Create())
{
await DatabaseGateway.ExecuteAsync(new Query("DROP TABLE IF EXISTS dbo.Nested"));
await DatabaseGateway.ExecuteAsync(new Query("CREATE TABLE dbo.Nested (Id int)"));
}

using (new DatabaseContextScope())
await using (DatabaseContextFactory.Create())
{
for (var i = 0; i < 10; i++)
var count = await DatabaseGateway.GetScalarAsync<int>(new Query("select count(*) from dbo.Nested"));

Assert.That(count, Is.Zero);

await DatabaseGateway.ExecuteAsync(new Query("INSERT INTO dbo.Nested (Id) VALUES (1)"));

count = await DatabaseGateway.GetScalarAsync<int>(new Query("select count(*) from dbo.Nested"));

Assert.That(count, Is.EqualTo(1));

using (new TransactionScope(TransactionScopeOption.Required, TransactionScopeAsyncFlowOption.Enabled))
using (new DatabaseContextScope())
await using (DatabaseContextFactory.Create())
{
tasks.Add(Task.Run(() =>
{
using (DatabaseContextFactory.Create())
{
Console.WriteLine($"{DateTime.Now:O}");
DatabaseGateway.GetRowsAsync(_rowsQuery);
}
}));
await DatabaseGateway.ExecuteAsync(new Query("INSERT INTO dbo.Nested (Id) VALUES (2)"));

count = await DatabaseGateway.GetScalarAsync<int>(new Query("select count(*) from dbo.Nested"));

Assert.That(count, Is.EqualTo(2));
}

count = await DatabaseGateway.GetScalarAsync<int>(new Query("select count(*) from dbo.Nested"));

Assert.That(count, Is.EqualTo(1));
}

Task.WaitAll(tasks.ToArray());
await using (DatabaseContextFactory.Create())
{
var count = await DatabaseGateway.GetScalarAsync<int>(new Query("select count(*) from dbo.Nested"));

Assert.That(count, Is.EqualTo(1));

await DatabaseGateway.ExecuteAsync(new Query("DROP TABLE IF EXISTS dbo.Nested"));
}
}

[Test]
Expand Down
14 changes: 2 additions & 12 deletions Shuttle.Core.Data.Tests/DatabaseContextFactoryFixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,6 @@ public void Should_be_able_to_create_a_database_context()
}
}

[Test]
public void Should_be_able_to_time_out_when_creating_another_context_with_the_same_name_as_an_existing_context_tha_does_not_complete()
{
var factory = DatabaseContextFactory;

using (factory.Create(DefaultConnectionStringName))
{
Assert.That(() => factory.Create(DefaultConnectionStringName, TimeSpan.FromMilliseconds(20)), Throws.TypeOf<TimeoutException>());
}
}

[Test]
public void Should_be_able_to_check_connection_availability()
Expand All @@ -37,8 +27,8 @@ public void Should_be_able_to_check_connection_availability()
Assert.That(databaseContextFactory.Object.IsAvailable(new CancellationToken(), 0, 0), Is.True);
Assert.That(databaseContextFactory.Object.IsAvailable("name", new CancellationToken(), 0, 0), Is.True);

databaseContextFactory.Setup(m => m.Create(null)).Throws(new Exception());
databaseContextFactory.Setup(m => m.Create(It.IsAny<string>(), null)).Throws(new Exception());
databaseContextFactory.Setup(m => m.Create()).Throws(new Exception());
databaseContextFactory.Setup(m => m.Create(It.IsAny<string>())).Throws(new Exception());

Assert.That(databaseContextFactory.Object.IsAvailable(new CancellationToken(), 0, 0), Is.False);
Assert.That(databaseContextFactory.Object.IsAvailable("name", new CancellationToken(), 0, 0), Is.False);
Expand Down
15 changes: 7 additions & 8 deletions Shuttle.Core.Data.Tests/DatabaseContextFixture.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System;
using System.Data.Common;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Data.SqlClient;
using Moq;
Expand All @@ -17,7 +16,7 @@ public void Should_not_be_able_to_create_an_invalid_connection()
{
Assert.Throws<ArgumentException>(() =>
{
using (new DatabaseContext("context", "Microsoft.Data.SqlClient", new SqlConnection("```"), new Mock<IDbCommandFactory>().Object, DatabaseContextService, new SemaphoreSlim(1,1)))
using (new DatabaseContext("context", "Microsoft.Data.SqlClient", new SqlConnection("```"), new Mock<IDbCommandFactory>().Object, DatabaseContextService))
{
}
});
Expand All @@ -28,7 +27,7 @@ public void Should_not_be_able_to_create_a_non_existent_connection()
{
Assert.Throws<SqlException>(() =>
{
using (var databaseContext = new DatabaseContext("context", "Microsoft.Data.SqlClient", new SqlConnection("data source=.;initial catalog=idontexist;integrated security=sspi"), new Mock<IDbCommandFactory>().Object, DatabaseContextService, new SemaphoreSlim(0, 1)))
using (var databaseContext = new DatabaseContext("context", "Microsoft.Data.SqlClient", new SqlConnection("data source=.;initial catalog=idontexist;integrated security=sspi"), new Mock<IDbCommandFactory>().Object, DatabaseContextService))
{
databaseContext.CreateCommand(new Query("select 1"));
}
Expand All @@ -49,7 +48,7 @@ public async Task Should_be_able_to_begin_and_commit_a_transaction_async()

private async Task Should_be_able_to_begin_and_commit_a_transaction_async(bool sync)
{
await using (var databaseContext = new DatabaseContext("context", "Microsoft.Data.SqlClient", DbConnectionFactory.Create(DefaultProviderName, DefaultConnectionString), new Mock<IDbCommandFactory>().Object, DatabaseContextService, new SemaphoreSlim(0, 1)))
await using (var databaseContext = new DatabaseContext("context", "Microsoft.Data.SqlClient", DbConnectionFactory.Create(DefaultProviderName, DefaultConnectionString), new Mock<IDbCommandFactory>().Object, DatabaseContextService))
{
if (sync)
{
Expand Down Expand Up @@ -78,7 +77,7 @@ public async Task Should_be_able_to_begin_and_rollback_a_transaction_async()

private async Task Should_be_able_to_begin_and_rollback_a_transaction_async(bool sync)
{
await using (var databaseContext = new DatabaseContext("context", "Microsoft.Data.SqlClient", DbConnectionFactory.Create(DefaultProviderName, DefaultConnectionString), new Mock<IDbCommandFactory>().Object, new DatabaseContextService(), new SemaphoreSlim(0, 1)))
await using (var databaseContext = new DatabaseContext("context", "Microsoft.Data.SqlClient", DbConnectionFactory.Create(DefaultProviderName, DefaultConnectionString), new Mock<IDbCommandFactory>().Object, new DatabaseContextService()))
{
if (sync)
{
Expand All @@ -105,7 +104,7 @@ public async Task Should_be_able_to_call_commit_without_a_transaction_async()

private async Task Should_be_able_to_call_commit_without_a_transaction_async(bool sync)
{
await using (var databaseContext = new DatabaseContext("context", "Microsoft.Data.SqlClient", DbConnectionFactory.Create(DefaultProviderName, DefaultConnectionString), new Mock<IDbCommandFactory>().Object, DatabaseContextService, new SemaphoreSlim(0, 1)))
await using (var databaseContext = new DatabaseContext("context", "Microsoft.Data.SqlClient", DbConnectionFactory.Create(DefaultProviderName, DefaultConnectionString), new Mock<IDbCommandFactory>().Object, DatabaseContextService))
{
if (sync)
{
Expand All @@ -121,7 +120,7 @@ private async Task Should_be_able_to_call_commit_without_a_transaction_async(boo
[Test]
public void Should_be_able_to_call_dispose_more_than_once()
{
using (var databaseContext = new DatabaseContext("context", "Microsoft.Data.SqlClient", DbConnectionFactory.Create(DefaultProviderName, DefaultConnectionString), new Mock<IDbCommandFactory>().Object, DatabaseContextService, new SemaphoreSlim(0, 1)))
using (var databaseContext = new DatabaseContext("context", "Microsoft.Data.SqlClient", DbConnectionFactory.Create(DefaultProviderName, DefaultConnectionString), new Mock<IDbCommandFactory>().Object, DatabaseContextService))
{
databaseContext.Dispose();
databaseContext.Dispose();
Expand All @@ -138,7 +137,7 @@ public void Should_be_able_to_create_a_command()

dbCommandFactory.Setup(m => m.Create(dbConnection, query.Object)).Returns(dbCommand.Object);

using (var databaseContext = new DatabaseContext("context", "Microsoft.Data.SqlClient", dbConnection, dbCommandFactory.Object, DatabaseContextService, new SemaphoreSlim(0, 1)))
using (var databaseContext = new DatabaseContext("context", "Microsoft.Data.SqlClient", dbConnection, dbCommandFactory.Object, DatabaseContextService))
{
databaseContext.CreateCommand(query.Object);
}
Expand Down
5 changes: 2 additions & 3 deletions Shuttle.Core.Data.Tests/DatabaseContextServiceFixture.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System.Data.Common;
using System.Threading;
using Moq;
using NUnit.Framework;

Expand All @@ -12,11 +11,11 @@ public void Should_be_able_to_use_different_contexts()
{
var service = new DatabaseContextService();

var context1 = new DatabaseContext("mock-1", "provider-name", new Mock<DbConnection>().Object, new Mock<IDbCommandFactory>().Object, service, new SemaphoreSlim(1, 1));
var context1 = new DatabaseContext("mock-1", "provider-name", new Mock<DbConnection>().Object, new Mock<IDbCommandFactory>().Object, service);

Assert.That(service.Active.Name, Is.EqualTo(context1.Name));

var context2 = new DatabaseContext("mock-2", "provider-name", new Mock<DbConnection>().Object, new Mock<IDbCommandFactory>().Object, service, new SemaphoreSlim(1, 1));
var context2 = new DatabaseContext("mock-2", "provider-name", new Mock<DbConnection>().Object, new Mock<IDbCommandFactory>().Object, service);

Assert.That(service.Active.Name, Is.EqualTo(context2.Name));

Expand Down
5 changes: 2 additions & 3 deletions Shuttle.Core.Data/.package/package.nuspec
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<package>
<metadata>
<id>Shuttle.Core.Data</id>
<version>16.0.2</version>
<version>17.0.0</version>
<authors>Eben Roux</authors>
<owners>Eben Roux</owners>
<license type="expression">BSD-3-Clause</license>
Expand All @@ -21,8 +21,7 @@
<dependency id="Microsoft.Extensions.Configuration.Binder" version="7.0.3" />
<dependency id="Microsoft.Extensions.DependencyInjection" version="7.0.0" />
<dependency id="Microsoft.Extensions.Options" version="7.0.1" />
<dependency id="Shuttle.Core.Contract" version="11.1.0" />
<dependency id="Shuttle.Core.Threading" version="13.0.0" />
<dependency id="Shuttle.Core.Contract" version="11.1.1" />
</dependencies>
</metadata>
<files>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,5 @@ namespace Shuttle.Core.Data
public class DatabaseContextFactoryOptions
{
public string DefaultConnectionStringName { get; set; }
public TimeSpan DefaultCreateTimeout { get; set; } = TimeSpan.FromSeconds(30);
}
}
5 changes: 1 addition & 4 deletions Shuttle.Core.Data/DatabaseContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,19 @@ namespace Shuttle.Core.Data
{
public class DatabaseContext : IDatabaseContext
{
private readonly SemaphoreSlim _lock;
private readonly IDatabaseContextService _databaseContextService;
private readonly IDbCommandFactory _dbCommandFactory;
private readonly IDbConnection _dbConnection;
private bool _disposed;

public DatabaseContext(string name, string providerName, IDbConnection dbConnection, IDbCommandFactory dbCommandFactory, IDatabaseContextService databaseContextService, SemaphoreSlim semaphoreSlim)
public DatabaseContext(string name, string providerName, IDbConnection dbConnection, IDbCommandFactory dbCommandFactory, IDatabaseContextService databaseContextService)
{
Name = Guard.AgainstNullOrEmptyString(name, nameof(name));
ProviderName = Guard.AgainstNullOrEmptyString(providerName, "providerName");

_dbCommandFactory = Guard.AgainstNull(dbCommandFactory, nameof(dbCommandFactory));
_databaseContextService = Guard.AgainstNull(databaseContextService, nameof(databaseContextService));
_dbConnection = Guard.AgainstNull(dbConnection, nameof(dbConnection));
_lock = Guard.AgainstNull(semaphoreSlim, nameof(semaphoreSlim));

_databaseContextService.Add(this);
}
Expand Down Expand Up @@ -105,7 +103,6 @@ public void Dispose()
}
finally
{
_lock.Release();
_disposed = true;
}

Expand Down
19 changes: 4 additions & 15 deletions Shuttle.Core.Data/DatabaseContextFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ public class DatabaseContextFactory : IDatabaseContextFactory

private readonly IDbConnectionFactory _dbConnectionFactory;
private readonly SemaphoreSlim _lock = new SemaphoreSlim(1, 1);
private readonly Dictionary<string, SemaphoreSlim> _semaphores = new Dictionary<string, SemaphoreSlim>();

public DatabaseContextFactory(IOptionsMonitor<ConnectionStringOptions> connectionStringOptions, IOptions<DataAccessOptions> dataAccessOptions, IDbConnectionFactory dbConnectionFactory, IDbCommandFactory dbCommandFactory, IDatabaseContextService databaseContextService)
{
Expand All @@ -32,22 +31,12 @@ public DatabaseContextFactory(IOptionsMonitor<ConnectionStringOptions> connectio

public event EventHandler<DatabaseContextEventArgs> DatabaseContextCreated;

public IDatabaseContext Create(string connectionStringName, TimeSpan? timeout = null)
public IDatabaseContext Create(string connectionStringName)
{
Guard.AgainstNullOrEmptyString(connectionStringName, nameof(connectionStringName));

_lock.Wait();

if (!_semaphores.ContainsKey(connectionStringName))
{
_semaphores.Add(connectionStringName, new SemaphoreSlim(1, 1));
}

if (!_semaphores[connectionStringName].Wait(timeout ?? _dataAccessOptions.DatabaseContextFactory.DefaultCreateTimeout))
{
throw new TimeoutException(string.Format(Resources.DatabaseContextFactoryTimeoutException, connectionStringName, timeout ?? _dataAccessOptions.DatabaseContextFactory.DefaultCreateTimeout));
}

try
{
var connectionStringOptions = _connectionStringOptions.Get(connectionStringName);
Expand All @@ -62,7 +51,7 @@ public IDatabaseContext Create(string connectionStringName, TimeSpan? timeout =
throw new InvalidOperationException(string.Format(Resources.DuplicateDatabaseContextException, connectionStringName));
}

var databaseContext = new DatabaseContext(connectionStringName, connectionStringOptions.ProviderName, (DbConnection)_dbConnectionFactory.Create(connectionStringOptions.ProviderName, connectionStringOptions.ConnectionString), _dbCommandFactory, _databaseContextService, _semaphores[connectionStringName]);
var databaseContext = new DatabaseContext(connectionStringName, connectionStringOptions.ProviderName, (DbConnection)_dbConnectionFactory.Create(connectionStringOptions.ProviderName, connectionStringOptions.ConnectionString), _dbCommandFactory, _databaseContextService);

DatabaseContextCreated?.Invoke(this, new DatabaseContextEventArgs(databaseContext));

Expand All @@ -74,14 +63,14 @@ public IDatabaseContext Create(string connectionStringName, TimeSpan? timeout =
}
}

public IDatabaseContext Create(TimeSpan? timeout = null)
public IDatabaseContext Create()
{
if (string.IsNullOrEmpty(_dataAccessOptions.DatabaseContextFactory.DefaultConnectionStringName))
{
throw new InvalidOperationException(Resources.DatabaseContextFactoryOptionsException);
}

return Create(_dataAccessOptions.DatabaseContextFactory.DefaultConnectionStringName, timeout);
return Create(_dataAccessOptions.DatabaseContextFactory.DefaultConnectionStringName);
}
}
}
10 changes: 7 additions & 3 deletions Shuttle.Core.Data/DatabaseContextScope.cs
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
using System;
using System.Collections.Generic;
using System.Threading;

namespace Shuttle.Core.Data
{
public class DatabaseContextScope : IDisposable
{
private static readonly AsyncLocal<Stack<DatabaseContextCollection>> DatabaseContextCollectionStack = new AsyncLocal<Stack<DatabaseContextCollection>>();
private static readonly AsyncLocal<DatabaseContextCollection> AmbientData = new AsyncLocal<DatabaseContextCollection>();

public DatabaseContextScope()
{
if (AmbientData.Value != null)
if (DatabaseContextCollectionStack.Value == null)
{
throw new InvalidOperationException(Resources.AmbientScopeException);
DatabaseContextCollectionStack.Value = new Stack<DatabaseContextCollection>();
}

DatabaseContextCollectionStack.Value.Push(AmbientData.Value);

AmbientData.Value = new DatabaseContextCollection();
}

public static DatabaseContextCollection Current => AmbientData.Value;

public void Dispose()
{
AmbientData.Value = null;
AmbientData.Value = DatabaseContextCollectionStack.Value.Pop();
}
}
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
using System;
using System.Data;
using System.Data.Common;
using System.Threading;
using Shuttle.Core.Contract;

namespace Shuttle.Core.Data
{
public static class DatabaseContextFactoryExtensions
{
public static bool IsAvailable(this IDatabaseContextFactory databaseContextFactory,
CancellationToken cancellationToken, int retries = 4, int secondsBetweenRetries = 15)
public static bool IsAvailable(this IDatabaseContextFactory databaseContextFactory, CancellationToken cancellationToken, int retries = 4, int secondsBetweenRetries = 15)
{
Guard.AgainstNull(databaseContextFactory, nameof(databaseContextFactory));

Expand Down
Loading

0 comments on commit 85cac96

Please sign in to comment.