diff --git a/src/Ocelot/Configuration/Builder/DownstreamReRouteBuilder.cs b/src/Ocelot/Configuration/Builder/DownstreamReRouteBuilder.cs index 972adeddd..8e2583d2d 100644 --- a/src/Ocelot/Configuration/Builder/DownstreamReRouteBuilder.cs +++ b/src/Ocelot/Configuration/Builder/DownstreamReRouteBuilder.cs @@ -38,7 +38,7 @@ public class DownstreamReRouteBuilder private List _addHeadersToDownstream; private List _addHeadersToUpstream; private bool _dangerousAcceptAnyServerCertificateValidator; - + private SecurityOptions _securityOptions; public DownstreamReRouteBuilder() { _downstreamAddresses = new List(); @@ -227,6 +227,12 @@ public DownstreamReRouteBuilder WithDangerousAcceptAnyServerCertificateValidator return this; } + public DownstreamReRouteBuilder WithSecurityOptions(SecurityOptions securityOptions) + { + _securityOptions = securityOptions; + return this; + } + public DownstreamReRoute Build() { return new DownstreamReRoute( @@ -258,7 +264,8 @@ public DownstreamReRoute Build() _delegatingHandlers, _addHeadersToDownstream, _addHeadersToUpstream, - _dangerousAcceptAnyServerCertificateValidator); + _dangerousAcceptAnyServerCertificateValidator, + _securityOptions); } } } diff --git a/src/Ocelot/Configuration/Creator/ISecurityOptionsCreator.cs b/src/Ocelot/Configuration/Creator/ISecurityOptionsCreator.cs new file mode 100644 index 000000000..792c85304 --- /dev/null +++ b/src/Ocelot/Configuration/Creator/ISecurityOptionsCreator.cs @@ -0,0 +1,12 @@ +using Ocelot.Configuration.File; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Ocelot.Configuration.Creator +{ + public interface ISecurityOptionsCreator + { + SecurityOptions Create(FileSecurityOptions securityOptions); + } +} diff --git a/src/Ocelot/Configuration/Creator/ReRoutesCreator.cs b/src/Ocelot/Configuration/Creator/ReRoutesCreator.cs index d8e5b66de..d2ec6619d 100644 --- a/src/Ocelot/Configuration/Creator/ReRoutesCreator.cs +++ b/src/Ocelot/Configuration/Creator/ReRoutesCreator.cs @@ -21,6 +21,7 @@ public class ReRoutesCreator : IReRoutesCreator private readonly IHeaderFindAndReplaceCreator _headerFAndRCreator; private readonly IDownstreamAddressesCreator _downstreamAddressesCreator; private readonly IReRouteKeyCreator _reRouteKeyCreator; + private readonly ISecurityOptionsCreator _securityOptionsCreator; public ReRoutesCreator( IClaimsToThingCreator claimsToThingCreator, @@ -35,7 +36,8 @@ public ReRoutesCreator( IHeaderFindAndReplaceCreator headerFAndRCreator, IDownstreamAddressesCreator downstreamAddressesCreator, ILoadBalancerOptionsCreator loadBalancerOptionsCreator, - IReRouteKeyCreator reRouteKeyCreator + IReRouteKeyCreator reRouteKeyCreator, + ISecurityOptionsCreator securityOptionsCreator ) { _reRouteKeyCreator = reRouteKeyCreator; @@ -52,6 +54,7 @@ IReRouteKeyCreator reRouteKeyCreator _fileReRouteOptionsCreator = fileReRouteOptionsCreator; _httpHandlerOptionsCreator = httpHandlerOptionsCreator; _loadBalancerOptionsCreator = loadBalancerOptionsCreator; + _securityOptionsCreator = securityOptionsCreator; } public List Create(FileConfiguration fileConfiguration) @@ -97,6 +100,8 @@ private DownstreamReRoute SetUpDownstreamReRoute(FileReRoute fileReRoute, FileGl var lbOptions = _loadBalancerOptionsCreator.Create(fileReRoute.LoadBalancerOptions); + var securityOptions = _securityOptionsCreator.Create(fileReRoute.SecurityOptions); + var reRoute = new DownstreamReRouteBuilder() .WithKey(fileReRoute.Key) .WithDownstreamPathTemplate(fileReRoute.DownstreamPathTemplate) @@ -128,6 +133,7 @@ private DownstreamReRoute SetUpDownstreamReRoute(FileReRoute fileReRoute, FileGl .WithAddHeadersToDownstream(hAndRs.AddHeadersToDownstream) .WithAddHeadersToUpstream(hAndRs.AddHeadersToUpstream) .WithDangerousAcceptAnyServerCertificateValidator(fileReRoute.DangerousAcceptAnyServerCertificateValidator) + .WithSecurityOptions(securityOptions) .Build(); return reRoute; diff --git a/src/Ocelot/Configuration/Creator/SecurityOptionsCreator.cs b/src/Ocelot/Configuration/Creator/SecurityOptionsCreator.cs new file mode 100644 index 000000000..d1ee35083 --- /dev/null +++ b/src/Ocelot/Configuration/Creator/SecurityOptionsCreator.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Ocelot.Configuration.File; + +namespace Ocelot.Configuration.Creator +{ + public class SecurityOptionsCreator : ISecurityOptionsCreator + { + public SecurityOptions Create(FileSecurityOptions securityOptions) + { + return new SecurityOptions(securityOptions.IPAllowedList, securityOptions.IPBlockedList); + } + } +} diff --git a/src/Ocelot/Configuration/DownstreamReRoute.cs b/src/Ocelot/Configuration/DownstreamReRoute.cs index ec6f03c6f..b8ec926ce 100644 --- a/src/Ocelot/Configuration/DownstreamReRoute.cs +++ b/src/Ocelot/Configuration/DownstreamReRoute.cs @@ -35,7 +35,8 @@ public DownstreamReRoute( List delegatingHandlers, List addHeadersToDownstream, List addHeadersToUpstream, - bool dangerousAcceptAnyServerCertificateValidator) + bool dangerousAcceptAnyServerCertificateValidator, + SecurityOptions securityOptions) { DangerousAcceptAnyServerCertificateValidator = dangerousAcceptAnyServerCertificateValidator; AddHeadersToDownstream = addHeadersToDownstream; @@ -66,6 +67,7 @@ public DownstreamReRoute( DownstreamPathTemplate = downstreamPathTemplate; LoadBalancerKey = loadBalancerKey; AddHeadersToUpstream = addHeadersToUpstream; + SecurityOptions = securityOptions; } public string Key { get; } @@ -97,5 +99,6 @@ public DownstreamReRoute( public List AddHeadersToDownstream { get; } public List AddHeadersToUpstream { get; } public bool DangerousAcceptAnyServerCertificateValidator { get; } + public SecurityOptions SecurityOptions { get; } } } diff --git a/src/Ocelot/Configuration/File/FileReRoute.cs b/src/Ocelot/Configuration/File/FileReRoute.cs index f03156e41..1a8f2a1a3 100644 --- a/src/Ocelot/Configuration/File/FileReRoute.cs +++ b/src/Ocelot/Configuration/File/FileReRoute.cs @@ -21,6 +21,7 @@ public FileReRoute() DownstreamHostAndPorts = new List(); DelegatingHandlers = new List(); LoadBalancerOptions = new FileLoadBalancerOptions(); + SecurityOptions = new FileSecurityOptions(); Priority = 1; } @@ -50,5 +51,6 @@ public FileReRoute() public int Priority { get;set; } public int Timeout { get; set; } public bool DangerousAcceptAnyServerCertificateValidator { get; set; } + public FileSecurityOptions SecurityOptions { get; set; } } } diff --git a/src/Ocelot/Configuration/File/FileSecurityOptions.cs b/src/Ocelot/Configuration/File/FileSecurityOptions.cs new file mode 100644 index 000000000..1f383a5b6 --- /dev/null +++ b/src/Ocelot/Configuration/File/FileSecurityOptions.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Ocelot.Configuration.File +{ + public class FileSecurityOptions + { + public FileSecurityOptions() + { + IPAllowedList = new List(); + IPBlockedList = new List(); + } + + public List IPAllowedList { get; set; } + + public List IPBlockedList { get; set; } + } +} diff --git a/src/Ocelot/Configuration/SecurityOptions.cs b/src/Ocelot/Configuration/SecurityOptions.cs new file mode 100644 index 000000000..88d4d08af --- /dev/null +++ b/src/Ocelot/Configuration/SecurityOptions.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Ocelot.Configuration +{ + public class SecurityOptions + { + public SecurityOptions(List allowedList, List blockedList) + { + this.IPAllowedList = allowedList; + this.IPBlockedList = blockedList; + } + + public List IPAllowedList { get; private set; } + + public List IPBlockedList { get; private set; } + } +} diff --git a/src/Ocelot/DependencyInjection/OcelotBuilder.cs b/src/Ocelot/DependencyInjection/OcelotBuilder.cs index 77f29c5a9..68b8414cc 100644 --- a/src/Ocelot/DependencyInjection/OcelotBuilder.cs +++ b/src/Ocelot/DependencyInjection/OcelotBuilder.cs @@ -35,6 +35,8 @@ namespace Ocelot.DependencyInjection using Ocelot.Infrastructure; using Ocelot.Middleware.Multiplexer; using Ocelot.Request.Creator; + using Ocelot.Security.IPSecurity; + using Ocelot.Security; public class OcelotBuilder : IOcelotBuilder { @@ -125,6 +127,9 @@ public OcelotBuilder(IServiceCollection services, IConfiguration configurationRo Services.TryAddSingleton(); Services.TryAddSingleton(); + //add security + this.AddSecurity(); + //add asp.net services.. var assembly = typeof(FileConfigurationController).GetTypeInfo().Assembly; @@ -139,22 +144,28 @@ public OcelotBuilder(IServiceCollection services, IConfiguration configurationRo Services.AddWebEncoders(); } - public IOcelotBuilder AddSingletonDefinedAggregator() + public IOcelotBuilder AddSingletonDefinedAggregator() where T : class, IDefinedAggregator { Services.AddSingleton(); return this; } - public IOcelotBuilder AddTransientDefinedAggregator() + public IOcelotBuilder AddTransientDefinedAggregator() where T : class, IDefinedAggregator { Services.AddTransient(); return this; } - public IOcelotBuilder AddDelegatingHandler(bool global = false) - where THandler : DelegatingHandler + private void AddSecurity() + { + Services.TryAddSingleton(); + Services.TryAddSingleton(); + } + + public IOcelotBuilder AddDelegatingHandler(bool global = false) + where THandler : DelegatingHandler { if(global) { diff --git a/src/Ocelot/Middleware/Pipeline/OcelotPipelineExtensions.cs b/src/Ocelot/Middleware/Pipeline/OcelotPipelineExtensions.cs index 1153ccd59..67865b68f 100644 --- a/src/Ocelot/Middleware/Pipeline/OcelotPipelineExtensions.cs +++ b/src/Ocelot/Middleware/Pipeline/OcelotPipelineExtensions.cs @@ -15,6 +15,7 @@ using Ocelot.Requester.Middleware; using Ocelot.RequestId.Middleware; using Ocelot.Responder.Middleware; +using Ocelot.Security.Middleware; using Ocelot.WebSockets.Middleware; namespace Ocelot.Middleware.Pipeline @@ -48,6 +49,9 @@ public static OcelotRequestDelegate BuildOcelotPipeline(this IOcelotPipelineBuil // Then we get the downstream route information builder.UseDownstreamRouteFinderMiddleware(); + // This security module, IP whitelist blacklist, extended security mechanism + builder.UseSecurityMiddleware(); + //Expand other branch pipes if (pipelineConfiguration.MapWhenOcelotPipeline != null) { diff --git a/src/Ocelot/Security/IPSecurity/IPSecurityPolicy.cs b/src/Ocelot/Security/IPSecurity/IPSecurityPolicy.cs new file mode 100644 index 000000000..bbf97778d --- /dev/null +++ b/src/Ocelot/Security/IPSecurity/IPSecurityPolicy.cs @@ -0,0 +1,44 @@ +using System; +using System.Collections.Generic; +using System.Net; +using System.Text; +using System.Threading.Tasks; +using Ocelot.Configuration; +using Ocelot.Middleware; +using Ocelot.Responses; + +namespace Ocelot.Security.IPSecurity +{ + public class IPSecurityPolicy : ISecurityPolicy + { + public async Task Security(DownstreamContext context) + { + IPAddress clientIp = context.HttpContext.Connection.RemoteIpAddress; + SecurityOptions securityOptions = context.DownstreamReRoute.SecurityOptions; + if (securityOptions == null) + { + return new OkResponse(); + } + + if (securityOptions.IPBlockedList != null) + { + if (securityOptions.IPBlockedList.Exists(f => f == clientIp.ToString())) + { + var error = new UnauthenticatedError($" This request rejects access to {clientIp.ToString()} IP"); + return new ErrorResponse(error); + } + } + + if (securityOptions.IPAllowedList != null && securityOptions.IPAllowedList.Count > 0) + { + if (!securityOptions.IPAllowedList.Exists(f => f == clientIp.ToString())) + { + var error = new UnauthenticatedError($"{clientIp.ToString()} does not allow access, the request is invalid"); + return new ErrorResponse(error); + } + } + + return await Task.FromResult(new OkResponse()); + } + } +} diff --git a/src/Ocelot/Security/ISecurityPolicy.cs b/src/Ocelot/Security/ISecurityPolicy.cs new file mode 100644 index 000000000..2b3457ece --- /dev/null +++ b/src/Ocelot/Security/ISecurityPolicy.cs @@ -0,0 +1,14 @@ +using Ocelot.Middleware; +using Ocelot.Responses; +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading.Tasks; + +namespace Ocelot.Security +{ + public interface ISecurityPolicy + { + Task Security(DownstreamContext context); + } +} diff --git a/src/Ocelot/Security/Middleware/SecurityMiddleware.cs b/src/Ocelot/Security/Middleware/SecurityMiddleware.cs new file mode 100644 index 000000000..1665ba3b6 --- /dev/null +++ b/src/Ocelot/Security/Middleware/SecurityMiddleware.cs @@ -0,0 +1,43 @@ +using Ocelot.Logging; +using Ocelot.Middleware; +using System.Collections.Generic; +using System.Threading.Tasks; + +namespace Ocelot.Security.Middleware +{ + public class SecurityMiddleware : OcelotMiddleware + { + private readonly OcelotRequestDelegate _next; + private readonly IOcelotLogger _logger; + private readonly IEnumerable _securityPolicies; + public SecurityMiddleware(IOcelotLoggerFactory loggerFactory, + IEnumerable securityPolicies, + OcelotRequestDelegate next) + : base(loggerFactory.CreateLogger()) + { + _logger = loggerFactory.CreateLogger(); + _securityPolicies = securityPolicies; + _next = next; + } + + public async Task Invoke(DownstreamContext context) + { + if (_securityPolicies != null) + { + foreach (var policie in _securityPolicies) + { + var result = await policie.Security(context); + if (!result.IsError) + { + continue; + } + + this.SetPipelineError(context, result.Errors); + return; + } + } + + await _next.Invoke(context); + } + } +} diff --git a/src/Ocelot/Security/Middleware/SecurityMiddlewareExtensions.cs b/src/Ocelot/Security/Middleware/SecurityMiddlewareExtensions.cs new file mode 100644 index 000000000..1e454d3a9 --- /dev/null +++ b/src/Ocelot/Security/Middleware/SecurityMiddlewareExtensions.cs @@ -0,0 +1,15 @@ +using Ocelot.Middleware.Pipeline; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Ocelot.Security.Middleware +{ + public static class SecurityMiddlewareExtensions + { + public static IOcelotPipelineBuilder UseSecurityMiddleware(this IOcelotPipelineBuilder builder) + { + return builder.UseMiddleware(); + } + } +} diff --git a/test/Ocelot.UnitTests/Configuration/ReRoutesCreatorTests.cs b/test/Ocelot.UnitTests/Configuration/ReRoutesCreatorTests.cs index 7d27f376e..dbfcc71a3 100644 --- a/test/Ocelot.UnitTests/Configuration/ReRoutesCreatorTests.cs +++ b/test/Ocelot.UnitTests/Configuration/ReRoutesCreatorTests.cs @@ -30,7 +30,8 @@ public class ReRoutesCreatorTests private Mock _daCreator; private Mock _lboCreator; private Mock _rrkCreator; - private FileConfiguration _fileConfig; + private Mock _soCreator; + private FileConfiguration _fileConfig; private ReRouteOptions _rro; private string _requestId; private string _rrk; @@ -45,6 +46,7 @@ public class ReRoutesCreatorTests private List _dhp; private LoadBalancerOptions _lbo; private List _result; + private SecurityOptions _securityOptions; public ReRoutesCreatorTests() { @@ -61,6 +63,7 @@ public ReRoutesCreatorTests() _daCreator = new Mock(); _lboCreator = new Mock(); _rrkCreator = new Mock(); + _soCreator = new Mock(); _creator = new ReRoutesCreator( _cthCreator.Object, @@ -75,7 +78,8 @@ public ReRoutesCreatorTests() _hfarCreator.Object, _daCreator.Object, _lboCreator.Object, - _rrkCreator.Object + _rrkCreator.Object, + _soCreator.Object ); } @@ -266,6 +270,7 @@ private void ThenTheDepsAreCalledFor(FileReRoute fileReRoute, FileGlobalConfigur _hfarCreator.Verify(x => x.Create(fileReRoute), Times.Once); _daCreator.Verify(x => x.Create(fileReRoute), Times.Once); _lboCreator.Verify(x => x.Create(fileReRoute.LoadBalancerOptions), Times.Once); + _soCreator.Verify(x => x.Create(fileReRoute.SecurityOptions), Times.Once); } } } diff --git a/test/Ocelot.UnitTests/Configuration/SecurityOptionsCreatorTests.cs b/test/Ocelot.UnitTests/Configuration/SecurityOptionsCreatorTests.cs new file mode 100644 index 000000000..220dadcd7 --- /dev/null +++ b/test/Ocelot.UnitTests/Configuration/SecurityOptionsCreatorTests.cs @@ -0,0 +1,72 @@ +using Ocelot.Configuration; +using Ocelot.Configuration.Creator; +using Ocelot.Configuration.File; +using Shouldly; +using System; +using System.Collections.Generic; +using System.Text; +using TestStack.BDDfy; +using Xunit; + +namespace Ocelot.UnitTests.Configuration +{ + public class SecurityOptionsCreatorTests + { + private FileReRoute _fileReRoute; + private FileGlobalConfiguration _fileGlobalConfig; + private SecurityOptions _result; + private ISecurityOptionsCreator _creator; + + public SecurityOptionsCreatorTests() + { + _creator = new SecurityOptionsCreator(); + } + + [Fact] + public void should_create_security_config() + { + var ipAllowedList = new List() { "127.0.0.1", "192.168.1.1" }; + var ipBlockedList = new List() { "127.0.0.1", "192.168.1.1" }; + var fileReRoute = new FileReRoute + { + SecurityOptions = new FileSecurityOptions() + { + IPAllowedList = ipAllowedList, + IPBlockedList = ipBlockedList + } + }; + + var expected = new SecurityOptions(ipAllowedList, ipBlockedList); + + this.Given(x => x.GivenThe(fileReRoute)) + .When(x => x.WhenICreate()) + .Then(x => x.ThenTheResultIs(expected)) + .BDDfy(); + } + + private void GivenThe(FileReRoute reRoute) + { + _fileReRoute = reRoute; + } + + + private void WhenICreate() + { + _result = _creator.Create(_fileReRoute.SecurityOptions); + } + + private void ThenTheResultIs(SecurityOptions expected) + { + for (int i = 0; i < expected.IPAllowedList.Count; i++) + { + _result.IPAllowedList[i].ShouldBe(expected.IPAllowedList[i]); + } + + for (int i = 0; i < expected.IPBlockedList.Count; i++) + { + _result.IPBlockedList[i].ShouldBe(expected.IPBlockedList[i]); + } + } + + } +} diff --git a/test/Ocelot.UnitTests/Security/IPSecurityPolicyTests.cs b/test/Ocelot.UnitTests/Security/IPSecurityPolicyTests.cs new file mode 100644 index 000000000..5e4062aea --- /dev/null +++ b/test/Ocelot.UnitTests/Security/IPSecurityPolicyTests.cs @@ -0,0 +1,117 @@ +using Microsoft.AspNetCore.Http; +using Ocelot.Configuration; +using Ocelot.Configuration.Builder; +using Ocelot.Middleware; +using Ocelot.Request.Middleware; +using Ocelot.Responses; +using Ocelot.Security.IPSecurity; +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Http; +using System.Text; +using TestStack.BDDfy; +using Xunit; + +namespace Ocelot.UnitTests.Security +{ + public class IPSecurityPolicyTests + { + private readonly DownstreamContext _downstreamContext; + private readonly DownstreamReRouteBuilder _downstreamReRouteBuilder; + private readonly IPSecurityPolicy _ipSecurityPolicy; + private Response response; + public IPSecurityPolicyTests() + { + _downstreamContext = new DownstreamContext(new DefaultHttpContext()); + _downstreamContext.DownstreamRequest = new DownstreamRequest(new HttpRequestMessage(HttpMethod.Get, "http://test.com")); + _downstreamContext.HttpContext.Connection.RemoteIpAddress = Dns.GetHostAddresses("192.168.1.1")[0]; + _downstreamReRouteBuilder = new DownstreamReRouteBuilder(); + _ipSecurityPolicy = new IPSecurityPolicy(); + } + + [Fact] + private void should_No_blocked_Ip_and_allowed_Ip() + { + this.Given(x => x.GivenSetDownstreamReRoute()) + .When(x => x.WhenTheSecurityPolicy()) + .Then(x => x.ThenSecurityPassing()) + .BDDfy(); + } + + [Fact] + private void should_blockedIp_clientIp_block() + { + _downstreamContext.HttpContext.Connection.RemoteIpAddress = Dns.GetHostAddresses("192.168.1.1")[0]; + this.Given(x => x.GivenSetBlockedIP()) + .Given(x => x.GivenSetDownstreamReRoute()) + .When(x => x.WhenTheSecurityPolicy()) + .Then(x => x.ThenNotSecurityPassing()) + .BDDfy(); + } + + [Fact] + private void should_blockedIp_clientIp_Not_block() + { + _downstreamContext.HttpContext.Connection.RemoteIpAddress = Dns.GetHostAddresses("192.168.1.2")[0]; + this.Given(x => x.GivenSetBlockedIP()) + .Given(x => x.GivenSetDownstreamReRoute()) + .When(x => x.WhenTheSecurityPolicy()) + .Then(x => x.ThenSecurityPassing()) + .BDDfy(); + } + + + [Fact] + private void should_allowedIp_clientIp_block() + { + _downstreamContext.HttpContext.Connection.RemoteIpAddress = Dns.GetHostAddresses("192.168.1.1")[0]; + this.Given(x => x.GivenSetAllowedIP()) + .Given(x => x.GivenSetDownstreamReRoute()) + .When(x => x.WhenTheSecurityPolicy()) + .Then(x => x.ThenSecurityPassing()) + .BDDfy(); + } + + [Fact] + private void should_allowedIp_clientIp_Not_block() + { + _downstreamContext.HttpContext.Connection.RemoteIpAddress = Dns.GetHostAddresses("192.168.1.2")[0]; + this.Given(x => x.GivenSetAllowedIP()) + .Given(x => x.GivenSetDownstreamReRoute()) + .When(x => x.WhenTheSecurityPolicy()) + .Then(x => x.ThenNotSecurityPassing()) + .BDDfy(); + } + + private void GivenSetAllowedIP() + { + _downstreamReRouteBuilder.WithSecurityOptions(new SecurityOptions(new List { "192.168.1.1" }, new List())); + } + + private void GivenSetBlockedIP() + { + _downstreamReRouteBuilder.WithSecurityOptions(new SecurityOptions(new List(), new List { "192.168.1.1" })); + } + + private void GivenSetDownstreamReRoute() + { + _downstreamContext.DownstreamReRoute = _downstreamReRouteBuilder.Build(); + } + + private void WhenTheSecurityPolicy() + { + response = this._ipSecurityPolicy.Security(_downstreamContext).GetAwaiter().GetResult(); + } + + private void ThenSecurityPassing() + { + Assert.False(response.IsError); + } + + private void ThenNotSecurityPassing() + { + Assert.True(response.IsError); + } + } +} diff --git a/test/Ocelot.UnitTests/Security/SecurityMiddlewareTests.cs b/test/Ocelot.UnitTests/Security/SecurityMiddlewareTests.cs new file mode 100644 index 000000000..d7d381402 --- /dev/null +++ b/test/Ocelot.UnitTests/Security/SecurityMiddlewareTests.cs @@ -0,0 +1,108 @@ +using Microsoft.AspNetCore.Http; +using Moq; +using Ocelot.Errors; +using Ocelot.Logging; +using Ocelot.Middleware; +using Ocelot.Request.Middleware; +using Ocelot.Responses; +using Ocelot.Security; +using Ocelot.Security.IPSecurity; +using Ocelot.Security.Middleware; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Text; +using System.Threading.Tasks; +using TestStack.BDDfy; +using Xunit; + +namespace Ocelot.UnitTests.Security +{ + public class SecurityMiddlewareTests + { + private List> _securityPolicyList; + private Mock _loggerFactory; + private Mock _logger; + private readonly SecurityMiddleware _middleware; + private readonly DownstreamContext _downstreamContext; + private readonly OcelotRequestDelegate _next; + + public SecurityMiddlewareTests() + { + _loggerFactory = new Mock(); + _logger = new Mock(); + _loggerFactory.Setup(x => x.CreateLogger()).Returns(_logger.Object); + _securityPolicyList = new List>(); + _securityPolicyList.Add(new Mock()); + _securityPolicyList.Add(new Mock()); + _next = context => + { + return Task.CompletedTask; + }; + _middleware = new SecurityMiddleware(_loggerFactory.Object, _securityPolicyList.Select(f => f.Object).ToList(), _next); + _downstreamContext = new DownstreamContext(new DefaultHttpContext()); + _downstreamContext.DownstreamRequest = new DownstreamRequest(new HttpRequestMessage(HttpMethod.Get, "http://test.com")); + } + [Fact] + public void should_legal_request() + { + this.Given(x => x.GivenPassingSecurityVerification()) + .When(x => x.WhenICallTheMiddleware()) + .Then(x => x.ThenTheRequestIsPassingSecurity()) + .BDDfy(); + } + + [Fact] + public void should_verification_failed_request() + { + this.Given(x => x.GivenNotPassingSecurityVerification()) + .When(x => x.WhenICallTheMiddleware()) + .Then(x => x.ThenTheRequestIsNotPassingSecurity()) + .BDDfy(); + } + + private void GivenPassingSecurityVerification() + { + foreach (var item in _securityPolicyList) + { + Response response = new OkResponse(); + item.Setup(x => x.Security(_downstreamContext)).Returns(Task.FromResult(response)); + } + } + + private void GivenNotPassingSecurityVerification() + { + for (int i = 0; i < _securityPolicyList.Count; i++) + { + Mock item = _securityPolicyList[i]; + if (i == 0) + { + Error error = new UnauthenticatedError($"Not passing security verification"); + Response response = new ErrorResponse(error); + item.Setup(x => x.Security(_downstreamContext)).Returns(Task.FromResult(response)); + } + else + { + Response response = new OkResponse(); + item.Setup(x => x.Security(_downstreamContext)).Returns(Task.FromResult(response)); + } + } + } + + private void WhenICallTheMiddleware() + { + _middleware.Invoke(_downstreamContext).GetAwaiter().GetResult(); + } + + private void ThenTheRequestIsPassingSecurity() + { + Assert.False(_downstreamContext.IsError); + } + + private void ThenTheRequestIsNotPassingSecurity() + { + Assert.True(_downstreamContext.IsError); + } + } +}