diff --git a/CHANGELOG.md b/CHANGELOG.md index ef427dd..d240ea4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ * Support multiple schemas in a single database ([#110](https://github.com/microsoft/durabletask-mssql/pull/110)) - contributed by [@AndreiRR24](https://github.com/AndreiRR24) +### Updates + +* Fixed issue where external client ignores task hub name configuration ([#128](https://github.com/microsoft/durabletask-mssql/issues/128)) - contributed by [@bhugot](https://github.com/bhugot) + ## v1.0.1 ### Updates diff --git a/src/DurableTask.SqlServer.AzureFunctions/SqlDurabilityProvider.cs b/src/DurableTask.SqlServer.AzureFunctions/SqlDurabilityProvider.cs index 1ebaf0f..f49a448 100644 --- a/src/DurableTask.SqlServer.AzureFunctions/SqlDurabilityProvider.cs +++ b/src/DurableTask.SqlServer.AzureFunctions/SqlDurabilityProvider.cs @@ -32,17 +32,7 @@ public SqlDurabilityProvider( this.service = service ?? throw new ArgumentNullException(nameof(service)); this.durabilityOptions = durabilityOptions; } - - public SqlDurabilityProvider( - SqlOrchestrationService service, - SqlDurabilityOptions durabilityOptions, - IOrchestrationServiceClient client) - : base(Name, service, client, durabilityOptions.ConnectionStringName) - { - this.service = service ?? throw new ArgumentNullException(nameof(service)); - this.durabilityOptions = durabilityOptions; - } - + public override bool GuaranteesOrderedDelivery => true; public override JObject ConfigurationJson => JObject.FromObject(this.durabilityOptions); diff --git a/src/DurableTask.SqlServer.AzureFunctions/SqlDurabilityProviderFactory.cs b/src/DurableTask.SqlServer.AzureFunctions/SqlDurabilityProviderFactory.cs index 92a5629..6ceb1f8 100644 --- a/src/DurableTask.SqlServer.AzureFunctions/SqlDurabilityProviderFactory.cs +++ b/src/DurableTask.SqlServer.AzureFunctions/SqlDurabilityProviderFactory.cs @@ -5,7 +5,6 @@ namespace DurableTask.SqlServer.AzureFunctions { using System; using System.Collections.Generic; - using DurableTask.Core; using Microsoft.Azure.WebJobs.Extensions.DurableTask; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; @@ -24,8 +23,6 @@ class SqlDurabilityProviderFactory : IDurabilityProviderFactory readonly IConnectionInfoResolver connectionInfoResolver; SqlDurabilityOptions? defaultOptions; - SqlOrchestrationServiceSettings? orchestrationServiceSettings; - SqlOrchestrationService? service; SqlDurabilityProvider? defaultProvider; /// @@ -56,7 +53,7 @@ public DurabilityProvider GetDurabilityProvider() if (this.defaultProvider == null) { SqlDurabilityOptions sqlProviderOptions = this.GetDefaultSqlOptions(); - SqlOrchestrationService service = this.GetOrchestrationService(); + SqlOrchestrationService service = this.GetOrchestrationService(sqlProviderOptions); this.defaultProvider = new SqlDurabilityProvider(service, sqlProviderOptions); } @@ -66,13 +63,6 @@ public DurabilityProvider GetDurabilityProvider() // Called by the Durable client binding infrastructure public DurabilityProvider GetDurabilityProvider(DurableClientAttribute attribute) { - // TODO: Much of this logic should go into the base class - if (string.IsNullOrEmpty(attribute.ConnectionName) && - string.IsNullOrEmpty(attribute.TaskHub)) - { - return this.GetDurabilityProvider(); - } - lock (this.clientProviders) { string key = GetDurabilityProviderKey(attribute); @@ -82,34 +72,27 @@ public DurabilityProvider GetDurabilityProvider(DurableClientAttribute attribute } SqlDurabilityOptions clientOptions = this.GetSqlOptions(attribute); - IOrchestrationServiceClient serviceClient = - new SqlOrchestrationService(clientOptions.GetOrchestrationServiceSettings( - this.extensionOptions, - this.connectionInfoResolver)); + SqlOrchestrationService orchestrationService = + this.GetOrchestrationService(clientOptions); clientProvider = new SqlDurabilityProvider( - this.GetOrchestrationService(), - clientOptions, - serviceClient); + orchestrationService, + clientOptions); this.clientProviders.Add(key, clientProvider); return clientProvider; } } - static string GetDurabilityProviderKey(DurableClientAttribute attribute) + SqlOrchestrationService GetOrchestrationService(SqlDurabilityOptions clientOptions) { - return attribute.ConnectionName + "|" + attribute.TaskHub; + return new (clientOptions.GetOrchestrationServiceSettings( + this.extensionOptions, + this.connectionInfoResolver)); } - SqlOrchestrationService GetOrchestrationService() + static string GetDurabilityProviderKey(DurableClientAttribute attribute) { - if (this.service == null) - { - SqlOrchestrationServiceSettings settings = this.GetOrchestrationServiceSettings(); - this.service = new SqlOrchestrationService(settings); - } - - return this.service; + return attribute.ConnectionName + "|" + attribute.TaskHub; } SqlDurabilityOptions GetDefaultSqlOptions() @@ -148,18 +131,5 @@ SqlDurabilityOptions GetSqlOptions(DurableClientAttribute attribute) return options; } - - SqlOrchestrationServiceSettings GetOrchestrationServiceSettings() - { - if (this.orchestrationServiceSettings == null) - { - SqlDurabilityOptions options = this.GetDefaultSqlOptions(); - this.orchestrationServiceSettings = options.GetOrchestrationServiceSettings( - this.extensionOptions, - this.connectionInfoResolver); - } - - return this.orchestrationServiceSettings; - } } } diff --git a/test/DurableTask.SqlServer.AzureFunctions.Tests/CoreScenarios.cs b/test/DurableTask.SqlServer.AzureFunctions.Tests/CoreScenarios.cs index 95f13d2..de67911 100644 --- a/test/DurableTask.SqlServer.AzureFunctions.Tests/CoreScenarios.cs +++ b/test/DurableTask.SqlServer.AzureFunctions.Tests/CoreScenarios.cs @@ -9,17 +9,23 @@ namespace DurableTask.SqlServer.AzureFunctions.Tests using System.Threading; using System.Threading.Tasks; using DurableTask.SqlServer.Tests.Utils; - using Microsoft.Azure.WebJobs; using Microsoft.Azure.WebJobs.Extensions.DurableTask; using Newtonsoft.Json; using Newtonsoft.Json.Linq; using Xunit; using Xunit.Abstractions; + [Collection("CoreScenarios")] public class CoreScenarios : IntegrationTestBase { + protected CoreScenarios(ITestOutputHelper output, string taskHubName, bool multiTenancy) + : base(output, taskHubName, multiTenancy) + { + this.AddFunctions(typeof(Functions)); + } + public CoreScenarios(ITestOutputHelper output) - : base(output) + : base(output, "TaskHubWithMultiTenancy", true) { this.AddFunctions(typeof(Functions)); } @@ -73,7 +79,7 @@ public async Task CanOrchestrateEntities() [Fact] public async Task CanInteractWithEntities() { - IDurableClient client = await this.GetDurableClientAsync(); + IDurableClient client = this.GetDurableClient(); var entityId = new EntityId(nameof(Functions.Counter), Guid.NewGuid().ToString("N")); EntityStateResponse result = await client.ReadEntityStateAsync(entityId); @@ -101,7 +107,7 @@ public async Task SingleInstanceQuery() Assert.Equal(OrchestrationRuntimeStatus.Completed, status.RuntimeStatus); Assert.Equal(10, (int)status.Output); - IDurableClient client = await this.GetDurableClientAsync(); + IDurableClient client = this.GetDurableClient(); status = await client.GetStatusAsync(status.InstanceId); Assert.NotNull(status.Input); @@ -154,7 +160,7 @@ await Enumerable.Range(0, 5).ParallelForEachAsync(5, i => DateTime afterAllFinishedTime = SharedTestHelpers.GetCurrentDatabaseTimeUtc(); - IDurableClient client = await this.GetDurableClientAsync(); + IDurableClient client = this.GetDurableClient(); // Create one condition object and reuse it for multiple queries var condition = new OrchestrationStatusQueryCondition(); @@ -268,15 +274,16 @@ await Enumerable.Range(0, 5).ParallelForEachAsync(5, i => Assert.All(result.DurableOrchestrationState, state => Assert.Equal(JValue.CreateNull(), state.Input)); } + [Fact] public async Task SingleInstancePurge() { DurableOrchestrationStatus status = await this.RunOrchestrationAsync(nameof(Functions.NoOp)); Assert.Equal(OrchestrationRuntimeStatus.Completed, status.RuntimeStatus); - IDurableClient client = await this.GetDurableClientAsync(); + IDurableClient client = this.GetDurableClient(); PurgeHistoryResult result; - + // First purge gets the active instance result = await client.PurgeInstanceHistoryAsync(status.InstanceId); Assert.NotNull(result); @@ -315,7 +322,7 @@ public async Task MultiInstancePurge() DateTime endTime = SharedTestHelpers.GetCurrentDatabaseTimeUtc(); - IDurableClient client = await this.GetDurableClientAsync(); + IDurableClient client = this.GetDurableClient(); PurgeHistoryResult purgeResult; // First attempt should delete nothing because it's out of range @@ -357,7 +364,7 @@ public async Task CanRewindOrchestration() Functions.ThrowExceptionInCanFail = true; DurableOrchestrationStatus status = await this.RunOrchestrationAsync(nameof(Functions.RewindOrchestration)); Assert.Equal(OrchestrationRuntimeStatus.Failed, status.RuntimeStatus); - + Functions.ThrowExceptionInCanFail = false; status = await this.RewindOrchestrationAsync(status.InstanceId); Assert.Equal(OrchestrationRuntimeStatus.Completed, status.RuntimeStatus); @@ -370,167 +377,11 @@ public async Task CanRewindSubOrchestration() Functions.ThrowExceptionInCanFail = true; DurableOrchestrationStatus status = await this.RunOrchestrationAsync(nameof(Functions.RewindSubOrchestration)); Assert.Equal(OrchestrationRuntimeStatus.Failed, status.RuntimeStatus); - + Functions.ThrowExceptionInCanFail = false; status = await this.RewindOrchestrationAsync(status.InstanceId); Assert.Equal(OrchestrationRuntimeStatus.Completed, status.RuntimeStatus); Assert.Equal("0,1,2,activity", status.Output); } - - static class Functions - { - [FunctionName(nameof(Sequence))] - public static async Task Sequence( - [OrchestrationTrigger] IDurableOrchestrationContext ctx) - { - int value = 0; - for (int i = 0; i < 10; i++) - { - value = await ctx.CallActivityAsync(nameof(PlusOne), value); - } - - return value; - } - - [FunctionName(nameof(PlusOne))] - public static int PlusOne([ActivityTrigger] int input) => input + 1; - - [FunctionName(nameof(FanOutFanIn))] - public static async Task FanOutFanIn( - [OrchestrationTrigger] IDurableOrchestrationContext ctx) - { - var tasks = new List>(); - for (int i = 0; i < 10; i++) - { - tasks.Add(ctx.CallActivityAsync(nameof(IntToString), i)); - } - - string[] results = await Task.WhenAll(tasks); - Array.Sort(results); - Array.Reverse(results); - return results; - } - - [FunctionName(nameof(IntToString))] - public static string? IntToString([ActivityTrigger] int input) => input.ToString(); - - [FunctionName(nameof(DeterministicGuid))] - public static async Task DeterministicGuid([OrchestrationTrigger] IDurableOrchestrationContext ctx) - { - Guid currentGuid1 = ctx.NewGuid(); - Guid originalGuid1 = await ctx.CallActivityAsync(nameof(Echo), currentGuid1); - if (currentGuid1 != originalGuid1) - { - return false; - } - - Guid currentGuid2 = ctx.NewGuid(); - Guid originalGuid2 = await ctx.CallActivityAsync(nameof(Echo), currentGuid2); - if (currentGuid2 != originalGuid2) - { - return false; - } - - return currentGuid1 != currentGuid2; - } - - [FunctionName(nameof(Echo))] - public static object Echo([ActivityTrigger] IDurableActivityContext ctx) => ctx.GetInput(); - - [FunctionName(nameof(OrchestrateCounterEntity))] - public static async Task OrchestrateCounterEntity( - [OrchestrationTrigger] IDurableOrchestrationContext ctx) - { - var entityId = new EntityId(nameof(Counter), ctx.NewGuid().ToString("N")); - ctx.SignalEntity(entityId, "incr"); - ctx.SignalEntity(entityId, "incr"); - ctx.SignalEntity(entityId, "incr"); - ctx.SignalEntity(entityId, "add", 4); - - using (await ctx.LockAsync(entityId)) - { - int result = await ctx.CallEntityAsync(entityId, "get"); - return result; - } - } - - [FunctionName(nameof(Counter))] - public static void Counter([EntityTrigger] IDurableEntityContext ctx) - { - int current = ctx.GetState(); - switch (ctx.OperationName) - { - case "incr": - ctx.SetState(current + 1); - break; - case "add": - int amount = ctx.GetInput(); - ctx.SetState(current + amount); - break; - case "get": - ctx.Return(current); - break; - case "set": - amount = ctx.GetInput(); - ctx.SetState(amount); - break; - case "delete": - ctx.DeleteState(); - break; - default: - throw new NotImplementedException("No such entity operation"); - } - } - - [FunctionName(nameof(WaitForEvent))] - public static Task WaitForEvent([OrchestrationTrigger] IDurableOrchestrationContext ctx) - { - string name = ctx.GetInput(); - return ctx.WaitForExternalEvent(name); - } - - [FunctionName(nameof(NoOp))] - public static Task NoOp([OrchestrationTrigger] IDurableOrchestrationContext ctx) => Task.CompletedTask; - - [FunctionName(nameof(SubOrchestrationTest))] - public static async Task SubOrchestrationTest([OrchestrationTrigger] IDurableOrchestrationContext ctx) - { - await ctx.CallSubOrchestratorAsync(nameof(NoOp), "NoOpInstanceId", null); - return "done"; - } - - [FunctionName(nameof(RewindOrchestration))] - public static async Task RewindOrchestration([OrchestrationTrigger] IDurableOrchestrationContext ctx) - { - return await ctx.CallActivityAsync(nameof(CanFail), "activity"); - } - - [FunctionName(nameof(RewindSubOrchestration))] - public static async Task RewindSubOrchestration([OrchestrationTrigger] IDurableOrchestrationContext ctx) - { - var tasks = new List>(); - for (int i = 0; i < 3; i++) - { - tasks.Add(ctx.CallActivityAsync(nameof(CanFail), i.ToString())); - } - tasks.Add(ctx.CallSubOrchestratorAsync(nameof(RewindOrchestration), "RewindOrchestrationId", "suborchestration")); - - var results = await Task.WhenAll(tasks); - return string.Join(',', results); - } - - public static bool ThrowExceptionInCanFail { get; set; } - - [FunctionName(nameof(CanFail))] - public static string CanFail([ActivityTrigger] IDurableActivityContext context) - { - if (ThrowExceptionInCanFail) - { - throw new Exception("exception"); - } - - return context.GetInput(); - } - } } } diff --git a/test/DurableTask.SqlServer.AzureFunctions.Tests/Functions.cs b/test/DurableTask.SqlServer.AzureFunctions.Tests/Functions.cs new file mode 100644 index 0000000..104e6fe --- /dev/null +++ b/test/DurableTask.SqlServer.AzureFunctions.Tests/Functions.cs @@ -0,0 +1,167 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace DurableTask.SqlServer.AzureFunctions.Tests +{ + using System; + using System.Collections.Generic; + using System.Threading.Tasks; + using Microsoft.Azure.WebJobs; + using Microsoft.Azure.WebJobs.Extensions.DurableTask; + + static class Functions + { + [FunctionName(nameof(Sequence))] + public static async Task Sequence( + [OrchestrationTrigger] IDurableOrchestrationContext ctx) + { + int value = 0; + for (int i = 0; i < 10; i++) + { + value = await ctx.CallActivityAsync(nameof(PlusOne), value); + } + + return value; + } + + [FunctionName(nameof(PlusOne))] + public static int PlusOne([ActivityTrigger] int input) => input + 1; + + [FunctionName(nameof(FanOutFanIn))] + public static async Task FanOutFanIn( + [OrchestrationTrigger] IDurableOrchestrationContext ctx) + { + var tasks = new List>(); + for (int i = 0; i < 10; i++) + { + tasks.Add(ctx.CallActivityAsync(nameof(IntToString), i)); + } + + string[] results = await Task.WhenAll(tasks); + Array.Sort(results); + Array.Reverse(results); + return results; + } + + [FunctionName(nameof(IntToString))] + public static string? IntToString([ActivityTrigger] int input) => input.ToString(); + + [FunctionName(nameof(DeterministicGuid))] + public static async Task DeterministicGuid([OrchestrationTrigger] IDurableOrchestrationContext ctx) + { + Guid currentGuid1 = ctx.NewGuid(); + Guid originalGuid1 = await ctx.CallActivityAsync(nameof(Echo), currentGuid1); + if (currentGuid1 != originalGuid1) + { + return false; + } + + Guid currentGuid2 = ctx.NewGuid(); + Guid originalGuid2 = await ctx.CallActivityAsync(nameof(Echo), currentGuid2); + if (currentGuid2 != originalGuid2) + { + return false; + } + + return currentGuid1 != currentGuid2; + } + + [FunctionName(nameof(Echo))] + public static object Echo([ActivityTrigger] IDurableActivityContext ctx) => ctx.GetInput(); + + [FunctionName(nameof(OrchestrateCounterEntity))] + public static async Task OrchestrateCounterEntity( + [OrchestrationTrigger] IDurableOrchestrationContext ctx) + { + var entityId = new EntityId(nameof(Counter), ctx.NewGuid().ToString("N")); + ctx.SignalEntity(entityId, "incr"); + ctx.SignalEntity(entityId, "incr"); + ctx.SignalEntity(entityId, "incr"); + ctx.SignalEntity(entityId, "add", 4); + + using (await ctx.LockAsync(entityId)) + { + int result = await ctx.CallEntityAsync(entityId, "get"); + return result; + } + } + + [FunctionName(nameof(Counter))] + public static void Counter([EntityTrigger] IDurableEntityContext ctx) + { + int current = ctx.GetState(); + switch (ctx.OperationName) + { + case "incr": + ctx.SetState(current + 1); + break; + case "add": + int amount = ctx.GetInput(); + ctx.SetState(current + amount); + break; + case "get": + ctx.Return(current); + break; + case "set": + amount = ctx.GetInput(); + ctx.SetState(amount); + break; + case "delete": + ctx.DeleteState(); + break; + default: + throw new NotImplementedException("No such entity operation"); + } + } + + [FunctionName(nameof(WaitForEvent))] + public static Task WaitForEvent([OrchestrationTrigger] IDurableOrchestrationContext ctx) + { + string name = ctx.GetInput(); + return ctx.WaitForExternalEvent(name); + } + + [FunctionName(nameof(NoOp))] + public static Task NoOp([OrchestrationTrigger] IDurableOrchestrationContext ctx) => Task.CompletedTask; + + [FunctionName(nameof(SubOrchestrationTest))] + public static async Task SubOrchestrationTest([OrchestrationTrigger] IDurableOrchestrationContext ctx) + { + await ctx.CallSubOrchestratorAsync(nameof(NoOp), "NoOpInstanceId", null); + return "done"; + } + + [FunctionName(nameof(RewindOrchestration))] + public static async Task RewindOrchestration([OrchestrationTrigger] IDurableOrchestrationContext ctx) + { + return await ctx.CallActivityAsync(nameof(CanFail), "activity"); + } + + [FunctionName(nameof(RewindSubOrchestration))] + public static async Task RewindSubOrchestration([OrchestrationTrigger] IDurableOrchestrationContext ctx) + { + var tasks = new List>(); + for (int i = 0; i < 3; i++) + { + tasks.Add(ctx.CallActivityAsync(nameof(CanFail), i.ToString())); + } + tasks.Add(ctx.CallSubOrchestratorAsync(nameof(RewindOrchestration), "RewindOrchestrationId", "suborchestration")); + + var results = await Task.WhenAll(tasks); + return string.Join(',', results); + } + + public static bool ThrowExceptionInCanFail { get; set; } + + [FunctionName(nameof(CanFail))] + public static string CanFail([ActivityTrigger] IDurableActivityContext context) + { + if (ThrowExceptionInCanFail) + { + throw new Exception("exception"); + } + + return context.GetInput(); + } + } +} \ No newline at end of file diff --git a/test/DurableTask.SqlServer.AzureFunctions.Tests/IntegrationTestBase.cs b/test/DurableTask.SqlServer.AzureFunctions.Tests/IntegrationTestBase.cs index 4745caf..515575e 100644 --- a/test/DurableTask.SqlServer.AzureFunctions.Tests/IntegrationTestBase.cs +++ b/test/DurableTask.SqlServer.AzureFunctions.Tests/IntegrationTestBase.cs @@ -12,6 +12,8 @@ namespace DurableTask.SqlServer.AzureFunctions.Tests using DurableTask.SqlServer.Tests.Utils; using Microsoft.Azure.WebJobs; using Microsoft.Azure.WebJobs.Extensions.DurableTask; + using Microsoft.Azure.WebJobs.Extensions.DurableTask.ContextImplementations; + using Microsoft.Azure.WebJobs.Extensions.DurableTask.Options; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; @@ -21,21 +23,43 @@ namespace DurableTask.SqlServer.AzureFunctions.Tests public class IntegrationTestBase : IAsyncLifetime { + readonly bool multiTenancy; + readonly string schema; + readonly string taskHubName; readonly TestLogProvider logProvider; readonly TestFunctionTypeLocator typeLocator; readonly TestSettingsResolver settingsResolver; readonly string testName; readonly IHost functionsHost; - + readonly IDurableClientFactory clientFactory; + TestCredential? testCredential; - public IntegrationTestBase(ITestOutputHelper output) + public IntegrationTestBase(ITestOutputHelper output, string? taskHubName = null, bool multiTenancy = true) { + this.multiTenancy = multiTenancy; + this.schema = multiTenancy ? "dt" : "dt2"; this.logProvider = new TestLogProvider(output); this.typeLocator = new TestFunctionTypeLocator(); this.settingsResolver = new TestSettingsResolver(); this.testName = SharedTestHelpers.GetTestName(output); + this.taskHubName = taskHubName ?? SharedTestHelpers.GetTestName(output); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(this.settingsResolver); + serviceCollection.AddSingleton(this.settingsResolver); + serviceCollection.AddOptions(); + serviceCollection.Configure(options => + { + options.StorageProvider["type"] = "mssql"; + options.StorageProvider["schemaName"] = this.schema; + }); + serviceCollection.AddDurableTaskSqlProvider(); + serviceCollection.AddDurableClientFactory(a => + { + a.ConnectionName = "SQLDB_Connection"; + }); this.functionsHost = new HostBuilder() .ConfigureLogging( @@ -49,7 +73,9 @@ public IntegrationTestBase(ITestOutputHelper output) { webJobsBuilder.AddDurableTask(options => { + options.HubName = this.taskHubName; options.StorageProvider["type"] = "mssql"; + options.StorageProvider["schemaName"] = this.schema; }); }) .ConfigureServices( @@ -61,17 +87,16 @@ public IntegrationTestBase(ITestOutputHelper output) services.AddDurableTaskSqlProvider(); }) .Build(); - - this.AddFunctions(typeof(ClientFunctions)); + this.clientFactory = serviceCollection.BuildServiceProvider().GetRequiredService(); } async Task IAsyncLifetime.InitializeAsync() { - await SharedTestHelpers.InitializeDatabaseAsync(); + await SharedTestHelpers.InitializeDatabaseAsync(this.schema); // Create a user login specifically for this test to isolate it from other tests - await SharedTestHelpers.EnableMultitenancyAsync(); - this.testCredential = await SharedTestHelpers.CreateTaskHubLoginAsync(this.testName); + await SharedTestHelpers.EnableMultiTenancyAsync(this.multiTenancy, this.schema); + this.testCredential = await SharedTestHelpers.CreateTaskHubLoginAsync(this.testName, this.schema); this.settingsResolver.AddSetting("SQLDB_Connection", this.testCredential.ConnectionString); await this.functionsHost.StartAsync(); @@ -84,7 +109,7 @@ async Task IAsyncLifetime.DisposeAsync() // Remove the temporarily-created credentials from the database if (this.testCredential != null) { - await SharedTestHelpers.DropTaskHubLoginAsync(this.testCredential); + await SharedTestHelpers.DropTaskHubLoginAsync(this.testCredential, this.schema); } } @@ -116,7 +141,7 @@ protected async Task RunOrchestrationAsync( object? input = null, string? instanceId = null) { - IDurableClient client = await this.GetDurableClientAsync(); + IDurableClient client = this.GetDurableClient(); instanceId = await client.StartNewAsync(name, instanceId ?? Guid.NewGuid().ToString("N"), input); TimeSpan timeout = Debugger.IsAttached ? TimeSpan.FromMinutes(5) : TimeSpan.FromSeconds(10); @@ -128,9 +153,10 @@ protected async Task RunOrchestrationAsync( protected async Task StartOrchestrationAsync( string name, object? input = null, - string? instanceId = null) + string? instanceId = null, + string? taskHub = null) { - IDurableClient client = await this.GetDurableClientAsync(); + IDurableClient client = this.GetDurableClient(taskHub ?? this.taskHubName); instanceId = await client.StartNewAsync(name, instanceId ?? Guid.NewGuid().ToString("N"), input); TimeSpan timeout = Debugger.IsAttached ? TimeSpan.FromMinutes(5) : TimeSpan.FromSeconds(10); @@ -139,9 +165,22 @@ protected async Task StartOrchestrationAsync( return status; } + + protected async Task StartOrchestrationWithoutWaitingAsync( + string name, + object? input = null, + string? instanceId = null, + string? taskHub = null) + { + IDurableClient client = this.GetDurableClient(taskHub ?? this.taskHubName); + instanceId = await client.StartNewAsync(name, instanceId ?? Guid.NewGuid().ToString("N"), input); + + + Assert.NotNull(instanceId); + } protected async Task RewindOrchestrationAsync(string instanceId) { - IDurableClient client = await this.GetDurableClientAsync(); + IDurableClient client = this.GetDurableClient(); #pragma warning disable CS0618 // Type or member is obsolete (preview feature) await client.RewindAsync(instanceId, "rewind"); #pragma warning restore CS0618 @@ -152,11 +191,12 @@ protected async Task RewindOrchestrationAsync(string return status; } - protected async Task GetDurableClientAsync() + protected IDurableClient GetDurableClient(string? taskHubName = null) { - var clientRef = new IDurableClient[1]; - await this.CallFunctionAsync(nameof(ClientFunctions.GetDurableClient), "clientRef", clientRef); - IDurableClient client = clientRef[0]; + var client = this.clientFactory.CreateClient(new DurableClientOptions + { + TaskHub = taskHubName ?? this.taskHubName + }); Assert.NotNull(client); return client; } @@ -237,16 +277,5 @@ IConfigurationSection IConnectionInfoResolver.Resolve(string name) return Environment.GetEnvironmentVariable(name); } } - - static class ClientFunctions - { - [NoAutomaticTrigger] - public static void GetDurableClient( - [DurableClient] IDurableClient client, - IDurableClient[] clientRef) - { - clientRef[0] = client; - } - } } } diff --git a/test/DurableTask.SqlServer.AzureFunctions.Tests/WithoutMultiTenancyCoreScenarios.cs b/test/DurableTask.SqlServer.AzureFunctions.Tests/WithoutMultiTenancyCoreScenarios.cs new file mode 100644 index 0000000..783ebff --- /dev/null +++ b/test/DurableTask.SqlServer.AzureFunctions.Tests/WithoutMultiTenancyCoreScenarios.cs @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace DurableTask.SqlServer.AzureFunctions.Tests +{ + using System; + using System.Threading.Tasks; + using Microsoft.Azure.WebJobs.Extensions.DurableTask; + using Xunit; + using Xunit.Abstractions; + + public class WithoutMultiTenancyCoreScenarios : CoreScenarios + { + public WithoutMultiTenancyCoreScenarios(ITestOutputHelper output) + : base(output, "TaskHubWithoutMultiTenancy", false) + { + } + + [Fact] + public async Task When_Without_MultiTenancy_should_Return_Correct_Orchestration_By_TaskHub() + { + string otherTaskHubName = "SomeOtherTaskHub"; + + string currentTaskHubInstanceId = Guid.NewGuid().ToString(); + await this.StartOrchestrationAsync(nameof(Functions.Sequence), instanceId: currentTaskHubInstanceId); + + string anotherTaskHubInstanceId = Guid.NewGuid().ToString(); + await this.StartOrchestrationWithoutWaitingAsync(nameof(Functions.Sequence), instanceId: anotherTaskHubInstanceId, taskHub: otherTaskHubName); + + IDurableClient client = this.GetDurableClient(); + var current = await client.GetStatusAsync(currentTaskHubInstanceId); + Assert.NotNull(current); + var otherInstance = await client.GetStatusAsync(anotherTaskHubInstanceId); + Assert.Null(otherInstance); + + IDurableClient clientOtherTaskHub = this.GetDurableClient(otherTaskHubName); + var currentFromOtherTaskHub = await clientOtherTaskHub.GetStatusAsync(currentTaskHubInstanceId); + Assert.Null(currentFromOtherTaskHub); + var otherInstanceFromOtherTaskHub = await clientOtherTaskHub.GetStatusAsync(anotherTaskHubInstanceId); + Assert.NotNull(otherInstanceFromOtherTaskHub); + } + } +} \ No newline at end of file diff --git a/test/DurableTask.SqlServer.Tests/Utils/SharedTestHelpers.cs b/test/DurableTask.SqlServer.Tests/Utils/SharedTestHelpers.cs index 94c6644..6c9a13d 100644 --- a/test/DurableTask.SqlServer.Tests/Utils/SharedTestHelpers.cs +++ b/test/DurableTask.SqlServer.Tests/Utils/SharedTestHelpers.cs @@ -17,12 +17,13 @@ namespace DurableTask.SqlServer.Tests.Utils public static class SharedTestHelpers { + private const string DefaultSchema = "dt"; public static string GetTestName(ITestOutputHelper output) { Type type = output.GetType(); FieldInfo testMember = type.GetField("test", BindingFlags.Instance | BindingFlags.NonPublic); var test = (ITest)testMember.GetValue(output); - return test.TestCase.TestMethod.Method.Name; + return $"{test.TestCase.TestMethod.Method.Name}" + (test.TestCase.TestMethodArguments == null || test.TestCase.TestMethodArguments.Length == 0 ? string.Empty : $"_{ string.Join("_",test.TestCase.TestMethodArguments.Select(a => a.ToString()))}") ; } public static string GetDefaultConnectionString(string database = "DurableDB") @@ -80,14 +81,14 @@ public static async Task ExecuteSqlAsync(string commandText, string conn throw lastException; } - public static async Task InitializeDatabaseAsync() + public static async Task InitializeDatabaseAsync(string schema = DefaultSchema) { - var options = new SqlOrchestrationServiceSettings(GetDefaultConnectionString()); + var options = new SqlOrchestrationServiceSettings(GetDefaultConnectionString(), schemaName: schema); var service = new SqlOrchestrationService(options); await service.CreateIfNotExistsAsync(); } - public static async Task CreateTaskHubLoginAsync(string prefix) + public static async Task CreateTaskHubLoginAsync(string prefix, string schema = DefaultSchema) { // NOTE: Max length for user IDs is 128 characters string userId = $"{prefix}_{DateTime.UtcNow:yyyyMMddhhmmssff}"; @@ -96,7 +97,7 @@ public static async Task CreateTaskHubLoginAsync(string prefix) // Generate a low-priviledge user account. This will map to a unique task hub. await ExecuteSqlAsync($"CREATE LOGIN [testlogin_{userId}] WITH PASSWORD = '{password}'"); await ExecuteSqlAsync($"CREATE USER [testuser_{userId}] FOR LOGIN [testlogin_{userId}]"); - await ExecuteSqlAsync($"ALTER ROLE dt_runtime ADD MEMBER [testuser_{userId}]"); + await ExecuteSqlAsync($"ALTER ROLE {schema}_runtime ADD MEMBER [testuser_{userId}]"); var builder = new SqlConnectionStringBuilder(GetDefaultConnectionString()) { @@ -108,11 +109,11 @@ public static async Task CreateTaskHubLoginAsync(string prefix) return new TestCredential(userId, builder.ToString()); } - public static async Task DropTaskHubLoginAsync(TestCredential credential) + public static async Task DropTaskHubLoginAsync(TestCredential credential, string schema = DefaultSchema) { // Drop the generated user information string userId = credential.UserId; - await ExecuteSqlAsync($"ALTER ROLE dt_runtime DROP MEMBER [testuser_{userId}]"); + await ExecuteSqlAsync($"ALTER ROLE {schema}_runtime DROP MEMBER [testuser_{userId}]"); await ExecuteSqlAsync($"DROP USER IF EXISTS [testuser_{userId}]"); // drop all the connections; otherwise, the DROP LOGIN statement will fail @@ -120,20 +121,26 @@ public static async Task DropTaskHubLoginAsync(TestCredential credential) await ExecuteSqlAsync($"DROP LOGIN [testlogin_{userId}]"); } - public static Task EnableMultitenancyAsync() + public static async Task EnableMultiTenancyAsync(bool multiTenancy, string schema = DefaultSchema) { - return ExecuteSqlAsync($"EXECUTE dt.SetGlobalSetting @Name='TaskHubMode', @Value=1"); + if (!multiTenancy) + { + await PurgeAsync(schema); + } + + int param = multiTenancy ? 1 : 0; + await ExecuteSqlAsync($"EXECUTE {schema}.SetGlobalSetting @Name='TaskHubMode', @Value={param}"); } static string GeneratePassword() { const string AllowedChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHJKLMNPQRSTWXYZ0123456789#$"; - const int PasswordLenth = 16; + const int PasswordLength = 16; - string password = GetRandomString(AllowedChars, PasswordLenth); + string password = GetRandomString(AllowedChars, PasswordLength); while (!MeetsSqlPasswordConstraint(password)) { - password = GetRandomString(AllowedChars, PasswordLenth); + password = GetRandomString(AllowedChars, PasswordLength); } return password; @@ -233,5 +240,14 @@ static async Task InvokeThrottledAction(T item, Func action, Semapho semaphore.Release(); } } + + public static async Task PurgeAsync(string schema = "dt") + { + await ExecuteSqlAsync($"TRUNCATE TABLE [{schema}].[NewTasks]"); + await ExecuteSqlAsync($"TRUNCATE TABLE [{schema}].[NewEvents]"); + await ExecuteSqlAsync($"TRUNCATE TABLE [{schema}].[Instances]"); + await ExecuteSqlAsync($"TRUNCATE TABLE [{schema}].[History]"); + await ExecuteSqlAsync($"TRUNCATE TABLE [{schema}].[Payloads]"); + } } } diff --git a/test/DurableTask.SqlServer.Tests/Utils/TestService.cs b/test/DurableTask.SqlServer.Tests/Utils/TestService.cs index e251f59..61824dc 100644 --- a/test/DurableTask.SqlServer.Tests/Utils/TestService.cs +++ b/test/DurableTask.SqlServer.Tests/Utils/TestService.cs @@ -52,7 +52,7 @@ public async Task InitializeAsync(bool startWorker = true) await new SqlOrchestrationService(this.OrchestrationServiceOptions).CreateIfNotExistsAsync(); // Enable multitenancy to isolate each test using low-privilege credentials - await SharedTestHelpers.EnableMultitenancyAsync(); + await SharedTestHelpers.EnableMultiTenancyAsync(true); // The runtime will use low-privilege credentials this.testCredential = await SharedTestHelpers.CreateTaskHubLoginAsync(this.testName);