Skip to content

Commit

Permalink
Updated DurableClient ThrowIfFunctionDoesNotExist calls to work with …
Browse files Browse the repository at this point in the history
…correctly when calling a function in another app (#1306)
  • Loading branch information
amdeel authored Apr 10, 2020
1 parent e8ba40c commit df1e28e
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ async Task<IActionResult> IDurableOrchestrationClient.WaitForCompletionOrCreateC
/// <inheritdoc />
async Task<string> IDurableOrchestrationClient.StartNewAsync<T>(string orchestratorFunctionName, string instanceId, T input)
{
this.config.ThrowIfFunctionDoesNotExist(orchestratorFunctionName, FunctionType.Orchestrator);
if (this.ClientReferencesCurrentApp(this))
{
this.config.ThrowIfFunctionDoesNotExist(orchestratorFunctionName, FunctionType.Orchestrator);
}

if (string.IsNullOrEmpty(instanceId))
{
Expand Down Expand Up @@ -210,7 +213,7 @@ Task IDurableEntityClient.SignalEntityAsync(EntityId entityId, string operationN
{
if (string.IsNullOrEmpty(taskHubName))
{
return this.SignalEntityAsyncInternal(this.client, this.TaskHubName, entityId, null, operationName, operationInput);
return this.SignalEntityAsyncInternal(this, this.TaskHubName, entityId, null, operationName, operationInput);
}
else
{
Expand All @@ -225,8 +228,8 @@ Task IDurableEntityClient.SignalEntityAsync(EntityId entityId, string operationN
ConnectionName = connectionName,
};

TaskHubClient taskHubClient = ((DurableClient)this.config.GetClient(attribute)).client;
return this.SignalEntityAsyncInternal(taskHubClient, taskHubName, entityId, null, operationName, operationInput);
var durableClient = (DurableClient)this.config.GetClient(attribute);
return this.SignalEntityAsyncInternal(durableClient, taskHubName, entityId, null, operationName, operationInput);
}
}

Expand All @@ -235,7 +238,7 @@ Task IDurableEntityClient.SignalEntityAsync(EntityId entityId, DateTime schedule
{
if (string.IsNullOrEmpty(taskHubName))
{
return this.SignalEntityAsyncInternal(this.client, this.TaskHubName, entityId, scheduledTimeUtc, operationName, operationInput);
return this.SignalEntityAsyncInternal(this, this.TaskHubName, entityId, scheduledTimeUtc, operationName, operationInput);
}
else
{
Expand All @@ -250,19 +253,19 @@ Task IDurableEntityClient.SignalEntityAsync(EntityId entityId, DateTime schedule
ConnectionName = connectionName,
};

TaskHubClient taskHubClient = ((DurableClient)this.config.GetClient(attribute)).client;
return this.SignalEntityAsyncInternal(taskHubClient, taskHubName, entityId, scheduledTimeUtc, operationName, operationInput);
var durableClient = (DurableClient)this.config.GetClient(attribute);
return this.SignalEntityAsyncInternal(durableClient, taskHubName, entityId, scheduledTimeUtc, operationName, operationInput);
}
}

private async Task SignalEntityAsyncInternal(TaskHubClient client, string hubName, EntityId entityId, DateTime? scheduledTimeUtc, string operationName, object operationInput)
private async Task SignalEntityAsyncInternal(DurableClient durableClient, string hubName, EntityId entityId, DateTime? scheduledTimeUtc, string operationName, object operationInput)
{
if (operationName == null)
{
throw new ArgumentNullException(nameof(operationName));
}

if (this.client.Equals(client))
if (this.ClientReferencesCurrentApp(durableClient))
{
this.config.ThrowIfFunctionDoesNotExist(entityId.EntityName, FunctionType.Entity);
}
Expand All @@ -286,7 +289,7 @@ private async Task SignalEntityAsyncInternal(TaskHubClient client, string hubNam

var jrequest = JToken.FromObject(request, this.messageDataConverter.JsonSerializer);
var eventName = scheduledTimeUtc.HasValue ? EntityMessageEventNames.ScheduledRequestMessageEventName(scheduledTimeUtc.Value) : EntityMessageEventNames.RequestMessageEventName;
await client.RaiseEventAsync(instance, eventName, jrequest);
await durableClient.client.RaiseEventAsync(instance, eventName, jrequest);

this.traceHelper.FunctionScheduled(
hubName,
Expand All @@ -297,6 +300,29 @@ private async Task SignalEntityAsyncInternal(TaskHubClient client, string hubNam
isReplay: false);
}

private bool ClientReferencesCurrentApp(DurableClient client)
{
return this.TaskHubMatchesCurrentApp(client) && this.ConnectionNameMatchesCurrentApp(client);
}

private bool TaskHubMatchesCurrentApp(DurableClient client)
{
var taskHubName = this.config.Options.HubName;
return client.TaskHubName.Equals(taskHubName);
}

private bool ConnectionNameMatchesCurrentApp(DurableClient client)
{
var storageProvider = this.config.Options.StorageProvider;
if (storageProvider.TryGetValue("ConnectionStringName", out object connectionName))
{
var newConnectionName = client.DurabilityProvider.ConnectionName;
return newConnectionName.Equals(connectionName);
}

return false;
}

/// <inheritdoc />
async Task IDurableOrchestrationClient.TerminateAsync(string instanceId, string reason)
{
Expand Down
1 change: 0 additions & 1 deletion test/Common/DurableClientBaseTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ public async Task StartNewAsync_is_calling_overload_method_with_specific_instanc

[Theory]
[Trait("Category", PlatformSpecificHelpers.TestCategory)]
[InlineData("")]
[InlineData("@invalid")]
[InlineData("/invalid")]
[InlineData("invalid\\")]
Expand Down
6 changes: 3 additions & 3 deletions test/Common/DurableTaskEndToEndTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4601,13 +4601,13 @@ private static void ValidateHttpManagementPayload(HttpManagementPayload httpMana
taskHubName += PlatformSpecificHelpers.VersionSuffix;

Assert.Equal(
$"{notificationUrl}/instances/{instanceId}?taskHub={taskHubName}&connection=Storage&code=mykey",
$"{notificationUrl}/instances/{instanceId}?taskHub={taskHubName}&connection=AzureWebJobsStorage&code=mykey",
httpManagementPayload.StatusQueryGetUri);
Assert.Equal(
$"{notificationUrl}/instances/{instanceId}/raiseEvent/{{eventName}}?taskHub={taskHubName}&connection=Storage&code=mykey",
$"{notificationUrl}/instances/{instanceId}/raiseEvent/{{eventName}}?taskHub={taskHubName}&connection=AzureWebJobsStorage&code=mykey",
httpManagementPayload.SendEventPostUri);
Assert.Equal(
$"{notificationUrl}/instances/{instanceId}/terminate?reason={{text}}&taskHub={taskHubName}&connection=Storage&code=mykey",
$"{notificationUrl}/instances/{instanceId}/terminate?reason={{text}}&taskHub={taskHubName}&connection=AzureWebJobsStorage&code=mykey",
httpManagementPayload.TerminatePostUri);
}

Expand Down
1 change: 1 addition & 0 deletions test/Common/TestHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ public static ITestHost GetJobHost(
// Azure Storage specfic tests
if (string.Equals(storageProviderType, AzureStorageProviderType))
{
options.StorageProvider["ConnectionStringName"] = "AzureWebJobsStorage";
options.StorageProvider["fetchLargeMessagesAutomatically"] = autoFetchLargeMessages;
if (maxQueuePollingInterval != null)
{
Expand Down

0 comments on commit df1e28e

Please sign in to comment.