diff --git a/crates/agent/src/api/authorize_task.rs b/crates/agent/src/api/authorize_task.rs index d93242e7b5..24e75f503b 100644 --- a/crates/agent/src/api/authorize_task.rs +++ b/crates/agent/src/api/authorize_task.rs @@ -15,28 +15,7 @@ pub async fn authorize_task( #[tracing::instrument(skip(app), err(level = tracing::Level::WARN))] async fn do_authorize_task(app: &App, Request { token }: &Request) -> anyhow::Result { - let jsonwebtoken::TokenData { header, mut claims }: jsonwebtoken::TokenData< - proto_gazette::Claims, - > = { - // In this pass we do not validate the signature, - // because we don't yet know which data-plane the JWT is signed by. - let empty_key = jsonwebtoken::DecodingKey::from_secret(&[]); - let mut validation = jsonwebtoken::Validation::default(); - validation.insecure_disable_signature_validation(); - jsonwebtoken::decode(token, &empty_key, &validation) - }?; - tracing::debug!(?claims, ?header, "decoded authorization request"); - - let shard_id = claims.sub.as_str(); - if shard_id.is_empty() { - anyhow::bail!("missing required shard ID (`sub` claim)"); - } - - let shard_data_plane_fqdn = claims.iss.as_str(); - if shard_data_plane_fqdn.is_empty() { - anyhow::bail!("missing required shard data-plane FQDN (`iss` claim)"); - } - + let (header, mut claims) = super::parse_untrusted_data_plane_claims(token)?; let journal_name_or_prefix = labels::expect_one(claims.sel.include(), "name")?.to_owned(); // Require the request was signed with the AUTHORIZE capability, @@ -62,8 +41,8 @@ async fn do_authorize_task(app: &App, Request { token }: &Request) -> anyhow::Re match Snapshot::evaluate(&app.snapshot, claims.iat, |snapshot: &Snapshot| { evaluate_authorization( snapshot, - shard_id, - shard_data_plane_fqdn, + &claims.sub, + &claims.iss, token, &journal_name_or_prefix, required_role, @@ -104,47 +83,13 @@ fn evaluate_authorization( journal_name_or_prefix: &str, required_role: models::Capability, ) -> anyhow::Result<(jsonwebtoken::EncodingKey, String, String)> { - // Map `claims.sub`, a Shard ID, into its task. - let task = tasks - .binary_search_by(|task| { - if shard_id.starts_with(&task.shard_template_id) { - std::cmp::Ordering::Equal - } else { - task.shard_template_id.as_str().cmp(shard_id) - } - }) - .ok() - .map(|index| &tasks[index]); - - // Map `claims.iss`, a data-plane FQDN, into its task-matched data-plane. - let task_data_plane = task.and_then(|task| { - data_planes - .get_by_key(&task.data_plane_id) - .filter(|data_plane| data_plane.data_plane_fqdn == shard_data_plane_fqdn) - }); - - let (Some(task), Some(task_data_plane)) = (task, task_data_plane) else { - anyhow::bail!( - "task shard {shard_id} within data-plane {shard_data_plane_fqdn} is not known" - ) - }; - - // Attempt to find an HMAC key of this data-plane which validates against the request token. - let validation = jsonwebtoken::Validation::default(); - let mut verified = false; - - for hmac_key in &task_data_plane.hmac_keys { - let key = jsonwebtoken::DecodingKey::from_base64_secret(hmac_key) - .context("invalid data-plane hmac key")?; - - if jsonwebtoken::decode::(token, &key, &validation).is_ok() { - verified = true; - break; - } - } - if !verified { - anyhow::bail!("no data-plane keys validated against the token signature"); - } + let (task, task_data_plane) = super::verify_data_plane_claims( + data_planes, + tasks, + shard_id, + shard_data_plane_fqdn, + token, + )?; // Map a required `name` journal label selector into its collection. let Some(collection) = collections diff --git a/crates/agent/src/api/mod.rs b/crates/agent/src/api/mod.rs index 60e635d845..888f99b8d4 100644 --- a/crates/agent/src/api/mod.rs +++ b/crates/agent/src/api/mod.rs @@ -5,6 +5,7 @@ mod authorize_task; mod authorize_user_collection; mod authorize_user_task; mod create_data_plane; +mod notify_shard_failure; mod snapshot; mod update_l2_reporting; @@ -125,6 +126,10 @@ pub fn build_router( post(update_l2_reporting::update_l2_reporting) .route_layer(axum::middleware::from_fn_with_state(app.clone(), authorize)), ) + .route( + "/notify/shard-failure", + post(notify_shard_failure::notify_shard_failure), + ) .layer(tower_http::trace::TraceLayer::new_for_http()) .layer(cors) .with_state(app); @@ -239,3 +244,81 @@ fn maybe_rewrite_address(external: bool, address: &str) -> String { address.to_string() } } + +// Parse a data-plane claims token without verifying it's signature. +fn parse_untrusted_data_plane_claims( + token: &str, +) -> anyhow::Result<(jsonwebtoken::Header, proto_gazette::Claims)> { + let jsonwebtoken::TokenData { header, claims }: jsonwebtoken::TokenData = + { + // In this pass we do not validate the signature, + // because we don't yet know which data-plane the JWT is signed by. + let empty_key = jsonwebtoken::DecodingKey::from_secret(&[]); + let mut validation = jsonwebtoken::Validation::default(); + validation.insecure_disable_signature_validation(); + jsonwebtoken::decode(token, &empty_key, &validation) + }?; + + if claims.sub.is_empty() { + anyhow::bail!("missing required shard ID (`sub` claim)"); + } + if claims.iss.is_empty() { + anyhow::bail!("missing required shard data-plane FQDN (`iss` claim)"); + } + + tracing::debug!(?claims, ?header, "decoded authorization request"); + + Ok((header, claims)) +} + +fn verify_data_plane_claims<'s>( + data_planes: &'s tables::DataPlanes, + tasks: &'s [snapshot::SnapshotTask], + shard_id: &str, + shard_data_plane_fqdn: &str, + token: &str, +) -> anyhow::Result<(&'s snapshot::SnapshotTask, &'s tables::DataPlane)> { + // Map `shard_id` into its task. + let task = tasks + .binary_search_by(|task| { + if shard_id.starts_with(&task.shard_template_id) { + std::cmp::Ordering::Equal + } else { + task.shard_template_id.as_str().cmp(shard_id) + } + }) + .ok() + .map(|index| &tasks[index]); + + // Map `shard_data_plane_fqdn` into its task-matched data-plane. + let task_data_plane = task.and_then(|task| { + data_planes + .get_by_key(&task.data_plane_id) + .filter(|data_plane| data_plane.data_plane_fqdn == shard_data_plane_fqdn) + }); + + let (Some(task), Some(task_data_plane)) = (task, task_data_plane) else { + anyhow::bail!( + "task shard {shard_id} within data-plane {shard_data_plane_fqdn} is not known" + ) + }; + + // Attempt to find an HMAC key of this data-plane which validates against the request token. + let validation = jsonwebtoken::Validation::default(); + let mut verified = false; + + for hmac_key in &task_data_plane.hmac_keys { + let key = jsonwebtoken::DecodingKey::from_base64_secret(hmac_key) + .context("invalid data-plane hmac key")?; + + if jsonwebtoken::decode::(token, &key, &validation).is_ok() { + verified = true; + break; + } + } + if !verified { + anyhow::bail!("no data-plane keys validated against the token signature"); + } + + Ok((task, task_data_plane)) +} diff --git a/crates/agent/src/api/notify_shard_failure.rs b/crates/agent/src/api/notify_shard_failure.rs new file mode 100644 index 0000000000..ed4ac1825b --- /dev/null +++ b/crates/agent/src/api/notify_shard_failure.rs @@ -0,0 +1,69 @@ +use super::{App, Snapshot}; +use std::sync::Arc; + +/// Request sent by data-plane reactors to notify the data-plane of a shard failure. +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +#[serde(rename_all = "camelCase")] +pub struct Request { + /// # JWT token which identifies the shard and authorizes the request. + /// The token subject is the shard ID. + pub token: String, + /// # Error encountered by the shard. + pub error: String, +} + +#[derive(Debug, Default, serde::Serialize, schemars::JsonSchema)] +#[serde(rename_all = "camelCase")] +pub struct Response { + /// # Number of milliseconds to wait before retrying the request. + /// Zero if the request was successful. + pub retry_millis: u64, +} + +#[axum::debug_handler] +pub async fn notify_shard_failure( + axum::extract::State(app): axum::extract::State>, + axum::Json(request): axum::Json, +) -> axum::response::Response { + super::wrap(async move { do_notify_shard_failure(&app, &request).await }).await +} + +#[tracing::instrument(skip(app), err(level = tracing::Level::WARN))] +async fn do_notify_shard_failure( + app: &App, + Request { token, error }: &Request, +) -> anyhow::Result { + let (_header, claims) = super::parse_untrusted_data_plane_claims(token)?; + + match Snapshot::evaluate( + &app.snapshot, + claims.iat, + |Snapshot { + data_planes, tasks, .. + }: &Snapshot| { + _ = super::verify_data_plane_claims( + data_planes, + tasks, + &claims.sub, + &claims.iss, + token, + )?; + Ok(()) + }, + ) { + Ok(()) => { + // TODO(johnny): This is a placeholder for enqueuing an automation to perform + // shard restart. + tracing::info!(%error, %claims.sub, %claims.iss, "notified of shard failure"); + + Ok(Response { + ..Default::default() + }) + } + Err(Ok(retry_millis)) => Ok(Response { + retry_millis, + ..Default::default() + }), + Err(Err(err)) => Err(err), + } +} diff --git a/go/runtime/flow_consumer.go b/go/runtime/flow_consumer.go index cb25325320..5451def539 100644 --- a/go/runtime/flow_consumer.go +++ b/go/runtime/flow_consumer.go @@ -61,6 +61,8 @@ func (c *FlowConsumerConfig) Execute(args []string) error { type FlowConsumer struct { // Configuration of this FlowConsumer. config *FlowConsumerConfig + // Control plane of this consumer, or nil in testing contexts. + controlPlane *controlPlane // Running consumer.service. service *consumer.Service // Shared catalog builds. @@ -208,11 +210,10 @@ func (f *FlowConsumer) InitApplication(args runconsumer.InitArgs) error { return fmt.Errorf("catalog builds service: %w", err) } - var controlPlane *controlPlane var localAuthorizer = args.Service.Authorizer if keyedAuth, ok := localAuthorizer.(*auth.KeyedAuth); ok && !config.Flow.TestAPIs { - controlPlane = newControlPlane( + f.controlPlane = newControlPlane( keyedAuth, config.Flow.DataPlaneFQDN, config.Flow.ControlAPI, @@ -221,7 +222,7 @@ func (f *FlowConsumer) InitApplication(args runconsumer.InitArgs) error { // Wrap the underlying KeyedAuth Authorizer to use the control-plane's Authorize API. // Next unwrap the raw JournalClient from its current AuthJournalClient, // and then replace it with one built using our wrapped Authorizer. - args.Service.Authorizer = newControlPlaneAuthorizer(controlPlane) + args.Service.Authorizer = newControlPlaneAuthorizer(f.controlPlane) var rawClient = args.Service.Journals.(*pb.ComposedRoutedJournalClient).JournalClient.(*pb.AuthJournalClient).Inner args.Service.Journals.(*pb.ComposedRoutedJournalClient).JournalClient = pb.NewAuthJournalClient(rawClient, args.Service.Authorizer) } diff --git a/go/runtime/task.go b/go/runtime/task.go index a1eb8a6a4c..c357ef5fa7 100644 --- a/go/runtime/task.go +++ b/go/runtime/task.go @@ -22,6 +22,8 @@ import ( "github.com/estuary/flow/go/shuffle" "github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/types" + "github.com/golang-jwt/jwt/v5" + "github.com/sirupsen/logrus" "go.gazette.dev/core/allocator" "go.gazette.dev/core/broker/client" pb "go.gazette.dev/core/broker/protocol" @@ -254,11 +256,12 @@ func (t *taskReader[TaskSpec]) Coordinator() *shuffle.Coordinator { return t.coo // and then logs the final exit status of the shard. func (t *taskBase[TaskSpec]) heartbeatLoop(shard consumer.Shard) { var ( + id = shard.Spec().Id // Period between regularly-published stat intervals. // This period must cleanly divide into one hour! period = 3 * time.Minute // Jitters when interval stats are written cluster-wide. - jitter = intervalJitter(period, shard.FQN()) + jitter = intervalJitter(period, id) // Op notified when the shard fails. op = shard.PrimaryLoop() ) @@ -289,8 +292,33 @@ func (t *taskBase[TaskSpec]) heartbeatLoop(shard consumer.Shard) { "assignment", shard.Assignment().Decoded, ) - // TODO(johnny): Notify control-plane of failure. + // Notify control-plane of the failure. + var now = time.Now() + var claims = pb.Claims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: id.String(), + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)), + }, + } + var token, postErr = t.host.controlPlane.signClaims(claims) + + var request = struct { + Token string `json:"token"` + Error string `json:"error"` + }{token, err.Error()} + + var response struct { + // No response fields. + } + + if postErr == nil { + postErr = callControlAPI(shard.Context(), t.host.controlPlane, "/notify/shard-failure", &request, &response) + } + if postErr != nil { + logrus.WithField("err", postErr).Error("failed to notify control plane of shard failure") + } return } } @@ -325,9 +353,9 @@ func durationToNextInterval(now time.Time, period time.Duration) time.Duration { // intervalJitter returns a globally consistent, unique jitter offset for `name` // so that heartbeats are uniformly distributed over time, in aggregate. -func intervalJitter(period time.Duration, name string) time.Duration { +func intervalJitter(period time.Duration, id pc.ShardID) time.Duration { var w = fnv.New32() - w.Write([]byte(name)) + w.Write([]byte(id)) return time.Duration(w.Sum32()%uint32(period.Seconds())) * time.Second } diff --git a/go/runtime/task_test.go b/go/runtime/task_test.go index f50f8482e1..ec6c3a32ce 100644 --- a/go/runtime/task_test.go +++ b/go/runtime/task_test.go @@ -7,16 +7,17 @@ import ( pf "github.com/estuary/flow/go/protocols/flow" "github.com/estuary/flow/go/protocols/ops" "github.com/stretchr/testify/require" + pc "go.gazette.dev/core/consumer/protocol" ) func TestIntervalJitterAndDurations(t *testing.T) { const period = time.Minute for _, tc := range []struct { - n string - i time.Duration + id pc.ShardID + i time.Duration }{{"foo", 35}, {"bar", 52}, {"baz", 0}, {"bing", 39}, {"quip", 56}} { - require.Equal(t, time.Second*tc.i, intervalJitter(period, tc.n), tc.n) + require.Equal(t, time.Second*tc.i, intervalJitter(period, tc.id), tc.id) }