diff --git a/src/ApiService/ApiService/Functions/Pool.cs b/src/ApiService/ApiService/Functions/Pool.cs index d9fefc49fe..6ca19ec640 100644 --- a/src/ApiService/ApiService/Functions/Pool.cs +++ b/src/ApiService/ApiService/Functions/Pool.cs @@ -133,7 +133,8 @@ private async Task Populate(PoolGetResult p, bool skipSummaries = HeartbeatQueue: queueSas, InstanceId: instanceId, ClientCredentials: null, - MultiTenantDomain: _context.ServiceConfiguration.MultiTenantDomain) + MultiTenantDomain: _context.ServiceConfiguration.MultiTenantDomain, + Managed: p.Managed) }; } } diff --git a/src/ApiService/ApiService/OneFuzzTypes/Model.cs b/src/ApiService/ApiService/OneFuzzTypes/Model.cs index c7620ea38c..09a1bc4e90 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Model.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Model.cs @@ -678,12 +678,11 @@ public record AgentConfig( string? InstanceTelemetryKey, string? MicrosoftTelemetryKey, string? MultiTenantDomain, - Guid InstanceId + Guid InstanceId, + bool? Managed = true ); - - public record Vm( string Name, Region Region, diff --git a/src/ApiService/ApiService/UserCredentials.cs b/src/ApiService/ApiService/UserCredentials.cs index 7446b0fd9b..fb9bf2d7f7 100644 --- a/src/ApiService/ApiService/UserCredentials.cs +++ b/src/ApiService/ApiService/UserCredentials.cs @@ -77,7 +77,7 @@ public virtual async Task> ParseJwtToken(HttpRequest switch (claim.Type) { case "oid": return acc with { UserInfo = acc.UserInfo with { ObjectId = Guid.Parse(claim.Value) } }; - case "appId": + case "appid": return acc with { UserInfo = acc.UserInfo with { ApplicationId = Guid.Parse(claim.Value) } }; case "upn": return acc with { UserInfo = acc.UserInfo with { Upn = claim.Value } }; @@ -88,7 +88,6 @@ public virtual async Task> ParseJwtToken(HttpRequest return acc; } }); - return OneFuzzResult.Ok(userInfo); } else { var tenantsStr = allowedTenants.OkV is null ? "null" : String.Join(';', allowedTenants.OkV!); diff --git a/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs b/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs index 045d43ac9e..6eccef84a4 100644 --- a/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs +++ b/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs @@ -30,8 +30,7 @@ public class EndpointAuthorization : IEndpointAuthorization { private readonly IOnefuzzContext _context; private readonly ILogTracer _log; private readonly GraphServiceClient _graphClient; - - private static readonly HashSet AgentRoles = new HashSet { "UnmamagedNode", "ManagedNode" }; + private static readonly HashSet AgentRoles = new HashSet { "UnmanagedNode", "ManagedNode" }; public EndpointAuthorization(IOnefuzzContext context, ILogTracer log, GraphServiceClient graphClient) { _context = context; @@ -46,10 +45,10 @@ public virtual async Async.Task CallIf(HttpRequestData req, Fu return await _context.RequestHandling.NotOk(req, tokenResult.ErrorV, "token verification", HttpStatusCode.Unauthorized); } - var token = tokenResult.OkV; - if (await IsUser(token)) { + var token = tokenResult.OkV.UserInfo; + if (await IsUser(tokenResult.OkV)) { if (!allowUser) { - return await Reject(req, tokenResult.OkV.UserInfo); + return await Reject(req, token); } var access = await CheckAccess(req); @@ -58,8 +57,8 @@ public virtual async Async.Task CallIf(HttpRequestData req, Fu } } - if (await IsAgent(token) && !allowAgent) { - return await Reject(req, tokenResult.OkV.UserInfo); + if (await IsAgent(tokenResult.OkV) && !allowAgent) { + return await Reject(req, token); } return await method(req); @@ -201,7 +200,9 @@ public async Async.Task IsAgent(UserAuthInfo authInfo) { } var principalId = await _context.Creds.GetScalesetPrincipalId(); - return principalId == tokenData.ObjectId; + if (principalId == tokenData.ObjectId) { + return true; + } } if (!tokenData.ApplicationId.HasValue) { diff --git a/src/ApiService/ApiService/onefuzzlib/Extension.cs b/src/ApiService/ApiService/onefuzzlib/Extension.cs index d9237104bd..a9bc508852 100644 --- a/src/ApiService/ApiService/onefuzzlib/Extension.cs +++ b/src/ApiService/ApiService/onefuzzlib/Extension.cs @@ -220,7 +220,8 @@ public static VMExtensionWrapper GenevaExtension(AzureLocation region) { InstanceTelemetryKey: _context.ServiceConfiguration.ApplicationInsightsInstrumentationKey, MicrosoftTelemetryKey: _context.ServiceConfiguration.OneFuzzTelemetry, MultiTenantDomain: _context.ServiceConfiguration.MultiTenantDomain, - InstanceId: instanceId + InstanceId: instanceId, + Managed: pool.Managed ); var fileName = $"{pool.Name}/config.json"; diff --git a/src/ApiService/ApiService/onefuzzlib/PoolOperations.cs b/src/ApiService/ApiService/onefuzzlib/PoolOperations.cs index 27867ab470..09c0ec342c 100644 --- a/src/ApiService/ApiService/onefuzzlib/PoolOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/PoolOperations.cs @@ -1,6 +1,5 @@ using System.Threading.Tasks; using ApiService.OneFuzzLib.Orm; -using Azure.Data.Tables; namespace Microsoft.OneFuzz.Service; public interface IPoolOperations : IStatefulOrm { @@ -89,7 +88,7 @@ public async Task ScheduleWorkset(Pool pool, WorkSet workSet) { } public IAsyncEnumerable GetByClientId(Guid clientId) { - return QueryAsync(filter: TableClient.CreateQueryFilter($"client_id eq {clientId}")); + return QueryAsync(filter: $"client_id eq '{clientId}'"); } public string GetPoolQueue(Guid poolId) diff --git a/src/agent/onefuzz-agent/src/config.rs b/src/agent/onefuzz-agent/src/config.rs index 7d3060e2e5..e0fafe4841 100644 --- a/src/agent/onefuzz-agent/src/config.rs +++ b/src/agent/onefuzz-agent/src/config.rs @@ -34,12 +34,19 @@ pub struct StaticConfig { pub heartbeat_queue: Option, pub instance_id: Uuid, + + #[serde(default = "default_as_true")] + pub managed: bool, +} + +fn default_as_true() -> bool { + true } // Temporary shim type to bridge the current service-provided config. #[derive(Clone, Debug, Deserialize, Eq, PartialEq)] struct RawStaticConfig { - pub credentials: Option, + pub client_credentials: Option, pub pool_name: String, @@ -54,13 +61,16 @@ struct RawStaticConfig { pub heartbeat_queue: Option, pub instance_id: Uuid, + + #[serde(default = "default_as_true")] + pub managed: bool, } impl StaticConfig { pub fn new(data: &[u8]) -> Result { let config: RawStaticConfig = serde_json::from_slice(data)?; - let credentials = match config.credentials { + let credentials = match config.client_credentials { Some(client) => client.into(), None => { // Remove trailing `/`, which is treated as a distinct resource. @@ -83,6 +93,7 @@ impl StaticConfig { instance_telemetry_key: config.instance_telemetry_key, heartbeat_queue: config.heartbeat_queue, instance_id: config.instance_id, + managed: config.managed, }; Ok(config) @@ -103,6 +114,7 @@ impl StaticConfig { let multi_tenant_domain = std::env::var("ONEFUZZ_MULTI_TENANT_DOMAIN").ok(); let onefuzz_url = Url::parse(&std::env::var("ONEFUZZ_URL")?)?; let pool_name = std::env::var("ONEFUZZ_POOL")?; + let is_unmanaged = std::env::var("ONEFUZZ_IS_UNMANAGED").is_ok(); let heartbeat_queue = if let Ok(key) = std::env::var("ONEFUZZ_HEARTBEAT") { Some(Url::parse(&key)?) @@ -142,6 +154,7 @@ impl StaticConfig { microsoft_telemetry_key, heartbeat_queue, instance_id, + managed: !is_unmanaged, }) } @@ -213,7 +226,8 @@ impl Registration { .append_pair("machine_id", &machine_id.to_string()) .append_pair("machine_name", &machine_name) .append_pair("pool_name", &config.pool_name) - .append_pair("version", env!("ONEFUZZ_VERSION")); + .append_pair("version", env!("ONEFUZZ_VERSION")) + .append_pair("os", std::env::consts::OS); if managed { let scaleset = onefuzz::machine_id::get_scaleset_name().await?; diff --git a/src/agent/onefuzz-agent/src/main.rs b/src/agent/onefuzz-agent/src/main.rs index 56ff42c7b1..642e453c98 100644 --- a/src/agent/onefuzz-agent/src/main.rs +++ b/src/agent/onefuzz-agent/src/main.rs @@ -277,7 +277,7 @@ async fn run_agent(config: StaticConfig) -> Result<()> { let registration = match config::Registration::load_existing(config.clone()).await { Ok(registration) => registration, Err(_) => { - if scaleset.is_some() { + if config.managed { config::Registration::create_managed(config.clone()).await? } else { config::Registration::create_unmanaged(config.clone()).await? diff --git a/src/agent/onefuzz/src/auth.rs b/src/agent/onefuzz/src/auth.rs index 19a9b08d66..d25a3807f9 100644 --- a/src/agent/onefuzz/src/auth.rs +++ b/src/agent/onefuzz/src/auth.rs @@ -134,7 +134,6 @@ impl ClientCredentials { let response = reqwest::Client::new() .post(url) - .header("Content-Length", "0") .form(&[ ("client_id", self.client_id.to_hyphenated().to_string()), ("client_secret", self.client_secret.expose_ref().to_string()), diff --git a/src/cli/onefuzz/api.py b/src/cli/onefuzz/api.py index b091ded41e..c77e4e3dd6 100644 --- a/src/cli/onefuzz/api.py +++ b/src/cli/onefuzz/api.py @@ -1268,13 +1268,7 @@ def get_config(self, pool_name: primitives.PoolName) -> models.AgentConfig: if pool.config is None: raise Exception("Missing AgentConfig in response") - config = pool.config - config.client_credentials = models.ClientCredentials( # nosec - bandit consider this a hard coded password - client_id=pool.client_id, - client_secret="", - ) - - return config + return pool.config def shutdown(self, name: str, *, now: bool = False) -> responses.BoolResult: expanded_name = self._disambiguate( diff --git a/src/pytypes/onefuzztypes/models.py b/src/pytypes/onefuzztypes/models.py index 0913fcbf61..858dd18c3d 100644 --- a/src/pytypes/onefuzztypes/models.py +++ b/src/pytypes/onefuzztypes/models.py @@ -339,6 +339,7 @@ class AgentConfig(BaseModel): microsoft_telemetry_key: Optional[str] multi_tenant_domain: Optional[str] instance_id: UUID + managed: Optional[bool] = Field(default=True) class TaskUnitConfig(BaseModel):