Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow SSL for AMQP connection #178

Merged
merged 11 commits into from
Feb 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ The above sample waits on a trigger from the queue named "queue" connected to th
|HostName|(ignored if using ConnectionStringSetting) Hostname of the queue|`10.26.45.210`|
|UserName|(ignored if using ConnectionStringSetting) User name to access queue|`user`|
|Password|(ignored if using ConnectionStringSetting) Password to access queue|`password`|
|EnableSsl|Bool to enable or disable SSL in AMQP connection (default false)|`true`|
|SkipCertificateValidation|Bool to enable or disable checking certificate when EnableSsl=true. It will accept RemoteCertificateChainErrors and RemoteCertificateNameMismatch errors (default false. Not recommended for production)|`true`|

# Contributing

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,16 @@
* @return The port to attach.
*/
int port() default 0;

/**
* Enable or disable ssl in RabbitMQ connection.
* @return A bool to enable Ssl.
*/
boolean enableSsl() default false;

/**
* Enable os disable checking certificate when Ssl is enabled (not recommended for production).
* @return A bool to enable checking certificates when Ssl is enabled.
*/
boolean skipCertificateValidation() default false;
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,16 @@
* @return The port to attach.
*/
int port() default 0;

/**
* Enable or disable ssl in RabbitMQ connection.
* @return A bool to enable Ssl.
*/
boolean enableSsl() default false;

/**
* Enable os disable checking certificate when Ssl is enabled (not recommended for production).
* @return A bool to enable checking certificates when Ssl is enabled.
*/
boolean skipCertificateValidation() default false;
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ public void TestRabbitMQOutput() {
EasyMock.expect(outputInterface.userName()).andReturn("randomUserName");
EasyMock.expect(outputInterface.connectionStringSetting()).andReturn("randomConnectionStringSetting");
EasyMock.expect(outputInterface.port()).andReturn(123);
EasyMock.expect(outputInterface.enableSsl()).andReturn(false);
EasyMock.expect(outputInterface.skipCertificateValidation()).andReturn(false);

EasyMock.replay(outputInterface);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ public void TestRabbitMQTrigger() {
EasyMock.expect(triggerInterface.userNameSetting()).andReturn("randomUserName");
EasyMock.expect(triggerInterface.connectionStringSetting()).andReturn("randomConnectionStringSetting");
EasyMock.expect(triggerInterface.port()).andReturn(123);
EasyMock.expect(triggerInterface.enableSsl()).andReturn(false);
EasyMock.expect(triggerInterface.skipCertificateValidation()).andReturn(false);

EasyMock.replay(triggerInterface);
}
Expand Down
6 changes: 4 additions & 2 deletions src/Bindings/RabbitMQClientBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ private IModel CreateModelFromAttribute(RabbitMQAttribute attribute)
string resolvedUserName = Utility.FirstOrDefault(attribute.UserName, _options.Value.UserName);
string resolvedPassword = Utility.FirstOrDefault(attribute.Password, _options.Value.Password);
int resolvedPort = Utility.FirstOrDefault(attribute.Port, _options.Value.Port);
bool resolvedEnableSsl = Utility.FirstOrDefault(attribute.EnableSsl, _options.Value.EnableSsl);
bool resolvedSkipCertificateValidation = Utility.FirstOrDefault(attribute.SkipCertificateValidation, _options.Value.SkipCertificateValidation);

IRabbitMQService service = _configProvider.GetService(resolvedConnectionString, resolvedHostName, resolvedUserName, resolvedPassword, resolvedPort);
IRabbitMQService service = _configProvider.GetService(resolvedConnectionString, resolvedHostName, resolvedUserName, resolvedPassword, resolvedPort, resolvedEnableSsl, resolvedSkipCertificateValidation);

return service.Model;
}
}
}
}
8 changes: 4 additions & 4 deletions src/Config/DefaultRabbitMQServiceFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ namespace Microsoft.Azure.WebJobs.Extensions.RabbitMQ
{
internal class DefaultRabbitMQServiceFactory : IRabbitMQServiceFactory
{
public IRabbitMQService CreateService(string connectionString, string hostName, string queueName, string userName, string password, int port)
public IRabbitMQService CreateService(string connectionString, string hostName, string queueName, string userName, string password, int port, bool ssl, bool insecureSsl)
{
return new RabbitMQService(connectionString, hostName, queueName, userName, password, port);
return new RabbitMQService(connectionString, hostName, queueName, userName, password, port, ssl, insecureSsl);
}

public IRabbitMQService CreateService(string connectionString, string hostName, string userName, string password, int port)
public IRabbitMQService CreateService(string connectionString, string hostName, string userName, string password, int port, bool ssl, bool insecureSsl)
{
return new RabbitMQService(connectionString, hostName, userName, password, port);
return new RabbitMQService(connectionString, hostName, userName, password, port, ssl, insecureSsl);
}
}
}
4 changes: 2 additions & 2 deletions src/Config/IRabbitMQServiceFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ namespace Microsoft.Azure.WebJobs.Extensions.RabbitMQ
{
public interface IRabbitMQServiceFactory
{
IRabbitMQService CreateService(string connectionString, string hostName, string queueName, string userName, string password, int port);
IRabbitMQService CreateService(string connectionString, string hostName, string queueName, string userName, string password, int port, bool enableSsl, bool skipCertificateValidation);

IRabbitMQService CreateService(string connectionString, string hostName, string userName, string password, int port);
IRabbitMQService CreateService(string connectionString, string hostName, string userName, string password, int port, bool enableSsl, bool skipCertificateValidation);
}
}
14 changes: 9 additions & 5 deletions src/Config/RabbitMQExtensionConfigProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ internal RabbitMQContext CreateContext(RabbitMQAttribute attribute)
string userName = Utility.FirstOrDefault(attribute.UserName, _options.Value.UserName);
string password = Utility.FirstOrDefault(attribute.Password, _options.Value.Password);
int port = Utility.FirstOrDefault(attribute.Port, _options.Value.Port);
bool enableSsl = Utility.FirstOrDefault(attribute.EnableSsl, _options.Value.EnableSsl);
bool skipCertificateValidation = Utility.FirstOrDefault(attribute.SkipCertificateValidation, _options.Value.SkipCertificateValidation);

RabbitMQAttribute resolvedAttribute;
IRabbitMQService service;
Expand All @@ -98,9 +100,11 @@ internal RabbitMQContext CreateContext(RabbitMQAttribute attribute)
UserName = userName,
Password = password,
Port = port,
EnableSsl = enableSsl,
SkipCertificateValidation = skipCertificateValidation,
};

service = GetService(connectionString, hostName, queueName, userName, password, port);
service = GetService(connectionString, hostName, queueName, userName, password, port, enableSsl, skipCertificateValidation);

return new RabbitMQContext
{
Expand All @@ -109,19 +113,19 @@ internal RabbitMQContext CreateContext(RabbitMQAttribute attribute)
};
}

internal IRabbitMQService GetService(string connectionString, string hostName, string queueName, string userName, string password, int port)
internal IRabbitMQService GetService(string connectionString, string hostName, string queueName, string userName, string password, int port, bool enableSsl, bool skipCertificateValidation)
{
string[] keyArray = { connectionString, hostName, queueName, userName, password, port.ToString() };
string key = string.Join(",", keyArray);
return _connectionParametersToService.GetOrAdd(key, _ => _rabbitMQServiceFactory.CreateService(connectionString, hostName, queueName, userName, password, port));
return _connectionParametersToService.GetOrAdd(key, _ => _rabbitMQServiceFactory.CreateService(connectionString, hostName, queueName, userName, password, port, enableSsl, skipCertificateValidation));
}

// Overloaded method used only for getting the RabbitMQ client
internal IRabbitMQService GetService(string connectionString, string hostName, string userName, string password, int port)
internal IRabbitMQService GetService(string connectionString, string hostName, string userName, string password, int port, bool enableSsl, bool skipCertificateValidation)
{
string[] keyArray = { connectionString, hostName, userName, password, port.ToString() };
string key = string.Join(",", keyArray);
return _connectionParametersToService.GetOrAdd(key, _ => _rabbitMQServiceFactory.CreateService(connectionString, hostName, userName, password, port));
return _connectionParametersToService.GetOrAdd(key, _ => _rabbitMQServiceFactory.CreateService(connectionString, hostName, userName, password, port, enableSsl, skipCertificateValidation));
}
}
}
10 changes: 10 additions & 0 deletions src/Config/RabbitMQOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ public RabbitMQOptions()
/// </summary>
public int Port { get; set; }

/// <summary>
/// Enable or disable ssl in RabbitMQ connection.
/// </summary>
public bool EnableSsl { get; set; }

/// <summary>
/// Enable os disable checking certificate when Ssl is enabled (not recommended for production).
/// </summary>
public bool SkipCertificateValidation { get; set; }

/// <summary>
/// Gets or sets the prefetch count while creating the RabbitMQ QoS. This seting controls how many values are cached.
/// </summary>
Expand Down
10 changes: 10 additions & 0 deletions src/RabbitMQAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,15 @@ public sealed class RabbitMQAttribute : Attribute
/// </summary>
[ConnectionString]
public string ConnectionStringSetting { get; set; }

/// <summary>
/// Enable or disable ssl in RabbitMQ connection.
/// </summary>
public bool EnableSsl { get; set; }

/// <summary>
/// Enable os disable checking certificate when Ssl is enabled (not recommended for production).
/// </summary>
public bool SkipCertificateValidation { get; set; }
}
}
40 changes: 34 additions & 6 deletions src/Services/RabbitMQService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License. See License.txt in the project root for license information.

using System;
using System.Net.Security;
using RabbitMQ.Client;

namespace Microsoft.Azure.WebJobs.Extensions.RabbitMQ
Expand All @@ -16,26 +17,30 @@ internal sealed class RabbitMQService : IRabbitMQService
private readonly string _userName;
private readonly string _password;
private readonly int _port;
private readonly bool _enableSsl;
private readonly bool _skipCertificateValidation;
private readonly object _publishBatchLock;

private IBasicPublishBatch _batch;

public RabbitMQService(string connectionString, string hostName, string userName, string password, int port)
public RabbitMQService(string connectionString, string hostName, string userName, string password, int port, bool enableSsl, bool skipCertificateValidation)
{
_connectionString = connectionString;
_hostName = hostName;
_userName = userName;
_password = password;
_port = port;
_enableSsl = enableSsl;
_skipCertificateValidation = skipCertificateValidation;

ConnectionFactory connectionFactory = GetConnectionFactory(_connectionString, _hostName, _userName, _password, _port);
ConnectionFactory connectionFactory = GetConnectionFactory(_connectionString, _hostName, _userName, _password, _port, _enableSsl, _skipCertificateValidation);

_model = connectionFactory.CreateConnection().CreateModel();
_publishBatchLock = new object();
}

public RabbitMQService(string connectionString, string hostName, string queueName, string userName, string password, int port)
: this(connectionString, hostName, userName, password, port)
public RabbitMQService(string connectionString, string hostName, string queueName, string userName, string password, int port, bool enableSsl, bool skipCertificateValidation)
: this(connectionString, hostName, userName, password, port, enableSsl, skipCertificateValidation)
{
_rabbitMQModel = new RabbitMQModel(_model);
_queueName = queueName ?? throw new ArgumentNullException(nameof(queueName));
Expand All @@ -58,14 +63,16 @@ public void ResetPublishBatch()
_batch = _model.CreateBasicPublishBatch();
}

internal static ConnectionFactory GetConnectionFactory(string connectionString, string hostName, string userName, string password, int port)
internal static ConnectionFactory GetConnectionFactory(string connectionString, string hostName, string userName, string password, int port, bool enableSsl, bool skipCertificateValidation)
{
ConnectionFactory connectionFactory = new ConnectionFactory();

// Only set these if specified by user. Otherwise, API will use default parameters.
if (!string.IsNullOrEmpty(connectionString))
{
connectionFactory.Uri = new Uri(connectionString);
Uri amqpUri = new Uri(connectionString);
connectionFactory.Uri = amqpUri;
ConfigureSsl(connectionFactory, amqpUri.Host, enableSsl, skipCertificateValidation);
}
else
{
Expand All @@ -88,9 +95,30 @@ internal static ConnectionFactory GetConnectionFactory(string connectionString,
{
connectionFactory.Port = port;
}

ConfigureSsl(connectionFactory, hostName, enableSsl, skipCertificateValidation);
}

return connectionFactory;
}

internal static void ConfigureSsl(ConnectionFactory connectionFactory, string hostname, bool enableSsl, bool skipCertificateValidation)
{
if (enableSsl)
{
connectionFactory.Ssl = new SslOption
{
Enabled = true,

// Set SNI in order to work for multiple RabbitMQ clusters located behind a LoadBalancer
ServerName = hostname,
};
aaguilartablada marked this conversation as resolved.
Show resolved Hide resolved
if (skipCertificateValidation)
{
connectionFactory.Ssl.AcceptablePolicyErrors =
SslPolicyErrors.RemoteCertificateNameMismatch | SslPolicyErrors.RemoteCertificateChainErrors;
}
}
}
}
}
10 changes: 10 additions & 0 deletions src/Trigger/RabbitMQTriggerAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,15 @@ public RabbitMQTriggerAttribute(string hostName, string userNameSetting, string
/// Gets or sets the Port used. Defaults to 0.
/// </summary>
public int Port { get; set; }

/// <summary>
/// Enable or disable ssl in RabbitMQ connection.
/// </summary>
public bool EnableSsl { get; set; }

/// <summary>
/// Enable os disable checking certificate when Ssl is enabled (not recommended for production).
aaguilartablada marked this conversation as resolved.
Show resolved Hide resolved
/// </summary>
public bool SkipCertificateValidation { get; set; }
}
}
6 changes: 5 additions & 1 deletion src/Trigger/RabbitMQTriggerAttributeBindingProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,18 @@ public Task<ITriggerBinding> TryCreateAsync(TriggerBindingProviderContext contex

string password = Resolve(attribute.PasswordSetting);

bool enableSsl = attribute.EnableSsl;

bool skipCertificateValidation = attribute.SkipCertificateValidation;

int port = attribute.Port;

if (string.IsNullOrEmpty(connectionString) && !Utility.ValidateUserNamePassword(userName, password, hostName))
{
throw new InvalidOperationException("RabbitMQ username and password required if not connecting to localhost");
}

IRabbitMQService service = _provider.GetService(connectionString, hostName, queueName, userName, password, port);
IRabbitMQService service = _provider.GetService(connectionString, hostName, queueName, userName, password, port, enableSsl, skipCertificateValidation);

return Task.FromResult<ITriggerBinding>(new RabbitMQTriggerBinding(service, hostName, queueName, _logger, parameter.ParameterType, _options.Value.PrefetchCount));
}
Expand Down
13 changes: 5 additions & 8 deletions src/Utility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@ internal static string FirstOrDefault(params string[] values)

internal static int FirstOrDefault(params int[] values)
{
return values.FirstOrDefault(v =>
{
if (v != 0)
{
return true;
}
return values.FirstOrDefault(v => v != 0);
}

return false;
});
internal static bool FirstOrDefault(params bool[] values)
{
return values.FirstOrDefault(v => v);
}

internal static bool ValidateUserNamePassword(string userName, string password, string hostName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,21 @@ public void Opens_Connection()
var options = new OptionsWrapper<RabbitMQOptions>(new RabbitMQOptions { HostName = Constants.LocalHost });
var mockServiceFactory = new Mock<IRabbitMQServiceFactory>();
var config = new RabbitMQExtensionConfigProvider(options, new Mock<INameResolver>().Object, mockServiceFactory.Object, new LoggerFactory(), _emptyConfig);
mockServiceFactory.Setup(m => m.CreateService(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<string>(), It.IsAny<string>(), It.IsAny<int>())).Returns(new Mock<IRabbitMQService>().Object);
mockServiceFactory.Setup(m => m.CreateService(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<string>(), It.IsAny<string>(), It.IsAny<int>(), false, false)).Returns(new Mock<IRabbitMQService>().Object);
RabbitMQAttribute attr = GetTestAttribute();

RabbitMQClientBuilder clientBuilder = new RabbitMQClientBuilder(config, options);
var model = clientBuilder.Convert(attr);

mockServiceFactory.Verify(m => m.CreateService(It.IsAny<string>(), Constants.LocalHost, "guest", "guest", 5672), Times.Exactly(1));
mockServiceFactory.Verify(m => m.CreateService(It.IsAny<string>(), Constants.LocalHost, "guest", "guest", 5672, false, false), Times.Exactly(1));
}

[Fact]
public void TestWhetherConnectionIsPooled()
{
var options = new OptionsWrapper<RabbitMQOptions>(new RabbitMQOptions { HostName = Constants.LocalHost });
var mockServiceFactory = new Mock<IRabbitMQServiceFactory>();
mockServiceFactory.SetupSequence(m => m.CreateService(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<string>(), It.IsAny<string>(), It.IsAny<int>()))
mockServiceFactory.SetupSequence(m => m.CreateService(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<string>(), It.IsAny<string>(), It.IsAny<int>(), false, false))
.Returns(GetRabbitMQService())
.Returns(GetRabbitMQService());
var config = new RabbitMQExtensionConfigProvider(options, new Mock<INameResolver>().Object, mockServiceFactory.Object, new LoggerFactory(), _emptyConfig);
Expand Down Expand Up @@ -73,4 +73,4 @@ private IRabbitMQService GetRabbitMQService()
return mockService.Object;
}
}
}
}
Loading