Skip to content
This repository has been archived by the owner on Nov 1, 2023. It is now read-only.

Bug fixes related to the unmanaged nodes #2632

Merged
merged 6 commits into from
Nov 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/ApiService/ApiService/Functions/Pool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ private async Task<PoolGetResult> Populate(PoolGetResult p, bool skipSummaries =
HeartbeatQueue: queueSas,
InstanceId: instanceId,
ClientCredentials: null,
MultiTenantDomain: _context.ServiceConfiguration.MultiTenantDomain)
MultiTenantDomain: _context.ServiceConfiguration.MultiTenantDomain,
Managed: p.Managed)
};
}
}
5 changes: 2 additions & 3 deletions src/ApiService/ApiService/OneFuzzTypes/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -678,12 +678,11 @@ public record AgentConfig(
string? InstanceTelemetryKey,
string? MicrosoftTelemetryKey,
string? MultiTenantDomain,
Guid InstanceId
Guid InstanceId,
bool? Managed = true
);




public record Vm(
string Name,
Region Region,
Expand Down
3 changes: 1 addition & 2 deletions src/ApiService/ApiService/UserCredentials.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public virtual async Task<OneFuzzResult<UserAuthInfo>> ParseJwtToken(HttpRequest
switch (claim.Type) {
case "oid":
return acc with { UserInfo = acc.UserInfo with { ObjectId = Guid.Parse(claim.Value) } };
case "appId":
case "appid":
return acc with { UserInfo = acc.UserInfo with { ApplicationId = Guid.Parse(claim.Value) } };
case "upn":
return acc with { UserInfo = acc.UserInfo with { Upn = claim.Value } };
Expand All @@ -88,7 +88,6 @@ public virtual async Task<OneFuzzResult<UserAuthInfo>> ParseJwtToken(HttpRequest
return acc;
}
});

return OneFuzzResult<UserAuthInfo>.Ok(userInfo);
} else {
var tenantsStr = allowedTenants.OkV is null ? "null" : String.Join(';', allowedTenants.OkV!);
Expand Down
17 changes: 9 additions & 8 deletions src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ public class EndpointAuthorization : IEndpointAuthorization {
private readonly IOnefuzzContext _context;
private readonly ILogTracer _log;
private readonly GraphServiceClient _graphClient;

private static readonly HashSet<string> AgentRoles = new HashSet<string> { "UnmamagedNode", "ManagedNode" };
private static readonly HashSet<string> AgentRoles = new HashSet<string> { "UnmanagedNode", "ManagedNode" };

public EndpointAuthorization(IOnefuzzContext context, ILogTracer log, GraphServiceClient graphClient) {
_context = context;
Expand All @@ -46,10 +45,10 @@ public virtual async Async.Task<HttpResponseData> CallIf(HttpRequestData req, Fu
return await _context.RequestHandling.NotOk(req, tokenResult.ErrorV, "token verification", HttpStatusCode.Unauthorized);
}

var token = tokenResult.OkV;
if (await IsUser(token)) {
var token = tokenResult.OkV.UserInfo;
if (await IsUser(tokenResult.OkV)) {
if (!allowUser) {
return await Reject(req, tokenResult.OkV.UserInfo);
return await Reject(req, token);
}

var access = await CheckAccess(req);
Expand All @@ -58,8 +57,8 @@ public virtual async Async.Task<HttpResponseData> CallIf(HttpRequestData req, Fu
}
}

if (await IsAgent(token) && !allowAgent) {
return await Reject(req, tokenResult.OkV.UserInfo);
if (await IsAgent(tokenResult.OkV) && !allowAgent) {
return await Reject(req, token);
}

return await method(req);
Expand Down Expand Up @@ -201,7 +200,9 @@ public async Async.Task<bool> IsAgent(UserAuthInfo authInfo) {
}

var principalId = await _context.Creds.GetScalesetPrincipalId();
return principalId == tokenData.ObjectId;
if (principalId == tokenData.ObjectId) {
return true;
}
}

if (!tokenData.ApplicationId.HasValue) {
Expand Down
3 changes: 2 additions & 1 deletion src/ApiService/ApiService/onefuzzlib/Extension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ public static VMExtensionWrapper GenevaExtension(AzureLocation region) {
InstanceTelemetryKey: _context.ServiceConfiguration.ApplicationInsightsInstrumentationKey,
MicrosoftTelemetryKey: _context.ServiceConfiguration.OneFuzzTelemetry,
MultiTenantDomain: _context.ServiceConfiguration.MultiTenantDomain,
InstanceId: instanceId
InstanceId: instanceId,
Managed: pool.Managed
);

var fileName = $"{pool.Name}/config.json";
Expand Down
3 changes: 1 addition & 2 deletions src/ApiService/ApiService/onefuzzlib/PoolOperations.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System.Threading.Tasks;
using ApiService.OneFuzzLib.Orm;
using Azure.Data.Tables;
namespace Microsoft.OneFuzz.Service;

public interface IPoolOperations : IStatefulOrm<Pool, PoolState> {
Expand Down Expand Up @@ -89,7 +88,7 @@ public async Task<bool> ScheduleWorkset(Pool pool, WorkSet workSet) {
}

public IAsyncEnumerable<Pool> GetByClientId(Guid clientId) {
return QueryAsync(filter: TableClient.CreateQueryFilter($"client_id eq {clientId}"));
return QueryAsync(filter: $"client_id eq '{clientId}'");
}

public string GetPoolQueue(Guid poolId)
Expand Down
20 changes: 17 additions & 3 deletions src/agent/onefuzz-agent/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,19 @@ pub struct StaticConfig {
pub heartbeat_queue: Option<Url>,

pub instance_id: Uuid,

#[serde(default = "default_as_true")]
pub managed: bool,
}

fn default_as_true() -> bool {
true
}

// Temporary shim type to bridge the current service-provided config.
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
struct RawStaticConfig {
pub credentials: Option<ClientCredentials>,
pub client_credentials: Option<ClientCredentials>,

pub pool_name: String,

Expand All @@ -54,13 +61,16 @@ struct RawStaticConfig {
pub heartbeat_queue: Option<Url>,

pub instance_id: Uuid,

#[serde(default = "default_as_true")]
pub managed: bool,
}

impl StaticConfig {
pub fn new(data: &[u8]) -> Result<Self> {
let config: RawStaticConfig = serde_json::from_slice(data)?;

let credentials = match config.credentials {
let credentials = match config.client_credentials {
Some(client) => client.into(),
None => {
// Remove trailing `/`, which is treated as a distinct resource.
Expand All @@ -83,6 +93,7 @@ impl StaticConfig {
instance_telemetry_key: config.instance_telemetry_key,
heartbeat_queue: config.heartbeat_queue,
instance_id: config.instance_id,
managed: config.managed,
};

Ok(config)
Expand All @@ -103,6 +114,7 @@ impl StaticConfig {
let multi_tenant_domain = std::env::var("ONEFUZZ_MULTI_TENANT_DOMAIN").ok();
let onefuzz_url = Url::parse(&std::env::var("ONEFUZZ_URL")?)?;
let pool_name = std::env::var("ONEFUZZ_POOL")?;
let is_unmanaged = std::env::var("ONEFUZZ_IS_UNMANAGED").is_ok();

let heartbeat_queue = if let Ok(key) = std::env::var("ONEFUZZ_HEARTBEAT") {
Some(Url::parse(&key)?)
Expand Down Expand Up @@ -142,6 +154,7 @@ impl StaticConfig {
microsoft_telemetry_key,
heartbeat_queue,
instance_id,
managed: !is_unmanaged,
})
}

Expand Down Expand Up @@ -213,7 +226,8 @@ impl Registration {
.append_pair("machine_id", &machine_id.to_string())
.append_pair("machine_name", &machine_name)
.append_pair("pool_name", &config.pool_name)
.append_pair("version", env!("ONEFUZZ_VERSION"));
.append_pair("version", env!("ONEFUZZ_VERSION"))
.append_pair("os", std::env::consts::OS);

if managed {
let scaleset = onefuzz::machine_id::get_scaleset_name().await?;
Expand Down
2 changes: 1 addition & 1 deletion src/agent/onefuzz-agent/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ async fn run_agent(config: StaticConfig) -> Result<()> {
let registration = match config::Registration::load_existing(config.clone()).await {
Ok(registration) => registration,
Err(_) => {
if scaleset.is_some() {
if config.managed {
config::Registration::create_managed(config.clone()).await?
} else {
config::Registration::create_unmanaged(config.clone()).await?
Expand Down
1 change: 0 additions & 1 deletion src/agent/onefuzz/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ impl ClientCredentials {

let response = reqwest::Client::new()
.post(url)
.header("Content-Length", "0")
.form(&[
("client_id", self.client_id.to_hyphenated().to_string()),
("client_secret", self.client_secret.expose_ref().to_string()),
Expand Down
8 changes: 1 addition & 7 deletions src/cli/onefuzz/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,13 +1268,7 @@ def get_config(self, pool_name: primitives.PoolName) -> models.AgentConfig:
if pool.config is None:
raise Exception("Missing AgentConfig in response")

config = pool.config
config.client_credentials = models.ClientCredentials( # nosec - bandit consider this a hard coded password
client_id=pool.client_id,
client_secret="<client secret>",
)

return config
return pool.config

def shutdown(self, name: str, *, now: bool = False) -> responses.BoolResult:
expanded_name = self._disambiguate(
Expand Down
1 change: 1 addition & 0 deletions src/pytypes/onefuzztypes/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ class AgentConfig(BaseModel):
microsoft_telemetry_key: Optional[str]
multi_tenant_domain: Optional[str]
instance_id: UUID
managed: Optional[bool] = Field(default=True)


class TaskUnitConfig(BaseModel):
Expand Down