diff --git a/src/ApiService/ApiService/OneFuzzTypes/Model.cs b/src/ApiService/ApiService/OneFuzzTypes/Model.cs index fc27cf12fc..1b129ce808 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Model.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Model.cs @@ -281,7 +281,9 @@ public record AzureMonitorExtensionConfig( public record AzureVmExtensionConfig( KeyvaultExtensionConfig? Keyvault, - AzureMonitorExtensionConfig AzureMonitor + AzureMonitorExtensionConfig? AzureMonitor, + AzureSecurityExtensionConfig? AzureSecurity, + GenevaExtensionConfig? Geneva ); public record NetworkConfig( diff --git a/src/ApiService/ApiService/OneFuzzTypes/ReturnTypes.cs b/src/ApiService/ApiService/OneFuzzTypes/ReturnTypes.cs index 89f543fe8f..36abcdec59 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/ReturnTypes.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/ReturnTypes.cs @@ -56,7 +56,32 @@ public struct OneFuzzResult { public static OneFuzzResult Ok(T_Ok ok) => new(ok); public static OneFuzzResult Error(ErrorCode errorCode, string[] errors) => new(errorCode, errors); + public static OneFuzzResult Error(ErrorCode errorCode, string error) => new(errorCode, new[] { error }); public static OneFuzzResult Error(Error err) => new(err); } + + + public struct OneFuzzResultVoid { + static Error NoError = new(0); + + readonly Error error; + readonly bool isOk; + + public bool IsOk => isOk; + + public Error ErrorV => error; + + private OneFuzzResultVoid(ErrorCode errorCode, string[] errors) => (error, isOk) = (new Error(errorCode, errors), false); + + private OneFuzzResultVoid(Error err) => (error, isOk) = (err, false); + + public static OneFuzzResultVoid Ok() => new(); + public static OneFuzzResultVoid Error(ErrorCode errorCode, string[] errors) => new(errorCode, errors); + public static OneFuzzResultVoid Error(ErrorCode errorCode, string error) => new(errorCode, new[] { error }); + public static OneFuzzResultVoid Error(Error err) => new(err); + } + + + } diff --git a/src/ApiService/ApiService/Program.cs b/src/ApiService/ApiService/Program.cs index bc37ac2dee..3ff38439d0 100644 --- a/src/ApiService/ApiService/Program.cs +++ b/src/ApiService/ApiService/Program.cs @@ -86,6 +86,8 @@ public static void Main() { .AddScoped() .AddScoped() .AddScoped() + .AddScoped() + .AddScoped() //Move out expensive resources into separate class, and add those as Singleton // ArmClient, Table Client(s), Queue Client(s), HttpClient, etc.\ diff --git a/src/ApiService/ApiService/TestHooks.cs b/src/ApiService/ApiService/TestHooks.cs index a06522233b..620bdbe207 100644 --- a/src/ApiService/ApiService/TestHooks.cs +++ b/src/ApiService/ApiService/TestHooks.cs @@ -1,4 +1,5 @@ using System.Net; +using System.Text.Json; using System.Threading.Tasks; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Http; @@ -13,12 +14,16 @@ public class TestHooks { private readonly IConfigOperations _configOps; private readonly IEvents _events; private readonly IServiceConfig _config; + private readonly ISecretsOperations _secretOps; + private readonly ILogAnalytics _logAnalytics; - public TestHooks(ILogTracer log, IConfigOperations configOps, IEvents events, IServiceConfig config) { + public TestHooks(ILogTracer log, IConfigOperations configOps, IEvents events, IServiceConfig config, ISecretsOperations secretOps, ILogAnalytics logAnalytics) { _log = log; _configOps = configOps; _events = events; _config = config; + _secretOps = secretOps; + _logAnalytics = logAnalytics; } [Function("Info")] @@ -57,4 +62,68 @@ public async Task InstanceConfig([HttpTrigger(AuthorizationLev return resp; } } + + [Function("GetKeyvaultAddress")] + public async Task GetKeyVaultAddress([HttpTrigger(AuthorizationLevel.Anonymous, "get", Route = "testhooks/secrets/keyvaultaddress")] HttpRequestData req) { + _log.Info("Getting keyvault address"); + var addr = _secretOps.GetKeyvaultAddress(); + var resp = req.CreateResponse(HttpStatusCode.OK); + await resp.WriteAsJsonAsync(addr); + return resp; + } + + [Function("SaveToKeyvault")] + public async Task SaveToKeyvault([HttpTrigger(AuthorizationLevel.Anonymous, "post", Route = "testhooks/secrets/keyvault")] HttpRequestData req) { + var s = await req.ReadAsStringAsync(); + var secretData = JsonSerializer.Deserialize>(s!, EntityConverter.GetJsonSerializerOptions()); + if (secretData is null) { + _log.Error("Secret data is null"); + return req.CreateResponse(HttpStatusCode.BadRequest); + } else { + _log.Info($"Saving secret data in the keyvault"); + var r = await _secretOps.SaveToKeyvault(secretData); + var addr = _secretOps.GetKeyvaultAddress(); + var resp = req.CreateResponse(HttpStatusCode.OK); + await resp.WriteAsJsonAsync(addr); + return resp; + } + } + + [Function("GetSecretStringValue")] + public async Task GetSecretStringValue([HttpTrigger(AuthorizationLevel.Anonymous, "get", Route = "testhooks/secrets/keyvault")] HttpRequestData req) { + var queryComponents = req.Url.GetComponents(UriComponents.Query, UriFormat.UriEscaped).Split("&"); + + var q = + from cs in queryComponents + where !string.IsNullOrEmpty(cs) + let i = cs.IndexOf('=') + select new KeyValuePair(Uri.UnescapeDataString(cs.Substring(0, i)), Uri.UnescapeDataString(cs.Substring(i + 1))); + + var qs = new Dictionary(q); + var d = await _secretOps.GetSecretStringValue(new SecretData(qs["SecretName"])); + + var resp = req.CreateResponse(HttpStatusCode.OK); + await resp.WriteAsJsonAsync(d); + return resp; + } + + + [Function("GetWorkspaceId")] + public async Task GetWorkspaceId([HttpTrigger(AuthorizationLevel.Anonymous, "get", Route = "testhooks/logAnalytics/workspaceId")] HttpRequestData req) { + var id = _logAnalytics.GetWorkspaceId(); + var resp = req.CreateResponse(HttpStatusCode.OK); + await resp.WriteAsJsonAsync(id); + return resp; + } + + + + [Function("GetMonitorSettings")] + public async Task GetMonitorSettings([HttpTrigger(AuthorizationLevel.Anonymous, "get", Route = "testhooks/logAnalytics/monitorSettings")] HttpRequestData req) { + var settings = await _logAnalytics.GetMonitorSettings(); + var resp = req.CreateResponse(HttpStatusCode.OK); + await resp.WriteAsJsonAsync(settings); + return resp; + } + } diff --git a/src/ApiService/ApiService/onefuzzlib/Containers.cs b/src/ApiService/ApiService/onefuzzlib/Containers.cs index 474db0645e..a9d0d754f4 100644 --- a/src/ApiService/ApiService/onefuzzlib/Containers.cs +++ b/src/ApiService/ApiService/onefuzzlib/Containers.cs @@ -142,35 +142,35 @@ public async Async.Task GetInstanceId() { } return System.Guid.Parse(blob.ToString()); } - - public Uri? GetContainerSasUrlService( - BlobContainerClient client, - BlobSasPermissions permissions, - bool tag = false, - TimeSpan? timeSpan = null) { - var (start, expiry) = SasTimeWindow(timeSpan ?? TimeSpan.FromDays(30.0)); - var sasBuilder = new BlobSasBuilder(permissions, expiry) { StartsOn = start }; - var sas = client.GenerateSasUri(sasBuilder); - return sas; - } - - - //TODO: instead of returning null when container not found, convert to return to "Result" type and set appropriate error - public async Async.Task GetContainerSasUrl(Container container, StorageType storageType, BlobSasPermissions permissions) { - var client = await FindContainer(container, storageType); - - if (client is null) { - return null; - } - - var uri = GetContainerSasUrlService(client, permissions); - - if (uri is null) { - //TODO: return result error - return uri; - } else { - return uri; - } - } -} - + + public Uri? GetContainerSasUrlService( + BlobContainerClient client, + BlobSasPermissions permissions, + bool tag = false, + TimeSpan? timeSpan = null) { + var (start, expiry) = SasTimeWindow(timeSpan ?? TimeSpan.FromDays(30.0)); + var sasBuilder = new BlobSasBuilder(permissions, expiry) { StartsOn = start }; + var sas = client.GenerateSasUri(sasBuilder); + return sas; + } + + + //TODO: instead of returning null when container not found, convert to return to "Result" type and set appropriate error + public async Async.Task GetContainerSasUrl(Container container, StorageType storageType, BlobSasPermissions permissions) { + var client = await FindContainer(container, storageType); + + if (client is null) { + return null; + } + + var uri = GetContainerSasUrlService(client, permissions); + + if (uri is null) { + //TODO: return result error + return uri; + } else { + return uri; + } + } +} + diff --git a/src/ApiService/ApiService/onefuzzlib/Extension.cs b/src/ApiService/ApiService/onefuzzlib/Extension.cs new file mode 100644 index 0000000000..2aa8cc0541 --- /dev/null +++ b/src/ApiService/ApiService/onefuzzlib/Extension.cs @@ -0,0 +1,351 @@ +using System.Text.Json; +using Azure.ResourceManager.Compute; +using Azure.Storage.Sas; +using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; + +namespace Microsoft.OneFuzz.Service; + +public interface IExtensions { + public Async.Task> FuzzExtensions(Pool pool, Scaleset scaleset); +} + + +public class Extensions : IExtensions { + IServiceConfig _serviceConfig; + ICreds _creds; + IQueue _queue; + IContainers _containers; + IConfigOperations _instanceConfigOps; + ILogAnalytics _logAnalytics; + + public Extensions(IServiceConfig config, ICreds creds, IQueue queue, IContainers containers, IConfigOperations instanceConfigOps, ILogAnalytics logAnalytics) { + _serviceConfig = config; + _creds = creds; + _queue = queue; + _containers = containers; + _instanceConfigOps = instanceConfigOps; + _logAnalytics = logAnalytics; + } + + public async Async.Task ConfigUrl(Container container, string fileName, bool withSas) { + if (withSas) + return await _containers.GetFileSasUrl(container, fileName, StorageType.Config, BlobSasPermissions.Read); + else + return await _containers.GetFileUrl(container, fileName, StorageType.Config); + } + + public async Async.Task> GenericExtensions(string region, Os vmOs) { + var extensions = new List(); + + var instanceConfig = await _instanceConfigOps.Fetch(); + extensions.Add(await MonitorExtension(region, vmOs)); + + var depenency = DependencyExtension(region, vmOs); + if (depenency is not null) { + extensions.Add(depenency); + } + + if (instanceConfig.Extensions is not null) { + + if (instanceConfig.Extensions.Keyvault is not null) { + var keyvault = KeyVaultExtension(region, instanceConfig.Extensions.Keyvault, vmOs); + extensions.Add(keyvault); + } + + if (instanceConfig.Extensions.Geneva is not null && vmOs == Os.Windows) { + var geneva = GenevaExtension(region); + extensions.Add(geneva); + } + + if (instanceConfig.Extensions.AzureMonitor is not null && vmOs == Os.Linux) { + var azMon = AzMonExtension(region, instanceConfig.Extensions.AzureMonitor); + extensions.Add(azMon); + } + + if (instanceConfig.Extensions.AzureSecurity is not null && vmOs == Os.Linux) { + var azSec = AzSecExtension(region); + extensions.Add(azSec); + } + } + + return extensions; + } + + public VirtualMachineScaleSetExtensionData KeyVaultExtension(string region, KeyvaultExtensionConfig keyVault, Os vmOs) { + var keyVaultName = keyVault.KeyVaultName; + var certName = keyVault.CertName; + var uri = keyVaultName + certName; + + if (vmOs == Os.Windows) { + return new VirtualMachineScaleSetExtensionData { + Name = "KVVMExtensionForWindows", + Publisher = "Microsoft.Azure.KeyVault", + TypePropertiesType = "KeyVaultForWindows", + TypeHandlerVersion = "1.0", + AutoUpgradeMinorVersion = true, + Settings = new BinaryData(new { + SecretsManagementSettings = new { + PollingIntervalInS = "3600", + CertificateStoreName = "MY", + LinkOnRenewal = false, + CertificateStoreLocation = "LocalMachine", + RequireInitialSync = true, + ObservedCertificates = new string[] { uri }, + } + }) + }; + } else if (vmOs == Os.Linux) { + var certPath = keyVault.CertPath; + var extensionStore = keyVault.ExtensionStore; + var certLocation = certPath + extensionStore; + + return new VirtualMachineScaleSetExtensionData { + Name = "KVVMExtensionForLinux", + Publisher = "Microsoft.Azure.KeyVault", + TypePropertiesType = "KeyVaultForLinux", + TypeHandlerVersion = "2.0", + AutoUpgradeMinorVersion = true, + Settings = new BinaryData(new { + SecretsManagementSettings = new { + PollingIntervalInS = "3600", + CertificateStoreLocation = certLocation, + RequireInitialSync = true, + ObservedCertificates = new string[] { uri }, + } + }) + }; + } else { + throw new NotImplementedException($"unsupported os {vmOs}"); + } + } + + public VirtualMachineScaleSetExtensionData AzSecExtension(string region) { + return new VirtualMachineScaleSetExtensionData { + Name = "AzureSecurityLinuxAgent", + Publisher = "Microsoft.Azure.Security.Monitoring", + TypePropertiesType = "AzureSecurityLinuxAgent", + TypeHandlerVersion = "2.0", + AutoUpgradeMinorVersion = true, + Settings = new BinaryData(new { EnableGenevaUpload = true, EnableAutoConfig = true }) + }; + + } + + public VirtualMachineScaleSetExtensionData AzMonExtension(string region, AzureMonitorExtensionConfig azureMonitor) { + var authId = azureMonitor.MonitoringGCSAuthId; + var configVersion = azureMonitor.ConfigVersion; + var moniker = azureMonitor.Moniker; + var namespaceName = azureMonitor.Namespace; + var environment = azureMonitor.MonitoringGSEnvironment; + var account = azureMonitor.MonitoringGCSAccount; + var authIdType = azureMonitor.MonitoringGCSAuthIdType; + + return new VirtualMachineScaleSetExtensionData { + Name = "AzureMonitorLinuxAgent", + Publisher = "Microsoft.Azure.Monitor", + TypePropertiesType = "AzureMonitorLinuxAgent", + AutoUpgradeMinorVersion = true, + TypeHandlerVersion = "1.0", + Settings = new BinaryData(new { GCS_AUTO_CONFIG = true }), + ProtectedSettings = + new BinaryData( + new { + ConfigVersion = configVersion, + Moniker = moniker, + Namespace = namespaceName, + MonitoringGCSEnvironment = environment, + MonitoringGCSAccount = account, + MonitoringGCSRegion = region, + MonitoringGCSAuthId = authId, + MonitoringGCSAuthIdType = authIdType, + }) + }; + } + + + + public VirtualMachineScaleSetExtensionData GenevaExtension(string region) { + return new VirtualMachineScaleSetExtensionData { + Name = "Microsoft.Azure.Geneva.GenevaMonitoring", + Publisher = "Microsoft.Azure.Geneva", + TypePropertiesType = "GenevaMonitoring", + TypeHandlerVersion = "2.0", + AutoUpgradeMinorVersion = true, + EnableAutomaticUpgrade = true, + }; + } + + public VirtualMachineScaleSetExtensionData? DependencyExtension(string region, Os vmOs) { + + if (vmOs == Os.Windows) { + return new VirtualMachineScaleSetExtensionData { + AutoUpgradeMinorVersion = true, + Name = "DependencyAgentWindows", + Publisher = "Microsoft.Azure.Monitoring.DependencyAgent", + TypePropertiesType = "DependencyAgentWindows", + TypeHandlerVersion = "9.5" + }; + } else { + // THIS TODO IS FROM PYTHON CODE + //# TODO: dependency agent for linux is not reliable + //# extension = { + //# "name": "DependencyAgentLinux", + //# "publisher": "Microsoft.Azure.Monitoring.DependencyAgent", + //# "type": "DependencyAgentLinux", + //# "typeHandlerVersion": "9.5", + //# "location": vm.region, + //# "autoUpgradeMinorVersion": True, + //# } + return null; + } + } + + + public async Async.Task BuildPoolConfig(Pool pool) { + var instanceId = await _containers.GetInstanceId(); + + var queueSas = await _queue.GetQueueSas("node-heartbeat", StorageType.Config, QueueSasPermissions.Add); + var config = new AgentConfig( + ClientCredentials: null, + OneFuzzUrl: _creds.GetInstanceUrl(), + PoolName: pool.Name, + HeartbeatQueue: queueSas, + InstanceTelemetryKey: _serviceConfig.ApplicationInsightsInstrumentationKey, + MicrosoftTelemetryKey: _serviceConfig.OneFuzzTelemetry, + MultiTenantDomain: _serviceConfig.MultiTenantDomain, + InstanceId: instanceId + ); + + var fileName = $"{pool.Name}/config.json"; + await _containers.SaveBlob(new Container("vm-scripts"), fileName, (JsonSerializer.Serialize(config, EntityConverter.GetJsonSerializerOptions())), StorageType.Config); + return await ConfigUrl(new Container("vm-scripts"), fileName, false); + } + + + public async Async.Task BuildScaleSetScript(Pool pool, Scaleset scaleSet) { + List commands = new(); + var extension = pool.Os == Os.Windows ? "ps1" : "sh"; + var fileName = $"{scaleSet.ScalesetId}/scaleset-setup.{extension}"; + var sep = pool.Os == Os.Windows ? "\r\n" : "\n"; + + if (pool.Os == Os.Windows && scaleSet.Auth is not null) { + var sshKey = scaleSet.Auth.PublicKey.Trim(); + var sshPath = "$env:ProgramData/ssh/administrators_authorized_keys"; + commands.Add($"Set-Content -Path {sshPath} -Value \"{sshKey}\""); + } + + await _containers.SaveBlob(new Container("vm-scripts"), fileName, string.Join(sep, commands) + sep, StorageType.Config); + return await _containers.GetFileUrl(new Container("vm-scripts"), fileName, StorageType.Config); + } + + public async Async.Task UpdateManagedScripts() { + var instanceSpecificSetupSas = _containers.GetContainerSasUrl(new Container("instance-specific-setup"), StorageType.Config, BlobSasPermissions.List | BlobSasPermissions.Read); + var toolsSas = _containers.GetContainerSasUrl(new Container("tools"), StorageType.Config, BlobSasPermissions.List | BlobSasPermissions.Read); + + string[] commands = { + $"azcopy sync '{instanceSpecificSetupSas}' instance-specific-setup", + $"azcopy sync '{toolsSas}' tools" + }; + + await _containers.SaveBlob(new Container("vm-scripts"), "managed.ps1", string.Join("\r\n", commands) + "\r\n", StorageType.Config); + await _containers.SaveBlob(new Container("vm-scripts"), "managed.sh", string.Join("\n", commands) + "\n", StorageType.Config); + } + + + public async Async.Task AgentConfig(string region, Os vmOs, AgentMode mode, List? urls = null, bool withSas = false) { + await UpdateManagedScripts(); + var urlsUpdated = urls ?? new(); + + if (vmOs == Os.Windows) { + var vmScripts = await ConfigUrl(new Container("vm-scripts"), "managed.ps1", withSas) ?? throw new Exception("failed to get VmScripts config url"); + var toolsAzCopy = await ConfigUrl(new Container("tools"), "win64/azcopy.exe", withSas) ?? throw new Exception("failed to get toolsAzCopy config url"); + var toolsSetup = await ConfigUrl(new Container("tools"), "win64/setup.ps1", withSas) ?? throw new Exception("failed to get toolsSetup config url"); + var toolsOneFuzz = await ConfigUrl(new Container("tools"), "win64/onefuzz.ps1", withSas) ?? throw new Exception("failed to get toolsOneFuzz config url"); + + urlsUpdated.Add(vmScripts); + urlsUpdated.Add(toolsAzCopy); + urlsUpdated.Add(toolsSetup); + urlsUpdated.Add(toolsOneFuzz); + + var toExecuteCmd = $"powershell -ExecutionPolicy Unrestricted -File win64/setup.ps1 -mode {mode}"; + + var extension = new VirtualMachineScaleSetExtensionData { + Name = "CustomScriptExtension", + TypePropertiesType = "CustomScriptExtension", + Publisher = "Microsoft.Compute", + ForceUpdateTag = Guid.NewGuid().ToString(), + TypeHandlerVersion = "1.9", + AutoUpgradeMinorVersion = true, + Settings = new BinaryData(new { commandToExecute = toExecuteCmd, fileUrls = urlsUpdated }), + ProtectedSettings = new BinaryData(new { managedIdentity = new Dictionary() }) + }; + return extension; + } else if (vmOs == Os.Linux) { + + var vmScripts = await ConfigUrl(new Container("vm-scripts"), "managed.sh", withSas) ?? throw new Exception("failed to get VmScripts config url"); + var toolsAzCopy = await ConfigUrl(new Container("tools"), "linux/azcopy", withSas) ?? throw new Exception("failed to get toolsAzCopy config url"); + var toolsSetup = await ConfigUrl(new Container("tools"), "linux/setup.sh", withSas) ?? throw new Exception("failed to get toolsSetup config url"); + + urlsUpdated.Add(vmScripts); + urlsUpdated.Add(toolsAzCopy); + urlsUpdated.Add(toolsSetup); + + var toExecuteCmd = $"sh setup.sh {mode}"; + + var extension = new VirtualMachineScaleSetExtensionData { + Name = "CustomScript", + TypePropertiesType = "CustomScript", + Publisher = "Microsoft.Azure.Extension", + ForceUpdateTag = Guid.NewGuid().ToString(), + TypeHandlerVersion = "2.1", + AutoUpgradeMinorVersion = true, + Settings = new BinaryData(new { CommandToExecute = toExecuteCmd, FileUrls = urlsUpdated }), + ProtectedSettings = new BinaryData(new { ManagedIdentity = new Dictionary() }) + }; + return extension; + } + + throw new NotImplementedException($"unsupported OS: {vmOs}"); + } + + public async Async.Task MonitorExtension(string region, Os vmOs) { + var settings = await _logAnalytics.GetMonitorSettings(); + + if (vmOs == Os.Windows) { + return new VirtualMachineScaleSetExtensionData { + Name = "OMSExtension", + TypePropertiesType = "MicrosoftMonitoringAgent", + Publisher = "Microsoft.EnterpriseCloud.Monitoring", + TypeHandlerVersion = "1.0", + AutoUpgradeMinorVersion = true, + Settings = new BinaryData(new { WorkSpaceId = settings.Id }), + ProtectedSettings = new BinaryData(new { WorkspaceKey = settings.Key }) + }; + } else if (vmOs == Os.Linux) { + return new VirtualMachineScaleSetExtensionData { + Name = "OMSExtension", + TypePropertiesType = "OmsAgentForLinux", + Publisher = "Microsoft.EnterpriseCloud.Monitoring", + TypeHandlerVersion = "1.12", + AutoUpgradeMinorVersion = true, + Settings = new BinaryData(new { WorkSpaceId = settings.Id }), + ProtectedSettings = new BinaryData(new { WorkspaceKey = settings.Key }) + }; + } else { + throw new NotImplementedException($"unsupported os: {vmOs}"); + } + } + + + public async Async.Task> FuzzExtensions(Pool pool, Scaleset scaleset) { + var poolConfig = await BuildPoolConfig(pool) ?? throw new Exception("pool config url is null"); + var scaleSetScript = await BuildScaleSetScript(pool, scaleset) ?? throw new Exception("scaleSet script url is null"); + var urls = new List() { poolConfig, scaleSetScript }; + + var fuzzExtension = await AgentConfig(scaleset.Region, pool.Os, AgentMode.Fuzz, urls); + var extensions = await GenericExtensions(scaleset.Region, pool.Os); + + extensions.Add(fuzzExtension); + return extensions; + } +} diff --git a/src/ApiService/ApiService/onefuzzlib/LogAnalytics.cs b/src/ApiService/ApiService/onefuzzlib/LogAnalytics.cs index 02d296c0ea..ac25141d59 100644 --- a/src/ApiService/ApiService/onefuzzlib/LogAnalytics.cs +++ b/src/ApiService/ApiService/onefuzzlib/LogAnalytics.cs @@ -3,7 +3,7 @@ namespace Microsoft.OneFuzz.Service; -public record MonitorSettings(string CustomerId, string Key); +public record MonitorSettings(string Id, string Key); public interface ILogAnalytics { public ResourceIdentifier GetWorkspaceId(); diff --git a/src/ApiService/ApiService/onefuzzlib/ScalesetOperations.cs b/src/ApiService/ApiService/onefuzzlib/ScalesetOperations.cs index a3e89e1a32..0147825e7b 100644 --- a/src/ApiService/ApiService/onefuzzlib/ScalesetOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/ScalesetOperations.cs @@ -7,13 +7,24 @@ public interface IScalesetOperations : IOrm { public IAsyncEnumerable SearchByPool(string poolName); + public Async.Task UpdateConfigs(Scaleset scaleSet); } public class ScalesetOperations : StatefulOrm, IScalesetOperations { + const string SCALESET_LOG_PREFIX = "scalesets: "; + ILogTracer _log; + IPoolOperations _poolOps; + IEvents _events; + IExtensions _extensions; + IVmssOperations _vmssOps; - public ScalesetOperations(IStorage storage, ILogTracer log, IServiceConfig config) + public ScalesetOperations(IStorage storage, ILogTracer log, IServiceConfig config, IPoolOperations poolOps, IEvents events, IExtensions extensions, IVmssOperations vmssOps) : base(storage, log, config) { - + _log = log; + _poolOps = poolOps; + _events = events; + _extensions = extensions; + _vmssOps = vmssOps; } public IAsyncEnumerable Search() { @@ -24,4 +35,69 @@ public IAsyncEnumerable SearchByPool(string poolName) { return QueryAsync(filter: $"pool_name eq '{poolName}'"); } + + async Async.Task SetState(Scaleset scaleSet, ScalesetState state) { + if (scaleSet.State == state) + return; + + if (scaleSet.State == ScalesetState.Halt) + return; + + var updatedScaleSet = scaleSet with { State = state }; + var r = await this.Replace(updatedScaleSet); + if (!r.IsOk) { + _log.Error($"Failed to update scaleset {scaleSet.ScalesetId} when updating state from {scaleSet.State} to {state}"); + } + + if (state == ScalesetState.Resize) { + await _events.SendEvent( + new EventScalesetResizeScheduled(updatedScaleSet.ScalesetId, updatedScaleSet.PoolName, updatedScaleSet.Size) + ); + } else { + await _events.SendEvent( + new EventScalesetStateUpdated(updatedScaleSet.ScalesetId, updatedScaleSet.PoolName, updatedScaleSet.State) + ); + } + } + + async Async.Task SetFailed(Scaleset scaleSet, Error error) { + if (scaleSet.Error is not null) + return; + + await SetState(scaleSet with { Error = error }, ScalesetState.CreationFailed); + await _events.SendEvent(new EventScalesetFailed(scaleSet.ScalesetId, scaleSet.PoolName, error)); + } + + public async Async.Task UpdateConfigs(Scaleset scaleSet) { + if (scaleSet == null) { + _log.Warning("skipping update configs on scaleset, since scaleset is null"); + return; + } + if (scaleSet.State == ScalesetState.Halt) { + _log.Info($"{SCALESET_LOG_PREFIX} not updating configs, scalest is set to be deleted. scaleset_id: {scaleSet.ScalesetId}"); + return; + } + if (!scaleSet.NeedsConfigUpdate) { + _log.Verbose($"{SCALESET_LOG_PREFIX} config update no needed. scaleset_id: {scaleSet.ScalesetId}"); + return; + } + + _log.Info($"{SCALESET_LOG_PREFIX} updating scalset configs. scalset_id: {scaleSet.ScalesetId}"); + + var pool = await _poolOps.GetByName(scaleSet.PoolName); + + if (!pool.IsOk || pool.OkV is null) { + _log.Error($"{SCALESET_LOG_PREFIX}: unable to find pool during config update. pool:{scaleSet.PoolName}, scaleset_id:{scaleSet.ScalesetId}"); + await SetFailed(scaleSet, pool.ErrorV!); + return; + } + + var extensions = await _extensions.FuzzExtensions(pool.OkV, scaleSet); + + var res = await _vmssOps.UpdateExtensions(scaleSet.ScalesetId, extensions); + + if (!res.IsOk) { + _log.Info($"{SCALESET_LOG_PREFIX}: unable to update configs {string.Join(',', res.ErrorV.Errors!)}"); + } + } } diff --git a/src/ApiService/ApiService/onefuzzlib/VmssOperations.cs b/src/ApiService/ApiService/onefuzzlib/VmssOperations.cs new file mode 100644 index 0000000000..d4bbee3e7b --- /dev/null +++ b/src/ApiService/ApiService/onefuzzlib/VmssOperations.cs @@ -0,0 +1,66 @@ +using Azure; +using Azure.ResourceManager.Compute; +using Azure.ResourceManager.Compute.Models; + +namespace Microsoft.OneFuzz.Service; + +public interface IVmssOperations { + public Async.Task UpdateExtensions(Guid name, IList extensions); + +} + +public class VmssOperations : IVmssOperations { + + ILogTracer _log; + ICreds _creds; + + public VmssOperations(ILogTracer log, ICreds creds) { + _log = log; + _creds = creds; + } + + private VirtualMachineScaleSetResource GetVmssResource(Guid name) { + var resourceGroup = _creds.GetBaseResourceGroup(); + var id = VirtualMachineScaleSetResource.CreateResourceIdentifier(_creds.GetSubscription(), resourceGroup, name.ToString()); + return _creds.ArmClient.GetVirtualMachineScaleSetResource(id); + } + + + public async Async.Task GetVmss(Guid name) { + var res = GetVmssResource(name); + _log.Verbose($"getting vmss: {name}"); + var r = await res.GetAsync(); + return r.Value.Data; + } + + public async Async.Task> CheckCanUpdate(Guid name) { + var vmss = await GetVmss(name); + if (vmss is null) { + return OneFuzzResult.Error(ErrorCode.UNABLE_TO_UPDATE, $"vmss not found: {name}"); + } + if (vmss.ProvisioningState == "Updating") { + return OneFuzzResult.Error(ErrorCode.UNABLE_TO_UPDATE, $"vmss is in updating state: {name}"); + } + return OneFuzzResult.Ok(vmss); + } + + + public async Async.Task UpdateExtensions(Guid name, IList extensions) { + var canUpdate = await CheckCanUpdate(name); + if (canUpdate.IsOk) { + _log.Info($"updating VM extensions: {name}"); + var res = GetVmssResource(name); + var patch = new VirtualMachineScaleSetPatch(); + + foreach (var ext in extensions) { + patch.VirtualMachineProfile.ExtensionProfile.Extensions.Add(ext); + } + var _ = await res.UpdateAsync(WaitUntil.Started, patch); + _log.Info($"VM extensions updated: {name}"); + return OneFuzzResultVoid.Ok(); + + } else { + return OneFuzzResultVoid.Error(canUpdate.ErrorV); + } + } +}