Skip to content

Commit

Permalink
dekaf: Refactor to support new agent auth mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
jshearer committed Sep 26, 2024
1 parent 1a20b83 commit e8e48e3
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 164 deletions.
1 change: 1 addition & 0 deletions crates/dekaf/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ doc = { path = "../doc" }
flowctl = { path = "../flowctl" }
gazette = { path = "../gazette" }
labels = { path = "../labels" }
models = { path = "../models" }
ops = { path = "../ops" }
proto-flow = { path = "../proto-flow" }
proto-gazette = { path = "../proto-gazette" }
Expand Down
64 changes: 9 additions & 55 deletions crates/dekaf/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,15 @@ use read::Read;
mod session;
pub use session::Session;

pub mod registry;
pub mod metrics;
pub mod registry;

mod api_client;
pub use api_client::KafkaApiClient;

use aes_siv::{aead::Aead, Aes256SivAead, KeyInit, KeySizeUser};
use itertools::Itertools;
use percent_encoding::{percent_decode_str, utf8_percent_encode};
use serde::{Deserialize, Serialize};
use serde_json::de;

pub struct App {
/// Anonymous API client for the Estuary control plane.
Expand All @@ -47,17 +45,9 @@ pub struct ConfigOptions {
}

pub struct Authenticated {
client: postgrest::Postgrest,
client: flowctl::Client,
user_config: ConfigOptions,
claims: JwtClaims,
}

#[derive(Deserialize)]
struct JwtClaims {
/// Unix timestamp in seconds when this token will expire
exp: u64,
/// ID of the user that owns this token
sub: String,
claims: models::authorizations::ControlClaims,
}

impl App {
Expand All @@ -71,51 +61,15 @@ impl App {
let config: ConfigOptions = serde_json::from_str(&username_str)
.context("failed to parse username as a JSON object")?;

#[derive(serde::Deserialize)]
struct RefreshToken {
id: String,
secret: String,
}
let RefreshToken {
id: refresh_token_id,
secret,
} = serde_json::from_slice(&base64::decode(password).context("password is not base64")?)
.context("failed to decode refresh token from password")?;

tracing::info!(refresh_token_id, "authenticating refresh token");
let mut client = flowctl::Client::new(&flowctl::Config::from_refresh_token(
String::from_utf8(base64::decode(password)?.to_vec())?.as_str(),
)?);

#[derive(serde::Deserialize)]
struct AccessToken {
access_token: String,
}
let AccessToken { access_token } = self
.anon_client
.rpc(
"generate_access_token",
serde_json::json!({"refresh_token_id": refresh_token_id, "secret": secret})
.to_string(),
)
.execute()
.await
.and_then(|r| r.error_for_status())
.context("generating access token")?
.json()
.await?;

let authenticated_client = self
.anon_client
.clone()
.insert_header("Authorization", format!("Bearer {access_token}"));

let claims = base64::decode(access_token.split(".").collect_vec()[1])
.map_err(anyhow::Error::from)
.and_then(|decoded| {
de::from_slice::<JwtClaims>(&decoded[..]).map_err(anyhow::Error::from)
})
.context("Failed to parse access token claims")?;
client.refresh().await?;
let claims = client.claims()?;

Ok(Authenticated {
client: authenticated_client,
client,
user_config: config,
claims,
})
Expand Down
4 changes: 3 additions & 1 deletion crates/dekaf/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ async fn all_subjects(
..
} = app.authenticate(auth.username(), auth.password()).await?;

let client = client.pg_client();

super::fetch_all_collection_names(&client)
.await
.context("failed to list collections from the control plane")
Expand Down Expand Up @@ -95,7 +97,7 @@ async fn get_subject_latest(
.with_context(|| format!("collection {collection} does not exist"))?;

let (key_id, value_id) = collection
.registered_schema_ids(&client)
.registered_schema_ids(&client.pg_client())
.await
.context("failed to resolve registered Avro schemas")?;

Expand Down
41 changes: 29 additions & 12 deletions crates/dekaf/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct PendingRead {

pub struct Session {
app: Arc<App>,
client: postgrest::Postgrest,
client: Option<flowctl::Client>,
reads: HashMap<(TopicName, i32), PendingRead>,
/// ID of the authenticated user
user_id: Option<String>,
Expand All @@ -41,10 +41,9 @@ pub struct Session {

impl Session {
pub fn new(app: Arc<App>, secret: String) -> Self {
let client = app.anon_client.clone();
Self {
app,
client,
client: None,
reads: HashMap::new(),
user_id: None,
config: None,
Expand Down Expand Up @@ -87,9 +86,9 @@ impl Session {
user_config,
claims,
}) => {
self.client = client;
self.client.replace(client);
self.config.replace(user_config);
self.user_id.replace(claims.sub);
self.user_id.replace(claims.sub.to_string());

let mut response = messages::SaslAuthenticateResponse::default();
response.session_lifetime_ms = (1000
Expand Down Expand Up @@ -144,7 +143,14 @@ impl Session {
async fn metadata_all_topics(
&mut self,
) -> anyhow::Result<IndexMap<TopicName, MetadataResponseTopic>> {
let collections = fetch_all_collection_names(&self.client).await?;
let collections = fetch_all_collection_names(
&self
.client
.as_ref()
.ok_or(anyhow::anyhow!("Session not authenticated"))?
.pg_client(),
)
.await?;

tracing::debug!(collections=?ops::DebugJson(&collections), "fetched all collections");

Expand All @@ -170,7 +176,10 @@ impl Session {
&mut self,
requests: Vec<messages::metadata_request::MetadataRequestTopic>,
) -> anyhow::Result<IndexMap<TopicName, MetadataResponseTopic>> {
let client = &self.client;
let client = &self
.client
.as_ref()
.ok_or(anyhow::anyhow!("Session not authenticated"))?;

// Concurrently fetch Collection instances for all requested topics.
let collections: anyhow::Result<Vec<(TopicName, Option<Collection>)>> =
Expand Down Expand Up @@ -247,7 +256,10 @@ impl Session {
&mut self,
request: messages::ListOffsetsRequest,
) -> anyhow::Result<messages::ListOffsetsResponse> {
let client = &self.client;
let client = &self
.client
.as_ref()
.ok_or(anyhow::anyhow!("Session not authenticated"))?;

// Concurrently fetch Collection instances and offsets for all requested topics and partitions.
// Map each "topic" into Vec<(Partition Index, Option<(Journal Offset, Timestamp))>.
Expand Down Expand Up @@ -342,7 +354,11 @@ impl Session {
..
} = request;

let client = &self.client;
let client = &self
.client
.as_ref()
.ok_or(anyhow::anyhow!("Session not authenticated"))?;

let timeout = tokio::time::sleep(std::time::Duration::from_millis(max_wait_ms as u64));
let timeout = futures::future::maybe_done(timeout);
tokio::pin!(timeout);
Expand Down Expand Up @@ -370,10 +386,11 @@ impl Session {
tracing::debug!(collection = ?&key.0, partition=partition_request.partition, "Partition doesn't exist!");
continue; // Partition doesn't exist.
};
let (key_schema_id, value_schema_id) =
collection.registered_schema_ids(&client).await?;
let (key_schema_id, value_schema_id) = collection
.registered_schema_ids(&client.pg_client())
.await?;

let read = Read::new(
let read: Read = Read::new(
collection.journal_client.clone(),
&collection,
partition,
Expand Down
46 changes: 7 additions & 39 deletions crates/dekaf/src/topology.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use anyhow::Context;
use flowctl::fetch_collection_authorization;
use futures::{StreamExt, TryStreamExt};
use gazette::{broker, journal, uuid};
use proto_flow::flow;
Expand Down Expand Up @@ -50,11 +51,9 @@ pub struct Partition {

impl Collection {
/// Build a Collection by fetching its spec, a authenticated data-plane access token, and its partitions.
pub async fn new(
client: &postgrest::Postgrest,
collection: &str,
) -> anyhow::Result<Option<Self>> {
pub async fn new(client: &flowctl::Client, collection: &str) -> anyhow::Result<Option<Self>> {
let not_before = uuid::Clock::default();
let pg_client = client.pg_client();

// Build a journal client and use it to fetch partitions while concurrently
// fetching the collection's metadata from the control plane.
Expand All @@ -64,7 +63,7 @@ impl Collection {
Ok((journal_client, partitions))
};
let (spec, client_partitions): (anyhow::Result<_>, anyhow::Result<_>) =
futures::join!(Self::fetch_spec(&client, collection), client_partitions);
futures::join!(Self::fetch_spec(&pg_client, collection), client_partitions);

let Some(spec) = spec? else { return Ok(None) };
let (journal_client, partitions) = client_partitions?;
Expand Down Expand Up @@ -234,43 +233,12 @@ impl Collection {

/// Build a journal client by resolving the collections data-plane gateway and an access token.
async fn build_journal_client(
client: &postgrest::Postgrest,
client: &flowctl::Client,
collection: &str,
) -> anyhow::Result<journal::Client> {
let body = serde_json::json!({
"prefixes": [collection],
})
.to_string();

#[derive(serde::Deserialize)]
struct Auth {
token: String,
gateway_url: String,
}

let [auth]: [Auth; 1] = client
.rpc("gateway_auth_token", body)
.build()
.send()
.await
.and_then(|r| r.error_for_status())
.context("requesting data plane gateway auth token")?
.json()
.await?;

tracing::debug!(
collection,
gateway = auth.gateway_url,
"fetched data-plane token"
);

let mut metadata = gazette::Metadata::default();
metadata.bearer_token(&auth.token)?;

let router = gazette::Router::new("dekaf");
let client = journal::Client::new(auth.gateway_url, metadata, router);
let (_, journal_client) = fetch_collection_authorization(client, collection).await?;

Ok(client)
Ok(journal_client)
}

async fn registered_schema_id(
Expand Down
Loading

0 comments on commit e8e48e3

Please sign in to comment.