Skip to content

Commit

Permalink
flow_client: Refactor to expose refreshed credentials
Browse files Browse the repository at this point in the history
`flowctl` needs these credentials exposed in order to store them in `flowctl::Config` + on disk for the next flowctl invocation
  • Loading branch information
jshearer committed Oct 3, 2024
1 parent fdd7c1a commit f9fcd6c
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 100 deletions.
29 changes: 11 additions & 18 deletions crates/dekaf/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,10 @@ pub use api_client::KafkaApiClient;

use aes_siv::{aead::Aead, Aes256SivAead, KeyInit, KeySizeUser};
use connector::DekafConfig;
use flow_client::{
client::{refresh_client, RefreshToken},
DEFAULT_AGENT_URL,
};
use flow_client::client::{refresh_authorizations, RefreshToken};
use percent_encoding::{percent_decode_str, utf8_percent_encode};
use serde::{Deserialize, Serialize};
use std::time::SystemTime;
use url::Url;

pub struct App {
/// Hostname which is advertised for Kafka access.
Expand All @@ -42,10 +38,8 @@ pub struct App {
pub kafka_client: KafkaApiClient,
/// Secret used to secure Prometheus endpoint
pub secret: String,
/// Supabase endpoint
pub api_endpoint: Url,
/// Supabase api key
pub api_key: String,
/// Share a single base client in order to re-use connection pools
pub client_base: flow_client::Client,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand All @@ -72,16 +66,15 @@ impl App {
let raw_token = String::from_utf8(base64::decode(password)?.to_vec())?;
let refresh: RefreshToken = serde_json::from_str(raw_token.as_str())?;

let mut client = flow_client::Client::new(
DEFAULT_AGENT_URL.to_owned(),
self.api_key.to_owned(),
self.api_endpoint.to_owned(),
None,
Some(refresh),
);
let (access, refresh) =
refresh_authorizations(&self.client_base, None, Some(refresh)).await?;

let client = self
.client_base
.clone()
.with_creds(Some(access), Some(refresh));

refresh_client(&mut client).await?;
let claims = client.claims()?;
let claims = flow_client::client::client_claims(&client)?;

if models::Materialization::regex().is_match(username.as_ref()) {
Ok(Authenticated {
Expand Down
13 changes: 10 additions & 3 deletions crates/dekaf/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use anyhow::{bail, Context};
use axum_server::tls_rustls::RustlsConfig;
use clap::{Args, Parser};
use dekaf::{KafkaApiClient, Session};
use flow_client::{DEFAULT_PG_PUBLIC_TOKEN, DEFAULT_PG_URL, LOCAL_PG_PUBLIC_TOKEN, LOCAL_PG_URL};
use flow_client::{
DEFAULT_AGENT_URL, DEFAULT_PG_PUBLIC_TOKEN, DEFAULT_PG_URL, LOCAL_PG_PUBLIC_TOKEN, LOCAL_PG_URL,
};
use futures::{FutureExt, TryStreamExt};
use rsasl::config::SASLConfig;
use rustls::pki_types::CertificateDer;
Expand Down Expand Up @@ -133,8 +135,13 @@ async fn main() -> anyhow::Result<()> {
"failed to connect or authenticate to upstream Kafka broker used for serving group management APIs",
)?,
secret: cli.encryption_secret.to_owned(),
api_endpoint,
api_key
client_base: flow_client::Client::new(
DEFAULT_AGENT_URL.to_owned(),
api_key,
api_endpoint,
None,
None
)
});

tracing::info!(
Expand Down
130 changes: 72 additions & 58 deletions crates/flow-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl Client {
pg_api_token: String,
pg_url: Url,
user_access_token: Option<String>,
userrefresh_token: Option<RefreshToken>,
user_refresh_token: Option<RefreshToken>,
) -> Self {
// Build journal and shard clients with an empty default service address.
// We'll use their with_endpoint_and_metadata() routines to cheaply clone
Expand All @@ -59,7 +59,19 @@ impl Client {
journal_client,
shard_client,
user_access_token,
user_refresh_token: userrefresh_token,
user_refresh_token,
}
}

pub fn with_creds(
self,
user_access_token: Option<String>,
user_refresh_token: Option<RefreshToken>,
) -> Self {
Self {
user_access_token: user_access_token.or(self.user_access_token),
user_refresh_token: user_refresh_token.or(self.user_refresh_token),
..self
}
}

Expand All @@ -74,15 +86,6 @@ impl Client {
self.pg_parent.clone()
}

pub fn claims(&self) -> anyhow::Result<ControlClaims> {
parse_jwt_claims(
self.user_access_token
.as_ref()
.ok_or(anyhow::anyhow!("Client is not authenticated"))?
.as_str(),
)
}

pub fn from(&self, table: &str) -> postgrest::Builder {
self.pg_client().from(table)
}
Expand Down Expand Up @@ -266,29 +269,34 @@ pub async fn fetch_collection_authorization(
Ok((journal_name_prefix, journal_client))
}

pub async fn refresh_client(client: &mut Client) -> anyhow::Result<()> {
pub async fn refresh_authorizations(
client: &Client,
access_token: Option<String>,
refresh_token: Option<RefreshToken>,
) -> anyhow::Result<(String, RefreshToken)> {
// Clear expired or soon-to-expire access token
if let Some(_) = &client.user_access_token {
let claims = client.claims()?;

let now = time::OffsetDateTime::now_utc();
let exp = time::OffsetDateTime::from_unix_timestamp(claims.exp as i64).unwrap();
let access_token = if let Some(token) = &access_token {
let claims: ControlClaims = parse_jwt_claims(token.as_str())?;

// Refresh access tokens with plenty of time to spare if we have a
// refresh token. If not, allow refreshing right until the token expires
match ((now - exp).whole_seconds(), &client.user_refresh_token) {
(exp_seconds, Some(_)) if exp_seconds < 60 => client.user_access_token = None,
(exp_seconds, None) if exp_seconds <= 0 => client.user_access_token = None,
_ => {}
match (claims.time_remaining().whole_seconds(), &refresh_token) {
(exp_seconds, Some(_)) if exp_seconds < 60 => None,
(exp_seconds, None) if exp_seconds <= 0 => None,
_ => Some(token.to_owned()),
}
}
} else {
None
};

if client.user_access_token.is_some() && client.user_refresh_token.is_some() {
// Authorization is current: nothing to do.
Ok(())
} else if client.user_access_token.is_some() {
// We have an access token but no refresh token. Create one.
let refresh_token = api_exec::<RefreshToken>(
match (access_token, refresh_token) {
(Some(access), Some(refresh)) => {
// Authorization is current: nothing to do.
Ok((access, refresh))
}
(Some(access), None) => {
// We have an access token but no refresh token. Create one.
let refresh_token = api_exec::<RefreshToken>(
client.rpc(
"create_refresh_token",
serde_json::json!({"multi_use": true, "valid_for": "90d", "detail": "Created by flowctl"})
Expand All @@ -297,37 +305,43 @@ pub async fn refresh_client(client: &mut Client) -> anyhow::Result<()> {
)
.await?;

client.user_refresh_token = Some(refresh_token);

tracing::info!("created new refresh token");
Ok(())
} else if let Some(RefreshToken { id, secret }) = &client.user_refresh_token {
// We have a refresh token but no access token. Generate one.

#[derive(serde::Deserialize)]
struct Response {
access_token: String,
refresh_token: Option<RefreshToken>, // Set iff the token was single-use.
tracing::info!("created new refresh token");
Ok((access, refresh_token))
}
let Response {
access_token,
refresh_token: next_refresh_token,
} = api_exec::<Response>(client.rpc(
"generate_access_token",
serde_json::json!({"refresh_token_id": id, "secret": secret}).to_string(),
))
.await
.context("failed to obtain access token")?;

if next_refresh_token.is_some() {
client.user_refresh_token = next_refresh_token;
(None, Some(RefreshToken { id, secret })) => {
// We have a refresh token but no access token. Generate one.

#[derive(serde::Deserialize)]
struct Response {
access_token: String,
refresh_token: Option<RefreshToken>, // Set iff the token was single-use.
}
let Response {
access_token,
refresh_token: next_refresh_token,
} = api_exec::<Response>(client.rpc(
"generate_access_token",
serde_json::json!({"refresh_token_id": id, "secret": secret}).to_string(),
))
.await
.context("failed to obtain access token")?;

tracing::info!("generated a new access token");
Ok((
access_token,
next_refresh_token.unwrap_or(RefreshToken { id, secret }),
))
}

client.user_access_token = Some(access_token);

tracing::info!("generated a new access token");
Ok(())
} else {
anyhow::bail!("Client not authenticated");
_ => anyhow::bail!("Client not authenticated"),
}
}

pub fn client_claims(client: &Client) -> anyhow::Result<ControlClaims> {
parse_jwt_claims(
client
.user_access_token
.as_ref()
.ok_or(anyhow::anyhow!("Client is not authenticated"))?
.as_str(),
)
}
31 changes: 10 additions & 21 deletions crates/flowctl/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::fmt::Debug;

use anyhow::Context;
use clap::Parser;

mod auth;
Expand All @@ -16,9 +15,9 @@ mod poll;
mod preview;
mod raw;

use flow_client::client::refresh_client;
use flow_client::client::refresh_authorizations;
pub(crate) use flow_client::client::Client;
pub(crate) use flow_client::{api_exec, api_exec_paginated, parse_jwt_claims};
pub(crate) use flow_client::{api_exec, api_exec_paginated};
use output::{Output, OutputType};
use poll::poll_while_queued;

Expand Down Expand Up @@ -135,27 +134,17 @@ impl Cli {
let mut config = config::Config::load(&self.profile)?;
let output = self.output.clone();

// If the configured access token has expired then remove it before continuing.
if let Some(token) = &config.user_access_token {
let claims: models::authorizations::ControlClaims =
parse_jwt_claims(token).context("failed to parse control-plane access token")?;
let client: flow_client::Client = config.build_client();

let now = time::OffsetDateTime::now_utc();
let exp = time::OffsetDateTime::from_unix_timestamp(claims.exp as i64).unwrap();
let (access, refresh) =
refresh_authorizations(&client, config.user_access_token, config.user_refresh_token)
.await?;

if now + std::time::Duration::from_secs(60) > exp {
tracing::info!(expired=%exp, "removing expired user access token from configuration");
config.user_access_token = None;
}
}

let mut client: flow_client::Client = config.build_client();
// Make sure to store refreshed tokens back in Config so they get written back to disk
config.user_access_token = Some(access.to_owned());
config.user_refresh_token = Some(refresh.to_owned());

if config.user_access_token.is_some() || config.user_refresh_token.is_some() {
refresh_client(&mut client).await?;
} else {
tracing::warn!("You are not authenticated. Run `auth login` to login to Flow.");
}
let client = client.with_creds(Some(access), Some(refresh));

let mut context = CliContext {
client,
Expand Down
10 changes: 10 additions & 0 deletions crates/models/src/authorizations.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::cmp::max;
use validator::Validate;

/// ControlClaims are claims encoded within control-plane access tokens.
Expand All @@ -11,6 +12,15 @@ pub struct ControlClaims {
pub exp: u64,
}

impl ControlClaims {
pub fn time_remaining(&self) -> time::Duration {
let now = time::OffsetDateTime::now_utc();
let exp = time::OffsetDateTime::from_unix_timestamp(self.exp as i64).unwrap();

max(exp - now, time::Duration::ZERO)
}
}

// Data-plane claims are represented by proto_gazette::Claims,
// which is not re-exported by this crate.

Expand Down

0 comments on commit f9fcd6c

Please sign in to comment.