Skip to content

Commit

Permalink
Add StackExchangeRedis atomic Lua script support and refactor to supp…
Browse files Browse the repository at this point in the history
…ort differing processing strategies
  • Loading branch information
nick-cromwell committed Dec 17, 2020
1 parent 5b21509 commit 60e6ebc
Show file tree
Hide file tree
Showing 27 changed files with 419 additions and 112 deletions.
19 changes: 11 additions & 8 deletions src/AspNetCoreRateLimit/AspNetCoreRateLimit.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
<PackageLicense>http://opensource.org/licenses/MIT</PackageLicense>
<RepositoryType>git</RepositoryType>
<RepositoryUrl>https://github.com/stefanprodan/AspNetCoreRateLimit</RepositoryUrl>
<LangVersion>8</LangVersion>
<LangVersion>9</LangVersion>
<Version>3.2.3</Version>
<SignAssembly>true</SignAssembly>
<AssemblyOriginatorKeyFile>../../sgKey.snk</AssemblyOriginatorKeyFile>
Expand All @@ -27,20 +27,23 @@
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="2.2.0" />
<PackageReference Include="Microsoft.Extensions.Options" Version="2.2.0" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
<PackageReference Include="StackExchange.Redis" Version="2.2.4" />
</ItemGroup>

<ItemGroup Condition="$(TargetFramework) == 'netcoreapp3.1'">
<PackageReference Include="Microsoft.Extensions.Caching.Abstractions" Version="3.1.9" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="3.1.9" />
<PackageReference Include="Microsoft.Extensions.Options" Version="3.1.9" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
<PackageReference Include="Microsoft.Extensions.Caching.Abstractions" Version="3.1.9" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="3.1.9" />
<PackageReference Include="Microsoft.Extensions.Options" Version="3.1.9" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
<PackageReference Include="StackExchange.Redis" Version="2.2.4" />
</ItemGroup>

<ItemGroup Condition="$(TargetFramework) == 'net5.0'">
<PackageReference Include="Microsoft.Extensions.Caching.Abstractions" Version="5.0.0" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="5.0.0" />
<PackageReference Include="Microsoft.Extensions.Options" Version="5.0.0" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
<PackageReference Include="StackExchange.Redis" Version="2.2.4" />
</ItemGroup>

<ItemGroup Condition="$(TargetFramework) == 'netcoreapp3.1'">
Expand All @@ -52,11 +55,11 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="1.0.0" PrivateAssets="All"/>
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="1.0.0" PrivateAssets="All" />
</ItemGroup>

<PropertyGroup Condition="'$(APPVEYOR)' == 'true'">
<ContinuousIntegrationBuild>true</ContinuousIntegrationBuild>
</PropertyGroup>

</Project>
18 changes: 13 additions & 5 deletions src/AspNetCoreRateLimit/Core/ClientRateLimitProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,20 @@ namespace AspNetCoreRateLimit
public class ClientRateLimitProcessor : RateLimitProcessor, IRateLimitProcessor
{
private readonly ClientRateLimitOptions _options;
private readonly IProcessingStrategy _processingStrategy;
private readonly IRateLimitStore<ClientRateLimitPolicy> _policyStore;

public ClientRateLimitProcessor(
ClientRateLimitOptions options,
IRateLimitCounterStore counterStore,
IClientPolicyStore policyStore,
IRateLimitConfiguration config)
: base(options, counterStore, new ClientCounterKeyBuilder(options), config)
IProcessingStrategyFactory processingStrategyFactory,
ClientRateLimitOptions options,
IRateLimitCounterStore counterStore,
IClientPolicyStore policyStore,
IRateLimitConfiguration config)
: base(options)
{
_options = options;
_policyStore = policyStore;
_processingStrategy = processingStrategyFactory.CreateProcessingStrategy(counterStore, new ClientCounterKeyBuilder(options), config, options);
}

public async Task<IEnumerable<RateLimitRule>> GetMatchingRulesAsync(ClientRequestIdentity identity, CancellationToken cancellationToken = default)
Expand All @@ -26,5 +29,10 @@ public async Task<IEnumerable<RateLimitRule>> GetMatchingRulesAsync(ClientReques

return GetMatchingRules(identity, policy?.Rules);
}

public async Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default)
{
return await _processingStrategy.ProcessRequestAsync(requestIdentity, rule, cancellationToken);
}
}
}
19 changes: 14 additions & 5 deletions src/AspNetCoreRateLimit/Core/IpRateLimitProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,22 @@ public class IpRateLimitProcessor : RateLimitProcessor, IRateLimitProcessor
{
private readonly IpRateLimitOptions _options;
private readonly IRateLimitStore<IpRateLimitPolicies> _policyStore;
private readonly IProcessingStrategy _processingStrategy;

public IpRateLimitProcessor(
IpRateLimitOptions options,
IRateLimitCounterStore counterStore,
IIpPolicyStore policyStore,
IRateLimitConfiguration config)
: base(options, counterStore, new IpCounterKeyBuilder(options), config)
IProcessingStrategyFactory processingStrategyFactory,
IpRateLimitOptions options,
IRateLimitCounterStore counterStore,
IIpPolicyStore policyStore,
IRateLimitConfiguration config)
: base(options)
{
_options = options;
_policyStore = policyStore;
_processingStrategy = processingStrategyFactory.CreateProcessingStrategy(counterStore, new IpCounterKeyBuilder(options), config, options);
}


public async Task<IEnumerable<RateLimitRule>> GetMatchingRulesAsync(ClientRequestIdentity identity, CancellationToken cancellationToken = default)
{
var policies = await _policyStore.GetAsync($"{_options.IpPolicyPrefix}", cancellationToken);
Expand All @@ -40,5 +44,10 @@ public async Task<IEnumerable<RateLimitRule>> GetMatchingRulesAsync(ClientReques

return GetMatchingRules(identity, rules);
}

public async Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default)
{
return await _processingStrategy.ProcessRequestAsync(requestIdentity, rule, cancellationToken);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
using System;
using System.Threading;
using System.Threading.Tasks;

namespace AspNetCoreRateLimit
{
public class AsyncKeyLockProcessingStrategy : ProcessingStrategy
{
private readonly RateLimitOptions _options;
private readonly IRateLimitCounterStore _counterStore;
private readonly ICounterKeyBuilder _counterKeyBuilder;
private readonly IRateLimitConfiguration _config;

public AsyncKeyLockProcessingStrategy(IRateLimitCounterStore counterStore, ICounterKeyBuilder counterKeyBuilder, IRateLimitConfiguration config, RateLimitOptions options)
: base(counterKeyBuilder, config, options)
{
_counterStore = counterStore;
_counterKeyBuilder = counterKeyBuilder;
_config = config;
_options = options;
}


/// The key-lock used for limiting requests.
private static readonly AsyncKeyLock AsyncLock = new AsyncKeyLock();

public override async Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default)
{
var counter = new RateLimitCounter
{
Timestamp = DateTime.UtcNow,
Count = 1
};

var counterId = BuildCounterKey(requestIdentity, rule);

// serial reads and writes on same key
using (await AsyncLock.WriterLockAsync(counterId).ConfigureAwait(false))
{
var entry = await _counterStore.GetAsync(counterId, cancellationToken);

if (entry.HasValue)
{
// entry has not expired
if (entry.Value.Timestamp + rule.PeriodTimespan.Value >= DateTime.UtcNow)
{
// increment request count
var totalCount = entry.Value.Count + _config.RateIncrementer?.Invoke() ?? 1;

// deep copy
counter = new RateLimitCounter
{
Timestamp = entry.Value.Timestamp,
Count = totalCount
};
}
}

// stores: id (string) - timestamp (datetime) - total_requests (long)
await _counterStore.SetAsync(counterId, counter, rule.PeriodTimespan.Value, cancellationToken);
}

return counter;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using System.Threading;
using System.Threading.Tasks;

namespace AspNetCoreRateLimit
{
public interface IProcessingStrategy
{
Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
namespace AspNetCoreRateLimit
{
public interface IProcessingStrategyFactory
{
ProcessingStrategy CreateProcessingStrategy(IRateLimitCounterStore counterStore, ICounterKeyBuilder counterKeyBuilder, IRateLimitConfiguration config, RateLimitOptions options);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using System;
using System.Security.Cryptography;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace AspNetCoreRateLimit
{
public abstract class ProcessingStrategy : IProcessingStrategy
{
private readonly RateLimitOptions _options;
private readonly ICounterKeyBuilder _counterKeyBuilder;
private readonly IRateLimitConfiguration _config;

public ProcessingStrategy(ICounterKeyBuilder counterKeyBuilder, IRateLimitConfiguration config, RateLimitOptions options)
: base()
{
_counterKeyBuilder = counterKeyBuilder;
_config = config;
_options = options;
}

protected virtual string BuildCounterKey(ClientRequestIdentity requestIdentity, RateLimitRule rule)
{
var key = _counterKeyBuilder.Build(requestIdentity, rule);

if (_options.EnableEndpointRateLimiting && _config.EndpointCounterKeyBuilder != null)
{
key += _config.EndpointCounterKeyBuilder.Build(requestIdentity, rule);
}

var bytes = Encoding.UTF8.GetBytes(key);

using var algorithm = new SHA1Managed();
var hash = algorithm.ComputeHash(bytes);

return Convert.ToBase64String(hash);
}

public abstract Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using System;
using StackExchange.Redis;

namespace AspNetCoreRateLimit
{

public class ProcessingStrategyFactory : IProcessingStrategyFactory
{
private readonly IConnectionMultiplexer _connectionMultiplexer;

public ProcessingStrategyFactory(IConnectionMultiplexer connectionMultiplexer = null)
{
_connectionMultiplexer = connectionMultiplexer;
}

public ProcessingStrategy CreateProcessingStrategy(IRateLimitCounterStore counterStore, ICounterKeyBuilder counterKeyBuilder, IRateLimitConfiguration config, RateLimitOptions options)
{
return counterStore switch
{
MemoryCacheRateLimitCounterStore => new AsyncKeyLockProcessingStrategy(counterStore, counterKeyBuilder, config, options),
DistributedCacheRateLimitCounterStore => new AsyncKeyLockProcessingStrategy(counterStore, counterKeyBuilder, config, options),
StackExchangeRedisRateLimitCounterStore => new StackExchangeRedisProcessingStrategy(_connectionMultiplexer, counterStore, counterKeyBuilder, config, options),
_ => throw new ArgumentException("Unsupported instance of IRateLimitCounterStore provided")
};
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using StackExchange.Redis;

namespace AspNetCoreRateLimit
{
public class StackExchangeRedisProcessingStrategy : ProcessingStrategy
{
private readonly IConnectionMultiplexer _connectionMultiplexer;
private readonly IRateLimitConfiguration _config;

public StackExchangeRedisProcessingStrategy(IConnectionMultiplexer connectionMultiplexer, IRateLimitCounterStore counterStore, ICounterKeyBuilder counterKeyBuilder, IRateLimitConfiguration config, RateLimitOptions options)
: base(counterKeyBuilder, config, options)
{
_connectionMultiplexer = connectionMultiplexer ?? throw new ArgumentException("IConnectionMultiplexer was null. Ensure StackExchange.Redis was successfully registered");
_config = config;
}


static private readonly LuaScript _atomicIncrement = LuaScript.Prepare("local count count = redis.call(\"INCRBYFLOAT\", @key, tonumber(@delta)) if tonumber(count) == @delta then redis.call(\"EXPIRE\", @key, @timeout) end return count");

public override async Task<RateLimitCounter> ProcessRequestAsync(ClientRequestIdentity requestIdentity, RateLimitRule rule, CancellationToken cancellationToken = default)
{
var counterId = BuildCounterKey(requestIdentity, rule);
return await IncrementAsync(counterId, rule.PeriodTimespan.Value, _config.RateIncrementer);
}

public async Task<RateLimitCounter> IncrementAsync(string counterId, TimeSpan interval, Func<double> RateIncrementer = null)
{
var now = DateTime.UtcNow;
var numberOfIntervals = now.Ticks / interval.Ticks;
var intervalStart = new DateTime(numberOfIntervals * interval.Ticks, DateTimeKind.Utc);

// Call the Lua script
var count = await _connectionMultiplexer.GetDatabase().ScriptEvaluateAsync(_atomicIncrement, new { key = counterId, timeout = interval.TotalSeconds, delta = RateIncrementer?.Invoke() ?? 1D });
return new RateLimitCounter
{
Count = (double)count,
Timestamp = intervalStart
};
}
}
}
Loading

0 comments on commit 60e6ebc

Please sign in to comment.