Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 158 additions & 11 deletions crates/catalog/rest/src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,22 @@ impl RestCatalog {

Ok(file_io)
}

/// Invalidate the current token without generating a new one. On the next request, the client
/// will attempt to generate a new token.
pub async fn invalidate_token(&self) -> Result<()> {
self.context().await?.client.invalidate_token().await
}

/// Invalidate the current token and set a new one. Generates a new token before invalidating
/// the current token, meaning the old token will be used until this function acquires the lock
/// and overwrites the token.
///
/// If credential is invalid, or the request fails, this method will return an error and leave
/// the current token unchanged.
pub async fn regenerate_token(&self) -> Result<()> {
self.context().await?.client.regenerate_token().await
}
}

/// All requests and expected responses are derived from the REST catalog API spec:
Expand Down Expand Up @@ -860,21 +876,27 @@ mod tests {
}

async fn create_oauth_mock(server: &mut ServerGuard) -> Mock {
create_oauth_mock_with_path(server, "/v1/oauth/tokens").await
create_oauth_mock_with_path(server, "/v1/oauth/tokens", "ey000000000000", 200).await
}

async fn create_oauth_mock_with_path(server: &mut ServerGuard, path: &str) -> Mock {
server
.mock("POST", path)
.with_status(200)
.with_body(
r#"{
"access_token": "ey000000000000",
async fn create_oauth_mock_with_path(
server: &mut ServerGuard,
path: &str,
token: &str,
status: usize,
) -> Mock {
let body = format!(
r#"{{
"access_token": "{token}",
"token_type": "Bearer",
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
"expires_in": 86400
}"#,
)
}}"#
);
server
.mock("POST", path)
.with_status(status)
.with_body(body)
.expect(1)
.create_async()
.await
Expand Down Expand Up @@ -949,6 +971,129 @@ mod tests {
assert_eq!(token, Some("ey000000000000".to_string()));
}

#[tokio::test]
async fn test_invalidate_token() {
let mut server = Server::new_async().await;
let oauth_mock = create_oauth_mock(&mut server).await;
let config_mock = create_config_mock(&mut server).await;

let mut props = HashMap::new();
props.insert("credential".to_string(), "client1:secret1".to_string());

let catalog = RestCatalog::new(
RestCatalogConfig::builder()
.uri(server.url())
.props(props)
.build(),
);

let token = catalog.context().await.unwrap().client.token().await;
oauth_mock.assert_async().await;
config_mock.assert_async().await;
assert_eq!(token, Some("ey000000000000".to_string()));

let oauth_mock =
create_oauth_mock_with_path(&mut server, "/v1/oauth/tokens", "ey000000000001", 200)
.await;
catalog.invalidate_token().await.unwrap();
let token = catalog.context().await.unwrap().client.token().await;
oauth_mock.assert_async().await;
assert_eq!(token, Some("ey000000000001".to_string()));
}

#[tokio::test]
async fn test_invalidate_token_failing_request() {
let mut server = Server::new_async().await;
let oauth_mock = create_oauth_mock(&mut server).await;
let config_mock = create_config_mock(&mut server).await;

let mut props = HashMap::new();
props.insert("credential".to_string(), "client1:secret1".to_string());

let catalog = RestCatalog::new(
RestCatalogConfig::builder()
.uri(server.url())
.props(props)
.build(),
);

let token = catalog.context().await.unwrap().client.token().await;
oauth_mock.assert_async().await;
config_mock.assert_async().await;
assert_eq!(token, Some("ey000000000000".to_string()));

let oauth_mock =
create_oauth_mock_with_path(&mut server, "/v1/oauth/tokens", "ey000000000001", 500)
.await;
catalog.invalidate_token().await.unwrap();
let token = catalog.context().await.unwrap().client.token().await;
oauth_mock.assert_async().await;
assert_eq!(token, None);
}

#[tokio::test]
async fn test_regenerate_token() {
let mut server = Server::new_async().await;
let oauth_mock = create_oauth_mock(&mut server).await;
let config_mock = create_config_mock(&mut server).await;

let mut props = HashMap::new();
props.insert("credential".to_string(), "client1:secret1".to_string());

let catalog = RestCatalog::new(
RestCatalogConfig::builder()
.uri(server.url())
.props(props)
.build(),
);

let token = catalog.context().await.unwrap().client.token().await;
oauth_mock.assert_async().await;
config_mock.assert_async().await;
assert_eq!(token, Some("ey000000000000".to_string()));

let oauth_mock =
create_oauth_mock_with_path(&mut server, "/v1/oauth/tokens", "ey000000000001", 200)
.await;
catalog.regenerate_token().await.unwrap();
oauth_mock.assert_async().await;
let token = catalog.context().await.unwrap().client.token().await;
assert_eq!(token, Some("ey000000000001".to_string()));
}

#[tokio::test]
async fn test_regenerate_token_failing_request() {
let mut server = Server::new_async().await;
let oauth_mock = create_oauth_mock(&mut server).await;
let config_mock = create_config_mock(&mut server).await;

let mut props = HashMap::new();
props.insert("credential".to_string(), "client1:secret1".to_string());

let catalog = RestCatalog::new(
RestCatalogConfig::builder()
.uri(server.url())
.props(props)
.build(),
);

let token = catalog.context().await.unwrap().client.token().await;
oauth_mock.assert_async().await;
config_mock.assert_async().await;
assert_eq!(token, Some("ey000000000000".to_string()));

let oauth_mock =
create_oauth_mock_with_path(&mut server, "/v1/oauth/tokens", "ey000000000001", 500)
.await;
let invalidate_result = catalog.regenerate_token().await;
assert!(invalidate_result.is_err());
oauth_mock.assert_async().await;
let token = catalog.context().await.unwrap().client.token().await;

// original token is left intact
assert_eq!(token, Some("ey000000000000".to_string()));
}

#[tokio::test]
async fn test_http_headers() {
let server = Server::new_async().await;
Expand Down Expand Up @@ -1026,7 +1171,9 @@ mod tests {

let mut auth_server = Server::new_async().await;
let auth_server_path = "/some/path";
let oauth_mock = create_oauth_mock_with_path(&mut auth_server, auth_server_path).await;
let oauth_mock =
create_oauth_mock_with_path(&mut auth_server, auth_server_path, "ey000000000000", 200)
.await;

let mut props = HashMap::new();
props.insert("credential".to_string(), "client1:secret1".to_string());
Expand Down
89 changes: 56 additions & 33 deletions crates/catalog/rest/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,38 +106,7 @@ impl HttpClient {
self.token.lock().await.clone()
}

/// Authenticate the request by filling token.
///
/// - If neither token nor credential is provided, this method will do nothing.
/// - If only credential is provided, this method will try to fetch token from the server.
/// - If token is provided, this method will use the token directly.
///
/// # TODO
///
/// Support refreshing token while needed.
async fn authenticate(&self, req: &mut Request) -> Result<()> {
// Clone the token from lock without holding the lock for entire function.
let token = self.token.lock().await.clone();

if self.credential.is_none() && token.is_none() {
return Ok(());
}

// Use token if provided.
if let Some(token) = &token {
req.headers_mut().insert(
http::header::AUTHORIZATION,
format!("Bearer {token}").parse().map_err(|e| {
Error::new(
ErrorKind::DataInvalid,
"Invalid token received from catalog server!",
)
.with_source(e)
})?,
);
return Ok(());
}

async fn exchange_credential_for_token(&self) -> Result<String> {
// Credential must exist here.
let (client_id, client_secret) = self.credential.as_ref().ok_or_else(|| {
Error::new(
Expand Down Expand Up @@ -202,7 +171,61 @@ impl HttpClient {
})?;
Err(Error::from(e))
}?;
let token = auth_res.access_token;
Ok(auth_res.access_token)
}

/// Invalidate the current token without generating a new one. On the next request, the client
/// will attempt to generate a new token.
pub(crate) async fn invalidate_token(&self) -> Result<()> {
*self.token.lock().await = None;
Ok(())
}

/// Invalidate the current token and set a new one. Generates a new token before invalidating
/// the current token, meaning the old token will be used until this function acquires the lock
/// and overwrites the token.
///
/// If credential is invalid, or the request fails, this method will return an error and leave
/// the current token unchanged.
pub(crate) async fn regenerate_token(&self) -> Result<()> {
let new_token = self.exchange_credential_for_token().await?;
*self.token.lock().await = Some(new_token.clone());
Ok(())
}

/// Authenticate the request by filling token.
///
/// - If neither token nor credential is provided, this method will do nothing.
/// - If only credential is provided, this method will try to fetch token from the server.
/// - If token is provided, this method will use the token directly.
///
/// # TODO
///
/// Support refreshing token while needed.
async fn authenticate(&self, req: &mut Request) -> Result<()> {
// Clone the token from lock without holding the lock for entire function.
let token = self.token.lock().await.clone();

if self.credential.is_none() && token.is_none() {
return Ok(());
}

// Use token if provided.
if let Some(token) = &token {
req.headers_mut().insert(
http::header::AUTHORIZATION,
format!("Bearer {token}").parse().map_err(|e| {
Error::new(
ErrorKind::DataInvalid,
"Invalid token received from catalog server!",
)
.with_source(e)
})?,
);
return Ok(());
}

let token = self.exchange_credential_for_token().await?;
// Update token.
*self.token.lock().await = Some(token.clone());
// Insert token in request.
Expand Down
Loading