Skip to content

Commit

Permalink
Merge pull request #180 from rangerlabs/master
Browse files Browse the repository at this point in the history
Allow a Processing Strategy to be injected and configured via new startup extension methods
  • Loading branch information
cristipufu authored May 25, 2021
2 parents 5b21509 + a88e84a commit da8f0cd
Show file tree
Hide file tree
Showing 21 changed files with 199 additions and 117 deletions.
16 changes: 8 additions & 8 deletions src/AspNetCoreRateLimit/AspNetCoreRateLimit.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
<PackageLicense>http://opensource.org/licenses/MIT</PackageLicense>
<RepositoryType>git</RepositoryType>
<RepositoryUrl>https://github.com/stefanprodan/AspNetCoreRateLimit</RepositoryUrl>
<LangVersion>8</LangVersion>
<LangVersion>9</LangVersion>
<Version>3.2.3</Version>
<SignAssembly>true</SignAssembly>
<AssemblyOriginatorKeyFile>../../sgKey.snk</AssemblyOriginatorKeyFile>
Expand All @@ -28,12 +28,12 @@
<PackageReference Include="Microsoft.Extensions.Options" Version="2.2.0" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
</ItemGroup>

<ItemGroup Condition="$(TargetFramework) == 'netcoreapp3.1'">
<PackageReference Include="Microsoft.Extensions.Caching.Abstractions" Version="3.1.9" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="3.1.9" />
<PackageReference Include="Microsoft.Extensions.Options" Version="3.1.9" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
<PackageReference Include="Microsoft.Extensions.Caching.Abstractions" Version="3.1.9" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="3.1.9" />
<PackageReference Include="Microsoft.Extensions.Options" Version="3.1.9" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
</ItemGroup>

<ItemGroup Condition="$(TargetFramework) == 'net5.0'">
Expand All @@ -52,11 +52,11 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="1.0.0" PrivateAssets="All"/>
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="1.0.0" PrivateAssets="All" />
</ItemGroup>

<PropertyGroup Condition="'$(APPVEYOR)' == 'true'">
<ContinuousIntegrationBuild>true</ContinuousIntegrationBuild>
</PropertyGroup>

</Project>
20 changes: 15 additions & 5 deletions src/AspNetCoreRateLimit/Core/ClientRateLimitProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,22 @@ namespace AspNetCoreRateLimit
public class ClientRateLimitProcessor : RateLimitProcessor, IRateLimitProcessor
{
private readonly ClientRateLimitOptions _options;
private readonly IProcessingStrategy _processingStrategy;
private readonly IRateLimitStore<ClientRateLimitPolicy> _policyStore;
private readonly ICounterKeyBuilder _counterKeyBuilder;

public ClientRateLimitProcessor(
ClientRateLimitOptions options,
IRateLimitCounterStore counterStore,
IClientPolicyStore policyStore,
IRateLimitConfiguration config)
: base(options, counterStore, new ClientCounterKeyBuilder(options), config)
ClientRateLimitOptions options,
IRateLimitCounterStore counterStore,
IClientPolicyStore policyStore,
IRateLimitConfiguration config,
IProcessingStrategy processingStrategy)
: base(options)
{
_options = options;
_policyStore = policyStore;
_counterKeyBuilder = new ClientCounterKeyBuilder(options);
_processingStrategy = processingStrategy;
}

public async Task<IEnumerable<RateLimitRule>> GetMatchingRulesAsync(ClientRequestIdentity identity, CancellationToken cancellationToken = default)
Expand All @@ -26,5 +31,10 @@ public async Task<IEnumerable<RateLimitRule>> GetMatchingRulesAsync(ClientReques

return GetMatchingRules(identity, policy?.Rules);
}

public async Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default)
{
return await _processingStrategy.ProcessRequestAsync(requestIdentity, rule, _counterKeyBuilder, _options, cancellationToken);
}
}
}
3 changes: 0 additions & 3 deletions src/AspNetCoreRateLimit/Core/IRateLimitProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@ namespace AspNetCoreRateLimit
public interface IRateLimitProcessor
{
Task<IEnumerable<RateLimitRule>> GetMatchingRulesAsync(ClientRequestIdentity identity, CancellationToken cancellationToken = default);

RateLimitHeaders GetRateLimitHeaders(RateLimitCounter? counter, RateLimitRule rule, CancellationToken cancellationToken = default);

Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default);

bool IsWhitelisted(ClientRequestIdentity requestIdentity);
}
}
21 changes: 16 additions & 5 deletions src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,24 @@ public class IpRateLimitProcessor : RateLimitProcessor, IRateLimitProcessor
{
private readonly IpRateLimitOptions _options;
private readonly IRateLimitStore<IpRateLimitPolicies> _policyStore;
private readonly IProcessingStrategy _processingStrategy;
private readonly ICounterKeyBuilder _counterKeyBuilder;

public IpRateLimitProcessor(
IpRateLimitOptions options,
IRateLimitCounterStore counterStore,
IIpPolicyStore policyStore,
IRateLimitConfiguration config)
: base(options, counterStore, new IpCounterKeyBuilder(options), config)
IpRateLimitOptions options,
IRateLimitCounterStore counterStore,
IIpPolicyStore policyStore,
IRateLimitConfiguration config,
IProcessingStrategy processingStrategy)
: base(options)
{
_options = options;
_policyStore = policyStore;
_counterKeyBuilder = new IpCounterKeyBuilder(options);
_processingStrategy = processingStrategy;
}


public async Task<IEnumerable<RateLimitRule>> GetMatchingRulesAsync(ClientRequestIdentity identity, CancellationToken cancellationToken = default)
{
var policies = await _policyStore.GetAsync($"{_options.IpPolicyPrefix}", cancellationToken);
Expand All @@ -40,5 +46,10 @@ public async Task<IEnumerable<RateLimitRule>> GetMatchingRulesAsync(ClientReques

return GetMatchingRules(identity, rules);
}

public async Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default)
{
return await _processingStrategy.ProcessRequestAsync(requestIdentity, rule, _counterKeyBuilder, _options, cancellationToken);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using System;
using System.Threading;
using System.Threading.Tasks;

namespace AspNetCoreRateLimit
{
public class AsyncKeyLockProcessingStrategy : ProcessingStrategy
{
private readonly IRateLimitCounterStore _counterStore;
private readonly IRateLimitConfiguration _config;

public AsyncKeyLockProcessingStrategy(IRateLimitCounterStore counterStore, IRateLimitConfiguration config)
: base(config)
{
_counterStore = counterStore;
_config = config;
}

/// The key-lock used for limiting requests.
private static readonly AsyncKeyLock AsyncLock = new AsyncKeyLock();

public override async Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, ICounterKeyBuilder counterKeyBuilder, RateLimitOptions rateLimitOptions, CancellationToken cancellationToken = default)
{
var counter = new RateLimitCounter
{
Timestamp = DateTime.UtcNow,
Count = 1
};

var counterId = BuildCounterKey(requestIdentity, rule, counterKeyBuilder, rateLimitOptions);

// serial reads and writes on same key
using (await AsyncLock.WriterLockAsync(counterId).ConfigureAwait(false))
{
var entry = await _counterStore.GetAsync(counterId, cancellationToken);

if (entry.HasValue)
{
// entry has not expired
if (entry.Value.Timestamp + rule.PeriodTimespan.Value >= DateTime.UtcNow)
{
// increment request count
var totalCount = entry.Value.Count + _config.RateIncrementer?.Invoke() ?? 1;

// deep copy
counter = new RateLimitCounter
{
Timestamp = entry.Value.Timestamp,
Count = totalCount
};
}
}

// stores: id (string) - timestamp (datetime) - total_requests (long)
await _counterStore.SetAsync(counterId, counter, rule.PeriodTimespan.Value, cancellationToken);
}

return counter;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using System.Threading;
using System.Threading.Tasks;

namespace AspNetCoreRateLimit
{
public interface IProcessingStrategy
{
Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, ICounterKeyBuilder counterKeyBuilder, RateLimitOptions rateLimitOptions, CancellationToken cancellationToken = default);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
using System;
using System.Security.Cryptography;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace AspNetCoreRateLimit
{
public abstract class ProcessingStrategy : IProcessingStrategy
{
private readonly IRateLimitConfiguration _config;

protected ProcessingStrategy(IRateLimitConfiguration config)
{
_config = config;
}

public abstract Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, ICounterKeyBuilder counterKeyBuilder, RateLimitOptions rateLimitOptions, CancellationToken cancellationToken = default);

protected virtual string BuildCounterKey(ClientRequestIdentity requestIdentity, RateLimitRule rule, ICounterKeyBuilder counterKeyBuilder, RateLimitOptions rateLimitOptions)
{
var key = counterKeyBuilder.Build(requestIdentity, rule);

if (rateLimitOptions.EnableEndpointRateLimiting && _config.EndpointCounterKeyBuilder != null)
{
key += _config.EndpointCounterKeyBuilder.Build(requestIdentity, rule);
}

var bytes = Encoding.UTF8.GetBytes(key);

using var algorithm = new SHA1Managed();
var hash = algorithm.ComputeHash(bytes);

return Convert.ToBase64String(hash);
}
}
}
72 changes: 2 additions & 70 deletions src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,12 @@ namespace AspNetCoreRateLimit
public abstract class RateLimitProcessor
{
private readonly RateLimitOptions _options;
private readonly IRateLimitCounterStore _counterStore;
private readonly ICounterKeyBuilder _counterKeyBuilder;
private readonly IRateLimitConfiguration _config;

protected RateLimitProcessor(
RateLimitOptions options,
IRateLimitCounterStore counterStore,
ICounterKeyBuilder counterKeyBuilder,
IRateLimitConfiguration config)

protected RateLimitProcessor(RateLimitOptions options)
{
_options = options;
_counterStore = counterStore;
_counterKeyBuilder = counterKeyBuilder;
_config = config;
}

/// The key-lock used for limiting requests.
private static readonly AsyncKeyLock AsyncLock = new AsyncKeyLock();

public virtual bool IsWhitelisted(ClientRequestIdentity requestIdentity)
{
Expand All @@ -55,45 +43,6 @@ public virtual bool IsWhitelisted(ClientRequestIdentity requestIdentity)
return false;
}

public virtual async Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default)
{
var counter = new RateLimitCounter
{
Timestamp = DateTime.UtcNow,
Count = 1
};

var counterId = BuildCounterKey(requestIdentity, rule);

// serial reads and writes on same key
using (await AsyncLock.WriterLockAsync(counterId).ConfigureAwait(false))
{
var entry = await _counterStore.GetAsync(counterId, cancellationToken);

if (entry.HasValue)
{
// entry has not expired
if (entry.Value.Timestamp + rule.PeriodTimespan.Value >= DateTime.UtcNow)
{
// increment request count
var totalCount = entry.Value.Count + _config.RateIncrementer?.Invoke() ?? 1;

// deep copy
counter = new RateLimitCounter
{
Timestamp = entry.Value.Timestamp,
Count = totalCount
};
}
}

// stores: id (string) - timestamp (datetime) - total_requests (long)
await _counterStore.SetAsync(counterId, counter, rule.PeriodTimespan.Value, cancellationToken);
}

return counter;
}

public virtual RateLimitHeaders GetRateLimitHeaders(RateLimitCounter? counter, RateLimitRule rule, CancellationToken cancellationToken = default)
{
var headers = new RateLimitHeaders();
Expand All @@ -119,23 +68,6 @@ public virtual RateLimitHeaders GetRateLimitHeaders(RateLimitCounter? counter, R
return headers;
}

protected virtual string BuildCounterKey(ClientRequestIdentity requestIdentity, RateLimitRule rule)
{
var key = _counterKeyBuilder.Build(requestIdentity, rule);

if (_options.EnableEndpointRateLimiting && _config.EndpointCounterKeyBuilder != null)
{
key += _config.EndpointCounterKeyBuilder.Build(requestIdentity, rule);
}

var bytes = Encoding.UTF8.GetBytes(key);

using var algorithm = new SHA1Managed();
var hash = algorithm.ComputeHash(bytes);

return Convert.ToBase64String(hash);
}

protected virtual List<RateLimitRule> GetMatchingRules(ClientRequestIdentity identity, List<RateLimitRule> rules = null)
{
var limits = new List<RateLimitRule>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ public class ClientRateLimitMiddleware : RateLimitMiddleware<ClientRateLimitProc
private readonly ILogger<ClientRateLimitMiddleware> _logger;

public ClientRateLimitMiddleware(RequestDelegate next,
IProcessingStrategy processingStrategy,
IOptions<ClientRateLimitOptions> options,
IRateLimitCounterStore counterStore,
IClientPolicyStore policyStore,
IRateLimitConfiguration config,
ILogger<ClientRateLimitMiddleware> logger)
: base(next, options?.Value, new ClientRateLimitProcessor(options?.Value, counterStore, policyStore, config), config)
: base(next, options?.Value, new ClientRateLimitProcessor(options?.Value, counterStore, policyStore, config, processingStrategy), config)
{
_logger = logger;
}
Expand Down
7 changes: 4 additions & 3 deletions src/AspNetCoreRateLimit/Middleware/IpRateLimitMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ public class IpRateLimitMiddleware : RateLimitMiddleware<IpRateLimitProcessor>
private readonly ILogger<IpRateLimitMiddleware> _logger;

public IpRateLimitMiddleware(RequestDelegate next,
IProcessingStrategy processingStrategy,
IOptions<IpRateLimitOptions> options,
IRateLimitCounterStore counterStore,
IIpPolicyStore policyStore,
IRateLimitConfiguration config,
ILogger<IpRateLimitMiddleware> logger)
: base(next, options?.Value, new IpRateLimitProcessor(options?.Value, counterStore, policyStore, config), config)

ILogger<IpRateLimitMiddleware> logger
)
: base(next, options?.Value, new IpRateLimitProcessor(options?.Value, counterStore, policyStore, config, processingStrategy), config)
{
_logger = logger;
}
Expand Down
17 changes: 0 additions & 17 deletions src/AspNetCoreRateLimit/Middleware/MiddlewareExtensions.cs

This file was deleted.

Loading

0 comments on commit da8f0cd

Please sign in to comment.