Skip to content

Commit

Permalink
go/runtime + agent: introduce /notify/shard-failure API
Browse files Browse the repository at this point in the history
The runtime invokes a new /notify/shard-failure control-plane API which
is told of shard failures that have occurred within a data-plane.

At the moment, this API verifies the data-plane token and logs the
failure, but takes no further action.

Update the taskBase.heartbeatLoop() to perform this notification if the
shard's primary loop exits with a non-cancellation error.

Issue #1666
  • Loading branch information
jgraettinger committed Sep 30, 2024
1 parent fd93dc4 commit 5a20739
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 75 deletions.
75 changes: 10 additions & 65 deletions crates/agent/src/api/authorize_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,7 @@ pub async fn authorize_task(

#[tracing::instrument(skip(app), err(level = tracing::Level::WARN))]
async fn do_authorize_task(app: &App, Request { token }: &Request) -> anyhow::Result<Response> {
let jsonwebtoken::TokenData { header, mut claims }: jsonwebtoken::TokenData<
proto_gazette::Claims,
> = {
// In this pass we do not validate the signature,
// because we don't yet know which data-plane the JWT is signed by.
let empty_key = jsonwebtoken::DecodingKey::from_secret(&[]);
let mut validation = jsonwebtoken::Validation::default();
validation.insecure_disable_signature_validation();
jsonwebtoken::decode(token, &empty_key, &validation)
}?;
tracing::debug!(?claims, ?header, "decoded authorization request");

let shard_id = claims.sub.as_str();
if shard_id.is_empty() {
anyhow::bail!("missing required shard ID (`sub` claim)");
}

let shard_data_plane_fqdn = claims.iss.as_str();
if shard_data_plane_fqdn.is_empty() {
anyhow::bail!("missing required shard data-plane FQDN (`iss` claim)");
}

let (header, mut claims) = super::parse_untrusted_data_plane_claims(token)?;
let journal_name_or_prefix = labels::expect_one(claims.sel.include(), "name")?.to_owned();

// Require the request was signed with the AUTHORIZE capability,
Expand All @@ -62,8 +41,8 @@ async fn do_authorize_task(app: &App, Request { token }: &Request) -> anyhow::Re
match Snapshot::evaluate(&app.snapshot, claims.iat, |snapshot: &Snapshot| {
evaluate_authorization(
snapshot,
shard_id,
shard_data_plane_fqdn,
&claims.sub,
&claims.iss,
token,
&journal_name_or_prefix,
required_role,
Expand Down Expand Up @@ -104,47 +83,13 @@ fn evaluate_authorization(
journal_name_or_prefix: &str,
required_role: models::Capability,
) -> anyhow::Result<(jsonwebtoken::EncodingKey, String, String)> {
// Map `claims.sub`, a Shard ID, into its task.
let task = tasks
.binary_search_by(|task| {
if shard_id.starts_with(&task.shard_template_id) {
std::cmp::Ordering::Equal
} else {
task.shard_template_id.as_str().cmp(shard_id)
}
})
.ok()
.map(|index| &tasks[index]);

// Map `claims.iss`, a data-plane FQDN, into its task-matched data-plane.
let task_data_plane = task.and_then(|task| {
data_planes
.get_by_key(&task.data_plane_id)
.filter(|data_plane| data_plane.data_plane_fqdn == shard_data_plane_fqdn)
});

let (Some(task), Some(task_data_plane)) = (task, task_data_plane) else {
anyhow::bail!(
"task shard {shard_id} within data-plane {shard_data_plane_fqdn} is not known"
)
};

// Attempt to find an HMAC key of this data-plane which validates against the request token.
let validation = jsonwebtoken::Validation::default();
let mut verified = false;

for hmac_key in &task_data_plane.hmac_keys {
let key = jsonwebtoken::DecodingKey::from_base64_secret(hmac_key)
.context("invalid data-plane hmac key")?;

if jsonwebtoken::decode::<proto_gazette::Claims>(token, &key, &validation).is_ok() {
verified = true;
break;
}
}
if !verified {
anyhow::bail!("no data-plane keys validated against the token signature");
}
let (task, task_data_plane) = super::verify_data_plane_claims(
data_planes,
tasks,
shard_id,
shard_data_plane_fqdn,
token,
)?;

// Map a required `name` journal label selector into its collection.
let Some(collection) = collections
Expand Down
83 changes: 83 additions & 0 deletions crates/agent/src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod authorize_task;
mod authorize_user_collection;
mod authorize_user_task;
mod create_data_plane;
mod notify_shard_failure;
mod snapshot;
mod update_l2_reporting;

Expand Down Expand Up @@ -125,6 +126,10 @@ pub fn build_router(
post(update_l2_reporting::update_l2_reporting)
.route_layer(axum::middleware::from_fn_with_state(app.clone(), authorize)),
)
.route(
"/notify/shard-failure",
post(notify_shard_failure::notify_shard_failure),
)
.layer(tower_http::trace::TraceLayer::new_for_http())
.layer(cors)
.with_state(app);
Expand Down Expand Up @@ -239,3 +244,81 @@ fn maybe_rewrite_address(external: bool, address: &str) -> String {
address.to_string()
}
}

// Parse a data-plane claims token without verifying it's signature.
fn parse_untrusted_data_plane_claims(
token: &str,
) -> anyhow::Result<(jsonwebtoken::Header, proto_gazette::Claims)> {
let jsonwebtoken::TokenData { header, claims }: jsonwebtoken::TokenData<proto_gazette::Claims> =
{
// In this pass we do not validate the signature,
// because we don't yet know which data-plane the JWT is signed by.
let empty_key = jsonwebtoken::DecodingKey::from_secret(&[]);
let mut validation = jsonwebtoken::Validation::default();
validation.insecure_disable_signature_validation();
jsonwebtoken::decode(token, &empty_key, &validation)
}?;

if claims.sub.is_empty() {
anyhow::bail!("missing required shard ID (`sub` claim)");
}
if claims.iss.is_empty() {
anyhow::bail!("missing required shard data-plane FQDN (`iss` claim)");
}

tracing::debug!(?claims, ?header, "decoded authorization request");

Ok((header, claims))
}

fn verify_data_plane_claims<'s>(
data_planes: &'s tables::DataPlanes,
tasks: &'s [snapshot::SnapshotTask],
shard_id: &str,
shard_data_plane_fqdn: &str,
token: &str,
) -> anyhow::Result<(&'s snapshot::SnapshotTask, &'s tables::DataPlane)> {
// Map `shard_id` into its task.
let task = tasks
.binary_search_by(|task| {
if shard_id.starts_with(&task.shard_template_id) {
std::cmp::Ordering::Equal
} else {
task.shard_template_id.as_str().cmp(shard_id)
}
})
.ok()
.map(|index| &tasks[index]);

// Map `shard_data_plane_fqdn` into its task-matched data-plane.
let task_data_plane = task.and_then(|task| {
data_planes
.get_by_key(&task.data_plane_id)
.filter(|data_plane| data_plane.data_plane_fqdn == shard_data_plane_fqdn)
});

let (Some(task), Some(task_data_plane)) = (task, task_data_plane) else {
anyhow::bail!(
"task shard {shard_id} within data-plane {shard_data_plane_fqdn} is not known"
)
};

// Attempt to find an HMAC key of this data-plane which validates against the request token.
let validation = jsonwebtoken::Validation::default();
let mut verified = false;

for hmac_key in &task_data_plane.hmac_keys {
let key = jsonwebtoken::DecodingKey::from_base64_secret(hmac_key)
.context("invalid data-plane hmac key")?;

if jsonwebtoken::decode::<proto_gazette::Claims>(token, &key, &validation).is_ok() {
verified = true;
break;
}
}
if !verified {
anyhow::bail!("no data-plane keys validated against the token signature");
}

Ok((task, task_data_plane))
}
69 changes: 69 additions & 0 deletions crates/agent/src/api/notify_shard_failure.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use super::{App, Snapshot};
use std::sync::Arc;

/// Request sent by data-plane reactors to notify the data-plane of a shard failure.
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct Request {
/// # JWT token which identifies the shard and authorizes the request.
/// The token subject is the shard ID.
pub token: String,
/// # Error encountered by the shard.
pub error: String,
}

#[derive(Debug, Default, serde::Serialize, schemars::JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct Response {
/// # Number of milliseconds to wait before retrying the request.
/// Zero if the request was successful.
pub retry_millis: u64,
}

#[axum::debug_handler]
pub async fn notify_shard_failure(
axum::extract::State(app): axum::extract::State<Arc<App>>,
axum::Json(request): axum::Json<Request>,
) -> axum::response::Response {
super::wrap(async move { do_notify_shard_failure(&app, &request).await }).await
}

#[tracing::instrument(skip(app), err(level = tracing::Level::WARN))]
async fn do_notify_shard_failure(
app: &App,
Request { token, error }: &Request,
) -> anyhow::Result<Response> {
let (_header, claims) = super::parse_untrusted_data_plane_claims(token)?;

match Snapshot::evaluate(
&app.snapshot,
claims.iat,
|Snapshot {
data_planes, tasks, ..
}: &Snapshot| {
_ = super::verify_data_plane_claims(
data_planes,
tasks,
&claims.sub,
&claims.iss,
token,
)?;
Ok(())
},
) {
Ok(()) => {
// TODO(johnny): This is a placeholder for enqueuing an automation to perform
// shard restart.
tracing::info!(%error, %claims.sub, %claims.iss, "notified of shard failure");

Ok(Response {
..Default::default()
})
}
Err(Ok(retry_millis)) => Ok(Response {
retry_millis,
..Default::default()
}),
Err(Err(err)) => Err(err),
}
}
7 changes: 4 additions & 3 deletions go/runtime/flow_consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ func (c *FlowConsumerConfig) Execute(args []string) error {
type FlowConsumer struct {
// Configuration of this FlowConsumer.
config *FlowConsumerConfig
// Control plane of this consumer, or nil in testing contexts.
controlPlane *controlPlane
// Running consumer.service.
service *consumer.Service
// Shared catalog builds.
Expand Down Expand Up @@ -208,11 +210,10 @@ func (f *FlowConsumer) InitApplication(args runconsumer.InitArgs) error {
return fmt.Errorf("catalog builds service: %w", err)
}

var controlPlane *controlPlane
var localAuthorizer = args.Service.Authorizer

if keyedAuth, ok := localAuthorizer.(*auth.KeyedAuth); ok && !config.Flow.TestAPIs {
controlPlane = newControlPlane(
f.controlPlane = newControlPlane(
keyedAuth,
config.Flow.DataPlaneFQDN,
config.Flow.ControlAPI,
Expand All @@ -221,7 +222,7 @@ func (f *FlowConsumer) InitApplication(args runconsumer.InitArgs) error {
// Wrap the underlying KeyedAuth Authorizer to use the control-plane's Authorize API.
// Next unwrap the raw JournalClient from its current AuthJournalClient,
// and then replace it with one built using our wrapped Authorizer.
args.Service.Authorizer = newControlPlaneAuthorizer(controlPlane)
args.Service.Authorizer = newControlPlaneAuthorizer(f.controlPlane)
var rawClient = args.Service.Journals.(*pb.ComposedRoutedJournalClient).JournalClient.(*pb.AuthJournalClient).Inner
args.Service.Journals.(*pb.ComposedRoutedJournalClient).JournalClient = pb.NewAuthJournalClient(rawClient, args.Service.Authorizer)
}
Expand Down
36 changes: 32 additions & 4 deletions go/runtime/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"github.com/estuary/flow/go/shuffle"
"github.com/gogo/protobuf/proto"
"github.com/gogo/protobuf/types"
"github.com/golang-jwt/jwt/v5"
"github.com/sirupsen/logrus"
"go.gazette.dev/core/allocator"
"go.gazette.dev/core/broker/client"
pb "go.gazette.dev/core/broker/protocol"
Expand Down Expand Up @@ -254,11 +256,12 @@ func (t *taskReader[TaskSpec]) Coordinator() *shuffle.Coordinator { return t.coo
// and then logs the final exit status of the shard.
func (t *taskBase[TaskSpec]) heartbeatLoop(shard consumer.Shard) {
var (
id = shard.Spec().Id
// Period between regularly-published stat intervals.
// This period must cleanly divide into one hour!
period = 3 * time.Minute
// Jitters when interval stats are written cluster-wide.
jitter = intervalJitter(period, shard.FQN())
jitter = intervalJitter(period, id)
// Op notified when the shard fails.
op = shard.PrimaryLoop()
)
Expand Down Expand Up @@ -289,8 +292,33 @@ func (t *taskBase[TaskSpec]) heartbeatLoop(shard consumer.Shard) {
"assignment", shard.Assignment().Decoded,
)

// TODO(johnny): Notify control-plane of failure.
// Notify control-plane of the failure.
var now = time.Now()

var claims = pb.Claims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: id.String(),
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)),
},
}
var token, postErr = t.host.controlPlane.signClaims(claims)

var request = struct {
Token string `json:"token"`
Error string `json:"error"`
}{token, err.Error()}

var response struct {
// No response fields.
}

if postErr == nil {
postErr = callControlAPI(shard.Context(), t.host.controlPlane, "/notify/shard-failure", &request, &response)
}
if postErr != nil {
logrus.WithField("err", postErr).Error("failed to notify control plane of shard failure")
}
return
}
}
Expand Down Expand Up @@ -325,9 +353,9 @@ func durationToNextInterval(now time.Time, period time.Duration) time.Duration {

// intervalJitter returns a globally consistent, unique jitter offset for `name`
// so that heartbeats are uniformly distributed over time, in aggregate.
func intervalJitter(period time.Duration, name string) time.Duration {
func intervalJitter(period time.Duration, id pc.ShardID) time.Duration {
var w = fnv.New32()
w.Write([]byte(name))
w.Write([]byte(id))
return time.Duration(w.Sum32()%uint32(period.Seconds())) * time.Second
}

Expand Down
7 changes: 4 additions & 3 deletions go/runtime/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@ import (
pf "github.com/estuary/flow/go/protocols/flow"
"github.com/estuary/flow/go/protocols/ops"
"github.com/stretchr/testify/require"
pc "go.gazette.dev/core/consumer/protocol"
)

func TestIntervalJitterAndDurations(t *testing.T) {
const period = time.Minute

for _, tc := range []struct {
n string
i time.Duration
id pc.ShardID
i time.Duration
}{{"foo", 35}, {"bar", 52}, {"baz", 0}, {"bing", 39}, {"quip", 56}} {
require.Equal(t, time.Second*tc.i, intervalJitter(period, tc.n), tc.n)
require.Equal(t, time.Second*tc.i, intervalJitter(period, tc.id), tc.id)

}

Expand Down

0 comments on commit 5a20739

Please sign in to comment.