Skip to content

Commit

Permalink
#1305 Populate RateLimiting headers in the original HttpContext res…
Browse files Browse the repository at this point in the history
…ponse accessed via `IHttpContextAccessor` (#1307)

* set rate limiting headers on the proper httpcontext

* fix Retry-After header

* merge fix

* refactor of ClientRateLimitTests

* merge fix

* Fix build after rebasing

* EOL: test/Ocelot.AcceptanceTests/Steps.cs

* Add `RateLimitingSteps`

* code review by @raman-m

* Inject IHttpContextAccessor, not IServiceProvider

* Ocelot's rate-limiting headers have become legacy

* Headers definition life hack

* A good StackOverflow link

---------

Co-authored-by: Jolanta Łukawska <jolanta.lukawska@outlook.com>
Co-authored-by: Raman Maksimchuk <dotnet044@gmail.com>
  • Loading branch information
3 people authored Nov 9, 2024
1 parent d310508 commit da9d6fa
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 42 deletions.
6 changes: 3 additions & 3 deletions src/Ocelot/Configuration/RateLimitOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public RateLimitOptions(bool enableRateLimiting, string clientIdHeader, Func<Lis
/// Gets the list of white listed clients.
/// </summary>
/// <value>
/// A <see cref="List{T}"/> collection with white listed clients.
/// A <see cref="List{T}"/> (where T is <see cref="string"/>) collection with white listed clients.
/// </value>
public List<string> ClientWhitelist => _getClientWhitelist();

Expand Down Expand Up @@ -80,10 +80,10 @@ public RateLimitOptions(bool enableRateLimiting, string clientIdHeader, Func<Lis
public bool EnableRateLimiting { get; }

/// <summary>
/// Disables X-Rate-Limit and Rety-After headers.
/// Disables <c>X-Rate-Limit</c> and <c>Retry-After</c> headers.
/// </summary>
/// <value>
/// A boolean value for disabling X-Rate-Limit and Rety-After headers.
/// A boolean value for disabling <c>X-Rate-Limit</c> and <c>Retry-After</c> headers.
/// </value>
public bool DisableRateLimitHeaders { get; }
}
Expand Down
55 changes: 30 additions & 25 deletions src/Ocelot/RateLimiting/Middleware/RateLimitingMiddleware.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Microsoft.AspNetCore.Http;
using Microsoft.Net.Http.Headers;
using Ocelot.Configuration;
using Ocelot.Logging;
using Ocelot.Middleware;
Expand All @@ -10,15 +11,18 @@ public class RateLimitingMiddleware : OcelotMiddleware
{
private readonly RequestDelegate _next;
private readonly IRateLimiting _limiter;
private readonly IHttpContextAccessor _contextAccessor;

public RateLimitingMiddleware(
RequestDelegate next,
IOcelotLoggerFactory factory,
IRateLimiting limiter)
IRateLimiting limiter,
IHttpContextAccessor contextAccessor)
: base(factory.CreateLogger<RateLimitingMiddleware>())
{
_next = next;
_limiter = limiter;
_contextAccessor = contextAccessor;
}

public async Task Invoke(HttpContext httpContext)
Expand Down Expand Up @@ -68,11 +72,15 @@ public async Task Invoke(HttpContext httpContext)
}
}

//set X-Rate-Limit headers for the longest period
// Set X-Rate-Limit headers for the longest period
if (!options.DisableRateLimitHeaders)
{
var headers = _limiter.GetHeaders(httpContext, identity, options);
httpContext.Response.OnStarting(SetRateLimitHeaders, state: headers);
var originalContext = _contextAccessor?.HttpContext;
if (originalContext != null)
{
var headers = _limiter.GetHeaders(originalContext, identity, options);
originalContext.Response.OnStarting(SetRateLimitHeaders, state: headers);
}
}

await _next.Invoke(httpContext);
Expand All @@ -93,15 +101,8 @@ public virtual ClientRequestIdentity SetIdentity(HttpContext httpContext, RateLi
);
}

public bool IsWhitelisted(ClientRequestIdentity requestIdentity, RateLimitOptions option)
{
if (option.ClientWhitelist.Contains(requestIdentity.ClientId))
{
return true;
}

return false;
}
public static bool IsWhitelisted(ClientRequestIdentity requestIdentity, RateLimitOptions option)
=> option.ClientWhitelist.Contains(requestIdentity.ClientId);

public virtual void LogBlockedRequest(HttpContext httpContext, ClientRequestIdentity identity, RateLimitCounter counter, RateLimitRule rule, DownstreamRoute downstreamRoute)
{
Expand All @@ -112,14 +113,15 @@ public virtual void LogBlockedRequest(HttpContext httpContext, ClientRequestIden
public virtual DownstreamResponse ReturnQuotaExceededResponse(HttpContext httpContext, RateLimitOptions option, string retryAfter)
{
var message = GetResponseMessage(option);

var http = new HttpResponseMessage((HttpStatusCode)option.HttpStatusCode);

http.Content = new StringContent(message);
var http = new HttpResponseMessage((HttpStatusCode)option.HttpStatusCode)
{
Content = new StringContent(message),
};

if (!option.DisableRateLimitHeaders)
{
http.Headers.TryAddWithoutValidation("Retry-After", retryAfter); // in seconds, not date string
http.Headers.TryAddWithoutValidation(HeaderNames.RetryAfter, retryAfter); // in seconds, not date string
httpContext.Response.Headers[HeaderNames.RetryAfter] = retryAfter;
}

return new DownstreamResponse(http);
Expand All @@ -133,14 +135,17 @@ private static string GetResponseMessage(RateLimitOptions option)
return message;
}

private static Task SetRateLimitHeaders(object rateLimitHeaders)
/// <summary>TODO: Produced Ocelot's headers don't follow industry standards.</summary>
/// <remarks>More details in <see cref="RateLimitingHeaders"/> docs.</remarks>
/// <param name="state">Captured state as a <see cref="RateLimitHeaders"/> object.</param>
/// <returns>The <see cref="Task.CompletedTask"/> object.</returns>
private static Task SetRateLimitHeaders(object state)
{
var headers = (RateLimitHeaders)rateLimitHeaders;

headers.Context.Response.Headers["X-Rate-Limit-Limit"] = headers.Limit;
headers.Context.Response.Headers["X-Rate-Limit-Remaining"] = headers.Remaining;
headers.Context.Response.Headers["X-Rate-Limit-Reset"] = headers.Reset;

var limitHeaders = (RateLimitHeaders)state;
var headers = limitHeaders.Context.Response.Headers;
headers[RateLimitingHeaders.X_Rate_Limit_Limit] = limitHeaders.Limit;
headers[RateLimitingHeaders.X_Rate_Limit_Remaining] = limitHeaders.Remaining;
headers[RateLimitingHeaders.X_Rate_Limit_Reset] = limitHeaders.Reset;
return Task.CompletedTask;
}
}
Expand Down
32 changes: 32 additions & 0 deletions src/Ocelot/RateLimiting/RateLimitingHeaders.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using Microsoft.Net.Http.Headers;

namespace Ocelot.RateLimiting;

/// <summary>
/// TODO These Ocelot's RateLimiting headers don't follow industry standards, see links.
/// </summary>
/// <remarks>Links:
/// <list type="bullet">
/// <item>GitHub: <see href="https://github.com/ioggstream/draft-polli-ratelimit-headers">draft-polli-ratelimit-headers</see></item>
/// <item>GitHub: <see href="https://github.com/ietf-wg-httpapi/ratelimit-headers">ratelimit-headers</see></item>
/// <item>GitHub Wiki: <see href="https://ietf-wg-httpapi.github.io/ratelimit-headers/draft-ietf-httpapi-ratelimit-headers.html">RateLimit header fields for HTTP</see></item>
/// <item>StackOverflow: <see href="https://stackoverflow.com/questions/16022624/examples-of-http-api-rate-limiting-http-response-headers">Examples of HTTP API Rate Limiting HTTP Response headers</see></item>
/// </list>
/// </remarks>
public static class RateLimitingHeaders
{
public const char Dash = '-';
public const char Underscore = '_';

/// <summary>Gets the <c>Retry-After</c> HTTP header name.</summary>
public static readonly string Retry_After = HeaderNames.RetryAfter;

/// <summary>Gets the <c>X-Rate-Limit-Limit</c> Ocelot's header name.</summary>
public static readonly string X_Rate_Limit_Limit = nameof(X_Rate_Limit_Limit).Replace(Underscore, Dash);

/// <summary>Gets the <c>X-Rate-Limit-Remaining</c> Ocelot's header name.</summary>
public static readonly string X_Rate_Limit_Remaining = nameof(X_Rate_Limit_Remaining).Replace(Underscore, Dash);

/// <summary>Gets the <c>X-Rate-Limit-Reset</c> Ocelot's header name.</summary>
public static readonly string X_Rate_Limit_Reset = nameof(X_Rate_Limit_Reset).Replace(Underscore, Dash);
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
using Microsoft.AspNetCore.Http;
using Microsoft.Net.Http.Headers;
using Ocelot.Configuration.File;
using Ocelot.RateLimiting;

namespace Ocelot.AcceptanceTests.RateLimiting;

public sealed class ClientRateLimitingTests : Steps, IDisposable
public sealed class ClientRateLimitingTests : RateLimitingSteps, IDisposable
{
const int OK = (int)HttpStatusCode.OK;
const int TooManyRequests = (int)HttpStatusCode.TooManyRequests;
Expand Down Expand Up @@ -129,6 +131,53 @@ public void StatusShouldNotBeEqualTo429_PeriodTimespanValueIsGreaterThanPeriod()
.And(x => ThenTheResponseBodyShouldBe("101")) // total 101 OK responses
.BDDfy();
}

[Theory]
[Trait("Bug", "1305")]
[InlineData(false)]
[InlineData(true)]
public void Should_set_ratelimiting_headers_on_response_when_DisableRateLimitHeaders_set_to(bool disableRateLimitHeaders)
{
int port = PortFinder.GetRandomPort();
var configuration = CreateConfigurationForCheckingHeaders(port, disableRateLimitHeaders);
bool exist = !disableRateLimitHeaders;
this.Given(x => x.GivenThereIsAServiceRunningOn(DownstreamUrl(port), "/api/ClientRateLimit"))
.And(x => GivenThereIsAConfiguration(configuration))
.And(x => GivenOcelotIsRunning())
.When(x => WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit("/api/ClientRateLimit", 1))
.Then(x => ThenRateLimitingHeadersExistInResponse(exist))
.And(x => ThenRetryAfterHeaderExistsInResponse(false))
.When(x => WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit("/api/ClientRateLimit", 2))
.Then(x => ThenRateLimitingHeadersExistInResponse(exist))
.And(x => ThenRetryAfterHeaderExistsInResponse(false))
.When(x => WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit("/api/ClientRateLimit", 1))
.Then(x => ThenRateLimitingHeadersExistInResponse(false))
.And(x => ThenRetryAfterHeaderExistsInResponse(exist))
.BDDfy();
}

private FileConfiguration CreateConfigurationForCheckingHeaders(int port, bool disableRateLimitHeaders)
{
var route = GivenRoute(port, null, null, new(), 3, "100s", 1000.0D);
var config = GivenConfiguration(route);
config.GlobalConfiguration.RateLimitOptions = new FileRateLimitOptions()
{
DisableRateLimitHeaders = disableRateLimitHeaders,
QuotaExceededMessage = "",
HttpStatusCode = TooManyRequests,
};
return config;
}

private void ThenRateLimitingHeadersExistInResponse(bool headersExist)
{
_response.Headers.Contains(RateLimitingHeaders.X_Rate_Limit_Limit).ShouldBe(headersExist);
_response.Headers.Contains(RateLimitingHeaders.X_Rate_Limit_Remaining).ShouldBe(headersExist);
_response.Headers.Contains(RateLimitingHeaders.X_Rate_Limit_Reset).ShouldBe(headersExist);
}

private void ThenRetryAfterHeaderExistsInResponse(bool headersExist)
=> _response.Headers.Contains(HeaderNames.RetryAfter).ShouldBe(headersExist);

private void GivenThereIsAServiceRunningOn(string baseUrl, string basePath)
{
Expand Down
15 changes: 15 additions & 0 deletions test/Ocelot.AcceptanceTests/RateLimiting/RateLimitingSteps.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
namespace Ocelot.AcceptanceTests.RateLimiting;

public class RateLimitingSteps : Steps
{
public async Task WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit(string url, int times)
{
for (var i = 0; i < times; i++)
{
const string clientId = "ocelotclient1";
var request = new HttpRequestMessage(new HttpMethod("GET"), url);
request.Headers.Add("ClientId", clientId);
_response = await _ocelotClient.SendAsync(request);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.Configuration;
using Newtonsoft.Json;
using Ocelot.AcceptanceTests.RateLimiting;
using Ocelot.Cache;
using Ocelot.Configuration.File;
using Ocelot.DependencyInjection;
Expand All @@ -14,7 +15,7 @@

namespace Ocelot.AcceptanceTests.ServiceDiscovery
{
public sealed class ConsulConfigurationInConsulTests : Steps, IDisposable
public sealed class ConsulConfigurationInConsulTests : RateLimitingSteps, IDisposable
{
private IWebHost _builder;
private IWebHost _fakeConsulBuilder;
Expand Down
11 changes: 0 additions & 11 deletions test/Ocelot.AcceptanceTests/Steps.cs
Original file line number Diff line number Diff line change
Expand Up @@ -733,17 +733,6 @@ public static async Task WhenIDoActionMultipleTimes(int times, Func<int, Task> a
await action.Invoke(i);
}

public async Task WhenIGetUrlOnTheApiGatewayMultipleTimesForRateLimit(string url, int times)
{
for (var i = 0; i < times; i++)
{
const string clientId = "ocelotclient1";
var request = new HttpRequestMessage(new HttpMethod("GET"), url);
request.Headers.Add("ClientId", clientId);
_response = await _ocelotClient.SendAsync(request);
}
}

public async Task WhenIGetUrlOnTheApiGateway(string url, string requestId)
{
_ocelotClient.DefaultRequestHeaders.TryAddWithoutValidation(RequestIdKey, requestId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public class RateLimitingMiddlewareTests : UnitTest
private readonly IRateLimitStorage _storage;
private readonly Mock<IOcelotLoggerFactory> _loggerFactory;
private readonly Mock<IOcelotLogger> _logger;
private readonly Mock<IHttpContextAccessor> _contextAccessor;
private readonly RateLimitingMiddleware _middleware;
private readonly RequestDelegate _next;
private readonly IRateLimiting _rateLimiting;
Expand All @@ -34,7 +35,8 @@ public RateLimitingMiddlewareTests()
_loggerFactory.Setup(x => x.CreateLogger<RateLimitingMiddleware>()).Returns(_logger.Object);
_next = context => Task.CompletedTask;
_rateLimiting = new _RateLimiting_(_storage);
_middleware = new RateLimitingMiddleware(_next, _loggerFactory.Object, _rateLimiting);
_contextAccessor = new Mock<IHttpContextAccessor>();
_middleware = new RateLimitingMiddleware(_next, _loggerFactory.Object, _rateLimiting, _contextAccessor.Object);
_downstreamResponses = new();
}

Expand Down

0 comments on commit da9d6fa

Please sign in to comment.