Skip to content
This repository has been archived by the owner on Nov 1, 2023. It is now read-only.

Refactor agent commands #1922

Merged
merged 7 commits into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 11 additions & 25 deletions src/ApiService/ApiService/AgentCanSchedule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,64 +6,50 @@ namespace Microsoft.OneFuzz.Service;
public class AgentCanSchedule {
private readonly ILogTracer _log;

private readonly IStorage _storage;
private readonly IOnefuzzContext _context;

private readonly INodeOperations _nodeOperations;

private readonly ITaskOperations _taskOperations;

private readonly IScalesetOperations _scalesetOperations;

public AgentCanSchedule(ILogTracer log, IStorage storage, INodeOperations nodeOperations, ITaskOperations taskOperations, IScalesetOperations scalesetOperations) {
public AgentCanSchedule(ILogTracer log, IOnefuzzContext context) {
_log = log;
_storage = storage;
_nodeOperations = nodeOperations;
_taskOperations = taskOperations;
_scalesetOperations = scalesetOperations;
_context = context;
}

// [Function("AgentCanSchedule")]
public async Async.Task<HttpResponseData> Run([HttpTrigger] HttpRequestData req) {
var request = await RequestHandling.ParseRequest<CanScheduleRequest>(req);
if (!request.IsOk || request.OkV == null) {
return await RequestHandling.NotOk(req, request.ErrorV, typeof(CanScheduleRequest).ToString(), _log);
return await _context.RequestHandling.NotOk(req, request.ErrorV, typeof(CanScheduleRequest).ToString());
}

var canScheduleRequest = request.OkV;

var node = await _nodeOperations.GetByMachineId(canScheduleRequest.MachineId);
var node = await _context.NodeOperations.GetByMachineId(canScheduleRequest.MachineId);
if (node == null) {
return await RequestHandling.NotOk(
return await _context.RequestHandling.NotOk(
req,
new Error(
ErrorCode.UNABLE_TO_FIND,
new string[] {
"unable to find node"
}
),
canScheduleRequest.MachineId.ToString(),
_log
canScheduleRequest.MachineId.ToString()
);
}

var allowed = true;
var workStopped = false;

if (!await _nodeOperations.CanProcessNewWork(node)) {
if (!await _context.NodeOperations.CanProcessNewWork(node)) {
allowed = false;
}

var task = await _taskOperations.GetByTaskId(canScheduleRequest.TaskId);
var task = await _context.TaskOperations.GetByTaskId(canScheduleRequest.TaskId);
workStopped = task == null || TaskStateHelper.ShuttingDown.Contains(task.State);

if (allowed) {
allowed = (await _nodeOperations.AcquireScaleInProtection(node)).IsOk;
allowed = (await _context.NodeOperations.AcquireScaleInProtection(node)).IsOk;
}

return await RequestHandling.Ok(
req,
new BaseResponse[] {
new CanSchedule(allowed, workStopped)
});
return await RequestHandling.Ok(req, new CanSchedule(allowed, workStopped));
}
}
57 changes: 57 additions & 0 deletions src/ApiService/ApiService/AgentCommands.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
using Microsoft.Azure.Functions.Worker;
using Microsoft.Azure.Functions.Worker.Http;

namespace Microsoft.OneFuzz.Service;

public class AgentCommands {
private readonly ILogTracer _log;

private readonly IOnefuzzContext _context;

public AgentCommands(ILogTracer log, IOnefuzzContext context) {
_log = log;
_context = context;
}

// [Function("AgentCommands")]
public async Async.Task<HttpResponseData> Run([HttpTrigger("get", "delete")] HttpRequestData req) {
return req.Method switch {
"GET" => await Get(req),
"DELETE" => await Delete(req),
_ => throw new NotImplementedException($"HTTP Method {req.Method} is not supported for this method")
};
}

private async Async.Task<HttpResponseData> Get(HttpRequestData req) {
var request = await RequestHandling.ParseRequest<NodeCommandGet>(req);
if (!request.IsOk || request.OkV == null) {
return await _context.RequestHandling.NotOk(req, request.ErrorV, typeof(NodeCommandGet).ToString());
}
var nodeCommand = request.OkV;

var message = await _context.NodeMessageOperations.GetMessage(nodeCommand.MachineId).FirstOrDefaultAsync();
if (message != null) {
var command = message.Message;
var messageId = message.MessageId;
var envelope = new NodeCommandEnvelope(command, messageId);
return await RequestHandling.Ok(req, new PendingNodeCommand(envelope));
} else {
return await RequestHandling.Ok(req, new PendingNodeCommand(null));
}
}

private async Async.Task<HttpResponseData> Delete(HttpRequestData req) {
var request = await RequestHandling.ParseRequest<NodeCommandDelete>(req);
if (!request.IsOk || request.OkV == null) {
return await _context.RequestHandling.NotOk(req, request.ErrorV, typeof(NodeCommandDelete).ToString());
}
var nodeCommand = request.OkV;

var message = await _context.NodeMessageOperations.GetEntityAsync(nodeCommand.MachineId.ToString(), nodeCommand.MessageId);
if (message != null) {
await _context.NodeMessageOperations.Delete(message);
}

return await RequestHandling.Ok(req, new BoolResult(true));
}
}
5 changes: 5 additions & 0 deletions src/ApiService/ApiService/OneFuzzTypes/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -695,3 +695,8 @@ Uri HeartbeatQueue
public IContainerDef? RegressionReport { get; set; }

}

public record NodeCommandEnvelope(
NodeCommand Command,
string MessageId
);
9 changes: 9 additions & 0 deletions src/ApiService/ApiService/OneFuzzTypes/Requests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,12 @@ public record CanScheduleRequest(
Guid MachineId,
Guid TaskId
) : BaseRequest;

public record NodeCommandGet(
Guid MachineId
) : BaseRequest;

public record NodeCommandDelete(
Guid MachineId,
string MessageId
) : BaseRequest;
30 changes: 27 additions & 3 deletions src/ApiService/ApiService/OneFuzzTypes/Responses.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,32 @@
namespace Microsoft.OneFuzz.Service;
using System.Text.Json;
using System.Text.Json.Serialization;

public record BaseResponse();
namespace Microsoft.OneFuzz.Service;

[JsonConverter(typeof(BaseResponseConverter))]
public abstract record BaseResponse();

public record CanSchedule(
bool Allowed,
bool WorkStopped
) : BaseResponse;
) : BaseResponse();

public record PendingNodeCommand(
NodeCommandEnvelope? Envelope
) : BaseResponse();

public record BoolResult(
bool Result
) : BaseResponse();


public class BaseResponseConverter : JsonConverter<BaseResponse> {
public override BaseResponse? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) {
return null;
}

public override void Write(Utf8JsonWriter writer, BaseResponse value, JsonSerializerOptions options) {
var eventType = value.GetType();
JsonSerializer.Serialize(writer, value, eventType, options);
}
}
1 change: 1 addition & 0 deletions src/ApiService/ApiService/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ public static void Main() {
.AddScoped<IVmssOperations, VmssOperations>()
.AddScoped<INodeTasksOperations, NodeTasksOperations>()
.AddScoped<INodeMessageOperations, NodeMessageOperations>()
.AddScoped<IRequestHandling, RequestHandling>()
.AddScoped<IOnefuzzContext, OnefuzzContext>()

.AddSingleton<ICreds, Creds>()
Expand Down
8 changes: 4 additions & 4 deletions src/ApiService/ApiService/UserCredentials.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace Microsoft.OneFuzz.Service;
public interface IUserCredentials {
public string? GetBearerToken(HttpRequestData req);
public string? GetAuthToken(HttpRequestData req);
public Task<OneFuzzResult<UserInfo>> ParseJwtToken(LogTracer log, HttpRequestData req);
public Task<OneFuzzResult<UserInfo>> ParseJwtToken(HttpRequestData req);
}

public class UserCredentials : IUserCredentials {
Expand Down Expand Up @@ -58,7 +58,7 @@ from t in r.AllowedAadTenants
return OneFuzzResult<string[]>.Ok(allowedAddTenantsQuery.ToArray());
}

public async Task<OneFuzzResult<UserInfo>> ParseJwtToken(LogTracer log, HttpRequestData req) {
public async Task<OneFuzzResult<UserInfo>> ParseJwtToken(HttpRequestData req) {
var authToken = GetAuthToken(req);
if (authToken is null) {
return OneFuzzResult<UserInfo>.Error(ErrorCode.INVALID_REQUEST, new[] { "unable to find authorization token" });
Expand All @@ -84,11 +84,11 @@ from t in token.Claims

return OneFuzzResult<UserInfo>.Ok(new(applicationId, objectId, upn));
} else {
log.Error($"issuer not from allowed tenant: {token.Issuer} - {allowedTenants}");
_log.Error($"issuer not from allowed tenant: {token.Issuer} - {allowedTenants}");
return OneFuzzResult<UserInfo>.Error(ErrorCode.INVALID_REQUEST, new[] { "unauthorized AAD issuer" });
}
} else {
log.Error("Failed to get allowed tenants");
_log.Error("Failed to get allowed tenants");
return OneFuzzResult<UserInfo>.Error(allowedTenants.ErrorV);
}
}
Expand Down
19 changes: 18 additions & 1 deletion src/ApiService/ApiService/onefuzzlib/Creds.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Azure.Core;
using System.Text.Json;
using Azure.Core;
using Azure.Identity;
using Azure.ResourceManager;
using Azure.ResourceManager.Resources;
Expand All @@ -23,6 +24,7 @@ public interface ICreds {
public Async.Task<string> GetBaseRegion();

public Uri GetInstanceUrl();
Guid GetScalesetPrincipalId();
}

public class Creds : ICreds {
Expand Down Expand Up @@ -85,4 +87,19 @@ public async Async.Task<string> GetBaseRegion() {
public Uri GetInstanceUrl() {
return new Uri($"https://{GetInstanceName()}.azurewebsites.net");
}

public Guid GetScalesetPrincipalId() {
var uid = ArmClient.GetGenericResource(
new ResourceIdentifier(GetScalesetIdentityResourcePath())
);
var principalId = JsonSerializer.Deserialize<JsonDocument>(uid.Data.Properties.ToString())?.RootElement.GetProperty("principalId").GetString()!;
return new Guid(principalId);
}

public string GetScalesetIdentityResourcePath() {
var scalesetIdName = $"{GetInstanceName()}-scalesetid";
var resourceGroupPath = $"/subscriptions/{GetSubscription()}/resourceGroups/{GetBaseResourceGroup()}/providers";

return $"{resourceGroupPath}/Microsoft.ManagedIdentity/userAssignedIdentities/{scalesetIdName}";
}
}
91 changes: 91 additions & 0 deletions src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
using System.Net;
using Microsoft.Azure.Functions.Worker.Http;

namespace Microsoft.OneFuzz.Service;

public class EndpointAuthorization {
private readonly IOnefuzzContext _context;
private readonly ILogTracer _log;

public EndpointAuthorization(IOnefuzzContext context, ILogTracer log) {
_context = context;
_log = log;
}
public async Async.Task<HttpResponseData> CallIfAgent(HttpRequestData req, Func<HttpRequestData, Async.Task<HttpResponseData>> method) {
return await CallIf(req, method, allowAgent: true);
}

public async Async.Task<HttpResponseData> CallIf(HttpRequestData req, Func<HttpRequestData, Async.Task<HttpResponseData>> method, bool allowUser = false, bool allowAgent = false) {
var tokenResult = await _context.UserCredentials.ParseJwtToken(req);

if (!tokenResult.IsOk) {
return await _context.RequestHandling.NotOk(req, tokenResult.ErrorV, "token verification", HttpStatusCode.Unauthorized);
}
var token = tokenResult.OkV!;

if (await IsUser(token)) {
if (!allowUser) {
return await Reject(req, token);
}

var access = CheckAccess(req);
if (!access.IsOk) {
return await _context.RequestHandling.NotOk(req, access.ErrorV, "access control", HttpStatusCode.Unauthorized);
}
}


if (await IsAgent(token) && !allowAgent) {
return await Reject(req, token);
}

return await method(req);
}

public async Async.Task<bool> IsUser(UserInfo tokenData) {
return !await IsAgent(tokenData);
}

public async Async.Task<HttpResponseData> Reject(HttpRequestData req, UserInfo token) {
_log.Error(
$"reject token. url:{req.Url} token:{token} body:{await req.ReadAsStringAsync()}"
);

return await _context.RequestHandling.NotOk(
req,
new Error(
ErrorCode.UNAUTHORIZED,
new string[] { "Unrecognized agent" }
),
"token verification",
HttpStatusCode.Unauthorized
);
}

public OneFuzzResultVoid CheckAccess(HttpRequestData req) {
throw new NotImplementedException();
}

public async Async.Task<bool> IsAgent(UserInfo tokenData) {
if (tokenData.ObjectId != null) {
var scalesets = _context.ScalesetOperations.GetByObjectId(tokenData.ObjectId.Value);
if (await scalesets.AnyAsync()) {
return true;
}

var principalId = _context.Creds.GetScalesetPrincipalId();
return principalId == tokenData.ObjectId;
}

if (!tokenData.ApplicationId.HasValue) {
return false;
}

var pools = _context.PoolOperations.GetByClientId(tokenData.ApplicationId.Value);
if (await pools.AnyAsync()) {
return true;
}

return false;
}
}
2 changes: 1 addition & 1 deletion src/ApiService/ApiService/onefuzzlib/NodeOperations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ public NodeMessageOperations(ILogTracer log, IOnefuzzContext context) : base(log
}

public IAsyncEnumerable<NodeMessage> GetMessage(Guid machineId) {
return QueryAsync($"machine_id eq '{machineId}'");
return QueryAsync($"PartitionKey eq '{machineId}'");
}

public async Async.Task ClearMessages(Guid machineId) {
Expand Down
4 changes: 4 additions & 0 deletions src/ApiService/ApiService/onefuzzlib/OnefuzzContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ public interface IOnefuzzContext {
IVmssOperations VmssOperations { get; }
IWebhookMessageLogOperations WebhookMessageLogOperations { get; }
IWebhookOperations WebhookOperations { get; }

IRequestHandling RequestHandling { get; }
}

public class OnefuzzContext : IOnefuzzContext {
Expand Down Expand Up @@ -71,6 +73,8 @@ public class OnefuzzContext : IOnefuzzContext {
public ICreds Creds { get => _serviceProvider.GetService<ICreds>() ?? throw new Exception("No ICreds service"); }
public IServiceConfig ServiceConfiguration { get => _serviceProvider.GetService<IServiceConfig>() ?? throw new Exception("No IServiceConfiguration service"); }

public IRequestHandling RequestHandling { get => _serviceProvider.GetService<IRequestHandling>() ?? throw new Exception("No IRequestHandling service"); }

public OnefuzzContext(IServiceProvider serviceProvider) {
_serviceProvider = serviceProvider;
}
Expand Down
Loading