diff --git a/src/ApiService/ApiService/OneFuzzTypes/Model.cs b/src/ApiService/ApiService/OneFuzzTypes/Model.cs index 295b44ccc7..6b08b3a373 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Model.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Model.cs @@ -86,17 +86,19 @@ public record Node [RowKey] Guid MachineId, Guid? PoolId, string Version, - DateTimeOffset? Heartbeat = null, DateTimeOffset? InitializedAt = null, NodeState State = NodeState.Init, - List? Tasks = null, - List? Messages = null, + Guid? ScalesetId = null, bool ReimageRequested = false, bool DeleteRequested = false, bool DebugKeepNode = false -) : StatefulEntityBase(State); +) : StatefulEntityBase(State) { + + public List? Tasks { get; set; } + public List? Messages { get; set; } +} public record Forward diff --git a/src/ApiService/ApiService/Program.cs b/src/ApiService/ApiService/Program.cs index c4a433ad76..806bc2bb57 100644 --- a/src/ApiService/ApiService/Program.cs +++ b/src/ApiService/ApiService/Program.cs @@ -103,6 +103,7 @@ public static void Main() { .AddScoped() .AddScoped() .AddScoped() + .AddScoped() .AddSingleton() .AddSingleton() diff --git a/src/ApiService/ApiService/QueueNodeHearbeat.cs b/src/ApiService/ApiService/QueueNodeHearbeat.cs index c724137ed5..f8935c8069 100644 --- a/src/ApiService/ApiService/QueueNodeHearbeat.cs +++ b/src/ApiService/ApiService/QueueNodeHearbeat.cs @@ -8,22 +8,22 @@ namespace Microsoft.OneFuzz.Service; public class QueueNodeHearbeat { private readonly ILogTracer _log; - private readonly IEvents _events; - private readonly INodeOperations _nodes; + private readonly IOnefuzzContext _context; - public QueueNodeHearbeat(ILogTracer log, INodeOperations nodes, IEvents events) { + public QueueNodeHearbeat(ILogTracer log, IOnefuzzContext context) { _log = log; - _nodes = nodes; - _events = events; + _context = context; } [Function("QueueNodeHearbeat")] public async Async.Task Run([QueueTrigger("myqueue-items", Connection = "AzureWebJobsStorage")] string msg) { _log.Info($"heartbeat: {msg}"); + var nodes = _context.NodeOperations; + var events = _context.Events; var hb = JsonSerializer.Deserialize(msg, EntityConverter.GetJsonSerializerOptions()).EnsureNotNull($"wrong data {msg}"); - var node = await _nodes.GetByMachineId(hb.NodeId); + var node = await nodes.GetByMachineId(hb.NodeId); var log = _log.WithTag("NodeId", hb.NodeId.ToString()); @@ -34,7 +34,7 @@ public async Async.Task Run([QueueTrigger("myqueue-items", Connection = "AzureWe var newNode = node with { Heartbeat = DateTimeOffset.UtcNow }; - var r = await _nodes.Replace(newNode); + var r = await nodes.Replace(newNode); if (!r.IsOk) { var (status, reason) = r.ErrorV; @@ -42,6 +42,6 @@ public async Async.Task Run([QueueTrigger("myqueue-items", Connection = "AzureWe } // TODO: do we still send event if we fail do update the table ? - await _events.SendEvent(new EventNodeHeartbeat(node.MachineId, node.ScalesetId, node.PoolName)); + await events.SendEvent(new EventNodeHeartbeat(node.MachineId, node.ScalesetId, node.PoolName)); } } diff --git a/src/ApiService/ApiService/onefuzzlib/InstanceConfig.cs b/src/ApiService/ApiService/onefuzzlib/InstanceConfig.cs index fe77a88465..bb410c248c 100644 --- a/src/ApiService/ApiService/onefuzzlib/InstanceConfig.cs +++ b/src/ApiService/ApiService/onefuzzlib/InstanceConfig.cs @@ -11,16 +11,14 @@ public interface IConfigOperations : IOrm { } public class ConfigOperations : Orm, IConfigOperations { - private readonly IEvents _events; private readonly ILogTracer _log; - public ConfigOperations(IStorage storage, IEvents events, ILogTracer log, IServiceConfig config) : base(storage, log, config) { - _events = events; + public ConfigOperations(ILogTracer log, IOnefuzzContext context) : base(log, context) { _log = log; } public async Task Fetch() { - var key = _config.OneFuzzInstanceName ?? throw new Exception("Environment variable ONEFUZZ_INSTANCE_NAME is not set"); + var key = _context.ServiceConfiguration.OneFuzzInstanceName ?? throw new Exception("Environment variable ONEFUZZ_INSTANCE_NAME is not set"); var config = await GetEntityAsync(key, key); return config; } @@ -44,6 +42,6 @@ public async Async.Task Save(InstanceConfig config, bool isNew = false, bool req } } - await _events.SendEvent(new EventInstanceConfigUpdated(config)); + await _context.Events.SendEvent(new EventInstanceConfigUpdated(config)); } } diff --git a/src/ApiService/ApiService/onefuzzlib/JobOperations.cs b/src/ApiService/ApiService/onefuzzlib/JobOperations.cs index 4611678ff6..fc1fa242aa 100644 --- a/src/ApiService/ApiService/onefuzzlib/JobOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/JobOperations.cs @@ -12,10 +12,8 @@ public interface IJobOperations : IStatefulOrm { } public class JobOperations : StatefulOrm, IJobOperations { - private readonly IEvents _events; - public JobOperations(IStorage storage, ILogTracer logTracer, IServiceConfig config, IEvents events) : base(storage, logTracer, config) { - _events = events; + public JobOperations(ILogTracer logTracer, IOnefuzzContext context) : base(logTracer, context) { } public async Async.Task Get(Guid jobId) { @@ -56,7 +54,7 @@ public async Async.Task Stopping(Job job, ITaskOperations taskOperations) { } else { job = job with { State = JobState.Stopped }; var taskInfo = stopped.Select(t => new JobTaskStopped(t.TaskId, t.Config.Task.Type, t.Error)).ToList(); - await _events.SendEvent(new EventJobStopped(job.JobId, job.Config, job.UserInfo, taskInfo)); + await _context.Events.SendEvent(new EventJobStopped(job.JobId, job.Config, job.UserInfo, taskInfo)); } await Replace(job); diff --git a/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs b/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs index 4e2fd3272e..ebf5e76717 100644 --- a/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs @@ -52,47 +52,20 @@ Async.Task Create( /// https://docs.microsoft.com/en-us/azure/azure-monitor/platform/autoscale-common-metrics#commonly-used-storage-metrics public class NodeOperations : StatefulOrm, INodeOperations { - private IScalesetOperations _scalesetOperations; - private IPoolOperations _poolOperations; - private readonly INodeTasksOperations _nodeTasksOps; - private readonly ITaskOperations _taskOps; - private readonly INodeMessageOperations _nodeMessageOps; - private readonly IEvents _events; - private readonly ILogTracer _log; - private readonly ICreds _creds; - private readonly IVmssOperations _vmssOperations; public NodeOperations( - IStorage storage, ILogTracer log, - IServiceConfig config, - ITaskOperations taskOps, - INodeTasksOperations nodeTasksOps, - INodeMessageOperations nodeMessageOps, - IEvents events, - IScalesetOperations scalesetOperations, - IPoolOperations poolOperations, - IVmssOperations vmssOperations, - ICreds creds + IOnefuzzContext context ) - : base(storage, log, config) { - - _taskOps = taskOps; - _nodeTasksOps = nodeTasksOps; - _nodeMessageOps = nodeMessageOps; - _events = events; - _scalesetOperations = scalesetOperations; - _poolOperations = poolOperations; - _vmssOperations = vmssOperations; - _creds = creds; - _log = log; + : base(log, context) { + } public async Task AcquireScaleInProtection(Node node) { if (await ScalesetNodeExists(node) && node.ScalesetId != null) { _logTracer.Info($"Setting scale-in protection on node {node.MachineId}"); - return await _vmssOperations.UpdateScaleInProtection((Guid)node.ScalesetId, node.MachineId, protectFromScaleIn: true); + return await _context.VmssOperations.UpdateScaleInProtection((Guid)node.ScalesetId, node.MachineId, protectFromScaleIn: true); } return OneFuzzResultVoid.Ok(); } @@ -102,19 +75,19 @@ public async Async.Task ScalesetNodeExists(Node node) { return false; } - var scalesetResult = await _scalesetOperations.GetById((Guid)(node.ScalesetId!)); + var scalesetResult = await _context.ScalesetOperations.GetById((Guid)(node.ScalesetId!)); if (!scalesetResult.IsOk || scalesetResult.OkV == null) { return false; } var scaleset = scalesetResult.OkV; - var instanceId = await _vmssOperations.GetInstanceId(scaleset.ScalesetId, node.MachineId); + var instanceId = await _context.VmssOperations.GetInstanceId(scaleset.ScalesetId, node.MachineId); return instanceId.IsOk; } public async Task CanProcessNewWork(Node node) { if (IsOutdated(node)) { - _logTracer.Info($"can_process_new_work agent and service versions differ, stopping node. machine_id:{node.MachineId} agent_version:{node.Version} service_version:{_config.OneFuzzVersion}"); + _logTracer.Info($"can_process_new_work agent and service versions differ, stopping node. machine_id:{node.MachineId} agent_version:{node.Version} service_version:{_context.ServiceConfiguration.OneFuzzVersion}"); await Stop(node, done: true); return false; } @@ -154,7 +127,7 @@ public async Task CanProcessNewWork(Node node) { } if (node.ScalesetId != null) { - var scalesetResult = await _scalesetOperations.GetById(node.ScalesetId.Value); + var scalesetResult = await _context.ScalesetOperations.GetById(node.ScalesetId.Value); if (!scalesetResult.IsOk || scalesetResult.OkV == null) { _logTracer.Info($"can_process_new_work invalid scaleset. scaleset_id:{node.ScalesetId} machine_id:{node.MachineId}"); return false; @@ -167,7 +140,7 @@ public async Task CanProcessNewWork(Node node) { } } - var poolResult = await _poolOperations.GetByName(node.PoolName); + var poolResult = await _context.PoolOperations.GetByName(node.PoolName); if (!poolResult.IsOk || poolResult.OkV == null) { _logTracer.Info($"can_schedule - invalid pool. pool_name:{node.PoolName} machine_id:{node.MachineId}"); return false; @@ -192,7 +165,7 @@ public async Async.Task ReimageLongLivedNodes(Guid scaleSetId) { await foreach (var node in QueryAsync($"(scaleset_id eq {scaleSetId}) and {timeFilter}")) { if (node.DebugKeepNode) { - _log.Info($"removing debug_keep_node for expired node. scaleset_id:{node.ScalesetId} machine_id:{node.MachineId}"); + _logTracer.Info($"removing debug_keep_node for expired node. scaleset_id:{node.ScalesetId} machine_id:{node.MachineId}"); } await ToReimage(node with { DebugKeepNode = false }); } @@ -209,7 +182,7 @@ public async Async.Task ToReimage(Node node, bool done = false) { var reimageRequested = node.ReimageRequested; if (!node.ReimageRequested && !node.DeleteRequested) { - _log.Info($"setting reimage_requested: {node.MachineId}"); + _logTracer.Info($"setting reimage_requested: {node.MachineId}"); reimageRequested = true; } @@ -219,7 +192,7 @@ public async Async.Task ToReimage(Node node, bool done = false) { var r = await Replace(updatedNode); if (!r.IsOk) { - _log.WithHttpStatus(r.ErrorV).Error("Failed to save Node record"); + _logTracer.WithHttpStatus(r.ErrorV).Error("Failed to save Node record"); } } @@ -248,9 +221,9 @@ public async Async.Task Create( r = await Update(node); } if (!r.IsOk) { - _log.WithHttpStatus(r.ErrorV).Error($"failed to save NodeRecord, isNew: {isNew}"); + _logTracer.WithHttpStatus(r.ErrorV).Error($"failed to save NodeRecord, isNew: {isNew}"); } else { - await _events.SendEvent( + await _context.Events.SendEvent( new EventNodeCreated( node.MachineId, node.ScalesetId, @@ -273,23 +246,23 @@ public async Async.Task Stop(Node node, bool done = false) { /// /// public async Async.Task SetHalt(Node node) { - _log.Info($"setting halt: {node.MachineId}"); + _logTracer.Info($"setting halt: {node.MachineId}"); var updatedNode = node with { DeleteRequested = true }; await Stop(updatedNode, true); await SendStopIfFree(updatedNode); } public async Async.Task SendStopIfFree(Node node) { - var ver = new Version(_config.OneFuzzVersion.Split('-')[0]); + var ver = new Version(_context.ServiceConfiguration.OneFuzzVersion.Split('-')[0]); if (ver >= Version.Parse("2.16.1")) { await SendMessage(node, new NodeCommand(StopIfFree: new NodeCommandStopIfFree())); } } public async Async.Task SendMessage(Node node, NodeCommand message) { - var r = await _nodeMessageOps.Replace(new NodeMessage(node.MachineId, message)); + var r = await _context.NodeMessageOperations.Replace(new NodeMessage(node.MachineId, message)); if (!r.IsOk) { - _log.WithHttpStatus(r.ErrorV).Error($"failed to replace NodeMessge record for machine_id: {node.MachineId}"); + _logTracer.WithHttpStatus(r.ErrorV).Error($"failed to replace NodeMessge record for machine_id: {node.MachineId}"); } } @@ -301,7 +274,7 @@ public async Async.Task SendMessage(Node node, NodeCommand message) { } public bool IsOutdated(Node node) { - return node.Version != _config.OneFuzzVersion; + return node.Version != _context.ServiceConfiguration.OneFuzzVersion; } public bool IsTooOld(Node node) { @@ -318,7 +291,7 @@ public async Async.Task SetState(Node node, NodeState state) { var newNode = node; if (node.State != state) { newNode = newNode with { State = state }; - await _events.SendEvent(new EventNodeStateUpdated( + await _context.Events.SendEvent(new EventNodeStateUpdated( node.MachineId, node.ScalesetId, node.PoolName, @@ -375,7 +348,7 @@ public IAsyncEnumerable SearchStates( string? poolName = default, bool excludeUpdateScheduled = false, int? numResults = default) { - var query = NodeOperations.SearchStatesQuery(_config.OneFuzzVersion, poolId, scaleSetId, states, poolName, excludeUpdateScheduled, numResults); + var query = NodeOperations.SearchStatesQuery(_context.ServiceConfiguration.OneFuzzVersion, poolId, scaleSetId, states, poolName, excludeUpdateScheduled, numResults); return QueryAsync(query); } @@ -384,10 +357,10 @@ public async Async.Task MarkTasksStoppedEarly(Node node, Error? error = null) { error = new Error(ErrorCode.TASK_FAILED, new[] { $"node reimaged during task execution. machine_id: {node.MachineId}" }); } - await foreach (var entry in _nodeTasksOps.GetByMachineId(node.MachineId)) { - var task = await _taskOps.GetByTaskId(entry.TaskId); + await foreach (var entry in _context.NodeTasksOperations.GetByMachineId(node.MachineId)) { + var task = await _context.TaskOperations.GetByTaskId(entry.TaskId); if (task is not null) { - await _taskOps.MarkFailed(task, error); + await _context.TaskOperations.MarkFailed(task, error); } if (!node.DebugKeepNode) { await Delete(node); @@ -397,11 +370,11 @@ public async Async.Task MarkTasksStoppedEarly(Node node, Error? error = null) { public new async Async.Task Delete(Node node) { await MarkTasksStoppedEarly(node); - await _nodeTasksOps.ClearByMachineId(node.MachineId); - await _nodeMessageOps.ClearMessages(node.MachineId); + await _context.NodeTasksOperations.ClearByMachineId(node.MachineId); + await _context.NodeMessageOperations.ClearMessages(node.MachineId); await base.Delete(node); - await _events.SendEvent(new EventNodeDeleted(node.MachineId, node.ScalesetId, node.PoolName)); + await _context.Events.SendEvent(new EventNodeDeleted(node.MachineId, node.ScalesetId, node.PoolName)); } } @@ -419,8 +392,8 @@ public class NodeTasksOperations : StatefulOrm, INodeT ILogTracer _log; - public NodeTasksOperations(IStorage storage, ILogTracer log, IServiceConfig config) - : base(storage, log, config) { + public NodeTasksOperations(ILogTracer log, IOnefuzzContext context) + : base(log, context) { _log = log; } @@ -453,11 +426,11 @@ public IAsyncEnumerable GetByTaskId(Guid taskId) { } public async Async.Task ClearByMachineId(Guid machineId) { - _log.Info($"clearing tasks for node {machineId}"); + _logTracer.Info($"clearing tasks for node {machineId}"); await foreach (var entry in GetByMachineId(machineId)) { var res = await Delete(entry); if (!res.IsOk) { - _log.WithHttpStatus(res.ErrorV).Error($"failed to delete node task entry for machine_id: {entry.MachineId}"); + _logTracer.WithHttpStatus(res.ErrorV).Error($"failed to delete node task entry for machine_id: {entry.MachineId}"); } } } @@ -484,7 +457,7 @@ public interface INodeMessageOperations : IOrm { public class NodeMessageOperations : Orm, INodeMessageOperations { private readonly ILogTracer _log; - public NodeMessageOperations(IStorage storage, ILogTracer log, IServiceConfig config) : base(storage, log, config) { + public NodeMessageOperations(ILogTracer log, IOnefuzzContext context) : base(log, context) { _log = log; } @@ -493,12 +466,12 @@ public IAsyncEnumerable GetMessage(Guid machineId) { } public async Async.Task ClearMessages(Guid machineId) { - _log.Info($"clearing messages for node {machineId}"); + _logTracer.Info($"clearing messages for node {machineId}"); await foreach (var message in GetMessage(machineId)) { var r = await Delete(message); if (!r.IsOk) { - _log.WithHttpStatus(r.ErrorV).Error($"failed to delete message for node {machineId}"); + _logTracer.WithHttpStatus(r.ErrorV).Error($"failed to delete message for node {machineId}"); } } } diff --git a/src/ApiService/ApiService/onefuzzlib/NotificationOperations.cs b/src/ApiService/ApiService/onefuzzlib/NotificationOperations.cs index c1bba766c7..66b229b845 100644 --- a/src/ApiService/ApiService/onefuzzlib/NotificationOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/NotificationOperations.cs @@ -9,28 +9,15 @@ public interface INotificationOperations : IOrm { } public class NotificationOperations : Orm, INotificationOperations { - private IReports _reports; - private ITaskOperations _taskOperations; - private IContainers _containers; + public NotificationOperations(ILogTracer log, IOnefuzzContext context) + : base(log, context) { - private IQueue _queue; - - private IEvents _events; - - public NotificationOperations(ILogTracer log, IStorage storage, IReports reports, ITaskOperations taskOperations, IContainers containers, IQueue queue, IEvents events, IServiceConfig config) - : base(storage, log, config) { - - _reports = reports; - _taskOperations = taskOperations; - _containers = containers; - _queue = queue; - _events = events; } public async Async.Task NewFiles(Container container, string filename, bool failTaskOnTransientError) { var notifications = GetNotifications(container); var hasNotifications = await notifications.AnyAsync(); - var report = await _reports.GetReportOrRegression(container, filename, expectReports: hasNotifications); + var report = await _context.Reports.GetReportOrRegression(container, filename, expectReports: hasNotifications); if (!hasNotifications) { return; @@ -64,18 +51,18 @@ public async Async.Task NewFiles(Container container, string filename, bool fail await foreach (var (task, containers) in GetQueueTasks()) { if (containers.Contains(container.ContainerName)) { _logTracer.Info($"queuing input {container.ContainerName} {filename} {task.TaskId}"); - var url = _containers.GetFileSasUrl(container, filename, StorageType.Corpus, BlobSasPermissions.Read | BlobSasPermissions.Delete); - await _queue.SendMessage(task.TaskId.ToString(), url?.ToString() ?? "", StorageType.Corpus); + var url = _context.Containers.GetFileSasUrl(container, filename, StorageType.Corpus, BlobSasPermissions.Read | BlobSasPermissions.Delete); + await _context.Queue.SendMessage(task.TaskId.ToString(), url?.ToString() ?? "", StorageType.Corpus); } } if (report == null) { - await _events.SendEvent(new EventFileAdded(container, filename)); + await _context.Events.SendEvent(new EventFileAdded(container, filename)); } else if (report.Report != null) { - var reportTask = await _taskOperations.GetByJobIdAndTaskId(report.Report.JobId, report.Report.TaskId); + var reportTask = await _context.TaskOperations.GetByJobIdAndTaskId(report.Report.JobId, report.Report.TaskId); var crashReportedEvent = new EventCrashReported(report.Report, container, filename, reportTask?.Config); - await _events.SendEvent(crashReportedEvent); + await _context.Events.SendEvent(crashReportedEvent); } else if (report.RegressionReport != null) { var reportTask = await GetRegressionReportTask(report.RegressionReport); @@ -89,17 +76,17 @@ public IAsyncEnumerable GetNotifications(Container container) { public IAsyncEnumerable<(Task, IEnumerable)> GetQueueTasks() { // Nullability mismatch: We filter tuples where the containers are null - return _taskOperations.SearchStates(states: TaskStateHelper.Available()) - .Select(task => (task, _taskOperations.GetInputContainerQueues(task.Config))) + return _context.TaskOperations.SearchStates(states: TaskStateHelper.Available()) + .Select(task => (task, _context.TaskOperations.GetInputContainerQueues(task.Config))) .Where(taskTuple => taskTuple.Item2 != null)!; } private async Async.Task GetRegressionReportTask(RegressionReport report) { if (report.CrashTestResult.CrashReport != null) { - return await _taskOperations.GetByJobIdAndTaskId(report.CrashTestResult.CrashReport.JobId, report.CrashTestResult.CrashReport.TaskId); + return await _context.TaskOperations.GetByJobIdAndTaskId(report.CrashTestResult.CrashReport.JobId, report.CrashTestResult.CrashReport.TaskId); } if (report.CrashTestResult.NoReproReport != null) { - return await _taskOperations.GetByJobIdAndTaskId(report.CrashTestResult.NoReproReport.JobId, report.CrashTestResult.NoReproReport.TaskId); + return await _context.TaskOperations.GetByJobIdAndTaskId(report.CrashTestResult.NoReproReport.JobId, report.CrashTestResult.NoReproReport.TaskId); } _logTracer.Error($"unable to find crash_report or no repro entry for report: {JsonSerializer.Serialize(report)}"); diff --git a/src/ApiService/ApiService/onefuzzlib/OnefuzzContext.cs b/src/ApiService/ApiService/onefuzzlib/OnefuzzContext.cs new file mode 100644 index 0000000000..a2f2caf159 --- /dev/null +++ b/src/ApiService/ApiService/onefuzzlib/OnefuzzContext.cs @@ -0,0 +1,78 @@ +namespace Microsoft.OneFuzz.Service; + +using Microsoft.Extensions.DependencyInjection; + +public interface IOnefuzzContext { + IConfig Config { get; } + IConfigOperations ConfigOperations { get; } + IContainers Containers { get; } + ICreds Creds { get; } + IDiskOperations DiskOperations { get; } + IEvents Events { get; } + IExtensions Extensions { get; } + IIpOperations IpOperations { get; } + IJobOperations JobOperations { get; } + ILogAnalytics LogAnalytics { get; } + INodeMessageOperations NodeMessageOperations { get; } + INodeOperations NodeOperations { get; } + INodeTasksOperations NodeTasksOperations { get; } + INotificationOperations NotificationOperations { get; } + IPoolOperations PoolOperations { get; } + IProxyForwardOperations ProxyForwardOperations { get; } + IProxyOperations ProxyOperations { get; } + IQueue Queue { get; } + IReports Reports { get; } + IReproOperations ReproOperations { get; } + IScalesetOperations ScalesetOperations { get; } + IScheduler Scheduler { get; } + ISecretsOperations SecretsOperations { get; } + IServiceConfig ServiceConfiguration { get; } + IStorage Storage { get; } + ITaskOperations TaskOperations { get; } + IUserCredentials UserCredentials { get; } + IVmOperations VmOperations { get; } + IVmssOperations VmssOperations { get; } + IWebhookMessageLogOperations WebhookMessageLogOperations { get; } + IWebhookOperations WebhookOperations { get; } +} + +public class OnefuzzContext : IOnefuzzContext { + + private readonly IServiceProvider _serviceProvider; + public INodeOperations NodeOperations { get => _serviceProvider.GetService() ?? throw new Exception("No INodeOperations service"); } + public IEvents Events { get => _serviceProvider.GetService() ?? throw new Exception("No IEvents service"); } + public IWebhookOperations WebhookOperations { get => _serviceProvider.GetService() ?? throw new Exception("No IWebhookOperations service"); } + public IWebhookMessageLogOperations WebhookMessageLogOperations { get => _serviceProvider.GetService() ?? throw new Exception("No IWebhookMessageLogOperations service"); } + public ITaskOperations TaskOperations { get => _serviceProvider.GetService() ?? throw new Exception("No ITaskOperations service"); } + public IQueue Queue { get => _serviceProvider.GetService() ?? throw new Exception("No IQueue service"); } + public IStorage Storage { get => _serviceProvider.GetService() ?? throw new Exception("No IStorage service"); } + public IProxyOperations ProxyOperations { get => _serviceProvider.GetService() ?? throw new Exception("No IProxyOperations service"); } + public IProxyForwardOperations ProxyForwardOperations { get => _serviceProvider.GetService() ?? throw new Exception("No IProxyForwardOperations service"); } + public IConfigOperations ConfigOperations { get => _serviceProvider.GetService() ?? throw new Exception("No IConfigOperations service"); } + public IScalesetOperations ScalesetOperations { get => _serviceProvider.GetService() ?? throw new Exception("No IScalesetOperations service"); } + public IContainers Containers { get => _serviceProvider.GetService() ?? throw new Exception("No IContainers service"); } + public IReports Reports { get => _serviceProvider.GetService() ?? throw new Exception("No IReports service"); } + public INotificationOperations NotificationOperations { get => _serviceProvider.GetService() ?? throw new Exception("No INotificationOperations service"); } + public IUserCredentials UserCredentials { get => _serviceProvider.GetService() ?? throw new Exception("No IUserCredentials service"); } + public IReproOperations ReproOperations { get => _serviceProvider.GetService() ?? throw new Exception("No IReproOperations service"); } + public IPoolOperations PoolOperations { get => _serviceProvider.GetService() ?? throw new Exception("No IPoolOperations service"); } + public IIpOperations IpOperations { get => _serviceProvider.GetService() ?? throw new Exception("No IIpOperations service"); } + public IDiskOperations DiskOperations { get => _serviceProvider.GetService() ?? throw new Exception("No IDiskOperations service"); } + public IVmOperations VmOperations { get => _serviceProvider.GetService() ?? throw new Exception("No IVmOperations service"); } + public ISecretsOperations SecretsOperations { get => _serviceProvider.GetService() ?? throw new Exception("No ISecretsOperations service"); } + public IJobOperations JobOperations { get => _serviceProvider.GetService() ?? throw new Exception("No IJobOperations service"); } + public IScheduler Scheduler { get => _serviceProvider.GetService() ?? throw new Exception("No IScheduler service"); } + public IConfig Config { get => _serviceProvider.GetService() ?? throw new Exception("No IConfig service"); } + public ILogAnalytics LogAnalytics { get => _serviceProvider.GetService() ?? throw new Exception("No ILogAnalytics service"); } + public IExtensions Extensions { get => _serviceProvider.GetService() ?? throw new Exception("No IExtensions service"); } + public IVmssOperations VmssOperations { get => _serviceProvider.GetService() ?? throw new Exception("No IVmssOperations service"); } + public INodeTasksOperations NodeTasksOperations { get => _serviceProvider.GetService() ?? throw new Exception("No INodeTasksOperations service"); } + public INodeMessageOperations NodeMessageOperations { get => _serviceProvider.GetService() ?? throw new Exception("No INodeMessageOperations service"); } + public ICreds Creds { get => _serviceProvider.GetService() ?? throw new Exception("No ICreds service"); } + public IServiceConfig ServiceConfiguration { get => _serviceProvider.GetService() ?? throw new Exception("No IServiceConfiguration service"); } + + public OnefuzzContext(IServiceProvider serviceProvider) { + _serviceProvider = serviceProvider; + } +} + diff --git a/src/ApiService/ApiService/onefuzzlib/PoolOperations.cs b/src/ApiService/ApiService/onefuzzlib/PoolOperations.cs index 5e14608828..a8089b7733 100644 --- a/src/ApiService/ApiService/onefuzzlib/PoolOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/PoolOperations.cs @@ -9,13 +9,10 @@ public interface IPoolOperations { } public class PoolOperations : StatefulOrm, IPoolOperations { - private IConfigOperations _configOperations; - private readonly IQueue _queue; - public PoolOperations(IStorage storage, ILogTracer log, IServiceConfig config, IConfigOperations configOperations, IQueue queue) - : base(storage, log, config) { - _configOperations = configOperations; - _queue = queue; + public PoolOperations(ILogTracer log, IOnefuzzContext context) + : base(log, context) { + } public async Async.Task> GetByName(string poolName) { @@ -37,7 +34,7 @@ public async Task ScheduleWorkset(Pool pool, WorkSet workSet) { return false; } - return await _queue.QueueObject(GetPoolQueue(pool), workSet, StorageType.Corpus); + return await _context.Queue.QueueObject(GetPoolQueue(pool), workSet, StorageType.Corpus); } private string GetPoolQueue(Pool pool) { diff --git a/src/ApiService/ApiService/onefuzzlib/ProxyForwardOperations.cs b/src/ApiService/ApiService/onefuzzlib/ProxyForwardOperations.cs index 805b5485d3..7ad9e0920f 100644 --- a/src/ApiService/ApiService/onefuzzlib/ProxyForwardOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/ProxyForwardOperations.cs @@ -9,7 +9,9 @@ public interface IProxyForwardOperations : IOrm { public class ProxyForwardOperations : Orm, IProxyForwardOperations { - public ProxyForwardOperations(IStorage storage, ILogTracer logTracer, IServiceConfig config) : base(storage, logTracer, config) { + public ProxyForwardOperations(ILogTracer log, IOnefuzzContext context) + : base(log, context) { + } public IAsyncEnumerable SearchForward(Guid? scalesetId = null, string? region = null, Guid? machineId = null, Guid? proxyId = null, int? dstPort = null) { diff --git a/src/ApiService/ApiService/onefuzzlib/ProxyOperations.cs b/src/ApiService/ApiService/onefuzzlib/ProxyOperations.cs index 609d31c1b0..1dc040fe23 100644 --- a/src/ApiService/ApiService/onefuzzlib/ProxyOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/ProxyOperations.cs @@ -15,23 +15,15 @@ public interface IProxyOperations : IStatefulOrm { } public class ProxyOperations : StatefulOrm, IProxyOperations { - private readonly IEvents _events; - private readonly IProxyForwardOperations _proxyForwardOperations; - private readonly IContainers _containers; - private readonly IQueue _queue; - private readonly ICreds _creds; static TimeSpan PROXY_LIFESPAN = TimeSpan.FromDays(7); - public ProxyOperations(ILogTracer log, IStorage storage, IEvents events, IProxyForwardOperations proxyForwardOperations, IContainers containers, IQueue queue, ICreds creds, IServiceConfig config) - : base(storage, log.WithTag("Component", "scaleset-proxy"), config) { - _events = events; - _proxyForwardOperations = proxyForwardOperations; - _containers = containers; - _queue = queue; - _creds = creds; + public ProxyOperations(ILogTracer log, IOnefuzzContext context) + : base(log.WithTag("Component", "scaleset-proxy"), context) { + } + public async Task GetByProxyId(Guid proxyId) { var data = QueryAsync(filter: $"RowKey eq '{proxyId}'"); @@ -55,10 +47,10 @@ public ProxyOperations(ILogTracer log, IStorage storage, IEvents events, IProxyF } _logTracer.Info($"creating proxy: region:{region}"); - var newProxy = new Proxy(region, Guid.NewGuid(), DateTimeOffset.UtcNow, VmState.Init, Auth.BuildAuth(), null, null, _config.OneFuzzVersion.ToString(), null, false); + var newProxy = new Proxy(region, Guid.NewGuid(), DateTimeOffset.UtcNow, VmState.Init, Auth.BuildAuth(), null, null, _context.ServiceConfiguration.OneFuzzVersion, null, false); await Replace(newProxy); - await _events.SendEvent(new EventProxyCreated(region, newProxy.ProxyId)); + await _context.Events.SendEvent(new EventProxyCreated(region, newProxy.ProxyId)); return newProxy; } @@ -83,8 +75,8 @@ public bool IsOutdated(Proxy proxy) { return false; } - if (proxy.Version != _config.OneFuzzVersion) { - _logTracer.Info($"mismatch version: proxy:{proxy.Version} service:{_config.OneFuzzVersion} state:{proxy.State}"); + if (proxy.Version != _context.ServiceConfiguration.OneFuzzVersion) { + _logTracer.Info($"mismatch version: proxy:{proxy.Version} service:{_context.ServiceConfiguration.OneFuzzVersion} state:{proxy.State}"); return true; } @@ -99,8 +91,8 @@ public bool IsOutdated(Proxy proxy) { public async Async.Task SaveProxyConfig(Proxy proxy) { var forwards = await GetForwards(proxy); - var url = (await _containers.GetFileSasUrl(new Container("proxy-configs"), $"{proxy.Region}/{proxy.ProxyId}/config.json", StorageType.Config, BlobSasPermissions.Read)).EnsureNotNull("Can't generate file sas"); - var queueSas = await _queue.GetQueueSas("proxy", StorageType.Config, QueueSasPermissions.Add).EnsureNotNull("can't generate queue sas") ?? throw new Exception("Queue sas is null"); + var url = (await _context.Containers.GetFileSasUrl(new Container("proxy-configs"), $"{proxy.Region}/{proxy.ProxyId}/config.json", StorageType.Config, BlobSasPermissions.Read)).EnsureNotNull("Can't generate file sas"); + var queueSas = await _context.Queue.GetQueueSas("proxy", StorageType.Config, QueueSasPermissions.Add).EnsureNotNull("can't generate queue sas") ?? throw new Exception("Queue sas is null"); var proxyConfig = new ProxyConfig( Url: url, @@ -108,11 +100,11 @@ public async Async.Task SaveProxyConfig(Proxy proxy) { Region: proxy.Region, ProxyId: proxy.ProxyId, Forwards: forwards, - InstanceTelemetryKey: _config.ApplicationInsightsInstrumentationKey.EnsureNotNull("missing InstrumentationKey"), - MicrosoftTelemetryKey: _config.OneFuzzTelemetry.EnsureNotNull("missing Telemetry"), - InstanceId: await _containers.GetInstanceId()); + InstanceTelemetryKey: _context.ServiceConfiguration.ApplicationInsightsInstrumentationKey.EnsureNotNull("missing InstrumentationKey"), + MicrosoftTelemetryKey: _context.ServiceConfiguration.OneFuzzTelemetry.EnsureNotNull("missing Telemetry"), + InstanceId: await _context.Containers.GetInstanceId()); - await _containers.SaveBlob(new Container("proxy-configs"), $"{proxy.Region}/{proxy.ProxyId}/config.json", _entityConverter.ToJsonString(proxyConfig), StorageType.Config); + await _context.Containers.SaveBlob(new Container("proxy-configs"), $"{proxy.Region}/{proxy.ProxyId}/config.json", _entityConverter.ToJsonString(proxyConfig), StorageType.Config); } @@ -124,16 +116,16 @@ public async Async.Task SetState(Proxy proxy, VmState state) { await Replace(proxy with { State = state }); - await _events.SendEvent(new EventProxyStateUpdated(proxy.Region, proxy.ProxyId, proxy.State)); + await _context.Events.SendEvent(new EventProxyStateUpdated(proxy.Region, proxy.ProxyId, proxy.State)); } public async Async.Task> GetForwards(Proxy proxy) { var forwards = new List(); - await foreach (var entry in _proxyForwardOperations.SearchForward(region: proxy.Region, proxyId: proxy.ProxyId)) { + await foreach (var entry in _context.ProxyForwardOperations.SearchForward(region: proxy.Region, proxyId: proxy.ProxyId)) { if (entry.EndTime < DateTimeOffset.UtcNow) { - await _proxyForwardOperations.Delete(entry); + await _context.ProxyForwardOperations.Delete(entry); } else { forwards.Add(new Forward(entry.Port, entry.DstPort, entry.DstIp)); } diff --git a/src/ApiService/ApiService/onefuzzlib/ReproOperations.cs b/src/ApiService/ApiService/onefuzzlib/ReproOperations.cs index 43faf70dc9..c45c07c451 100644 --- a/src/ApiService/ApiService/onefuzzlib/ReproOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/ReproOperations.cs @@ -19,19 +19,11 @@ public class ReproOperations : StatefulOrm, IReproOperations { const string DEFAULT_SKU = "Standard_DS1_v2"; - private IConfigOperations _configOperations; - private ITaskOperations _taskOperations; - private IVmOperations _vmOperations; - private ICreds _creds; + public ReproOperations(ILogTracer log, IOnefuzzContext context) + : base(log, context) { - public ReproOperations(IStorage storage, ILogTracer log, IServiceConfig config, IConfigOperations configOperations, ITaskOperations taskOperations, ICreds creds, IVmOperations vmOperations) - : base(storage, log, config) { - _configOperations = configOperations; - _taskOperations = taskOperations; - _creds = creds; - _vmOperations = vmOperations; } public IAsyncEnumerable SearchExpired() { @@ -39,20 +31,21 @@ public IAsyncEnumerable SearchExpired() { } public async Async.Task GetVm(Repro repro, InstanceConfig config) { + var taskOperations = _context.TaskOperations; var tags = config.VmTags; - var task = await _taskOperations.GetByTaskId(repro.TaskId); + var task = await taskOperations.GetByTaskId(repro.TaskId); if (task == null) { throw new Exception($"previous existing task missing: {repro.TaskId}"); } - var vmConfig = await _taskOperations.GetReproVmConfig(task); + var vmConfig = await taskOperations.GetReproVmConfig(task); if (vmConfig == null) { if (!DEFAULT_OS.ContainsKey(task.Os)) { throw new NotImplementedException($"unsupport OS for repro {task.Os}"); } vmConfig = new TaskVm( - await _creds.GetBaseRegion(), + await _context.Creds.GetBaseRegion(), DEFAULT_SKU, DEFAULT_OS[task.Os], null @@ -75,11 +68,12 @@ await _creds.GetBaseRegion(), } public async Async.Task Stopping(Repro repro) { - var config = await _configOperations.Fetch(); + var config = await _context.ConfigOperations.Fetch(); var vm = await GetVm(repro, config); - if (!await _vmOperations.IsDeleted(vm)) { + var vmOperations = _context.VmOperations; + if (!await vmOperations.IsDeleted(vm)) { _logTracer.Info($"vm stopping: {repro.VmId}"); - await _vmOperations.Delete(vm); + await vmOperations.Delete(vm); await Replace(repro); } else { await Stopped(repro); diff --git a/src/ApiService/ApiService/onefuzzlib/ScalesetOperations.cs b/src/ApiService/ApiService/onefuzzlib/ScalesetOperations.cs index 9dc92d9201..535fa9bdc5 100644 --- a/src/ApiService/ApiService/onefuzzlib/ScalesetOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/ScalesetOperations.cs @@ -17,34 +17,11 @@ public class ScalesetOperations : StatefulOrm, IScalese const string SCALESET_LOG_PREFIX = "scalesets: "; ILogTracer _log; - IPoolOperations _poolOps; - IEvents _events; - IExtensions _extensions; - IVmssOperations _vmssOps; - IQueue _queue; - IServiceConfig _serviceConfig; - ICreds _creds; - - public ScalesetOperations( - IStorage storage, - ILogTracer log, - IServiceConfig config, - IPoolOperations poolOps, - IEvents events, - IExtensions extensions, - IVmssOperations vmssOps, - IQueue queue, - ICreds creds - ) - : base(storage, log, config) { + + public ScalesetOperations(ILogTracer log, IOnefuzzContext context) + : base(log, context) { _log = log; - _poolOps = poolOps; - _events = events; - _extensions = extensions; - _vmssOps = vmssOps; - _queue = queue; - _serviceConfig = config; - _creds = creds; + } public IAsyncEnumerable Search() { @@ -70,11 +47,11 @@ async Async.Task SetState(Scaleset scaleSet, ScalesetState state) { } if (state == ScalesetState.Resize) { - await _events.SendEvent( + await _context.Events.SendEvent( new EventScalesetResizeScheduled(updatedScaleSet.ScalesetId, updatedScaleSet.PoolName, updatedScaleSet.Size) ); } else { - await _events.SendEvent( + await _context.Events.SendEvent( new EventScalesetStateUpdated(updatedScaleSet.ScalesetId, updatedScaleSet.PoolName, updatedScaleSet.State) ); } @@ -85,7 +62,7 @@ async Async.Task SetFailed(Scaleset scaleSet, Error error) { return; await SetState(scaleSet with { Error = error }, ScalesetState.CreationFailed); - await _events.SendEvent(new EventScalesetFailed(scaleSet.ScalesetId, scaleSet.PoolName, error)); + await _context.Events.SendEvent(new EventScalesetFailed(scaleSet.ScalesetId, scaleSet.PoolName, error)); } public async Async.Task UpdateConfigs(Scaleset scaleSet) { @@ -104,7 +81,7 @@ public async Async.Task UpdateConfigs(Scaleset scaleSet) { _log.Info($"{SCALESET_LOG_PREFIX} updating scalset configs. scalset_id: {scaleSet.ScalesetId}"); - var pool = await _poolOps.GetByName(scaleSet.PoolName); + var pool = await _context.PoolOperations.GetByName(scaleSet.PoolName); if (!pool.IsOk || pool.OkV is null) { _log.Error($"{SCALESET_LOG_PREFIX} unable to find pool during config update. pool:{scaleSet.PoolName}, scaleset_id:{scaleSet.ScalesetId}"); @@ -112,26 +89,27 @@ public async Async.Task UpdateConfigs(Scaleset scaleSet) { return; } - var extensions = await _extensions.FuzzExtensions(pool.OkV, scaleSet); + var extensions = await _context.Extensions.FuzzExtensions(pool.OkV, scaleSet); - var res = await _vmssOps.UpdateExtensions(scaleSet.ScalesetId, extensions); + var res = await _context.VmssOperations.UpdateExtensions(scaleSet.ScalesetId, extensions); if (!res.IsOk) { _log.Info($"{SCALESET_LOG_PREFIX} unable to update configs {string.Join(',', res.ErrorV.Errors!)}"); } } - public async Async.Task Halt(Scaleset scaleset, INodeOperations _nodeOps) { - var shrinkQueue = new ShrinkQueue(scaleset.ScalesetId, _queue, _log); + + public async Async.Task Halt(Scaleset scaleset) { + var shrinkQueue = new ShrinkQueue(scaleset.ScalesetId, _context.Queue, _log); await shrinkQueue.Delete(); - await foreach (var node in _nodeOps.SearchStates(scaleSetId: scaleset.ScalesetId)) { + await foreach (var node in _context.NodeOperations.SearchStates(scaleSetId: scaleset.ScalesetId)) { _log.Info($"{SCALESET_LOG_PREFIX} deleting node scaleset_id {scaleset.ScalesetId} machine_id {node.MachineId}"); - await _nodeOps.Delete(node); + await _context.NodeOperations.Delete(node); } _log.Info($"{SCALESET_LOG_PREFIX} scaleset delete starting: scaleset_id:{scaleset.ScalesetId}"); - if (await _vmssOps.DeleteVmss(scaleset.ScalesetId)) { + if (await _context.VmssOperations.DeleteVmss(scaleset.ScalesetId)) { _log.Info($"{SCALESET_LOG_PREFIX}scaleset deleted: scaleset_id {scaleset.ScalesetId}"); var r = await Delete(scaleset); if (!r.IsOk) { @@ -150,32 +128,32 @@ public async Async.Task Halt(Scaleset scaleset, INodeOperations _nodeOps) { /// /// /// true if scaleset got modified - public async Async.Task CleanupNodes(Scaleset scaleSet, INodeOperations _nodeOps) { + public async Async.Task CleanupNodes(Scaleset scaleSet) { _log.Info($"{SCALESET_LOG_PREFIX} cleaning up nodes. scaleset_id {scaleSet.ScalesetId}"); if (scaleSet.State == ScalesetState.Halt) { _log.Info($"{SCALESET_LOG_PREFIX} halting scaleset scaleset_id {scaleSet.ScalesetId}"); - await Halt(scaleSet, _nodeOps); + await Halt(scaleSet); return true; } - var pool = await _poolOps.GetByName(scaleSet.PoolName); + var pool = await _context.PoolOperations.GetByName(scaleSet.PoolName); if (!pool.IsOk) { _log.Error($"unable to find pool during cleanup {scaleSet.ScalesetId} - {scaleSet.PoolName}"); await SetFailed(scaleSet, pool.ErrorV!); return true; } - await _nodeOps.ReimageLongLivedNodes(scaleSet.ScalesetId); + await _context.NodeOperations.ReimageLongLivedNodes(scaleSet.ScalesetId); //ground truth of existing nodes - var azureNodes = await _vmssOps.ListInstanceIds(scaleSet.ScalesetId); - var nodes = _nodeOps.SearchStates(scaleSetId: scaleSet.ScalesetId); + var azureNodes = await _context.VmssOperations.ListInstanceIds(scaleSet.ScalesetId); + var nodes = _context.NodeOperations.SearchStates(scaleSetId: scaleSet.ScalesetId); //# Nodes do not exists in scalesets but in table due to unknown failure await foreach (var node in nodes) { if (!azureNodes.ContainsKey(node.MachineId)) { _log.Info($"{SCALESET_LOG_PREFIX} no longer in scaleset. scaleset_id:{scaleSet.ScalesetId} machine_id:{node.MachineId}"); - await _nodeOps.Delete(node); + await _context.NodeOperations.Delete(node); } } @@ -202,7 +180,7 @@ public async Async.Task CleanupNodes(Scaleset scaleSet, INodeOperations _n //Python code does use created node //pool.IsOk was handled above, OkV must be not null at this point - var _ = await _nodeOps.Create(pool.OkV!.PoolId, scaleSet.PoolName, machineId, scaleSet.ScalesetId, _config.OneFuzzVersion, true); + var _ = await _context.NodeOperations.Create(pool.OkV!.PoolId, scaleSet.PoolName, machineId, scaleSet.ScalesetId, _context.ServiceConfiguration.OneFuzzVersion, true); } var existingNodes = @@ -223,11 +201,11 @@ where NodeStateHelper.ReadyForReset.Contains(x.State) if (node.DeleteRequested) { toDelete[node.MachineId] = node; } else { - if (await new ShrinkQueue(scaleSet.ScalesetId, _queue, _log).ShouldShrink()) { - await _nodeOps.SetHalt(node); + if (await new ShrinkQueue(scaleSet.ScalesetId, _context.Queue, _log).ShouldShrink()) { + await _context.NodeOperations.SetHalt(node); toDelete[node.MachineId] = node; - } else if (await new ShrinkQueue(pool.OkV!.PoolId, _queue, _log).ShouldShrink()) { - await _nodeOps.SetHalt(node); + } else if (await new ShrinkQueue(pool.OkV!.PoolId, _context.Queue, _log).ShouldShrink()) { + await _context.NodeOperations.SetHalt(node); toDelete[node.MachineId] = node; } else { toReimage[node.MachineId] = node; @@ -235,7 +213,7 @@ where NodeStateHelper.ReadyForReset.Contains(x.State) } } - var deadNodes = _nodeOps.GetDeadNodes(scaleSet.ScalesetId, INodeOperations.NODE_EXPIRATION_TIME); + var deadNodes = _context.NodeOperations.GetDeadNodes(scaleSet.ScalesetId, INodeOperations.NODE_EXPIRATION_TIME); await foreach (var deadNode in deadNodes) { string errorMessage; @@ -246,14 +224,14 @@ where NodeStateHelper.ReadyForReset.Contains(x.State) } var error = new Error(ErrorCode.TASK_FAILED, new[] { $"{errorMessage} scaleset_id {deadNode.ScalesetId} last heartbeat:{deadNode.Heartbeat}" }); - await _nodeOps.MarkTasksStoppedEarly(deadNode, error); - await _nodeOps.ToReimage(deadNode, true); + await _context.NodeOperations.MarkTasksStoppedEarly(deadNode, error); + await _context.NodeOperations.ToReimage(deadNode, true); toReimage[deadNode.MachineId] = deadNode; } // Perform operations until they fail due to scaleset getting locked NodeDisposalStrategy strategy = - (_serviceConfig.OneFuzzNodeDisposalStrategy.ToLowerInvariant()) switch { + (_context.ServiceConfiguration.OneFuzzNodeDisposalStrategy.ToLowerInvariant()) switch { "decomission" => NodeDisposalStrategy.Decomission, _ => NodeDisposalStrategy.ScaleIn }; @@ -262,7 +240,7 @@ where NodeStateHelper.ReadyForReset.Contains(x.State) } - public async Async.Task ReimageNodes(Scaleset scaleSet, IEnumerable nodes, NodeDisposalStrategy disposalStrategy, INodeOperations _nodeOps) { + public async Async.Task ReimageNodes(Scaleset scaleSet, IEnumerable nodes, NodeDisposalStrategy disposalStrategy) { if (nodes is null || !nodes.Any()) { _log.Info($"{SCALESET_LOG_PREFIX} no nodes to reimage: scaleset_id: {scaleSet.ScalesetId}"); @@ -271,7 +249,7 @@ public async Async.Task ReimageNodes(Scaleset scaleSet, IEnumerable nodes, if (scaleSet.State == ScalesetState.Shutdown) { _log.Info($"{SCALESET_LOG_PREFIX} scaleset shutting down, deleting rather than reimaging nodes. scaleset_id: {scaleSet.ScalesetId}"); - await DeleteNodes(scaleSet, nodes, disposalStrategy, _nodeOps); + await DeleteNodes(scaleSet, nodes, disposalStrategy); return; } @@ -301,7 +279,7 @@ public async Async.Task ReimageNodes(Scaleset scaleSet, IEnumerable nodes, throw new NotImplementedException(); } - public async Async.Task DeleteNodes(Scaleset scaleSet, IEnumerable nodes, NodeDisposalStrategy disposalStrategy, INodeOperations _nodeOps) { + public async Async.Task DeleteNodes(Scaleset scaleSet, IEnumerable nodes, NodeDisposalStrategy disposalStrategy) { if (nodes is null || !nodes.Any()) { _log.Info($"{SCALESET_LOG_PREFIX} no nodes to delete: scaleset_id: {scaleSet.ScalesetId}"); return; @@ -309,7 +287,7 @@ public async Async.Task DeleteNodes(Scaleset scaleSet, IEnumerable nodes, foreach (var node in nodes) { - await _nodeOps.SetHalt(node); + await _context.NodeOperations.SetHalt(node); } if (scaleSet.State == ScalesetState.Halt) { diff --git a/src/ApiService/ApiService/onefuzzlib/TaskOperations.cs b/src/ApiService/ApiService/onefuzzlib/TaskOperations.cs index aed2bf4c81..31f5f8c34d 100644 --- a/src/ApiService/ApiService/onefuzzlib/TaskOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/TaskOperations.cs @@ -23,17 +23,11 @@ public interface ITaskOperations : IStatefulOrm { } public class TaskOperations : StatefulOrm, ITaskOperations { - private readonly IEvents _events; - private readonly IJobOperations _jobOperations; - private readonly IPoolOperations _poolOperations; - private readonly IScalesetOperations _scalesetOperations; - - public TaskOperations(IStorage storage, ILogTracer log, IServiceConfig config, IPoolOperations poolOperations, IScalesetOperations scalesetOperations, IEvents events, IJobOperations jobOperations) - : base(storage, log, config) { - _poolOperations = poolOperations; - _scalesetOperations = scalesetOperations; - _events = events; - _jobOperations = jobOperations; + + + public TaskOperations(ILogTracer log, IOnefuzzContext context) + : base(log, context) { + } public async Async.Task GetByTaskId(Guid taskId) { @@ -132,7 +126,7 @@ public async Async.Task SetState(Task task, TaskState state) { } await this.Replace(task); - + var _events = _context.Events; if (task.State == TaskState.Stopped) { if (task.Error != null) { await _events.SendEvent(new EventTaskFailed( @@ -167,9 +161,10 @@ private async Async.Task OnStart(Task task) { if (task.EndTime == null) { task = task with { EndTime = DateTimeOffset.UtcNow + TimeSpan.FromHours(task.Config.Task.Duration) }; - Job? job = await _jobOperations.Get(task.JobId); + var jobOperations = _context.JobOperations; + Job? job = await jobOperations.Get(task.JobId); if (job != null) { - await _jobOperations.OnStart(job); + await jobOperations.OnStart(job); } } @@ -187,14 +182,14 @@ private async Async.Task OnStart(Task task) { throw new Exception($"either pool or vm must be specified: {task.TaskId}"); } - var pool = await _poolOperations.GetByName(task.Config.Pool.PoolName); + var pool = await _context.PoolOperations.GetByName(task.Config.Pool.PoolName); if (!pool.IsOk) { _logTracer.Info($"unable to find pool from task: {task.TaskId}"); return null; } - var scaleset = await _scalesetOperations.SearchByPool(task.Config.Pool.PoolName).FirstOrDefaultAsync(); + var scaleset = await _context.ScalesetOperations.SearchByPool(task.Config.Pool.PoolName).FirstOrDefaultAsync(); if (scaleset == null) { _logTracer.Warning($"no scalesets are defined for task: {task.JobId}:{task.TaskId}"); @@ -225,7 +220,7 @@ public async Async.Task CheckPrereqTasks(Task task) { public async Async.Task GetPool(Task task) { if (task.Config.Pool != null) { - var pool = await _poolOperations.GetByName(task.Config.Pool.PoolName); + var pool = await _context.PoolOperations.GetByName(task.Config.Pool.PoolName); if (!pool.IsOk) { _logTracer.Info( $"unable to schedule task to pool: {task.TaskId} - {pool.ErrorV}" @@ -234,13 +229,13 @@ public async Async.Task CheckPrereqTasks(Task task) { } return pool.OkV; } else if (task.Config.Vm != null) { - var scalesets = _scalesetOperations.Search().Where(s => s.VmSku == task.Config.Vm.Sku && s.Image == task.Config.Vm.Image); + var scalesets = _context.ScalesetOperations.Search().Where(s => s.VmSku == task.Config.Vm.Sku && s.Image == task.Config.Vm.Image); await foreach (var scaleset in scalesets) { if (task.Config.Pool == null) { continue; } - var pool = await _poolOperations.GetByName(task.Config.Pool.PoolName); + var pool = await _context.PoolOperations.GetByName(task.Config.Pool.PoolName); if (!pool.IsOk) { _logTracer.Info( $"unable to schedule task to pool: {task.TaskId} - {pool.ErrorV}" diff --git a/src/ApiService/ApiService/onefuzzlib/WebhookOperations.cs b/src/ApiService/ApiService/onefuzzlib/WebhookOperations.cs index 756bae1b59..c93c31a569 100644 --- a/src/ApiService/ApiService/onefuzzlib/WebhookOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/WebhookOperations.cs @@ -15,18 +15,10 @@ public interface IWebhookOperations { public class WebhookOperations : Orm, IWebhookOperations { - private readonly IWebhookMessageLogOperations _webhookMessageLogOperations; - private readonly ILogTracer _log; - private readonly ICreds _creds; - private readonly IContainers _containers; private readonly IHttpClientFactory _httpFactory; - public WebhookOperations(IHttpClientFactory httpFactory, ICreds creds, IStorage storage, IWebhookMessageLogOperations webhookMessageLogOperations, IContainers containers, ILogTracer log, IServiceConfig config) - : base(storage, log, config) { - _webhookMessageLogOperations = webhookMessageLogOperations; - _log = log; - _creds = creds; - _containers = containers; + public WebhookOperations(IHttpClientFactory httpFactory, ILogTracer log, IOnefuzzContext context) + : base(log, context) { _httpFactory = httpFactory; } @@ -50,10 +42,10 @@ async private Async.Task AddEvent(Webhook webhook, EventMessage eventMessage) { TryCount: 0 ); - var r = await _webhookMessageLogOperations.Replace(message); + var r = await _context.WebhookMessageLogOperations.Replace(message); if (!r.IsOk) { var (status, reason) = r.ErrorV; - _log.Error($"Failed to replace webhook message log due to [{status}] {reason}"); + _logTracer.Error($"Failed to replace webhook message log due to [{status}] {reason}"); } } @@ -65,14 +57,14 @@ public async Async.Task Send(WebhookMessageLog messageLog) { var (data, digest) = await BuildMessage(webhookId: webhook.WebhookId, eventId: messageLog.EventId, eventType: messageLog.EventType, webhookEvent: messageLog.Event, secretToken: webhook.SecretToken, messageFormat: webhook.MessageFormat); - var headers = new Dictionary { { "User-Agent", $"onefuzz-webhook {_config.OneFuzzVersion}" } }; + var headers = new Dictionary { { "User-Agent", $"onefuzz-webhook {_context.ServiceConfiguration.OneFuzzVersion}" } }; if (digest != null) { headers["X-Onefuzz-Digest"] = digest; } var client = new Request(_httpFactory.CreateClient()); - _log.Info(data); + _logTracer.Info(data); var response = client.Post(url: webhook.Url, json: data, headers: headers); var result = response.Result; if (result.StatusCode == HttpStatusCode.Accepted) { @@ -81,16 +73,16 @@ public async Async.Task Send(WebhookMessageLog messageLog) { return false; } - // Not converting to bytes, as it's not neccessary in C#. Just keeping as string. + // Not converting to bytes, as it's not neccessary in C#. Just keeping as string. public async Async.Task> BuildMessage(Guid webhookId, Guid eventId, EventType eventType, BaseEvent webhookEvent, String? secretToken, WebhookMessageFormat? messageFormat) { var entityConverter = new EntityConverter(); string data = ""; if (messageFormat != null && messageFormat == WebhookMessageFormat.EventGrid) { - var eventGridMessage = new[] { new WebhookMessageEventGrid(Id: eventId, Data: webhookEvent, DataVersion: "1.0.0", Subject: _creds.GetInstanceName(), EventType: eventType, EventTime: DateTimeOffset.UtcNow) }; + var eventGridMessage = new[] { new WebhookMessageEventGrid(Id: eventId, Data: webhookEvent, DataVersion: "1.0.0", Subject: _context.Creds.GetInstanceName(), EventType: eventType, EventTime: DateTimeOffset.UtcNow) }; data = JsonSerializer.Serialize(eventGridMessage, options: EntityConverter.GetJsonSerializerOptions()); } else { - var instanceId = await _containers.GetInstanceId(); - var webhookMessage = new WebhookMessage(WebhookId: webhookId, EventId: eventId, EventType: eventType, Event: webhookEvent, InstanceId: instanceId, InstanceName: _creds.GetInstanceName()); + var instanceId = await _context.Containers.GetInstanceId(); + var webhookMessage = new WebhookMessage(WebhookId: webhookId, EventId: eventId, EventType: eventType, Event: webhookEvent, InstanceId: instanceId, InstanceName: _context.Creds.GetInstanceName()); data = JsonSerializer.Serialize(webhookMessage, options: EntityConverter.GetJsonSerializerOptions()); } @@ -128,14 +120,10 @@ public class WebhookMessageLogOperations : Orm, IWebhookMessa const int EXPIRE_DAYS = 7; const int MAX_TRIES = 5; - private readonly IQueue _queue; - private readonly ILogTracer _log; - private readonly IWebhookOperations _webhook; - public WebhookMessageLogOperations(IStorage storage, IQueue queue, ILogTracer log, IServiceConfig config, ICreds creds, IHttpClientFactory httpFactory, IContainers containers) : base(storage, log, config) { - _queue = queue; - _log = log; - _webhook = new WebhookOperations(httpFactory: httpFactory, creds: creds, storage: storage, webhookMessageLogOperations: this, containers: containers, log: log, config: config); + public WebhookMessageLogOperations(IHttpClientFactory httpFactory, ILogTracer log, IOnefuzzContext context) + : base(log, context) { + } @@ -149,14 +137,14 @@ public async Async.Task QueueWebhook(WebhookMessageLog webhookLog) { }; if (visibilityTimeout == null) { - _log.WithTags( + _logTracer.WithTags( new[] { ("WebhookId", webhookLog.WebhookId.ToString()), ("EventId", webhookLog.EventId.ToString()) } ). Error($"invalid WebhookMessage queue state, not queuing. {webhookLog.WebhookId}:{webhookLog.EventId} - {webhookLog.State}"); } else { - await _queue.QueueObject("webhooks", obj, StorageType.Config, visibilityTimeout: visibilityTimeout); + await _context.Queue.QueueObject("webhooks", obj, StorageType.Config, visibilityTimeout: visibilityTimeout); } } @@ -164,7 +152,7 @@ public async Async.Task ProcessFromQueue(WebhookMessageQueueObj obj) { var message = await GetWebhookMessageById(obj.WebhookId, obj.EventId); if (message == null) { - _log.WithTags( + _logTracer.WithTags( new[] { ("WebhookId", obj.WebhookId.ToString()), ("EventId", obj.EventId.ToString()) } @@ -178,7 +166,7 @@ public async Async.Task ProcessFromQueue(WebhookMessageQueueObj obj) { private async Async.Task Process(WebhookMessageLog message) { if (message.State == WebhookMessageState.Failed || message.State == WebhookMessageState.Succeeded) { - _log.WithTags( + _logTracer.WithTags( new[] { ("WebhookId", message.WebhookId.ToString()), ("EventId", message.EventId.ToString()) } @@ -189,29 +177,29 @@ private async Async.Task Process(WebhookMessageLog message) { var newMessage = message with { TryCount = message.TryCount + 1 }; - _log.Info($"sending webhook: {message.WebhookId}:{message.EventId}"); + _logTracer.Info($"sending webhook: {message.WebhookId}:{message.EventId}"); var success = await Send(newMessage); if (success) { newMessage = newMessage with { State = WebhookMessageState.Succeeded }; await Replace(newMessage); - _log.Info($"sent webhook event {newMessage.WebhookId}:{newMessage.EventId}"); + _logTracer.Info($"sent webhook event {newMessage.WebhookId}:{newMessage.EventId}"); } else if (newMessage.TryCount < MAX_TRIES) { newMessage = newMessage with { State = WebhookMessageState.Retrying }; await Replace(newMessage); await QueueWebhook(newMessage); - _log.Warning($"sending webhook event failed, re-queued {newMessage.WebhookId}:{newMessage.EventId}"); + _logTracer.Warning($"sending webhook event failed, re-queued {newMessage.WebhookId}:{newMessage.EventId}"); } else { newMessage = newMessage with { State = WebhookMessageState.Failed }; await Replace(newMessage); - _log.Info($"sending webhook: {newMessage.WebhookId} event: {newMessage.EventId} failed {newMessage.TryCount} times."); + _logTracer.Info($"sending webhook: {newMessage.WebhookId} event: {newMessage.EventId} failed {newMessage.TryCount} times."); } } private async Async.Task Send(WebhookMessageLog message) { - var webhook = await _webhook.GetByWebhookId(message.WebhookId); + var webhook = await _context.WebhookOperations.GetByWebhookId(message.WebhookId); if (webhook == null) { - _log.WithTags( + _logTracer.WithTags( new[] { ("WebhookId", message.WebhookId.ToString()), } @@ -221,9 +209,9 @@ private async Async.Task Send(WebhookMessageLog message) { } try { - return await _webhook.Send(message); + return await _context.WebhookOperations.Send(message); } catch (Exception exc) { - _log.WithTags( + _logTracer.WithTags( new[] { ("WebhookId", message.WebhookId.ToString()) } diff --git a/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs b/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs index 76a9fb4968..f657d8ab37 100644 --- a/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs +++ b/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs @@ -276,9 +276,9 @@ public T ToRecord(TableEntity entity) where T : EntityBase { entityRecord.TimeStamp = entity.Timestamp; return entityRecord; - } catch (Exception) { + } catch (Exception ex) { var stringParam = string.Join(", ", parameters); - throw new Exception($"Could not initialize object of type {typeof(T)} with the following parameters: {stringParam} constructor {entityInfo.constructor}"); + throw new Exception($"Could not initialize object of type {typeof(T)} with the following parameters: {stringParam} constructor {entityInfo.constructor} : {ex}"); } } diff --git a/src/ApiService/ApiService/onefuzzlib/orm/Orm.cs b/src/ApiService/ApiService/onefuzzlib/orm/Orm.cs index a843db97b7..0104662ac3 100644 --- a/src/ApiService/ApiService/onefuzzlib/orm/Orm.cs +++ b/src/ApiService/ApiService/onefuzzlib/orm/Orm.cs @@ -19,18 +19,15 @@ public interface IOrm where T : EntityBase { public class Orm : IOrm where T : EntityBase { - protected readonly IStorage _storage; protected readonly EntityConverter _entityConverter; + protected readonly IOnefuzzContext _context; protected readonly ILogTracer _logTracer; - protected readonly IServiceConfig _config; - - public Orm(IStorage storage, ILogTracer logTracer, IServiceConfig config) { - _storage = storage; - _entityConverter = new EntityConverter(); + public Orm(ILogTracer logTracer, IOnefuzzContext context) { + _context = context; _logTracer = logTracer; - _config = config; + _entityConverter = new EntityConverter(); } public async IAsyncEnumerable QueryAsync(string? filter = null) { @@ -87,8 +84,8 @@ public async Task GetEntityAsync(string partitionKey, string rowKey) { } public async Task GetTableClient(string table, string? accountId = null) { - var account = accountId ?? _config.OneFuzzFuncStorage ?? throw new ArgumentNullException(nameof(accountId)); - var (name, key) = await _storage.GetStorageAccountNameAndKey(account); + var account = accountId ?? _context.ServiceConfiguration.OneFuzzFuncStorage ?? throw new ArgumentNullException(nameof(accountId)); + var (name, key) = await _context.Storage.GetStorageAccountNameAndKey(account); var tableClient = new TableServiceClient(new Uri($"https://{name}.table.core.windows.net"), new TableSharedKeyCredential(name, key)); await tableClient.CreateTableIfNotExistsAsync(table); return tableClient.GetTableClient(table); @@ -134,7 +131,7 @@ static StatefulOrm() { }; } - public StatefulOrm(IStorage storage, ILogTracer logTracer, IServiceConfig config) : base(storage, logTracer, config) { + public StatefulOrm(ILogTracer logTracer, IOnefuzzContext context) : base(logTracer, context) { } ///