Skip to content

Commit

Permalink
Ensure extension RPC endpoints ready before processing gRPC messages (#…
Browse files Browse the repository at this point in the history
…10255)

* Ensure extension RPC endpoints ready before processing gRPC messages

* Add timeout and tests to waiting on RPC extensions.
  • Loading branch information
jviau committed Jul 11, 2024
1 parent ae05d5e commit d992955
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 1 deletion.
1 change: 1 addition & 0 deletions release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
- Fixed incorrect function count in the log message.(#10220)
- Migrate Diagnostic Events to Azure.Data.Tables (#10218)
- Sanitize worker arguments before logging (#10260)
- Fix race condition on startup with extension RPC endpoints not being available. (#10255)
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Routing;
using Microsoft.Azure.WebJobs.Rpc.Core.Internal;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;

namespace Microsoft.Azure.WebJobs.Script.Grpc
Expand All @@ -26,6 +28,7 @@ internal sealed class ExtensionsCompositeEndpointDataSource : EndpointDataSource
private readonly object _lock = new();
private readonly List<EndpointDataSource> _dataSources = new();
private readonly IScriptHostManager _scriptHostManager;
private readonly TaskCompletionSource _initialized = new();

private IServiceProvider _extensionServices;
private List<Endpoint> _endpoints;
Expand Down Expand Up @@ -191,6 +194,7 @@ private void OnHostChanged(object sender, ActiveHostChangedEventArgs args)
.GetService<IEnumerable<WebJobsRpcEndpointDataSource>>()
?? Enumerable.Empty<WebJobsRpcEndpointDataSource>();
_dataSources.AddRange(sources);
_initialized.TrySetResult(); // signal we have first initialized.
}
else
{
Expand Down Expand Up @@ -301,5 +305,49 @@ private void ThrowIfDisposed()
throw new ObjectDisposedException(nameof(ExtensionsCompositeEndpointDataSource));
}
}

/// <summary>
/// Middleware to ensure <see cref="ExtensionsCompositeEndpointDataSource"/> is initialized before routing for the first time.
/// Must be registered as a singleton service.
/// </summary>
/// <param name="dataSource">The <see cref="ExtensionsCompositeEndpointDataSource"/> to ensure is initialized.</param>
/// <param name="logger">The logger.</param>
public sealed class EnsureInitializedMiddleware(ExtensionsCompositeEndpointDataSource dataSource, ILogger<EnsureInitializedMiddleware> logger) : IMiddleware
{
private TaskCompletionSource _initialized = new();
private bool _firstRun = true;

// used for testing to verify initialization success.
internal Task Initialized => _initialized.Task;

// settable only for testing purposes.
internal TimeSpan Timeout { get; init; } = TimeSpan.FromSeconds(2);

public Task InvokeAsync(HttpContext context, RequestDelegate next)
{
return _firstRun ? InvokeCoreAsync(context, next) : next(context);
}

private async Task InvokeCoreAsync(HttpContext context, RequestDelegate next)
{
try
{
await dataSource._initialized.Task.WaitAsync(Timeout);
}
catch (TimeoutException ex)
{
// In case of deadlock we don't want to block all gRPC requests.
// Log an error and continue.
logger.LogError(ex, "Error initializing extension endpoints.");
_initialized.TrySetException(ex);
}

// Even in case of timeout we don't want to continually test for initialization on subsequent requests.
// That would be a serious performance degredation.
_firstRun = false;
_initialized.TrySetResult();
await next(context);
}
}
}
}
7 changes: 6 additions & 1 deletion src/WebJobs.Script.Grpc/Server/Startup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ internal class Startup
public void ConfigureServices(IServiceCollection services)
{
services.AddSingleton<ExtensionsCompositeEndpointDataSource>();
services.AddSingleton<ExtensionsCompositeEndpointDataSource.EnsureInitializedMiddleware>();
services.AddGrpc(options =>
{
options.MaxReceiveMessageSize = MaxMessageLengthBytes;
Expand All @@ -30,12 +31,16 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env)
app.UseDeveloperExceptionPage();
}

// This must occur before 'UseRouting'. This ensures extension endpoints are registered before the
// endpoints are collected by the routing middleware.
app.UseMiddleware<ExtensionsCompositeEndpointDataSource.EnsureInitializedMiddleware>();
app.UseRouting();

app.UseEndpoints(endpoints =>
{
endpoints.MapGrpcService<FunctionRpc.FunctionRpcBase>();
endpoints.DataSources.Add(endpoints.ServiceProvider.GetRequiredService<ExtensionsCompositeEndpointDataSource>());
endpoints.DataSources.Add(
endpoints.ServiceProvider.GetRequiredService<ExtensionsCompositeEndpointDataSource>());
});
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.Azure.WebJobs.Rpc.Core.Internal;
using Microsoft.Azure.WebJobs.Script.Grpc;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.FileProviders;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Primitives;
using Moq;
using Xunit;
Expand All @@ -17,6 +21,9 @@ namespace Microsoft.Azure.WebJobs.Script.Tests.Workers.Rpc
{
public class ExtensionsCompositeEndpointDataSourceTests
{
private static readonly ILogger<ExtensionsCompositeEndpointDataSource.EnsureInitializedMiddleware> _logger
= NullLogger<ExtensionsCompositeEndpointDataSource.EnsureInitializedMiddleware>.Instance;

[Fact]
public void NoActiveHost_NoEndpoints()
{
Expand All @@ -41,6 +48,7 @@ public void ActiveHostChanged_NullHost_NoEndpoints()
public void ActiveHostChanged_NoExtensions_NoEndpoints()
{
Mock<IScriptHostManager> manager = new();

ExtensionsCompositeEndpointDataSource dataSource = new(manager.Object);

IChangeToken token = dataSource.GetChangeToken();
Expand All @@ -67,6 +75,45 @@ public void ActiveHostChanged_NewExtensions_NewEndpoints()
endpoint => Assert.Equal("Test2", endpoint.DisplayName));
}

[Fact]
public async Task ActiveHostChanged_MiddlewareWaits_Success()
{
Mock<IScriptHostManager> manager = new();

ExtensionsCompositeEndpointDataSource dataSource = new(manager.Object);
ExtensionsCompositeEndpointDataSource.EnsureInitializedMiddleware middleware =
new(dataSource, _logger) { Timeout = Timeout.InfiniteTimeSpan };
TestDelegate next = new();

Task waiter = middleware.InvokeAsync(null, next.InvokeAsync);
Assert.False(waiter.IsCompleted); // should be blocked until we raise the event.

manager.Raise(x => x.ActiveHostChanged += null, new ActiveHostChangedEventArgs(null, GetHost()));
await waiter.WaitAsync(TimeSpan.FromSeconds(5));
await middleware.Initialized;
await next.Invoked;
}

[Fact]
public async Task NoActiveHostChanged_MiddlewareWaits_Timeout()
{
Mock<IScriptHostManager> manager = new();

ExtensionsCompositeEndpointDataSource dataSource = new(manager.Object);
ExtensionsCompositeEndpointDataSource.EnsureInitializedMiddleware middleware =
new(dataSource, _logger) { Timeout = TimeSpan.Zero };
TestDelegate next = new();

await middleware.InvokeAsync(null, next.InvokeAsync).WaitAsync(TimeSpan.FromSeconds(5)); // should not throw
await Assert.ThrowsAsync<TimeoutException>(() => middleware.Initialized);
await next.Invoked;

// invoke again to verify it processes the next request.
next = new();
await middleware.InvokeAsync(null, next.InvokeAsync);
await next.Invoked;
}

[Fact]
public void Dispose_GetThrows()
{
Expand Down Expand Up @@ -105,5 +152,18 @@ public TestEndpoints(params Endpoint[] endpoints)

public override IChangeToken GetChangeToken() => NullChangeToken.Singleton;
}

private class TestDelegate
{
private readonly TaskCompletionSource _invoked = new();

public Task Invoked => _invoked.Task;

public Task InvokeAsync(HttpContext context)
{
_invoked.TrySetResult();
return Task.CompletedTask;
}
}
}
}

0 comments on commit d992955

Please sign in to comment.