From fa4fb6ff44054550b715cdf09c16fa949c1bd1f7 Mon Sep 17 00:00:00 2001 From: George Pollard Date: Wed, 22 Jun 2022 03:45:05 +0000 Subject: [PATCH] Implement `node` function --- src/ApiService/ApiService/Node.cs | 179 ++++++++++++++++ .../ApiService/OneFuzzTypes/Events.cs | 1 - .../ApiService/OneFuzzTypes/Model.cs | 20 +- .../ApiService/OneFuzzTypes/Requests.cs | 16 ++ .../ApiService/OneFuzzTypes/Responses.cs | 22 +- .../ApiService/OneFuzzTypes/ReturnTypes.cs | 8 +- .../ApiService/OneFuzzTypes/Validated.cs | 142 +++++++++++++ .../TestHooks/NodeOperationsTestHooks.cs | 6 +- .../TestHooks/PoolOperationsTestHooks.cs | 2 +- src/ApiService/ApiService/UserCredentials.cs | 2 +- .../onefuzzlib/EndpointAuthorization.cs | 57 ++++- .../ApiService/onefuzzlib/NodeOperations.cs | 17 +- .../ApiService/onefuzzlib/NsgOperations.cs | 6 +- .../ApiService/onefuzzlib/PoolOperations.cs | 6 +- .../onefuzzlib/ScalesetOperations.cs | 6 +- .../ApiService/onefuzzlib/Scheduler.cs | 6 +- .../ApiService/onefuzzlib/VmssOperations.cs | 6 +- .../onefuzzlib/orm/EntityConverter.cs | 16 +- .../ApiService/onefuzzlib/orm/Queries.cs | 3 + src/ApiService/Tests/Fakes/TestContext.cs | 9 +- .../Tests/Fakes/TestEndpointAuthorization.cs | 6 +- .../Tests/Fakes/TestServiceConfiguration.cs | 4 +- .../Tests/Fakes/TestUserCredentials.cs | 19 ++ .../Tests/Functions/AgentEventsTests.cs | 2 +- src/ApiService/Tests/Functions/InfoTests.cs | 6 +- src/ApiService/Tests/Functions/NodeTests.cs | 200 ++++++++++++++++++ .../Tests/Functions/_FunctionTestBase.cs | 5 + src/ApiService/Tests/OrmModelsTest.cs | 82 +++---- src/ApiService/Tests/OrmTest.cs | 2 +- src/ApiService/Tests/ValidatedStringTests.cs | 28 +++ 30 files changed, 774 insertions(+), 110 deletions(-) create mode 100644 src/ApiService/ApiService/Node.cs create mode 100644 src/ApiService/ApiService/OneFuzzTypes/Validated.cs create mode 100644 src/ApiService/Tests/Fakes/TestUserCredentials.cs create mode 100644 src/ApiService/Tests/Functions/NodeTests.cs create mode 100644 src/ApiService/Tests/ValidatedStringTests.cs diff --git a/src/ApiService/ApiService/Node.cs b/src/ApiService/ApiService/Node.cs new file mode 100644 index 00000000000..65ee6f22264 --- /dev/null +++ b/src/ApiService/ApiService/Node.cs @@ -0,0 +1,179 @@ +using System.Threading.Tasks; +using Microsoft.Azure.Functions.Worker; +using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; + +namespace Microsoft.OneFuzz.Service; + +public class NodeFunction { + private readonly ILogTracer _log; + private readonly IEndpointAuthorization _auth; + private readonly IOnefuzzContext _context; + + public NodeFunction(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + _log = log; + _auth = auth; + _context = context; + } + + private static readonly EntityConverter _entityConverter = new(); + + // [Function("Node") + public Async.Task Run([HttpTrigger("GET", "PATCH", "POST", "DELETE")] HttpRequestData req) { + return _auth.CallIfUser(req, r => r.Method switch { + "GET" => Get(r), + "PATCH" => Patch(r), + "POST" => Post(r), + "DELETE" => Delete(r), + _ => throw new InvalidOperationException("Unsupported HTTP method"), + }); + } + + private async Async.Task Get(HttpRequestData req) { + var request = await RequestHandling.ParseRequest(req); + if (!request.IsOk) { + return await _context.RequestHandling.NotOk(req, request.ErrorV, "pool get"); + } + + var search = request.OkV; + if (search.MachineId is Guid machineId) { + var node = await _context.NodeOperations.GetByMachineId(machineId); + if (node is null) { + return await _context.RequestHandling.NotOk( + req, + new Error( + Code: ErrorCode.UNABLE_TO_FIND, + Errors: new string[] { "unable to find node " }), + context: machineId.ToString()); + } + + var (tasks, messages) = await ( + _context.NodeTasksOperations.GetByMachineId(machineId).ToListAsync().AsTask(), + _context.NodeMessageOperations.GetMessage(machineId).ToListAsync().AsTask()); + + var commands = messages.Select(m => m.Message).ToList(); + return await RequestHandling.Ok(req, NodeToNodeSearchResult(node with { Tasks = tasks, Messages = commands })); + } + + var nodes = await _context.NodeOperations.SearchStates( + states: search.State, + poolName: search.PoolName, + scaleSetId: search.ScalesetId).ToListAsync(); + + return await RequestHandling.Ok(req, nodes.Select(NodeToNodeSearchResult)); + } + + private static NodeSearchResult NodeToNodeSearchResult(Node node) { + return new NodeSearchResult( + PoolId: node.PoolId, + PoolName: node.PoolName, + MachineId: node.MachineId, + Version: node.Version, + Heartbeat: node.Heartbeat, + InitializedAt: node.InitializedAt, + State: node.State, + ScalesetId: node.ScalesetId, + ReimageRequested: node.ReimageRequested, + DeleteRequested: node.DeleteRequested, + DebugKeepNode: node.DebugKeepNode); + } + + private async Async.Task Patch(HttpRequestData req) { + var request = await RequestHandling.ParseRequest(req); + if (!request.IsOk) { + return await _context.RequestHandling.NotOk( + req, + request.ErrorV, + "NodeReimage"); + } + + var authCheck = await _auth.CheckRequireAdmins(req); + if (!authCheck.IsOk) { + return await _context.RequestHandling.NotOk(req, authCheck.ErrorV, "NodeReimage"); + } + + var patch = request.OkV; + var node = await _context.NodeOperations.GetByMachineId(patch.MachineId); + if (node is null) { + return await _context.RequestHandling.NotOk( + req, + new Error( + Code: ErrorCode.UNABLE_TO_FIND, + Errors: new string[] { "unable to find node " }), + context: patch.MachineId.ToString()); + } + + await _context.NodeOperations.Stop(node, done: true); + if (node.DebugKeepNode) { + await _context.NodeOperations.Replace(node with { DebugKeepNode = false }); + } + + return await RequestHandling.Ok(req, true); + } + + private async Async.Task Post(HttpRequestData req) { + var request = await RequestHandling.ParseRequest(req); + if (!request.IsOk) { + return await _context.RequestHandling.NotOk( + req, + request.ErrorV, + "NodeUpdate"); + } + + var authCheck = await _auth.CheckRequireAdmins(req); + if (!authCheck.IsOk) { + return await _context.RequestHandling.NotOk(req, authCheck.ErrorV, "NodeUpdate"); + } + + var post = request.OkV; + var node = await _context.NodeOperations.GetByMachineId(post.MachineId); + if (node is null) { + return await _context.RequestHandling.NotOk( + req, + new Error( + Code: ErrorCode.UNABLE_TO_FIND, + Errors: new string[] { "unable to find node " }), + context: post.MachineId.ToString()); + } + + if (post.DebugKeepNode is bool value) { + node = node with { DebugKeepNode = value }; + } + + await _context.NodeOperations.Replace(node); + return await RequestHandling.Ok(req, true); + } + + private async Async.Task Delete(HttpRequestData req) { + var request = await RequestHandling.ParseRequest(req); + if (!request.IsOk) { + return await _context.RequestHandling.NotOk( + req, + request.ErrorV, + context: "NodeDelete"); + } + + var authCheck = await _auth.CheckRequireAdmins(req); + if (!authCheck.IsOk) { + return await _context.RequestHandling.NotOk(req, authCheck.ErrorV, "NodeDelete"); + } + + var delete = request.OkV; + var node = await _context.NodeOperations.GetByMachineId(delete.MachineId); + if (node is null) { + return await _context.RequestHandling.NotOk( + req, + new Error( + Code: ErrorCode.UNABLE_TO_FIND, + new string[] { "unable to find node" }), + context: delete.MachineId.ToString()); + } + + await _context.NodeOperations.SetHalt(node); + if (node.DebugKeepNode) { + await _context.NodeOperations.Replace(node with { DebugKeepNode = false }); + } + + return await RequestHandling.Ok(req, true); + } +} diff --git a/src/ApiService/ApiService/OneFuzzTypes/Events.cs b/src/ApiService/ApiService/OneFuzzTypes/Events.cs index 6e75a1cd2aa..0cc1f6fcb84 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Events.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Events.cs @@ -1,7 +1,6 @@ using System.Text.Json; using System.Text.Json.Serialization; using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; -using PoolName = System.String; using Region = System.String; namespace Microsoft.OneFuzz.Service; diff --git a/src/ApiService/ApiService/OneFuzzTypes/Model.cs b/src/ApiService/ApiService/OneFuzzTypes/Model.cs index 7fd86ca08cf..38876528cb0 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Model.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Model.cs @@ -3,7 +3,6 @@ using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; using Endpoint = System.String; using GroupId = System.Guid; -using PoolName = System.String; using PrincipalId = System.Guid; using Region = System.String; @@ -152,7 +151,6 @@ public record Error(ErrorCode Code, string[]? Errors = null); public record UserInfo(Guid? ApplicationId, Guid? ObjectId, String? Upn); - public record TaskDetails( TaskType Type, int Duration, @@ -316,11 +314,11 @@ public record InstanceConfig NetworkSecurityGroupConfig ProxyNsgConfig, AzureVmExtensionConfig? Extensions, string ProxyVmSku, - IDictionary? ApiAccessRules, - IDictionary? GroupMembership, - - IDictionary? VmTags, - IDictionary? VmssTags + IDictionary? ApiAccessRules = null, + IDictionary? GroupMembership = null, + IDictionary? VmTags = null, + IDictionary? VmssTags = null, + bool? RequireAdminPrivileges = null ) : EntityBase() { public InstanceConfig(string instanceName) : this( instanceName, @@ -330,12 +328,7 @@ public InstanceConfig(string instanceName) : this( new NetworkConfig(), new NetworkSecurityGroupConfig(), null, - "Standard_B2s", - null, - null, - null, - null) { } - + "Standard_B2s") { } public InstanceConfig() : this(String.Empty) { } public List? CheckAdmins(List? value) { @@ -346,7 +339,6 @@ public InstanceConfig() : this(String.Empty) { } } } - //# At the moment, this only checks allowed_aad_tenants, however adding //# support for 3rd party JWT validation is anticipated in a future release. public ResultVoid> CheckInstanceConfig() { diff --git a/src/ApiService/ApiService/OneFuzzTypes/Requests.cs b/src/ApiService/ApiService/OneFuzzTypes/Requests.cs index 7f7432f9f10..3ea21e0719b 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Requests.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Requests.cs @@ -18,6 +18,22 @@ public record NodeCommandDelete( string MessageId ) : BaseRequest; +public record NodeGet( + Guid MachineId +) : BaseRequest; + +public record NodeUpdate( + Guid MachineId, + bool? DebugKeepNode +) : BaseRequest; + +public record NodeSearch( + Guid? MachineId = null, + List? State = null, + Guid? ScalesetId = null, + PoolName? PoolName = null +) : BaseRequest; + public record NodeStateEnvelope( NodeEventBase Event, Guid MachineId diff --git a/src/ApiService/ApiService/OneFuzzTypes/Responses.cs b/src/ApiService/ApiService/OneFuzzTypes/Responses.cs index 38785e4791c..ad66e277a1b 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Responses.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Responses.cs @@ -4,7 +4,10 @@ namespace Microsoft.OneFuzz.Service; [JsonConverter(typeof(BaseResponseConverter))] -public abstract record BaseResponse(); +public abstract record BaseResponse() { + public static implicit operator BaseResponse(bool value) + => new BoolResult(value); +}; public record CanSchedule( bool Allowed, @@ -15,6 +18,23 @@ public record PendingNodeCommand( NodeCommandEnvelope? Envelope ) : BaseResponse(); +// TODO: not sure how much of this is actually +// needed in the search results, so at the moment +// it is a copy of the whole Node type +public record NodeSearchResult( + PoolName PoolName, + Guid MachineId, + Guid? PoolId, + string Version, + DateTimeOffset? Heartbeat, + DateTimeOffset? InitializedAt, + NodeState State, + Guid? ScalesetId, + bool ReimageRequested, + bool DeleteRequested, + bool DebugKeepNode +) : BaseResponse(); + public record BoolResult( bool Result ) : BaseResponse(); diff --git a/src/ApiService/ApiService/OneFuzzTypes/ReturnTypes.cs b/src/ApiService/ApiService/OneFuzzTypes/ReturnTypes.cs index 769c521f82c..b1f71c9b2f5 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/ReturnTypes.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/ReturnTypes.cs @@ -58,6 +58,9 @@ public struct OneFuzzResult { public static OneFuzzResult Error(ErrorCode errorCode, string error) => new(errorCode, new[] { error }); public static OneFuzzResult Error(Error err) => new(err); + + // Allow simple conversion of Errors to Results. + public static implicit operator OneFuzzResult(Error err) => new(err); } @@ -77,10 +80,13 @@ public struct OneFuzzResultVoid { private OneFuzzResultVoid(Error err) => (error, isOk) = (err, false); - public static OneFuzzResultVoid Ok() => new(); + public static OneFuzzResultVoid Ok => new(); public static OneFuzzResultVoid Error(ErrorCode errorCode, string[] errors) => new(errorCode, errors); public static OneFuzzResultVoid Error(ErrorCode errorCode, string error) => new(errorCode, new[] { error }); public static OneFuzzResultVoid Error(Error err) => new(err); + + // Allow simple conversion of Errors to Results. + public static implicit operator OneFuzzResultVoid(Error err) => new(err); } diff --git a/src/ApiService/ApiService/OneFuzzTypes/Validated.cs b/src/ApiService/ApiService/OneFuzzTypes/Validated.cs new file mode 100644 index 00000000000..965321d4591 --- /dev/null +++ b/src/ApiService/ApiService/OneFuzzTypes/Validated.cs @@ -0,0 +1,142 @@ + +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Text.RegularExpressions; + +namespace Microsoft.OneFuzz.Service; + +static class Check { + private static readonly Regex _isAlnum = new(@"\A[a-zA-Z0-9]+\z", RegexOptions.Compiled); + public static bool IsAlnum(string input) => _isAlnum.IsMatch(input); + + private static readonly Regex _isAlnumDash = new(@"\A[a-zA-Z0-9\-]+\z", RegexOptions.Compiled); + public static bool IsAlnumDash(string input) => _isAlnumDash.IsMatch(input); +} + +// Base class for types that are wrappers around a validated string. +public abstract record ValidatedString(string String) { + public sealed override string ToString() => String; +} + +// JSON converter for types that are wrappers around a validated string. +public abstract class ValidatedStringConverter : JsonConverter where T : ValidatedString { + protected abstract bool TryParse(string input, out T? output); + + public sealed override bool CanConvert(Type typeToConvert) + => typeToConvert == typeof(T); + + public sealed override T? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { + if (reader.TokenType != JsonTokenType.String) { + throw new JsonException("expected a string"); + } + + var value = reader.GetString(); + if (value is null) { + throw new JsonException("expected a string"); + } + + if (TryParse(value, out var result)) { + return result; + } else { + throw new JsonException($"unable to parse input as a {typeof(T).Name}"); + } + } + + public sealed override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options) + => writer.WriteStringValue(value.String); +} + +[JsonConverter(typeof(Converter))] +public record PoolName : ValidatedString { + private PoolName(string value) : base(value) { + Debug.Assert(Check.IsAlnumDash(value)); + } + + public static PoolName Parse(string input) { + if (TryParse(input, out var result)) { + return result; + } + + throw new ArgumentException("Pool name must have only numbers, letters or dashes"); + } + + public static bool TryParse(string input, [NotNullWhen(returnValue: true)] out PoolName? result) { + if (!Check.IsAlnumDash(input)) { + result = default; + return false; + } + + result = new PoolName(input); + return true; + } + + public sealed class Converter : ValidatedStringConverter { + protected override bool TryParse(string input, out PoolName? output) + => PoolName.TryParse(input, out output); + } +} + +/* TODO: to be enabled in a separate PR + +[JsonConverter(typeof(Converter))] +public record Region : ValidatedString { + private Region(string value) : base(value) { + Debug.Assert(Check.IsAlnum(value)); + } + + public static Region Parse(string input) { + if (TryParse(input, out var result)) { + return result; + } + + throw new ArgumentException("Region name must have only numbers, letters or dashes"); + } + + public static bool TryParse(string input, [NotNullWhen(returnValue: true)] out Region? result) { + if (!Check.IsAlnum(input)) { + result = default; + return false; + } + + result = new Region(input); + return true; + } + + public sealed class Converter : ValidatedStringConverter { + protected override bool TryParse(string input, out Region? output) + => Region.TryParse(input, out output); + } +} + +[JsonConverter(typeof(Converter))] +public record Container : ValidatedString { + private Container(string value) : base(value) { + Debug.Assert(Check.IsAlnumDash(value)); + } + + public static Container Parse(string input) { + if (TryParse(input, out var result)) { + return result; + } + + throw new ArgumentException("Container name must have only numbers, letters or dashes"); + } + + public static bool TryParse(string input, [NotNullWhen(returnValue: true)] out Container? result) { + if (!Check.IsAlnumDash(input)) { + result = default; + return false; + } + + result = new Container(input); + return true; + } + + public sealed class Converter : ValidatedStringConverter { + protected override bool TryParse(string input, out Container? output) + => Container.TryParse(input, out output); + } +} +*/ diff --git a/src/ApiService/ApiService/TestHooks/NodeOperationsTestHooks.cs b/src/ApiService/ApiService/TestHooks/NodeOperationsTestHooks.cs index d5c3a31626b..f4e7dd98535 100644 --- a/src/ApiService/ApiService/TestHooks/NodeOperationsTestHooks.cs +++ b/src/ApiService/ApiService/TestHooks/NodeOperationsTestHooks.cs @@ -165,7 +165,9 @@ public async Task SearchStates([HttpTrigger(AuthorizationLevel if (query.ContainsKey("states")) { states = query["states"].Split('-').Select(s => Enum.Parse(s)).ToList(); } - string? poolName = UriExtension.GetString("poolName", query); + string? poolNameString = UriExtension.GetString("poolName", query); + + PoolName? poolName = poolNameString is null ? null : PoolName.Parse(poolNameString); var excludeUpdateScheduled = UriExtension.GetBool("excludeUpdateScheduled", query, false); int? numResults = UriExtension.GetInt("numResults", query); @@ -209,7 +211,7 @@ public async Task CreateNode([HttpTrigger(AuthorizationLevel.A var query = UriExtension.GetQueryComponents(req.Url); Guid poolId = Guid.Parse(query["poolId"]); - string poolName = query["poolName"]; + var poolName = PoolName.Parse(query["poolName"]); Guid machineId = Guid.Parse(query["machineId"]); Guid? scaleSetId = default; diff --git a/src/ApiService/ApiService/TestHooks/PoolOperationsTestHooks.cs b/src/ApiService/ApiService/TestHooks/PoolOperationsTestHooks.cs index 75a40c8c8a4..435f3925877 100644 --- a/src/ApiService/ApiService/TestHooks/PoolOperationsTestHooks.cs +++ b/src/ApiService/ApiService/TestHooks/PoolOperationsTestHooks.cs @@ -25,7 +25,7 @@ public async Task GetPool([HttpTrigger(AuthorizationLevel.Anon _log.Info("get pool"); var query = UriExtension.GetQueryComponents(req.Url); - var poolRes = await _poolOps.GetByName(query["name"]); + var poolRes = await _poolOps.GetByName(PoolName.Parse(query["name"])); if (poolRes.IsOk) { var resp = req.CreateResponse(HttpStatusCode.OK); diff --git a/src/ApiService/ApiService/UserCredentials.cs b/src/ApiService/ApiService/UserCredentials.cs index d1e7aa8d1aa..8a152eaea57 100644 --- a/src/ApiService/ApiService/UserCredentials.cs +++ b/src/ApiService/ApiService/UserCredentials.cs @@ -58,7 +58,7 @@ from t in r.AllowedAadTenants return OneFuzzResult.Ok(allowedAddTenantsQuery.ToArray()); } - public async Task> ParseJwtToken(HttpRequestData req) { + public virtual async Task> ParseJwtToken(HttpRequestData req) { var authToken = GetAuthToken(req); if (authToken is null) { return OneFuzzResult.Error(ErrorCode.INVALID_REQUEST, new[] { "unable to find authorization token" }); diff --git a/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs b/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs index 2ef045eb3b1..67562a2dc28 100644 --- a/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs +++ b/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs @@ -19,6 +19,8 @@ Async.Task CallIf( Func> method, bool allowUser = false, bool allowAgent = false); + + Async.Task CheckRequireAdmins(HttpRequestData req); } public class EndpointAuthorization : IEndpointAuthorization { @@ -30,7 +32,7 @@ public EndpointAuthorization(IOnefuzzContext context, ILogTracer log) { _log = log; } - public async Async.Task CallIf(HttpRequestData req, Func> method, bool allowUser = false, bool allowAgent = false) { + public virtual async Async.Task CallIf(HttpRequestData req, Func> method, bool allowUser = false, bool allowAgent = false) { var tokenResult = await _context.UserCredentials.ParseJwtToken(req); if (!tokenResult.IsOk) { @@ -77,6 +79,59 @@ public async Async.Task Reject(HttpRequestData req, UserInfo t ); } + public async Async.Task CheckRequireAdmins(HttpRequestData req) { + var tokenResult = await _context.UserCredentials.ParseJwtToken(req); + if (!tokenResult.IsOk) { + return tokenResult.ErrorV; + } + + var config = await _context.ConfigOperations.Fetch(); + if (config is null) { + return new Error( + Code: ErrorCode.INVALID_CONFIGURATION, + Errors: new string[] { "no instance configuration found " }); + } + + return CheckRequireAdminsImpl(config, tokenResult.OkV); + } + + private static OneFuzzResultVoid CheckRequireAdminsImpl(InstanceConfig config, UserInfo userInfo) { + // When there are no admins in the `admins` list, all users are considered + // admins. However, `require_admin_privileges` is still useful to protect from + // mistakes. + // + // To make changes while still protecting against accidental changes to + // pools, do the following: + // + // 1. set `require_admin_privileges` to `False` + // 2. make the change + // 3. set `require_admin_privileges` to `True` + + if (config.RequireAdminPrivileges == false) { + return OneFuzzResultVoid.Ok; + } + + if (config.Admins is null) { + return new Error( + Code: ErrorCode.UNAUTHORIZED, + Errors: new string[] { "pool modification disabled " }); + } + + if (userInfo.ObjectId is Guid objectId) { + if (config.Admins.Contains(objectId)) { + return OneFuzzResultVoid.Ok; + } + + return new Error( + Code: ErrorCode.UNAUTHORIZED, + Errors: new string[] { "not authorized to manage pools" }); + } else { + return new Error( + Code: ErrorCode.UNAUTHORIZED, + Errors: new string[] { "user had no Object ID" }); + } + } + public OneFuzzResultVoid CheckAccess(HttpRequestData req) { throw new NotImplementedException(); } diff --git a/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs b/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs index 24b40c9c66d..9a0a426672c 100644 --- a/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs @@ -22,7 +22,7 @@ public interface INodeOperations : IStatefulOrm { IAsyncEnumerable SearchStates(Guid? poolId = default, Guid? scaleSetId = default, IEnumerable? states = default, - string? poolName = default, + PoolName? poolName = default, bool excludeUpdateScheduled = false, int? numResults = default); @@ -32,7 +32,7 @@ IAsyncEnumerable SearchStates(Guid? poolId = default, Async.Task Create( Guid poolId, - string poolName, + PoolName poolName, Guid machineId, Guid? scaleSetId, string version, @@ -67,7 +67,7 @@ public async Task AcquireScaleInProtection(Node node) { _logTracer.Info($"Setting scale-in protection on node {node.MachineId}"); return await _context.VmssOperations.UpdateScaleInProtection((Guid)node.ScalesetId, node.MachineId, protectFromScaleIn: true); } - return OneFuzzResultVoid.Ok(); + return OneFuzzResultVoid.Ok; } public async Async.Task ScalesetNodeExists(Node node) { @@ -207,7 +207,7 @@ public IAsyncEnumerable GetDeadNodes(Guid scaleSetId, TimeSpan expirationP public async Async.Task Create( Guid poolId, - string poolName, + PoolName poolName, Guid machineId, Guid? scaleSetId, string version, @@ -308,7 +308,7 @@ public static string SearchStatesQuery( Guid? poolId = default, Guid? scaleSetId = default, IEnumerable? states = default, - string? poolName = default, + PoolName? poolName = default, bool excludeUpdateScheduled = false, int? numResults = default) { @@ -346,7 +346,7 @@ public IAsyncEnumerable SearchStates( Guid? poolId = default, Guid? scaleSetId = default, IEnumerable? states = default, - string? poolName = default, + PoolName? poolName = default, bool excludeUpdateScheduled = false, int? numResults = default) { var query = NodeOperations.SearchStatesQuery(_context.ServiceConfiguration.OneFuzzVersion, poolId, scaleSetId, states, poolName, excludeUpdateScheduled, numResults); @@ -467,9 +467,8 @@ public NodeMessageOperations(ILogTracer log, IOnefuzzContext context) : base(log _log = log; } - public IAsyncEnumerable GetMessage(Guid machineId) { - return QueryAsync($"PartitionKey eq '{machineId}'"); - } + public IAsyncEnumerable GetMessage(Guid machineId) + => QueryAsync(Query.PartitionKey(machineId)); public async Async.Task ClearMessages(Guid machineId) { _logTracer.Info($"clearing messages for node {machineId}"); diff --git a/src/ApiService/ApiService/onefuzzlib/NsgOperations.cs b/src/ApiService/ApiService/onefuzzlib/NsgOperations.cs index 33b120e7737..561c3c0e4ae 100644 --- a/src/ApiService/ApiService/onefuzzlib/NsgOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/NsgOperations.cs @@ -48,7 +48,7 @@ public NsgOperations(ICreds creds, ILogTracer logTracer) { public async Async.Task DissociateNic(Nsg nsg, NetworkInterfaceResource nic) { if (nic.Data.NetworkSecurityGroup == null) { - return OneFuzzResultVoid.Ok(); + return OneFuzzResultVoid.Ok; } var azureNsg = await GetNsg(nsg.Name); @@ -83,7 +83,7 @@ await _creds.GetResourceGroupResource() err, ) */ - return OneFuzzResultVoid.Ok(); + return OneFuzzResultVoid.Ok; } return OneFuzzResultVoid.Error( ErrorCode.UNABLE_TO_UPDATE, @@ -93,7 +93,7 @@ await _creds.GetResourceGroupResource() ); } - return OneFuzzResultVoid.Ok(); + return OneFuzzResultVoid.Ok; } public async Async.Task GetNsg(string name) { diff --git a/src/ApiService/ApiService/onefuzzlib/PoolOperations.cs b/src/ApiService/ApiService/onefuzzlib/PoolOperations.cs index 4717f56655a..67153a1f031 100644 --- a/src/ApiService/ApiService/onefuzzlib/PoolOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/PoolOperations.cs @@ -4,7 +4,7 @@ namespace Microsoft.OneFuzz.Service; public interface IPoolOperations { - public Async.Task> GetByName(string poolName); + public Async.Task> GetByName(PoolName poolName); Task ScheduleWorkset(Pool pool, WorkSet workSet); IAsyncEnumerable GetByClientId(Guid clientId); } @@ -16,8 +16,8 @@ public PoolOperations(ILogTracer log, IOnefuzzContext context) } - public async Async.Task> GetByName(string poolName) { - var pools = QueryAsync(filter: $"PartitionKey eq '{poolName}'"); + public async Async.Task> GetByName(PoolName poolName) { + var pools = QueryAsync(filter: $"PartitionKey eq '{poolName.String}'"); if (pools == null || await pools.CountAsync() == 0) { return OneFuzzResult.Error(ErrorCode.INVALID_REQUEST, "unable to find pool"); diff --git a/src/ApiService/ApiService/onefuzzlib/ScalesetOperations.cs b/src/ApiService/ApiService/onefuzzlib/ScalesetOperations.cs index a4d9b5b9ac7..0faee37954e 100644 --- a/src/ApiService/ApiService/onefuzzlib/ScalesetOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/ScalesetOperations.cs @@ -6,7 +6,7 @@ namespace Microsoft.OneFuzz.Service; public interface IScalesetOperations : IOrm { IAsyncEnumerable Search(); - public IAsyncEnumerable SearchByPool(string poolName); + public IAsyncEnumerable SearchByPool(PoolName poolName); public Async.Task UpdateConfigs(Scaleset scaleSet); @@ -29,8 +29,8 @@ public IAsyncEnumerable Search() { return QueryAsync(); } - public IAsyncEnumerable SearchByPool(string poolName) { - return QueryAsync(filter: $"pool_name eq '{poolName}'"); + public IAsyncEnumerable SearchByPool(PoolName poolName) { + return QueryAsync(filter: $"PartitionKey eq '{poolName}'"); } diff --git a/src/ApiService/ApiService/onefuzzlib/Scheduler.cs b/src/ApiService/ApiService/onefuzzlib/Scheduler.cs index d219de1b3b7..1d4bc711405 100644 --- a/src/ApiService/ApiService/onefuzzlib/Scheduler.cs +++ b/src/ApiService/ApiService/onefuzzlib/Scheduler.cs @@ -182,7 +182,7 @@ record BucketConfig(int count, bool reboot, Container setupContainer, string? se return (bucketConfig, workUnit); } - record struct BucketId(Os os, Guid jobId, (string, string)? vm, string? pool, string setupContainer, bool? reboot, Guid? unique); + record struct BucketId(Os os, Guid jobId, (string, string)? vm, PoolName? pool, string setupContainer, bool? reboot, Guid? unique); private ILookup BucketTasks(IEnumerable tasks) { @@ -205,7 +205,7 @@ private ILookup BucketTasks(IEnumerable tasks) { } // check for multiple VMs for 1.0.0 and later tasks - string? pool = task.Config.Pool?.PoolName; + var pool = task.Config.Pool?.PoolName; if ((task.Config.Pool?.Count ?? 0) > 1) { unique = Guid.NewGuid(); } @@ -219,5 +219,3 @@ private ILookup BucketTasks(IEnumerable tasks) { }); } } - - diff --git a/src/ApiService/ApiService/onefuzzlib/VmssOperations.cs b/src/ApiService/ApiService/onefuzzlib/VmssOperations.cs index 1e90b7e5eae..06d4a61aa72 100644 --- a/src/ApiService/ApiService/onefuzzlib/VmssOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/VmssOperations.cs @@ -80,7 +80,7 @@ public async Async.Task UpdateExtensions(Guid name, IList UpdateScaleInProtection(Guid name, Gu _log.WithHttpStatus((r.GetRawResponse().Status, r.GetRawResponse().ReasonPhrase)).Error(msg); return OneFuzzResultVoid.Error(ErrorCode.UNABLE_TO_UPDATE, msg); } else { - return OneFuzzResultVoid.Ok(); + return OneFuzzResultVoid.Ok; } } catch (Exception ex) when (ex is RequestFailedException || ex is CloudException) { if (ex.Message.Contains(INSTANCE_NOT_FOUND) && protectFromScaleIn == false) { _log.Info($"Tried to remove scale in protection on node {name} {vmId} but instance no longer exists"); - return OneFuzzResultVoid.Ok(); + return OneFuzzResultVoid.Ok; } else { var msg = $"failed to update scale in protection on vm {vmId} for scaleset {name}"; _log.Exception(ex, msg); diff --git a/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs b/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs index 8410cbfc48e..bb155f3c611 100644 --- a/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs +++ b/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs @@ -143,10 +143,9 @@ private EntityInfo GetEntityInfo() { }); } - public string ToJsonString(T typedEntity) { - var serialized = JsonSerializer.Serialize(typedEntity, _options); - return serialized; - } + public string ToJsonString(T typedEntity) => JsonSerializer.Serialize(typedEntity, _options); + + public T? FromJsonString(string value) => JsonSerializer.Deserialize(value, _options); public TableEntity ToTableEntity(T typedEntity) where T : EntityBase { if (typedEntity == null) { @@ -211,8 +210,11 @@ public TableEntity ToTableEntity(T typedEntity) where T : EntityBase { return Guid.Parse(entity.GetString(ef.kind.ToString())); else if (ef.type == typeof(int)) return int.Parse(entity.GetString(ef.kind.ToString())); + else if (ef.type == typeof(PoolName)) + // TODO: this should be able to be generic over any ValidatedString + return PoolName.Parse(entity.GetString(ef.kind.ToString())); else { - throw new Exception("invalid "); + throw new Exception($"invalid partition or row key type of {info.type} property {name}: {ef.type}"); } } @@ -247,7 +249,6 @@ public TableEntity ToTableEntity(T typedEntity) where T : EntityBase { outputType = typeProvider.GetTypeInfo(v); } - if (objType == typeof(string)) { var value = entity.GetString(fieldName); if (value.StartsWith('[') || value.StartsWith('{') || value == "null") { @@ -283,6 +284,3 @@ public T ToRecord(TableEntity entity) where T : EntityBase { } } - - - diff --git a/src/ApiService/ApiService/onefuzzlib/orm/Queries.cs b/src/ApiService/ApiService/onefuzzlib/orm/Queries.cs index 51bf2c9c0aa..24a6822a9ba 100644 --- a/src/ApiService/ApiService/onefuzzlib/orm/Queries.cs +++ b/src/ApiService/ApiService/onefuzzlib/orm/Queries.cs @@ -11,6 +11,9 @@ public static class Query { public static string PartitionKey(string partitionKey) => TableClient.CreateQueryFilter($"PartitionKey eq {partitionKey}"); + public static string PartitionKey(Guid partitionKey) + => TableClient.CreateQueryFilter($"PartitionKey eq {partitionKey}"); + public static string RowKey(string rowKey) => TableClient.CreateQueryFilter($"RowKey eq {rowKey}"); diff --git a/src/ApiService/Tests/Fakes/TestContext.cs b/src/ApiService/Tests/Fakes/TestContext.cs index d782fee682c..206277c29cd 100644 --- a/src/ApiService/Tests/Fakes/TestContext.cs +++ b/src/ApiService/Tests/Fakes/TestContext.cs @@ -24,6 +24,9 @@ public TestContext(ILogTracer logTracer, IStorage storage, ICreds creds, string NodeTasksOperations = new NodeTasksOperations(logTracer, this); TaskEventOperations = new TaskEventOperations(logTracer, this); NodeMessageOperations = new NodeMessageOperations(logTracer, this); + ConfigOperations = new ConfigOperations(logTracer, this); + + UserCredentials = new UserCredentials(logTracer, ConfigOperations); } public TestEvents Events { get; set; } = new(); @@ -36,6 +39,7 @@ public Async.Task InsertAll(params EntityBase[] objs) Node n => NodeOperations.Insert(n), Job j => JobOperations.Insert(j), NodeTasks nt => NodeTasksOperations.Insert(nt), + InstanceConfig ic => ConfigOperations.Insert(ic), _ => throw new NotImplementedException($"Need to add an TestContext.InsertAll case for {x.GetType()} entities"), })); @@ -48,6 +52,7 @@ public Async.Task InsertAll(params EntityBase[] objs) public IStorage Storage { get; } public ICreds Creds { get; } public IContainers Containers { get; } + public IUserCredentials UserCredentials { get; set; } public IRequestHandling RequestHandling { get; } @@ -57,12 +62,12 @@ public Async.Task InsertAll(params EntityBase[] objs) public INodeTasksOperations NodeTasksOperations { get; } public ITaskEventOperations TaskEventOperations { get; } public INodeMessageOperations NodeMessageOperations { get; } + public IConfigOperations ConfigOperations { get; } // -- Remainder not implemented -- public IConfig Config => throw new System.NotImplementedException(); - public IConfigOperations ConfigOperations => throw new System.NotImplementedException(); public IDiskOperations DiskOperations => throw new System.NotImplementedException(); @@ -92,8 +97,6 @@ public Async.Task InsertAll(params EntityBase[] objs) public ISecretsOperations SecretsOperations => throw new System.NotImplementedException(); - public IUserCredentials UserCredentials => throw new System.NotImplementedException(); - public IVmOperations VmOperations => throw new System.NotImplementedException(); public IVmssOperations VmssOperations => throw new System.NotImplementedException(); diff --git a/src/ApiService/Tests/Fakes/TestEndpointAuthorization.cs b/src/ApiService/Tests/Fakes/TestEndpointAuthorization.cs index 70099268455..c1a2979a5a9 100644 --- a/src/ApiService/Tests/Fakes/TestEndpointAuthorization.cs +++ b/src/ApiService/Tests/Fakes/TestEndpointAuthorization.cs @@ -11,16 +11,16 @@ enum RequestType { Agent, } -sealed class TestEndpointAuthorization : IEndpointAuthorization { +sealed class TestEndpointAuthorization : EndpointAuthorization { private readonly RequestType _type; private readonly IOnefuzzContext _context; - public TestEndpointAuthorization(RequestType type, IOnefuzzContext context) { + public TestEndpointAuthorization(RequestType type, ILogTracer log, IOnefuzzContext context) : base(context, log) { _type = type; _context = context; } - public Task CallIf( + public override Task CallIf( HttpRequestData req, Func> method, bool allowUser = false, diff --git a/src/ApiService/Tests/Fakes/TestServiceConfiguration.cs b/src/ApiService/Tests/Fakes/TestServiceConfiguration.cs index a157e967e1b..b5196051fbd 100644 --- a/src/ApiService/Tests/Fakes/TestServiceConfiguration.cs +++ b/src/ApiService/Tests/Fakes/TestServiceConfiguration.cs @@ -19,6 +19,8 @@ public TestServiceConfiguration(string tablePrefix) { public string? ApplicationInsightsInstrumentationKey { get; set; } = "TestAppInsightsInstrumentationKey"; + public string? OneFuzzInstanceName => "UnitTestInstance"; + // -- Remainder not implemented -- public LogDestination[] LogDestinations { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); } @@ -42,8 +44,6 @@ public TestServiceConfiguration(string tablePrefix) { public string? OneFuzzInstance => throw new System.NotImplementedException(); - public string? OneFuzzInstanceName => throw new System.NotImplementedException(); - public string? OneFuzzKeyvault => throw new System.NotImplementedException(); public string? OneFuzzMonitor => throw new System.NotImplementedException(); diff --git a/src/ApiService/Tests/Fakes/TestUserCredentials.cs b/src/ApiService/Tests/Fakes/TestUserCredentials.cs new file mode 100644 index 00000000000..fc92357819d --- /dev/null +++ b/src/ApiService/Tests/Fakes/TestUserCredentials.cs @@ -0,0 +1,19 @@ +using System.Threading.Tasks; +using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service; + +using Async = System.Threading.Tasks; + +namespace Tests.Fakes; + +sealed class TestUserCredentials : UserCredentials { + + private readonly OneFuzzResult _tokenResult; + + public TestUserCredentials(ILogTracer log, IConfigOperations instanceConfig, OneFuzzResult tokenResult) + : base(log, instanceConfig) { + _tokenResult = tokenResult; + } + + public override Task> ParseJwtToken(HttpRequestData req) => Async.Task.FromResult(_tokenResult); +} diff --git a/src/ApiService/Tests/Functions/AgentEventsTests.cs b/src/ApiService/Tests/Functions/AgentEventsTests.cs index 4f232a78acc..5e72b9677e7 100644 --- a/src/ApiService/Tests/Functions/AgentEventsTests.cs +++ b/src/ApiService/Tests/Functions/AgentEventsTests.cs @@ -29,7 +29,7 @@ public AgentEventsTestsBase(ITestOutputHelper output, IStorage storage) readonly Guid jobId = Guid.NewGuid(); readonly Guid taskId = Guid.NewGuid(); readonly Guid machineId = Guid.NewGuid(); - readonly string poolName = $"pool-{Guid.NewGuid()}"; + readonly PoolName poolName = PoolName.Parse($"pool-{Guid.NewGuid()}"); readonly Guid poolId = Guid.NewGuid(); readonly string poolVersion = $"version-{Guid.NewGuid()}"; diff --git a/src/ApiService/Tests/Functions/InfoTests.cs b/src/ApiService/Tests/Functions/InfoTests.cs index 1adc547da77..fe1cbc7ac97 100644 --- a/src/ApiService/Tests/Functions/InfoTests.cs +++ b/src/ApiService/Tests/Functions/InfoTests.cs @@ -27,7 +27,7 @@ public InfoTestBase(ITestOutputHelper output, IStorage storage) [Fact] public async Async.Task TestInfo_WithoutAuthorization_IsRejected() { - var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Context); + var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context); var func = new Info(auth, Context); var result = await func.Run(TestHttpRequestData.Empty("GET")); @@ -36,7 +36,7 @@ public async Async.Task TestInfo_WithoutAuthorization_IsRejected() { [Fact] public async Async.Task TestInfo_WithAgentCredentials_IsRejected() { - var auth = new TestEndpointAuthorization(RequestType.Agent, Context); + var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context); var func = new Info(auth, Context); var result = await func.Run(TestHttpRequestData.Empty("GET")); @@ -52,7 +52,7 @@ public async Async.Task TestInfo_WithUserCredentials_Succeeds() { await containerClient.CreateAsync(); await containerClient.GetBlobClient("instance_id").UploadAsync(new BinaryData(instanceId)); - var auth = new TestEndpointAuthorization(RequestType.User, Context); + var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); var func = new Info(auth, Context); var result = await func.Run(TestHttpRequestData.Empty("GET")); diff --git a/src/ApiService/Tests/Functions/NodeTests.cs b/src/ApiService/Tests/Functions/NodeTests.cs new file mode 100644 index 00000000000..f959e81bcac --- /dev/null +++ b/src/ApiService/Tests/Functions/NodeTests.cs @@ -0,0 +1,200 @@ + +using System; +using System.Linq; +using System.Net; +using Microsoft.OneFuzz.Service; +using Tests.Fakes; +using Xunit; +using Xunit.Abstractions; + +using Async = System.Threading.Tasks; + +namespace Tests.Functions; + +[Trait("Category", "Integration")] +public class AzureStorageNodeTest : NodeTestBase { + public AzureStorageNodeTest(ITestOutputHelper output) + : base(output, Integration.AzureStorage.FromEnvironment()) { } +} + +public class AzuriteNodeTest : NodeTestBase { + public AzuriteNodeTest(ITestOutputHelper output) + : base(output, new Integration.AzuriteStorage()) { } +} + +public abstract class NodeTestBase : FunctionTestBase { + public NodeTestBase(ITestOutputHelper output, IStorage storage) + : base(output, storage) { } + + private readonly Guid _machineId = Guid.NewGuid(); + private readonly PoolName _poolName = PoolName.Parse($"pool-{Guid.NewGuid()}"); + private readonly string _version = Guid.NewGuid().ToString(); + + [Fact] + public async Async.Task Search_SpecificNode_NotFound_ReturnsNotFound() { + var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); + + var req = new NodeSearch(MachineId: _machineId); + var func = new NodeFunction(Logger, auth, Context); + var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); + Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); + } + + [Fact] + public async Async.Task Search_SpecificNode_Found_ReturnsOk() { + await Context.InsertAll( + new Node(_poolName, _machineId, null, _version)); + + var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); + + var req = new NodeSearch(MachineId: _machineId); + var func = new NodeFunction(Logger, auth, Context); + var result = await func.Run(TestHttpRequestData.FromJson("GET", req)); + Assert.Equal(HttpStatusCode.OK, result.StatusCode); + + // make sure we got the data from the table + var deserialized = BodyAs(result); + Assert.Equal(_version, deserialized.Version); + } + + [Theory] + [InlineData("PATCH")] + [InlineData("POST")] + [InlineData("DELETE")] + public async Async.Task RequiresAdmin(string method) { + // config must be found + await Context.InsertAll( + new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!)); + + // must be a user to auth + var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); + + // override the found user credentials + var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: Guid.NewGuid(), "upn"); + Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); + + var req = new NodeGet(MachineId: _machineId); + var func = new NodeFunction(Logger, auth, Context); + var result = await func.Run(TestHttpRequestData.FromJson(method, req)); + Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); + + var err = BodyAs(result); + Assert.Equal(ErrorCode.UNAUTHORIZED, err.Code); + Assert.Contains("pool modification disabled", err.Errors?.Single()); + } + + [Theory] + [InlineData("PATCH")] + [InlineData("POST")] + [InlineData("DELETE")] + public async Async.Task RequiresAdmin_CanBeDisabled(string method) { + // disable requiring admin privileges + await Context.InsertAll( + new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { + RequireAdminPrivileges = false + }); + + // must be a user to auth + var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); + + // override the found user credentials + var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: Guid.NewGuid(), "upn"); + Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); + + var req = new NodeGet(MachineId: _machineId); + var func = new NodeFunction(Logger, auth, Context); + var result = await func.Run(TestHttpRequestData.FromJson(method, req)); + + // we will fail with BadRequest but due to not being able to find the Node, + // not because of UNAUTHORIZED + Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); + Assert.Equal(ErrorCode.UNABLE_TO_FIND, BodyAs(result).Code); + } + + [Theory] + [InlineData("PATCH")] + [InlineData("POST")] + [InlineData("DELETE")] + public async Async.Task UserCanBeAdmin(string method) { + var userObjectId = Guid.NewGuid(); + + // config specifies that user is admin + await Context.InsertAll( + new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { + Admins = new[] { userObjectId } + }); + + // must be a user to auth + var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); + + // override the found user credentials + var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: userObjectId, "upn"); + Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); + + var req = new NodeGet(MachineId: _machineId); + var func = new NodeFunction(Logger, auth, Context); + var result = await func.Run(TestHttpRequestData.FromJson(method, req)); + + // we will fail with BadRequest but due to not being able to find the Node, + // not because of UNAUTHORIZED + Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); + Assert.Equal(ErrorCode.UNABLE_TO_FIND, BodyAs(result).Code); + } + + [Theory] + [InlineData("PATCH")] + [InlineData("POST")] + [InlineData("DELETE")] + public async Async.Task EnablingAdminForAnotherUserDoesNotPermitThisUser(string method) { + var userObjectId = Guid.NewGuid(); + var otherObjectId = Guid.NewGuid(); + + // config specifies that a different user is admin + await Context.InsertAll( + new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { + Admins = new[] { otherObjectId } + }); + + // must be a user to auth + var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); + + // override the found user credentials + var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: userObjectId, "upn"); + Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); + + var req = new NodeGet(MachineId: _machineId); + var func = new NodeFunction(Logger, auth, Context); + var result = await func.Run(TestHttpRequestData.FromJson(method, req)); + Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); + + var err = BodyAs(result); + Assert.Equal(ErrorCode.UNAUTHORIZED, err.Code); + Assert.Contains("not authorized to manage pools", err.Errors?.Single()); + } + + [Theory] + [InlineData("PATCH")] + [InlineData("POST")] + [InlineData("DELETE")] + public async Async.Task CanPerformOperation(string method) { + // disable requiring admin privileges + await Context.InsertAll( + new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) { + RequireAdminPrivileges = false + }, + new Node(_poolName, _machineId, null, _version)); + + // must be a user to auth + var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context); + + // override the found user credentials + var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: Guid.NewGuid(), "upn"); + Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult.Ok(userInfo)); + + // all of these operations use NodeGet + var req = new NodeGet(MachineId: _machineId); + var func = new NodeFunction(Logger, auth, Context); + var result = await func.Run(TestHttpRequestData.FromJson(method, req)); + Assert.Equal(HttpStatusCode.OK, result.StatusCode); + } +} diff --git a/src/ApiService/Tests/Functions/_FunctionTestBase.cs b/src/ApiService/Tests/Functions/_FunctionTestBase.cs index 1653fb94bc9..b8595aeb749 100644 --- a/src/ApiService/Tests/Functions/_FunctionTestBase.cs +++ b/src/ApiService/Tests/Functions/_FunctionTestBase.cs @@ -7,6 +7,7 @@ using Azure.Storage.Blobs; using Microsoft.Azure.Functions.Worker.Http; using Microsoft.OneFuzz.Service; +using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; using Tests.Fakes; using Xunit.Abstractions; @@ -60,6 +61,9 @@ protected static string BodyAsString(HttpResponseData data) { return sr.ReadToEnd(); } + protected static T BodyAs(HttpResponseData data) + => new EntityConverter().FromJsonString(BodyAsString(data)) ?? throw new Exception($"unable to deserialize body as {typeof(T)}"); + public void Dispose() { var (accountName, accountKey) = _storage.GetStorageAccountNameAndKey("").Result; // sync for test impls if (accountName is not null && accountKey is not null) { @@ -72,6 +76,7 @@ public void Dispose() { new StorageSharedKeyCredential(accountName, accountKey)); } } + private void CleanupBlobs(Uri endpoint, StorageSharedKeyCredential creds) { var blobClient = new BlobServiceClient(endpoint, creds); diff --git a/src/ApiService/Tests/OrmModelsTest.cs b/src/ApiService/Tests/OrmModelsTest.cs index fd40b71d1f2..8626f8789c3 100644 --- a/src/ApiService/Tests/OrmModelsTest.cs +++ b/src/ApiService/Tests/OrmModelsTest.cs @@ -74,21 +74,26 @@ public static Gen NodeTasks() { ); } - public static Gen Node() { - return Arb.Generate, Tuple>>().Select( - arg => new Node( + public static Gen PoolNameGen { get; } + = from name in Arb.Generate() + where PoolName.TryParse(name.Get, out _) + select PoolName.Parse(name.Get); + + public static Gen Node { get; } + = from arg in Arb.Generate, Tuple>>() + from poolName in PoolNameGen + select new Node( InitializedAt: arg.Item1.Item1, - PoolName: arg.Item1.Item2, + PoolName: poolName, PoolId: arg.Item1.Item3, - MachineId: arg.Item1.Item4, - State: arg.Item1.Item5, + MachineId: arg.Item1.Item3, + State: arg.Item1.Item4, ScalesetId: arg.Item2.Item1, Heartbeat: arg.Item2.Item2, Version: arg.Item2.Item3, ReimageRequested: arg.Item2.Item4, DeleteRequested: arg.Item2.Item5, - DebugKeepNode: arg.Item2.Item6)); - } + DebugKeepNode: arg.Item2.Item6); public static Gen ProxyForward() { return Arb.Generate, Tuple>>().Select( @@ -200,35 +205,31 @@ public static Gen Task() { ) ); } - public static Gen Scaleset() { - return Arb.Generate, - Tuple, Guid?>, - Tuple>>>().Select( - arg => - new Scaleset( - PoolName: arg.Item1.Item1, - ScalesetId: arg.Item1.Item2, - State: arg.Item1.Item3, - Auth: arg.Item1.Item4, - VmSku: arg.Item1.Item5, - Image: arg.Item1.Item6, - Region: arg.Item1.Item7, - - Size: arg.Item2.Item1, - SpotInstance: arg.Item2.Item2, - EphemeralOsDisks: arg.Item2.Item3, - NeedsConfigUpdate: arg.Item2.Item4, - Error: arg.Item2.Item5, - Nodes: arg.Item2.Item6, - ClientId: arg.Item2.Item7, - - ClientObjectId: arg.Item3.Item1, - Tags: arg.Item3.Item2 - ) - ); - } - + public static Gen Scaleset { get; } + = from arg in Arb.Generate, + Tuple, Guid?>, + Tuple>>>() + from poolName in PoolNameGen + select new Scaleset( + PoolName: poolName, + ScalesetId: arg.Item1.Item1, + State: arg.Item1.Item2, + Auth: arg.Item1.Item3, + VmSku: arg.Item1.Item4, + Image: arg.Item1.Item5, + Region: arg.Item1.Item6, + + Size: arg.Item2.Item1, + SpotInstance: arg.Item2.Item2, + EphemeralOsDisks: arg.Item2.Item3, + NeedsConfigUpdate: arg.Item2.Item4, + Error: arg.Item2.Item5, + Nodes: arg.Item2.Item6, + ClientId: arg.Item2.Item7, + + ClientObjectId: arg.Item3.Item1, + Tags: arg.Item3.Item2); public static Gen Webhook() { return Arb.Generate, string, WebhookMessageFormat>>().Select( @@ -348,7 +349,7 @@ public static Arbitrary NodeTasks() { } public static Arbitrary Node() { - return Arb.From(OrmGenerators.Node()); + return Arb.From(OrmGenerators.Node); } public static Arbitrary ProxyForward() { @@ -383,9 +384,8 @@ public static Arbitrary Task() { return Arb.From(OrmGenerators.Task()); } - public static Arbitrary Scaleset() { - return Arb.From(OrmGenerators.Scaleset()); - } + public static Arbitrary Scaleset() + => Arb.From(OrmGenerators.Scaleset); public static Arbitrary Webhook() { return Arb.From(OrmGenerators.Webhook()); diff --git a/src/ApiService/Tests/OrmTest.cs b/src/ApiService/Tests/OrmTest.cs index d68952c38a3..57b2868b28b 100644 --- a/src/ApiService/Tests/OrmTest.cs +++ b/src/ApiService/Tests/OrmTest.cs @@ -234,7 +234,7 @@ public void TestConvertSnakeToPAscalCase() { [Fact] public void TestEventSerialization() { - var expectedEvent = new EventMessage(Guid.NewGuid(), EventType.NodeHeartbeat, new EventNodeHeartbeat(Guid.NewGuid(), Guid.NewGuid(), "test Poool"), Guid.NewGuid(), "test"); + var expectedEvent = new EventMessage(Guid.NewGuid(), EventType.NodeHeartbeat, new EventNodeHeartbeat(Guid.NewGuid(), Guid.NewGuid(), PoolName.Parse("test-Poool")), Guid.NewGuid(), "test"); var serialized = JsonSerializer.Serialize(expectedEvent, EntityConverter.GetJsonSerializerOptions()); var actualEvent = JsonSerializer.Deserialize(serialized, EntityConverter.GetJsonSerializerOptions()); Assert.Equal(expectedEvent, actualEvent); diff --git a/src/ApiService/Tests/ValidatedStringTests.cs b/src/ApiService/Tests/ValidatedStringTests.cs new file mode 100644 index 00000000000..783bf769813 --- /dev/null +++ b/src/ApiService/Tests/ValidatedStringTests.cs @@ -0,0 +1,28 @@ +using System.Text.Json; +using Microsoft.OneFuzz.Service; +using Xunit; + +namespace Tests; + +public class ValidatedStringTests { + + record ThingContainingPoolName(PoolName PoolName); + + [Fact] + public void PoolNameValidatesOnDeserialization() { + var ex = Assert.Throws(() => JsonSerializer.Deserialize("{ \"PoolName\": \"is-not!-a-pool\" }")); + Assert.Equal("unable to parse input as a PoolName", ex.Message); + } + + [Fact] + public void PoolNameDeserializesFromString() { + var result = JsonSerializer.Deserialize("{ \"PoolName\": \"is-a-pool\" }"); + Assert.Equal("is-a-pool", result?.PoolName.String); + } + + [Fact] + public void PoolNameSerializesToString() { + var result = JsonSerializer.Serialize(new ThingContainingPoolName(PoolName.Parse("is-a-pool"))); + Assert.Equal("{\"PoolName\":\"is-a-pool\"}", result); + } +}