Skip to content

Commit

Permalink
added more logging to certificate import (cherry pick #6085) (#6100)
Browse files Browse the repository at this point in the history
  • Loading branch information
vipeller authored Feb 8, 2022
1 parent 3287237 commit c9f8daa
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public static async Task<EdgeHubCertificates> LoadAsync(IConfigurationRoot confi
string edgeletApiVersion = configuration.GetValue<string>(Constants.ConfigKey.WorkloadAPiVersion);
DateTime expiration = DateTime.UtcNow.AddDays(Constants.CertificateValidityDays);

certificates = await CertificateHelper.GetServerCertificatesFromEdgelet(workloadUri, edgeletApiVersion, Constants.WorkloadApiVersion, moduleId, generationId, edgeHubHostname, expiration);
certificates = await CertificateHelper.GetServerCertificatesFromEdgelet(workloadUri, edgeletApiVersion, Constants.WorkloadApiVersion, moduleId, generationId, edgeHubHostname, expiration, logger);
IEnumerable<X509Certificate2> trustBundle = await CertificateHelper.GetTrustBundleFromEdgelet(workloadUri, edgeletApiVersion, Constants.WorkloadApiVersion, moduleId, generationId);

result = new EdgeHubCertificates(
Expand Down
135 changes: 72 additions & 63 deletions edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Service/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,74 +51,83 @@ static async Task<int> MainAsync(IConfigurationRoot configuration)

ILogger logger = Logger.Factory.CreateLogger("EdgeHub");

EdgeHubCertificates certificates = await EdgeHubCertificates.LoadAsync(configuration, logger);
bool clientCertAuthEnabled = configuration.GetValue(Constants.ConfigKey.EdgeHubClientCertAuthEnabled, false);

string sslProtocolsConfig = configuration.GetValue(Constants.ConfigKey.SslProtocols, string.Empty);
SslProtocols sslProtocols = SslProtocolsHelper.Parse(sslProtocolsConfig, DefaultSslProtocols, logger);
logger.LogInformation($"Enabling SSL protocols: {sslProtocols.Print()}");

IDependencyManager dependencyManager = new DependencyManager(configuration, certificates.ServerCertificate, certificates.TrustBundle, sslProtocols);
Hosting hosting = Hosting.Initialize(configuration, certificates.ServerCertificate, dependencyManager, clientCertAuthEnabled, sslProtocols);
IContainer container = hosting.Container;

logger.LogInformation("Initializing Edge Hub");
LogLogo(logger);
LogVersionInfo(logger);
logger.LogInformation($"OptimizeForPerformance={configuration.GetValue("OptimizeForPerformance", true)}");
logger.LogInformation($"MessageAckTimeoutSecs={configuration.GetValue("MessageAckTimeoutSecs", 30)}");
logger.LogInformation("Loaded server certificate with expiration date of {0}", certificates.ServerCertificate.NotAfter.ToString("o"));

var metricsProvider = container.Resolve<IMetricsProvider>();
Metrics.InitWithAspNet(metricsProvider, logger); // Note this requires App.UseMetricServer() to be called in Startup.cs

// EdgeHub and CloudConnectionProvider have a circular dependency. So need to Bind the EdgeHub to the CloudConnectionProvider.
IEdgeHub edgeHub = await container.Resolve<Task<IEdgeHub>>();
ICloudConnectionProvider cloudConnectionProvider = await container.Resolve<Task<ICloudConnectionProvider>>();
cloudConnectionProvider.BindEdgeHub(edgeHub);

// EdgeHub cloud proxy and DeviceConnectivityManager have a circular dependency,
// so the cloud proxy has to be set on the DeviceConnectivityManager after both have been initialized.
var deviceConnectivityManager = container.Resolve<IDeviceConnectivityManager>();
IConnectionManager connectionManager = await container.Resolve<Task<IConnectionManager>>();
(deviceConnectivityManager as DeviceConnectivityManager)?.SetConnectionManager(connectionManager);

// Register EdgeHub credentials
var edgeHubCredentials = container.ResolveNamed<IClientCredentials>("EdgeHubCredentials");
ICredentialsCache credentialsCache = await container.Resolve<Task<ICredentialsCache>>();
await credentialsCache.Add(edgeHubCredentials);

// Initializing configuration
logger.LogInformation("Initializing configuration");
IConfigSource configSource = await container.Resolve<Task<IConfigSource>>();
ConfigUpdater configUpdater = await container.Resolve<Task<ConfigUpdater>>();
await configUpdater.Init(configSource);

if (!Enum.TryParse(configuration.GetValue("AuthenticationMode", string.Empty), true, out AuthenticationMode authenticationMode)
|| authenticationMode != AuthenticationMode.Cloud)
try
{
ConnectionReauthenticator connectionReauthenticator = await container.Resolve<Task<ConnectionReauthenticator>>();
connectionReauthenticator.Init();
EdgeHubCertificates certificates = await EdgeHubCertificates.LoadAsync(configuration, logger);
bool clientCertAuthEnabled = configuration.GetValue(Constants.ConfigKey.EdgeHubClientCertAuthEnabled, false);

string sslProtocolsConfig = configuration.GetValue(Constants.ConfigKey.SslProtocols, string.Empty);
SslProtocols sslProtocols = SslProtocolsHelper.Parse(sslProtocolsConfig, DefaultSslProtocols, logger);
logger.LogInformation($"Enabling SSL protocols: {sslProtocols.Print()}");

IDependencyManager dependencyManager = new DependencyManager(configuration, certificates.ServerCertificate, certificates.TrustBundle, sslProtocols);
Hosting hosting = Hosting.Initialize(configuration, certificates.ServerCertificate, dependencyManager, clientCertAuthEnabled, sslProtocols);
IContainer container = hosting.Container;

logger.LogInformation("Initializing Edge Hub");
LogLogo(logger);
LogVersionInfo(logger);
logger.LogInformation($"OptimizeForPerformance={configuration.GetValue("OptimizeForPerformance", true)}");
logger.LogInformation($"MessageAckTimeoutSecs={configuration.GetValue("MessageAckTimeoutSecs", 30)}");
logger.LogInformation("Loaded server certificate with expiration date of {0}", certificates.ServerCertificate.NotAfter.ToString("o"));

var metricsProvider = container.Resolve<IMetricsProvider>();
Metrics.InitWithAspNet(metricsProvider, logger); // Note this requires App.UseMetricServer() to be called in Startup.cs

// EdgeHub and CloudConnectionProvider have a circular dependency. So need to Bind the EdgeHub to the CloudConnectionProvider.
IEdgeHub edgeHub = await container.Resolve<Task<IEdgeHub>>();
ICloudConnectionProvider cloudConnectionProvider = await container.Resolve<Task<ICloudConnectionProvider>>();
cloudConnectionProvider.BindEdgeHub(edgeHub);

// EdgeHub cloud proxy and DeviceConnectivityManager have a circular dependency,
// so the cloud proxy has to be set on the DeviceConnectivityManager after both have been initialized.
var deviceConnectivityManager = container.Resolve<IDeviceConnectivityManager>();
IConnectionManager connectionManager = await container.Resolve<Task<IConnectionManager>>();
(deviceConnectivityManager as DeviceConnectivityManager)?.SetConnectionManager(connectionManager);

// Register EdgeHub credentials
var edgeHubCredentials = container.ResolveNamed<IClientCredentials>("EdgeHubCredentials");
ICredentialsCache credentialsCache = await container.Resolve<Task<ICredentialsCache>>();
await credentialsCache.Add(edgeHubCredentials);

// Initializing configuration
logger.LogInformation("Initializing configuration");
IConfigSource configSource = await container.Resolve<Task<IConfigSource>>();
ConfigUpdater configUpdater = await container.Resolve<Task<ConfigUpdater>>();
await configUpdater.Init(configSource);

if (!Enum.TryParse(configuration.GetValue("AuthenticationMode", string.Empty), true, out AuthenticationMode authenticationMode)
|| authenticationMode != AuthenticationMode.Cloud)
{
ConnectionReauthenticator connectionReauthenticator = await container.Resolve<Task<ConnectionReauthenticator>>();
connectionReauthenticator.Init();
}

TimeSpan shutdownWaitPeriod = TimeSpan.FromSeconds(configuration.GetValue("ShutdownWaitPeriod", DefaultShutdownWaitPeriod));
(CancellationTokenSource cts, ManualResetEventSlim completed, Option<object> handler) = ShutdownHandler.Init(shutdownWaitPeriod, logger);

using (IProtocolHead protocolHead = await GetEdgeHubProtocolHeadAsync(logger, configuration, container, hosting))
using (var renewal = new CertificateRenewal(certificates, logger))
{
await protocolHead.StartAsync();
await Task.WhenAny(cts.Token.WhenCanceled(), renewal.Token.WhenCanceled());
logger.LogInformation("Stopping the protocol heads...");
await protocolHead.CloseAsync(CancellationToken.None);
logger.LogInformation("Protocol heads stopped.");

await CloseDbStoreProviderAsync(container);
}

completed.Set();
handler.ForEach(h => GC.KeepAlive(h));
logger.LogInformation("Shutdown complete.");
}

TimeSpan shutdownWaitPeriod = TimeSpan.FromSeconds(configuration.GetValue("ShutdownWaitPeriod", DefaultShutdownWaitPeriod));
(CancellationTokenSource cts, ManualResetEventSlim completed, Option<object> handler) = ShutdownHandler.Init(shutdownWaitPeriod, logger);

using (IProtocolHead protocolHead = await GetEdgeHubProtocolHeadAsync(logger, configuration, container, hosting))
using (var renewal = new CertificateRenewal(certificates, logger))
catch (Exception ex)
{
await protocolHead.StartAsync();
await Task.WhenAny(cts.Token.WhenCanceled(), renewal.Token.WhenCanceled());
logger.LogInformation("Stopping the protocol heads...");
await protocolHead.CloseAsync(CancellationToken.None);
logger.LogInformation("Protocol heads stopped.");

await CloseDbStoreProviderAsync(container);
logger.LogError(ex, "Stopping with exception");
throw;
}

completed.Set();
handler.ForEach(h => GC.KeepAlive(h));
logger.LogInformation("Shutdown complete.");
return 0;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,15 +248,15 @@ public static IEnumerable<X509Certificate2> GetCertificatesFromPem(IEnumerable<s
.Select(c => new X509Certificate2(c))
.ToList();

public static async Task<(X509Certificate2 ServerCertificate, IEnumerable<X509Certificate2> CertificateChain)> GetServerCertificatesFromEdgelet(Uri workloadUri, string workloadApiVersion, string workloadClientApiVersion, string moduleId, string moduleGenerationId, string edgeHubHostname, DateTime expiration)
public static async Task<(X509Certificate2 ServerCertificate, IEnumerable<X509Certificate2> CertificateChain)> GetServerCertificatesFromEdgelet(Uri workloadUri, string workloadApiVersion, string workloadClientApiVersion, string moduleId, string moduleGenerationId, string edgeHubHostname, DateTime expiration, ILogger logger)
{
if (string.IsNullOrEmpty(edgeHubHostname))
{
throw new InvalidOperationException($"{nameof(edgeHubHostname)} is required.");
}

ServerCertificateResponse response = await new WorkloadClient(workloadUri, workloadApiVersion, workloadClientApiVersion, moduleId, moduleGenerationId).CreateServerCertificateAsync(edgeHubHostname, expiration);
return ParseCertificateResponse(response);
return ParseCertificateResponse(response, logger);
}

public static async Task<IEnumerable<X509Certificate2>> GetTrustBundleFromEdgelet(Uri workloadUri, string workloadApiVersion, string workloadClientApiVersion, string moduleId, string moduleGenerationId)
Expand All @@ -265,7 +265,7 @@ public static async Task<IEnumerable<X509Certificate2>> GetTrustBundleFromEdgele
return ParseTrustedBundleCerts(response);
}

public static (X509Certificate2 ServerCertificate, IEnumerable<X509Certificate2> CertificateChain) GetServerCertificateAndChainFromFile(string serverWithChainFilePath, string serverPrivateKeyFilePath)
public static (X509Certificate2 ServerCertificate, IEnumerable<X509Certificate2> CertificateChain) GetServerCertificateAndChainFromFile(string serverWithChainFilePath, string serverPrivateKeyFilePath, ILogger logger = null)
{
string cert, privateKey;

Expand All @@ -289,7 +289,7 @@ public static (X509Certificate2 ServerCertificate, IEnumerable<X509Certificate2>
privateKey = sr.ReadToEnd();
}

return ParseCertificateAndKey(cert, privateKey);
return ParseCertificateAndKey(cert, privateKey, logger);
}

public static IEnumerable<X509Certificate2> GetServerCACertificatesFromFile(string chainPath)
Expand Down Expand Up @@ -338,10 +338,10 @@ internal static IEnumerable<X509Certificate2> ParseTrustedBundleCerts(string tru
return GetCertificatesFromPem(ParsePemCerts(trustedCACerts));
}

internal static (X509Certificate2, IEnumerable<X509Certificate2>) ParseCertificateResponse(ServerCertificateResponse response) =>
ParseCertificateAndKey(response.Certificate, response.PrivateKey);
internal static (X509Certificate2, IEnumerable<X509Certificate2>) ParseCertificateResponse(ServerCertificateResponse response, ILogger logger = null) =>
ParseCertificateAndKey(response.Certificate, response.PrivateKey, logger);

internal static (X509Certificate2, IEnumerable<X509Certificate2>) ParseCertificateAndKey(string certificateWithChain, string privateKey)
internal static (X509Certificate2, IEnumerable<X509Certificate2>) ParseCertificateAndKey(string certificateWithChain, string privateKey, ILogger logger = null)
{
IEnumerable<string> pemCerts = ParsePemCerts(certificateWithChain);

Expand All @@ -353,7 +353,7 @@ internal static (X509Certificate2, IEnumerable<X509Certificate2>) ParseCertifica
IEnumerable<X509Certificate2> certsChain = GetCertificatesFromPem(pemCerts.Skip(1));

var certWithNoKey = new X509Certificate2(Encoding.UTF8.GetBytes(pemCerts.First()));
var certWithPrivateKey = AttachPrivateKey(certWithNoKey, privateKey);
var certWithPrivateKey = AttachPrivateKey(certWithNoKey, privateKey, logger);

return (certWithPrivateKey, certsChain);
}
Expand Down Expand Up @@ -385,7 +385,7 @@ static Option<string> GetCommonNameFromSubject(string subject)
return commonName;
}

static X509Certificate2 AttachPrivateKey(X509Certificate2 certificate, string pemEncodedKey)
static X509Certificate2 AttachPrivateKey(X509Certificate2 certificate, string pemEncodedKey, ILogger logger)
{
var pkcs8Label = "PRIVATE KEY";
var rsaLabel = "RSA PRIVATE KEY";
Expand All @@ -399,6 +399,8 @@ static X509Certificate2 AttachPrivateKey(X509Certificate2 certificate, string pe
{
if (oidRsaEncryption.Value == keyAlgorithm)
{
logger?.LogDebug("Importing RSA private key");

var decodedKey = UnwrapPrivateKey(pemEncodedKey, isPkcs8 ? pkcs8Label : rsaLabel);
var key = RSA.Create();

Expand All @@ -412,9 +414,13 @@ static X509Certificate2 AttachPrivateKey(X509Certificate2 certificate, string pe
}

result = certificate.CopyWithPrivateKey(key);

logger?.LogDebug("RSA private key has been imported and assigned to certificate");
}
else if (oidEcPublicKey.Value == keyAlgorithm)
{
logger?.LogDebug("Importing ECC private key");

var decodedKey = UnwrapPrivateKey(pemEncodedKey, isPkcs8 ? pkcs8Label : ecLabel);
var key = ECDsa.Create();

Expand All @@ -428,16 +434,22 @@ static X509Certificate2 AttachPrivateKey(X509Certificate2 certificate, string pe
}

result = certificate.CopyWithPrivateKey(key);

logger?.LogDebug("ECC private key has been imported and assigned to certificate");
}
}
catch (Exception ex)
{
throw new InvalidOperationException("Cannot import private key", ex);
var errorMessage = "Cannot import private key";
logger?.LogError(ex, errorMessage);
throw new InvalidOperationException(errorMessage, ex);
}

if (result == null)
{
throw new InvalidOperationException($"Cannot use certificate, not supported key algorithm: ${keyAlgorithm}");
var errorMessage = $"Cannot use certificate, not supported key algorithm: ${keyAlgorithm}";
logger?.LogError(errorMessage);
throw new InvalidOperationException(errorMessage);
}

return result;
Expand Down

0 comments on commit c9f8daa

Please sign in to comment.