Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow a Processing Strategy to be injected and configured via new startup extension methods #180

Merged
merged 2 commits into from
May 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/AspNetCoreRateLimit/AspNetCoreRateLimit.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
<PackageLicense>http://opensource.org/licenses/MIT</PackageLicense>
<RepositoryType>git</RepositoryType>
<RepositoryUrl>https://github.com/stefanprodan/AspNetCoreRateLimit</RepositoryUrl>
<LangVersion>8</LangVersion>
<LangVersion>9</LangVersion>
<Version>3.2.3</Version>
<SignAssembly>true</SignAssembly>
<AssemblyOriginatorKeyFile>../../sgKey.snk</AssemblyOriginatorKeyFile>
Expand All @@ -28,12 +28,12 @@
<PackageReference Include="Microsoft.Extensions.Options" Version="2.2.0" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
</ItemGroup>

<ItemGroup Condition="$(TargetFramework) == 'netcoreapp3.1'">
<PackageReference Include="Microsoft.Extensions.Caching.Abstractions" Version="3.1.9" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="3.1.9" />
<PackageReference Include="Microsoft.Extensions.Options" Version="3.1.9" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
<PackageReference Include="Microsoft.Extensions.Caching.Abstractions" Version="3.1.9" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="3.1.9" />
<PackageReference Include="Microsoft.Extensions.Options" Version="3.1.9" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
</ItemGroup>

<ItemGroup Condition="$(TargetFramework) == 'net5.0'">
Expand All @@ -52,11 +52,11 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="1.0.0" PrivateAssets="All"/>
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="1.0.0" PrivateAssets="All" />
</ItemGroup>

<PropertyGroup Condition="'$(APPVEYOR)' == 'true'">
<ContinuousIntegrationBuild>true</ContinuousIntegrationBuild>
</PropertyGroup>

</Project>
20 changes: 15 additions & 5 deletions src/AspNetCoreRateLimit/Core/ClientRateLimitProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,22 @@ namespace AspNetCoreRateLimit
public class ClientRateLimitProcessor : RateLimitProcessor, IRateLimitProcessor
{
private readonly ClientRateLimitOptions _options;
private readonly IProcessingStrategy _processingStrategy;
private readonly IRateLimitStore<ClientRateLimitPolicy> _policyStore;
private readonly ICounterKeyBuilder _counterKeyBuilder;

public ClientRateLimitProcessor(
ClientRateLimitOptions options,
IRateLimitCounterStore counterStore,
IClientPolicyStore policyStore,
IRateLimitConfiguration config)
: base(options, counterStore, new ClientCounterKeyBuilder(options), config)
ClientRateLimitOptions options,
IRateLimitCounterStore counterStore,
IClientPolicyStore policyStore,
IRateLimitConfiguration config,
IProcessingStrategy processingStrategy)
: base(options)
{
_options = options;
_policyStore = policyStore;
_counterKeyBuilder = new ClientCounterKeyBuilder(options);
_processingStrategy = processingStrategy;
}

public async Task<IEnumerable<RateLimitRule>> GetMatchingRulesAsync(ClientRequestIdentity identity, CancellationToken cancellationToken = default)
Expand All @@ -26,5 +31,10 @@ public async Task<IEnumerable<RateLimitRule>> GetMatchingRulesAsync(ClientReques

return GetMatchingRules(identity, policy?.Rules);
}

public async Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default)
{
return await _processingStrategy.ProcessRequestAsync(requestIdentity, rule, _counterKeyBuilder, _options, cancellationToken);
}
}
}
3 changes: 0 additions & 3 deletions src/AspNetCoreRateLimit/Core/IRateLimitProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@ namespace AspNetCoreRateLimit
public interface IRateLimitProcessor
{
Task<IEnumerable<RateLimitRule>> GetMatchingRulesAsync(ClientRequestIdentity identity, CancellationToken cancellationToken = default);

RateLimitHeaders GetRateLimitHeaders(RateLimitCounter? counter, RateLimitRule rule, CancellationToken cancellationToken = default);

Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default);

bool IsWhitelisted(ClientRequestIdentity requestIdentity);
}
}
21 changes: 16 additions & 5 deletions src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,24 @@ public class IpRateLimitProcessor : RateLimitProcessor, IRateLimitProcessor
{
private readonly IpRateLimitOptions _options;
private readonly IRateLimitStore<IpRateLimitPolicies> _policyStore;
private readonly IProcessingStrategy _processingStrategy;
private readonly ICounterKeyBuilder _counterKeyBuilder;

public IpRateLimitProcessor(
IpRateLimitOptions options,
IRateLimitCounterStore counterStore,
IIpPolicyStore policyStore,
IRateLimitConfiguration config)
: base(options, counterStore, new IpCounterKeyBuilder(options), config)
IpRateLimitOptions options,
IRateLimitCounterStore counterStore,
IIpPolicyStore policyStore,
IRateLimitConfiguration config,
IProcessingStrategy processingStrategy)
: base(options)
{
_options = options;
_policyStore = policyStore;
_counterKeyBuilder = new IpCounterKeyBuilder(options);
_processingStrategy = processingStrategy;
}


public async Task<IEnumerable<RateLimitRule>> GetMatchingRulesAsync(ClientRequestIdentity identity, CancellationToken cancellationToken = default)
{
var policies = await _policyStore.GetAsync($"{_options.IpPolicyPrefix}", cancellationToken);
Expand All @@ -40,5 +46,10 @@ public async Task<IEnumerable<RateLimitRule>> GetMatchingRulesAsync(ClientReques

return GetMatchingRules(identity, rules);
}

public async Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default)
{
return await _processingStrategy.ProcessRequestAsync(requestIdentity, rule, _counterKeyBuilder, _options, cancellationToken);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using System;
using System.Threading;
using System.Threading.Tasks;

namespace AspNetCoreRateLimit
{
public class AsyncKeyLockProcessingStrategy : ProcessingStrategy
{
private readonly IRateLimitCounterStore _counterStore;
private readonly IRateLimitConfiguration _config;

public AsyncKeyLockProcessingStrategy(IRateLimitCounterStore counterStore, IRateLimitConfiguration config)
: base(config)
{
_counterStore = counterStore;
_config = config;
}

/// The key-lock used for limiting requests.
private static readonly AsyncKeyLock AsyncLock = new AsyncKeyLock();

public override async Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, ICounterKeyBuilder counterKeyBuilder, RateLimitOptions rateLimitOptions, CancellationToken cancellationToken = default)
{
var counter = new RateLimitCounter
{
Timestamp = DateTime.UtcNow,
Count = 1
};

var counterId = BuildCounterKey(requestIdentity, rule, counterKeyBuilder, rateLimitOptions);

// serial reads and writes on same key
using (await AsyncLock.WriterLockAsync(counterId).ConfigureAwait(false))
{
var entry = await _counterStore.GetAsync(counterId, cancellationToken);

if (entry.HasValue)
{
// entry has not expired
if (entry.Value.Timestamp + rule.PeriodTimespan.Value >= DateTime.UtcNow)
{
// increment request count
var totalCount = entry.Value.Count + _config.RateIncrementer?.Invoke() ?? 1;

// deep copy
counter = new RateLimitCounter
{
Timestamp = entry.Value.Timestamp,
Count = totalCount
};
}
}

// stores: id (string) - timestamp (datetime) - total_requests (long)
await _counterStore.SetAsync(counterId, counter, rule.PeriodTimespan.Value, cancellationToken);
}

return counter;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using System.Threading;
using System.Threading.Tasks;

namespace AspNetCoreRateLimit
{
public interface IProcessingStrategy
{
Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, ICounterKeyBuilder counterKeyBuilder, RateLimitOptions rateLimitOptions, CancellationToken cancellationToken = default);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
using System;
using System.Security.Cryptography;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace AspNetCoreRateLimit
{
public abstract class ProcessingStrategy : IProcessingStrategy
{
private readonly IRateLimitConfiguration _config;

protected ProcessingStrategy(IRateLimitConfiguration config)
{
_config = config;
}

public abstract Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, ICounterKeyBuilder counterKeyBuilder, RateLimitOptions rateLimitOptions, CancellationToken cancellationToken = default);

protected virtual string BuildCounterKey(ClientRequestIdentity requestIdentity, RateLimitRule rule, ICounterKeyBuilder counterKeyBuilder, RateLimitOptions rateLimitOptions)
{
var key = counterKeyBuilder.Build(requestIdentity, rule);

if (rateLimitOptions.EnableEndpointRateLimiting && _config.EndpointCounterKeyBuilder != null)
{
key += _config.EndpointCounterKeyBuilder.Build(requestIdentity, rule);
}

var bytes = Encoding.UTF8.GetBytes(key);

using var algorithm = new SHA1Managed();
var hash = algorithm.ComputeHash(bytes);

return Convert.ToBase64String(hash);
}
}
}
72 changes: 2 additions & 70 deletions src/AspNetCoreRateLimit/Core/RateLimitProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,12 @@ namespace AspNetCoreRateLimit
public abstract class RateLimitProcessor
{
private readonly RateLimitOptions _options;
private readonly IRateLimitCounterStore _counterStore;
private readonly ICounterKeyBuilder _counterKeyBuilder;
private readonly IRateLimitConfiguration _config;

protected RateLimitProcessor(
RateLimitOptions options,
IRateLimitCounterStore counterStore,
ICounterKeyBuilder counterKeyBuilder,
IRateLimitConfiguration config)

protected RateLimitProcessor(RateLimitOptions options)
{
_options = options;
_counterStore = counterStore;
_counterKeyBuilder = counterKeyBuilder;
_config = config;
}

/// The key-lock used for limiting requests.
private static readonly AsyncKeyLock AsyncLock = new AsyncKeyLock();

public virtual bool IsWhitelisted(ClientRequestIdentity requestIdentity)
{
Expand All @@ -55,45 +43,6 @@ public virtual bool IsWhitelisted(ClientRequestIdentity requestIdentity)
return false;
}

public virtual async Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default)
{
var counter = new RateLimitCounter
{
Timestamp = DateTime.UtcNow,
Count = 1
};

var counterId = BuildCounterKey(requestIdentity, rule);

// serial reads and writes on same key
using (await AsyncLock.WriterLockAsync(counterId).ConfigureAwait(false))
{
var entry = await _counterStore.GetAsync(counterId, cancellationToken);

if (entry.HasValue)
{
// entry has not expired
if (entry.Value.Timestamp + rule.PeriodTimespan.Value >= DateTime.UtcNow)
{
// increment request count
var totalCount = entry.Value.Count + _config.RateIncrementer?.Invoke() ?? 1;

// deep copy
counter = new RateLimitCounter
{
Timestamp = entry.Value.Timestamp,
Count = totalCount
};
}
}

// stores: id (string) - timestamp (datetime) - total_requests (long)
await _counterStore.SetAsync(counterId, counter, rule.PeriodTimespan.Value, cancellationToken);
}

return counter;
}

public virtual RateLimitHeaders GetRateLimitHeaders(RateLimitCounter? counter, RateLimitRule rule, CancellationToken cancellationToken = default)
{
var headers = new RateLimitHeaders();
Expand All @@ -119,23 +68,6 @@ public virtual RateLimitHeaders GetRateLimitHeaders(RateLimitCounter? counter, R
return headers;
}

protected virtual string BuildCounterKey(ClientRequestIdentity requestIdentity, RateLimitRule rule)
{
var key = _counterKeyBuilder.Build(requestIdentity, rule);

if (_options.EnableEndpointRateLimiting && _config.EndpointCounterKeyBuilder != null)
{
key += _config.EndpointCounterKeyBuilder.Build(requestIdentity, rule);
}

var bytes = Encoding.UTF8.GetBytes(key);

using var algorithm = new SHA1Managed();
var hash = algorithm.ComputeHash(bytes);

return Convert.ToBase64String(hash);
}

protected virtual List<RateLimitRule> GetMatchingRules(ClientRequestIdentity identity, List<RateLimitRule> rules = null)
{
var limits = new List<RateLimitRule>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ public class ClientRateLimitMiddleware : RateLimitMiddleware<ClientRateLimitProc
private readonly ILogger<ClientRateLimitMiddleware> _logger;

public ClientRateLimitMiddleware(RequestDelegate next,
IProcessingStrategy processingStrategy,
IOptions<ClientRateLimitOptions> options,
IRateLimitCounterStore counterStore,
IClientPolicyStore policyStore,
IRateLimitConfiguration config,
ILogger<ClientRateLimitMiddleware> logger)
: base(next, options?.Value, new ClientRateLimitProcessor(options?.Value, counterStore, policyStore, config), config)
: base(next, options?.Value, new ClientRateLimitProcessor(options?.Value, counterStore, policyStore, config, processingStrategy), config)
{
_logger = logger;
}
Expand Down
7 changes: 4 additions & 3 deletions src/AspNetCoreRateLimit/Middleware/IpRateLimitMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ public class IpRateLimitMiddleware : RateLimitMiddleware<IpRateLimitProcessor>
private readonly ILogger<IpRateLimitMiddleware> _logger;

public IpRateLimitMiddleware(RequestDelegate next,
IProcessingStrategy processingStrategy,
IOptions<IpRateLimitOptions> options,
IRateLimitCounterStore counterStore,
IIpPolicyStore policyStore,
IRateLimitConfiguration config,
ILogger<IpRateLimitMiddleware> logger)
: base(next, options?.Value, new IpRateLimitProcessor(options?.Value, counterStore, policyStore, config), config)

ILogger<IpRateLimitMiddleware> logger
)
: base(next, options?.Value, new IpRateLimitProcessor(options?.Value, counterStore, policyStore, config, processingStrategy), config)
{
_logger = logger;
}
Expand Down
17 changes: 0 additions & 17 deletions src/AspNetCoreRateLimit/Middleware/MiddlewareExtensions.cs

This file was deleted.

Loading