Skip to content

Commit

Permalink
Hsm Authentication with Http over Unix domain sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
ancaantochi committed May 21, 2018
1 parent 2ad6775 commit 128e147
Show file tree
Hide file tree
Showing 11 changed files with 803 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ namespace Microsoft.Azure.Devices.Client.HsmAuthentication
public partial class HsmHttpClient
{
private string _baseUrl = "http://";
private System.Net.Http.HttpClient _httpClient;
private System.Lazy<Newtonsoft.Json.JsonSerializerSettings> _settings;

public HsmHttpClient()
public HsmHttpClient(System.Net.Http.HttpClient httpClient)
{
_httpClient = httpClient;
_settings = new System.Lazy<Newtonsoft.Json.JsonSerializerSettings>(() =>
{
var settings = new Newtonsoft.Json.JsonSerializerSettings();
Expand Down Expand Up @@ -74,7 +76,7 @@ public async System.Threading.Tasks.Task<SignResponse> SignAsync(string api_vers
urlBuilder_.Append("api-version=").Append(System.Net.WebUtility.UrlEncode(ConvertToString(api_version, System.Globalization.CultureInfo.InvariantCulture))).Append("&");
urlBuilder_.Length--;

var client_ = new System.Net.Http.HttpClient();
var client_ = _httpClient;
try
{
using (var request_ = new System.Net.Http.HttpRequestMessage())
Expand Down Expand Up @@ -158,8 +160,6 @@ public async System.Threading.Tasks.Task<SignResponse> SignAsync(string api_vers
}
finally
{
if (client_ != null)
client_.Dispose();
}
}

Expand Down Expand Up @@ -193,7 +193,7 @@ public async System.Threading.Tasks.Task<EncryptResponse> EncryptAsync(string ap
urlBuilder_.Append("api-version=").Append(System.Net.WebUtility.UrlEncode(ConvertToString(api_version, System.Globalization.CultureInfo.InvariantCulture))).Append("&");
urlBuilder_.Length--;

var client_ = new System.Net.Http.HttpClient();
var client_ = _httpClient;
try
{
using (var request_ = new System.Net.Http.HttpRequestMessage())
Expand Down Expand Up @@ -277,8 +277,6 @@ public async System.Threading.Tasks.Task<EncryptResponse> EncryptAsync(string ap
}
finally
{
if (client_ != null)
client_.Dispose();
}
}

Expand Down Expand Up @@ -312,7 +310,7 @@ public async System.Threading.Tasks.Task<DecryptResponse> DecryptAsync(string ap
urlBuilder_.Append("api-version=").Append(System.Net.WebUtility.UrlEncode(ConvertToString(api_version, System.Globalization.CultureInfo.InvariantCulture))).Append("&");
urlBuilder_.Length--;

var client_ = new System.Net.Http.HttpClient();
var client_ = _httpClient;
try
{
using (var request_ = new System.Net.Http.HttpRequestMessage())
Expand Down Expand Up @@ -396,8 +394,6 @@ public async System.Threading.Tasks.Task<DecryptResponse> DecryptAsync(string ap
}
finally
{
if (client_ != null)
client_.Dispose();
}
}

Expand Down Expand Up @@ -429,7 +425,7 @@ public async System.Threading.Tasks.Task<CertificateResponse> CreateIdentityCert
urlBuilder_.Append("api-version=").Append(System.Net.WebUtility.UrlEncode(ConvertToString(api_version, System.Globalization.CultureInfo.InvariantCulture))).Append("&");
urlBuilder_.Length--;

var client_ = new System.Net.Http.HttpClient();
var client_ = _httpClient;
try
{
using (var request_ = new System.Net.Http.HttpRequestMessage())
Expand Down Expand Up @@ -511,8 +507,6 @@ public async System.Threading.Tasks.Task<CertificateResponse> CreateIdentityCert
}
finally
{
if (client_ != null)
client_.Dispose();
}
}

Expand Down Expand Up @@ -546,7 +540,7 @@ public async System.Threading.Tasks.Task<CertificateResponse> CreateServerCertif
urlBuilder_.Append("api-version=").Append(System.Net.WebUtility.UrlEncode(ConvertToString(api_version, System.Globalization.CultureInfo.InvariantCulture))).Append("&");
urlBuilder_.Length--;

var client_ = new System.Net.Http.HttpClient();
var client_ = _httpClient;
try
{
using (var request_ = new System.Net.Http.HttpRequestMessage())
Expand Down Expand Up @@ -630,8 +624,6 @@ public async System.Threading.Tasks.Task<CertificateResponse> CreateServerCertif
}
finally
{
if (client_ != null)
client_.Dispose();
}
}

Expand Down
82 changes: 68 additions & 14 deletions iothub/device/src/HsmAuthentication/HttpHsmSignatureProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,27 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Linq;
using System.Net.Http;
using System.Text;
using System.Threading.Tasks;
#if NETSTANDARD2_0
using Microsoft.Azure.Devices.Client.HsmAuthentication.Transport;
#endif
using Microsoft.Azure.Devices.Client.TransientFaultHandling;

namespace Microsoft.Azure.Devices.Client.HsmAuthentication
{
class HttpHsmSignatureProvider : ISignatureProvider
{
const SignRequestAlgo DefaultSignRequestAlgo = SignRequestAlgo.HMACSHA256;
const string DefaultKeyId = "primary";
readonly string apiVersion;
readonly HsmHttpClient httpClient;
private const string DefaultApiVersion = "2018-06-28";
private const string HttpScheme = "http";
private const string HttpsScheme = "https";
private const string UnixScheme = "unix";
private const SignRequestAlgo DefaultSignRequestAlgo = SignRequestAlgo.HMACSHA256;
private const string DefaultKeyId = "primary";
private readonly string _apiVersion;
private readonly Uri _providerUri;

static readonly ITransientErrorDetectionStrategy TransientErrorDetectionStrategy = new ErrorDetectionStrategy();
static readonly RetryStrategy TransientRetryStrategy =
Expand All @@ -30,11 +39,8 @@ public HttpHsmSignatureProvider(string providerUri, string apiVersion)
throw new ArgumentNullException(nameof(apiVersion));
}

this.httpClient = new HsmHttpClient()
{
BaseUrl = providerUri
};
this.apiVersion = apiVersion;
this._providerUri = new Uri(providerUri);
this._apiVersion = apiVersion;
}

public async Task<string> SignAsync(string moduleId, string generationId, string data)
Expand All @@ -55,9 +61,15 @@ public async Task<string> SignAsync(string moduleId, string generationId, string
Data = Encoding.UTF8.GetBytes(data)
};

HttpClient httpClient = GetHttpClient();
try
{
SignResponse response = await this.SignAsyncWithRetry(moduleId, generationId, signRequest);
var hsmHttpClient = new HsmHttpClient(httpClient)
{
BaseUrl = GetBaseUrl()
};

SignResponse response = await this.SignAsyncWithRetry(hsmHttpClient, moduleId, generationId, signRequest);

return Convert.ToBase64String(response.Digest);
}
Expand All @@ -66,19 +78,61 @@ public async Task<string> SignAsync(string moduleId, string generationId, string
switch (ex)
{
case SwaggerException<ErrorResponse> errorResponseException:
throw new HttpHsmComunicationException($"Error calling SignAsync: {errorResponseException.Result?.Message ?? string.Empty}", errorResponseException.StatusCode);
throw new HttpHsmComunicationException(
$"Error calling SignAsync: {errorResponseException.Result?.Message ?? string.Empty}",
errorResponseException.StatusCode);
case SwaggerException swaggerException:
throw new HttpHsmComunicationException($"Error calling SignAsync: {swaggerException.Response ?? string.Empty}", swaggerException.StatusCode);
throw new HttpHsmComunicationException(
$"Error calling SignAsync: {swaggerException.Response ?? string.Empty}",
swaggerException.StatusCode);
default:
throw;
}
}
finally
{
httpClient.Dispose();
}
}

private HttpClient GetHttpClient()
{
HttpClient client;

if (_providerUri.Scheme.Equals(HttpScheme, StringComparison.OrdinalIgnoreCase) || _providerUri.Scheme.Equals(HttpsScheme, StringComparison.OrdinalIgnoreCase))
{
client = new HttpClient();
return client;
}

#if NETSTANDARD2_0
if (_providerUri.Scheme.Equals(UnixScheme, StringComparison.OrdinalIgnoreCase))
{
client = new HttpClient(new HttpUdsMessageHandler(_providerUri));
return client;
}
#endif

throw new InvalidOperationException("ProviderUri scheme is not supported");
}

private string GetBaseUrl()
{

#if NETSTANDARD2_0
if (_providerUri.Scheme.Equals(UnixScheme, StringComparison.OrdinalIgnoreCase))
{
return $"{HttpScheme}://{_providerUri.Segments.Last()}";
}
#endif

return _providerUri.OriginalString;
}

async Task<SignResponse> SignAsyncWithRetry(string moduleId, string generationId, SignRequest signRequest)
private async Task<SignResponse> SignAsyncWithRetry(HsmHttpClient hsmHttpClient, string moduleId, string generationId, SignRequest signRequest)
{
var transientRetryPolicy = new RetryPolicy(TransientErrorDetectionStrategy, TransientRetryStrategy);
SignResponse response = await transientRetryPolicy.ExecuteAsync(() => this.httpClient.SignAsync(this.apiVersion, moduleId, generationId, signRequest));
SignResponse response = await transientRetryPolicy.ExecuteAsync(() => hsmHttpClient.SignAsync(_apiVersion, moduleId, generationId, signRequest));
return response;
}

Expand Down
2 changes: 1 addition & 1 deletion iothub/device/src/HsmAuthentication/SasTokenBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

namespace Microsoft.Azure.Devices.Client.HsmAuthentication
{
static class SasTokenBuilder
internal static class SasTokenBuilder
{
public static string BuildSasToken(string audience, string signature, string expiry)
{
Expand Down
6 changes: 5 additions & 1 deletion iothub/device/src/Microsoft.Azure.Devices.Client.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
<Reference Include="System.Net.Http.WebRequest" />
<Reference Include="System.Transactions" />
<Reference Include="System.Web" />
<Reference Include="System.Net.Http"/>
<Reference Include="System.Net.Http" />
<PackageReference Include="Microsoft.AspNet.WebApi.Client" Version="5.2.3" />
<PackageReference Include="Microsoft.Owin" Version="4.0.0" />
</ItemGroup>
Expand Down Expand Up @@ -91,6 +91,10 @@
<!-- NetStandard 2.0 -->
<ItemGroup Condition=" '$(TargetFramework)' == 'netstandard2.0' ">
<Compile Include="netstandard13\Common\IOThreadTimerSlim.cs" />
<Compile Include="netstandard20\HsmAuthentication\Transport\HttpUdsMessageHandler.cs" />
<Compile Include="netstandard20\HsmAuthentication\Transport\HttpBufferedStream.cs" />
<Compile Include="netstandard20\HsmAuthentication\Transport\HttpRequestResponseSerializer.cs" />
<Compile Include="netstandard20\HsmAuthentication\Transport\UnixDomainSocketEndPoint.cs" />
</ItemGroup>
<ItemGroup Condition=" '$(TargetFramework)' == 'netstandard2.0' ">
<PackageReference Include="System.Configuration.ConfigurationManager" Version="4.4.1" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.IO;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Azure.Devices.Client.HsmAuthentication.Transport
{
internal class HttpBufferedStream : Stream
{
private const char CR = '\r';
private const char LF = '\n';
private BufferedStream _innerStream;

public HttpBufferedStream(Stream stream)
{
_innerStream = new BufferedStream(stream);
}

public override void Flush()
{
_innerStream.Flush();
}

public override Task FlushAsync(CancellationToken cancellationToken)
{
return _innerStream.FlushAsync(cancellationToken);
}

public override int Read(byte[] buffer, int offset, int count)
{
return _innerStream.Read(buffer, offset, count);
}

public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return _innerStream.ReadAsync(buffer, offset, count, cancellationToken);
}

public async Task<string> ReadLineAsync(CancellationToken cancellationToken)
{
int position = 0;
byte[] buffer = new byte[1];
bool crFound = false;
var builder = new StringBuilder();
while (true)
{
var length = await _innerStream.ReadAsync(buffer, 0, buffer.Length, cancellationToken)
.ConfigureAwait(false);

if (length == 0)
{
throw new IOException("Unexpected end of stream.");
}

if (crFound && (char) buffer[position] == LF)
{
builder.Remove(builder.Length - 1, 1);
return builder.ToString();
}

builder.Append((char) buffer[position]);
crFound = (char) buffer[position] == CR;
}
}

public override long Seek(long offset, SeekOrigin origin)
{
return _innerStream.Seek(offset, origin);
}

public override void SetLength(long value)
{
_innerStream.SetLength(value);
}

public override void Write(byte[] buffer, int offset, int count)
{
_innerStream.Write(buffer, offset, count);
}

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return _innerStream.WriteAsync(buffer, offset, count, cancellationToken);
}

public override bool CanRead => _innerStream.CanRead;

public override bool CanSeek => _innerStream.CanSeek;

public override bool CanWrite => _innerStream.CanWrite;

public override long Length => _innerStream.Length;

public override long Position
{
get => _innerStream.Position;
set => _innerStream.Position = value;
}

protected override void Dispose(bool disposing)
{
_innerStream.Dispose();
}
}
}
Loading

0 comments on commit 128e147

Please sign in to comment.