Skip to content

Commit

Permalink
Allow file middlewares to run if there's an endpoint with a null requ…
Browse files Browse the repository at this point in the history
…est delegate (#42458)

* Allow file middlewares to run if there's an endpoint with a null request delegate #42413

* Ensure EndpointMiddleware checks AuthZ/CORS metadata even if endpoint.RequestDelegate is null

* Update null equality check to use "is"

* Added tests for StaticFileMiddleware changes

* Add tests for DefaultFiles/DirectoryBrowser middleware

* Add endpoint middleware test
  • Loading branch information
DamianEdwards authored Jun 28, 2022
1 parent d6d2b0c commit a0513ea
Show file tree
Hide file tree
Showing 8 changed files with 222 additions and 24 deletions.
35 changes: 19 additions & 16 deletions src/Http/Routing/src/EndpointMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,41 +31,44 @@ public EndpointMiddleware(
public Task Invoke(HttpContext httpContext)
{
var endpoint = httpContext.GetEndpoint();
if (endpoint?.RequestDelegate != null)
if (endpoint is not null)
{
if (!_routeOptions.SuppressCheckForUnhandledSecurityMetadata)
{
if (endpoint.Metadata.GetMetadata<IAuthorizeData>() != null &&
if (endpoint.Metadata.GetMetadata<IAuthorizeData>() is not null &&
!httpContext.Items.ContainsKey(AuthorizationMiddlewareInvokedKey))
{
ThrowMissingAuthMiddlewareException(endpoint);
}

if (endpoint.Metadata.GetMetadata<ICorsMetadata>() != null &&
if (endpoint.Metadata.GetMetadata<ICorsMetadata>() is not null &&
!httpContext.Items.ContainsKey(CorsMiddlewareInvokedKey))
{
ThrowMissingCorsMiddlewareException(endpoint);
}
}

Log.ExecutingEndpoint(_logger, endpoint);

try
if (endpoint.RequestDelegate is not null)
{
var requestTask = endpoint.RequestDelegate(httpContext);
if (!requestTask.IsCompletedSuccessfully)
Log.ExecutingEndpoint(_logger, endpoint);

try
{
return AwaitRequestTask(endpoint, requestTask, _logger);
var requestTask = endpoint.RequestDelegate(httpContext);
if (!requestTask.IsCompletedSuccessfully)
{
return AwaitRequestTask(endpoint, requestTask, _logger);
}
}
}
catch (Exception exception)
{
catch (Exception exception)
{
Log.ExecutedEndpoint(_logger, endpoint);
return Task.FromException(exception);
}

Log.ExecutedEndpoint(_logger, endpoint);
return Task.FromException(exception);
return Task.CompletedTask;
}

Log.ExecutedEndpoint(_logger, endpoint);
return Task.CompletedTask;
}

return _next(httpContext);
Expand Down
29 changes: 29 additions & 0 deletions src/Http/Routing/test/UnitTests/EndpointMiddlewareTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,35 @@ public async Task Invoke_WithEndpoint_ThrowsIfAuthAttributesWereFound_ButAuthMid
Assert.Equal(expected, ex.Message);
}

[Fact]
public async Task Invoke_WithEndpointWithNullRequestDelegate_ThrowsIfAuthAttributesWereFound_ButAuthMiddlewareNotInvoked()
{
// Arrange
var expected = "Endpoint Test contains authorization metadata, but a middleware was not found that supports authorization." +
Environment.NewLine +
"Configure your application startup by adding app.UseAuthorization() in the application startup code. " +
"If there are calls to app.UseRouting() and app.UseEndpoints(...), the call to app.UseAuthorization() must go between them.";
var httpContext = new DefaultHttpContext
{
RequestServices = new ServiceProvider()
};

RequestDelegate throwIfCalled = (c) =>
{
throw new InvalidTimeZoneException("Should not be called");
};

httpContext.SetEndpoint(new Endpoint(requestDelegate: null, new EndpointMetadataCollection(Mock.Of<IAuthorizeData>()), "Test"));

var middleware = new EndpointMiddleware(NullLogger<EndpointMiddleware>.Instance, throwIfCalled, RouteOptions);

// Act & Assert
var ex = await Assert.ThrowsAsync<InvalidOperationException>(() => middleware.Invoke(httpContext));

// Assert
Assert.Equal(expected, ex.Message);
}

[Fact]
public async Task Invoke_WithEndpoint_WorksIfAuthAttributesWereFound_AndAuthMiddlewareInvoked()
{
Expand Down
2 changes: 1 addition & 1 deletion src/Middleware/StaticFiles/src/DefaultFilesMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public DefaultFilesMiddleware(RequestDelegate next, IWebHostEnvironment hostingE
/// <returns></returns>
public Task Invoke(HttpContext context)
{
if (context.GetEndpoint() == null
if (context.GetEndpoint()?.RequestDelegate is null
&& Helpers.IsGetOrHeadMethod(context.Request.Method)
&& Helpers.TryMatchPath(context, _matchUrl, forDirectory: true, subpath: out var subpath))
{
Expand Down
4 changes: 2 additions & 2 deletions src/Middleware/StaticFiles/src/DirectoryBrowserMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ public DirectoryBrowserMiddleware(RequestDelegate next, IWebHostEnvironment host
/// <returns></returns>
public Task Invoke(HttpContext context)
{
// Check if the URL matches any expected paths, skip if an endpoint was selected
if (context.GetEndpoint() == null
// Check if the URL matches any expected paths, skip if an endpoint with a request delegate was selected
if (context.GetEndpoint()?.RequestDelegate is null
&& Helpers.IsGetOrHeadMethod(context.Request.Method)
&& Helpers.TryMatchPath(context, _matchUrl, forDirectory: true, subpath: out var subpath)
&& TryGetDirectoryInfo(subpath, out var contents))
Expand Down
6 changes: 3 additions & 3 deletions src/Middleware/StaticFiles/src/StaticFileMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public StaticFileMiddleware(RequestDelegate next, IWebHostEnvironment hostingEnv
/// <returns></returns>
public Task Invoke(HttpContext context)
{
if (!ValidateNoEndpoint(context))
if (!ValidateNoEndpointDelegate(context))
{
_logger.EndpointMatched();
}
Expand All @@ -91,8 +91,8 @@ public Task Invoke(HttpContext context)
return _next(context);
}

// Return true because we only want to run if there is no endpoint.
private static bool ValidateNoEndpoint(HttpContext context) => context.GetEndpoint() == null;
// Return true because we only want to run if there is no endpoint delegate.
private static bool ValidateNoEndpointDelegate(HttpContext context) => context.GetEndpoint()?.RequestDelegate is null;

private static bool ValidateMethod(HttpContext context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ private async Task NoMatch_PassesThrough(string baseUrl, string baseDir, string
}

[Fact]
public async Task Endpoint_PassesThrough()
public async Task Endpoint_With_RequestDelegate_PassesThrough()
{
using (var fileProvider = new PhysicalFileProvider(Path.Combine(AppContext.BaseDirectory, ".")))
{
Expand Down Expand Up @@ -107,6 +107,9 @@ public async Task Endpoint_PassesThrough()
});

app.UseEndpoints(endpoints => { });

// Echo back the current request path value
app.Run(context => context.Response.WriteAsync(context.Request.Path.Value));
},
services => { services.AddDirectoryBrowser(); services.AddRouting(); });
using var server = host.GetTestServer();
Expand All @@ -117,6 +120,48 @@ public async Task Endpoint_PassesThrough()
}
}

[Fact]
public async Task Endpoint_With_Null_RequestDelegate_Does_Not_PassThrough()
{
using (var fileProvider = new PhysicalFileProvider(Path.Combine(AppContext.BaseDirectory, ".")))
{
using var host = await StaticFilesTestServer.Create(
app =>
{
app.UseRouting();

app.Use(next => context =>
{
// Assign an endpoint with a null RequestDelegate, the default files should still run
context.SetEndpoint(new Endpoint(requestDelegate: null,
new EndpointMetadataCollection(),
"test"));

return next(context);
});

app.UseDefaultFiles(new DefaultFilesOptions
{
RequestPath = new PathString(""),
FileProvider = fileProvider
});

app.UseEndpoints(endpoints => { });

// Echo back the current request path value
app.Run(context => context.Response.WriteAsync(context.Request.Path.Value));
},
services => { services.AddDirectoryBrowser(); services.AddRouting(); });
using var server = host.GetTestServer();

var response = await server.CreateRequest("/SubFolder/").GetAsync();
var responseContent = await response.Content.ReadAsStringAsync();

Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Equal("/SubFolder/default.html", responseContent); // Should be modified and be valid path to file
}
}

[Theory]
[InlineData("", @".", "/SubFolder/")]
[InlineData("", @"./", "/SubFolder/")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ private async Task NoMatch_PassesThrough(string baseUrl, string baseDir, string
}

[Fact]
public async Task Endpoint_PassesThrough()
public async Task Endpoint_With_RequestDelegate_PassesThrough()
{
using (var fileProvider = new PhysicalFileProvider(Path.Combine(AppContext.BaseDirectory, ".")))
{
Expand Down Expand Up @@ -135,6 +135,45 @@ public async Task Endpoint_PassesThrough()
}
}

[Fact]
public async Task Endpoint_With_Null_RequestDelegate_Does_Not_PassThrough()
{
using (var fileProvider = new PhysicalFileProvider(Path.Combine(AppContext.BaseDirectory, ".")))
{
using var host = await StaticFilesTestServer.Create(
app =>
{
app.UseRouting();

app.Use(next => context =>
{
// Assign an endpoint with a null RequestDelegate, the directory browser should still run
context.SetEndpoint(new Endpoint(requestDelegate: null,
new EndpointMetadataCollection(),
"test"));

return next(context);
});

app.UseDirectoryBrowser(new DirectoryBrowserOptions
{
RequestPath = new PathString(""),
FileProvider = fileProvider
});

app.UseEndpoints(endpoints => { });
},
services => { services.AddDirectoryBrowser(); services.AddRouting(); });
using var server = host.GetTestServer();

var response = await server.CreateRequest("/").GetAsync();
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Equal("text/html; charset=utf-8", response.Content.Headers.ContentType.ToString());
Assert.True(response.Content.Headers.ContentLength > 0);
Assert.Equal(response.Content.Headers.ContentLength, (await response.Content.ReadAsByteArrayAsync()).Length);
}
}

[Theory]
[InlineData("", @".", "/")]
[InlineData("", @".", "/SubFolder/")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Globalization;
using System.Net;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.TestHost;
using Microsoft.AspNetCore.Testing;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.FileProviders;
using Microsoft.Extensions.Hosting;
using Moq;
Expand Down Expand Up @@ -193,6 +195,86 @@ private async Task FoundFile_Served(string baseUrl, string baseDir, string reque
}
}

[Fact]
public async Task File_Served_If_Endpoint_With_Null_RequestDelegate_Is_Active()
{
using (var fileProvider = new PhysicalFileProvider(Path.Combine(AppContext.BaseDirectory, ".")))
{
using var host = await StaticFilesTestServer.Create(app =>
{
app.UseRouting();
app.Use((ctx, next) =>
{
ctx.SetEndpoint(new Endpoint(requestDelegate: null, new EndpointMetadataCollection(), "NullRequestDelegateEndpoint"));
return next();
});
app.UseStaticFiles(new StaticFileOptions
{
RequestPath = new PathString(),
FileProvider = fileProvider
});
app.UseEndpoints(endpoints => { });
}, services => services.AddRouting());
using var server = host.GetTestServer();
var requestUrl = "/TestDocument.txt";
var fileInfo = fileProvider.GetFileInfo(Path.GetFileName(requestUrl));
var response = await server.CreateRequest(requestUrl).GetAsync();
var responseContent = await response.Content.ReadAsByteArrayAsync();

Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Equal("text/plain", response.Content.Headers.ContentType.ToString());
Assert.True(response.Content.Headers.ContentLength == fileInfo.Length);
Assert.Equal(response.Content.Headers.ContentLength, responseContent.Length);
Assert.NotNull(response.Headers.ETag);

using (var stream = fileInfo.CreateReadStream())
{
var fileContents = new byte[stream.Length];
stream.Read(fileContents, 0, (int)stream.Length);
Assert.True(responseContent.SequenceEqual(fileContents));
}
}
}

[Fact]
public async Task File_NotServed_If_Endpoint_With_RequestDelegate_Is_Active()
{
var responseText = DateTime.UtcNow.Ticks.ToString(CultureInfo.InvariantCulture);
RequestDelegate handler = async (ctx) =>
{
ctx.Response.ContentType = "text/customfortest+plain";
await ctx.Response.WriteAsync(responseText);
};

using (var fileProvider = new PhysicalFileProvider(Path.Combine(AppContext.BaseDirectory, ".")))
{
using var host = await StaticFilesTestServer.Create(app =>
{
app.UseRouting();
app.Use((ctx, next) =>
{
ctx.SetEndpoint(new Endpoint(handler, new EndpointMetadataCollection(), "RequestDelegateEndpoint"));
return next();
});
app.UseStaticFiles(new StaticFileOptions
{
RequestPath = new PathString(),
FileProvider = fileProvider
});
app.UseEndpoints(endpoints => { });
}, services => services.AddRouting());
using var server = host.GetTestServer();
var requestUrl = "/TestDocument.txt";

var response = await server.CreateRequest(requestUrl).GetAsync();
var responseContent = await response.Content.ReadAsStringAsync();

Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Equal("text/customfortest+plain", response.Content.Headers.ContentType.ToString());
Assert.Equal(responseText, responseContent);
}
}

[Theory]
[MemberData(nameof(ExistingFiles))]
public async Task HeadFile_HeadersButNotBodyServed(string baseUrl, string baseDir, string requestUrl)
Expand Down

0 comments on commit a0513ea

Please sign in to comment.