Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for #127 Is not properly using TaskHubName #128

Merged
merged 11 commits into from
Oct 11, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,7 @@ public SqlDurabilityProvider(
this.service = service ?? throw new ArgumentNullException(nameof(service));
this.durabilityOptions = durabilityOptions;
}

public SqlDurabilityProvider(
SqlOrchestrationService service,
SqlDurabilityOptions durabilityOptions,
IOrchestrationServiceClient client)
: base(Name, service, client, durabilityOptions.ConnectionStringName)
{
this.service = service ?? throw new ArgumentNullException(nameof(service));
this.durabilityOptions = durabilityOptions;
}


public override bool GuaranteesOrderedDelivery => true;

public override JObject ConfigurationJson => JObject.FromObject(this.durabilityOptions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ namespace DurableTask.SqlServer.AzureFunctions
{
using System;
using System.Collections.Generic;
using DurableTask.Core;
using Microsoft.Azure.WebJobs.Extensions.DurableTask;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
Expand All @@ -25,7 +24,6 @@ class SqlDurabilityProviderFactory : IDurabilityProviderFactory

SqlDurabilityOptions? defaultOptions;
SqlOrchestrationServiceSettings? orchestrationServiceSettings;
SqlOrchestrationService? service;
SqlDurabilityProvider? defaultProvider;

/// <summary>
Expand Down Expand Up @@ -56,7 +54,7 @@ public DurabilityProvider GetDurabilityProvider()
if (this.defaultProvider == null)
{
SqlDurabilityOptions sqlProviderOptions = this.GetDefaultSqlOptions();
SqlOrchestrationService service = this.GetOrchestrationService();
SqlOrchestrationService service = this.GetOrchestrationService(sqlProviderOptions);
this.defaultProvider = new SqlDurabilityProvider(service, sqlProviderOptions);
}

Expand All @@ -66,13 +64,6 @@ public DurabilityProvider GetDurabilityProvider()
// Called by the Durable client binding infrastructure
public DurabilityProvider GetDurabilityProvider(DurableClientAttribute attribute)
{
// TODO: Much of this logic should go into the base class
if (string.IsNullOrEmpty(attribute.ConnectionName) &&
string.IsNullOrEmpty(attribute.TaskHub))
{
return this.GetDurabilityProvider();
}

lock (this.clientProviders)
{
string key = GetDurabilityProviderKey(attribute);
Expand All @@ -82,34 +73,27 @@ public DurabilityProvider GetDurabilityProvider(DurableClientAttribute attribute
}

SqlDurabilityOptions clientOptions = this.GetSqlOptions(attribute);
IOrchestrationServiceClient serviceClient =
new SqlOrchestrationService(clientOptions.GetOrchestrationServiceSettings(
this.extensionOptions,
this.connectionInfoResolver));
SqlOrchestrationService orchestrationService =
this.GetOrchestrationService(clientOptions);
clientProvider = new SqlDurabilityProvider(
this.GetOrchestrationService(),
clientOptions,
serviceClient);
orchestrationService,
clientOptions);

this.clientProviders.Add(key, clientProvider);
return clientProvider;
}
}

static string GetDurabilityProviderKey(DurableClientAttribute attribute)
SqlOrchestrationService GetOrchestrationService(SqlDurabilityOptions clientOptions)
{
return attribute.ConnectionName + "|" + attribute.TaskHub;
return new (clientOptions.GetOrchestrationServiceSettings(
this.extensionOptions,
this.connectionInfoResolver));
}

SqlOrchestrationService GetOrchestrationService()
static string GetDurabilityProviderKey(DurableClientAttribute attribute)
{
if (this.service == null)
{
SqlOrchestrationServiceSettings settings = this.GetOrchestrationServiceSettings();
this.service = new SqlOrchestrationService(settings);
}

return this.service;
return attribute.ConnectionName + "|" + attribute.TaskHub;
}

SqlDurabilityOptions GetDefaultSqlOptions()
Expand Down
180 changes: 12 additions & 168 deletions test/DurableTask.SqlServer.AzureFunctions.Tests/CoreScenarios.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,16 @@ namespace DurableTask.SqlServer.AzureFunctions.Tests
using System.Threading;
using System.Threading.Tasks;
using DurableTask.SqlServer.Tests.Utils;
using Microsoft.Azure.WebJobs;
using Microsoft.Azure.WebJobs.Extensions.DurableTask;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using Xunit;
using Xunit.Abstractions;

public class CoreScenarios : IntegrationTestBase
public abstract class CoreScenarios : IntegrationTestBase
{
public CoreScenarios(ITestOutputHelper output)
: base(output)
protected CoreScenarios(ITestOutputHelper output, string taskHubName, bool multiTenancy)
: base(output, taskHubName, multiTenancy)
{
this.AddFunctions(typeof(Functions));
}
Expand Down Expand Up @@ -73,7 +72,7 @@ public async Task CanOrchestrateEntities()
[Fact]
public async Task CanInteractWithEntities()
{
IDurableClient client = await this.GetDurableClientAsync();
IDurableClient client = this.GetDurableClient();

var entityId = new EntityId(nameof(Functions.Counter), Guid.NewGuid().ToString("N"));
EntityStateResponse<int> result = await client.ReadEntityStateAsync<int>(entityId);
Expand Down Expand Up @@ -101,7 +100,7 @@ public async Task SingleInstanceQuery()
Assert.Equal(OrchestrationRuntimeStatus.Completed, status.RuntimeStatus);
Assert.Equal(10, (int)status.Output);

IDurableClient client = await this.GetDurableClientAsync();
IDurableClient client = this.GetDurableClient();
status = await client.GetStatusAsync(status.InstanceId);

Assert.NotNull(status.Input);
Expand Down Expand Up @@ -154,7 +153,7 @@ await Enumerable.Range(0, 5).ParallelForEachAsync(5, i =>

DateTime afterAllFinishedTime = SharedTestHelpers.GetCurrentDatabaseTimeUtc();

IDurableClient client = await this.GetDurableClientAsync();
IDurableClient client = this.GetDurableClient();

// Create one condition object and reuse it for multiple queries
var condition = new OrchestrationStatusQueryCondition();
Expand Down Expand Up @@ -268,15 +267,16 @@ await Enumerable.Range(0, 5).ParallelForEachAsync(5, i =>
Assert.All(result.DurableOrchestrationState, state => Assert.Equal(JValue.CreateNull(), state.Input));
bhugot marked this conversation as resolved.
Show resolved Hide resolved
}


[Fact]
public async Task SingleInstancePurge()
{
DurableOrchestrationStatus status = await this.RunOrchestrationAsync(nameof(Functions.NoOp));
Assert.Equal(OrchestrationRuntimeStatus.Completed, status.RuntimeStatus);

IDurableClient client = await this.GetDurableClientAsync();
IDurableClient client = this.GetDurableClient();
PurgeHistoryResult result;

// First purge gets the active instance
result = await client.PurgeInstanceHistoryAsync(status.InstanceId);
Assert.NotNull(result);
Expand Down Expand Up @@ -315,7 +315,7 @@ public async Task MultiInstancePurge()

DateTime endTime = SharedTestHelpers.GetCurrentDatabaseTimeUtc();

IDurableClient client = await this.GetDurableClientAsync();
IDurableClient client = this.GetDurableClient();
PurgeHistoryResult purgeResult;

// First attempt should delete nothing because it's out of range
Expand Down Expand Up @@ -357,7 +357,7 @@ public async Task CanRewindOrchestration()
Functions.ThrowExceptionInCanFail = true;
DurableOrchestrationStatus status = await this.RunOrchestrationAsync(nameof(Functions.RewindOrchestration));
Assert.Equal(OrchestrationRuntimeStatus.Failed, status.RuntimeStatus);

Functions.ThrowExceptionInCanFail = false;
status = await this.RewindOrchestrationAsync(status.InstanceId);
Assert.Equal(OrchestrationRuntimeStatus.Completed, status.RuntimeStatus);
Expand All @@ -370,167 +370,11 @@ public async Task CanRewindSubOrchestration()
Functions.ThrowExceptionInCanFail = true;
DurableOrchestrationStatus status = await this.RunOrchestrationAsync(nameof(Functions.RewindSubOrchestration));
Assert.Equal(OrchestrationRuntimeStatus.Failed, status.RuntimeStatus);

Functions.ThrowExceptionInCanFail = false;
status = await this.RewindOrchestrationAsync(status.InstanceId);
Assert.Equal(OrchestrationRuntimeStatus.Completed, status.RuntimeStatus);
Assert.Equal("0,1,2,activity", status.Output);
}

static class Functions
{
[FunctionName(nameof(Sequence))]
public static async Task<int> Sequence(
[OrchestrationTrigger] IDurableOrchestrationContext ctx)
{
int value = 0;
for (int i = 0; i < 10; i++)
{
value = await ctx.CallActivityAsync<int>(nameof(PlusOne), value);
}

return value;
}

[FunctionName(nameof(PlusOne))]
public static int PlusOne([ActivityTrigger] int input) => input + 1;

[FunctionName(nameof(FanOutFanIn))]
public static async Task<string[]> FanOutFanIn(
[OrchestrationTrigger] IDurableOrchestrationContext ctx)
{
var tasks = new List<Task<string>>();
for (int i = 0; i < 10; i++)
{
tasks.Add(ctx.CallActivityAsync<string>(nameof(IntToString), i));
}

string[] results = await Task.WhenAll(tasks);
Array.Sort(results);
Array.Reverse(results);
return results;
}

[FunctionName(nameof(IntToString))]
public static string? IntToString([ActivityTrigger] int input) => input.ToString();

[FunctionName(nameof(DeterministicGuid))]
public static async Task<bool> DeterministicGuid([OrchestrationTrigger] IDurableOrchestrationContext ctx)
{
Guid currentGuid1 = ctx.NewGuid();
Guid originalGuid1 = await ctx.CallActivityAsync<Guid>(nameof(Echo), currentGuid1);
if (currentGuid1 != originalGuid1)
{
return false;
}

Guid currentGuid2 = ctx.NewGuid();
Guid originalGuid2 = await ctx.CallActivityAsync<Guid>(nameof(Echo), currentGuid2);
if (currentGuid2 != originalGuid2)
{
return false;
}

return currentGuid1 != currentGuid2;
}

[FunctionName(nameof(Echo))]
public static object Echo([ActivityTrigger] IDurableActivityContext ctx) => ctx.GetInput<object>();

[FunctionName(nameof(OrchestrateCounterEntity))]
public static async Task<int> OrchestrateCounterEntity(
[OrchestrationTrigger] IDurableOrchestrationContext ctx)
{
var entityId = new EntityId(nameof(Counter), ctx.NewGuid().ToString("N"));
ctx.SignalEntity(entityId, "incr");
ctx.SignalEntity(entityId, "incr");
ctx.SignalEntity(entityId, "incr");
ctx.SignalEntity(entityId, "add", 4);

using (await ctx.LockAsync(entityId))
{
int result = await ctx.CallEntityAsync<int>(entityId, "get");
return result;
}
}

[FunctionName(nameof(Counter))]
public static void Counter([EntityTrigger] IDurableEntityContext ctx)
{
int current = ctx.GetState<int>();
switch (ctx.OperationName)
{
case "incr":
ctx.SetState(current + 1);
break;
case "add":
int amount = ctx.GetInput<int>();
ctx.SetState(current + amount);
break;
case "get":
ctx.Return(current);
break;
case "set":
amount = ctx.GetInput<int>();
ctx.SetState(amount);
break;
case "delete":
ctx.DeleteState();
break;
default:
throw new NotImplementedException("No such entity operation");
}
}

[FunctionName(nameof(WaitForEvent))]
public static Task<object> WaitForEvent([OrchestrationTrigger] IDurableOrchestrationContext ctx)
{
string name = ctx.GetInput<string>();
return ctx.WaitForExternalEvent<object>(name);
}

[FunctionName(nameof(NoOp))]
public static Task NoOp([OrchestrationTrigger] IDurableOrchestrationContext ctx) => Task.CompletedTask;

[FunctionName(nameof(SubOrchestrationTest))]
public static async Task<string> SubOrchestrationTest([OrchestrationTrigger] IDurableOrchestrationContext ctx)
{
await ctx.CallSubOrchestratorAsync(nameof(NoOp), "NoOpInstanceId", null);
return "done";
}

[FunctionName(nameof(RewindOrchestration))]
public static async Task<string> RewindOrchestration([OrchestrationTrigger] IDurableOrchestrationContext ctx)
{
return await ctx.CallActivityAsync<string>(nameof(CanFail), "activity");
}

[FunctionName(nameof(RewindSubOrchestration))]
public static async Task<string> RewindSubOrchestration([OrchestrationTrigger] IDurableOrchestrationContext ctx)
{
var tasks = new List<Task<string>>();
for (int i = 0; i < 3; i++)
{
tasks.Add(ctx.CallActivityAsync<string>(nameof(CanFail), i.ToString()));
}
tasks.Add(ctx.CallSubOrchestratorAsync<string>(nameof(RewindOrchestration), "RewindOrchestrationId", "suborchestration"));

var results = await Task.WhenAll(tasks);
return string.Join(',', results);
}

public static bool ThrowExceptionInCanFail { get; set; }

[FunctionName(nameof(CanFail))]
public static string CanFail([ActivityTrigger] IDurableActivityContext context)
{
if (ThrowExceptionInCanFail)
{
throw new Exception("exception");
}

return context.GetInput<string>();
}
}
}
}
Loading