diff --git a/src/ApiService/ApiService/onefuzzlib/NodeMessageOperations.cs b/src/ApiService/ApiService/onefuzzlib/NodeMessageOperations.cs index dcffbfc7b4..48df6bb3c1 100644 --- a/src/ApiService/ApiService/onefuzzlib/NodeMessageOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/NodeMessageOperations.cs @@ -31,11 +31,10 @@ public IAsyncEnumerable GetMessage(Guid machineId) public async Async.Task ClearMessages(Guid machineId) { _logTracer.Info($"clearing messages for node {machineId:Tag:MachineId}"); - await foreach (var message in GetMessage(machineId)) { - var r = await Delete(message); - if (!r.IsOk) { - _logTracer.WithHttpStatus(r.ErrorV).Error($"failed to delete message for node {machineId:Tag:MachineId}"); - } + var result = await DeleteAll(new (string?, string?)[] { (machineId.ToString(), null) }); + + if (result.FailureCount > 0) { + _logTracer.Error($"failed to delete {result.FailureCount:Tag:FailedDeleteMessageCount} messages for node {machineId:Tag:MachineId}"); } } diff --git a/src/ApiService/ApiService/onefuzzlib/Utils.cs b/src/ApiService/ApiService/onefuzzlib/Utils.cs index c1ff87cdb1..3ecd7ce97d 100644 --- a/src/ApiService/ApiService/onefuzzlib/Utils.cs +++ b/src/ApiService/ApiService/onefuzzlib/Utils.cs @@ -15,3 +15,27 @@ public static T EnsureNotNull(this T? thisObject, string message) { public static Async.Task IgnoreResult(this Async.Task task) => task; } + +public static class IAsyncEnumerableExtension { + public static async IAsyncEnumerable> Chunk(this IAsyncEnumerable source, int size) { + + if (size <= 0) { + throw new ArgumentException("size must be greater than 0"); + } + + var enumerator = source.GetAsyncEnumerator(); + List result = new List(size); + while (await enumerator.MoveNextAsync()) { + result.Add(enumerator.Current); + + if (result.Count == size) { + yield return result; + result = new List(size); + } + } + + if (result.Count > 0) { + yield return result; + } + } +} diff --git a/src/ApiService/ApiService/onefuzzlib/orm/Orm.cs b/src/ApiService/ApiService/onefuzzlib/orm/Orm.cs index 87bacedaf6..84ba5b115d 100644 --- a/src/ApiService/ApiService/onefuzzlib/orm/Orm.cs +++ b/src/ApiService/ApiService/onefuzzlib/orm/Orm.cs @@ -1,11 +1,13 @@ using System.Collections.Concurrent; using System.Reflection; using System.Threading.Tasks; +using Azure; using Azure.Core; using Azure.Data.Tables; using Microsoft.OneFuzz.Service; using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; + namespace ApiService.OneFuzzLib.Orm { public interface IOrm where T : EntityBase { Task GetTableClient(string table, ResourceIdentifier? accountId = null); @@ -17,6 +19,8 @@ public interface IOrm where T : EntityBase { Task> Update(T entity); Task> Delete(T entity); + Task DeleteAll(IEnumerable<(string?, string?)> keys); + IAsyncEnumerable SearchAll(); IAsyncEnumerable SearchByPartitionKeys(IEnumerable partitionKeys); IAsyncEnumerable SearchByRowKeys(IEnumerable rowKeys); @@ -27,6 +31,7 @@ IAsyncEnumerable SearchByTimeRange((DateTimeOffset min, DateTimeOffset max) r => SearchByTimeRange(range.min, range.max); } + public record DeleteAllResult(int SuccessCount, int FailureCount); public abstract class Orm : IOrm where T : EntityBase { #pragma warning disable CA1051 // permit visible instance fields @@ -35,6 +40,7 @@ public abstract class Orm : IOrm where T : EntityBase { protected readonly ILogTracer _logTracer; #pragma warning restore CA1051 + const int MAX_TRANSACTION_SIZE = 100; public Orm(ILogTracer logTracer, IOnefuzzContext context) { _context = context; @@ -61,6 +67,7 @@ public async IAsyncEnumerable QueryAsync(string? filter = null) { var tableEntity = _entityConverter.ToTableEntity(entity); var response = await tableClient.AddEntityAsync(tableEntity); + if (response.IsError) { return ResultVoid<(int, string)>.Error((response.Status, response.ReasonPhrase)); } else { @@ -134,6 +141,56 @@ public IAsyncEnumerable SearchByRowKeys(IEnumerable rowKeys) public IAsyncEnumerable SearchByTimeRange(DateTimeOffset min, DateTimeOffset max) { return QueryAsync(Query.TimeRange(min, max)); } + + public async Task>> BatchOperation(IAsyncEnumerable entities, TableTransactionActionType actionType) { + var tableClient = await GetTableClient(typeof(T).Name); + var transactions = await entities.Select(e => new TableTransactionAction(actionType, _entityConverter.ToTableEntity(e))).ToListAsync(); + var responses = await tableClient.SubmitTransactionAsync(transactions); + return responses.Value.Select(response => + response.IsError ? ResultVoid<(int, string)>.Error((response.Status, response.ReasonPhrase)) : ResultVoid<(int, string)>.Ok() + ).ToList(); + } + + + public async Task DeleteAll(IEnumerable<(string?, string?)> keys) { + var query = Query.Or( + keys.Select(key => + key switch { + (null, null) => throw new ArgumentException("partitionKey and rowKey cannot both be null"), + (string partitionKey, null) => Query.PartitionKey(partitionKey), + (null, string rowKey) => Query.RowKey(rowKey), + (string partitionKey, string rowKey) => Query.And( + Query.PartitionKey(partitionKey), + Query.RowKey(rowKey) + ), + } + ) + ); + + var tableClient = await GetTableClient(typeof(T).Name); + var pages = tableClient.QueryAsync(query, select: new[] { "PartitionKey, RowKey" }); + + var requests = await pages + .Chunk(MAX_TRANSACTION_SIZE) + .Select(chunk => { + var transactions = chunk.Select(e => new TableTransactionAction(TableTransactionActionType.Delete, e)); + return tableClient.SubmitTransactionAsync(transactions); + }) + .ToListAsync(); + + var responses = await System.Threading.Tasks.Task.WhenAll(requests); + var (successes, failures) = responses + .SelectMany(x => x.Value) + .Aggregate( + (0, 0), + ((int Successes, int Failures) acc, Response current) => + current.IsError + ? (acc.Successes, acc.Failures + 1) + : (acc.Successes + 1, acc.Failures) + ); + + return new DeleteAllResult(successes, failures); + } }