Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[in-proc port] Ensure extension RPC endpoints ready before processing gRPC messages #10282

Merged
merged 3 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
- Fixed incorrect function count in the log message.(#10220)
- Migrate Diagnostic Events to Azure.Data.Tables (#10218)
- Sanitize worker arguments before logging (#10260)
- Fix race condition on startup with extension RPC endpoints not being available. (#10282)
- Adding a timeout when retrieving function metadata from metadata providers (#10219)
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Routing;
using Microsoft.Azure.WebJobs.Rpc.Core.Internal;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;

namespace Microsoft.Azure.WebJobs.Script.Grpc
Expand All @@ -26,6 +28,7 @@ internal sealed class ExtensionsCompositeEndpointDataSource : EndpointDataSource
private readonly object _lock = new();
private readonly List<EndpointDataSource> _dataSources = new();
private readonly IScriptHostManager _scriptHostManager;
private readonly TaskCompletionSource _initialized = new();

private IServiceProvider _extensionServices;
private List<Endpoint> _endpoints;
Expand Down Expand Up @@ -191,6 +194,7 @@ private void OnHostChanged(object sender, ActiveHostChangedEventArgs args)
.GetService<IEnumerable<WebJobsRpcEndpointDataSource>>()
?? Enumerable.Empty<WebJobsRpcEndpointDataSource>();
_dataSources.AddRange(sources);
_initialized.TrySetResult(); // signal we have first initialized.
}
else
{
Expand Down Expand Up @@ -301,5 +305,49 @@ private void ThrowIfDisposed()
throw new ObjectDisposedException(nameof(ExtensionsCompositeEndpointDataSource));
}
}

/// <summary>
/// Middleware to ensure <see cref="ExtensionsCompositeEndpointDataSource"/> is initialized before routing for the first time.
/// Must be registered as a singleton service.
/// </summary>
/// <param name="dataSource">The <see cref="ExtensionsCompositeEndpointDataSource"/> to ensure is initialized.</param>
/// <param name="logger">The logger.</param>
public sealed class EnsureInitializedMiddleware(ExtensionsCompositeEndpointDataSource dataSource, ILogger<EnsureInitializedMiddleware> logger) : IMiddleware
{
private TaskCompletionSource _initialized = new();
private bool _firstRun = true;

// used for testing to verify initialization success.
internal Task Initialized => _initialized.Task;

// settable only for testing purposes.
internal TimeSpan Timeout { get; init; } = TimeSpan.FromSeconds(2);

public Task InvokeAsync(HttpContext context, RequestDelegate next)
{
return _firstRun ? InvokeCoreAsync(context, next) : next(context);
}

private async Task InvokeCoreAsync(HttpContext context, RequestDelegate next)
{
try
{
await dataSource._initialized.Task.WaitAsync(Timeout);
}
catch (TimeoutException ex)
{
// In case of deadlock we don't want to block all gRPC requests.
// Log an error and continue.
logger.LogError(ex, "Error initializing extension endpoints.");
_initialized.TrySetException(ex);
}

// Even in case of timeout we don't want to continually test for initialization on subsequent requests.
// That would be a serious performance degredation.
_firstRun = false;
_initialized.TrySetResult();
await next(context);
}
}
}
}
7 changes: 6 additions & 1 deletion src/WebJobs.Script.Grpc/Server/Startup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ internal class Startup
public void ConfigureServices(IServiceCollection services)
{
services.AddSingleton<ExtensionsCompositeEndpointDataSource>();
services.AddSingleton<ExtensionsCompositeEndpointDataSource.EnsureInitializedMiddleware>();
services.AddGrpc(options =>
{
options.MaxReceiveMessageSize = MaxMessageLengthBytes;
Expand All @@ -30,12 +31,16 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env)
app.UseDeveloperExceptionPage();
}

// This must occur before 'UseRouting'. This ensures extension endpoints are registered before the
// endpoints are collected by the routing middleware.
app.UseMiddleware<ExtensionsCompositeEndpointDataSource.EnsureInitializedMiddleware>();
app.UseRouting();

app.UseEndpoints(endpoints =>
{
endpoints.MapGrpcService<FunctionRpc.FunctionRpcBase>();
endpoints.DataSources.Add(endpoints.ServiceProvider.GetRequiredService<ExtensionsCompositeEndpointDataSource>());
endpoints.DataSources.Add(
endpoints.ServiceProvider.GetRequiredService<ExtensionsCompositeEndpointDataSource>());
});
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.Azure.WebJobs.Rpc.Core.Internal;
using Microsoft.Azure.WebJobs.Script.Grpc;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.FileProviders;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Primitives;
using Moq;
using Xunit;
Expand All @@ -17,6 +21,9 @@ namespace Microsoft.Azure.WebJobs.Script.Tests.Workers.Rpc
{
public class ExtensionsCompositeEndpointDataSourceTests
{
private static readonly ILogger<ExtensionsCompositeEndpointDataSource.EnsureInitializedMiddleware> _logger
= NullLogger<ExtensionsCompositeEndpointDataSource.EnsureInitializedMiddleware>.Instance;

[Fact]
public void NoActiveHost_NoEndpoints()
{
Expand All @@ -41,6 +48,7 @@ public void ActiveHostChanged_NullHost_NoEndpoints()
public void ActiveHostChanged_NoExtensions_NoEndpoints()
{
Mock<IScriptHostManager> manager = new();

ExtensionsCompositeEndpointDataSource dataSource = new(manager.Object);

IChangeToken token = dataSource.GetChangeToken();
Expand All @@ -67,6 +75,45 @@ public void ActiveHostChanged_NewExtensions_NewEndpoints()
endpoint => Assert.Equal("Test2", endpoint.DisplayName));
}

[Fact]
public async Task ActiveHostChanged_MiddlewareWaits_Success()
{
Mock<IScriptHostManager> manager = new();

ExtensionsCompositeEndpointDataSource dataSource = new(manager.Object);
ExtensionsCompositeEndpointDataSource.EnsureInitializedMiddleware middleware =
new(dataSource, _logger) { Timeout = Timeout.InfiniteTimeSpan };
TestDelegate next = new();

Task waiter = middleware.InvokeAsync(null, next.InvokeAsync);
Assert.False(waiter.IsCompleted); // should be blocked until we raise the event.

manager.Raise(x => x.ActiveHostChanged += null, new ActiveHostChangedEventArgs(null, GetHost()));
await waiter.WaitAsync(TimeSpan.FromSeconds(5));
await middleware.Initialized;
await next.Invoked;
}

[Fact]
public async Task NoActiveHostChanged_MiddlewareWaits_Timeout()
{
Mock<IScriptHostManager> manager = new();

ExtensionsCompositeEndpointDataSource dataSource = new(manager.Object);
ExtensionsCompositeEndpointDataSource.EnsureInitializedMiddleware middleware =
new(dataSource, _logger) { Timeout = TimeSpan.Zero };
TestDelegate next = new();

await middleware.InvokeAsync(null, next.InvokeAsync).WaitAsync(TimeSpan.FromSeconds(5)); // should not throw
await Assert.ThrowsAsync<TimeoutException>(() => middleware.Initialized);
await next.Invoked;

// invoke again to verify it processes the next request.
next = new();
await middleware.InvokeAsync(null, next.InvokeAsync);
await next.Invoked;
}

[Fact]
public void Dispose_GetThrows()
{
Expand Down Expand Up @@ -105,5 +152,18 @@ public TestEndpoints(params Endpoint[] endpoints)

public override IChangeToken GetChangeToken() => NullChangeToken.Singleton;
}

private class TestDelegate
{
private readonly TaskCompletionSource _invoked = new();

public Task Invoked => _invoked.Task;

public Task InvokeAsync(HttpContext context)
{
_invoked.TrySetResult();
return Task.CompletedTask;
}
}
}
}
Loading