Skip to content
This repository has been archived by the owner on Nov 1, 2023. It is now read-only.

porting the proxy state machine #2286

Merged
merged 9 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 2 additions & 43 deletions src/ApiService/ApiService/Functions/Scaleset.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
using System.Buffers.Binary;
using System.Diagnostics;
using System.Security.Cryptography;
using System.Threading.Tasks;
using System.Threading.Tasks;
using Microsoft.Azure.Functions.Worker;
using Microsoft.Azure.Functions.Worker.Http;

Expand Down Expand Up @@ -117,7 +114,7 @@ private async Task<HttpResponseData> Post(HttpRequestData req) {
ScalesetId: Guid.NewGuid(),
State: ScalesetState.Init,
NeedsConfigUpdate: false,
Auth: GenerateAuthentication(),
Auth: Auth.BuildAuth(),
PoolName: create.PoolName,
VmSku: create.VmSku,
Image: create.Image,
Expand Down Expand Up @@ -155,44 +152,6 @@ private async Task<HttpResponseData> Post(HttpRequestData req) {
return await RequestHandling.Ok(req, ScalesetResponse.ForScaleset(scaleset));
}

private static Authentication GenerateAuthentication() {
using var rsa = RSA.Create(2048);
var privateKey = rsa.ExportRSAPrivateKey();
var formattedPrivateKey = $"-----BEGIN RSA PRIVATE KEY-----\n{Convert.ToBase64String(privateKey)}\n-----END RSA PRIVATE KEY-----\n";

var publicKey = BuildPublicKey(rsa);
var formattedPublicKey = $"ssh-rsa {Convert.ToBase64String(publicKey)} onefuzz-generated-key";

return new Authentication(
Password: Guid.NewGuid().ToString(),
PublicKey: formattedPublicKey,
PrivateKey: formattedPrivateKey);
}

private static ReadOnlySpan<byte> SSHRSABytes => new byte[] { (byte)'s', (byte)'s', (byte)'h', (byte)'-', (byte)'r', (byte)'s', (byte)'a' };

private static byte[] BuildPublicKey(RSA rsa) {
static Span<byte> WriteLengthPrefixedBytes(ReadOnlySpan<byte> src, Span<byte> dest) {
BinaryPrimitives.WriteInt32BigEndian(dest, src.Length);
dest = dest[sizeof(int)..];
src.CopyTo(dest);
return dest[src.Length..];
}

var parameters = rsa.ExportParameters(includePrivateParameters: false);

// public key format is "ssh-rsa", exponent, modulus, all written
// as (big-endian) length-prefixed bytes
var result = new byte[sizeof(int) + SSHRSABytes.Length + sizeof(int) + parameters.Modulus!.Length + sizeof(int) + parameters.Exponent!.Length];
var spanResult = result.AsSpan();
spanResult = WriteLengthPrefixedBytes(SSHRSABytes, spanResult);
spanResult = WriteLengthPrefixedBytes(parameters.Exponent, spanResult);
spanResult = WriteLengthPrefixedBytes(parameters.Modulus, spanResult);
Debug.Assert(spanResult.Length == 0);

return result;
}

private async Task<HttpResponseData> Patch(HttpRequestData req) {
var request = await RequestHandling.ParseRequest<ScalesetUpdate>(req);
if (!request.IsOk) {
Expand Down
3 changes: 2 additions & 1 deletion src/ApiService/ApiService/Functions/TimerProxy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ public async Async.Task Run([TimerTrigger("00:00:30")] TimerInfo myTimer) {
// As this function is called via a timer, this works around a user
// requesting to use the proxy while this function is checking if it's
// out of date
if (proxy.Outdated) {
if (proxy.Outdated && !(await _context.ProxyOperations.IsUsed(proxy))) {
_logger.Warning($"scaleset-proxy: outdated and not used: {proxy.Region}");
await proxyOperations.SetState(proxy, VmState.Stopping);
// If something is "wrong" with a proxy, delete & recreate it
} else if (!proxyOperations.IsAlive(proxy)) {
Expand Down
2 changes: 1 addition & 1 deletion src/ApiService/ApiService/OneFuzzTypes/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ public record InstanceConfig
[DefaultValue(InitMethod.DefaultConstructor)] NetworkConfig NetworkConfig,
[DefaultValue(InitMethod.DefaultConstructor)] NetworkSecurityGroupConfig ProxyNsgConfig,
AzureVmExtensionConfig? Extensions,
string? ProxyVmSku,
string ProxyVmSku,
bool AllowPoolManagement = true,
IDictionary<Endpoint, ApiAccessRule>? ApiAccessRules = null,
IDictionary<PrincipalId, GroupId[]>? GroupMembership = null,
Expand Down
43 changes: 37 additions & 6 deletions src/ApiService/ApiService/onefuzzlib/Auth.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,45 @@
namespace Microsoft.OneFuzz.Service;

using System.Buffers.Binary;
using System.Diagnostics;
using System.Security.Cryptography;

public class Auth {

private static ReadOnlySpan<byte> SSHRSABytes => new byte[] { (byte)'s', (byte)'s', (byte)'h', (byte)'-', (byte)'r', (byte)'s', (byte)'a' };

private static byte[] BuildPublicKey(RSA rsa) {
static Span<byte> WriteLengthPrefixedBytes(ReadOnlySpan<byte> src, Span<byte> dest) {
BinaryPrimitives.WriteInt32BigEndian(dest, src.Length);
dest = dest[sizeof(int)..];
src.CopyTo(dest);
return dest[src.Length..];
}

var parameters = rsa.ExportParameters(includePrivateParameters: false);

// public key format is "ssh-rsa", exponent, modulus, all written
// as (big-endian) length-prefixed bytes
var result = new byte[sizeof(int) + SSHRSABytes.Length + sizeof(int) + parameters.Modulus!.Length + sizeof(int) + parameters.Exponent!.Length];
var spanResult = result.AsSpan();
spanResult = WriteLengthPrefixedBytes(SSHRSABytes, spanResult);
spanResult = WriteLengthPrefixedBytes(parameters.Exponent, spanResult);
spanResult = WriteLengthPrefixedBytes(parameters.Modulus, spanResult);
Debug.Assert(spanResult.Length == 0);

return result;
}
public static Authentication BuildAuth() {
var rsa = RSA.Create(2048);
string header = "-----BEGIN RSA PRIVATE KEY-----";
string footer = "-----END RSA PRIVATE KEY-----";
var privateKey = $"{header}\n{Convert.ToBase64String(rsa.ExportRSAPrivateKey())}\n{footer}";
var publiceKey = $"{header}\n{Convert.ToBase64String(rsa.ExportRSAPublicKey())}\n{footer}";
return new Authentication(Guid.NewGuid().ToString(), publiceKey, privateKey);
using var rsa = RSA.Create(2048);
var privateKey = rsa.ExportRSAPrivateKey();
var formattedPrivateKey = $"-----BEGIN RSA PRIVATE KEY-----\n{Convert.ToBase64String(privateKey)}\n-----END RSA PRIVATE KEY-----\n";

var publicKey = BuildPublicKey(rsa);
var formattedPublicKey = $"ssh-rsa {Convert.ToBase64String(publicKey)} onefuzz-generated-key";

return new Authentication(
Password: Guid.NewGuid().ToString(),
PublicKey: formattedPublicKey,
PrivateKey: formattedPrivateKey);
}
}
16 changes: 16 additions & 0 deletions src/ApiService/ApiService/onefuzzlib/Extension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ public interface IExtensions {
Async.Task<IList<VirtualMachineScaleSetExtensionData>> FuzzExtensions(Pool pool, Scaleset scaleset);

Async.Task<Dictionary<string, VirtualMachineExtensionData>> ReproExtensions(AzureLocation region, Os reproOs, Guid reproId, ReproConfig reproConfig, Container? setupContainer);
Task<IList<VMExtensionWrapper>> ProxyManagerExtensions(string region, Guid proxyId);
}

public class Extensions : IExtensions {
Expand Down Expand Up @@ -449,4 +450,19 @@ await _context.Containers.GetFileSasUrl(
return extensionsDict;
}

public async Task<IList<VMExtensionWrapper>> ProxyManagerExtensions(string region, Guid proxyId) {
var config = await _context.Containers.GetFileSasUrl(new Container("proxy-configs"),
$"{region}/{proxyId}/config.json", StorageType.Config, BlobSasPermissions.Read);

var proxyManager = await _context.Containers.GetFileSasUrl(new Container("tools"),
$"linux/onefuzz-proxy-manager", StorageType.Config, BlobSasPermissions.Read);


var baseExtension =
await AgentConfig(region, Os.Linux, AgentMode.Proxy, new List<Uri> { config, proxyManager }, true);

var extensions = await GenericExtensions(region, Os.Linux);
extensions.Add(baseExtension);
return extensions;
}
}
13 changes: 11 additions & 2 deletions src/ApiService/ApiService/onefuzzlib/IpOperations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,21 @@ public IpOperations(ILogTracer log, IOnefuzzContext context) {

public async System.Threading.Tasks.Task DeleteNic(string resourceGroup, string name) {
_logTracer.Info($"deleting nic {resourceGroup}:{name}");
await _context.Creds.GetResourceGroupResource().GetNetworkInterfaceAsync(name).Result.Value.DeleteAsync(WaitUntil.Started);
var networkInterface = await _context.Creds.GetResourceGroupResource().GetNetworkInterfaceAsync(name);
try {
await networkInterface.Value.DeleteAsync(WaitUntil.Started);
} catch (RequestFailedException ex) {
if (ex.ErrorCode != "NicReservedForAnotherVm") {
throw;
}
_logTracer.Warning($"unable to delete nic {resourceGroup}:{name} {ex.Message}");
}
}

public async System.Threading.Tasks.Task DeleteIp(string resourceGroup, string name) {
_logTracer.Info($"deleting ip {resourceGroup}:{name}");
await _context.Creds.GetResourceGroupResource().GetPublicIPAddressAsync(name).Result.Value.DeleteAsync(WaitUntil.Started);
var publicIpAddressAsync = await _context.Creds.GetResourceGroupResource().GetPublicIPAddressAsync(name);
await publicIpAddressAsync.Value.DeleteAsync(WaitUntil.Started);
}

public async Task<string?> GetScalesetInstanceIp(Guid scalesetId, Guid machineId) {
Expand Down
157 changes: 152 additions & 5 deletions src/ApiService/ApiService/onefuzzlib/ProxyOperations.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using System.Threading.Tasks;
using ApiService.OneFuzzLib.Orm;
using Azure.ResourceManager.Compute;
using Azure.ResourceManager.Compute.Models;
using Azure.Storage.Sas;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;

Expand All @@ -8,12 +10,13 @@ namespace Microsoft.OneFuzz.Service;
public interface IProxyOperations : IStatefulOrm<Proxy, VmState> {
Task<Proxy?> GetByProxyId(Guid proxyId);

Async.Task SetState(Proxy proxy, VmState state);
Async.Task<Proxy> SetState(Proxy proxy, VmState state);
bool IsAlive(Proxy proxy);
Async.Task SaveProxyConfig(Proxy proxy);
bool IsOutdated(Proxy proxy);
Async.Task<Proxy?> GetOrCreate(string region);

Task<bool> IsUsed(Proxy proxy);
}
public class ProxyOperations : StatefulOrm<Proxy, VmState, ProxyOperations>, IProxyOperations {

Expand Down Expand Up @@ -56,6 +59,15 @@ public ProxyOperations(ILogTracer log, IOnefuzzContext context)
return newProxy;
}

public async Task<bool> IsUsed(Proxy proxy) {
var forwards = await GetForwards(proxy);
if (forwards.Count == 0) {
_logTracer.Info($"no forwards {proxy.Region}");
return false;
}
return true;
}

public bool IsAlive(Proxy proxy) {
var tenMinutesAgo = DateTimeOffset.UtcNow - TimeSpan.FromMinutes(10);

Expand Down Expand Up @@ -110,14 +122,15 @@ public async Async.Task SaveProxyConfig(Proxy proxy) {
}


public async Async.Task SetState(Proxy proxy, VmState state) {
public async Async.Task<Proxy> SetState(Proxy proxy, VmState state) {
if (proxy.State == state) {
return;
return proxy;
}

await Replace(proxy with { State = state });

var newProxy = proxy with { State = state };
await Replace(newProxy);
await _context.Events.SendEvent(new EventProxyStateUpdated(proxy.Region, proxy.ProxyId, proxy.State));
return newProxy;
}


Expand All @@ -133,4 +146,138 @@ public async Async.Task<List<Forward>> GetForwards(Proxy proxy) {
}
return forwards;
}

public async Async.Task<Proxy> Init(Proxy proxy) {
var config = await _context.ConfigOperations.Fetch();
var vm = GetVm(proxy, config);
var vmData = await _context.VmOperations.GetVm(vm.Name);

if (vmData != null) {
if (vmData.ProvisioningState == "Failed") {
return await SetProvisionFailed(proxy, vmData);
} else {
await SaveProxyConfig(proxy);
return await SetState(proxy, VmState.ExtensionsLaunch);
}
} else {
var nsg = new Nsg(proxy.Region, proxy.Region);
var result = await _context.NsgOperations.Create(nsg);
if (!result.IsOk) {
return await SetFailed(proxy, result.ErrorV);
}

var nsgConfig = config.ProxyNsgConfig;
var result2 = await _context.NsgOperations.SetAllowedSources(nsg, nsgConfig);

if (!result2.IsOk) {
return await SetFailed(proxy, result2.ErrorV);
}

var result3 = await _context.VmOperations.Create(vm with { Nsg = nsg });

if (!result3.IsOk) {
return await SetFailed(proxy, result3.ErrorV);
}
return proxy;
}
}

private async System.Threading.Tasks.Task<Proxy> SetProvisionFailed(Proxy proxy, VirtualMachineData vmData) {
var errors = GetErrors(proxy, vmData).ToArray();
await SetFailed(proxy, new Error(ErrorCode.PROXY_FAILED, errors));
return proxy;
}

private async Task<Proxy> SetFailed(Proxy proxy, Error error) {
if (proxy.Error != null) {
return proxy;
}

_logTracer.Error($"vm failed: {proxy.Region} -{error}");
await _context.Events.SendEvent(new EventProxyFailed(proxy.Region, proxy.ProxyId, error));
return await SetState(proxy with { Error = error }, VmState.Stopping);
}


private static IEnumerable<string> GetErrors(Proxy proxy, VirtualMachineData vmData) {
var instanceView = vmData.InstanceView;
yield return "provisioning failed";
chkeita marked this conversation as resolved.
Show resolved Hide resolved
if (instanceView is null) {
yield break;
}

foreach (var status in instanceView.Statuses) {
if (status.Level == StatusLevelTypes.Error) {
yield return $"code:{status.Code} status:{status.DisplayStatus} message:{status.Message}";
}
}
}

public static Vm GetVm(Proxy proxy, InstanceConfig config) {
var tags = config.VmssTags;
const string PROXY_IMAGE = "Canonical:UbuntuServer:18.04-LTS:latest";
return new Vm(
// name should be less than 40 chars otherwise it gets truncated by azure
Name: $"proxy-{proxy.ProxyId:N}",
Region: proxy.Region,
Sku: config.ProxyVmSku,
Image: PROXY_IMAGE,
Auth: proxy.Auth,
Tags: config.VmssTags,
Nsg: null
);
}

public async Task<Proxy> ExtensionsLaunch(Proxy proxy) {
var config = await _context.ConfigOperations.Fetch();
var vm = GetVm(proxy, config);
var vmData = await _context.VmOperations.GetVm(vm.Name);

if (vmData == null) {
return await SetFailed(proxy, new Error(ErrorCode.PROXY_FAILED, new[] { "azure not able to find vm" }));
}

if (vmData.ProvisioningState == "Failed") {
return await SetProvisionFailed(proxy, vmData);
}

var ip = await _context.IpOperations.GetPublicIp(vmData.NetworkProfile.NetworkInterfaces[0].Id);
if (ip == null) {
return proxy;
}

var newProxy = proxy with { Ip = ip };

var extensions = await _context.Extensions.ProxyManagerExtensions(newProxy.Region, newProxy.ProxyId);
var result = await _context.VmOperations.AddExtensions(vm,
extensions
.Select(e => e.GetAsVirtualMachineExtension())
.ToDictionary(x => x.Item1, x => x.Item2));

if (!result.IsOk) {
return await SetFailed(newProxy, result.ErrorV);
}

return await SetState(newProxy, VmState.Running);
}

public async Task<Proxy> Stopping(Proxy proxy) {
var config = await _context.ConfigOperations.Fetch();
var vm = GetVm(proxy, config);
if (!await _context.VmOperations.IsDeleted(vm)) {
_logTracer.Error($"stopping proxy: {proxy.Region}");
await _context.VmOperations.Delete(vm);
return proxy;
}

return await Stopped(proxy);
}

private async Task<Proxy> Stopped(Proxy proxy) {
var stoppedVm = await SetState(proxy, VmState.Stopped);
_logTracer.Info($"removing proxy: {proxy.Region}");
await _context.Events.SendEvent(new EventProxyDeleted(proxy.Region, proxy.ProxyId));
await Delete(proxy);
return stoppedVm;
}
}
Loading