Skip to content
This repository has been archived by the owner on Nov 1, 2023. It is now read-only.

Commit

Permalink
Refactor agent commands (#1922)
Browse files Browse the repository at this point in the history
* Checkpoint

* Disable the function for now

* snapshot

* Tested locally

* fmt
  • Loading branch information
tevoinea authored May 11, 2022
1 parent c9b46e9 commit 5f4a025
Show file tree
Hide file tree
Showing 14 changed files with 252 additions and 39 deletions.
36 changes: 11 additions & 25 deletions src/ApiService/ApiService/AgentCanSchedule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,64 +6,50 @@ namespace Microsoft.OneFuzz.Service;
public class AgentCanSchedule {
private readonly ILogTracer _log;

private readonly IStorage _storage;
private readonly IOnefuzzContext _context;

private readonly INodeOperations _nodeOperations;

private readonly ITaskOperations _taskOperations;

private readonly IScalesetOperations _scalesetOperations;

public AgentCanSchedule(ILogTracer log, IStorage storage, INodeOperations nodeOperations, ITaskOperations taskOperations, IScalesetOperations scalesetOperations) {
public AgentCanSchedule(ILogTracer log, IOnefuzzContext context) {
_log = log;
_storage = storage;
_nodeOperations = nodeOperations;
_taskOperations = taskOperations;
_scalesetOperations = scalesetOperations;
_context = context;
}

// [Function("AgentCanSchedule")]
public async Async.Task<HttpResponseData> Run([HttpTrigger] HttpRequestData req) {
var request = await RequestHandling.ParseRequest<CanScheduleRequest>(req);
if (!request.IsOk || request.OkV == null) {
return await RequestHandling.NotOk(req, request.ErrorV, typeof(CanScheduleRequest).ToString(), _log);
return await _context.RequestHandling.NotOk(req, request.ErrorV, typeof(CanScheduleRequest).ToString());
}

var canScheduleRequest = request.OkV;

var node = await _nodeOperations.GetByMachineId(canScheduleRequest.MachineId);
var node = await _context.NodeOperations.GetByMachineId(canScheduleRequest.MachineId);
if (node == null) {
return await RequestHandling.NotOk(
return await _context.RequestHandling.NotOk(
req,
new Error(
ErrorCode.UNABLE_TO_FIND,
new string[] {
"unable to find node"
}
),
canScheduleRequest.MachineId.ToString(),
_log
canScheduleRequest.MachineId.ToString()
);
}

var allowed = true;
var workStopped = false;

if (!await _nodeOperations.CanProcessNewWork(node)) {
if (!await _context.NodeOperations.CanProcessNewWork(node)) {
allowed = false;
}

var task = await _taskOperations.GetByTaskId(canScheduleRequest.TaskId);
var task = await _context.TaskOperations.GetByTaskId(canScheduleRequest.TaskId);
workStopped = task == null || TaskStateHelper.ShuttingDown.Contains(task.State);

if (allowed) {
allowed = (await _nodeOperations.AcquireScaleInProtection(node)).IsOk;
allowed = (await _context.NodeOperations.AcquireScaleInProtection(node)).IsOk;
}

return await RequestHandling.Ok(
req,
new BaseResponse[] {
new CanSchedule(allowed, workStopped)
});
return await RequestHandling.Ok(req, new CanSchedule(allowed, workStopped));
}
}
57 changes: 57 additions & 0 deletions src/ApiService/ApiService/AgentCommands.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
using Microsoft.Azure.Functions.Worker;
using Microsoft.Azure.Functions.Worker.Http;

namespace Microsoft.OneFuzz.Service;

public class AgentCommands {
private readonly ILogTracer _log;

private readonly IOnefuzzContext _context;

public AgentCommands(ILogTracer log, IOnefuzzContext context) {
_log = log;
_context = context;
}

// [Function("AgentCommands")]
public async Async.Task<HttpResponseData> Run([HttpTrigger("get", "delete")] HttpRequestData req) {
return req.Method switch {
"GET" => await Get(req),
"DELETE" => await Delete(req),
_ => throw new NotImplementedException($"HTTP Method {req.Method} is not supported for this method")
};
}

private async Async.Task<HttpResponseData> Get(HttpRequestData req) {
var request = await RequestHandling.ParseRequest<NodeCommandGet>(req);
if (!request.IsOk || request.OkV == null) {
return await _context.RequestHandling.NotOk(req, request.ErrorV, typeof(NodeCommandGet).ToString());
}
var nodeCommand = request.OkV;

var message = await _context.NodeMessageOperations.GetMessage(nodeCommand.MachineId).FirstOrDefaultAsync();
if (message != null) {
var command = message.Message;
var messageId = message.MessageId;
var envelope = new NodeCommandEnvelope(command, messageId);
return await RequestHandling.Ok(req, new PendingNodeCommand(envelope));
} else {
return await RequestHandling.Ok(req, new PendingNodeCommand(null));
}
}

private async Async.Task<HttpResponseData> Delete(HttpRequestData req) {
var request = await RequestHandling.ParseRequest<NodeCommandDelete>(req);
if (!request.IsOk || request.OkV == null) {
return await _context.RequestHandling.NotOk(req, request.ErrorV, typeof(NodeCommandDelete).ToString());
}
var nodeCommand = request.OkV;

var message = await _context.NodeMessageOperations.GetEntityAsync(nodeCommand.MachineId.ToString(), nodeCommand.MessageId);
if (message != null) {
await _context.NodeMessageOperations.Delete(message);
}

return await RequestHandling.Ok(req, new BoolResult(true));
}
}
5 changes: 5 additions & 0 deletions src/ApiService/ApiService/OneFuzzTypes/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -695,3 +695,8 @@ Uri HeartbeatQueue
public IContainerDef? RegressionReport { get; set; }

}

public record NodeCommandEnvelope(
NodeCommand Command,
string MessageId
);
9 changes: 9 additions & 0 deletions src/ApiService/ApiService/OneFuzzTypes/Requests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,12 @@ public record CanScheduleRequest(
Guid MachineId,
Guid TaskId
) : BaseRequest;

public record NodeCommandGet(
Guid MachineId
) : BaseRequest;

public record NodeCommandDelete(
Guid MachineId,
string MessageId
) : BaseRequest;
30 changes: 27 additions & 3 deletions src/ApiService/ApiService/OneFuzzTypes/Responses.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,32 @@
namespace Microsoft.OneFuzz.Service;
using System.Text.Json;
using System.Text.Json.Serialization;

public record BaseResponse();
namespace Microsoft.OneFuzz.Service;

[JsonConverter(typeof(BaseResponseConverter))]
public abstract record BaseResponse();

public record CanSchedule(
bool Allowed,
bool WorkStopped
) : BaseResponse;
) : BaseResponse();

public record PendingNodeCommand(
NodeCommandEnvelope? Envelope
) : BaseResponse();

public record BoolResult(
bool Result
) : BaseResponse();


public class BaseResponseConverter : JsonConverter<BaseResponse> {
public override BaseResponse? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) {
return null;
}

public override void Write(Utf8JsonWriter writer, BaseResponse value, JsonSerializerOptions options) {
var eventType = value.GetType();
JsonSerializer.Serialize(writer, value, eventType, options);
}
}
1 change: 1 addition & 0 deletions src/ApiService/ApiService/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ public static void Main() {
.AddScoped<IVmssOperations, VmssOperations>()
.AddScoped<INodeTasksOperations, NodeTasksOperations>()
.AddScoped<INodeMessageOperations, NodeMessageOperations>()
.AddScoped<IRequestHandling, RequestHandling>()
.AddScoped<IOnefuzzContext, OnefuzzContext>()

.AddSingleton<ICreds, Creds>()
Expand Down
8 changes: 4 additions & 4 deletions src/ApiService/ApiService/UserCredentials.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace Microsoft.OneFuzz.Service;
public interface IUserCredentials {
public string? GetBearerToken(HttpRequestData req);
public string? GetAuthToken(HttpRequestData req);
public Task<OneFuzzResult<UserInfo>> ParseJwtToken(LogTracer log, HttpRequestData req);
public Task<OneFuzzResult<UserInfo>> ParseJwtToken(HttpRequestData req);
}

public class UserCredentials : IUserCredentials {
Expand Down Expand Up @@ -58,7 +58,7 @@ from t in r.AllowedAadTenants
return OneFuzzResult<string[]>.Ok(allowedAddTenantsQuery.ToArray());
}

public async Task<OneFuzzResult<UserInfo>> ParseJwtToken(LogTracer log, HttpRequestData req) {
public async Task<OneFuzzResult<UserInfo>> ParseJwtToken(HttpRequestData req) {
var authToken = GetAuthToken(req);
if (authToken is null) {
return OneFuzzResult<UserInfo>.Error(ErrorCode.INVALID_REQUEST, new[] { "unable to find authorization token" });
Expand All @@ -84,11 +84,11 @@ from t in token.Claims

return OneFuzzResult<UserInfo>.Ok(new(applicationId, objectId, upn));
} else {
log.Error($"issuer not from allowed tenant: {token.Issuer} - {allowedTenants}");
_log.Error($"issuer not from allowed tenant: {token.Issuer} - {allowedTenants}");
return OneFuzzResult<UserInfo>.Error(ErrorCode.INVALID_REQUEST, new[] { "unauthorized AAD issuer" });
}
} else {
log.Error("Failed to get allowed tenants");
_log.Error("Failed to get allowed tenants");
return OneFuzzResult<UserInfo>.Error(allowedTenants.ErrorV);
}
}
Expand Down
19 changes: 18 additions & 1 deletion src/ApiService/ApiService/onefuzzlib/Creds.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Azure.Core;
using System.Text.Json;
using Azure.Core;
using Azure.Identity;
using Azure.ResourceManager;
using Azure.ResourceManager.Resources;
Expand All @@ -23,6 +24,7 @@ public interface ICreds {
public Async.Task<string> GetBaseRegion();

public Uri GetInstanceUrl();
Guid GetScalesetPrincipalId();
}

public class Creds : ICreds {
Expand Down Expand Up @@ -85,4 +87,19 @@ public async Async.Task<string> GetBaseRegion() {
public Uri GetInstanceUrl() {
return new Uri($"https://{GetInstanceName()}.azurewebsites.net");
}

public Guid GetScalesetPrincipalId() {
var uid = ArmClient.GetGenericResource(
new ResourceIdentifier(GetScalesetIdentityResourcePath())
);
var principalId = JsonSerializer.Deserialize<JsonDocument>(uid.Data.Properties.ToString())?.RootElement.GetProperty("principalId").GetString()!;
return new Guid(principalId);
}

public string GetScalesetIdentityResourcePath() {
var scalesetIdName = $"{GetInstanceName()}-scalesetid";
var resourceGroupPath = $"/subscriptions/{GetSubscription()}/resourceGroups/{GetBaseResourceGroup()}/providers";

return $"{resourceGroupPath}/Microsoft.ManagedIdentity/userAssignedIdentities/{scalesetIdName}";
}
}
91 changes: 91 additions & 0 deletions src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
using System.Net;
using Microsoft.Azure.Functions.Worker.Http;

namespace Microsoft.OneFuzz.Service;

public class EndpointAuthorization {
private readonly IOnefuzzContext _context;
private readonly ILogTracer _log;

public EndpointAuthorization(IOnefuzzContext context, ILogTracer log) {
_context = context;
_log = log;
}
public async Async.Task<HttpResponseData> CallIfAgent(HttpRequestData req, Func<HttpRequestData, Async.Task<HttpResponseData>> method) {
return await CallIf(req, method, allowAgent: true);
}

public async Async.Task<HttpResponseData> CallIf(HttpRequestData req, Func<HttpRequestData, Async.Task<HttpResponseData>> method, bool allowUser = false, bool allowAgent = false) {
var tokenResult = await _context.UserCredentials.ParseJwtToken(req);

if (!tokenResult.IsOk) {
return await _context.RequestHandling.NotOk(req, tokenResult.ErrorV, "token verification", HttpStatusCode.Unauthorized);
}
var token = tokenResult.OkV!;

if (await IsUser(token)) {
if (!allowUser) {
return await Reject(req, token);
}

var access = CheckAccess(req);
if (!access.IsOk) {
return await _context.RequestHandling.NotOk(req, access.ErrorV, "access control", HttpStatusCode.Unauthorized);
}
}


if (await IsAgent(token) && !allowAgent) {
return await Reject(req, token);
}

return await method(req);
}

public async Async.Task<bool> IsUser(UserInfo tokenData) {
return !await IsAgent(tokenData);
}

public async Async.Task<HttpResponseData> Reject(HttpRequestData req, UserInfo token) {
_log.Error(
$"reject token. url:{req.Url} token:{token} body:{await req.ReadAsStringAsync()}"
);

return await _context.RequestHandling.NotOk(
req,
new Error(
ErrorCode.UNAUTHORIZED,
new string[] { "Unrecognized agent" }
),
"token verification",
HttpStatusCode.Unauthorized
);
}

public OneFuzzResultVoid CheckAccess(HttpRequestData req) {
throw new NotImplementedException();
}

public async Async.Task<bool> IsAgent(UserInfo tokenData) {
if (tokenData.ObjectId != null) {
var scalesets = _context.ScalesetOperations.GetByObjectId(tokenData.ObjectId.Value);
if (await scalesets.AnyAsync()) {
return true;
}

var principalId = _context.Creds.GetScalesetPrincipalId();
return principalId == tokenData.ObjectId;
}

if (!tokenData.ApplicationId.HasValue) {
return false;
}

var pools = _context.PoolOperations.GetByClientId(tokenData.ApplicationId.Value);
if (await pools.AnyAsync()) {
return true;
}

return false;
}
}
2 changes: 1 addition & 1 deletion src/ApiService/ApiService/onefuzzlib/NodeOperations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ public NodeMessageOperations(ILogTracer log, IOnefuzzContext context) : base(log
}

public IAsyncEnumerable<NodeMessage> GetMessage(Guid machineId) {
return QueryAsync($"machine_id eq '{machineId}'");
return QueryAsync($"PartitionKey eq '{machineId}'");
}

public async Async.Task ClearMessages(Guid machineId) {
Expand Down
4 changes: 4 additions & 0 deletions src/ApiService/ApiService/onefuzzlib/OnefuzzContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ public interface IOnefuzzContext {
IVmssOperations VmssOperations { get; }
IWebhookMessageLogOperations WebhookMessageLogOperations { get; }
IWebhookOperations WebhookOperations { get; }

IRequestHandling RequestHandling { get; }
}

public class OnefuzzContext : IOnefuzzContext {
Expand Down Expand Up @@ -71,6 +73,8 @@ public class OnefuzzContext : IOnefuzzContext {
public ICreds Creds { get => _serviceProvider.GetService<ICreds>() ?? throw new Exception("No ICreds service"); }
public IServiceConfig ServiceConfiguration { get => _serviceProvider.GetService<IServiceConfig>() ?? throw new Exception("No IServiceConfiguration service"); }

public IRequestHandling RequestHandling { get => _serviceProvider.GetService<IRequestHandling>() ?? throw new Exception("No IRequestHandling service"); }

public OnefuzzContext(IServiceProvider serviceProvider) {
_serviceProvider = serviceProvider;
}
Expand Down
Loading

0 comments on commit 5f4a025

Please sign in to comment.