Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving error handling in SecretManager #9429

Merged
merged 6 commits into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using DryIoc;
using Microsoft.Azure.Web.DataProtection;
using Microsoft.Azure.WebJobs.Extensions.Http;
using Microsoft.Azure.WebJobs.Script.Diagnostics;
using Microsoft.Azure.WebJobs.Script.WebHost.Properties;
using Microsoft.Azure.WebJobs.Script.WebHost.Security;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;

using DataProtectionConstants = Microsoft.Azure.Web.DataProtection.Constants;

namespace Microsoft.Azure.WebJobs.Script.WebHost
Expand Down Expand Up @@ -90,12 +92,26 @@ public async virtual Task<HostSecretsInfo> GetHostSecretsAsync()
_logger.LogDebug("Loading host secrets");

hostSecrets = await LoadSecretsAsync<HostSecrets>();
if (hostSecrets == null)
try
{
// host secrets do not yet exist so generate them
_logger.LogDebug(Resources.TraceHostSecretGeneration);
hostSecrets = GenerateHostSecrets();
await PersistSecretsAsync(hostSecrets);
if (hostSecrets == null)
aishwaryabh marked this conversation as resolved.
Show resolved Hide resolved
{
// host secrets do not yet exist so generate them
_logger.LogDebug(Resources.TraceHostSecretGeneration);
hostSecrets = GenerateHostSecrets();
await PersistSecretsAsync(hostSecrets);
}
}
catch (Exception ex)
aishwaryabh marked this conversation as resolved.
Show resolved Hide resolved
{
_logger.LogDebug(ex, "Exception while generating host secrets. This can happen if another instance is also generating secrets. Attempting to read secrets again.");
hostSecrets = await LoadSecretsAsync<HostSecrets>();

if (hostSecrets == null)
{
_logger.LogError("Host secrets are still null on second attempt.");
throw;
}
}

try
Expand Down Expand Up @@ -158,14 +174,29 @@ public async virtual Task<IDictionary<string, string>> GetFunctionSecretsAsync(s
_logger.LogDebug($"Loading secrets for function '{functionName}'");

FunctionSecrets secrets = await LoadFunctionSecretsAsync(functionName);
if (secrets == null)
{
// no secrets exist for this function so generate them
string message = string.Format(Resources.TraceFunctionSecretGeneration, functionName);
_logger.LogDebug(message);
secrets = GenerateFunctionSecrets();

await PersistSecretsAsync(secrets, functionName);
try
{
if (secrets == null)
{
// no secrets exist for this function so generate them
string message = string.Format(Resources.TraceFunctionSecretGeneration, functionName);
_logger.LogDebug(message);

Check failure

Code scanning / CodeQL

Log entries created from user input

This log entry depends on a [user-provided value](1). This log entry depends on a [user-provided value](2).
secrets = GenerateFunctionSecrets();

await PersistSecretsAsync(secrets, functionName);
}
}
catch (Exception ex)
{
_logger.LogDebug(ex, "Exception while generating function secrets. This can happen if another instance is also generating secrets. Attempting to read secrets again.");
secrets = await LoadFunctionSecretsAsync(functionName);

if (secrets == null)
{
_logger.LogError("Function secrets are still null on second attempt.");
throw;
}
}

try
Expand Down
78 changes: 76 additions & 2 deletions test/WebJobs.Script.Tests/Security/SecretManagerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,68 @@ public async Task GetHostSecretsAsync_SecretGenerationIsSerialized()
}
}

[Fact]
public async Task SecretsRepository_SimultaneousCreates_Throws_Conflict()
{
var mockValueConverterFactory = GetConverterFactoryMock(false, false);
var metricsLogger = new TestMetricsLogger();

// Test repository that will fail on WriteAsync() due to a conflict, but will replicate a success when LoadSecretsAsync() is called
// indicating that there was race condition
var testRepository = new TestSecretsRepository(true, true, true);
string testFunctionName = "host";

using (var secretManager = new SecretManager(testRepository, mockValueConverterFactory.Object, _logger, metricsLogger, _hostNameProvider, _startupContextProvider))
{
var tasks = new List<Task<HostSecretsInfo>>();
for (int i = 0; i < 2; i++)
{
tasks.Add(secretManager.GetHostSecretsAsync());
}

// Ensure nothing is there.
HostSecrets secretsContent = await testRepository.ReadAsync(ScriptSecretsType.Host, testFunctionName) as HostSecrets;
Assert.Null(secretsContent);
await Task.WhenAll(tasks);

// verify all calls return the same result
var masterKey = tasks.First().Result.MasterKey;
var functionKey = tasks.First().Result.FunctionKeys.First();
Assert.True(tasks.Select(p => p.Result).All(q => q.MasterKey == masterKey));
Assert.True(tasks.Select(p => p.Result).All(q => q.FunctionKeys.First().Value == functionKey.Value));

// verify generated master and function keys are valid
tasks.Select(p => p.Result).All(q => ValidateHostSecrets(q));
}
}

[Fact]
public async Task FunctionSecrets_SimultaneousCreates_Throws_Conflict()
{
var mockValueConverterFactory = GetConverterFactoryMock(false, false);
var metricsLogger = new TestMetricsLogger();
var testRepository = new TestSecretsRepository(true, true, true);
string testFunctionName = $"TestFunction";

using (var secretManager = new SecretManager(testRepository, mockValueConverterFactory.Object, _logger, metricsLogger, _hostNameProvider, _startupContextProvider))
{
var tasks = new List<Task<IDictionary<string, string>>>();
for (int i = 0; i < 2; i++)
{
tasks.Add(secretManager.GetFunctionSecretsAsync(testFunctionName));
}

await Task.WhenAll(tasks);

// verify all calls return the same result
Assert.Equal(1, testRepository.FunctionSecrets.Count);
var functionSecrets = (FunctionSecrets)testRepository.FunctionSecrets[testFunctionName];
string defaultKeyValue = functionSecrets.Keys.Where(p => p.Name == "default").Single().Value;
SecretGeneratorTests.ValidateSecret(defaultKeyValue, SecretGenerator.FunctionKeySeed);
Assert.True(tasks.Select(p => p.Result).All(t => t["default"] == defaultKeyValue));
}
}

[Fact]
public async Task GetHostSecrets_WhenNoHostSecretFileExists_GeneratesSecretsAndPersistsFiles()
{
Expand Down Expand Up @@ -532,7 +594,7 @@ public async Task AddOrUpdateFunctionSecret_WhenStorageWriteError_ThrowsExceptio

KeyOperationResult result;

ISecretsRepository repository = new TestSecretsRepository(false, true, HttpStatusCode.InternalServerError);
ISecretsRepository repository = new TestSecretsRepository(false, true, false, HttpStatusCode.InternalServerError);
using (var secretManager = CreateSecretManager(directory.Path, simulateWriteConversion: false, secretsRepository: repository))
{
try
Expand Down Expand Up @@ -1265,17 +1327,19 @@ private class TestSecretsRepository : ISecretsRepository
private Random _rand = new Random();
private bool _enforceSerialWrites = false;
private bool _forceWriteErrors = false;
private bool _shouldSuceedAfterFailing = false;
private HttpStatusCode _httpstaus;

public TestSecretsRepository(bool enforceSerialWrites)
{
_enforceSerialWrites = enforceSerialWrites;
}

public TestSecretsRepository(bool enforceSerialWrites, bool forceWriteErrors, HttpStatusCode httpstaus = HttpStatusCode.InternalServerError)
public TestSecretsRepository(bool enforceSerialWrites, bool forceWriteErrors, bool shouldSucceedAfterFailing = false, HttpStatusCode httpstaus = HttpStatusCode.InternalServerError)
: this(enforceSerialWrites)
{
_forceWriteErrors = forceWriteErrors;
_shouldSuceedAfterFailing = shouldSucceedAfterFailing;
_httpstaus = httpstaus;
}

Expand Down Expand Up @@ -1319,6 +1383,11 @@ public async Task WriteAsync(ScriptSecretsType type, string functionName, Script
{
if (_forceWriteErrors)
{
// Replicate making the first write fail, but succeed on the second attempt
if (_shouldSuceedAfterFailing)
{
await WriteAsyncHelper(type, functionName, secrets);
}
throw new RequestFailedException((int)_httpstaus, "Error");
}

Expand All @@ -1327,6 +1396,11 @@ public async Task WriteAsync(ScriptSecretsType type, string functionName, Script
throw new Exception("Concurrent writes detected!");
}

await WriteAsyncHelper(type, functionName, secrets);
}

private async Task WriteAsyncHelper(ScriptSecretsType type, string functionName, ScriptSecrets secrets)
{
Interlocked.Increment(ref _writeCount);

await Task.Delay(_rand.Next(100, 300));
Expand Down