Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return the auth challenge in the WWW-Authenticate when auth fails #1387

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/Ocelot/Authentication/Middleware/AuthenticationMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public async Task Invoke(HttpContext httpContext)

if (result.Principal?.Identity == null)
{
await ChallengeAsync(httpContext, downstreamRoute, result);
SetUnauthenticatedError(httpContext, path, null);
return;
}
Expand All @@ -49,6 +50,7 @@ public async Task Invoke(HttpContext httpContext)
return;
}

await ChallengeAsync(httpContext, downstreamRoute, result);
SetUnauthenticatedError(httpContext, path, httpContext.User.Identity.Name);
}

Expand All @@ -59,6 +61,18 @@ private void SetUnauthenticatedError(HttpContext httpContext, string path, strin
httpContext.Items.SetError(error);
}

private async Task ChallengeAsync(HttpContext context, DownstreamRoute route, AuthenticateResult status)
{
// Perform a challenge. This populates the WWW-Authenticate header on the response
await context.ChallengeAsync(route.AuthenticationOptions.AuthenticationProviderKey); // TODO Read failed scheme from auth result

// Since the response gets re-created down the pipeline, we store the challenge in the Items, so we can re-apply it when sending the response
if (context.Response.Headers.TryGetValue("WWW-Authenticate", out var authenticateHeader))
{
context.Items.SetAuthChallenge(authenticateHeader);
}
}

private async Task<AuthenticateResult> AuthenticateAsync(HttpContext context, DownstreamRoute route)
{
var options = route.AuthenticationOptions;
Expand Down
6 changes: 6 additions & 0 deletions src/Ocelot/Middleware/HttpItemsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ public static void SetError(this IDictionary<object, object> input, Error error)
input.Upsert("Errors", errors);
}

public static void SetAuthChallenge(this IDictionary<object, object> input, string challengeString) =>
input.Upsert("AuthChallenge", challengeString);

public static string AuthChallenge(this IDictionary<object, object> input) =>
input.Get<string>("AuthChallenge");

public static void SetIInternalConfiguration(this IDictionary<object, object> input, IInternalConfiguration config)
{
input.Upsert("IInternalConfiguration", config);
Expand Down
57 changes: 29 additions & 28 deletions src/Ocelot/Multiplexer/MultiplexingMiddleware.cs
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Primitives;
using Newtonsoft.Json.Linq;
using Ocelot.Configuration;
using Ocelot.Configuration;
using Ocelot.Configuration.File;
using Ocelot.DownstreamRouteFinder.UrlMatcher;
using Ocelot.Logging;
using Ocelot.Middleware;
using Ocelot.DownstreamRouteFinder.UrlMatcher;
using Ocelot.Logging;
using Ocelot.Middleware;
using System.Collections;
using Route = Ocelot.Configuration.Route;

namespace Ocelot.Multiplexer;

public class MultiplexingMiddleware : OcelotMiddleware
{
private readonly RequestDelegate _next;
private readonly IResponseAggregatorFactory _factory;
private const string RequestIdString = "RequestId";

public MultiplexingMiddleware(RequestDelegate next,
IOcelotLoggerFactory loggerFactory,
IResponseAggregatorFactory factory)
Expand All @@ -25,7 +25,7 @@ public MultiplexingMiddleware(RequestDelegate next,
_factory = factory;
_next = next;
}

public async Task Invoke(HttpContext httpContext)
{
var downstreamRouteHolder = httpContext.Items.DownstreamRouteHolder();
Expand All @@ -38,37 +38,37 @@ public async Task Invoke(HttpContext httpContext)
await ProcessSingleRouteAsync(httpContext, downstreamRoutes[0]);
return;
}

// Case 2: if no downstream routes
if (downstreamRoutes.Count == 0)
{
return;
}

// Case 3: if multiple downstream routes
var routeKeysConfigs = route.DownstreamRouteConfig;
if (routeKeysConfigs == null || routeKeysConfigs.Count == 0)
{
await ProcessRoutesAsync(httpContext, route);
return;
}

// Case 4: if multiple downstream routes with route keys
var mainResponseContext = await ProcessMainRouteAsync(httpContext, downstreamRoutes[0]);
if (mainResponseContext == null)
{
return;
}

var responsesContexts = await ProcessRoutesWithRouteKeysAsync(httpContext, downstreamRoutes, routeKeysConfigs, mainResponseContext);
if (responsesContexts.Length == 0)
{
return;
}

await MapResponsesAsync(httpContext, route, mainResponseContext, responsesContexts);
}

/// <summary>
/// Helper method to determine if only the first downstream route should be processed.
/// It is the case if the request is a websocket request or if there is only one downstream route.
Expand All @@ -78,7 +78,7 @@ public async Task Invoke(HttpContext httpContext)
/// <returns>True if only the first downstream route should be processed.</returns>
private static bool ShouldProcessSingleRoute(HttpContext context, ICollection routes)
=> context.WebSockets.IsWebSocketRequest || routes.Count == 1;

/// <summary>
/// Processing a single downstream route (no route keys).
/// In that case, no need to make copies of the http context.
Expand All @@ -89,9 +89,10 @@ private static bool ShouldProcessSingleRoute(HttpContext context, ICollection ro
protected virtual Task ProcessSingleRouteAsync(HttpContext context, DownstreamRoute route)
{
context.Items.UpsertDownstreamRoute(route);
context.Items.SetAuthChallenge(/*finished*/context.Items.AuthChallenge());
return _next.Invoke(context);
}

/// <summary>
/// Processing the downstream routes (no route keys).
/// </summary>
Expand All @@ -105,7 +106,7 @@ private async Task ProcessRoutesAsync(HttpContext context, Route route)
var contexts = await Task.WhenAll(tasks);
await MapAsync(context, route, new(contexts));
}

/// <summary>
/// When using route keys, the first route is the main route and the rest are additional routes.
/// Since we need to break if the main route response is null, we must process the main route first.
Expand All @@ -119,7 +120,7 @@ private async Task<HttpContext> ProcessMainRouteAsync(HttpContext context, Downs
await _next.Invoke(context);
return context;
}

/// <summary>
/// Processing the downstream routes with route keys except the main route that has already been processed.
/// </summary>
Expand All @@ -133,7 +134,7 @@ protected virtual async Task<HttpContext[]> ProcessRoutesWithRouteKeysAsync(Http
var processing = new List<Task<HttpContext>>();
var content = await mainResponse.Items.DownstreamResponse().Content.ReadAsStringAsync();
var jObject = JToken.Parse(content);

foreach (var downstreamRoute in routes.Skip(1))
{
var matchAdvancedAgg = routeKeysConfigs.FirstOrDefault(q => q.RouteKey == downstreamRoute.Key);
Expand All @@ -142,13 +143,13 @@ protected virtual async Task<HttpContext[]> ProcessRoutesWithRouteKeysAsync(Http
processing.AddRange(ProcessRouteWithComplexAggregation(matchAdvancedAgg, jObject, context, downstreamRoute));
continue;
}

processing.Add(ProcessRouteAsync(context, downstreamRoute));
}

return await Task.WhenAll(processing);
}

/// <summary>
/// Mapping responses.
/// </summary>
Expand All @@ -158,7 +159,7 @@ private Task MapResponsesAsync(HttpContext context, Route route, HttpContext mai
contexts.AddRange(responsesContexts);
return MapAsync(context, route, contexts);
}

/// <summary>
/// Processing a route with aggregation.
/// </summary>
Expand All @@ -173,7 +174,7 @@ private IEnumerable<Task<HttpContext>> ProcessRouteWithComplexAggregation(Aggreg
tPnv.Add(new PlaceholderNameAndValue('{' + matchAdvancedAgg.Parameter + '}', value));
processing.Add(ProcessRouteAsync(httpContext, downstreamRoute, tPnv));
}

return processing;
}

Expand All @@ -186,11 +187,11 @@ private async Task<HttpContext> ProcessRouteAsync(HttpContext sourceContext, Dow
var newHttpContext = await CreateThreadContextAsync(sourceContext, route);
CopyItemsToNewContext(newHttpContext, sourceContext, placeholders);
newHttpContext.Items.UpsertDownstreamRoute(route);

await _next.Invoke(newHttpContext);
return newHttpContext;
}

/// <summary>
/// Copying some needed parameters to the Http context items.
/// </summary>
Expand Down Expand Up @@ -247,7 +248,7 @@ protected virtual async Task<HttpContext> CreateThreadContextAsync(HttpContext s
target.Response.RegisterForDisposeAsync(bodyStream); // manage Stream lifetime by HttpResponse object
return target;
}

protected virtual Task MapAsync(HttpContext httpContext, Route route, List<HttpContext> contexts)
{
if (route.DownstreamRoute.Count == 1)
Expand Down Expand Up @@ -282,4 +283,4 @@ protected virtual async Task<Stream> CloneRequestBodyAsync(HttpRequest request,

return targetBuffer;
}
}
}
3 changes: 3 additions & 0 deletions src/Ocelot/Responder/HttpContextResponder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ public async Task SetErrorResponseOnContext(HttpContext context, DownstreamRespo
}
}

public void SetAuthChallengeOnContext(HttpContext context, string challenge)
=> AddHeaderIfDoesntExist(context, new Header("WWW-Authenticate", new[] { challenge }));

private static void SetStatusCode(HttpContext context, int statusCode)
{
if (!context.Response.HasStarted)
Expand Down
4 changes: 3 additions & 1 deletion src/Ocelot/Responder/IHttpResponder.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Microsoft.AspNetCore.Http;
using Ocelot.Middleware;

namespace Ocelot.Responder
{
public interface IHttpResponder
Expand All @@ -10,5 +10,7 @@ public interface IHttpResponder
void SetErrorResponseOnContext(HttpContext context, int statusCode);

Task SetErrorResponseOnContext(HttpContext context, DownstreamResponse response);

void SetAuthChallengeOnContext(HttpContext context, string challenge);
}
}
15 changes: 11 additions & 4 deletions src/Ocelot/Responder/Middleware/ResponderMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,20 @@ private async Task SetErrorResponse(HttpContext context, List<Error> errors)
var statusCode = _codeMapper.Map(errors);
_responder.SetErrorResponseOnContext(context, statusCode);

if (errors.All(e => e.Code != OcelotErrorCode.QuotaExceededError))
if (errors.Any(e => e.Code == OcelotErrorCode.QuotaExceededError))
{
return;
var downstreamResponse = context.Items.DownstreamResponse();
await _responder.SetErrorResponseOnContext(context, downstreamResponse);
}

var downstreamResponse = context.Items.DownstreamResponse();
await _responder.SetErrorResponseOnContext(context, downstreamResponse);
if (errors.Any(e => e.Code == OcelotErrorCode.UnauthenticatedError))
{
var challenge = context.Items.AuthChallenge();
if (!string.IsNullOrEmpty(challenge))
{
_responder.SetAuthChallengeOnContext(context, challenge);
}
}
}
}
}
36 changes: 36 additions & 0 deletions test/Ocelot.AcceptanceTests/Authentication/AuthenticationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
using IdentityServer4.Models;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Ocelot.DependencyInjection;

namespace Ocelot.AcceptanceTests.Authentication
{
Expand Down Expand Up @@ -112,6 +114,40 @@ public void Should_return_201_using_identity_server_reference_token()
.BDDfy();
}

[Fact]
[Trait("Feat", "1387")]
public void Should_return_www_authenticate_header_on_401()
{
var port = PortFinder.GetRandomPort();
var route = GivenDefaultAuthRoute(port);
var configuration = GivenConfiguration(route);
this.Given(x => GivenThereIsAConfiguration(configuration))
.And(x => GivenOcelotIsRunningWithJwtAuth("Test"))
.And(x => GivenIHaveNoTokenForMyRequest())
.When(x => WhenIGetUrlOnTheApiGateway("/"))
.Then(x => ThenTheStatusCodeShouldBe(HttpStatusCode.Unauthorized))
.And(x => ThenTheResponseShouldContainAuthChallenge())
.BDDfy();
}
private void GivenOcelotIsRunningWithJwtAuth(string authenticationProviderKey)
{
GivenOcelotIsRunningWithServices(WithJwtBearer);
void WithJwtBearer(IServiceCollection s)
{
s.AddAuthentication().AddJwtBearer(authenticationProviderKey, options => { });
s.AddOcelot();
}
}
private void GivenIHaveNoTokenForMyRequest()
{
_ocelotClient.DefaultRequestHeaders.Authorization = null;
}
private void ThenTheResponseShouldContainAuthChallenge()
{
_response.Headers.TryGetValues("WWW-Authenticate", out var headerValue).ShouldBeTrue();
headerValue.ShouldNotBeEmpty();
}

[IgnorePublicMethod]
public async Task GivenThereIsAnIdentityServerOn(string url, AccessTokenType tokenType)
{
Expand Down
Loading