diff --git a/crates/catalog/rest/src/catalog.rs b/crates/catalog/rest/src/catalog.rs index 3d1f7cf83c..a3c24db293 100644 --- a/crates/catalog/rest/src/catalog.rs +++ b/crates/catalog/rest/src/catalog.rs @@ -318,6 +318,22 @@ impl RestCatalog { Ok(file_io) } + + /// Invalidate the current token without generating a new one. On the next request, the client + /// will attempt to generate a new token. + pub async fn invalidate_token(&self) -> Result<()> { + self.context().await?.client.invalidate_token().await + } + + /// Invalidate the current token and set a new one. Generates a new token before invalidating + /// the current token, meaning the old token will be used until this function acquires the lock + /// and overwrites the token. + /// + /// If credential is invalid, or the request fails, this method will return an error and leave + /// the current token unchanged. + pub async fn regenerate_token(&self) -> Result<()> { + self.context().await?.client.regenerate_token().await + } } /// All requests and expected responses are derived from the REST catalog API spec: @@ -860,21 +876,27 @@ mod tests { } async fn create_oauth_mock(server: &mut ServerGuard) -> Mock { - create_oauth_mock_with_path(server, "/v1/oauth/tokens").await + create_oauth_mock_with_path(server, "/v1/oauth/tokens", "ey000000000000", 200).await } - async fn create_oauth_mock_with_path(server: &mut ServerGuard, path: &str) -> Mock { - server - .mock("POST", path) - .with_status(200) - .with_body( - r#"{ - "access_token": "ey000000000000", + async fn create_oauth_mock_with_path( + server: &mut ServerGuard, + path: &str, + token: &str, + status: usize, + ) -> Mock { + let body = format!( + r#"{{ + "access_token": "{token}", "token_type": "Bearer", "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", "expires_in": 86400 - }"#, - ) + }}"# + ); + server + .mock("POST", path) + .with_status(status) + .with_body(body) .expect(1) .create_async() .await @@ -949,6 +971,129 @@ mod tests { assert_eq!(token, Some("ey000000000000".to_string())); } + #[tokio::test] + async fn test_invalidate_token() { + let mut server = Server::new_async().await; + let oauth_mock = create_oauth_mock(&mut server).await; + let config_mock = create_config_mock(&mut server).await; + + let mut props = HashMap::new(); + props.insert("credential".to_string(), "client1:secret1".to_string()); + + let catalog = RestCatalog::new( + RestCatalogConfig::builder() + .uri(server.url()) + .props(props) + .build(), + ); + + let token = catalog.context().await.unwrap().client.token().await; + oauth_mock.assert_async().await; + config_mock.assert_async().await; + assert_eq!(token, Some("ey000000000000".to_string())); + + let oauth_mock = + create_oauth_mock_with_path(&mut server, "/v1/oauth/tokens", "ey000000000001", 200) + .await; + catalog.invalidate_token().await.unwrap(); + let token = catalog.context().await.unwrap().client.token().await; + oauth_mock.assert_async().await; + assert_eq!(token, Some("ey000000000001".to_string())); + } + + #[tokio::test] + async fn test_invalidate_token_failing_request() { + let mut server = Server::new_async().await; + let oauth_mock = create_oauth_mock(&mut server).await; + let config_mock = create_config_mock(&mut server).await; + + let mut props = HashMap::new(); + props.insert("credential".to_string(), "client1:secret1".to_string()); + + let catalog = RestCatalog::new( + RestCatalogConfig::builder() + .uri(server.url()) + .props(props) + .build(), + ); + + let token = catalog.context().await.unwrap().client.token().await; + oauth_mock.assert_async().await; + config_mock.assert_async().await; + assert_eq!(token, Some("ey000000000000".to_string())); + + let oauth_mock = + create_oauth_mock_with_path(&mut server, "/v1/oauth/tokens", "ey000000000001", 500) + .await; + catalog.invalidate_token().await.unwrap(); + let token = catalog.context().await.unwrap().client.token().await; + oauth_mock.assert_async().await; + assert_eq!(token, None); + } + + #[tokio::test] + async fn test_regenerate_token() { + let mut server = Server::new_async().await; + let oauth_mock = create_oauth_mock(&mut server).await; + let config_mock = create_config_mock(&mut server).await; + + let mut props = HashMap::new(); + props.insert("credential".to_string(), "client1:secret1".to_string()); + + let catalog = RestCatalog::new( + RestCatalogConfig::builder() + .uri(server.url()) + .props(props) + .build(), + ); + + let token = catalog.context().await.unwrap().client.token().await; + oauth_mock.assert_async().await; + config_mock.assert_async().await; + assert_eq!(token, Some("ey000000000000".to_string())); + + let oauth_mock = + create_oauth_mock_with_path(&mut server, "/v1/oauth/tokens", "ey000000000001", 200) + .await; + catalog.regenerate_token().await.unwrap(); + oauth_mock.assert_async().await; + let token = catalog.context().await.unwrap().client.token().await; + assert_eq!(token, Some("ey000000000001".to_string())); + } + + #[tokio::test] + async fn test_regenerate_token_failing_request() { + let mut server = Server::new_async().await; + let oauth_mock = create_oauth_mock(&mut server).await; + let config_mock = create_config_mock(&mut server).await; + + let mut props = HashMap::new(); + props.insert("credential".to_string(), "client1:secret1".to_string()); + + let catalog = RestCatalog::new( + RestCatalogConfig::builder() + .uri(server.url()) + .props(props) + .build(), + ); + + let token = catalog.context().await.unwrap().client.token().await; + oauth_mock.assert_async().await; + config_mock.assert_async().await; + assert_eq!(token, Some("ey000000000000".to_string())); + + let oauth_mock = + create_oauth_mock_with_path(&mut server, "/v1/oauth/tokens", "ey000000000001", 500) + .await; + let invalidate_result = catalog.regenerate_token().await; + assert!(invalidate_result.is_err()); + oauth_mock.assert_async().await; + let token = catalog.context().await.unwrap().client.token().await; + + // original token is left intact + assert_eq!(token, Some("ey000000000000".to_string())); + } + #[tokio::test] async fn test_http_headers() { let server = Server::new_async().await; @@ -1026,7 +1171,9 @@ mod tests { let mut auth_server = Server::new_async().await; let auth_server_path = "/some/path"; - let oauth_mock = create_oauth_mock_with_path(&mut auth_server, auth_server_path).await; + let oauth_mock = + create_oauth_mock_with_path(&mut auth_server, auth_server_path, "ey000000000000", 200) + .await; let mut props = HashMap::new(); props.insert("credential".to_string(), "client1:secret1".to_string()); diff --git a/crates/catalog/rest/src/client.rs b/crates/catalog/rest/src/client.rs index fe311a71a7..0b9af9b5bd 100644 --- a/crates/catalog/rest/src/client.rs +++ b/crates/catalog/rest/src/client.rs @@ -106,38 +106,7 @@ impl HttpClient { self.token.lock().await.clone() } - /// Authenticate the request by filling token. - /// - /// - If neither token nor credential is provided, this method will do nothing. - /// - If only credential is provided, this method will try to fetch token from the server. - /// - If token is provided, this method will use the token directly. - /// - /// # TODO - /// - /// Support refreshing token while needed. - async fn authenticate(&self, req: &mut Request) -> Result<()> { - // Clone the token from lock without holding the lock for entire function. - let token = self.token.lock().await.clone(); - - if self.credential.is_none() && token.is_none() { - return Ok(()); - } - - // Use token if provided. - if let Some(token) = &token { - req.headers_mut().insert( - http::header::AUTHORIZATION, - format!("Bearer {token}").parse().map_err(|e| { - Error::new( - ErrorKind::DataInvalid, - "Invalid token received from catalog server!", - ) - .with_source(e) - })?, - ); - return Ok(()); - } - + async fn exchange_credential_for_token(&self) -> Result { // Credential must exist here. let (client_id, client_secret) = self.credential.as_ref().ok_or_else(|| { Error::new( @@ -202,7 +171,61 @@ impl HttpClient { })?; Err(Error::from(e)) }?; - let token = auth_res.access_token; + Ok(auth_res.access_token) + } + + /// Invalidate the current token without generating a new one. On the next request, the client + /// will attempt to generate a new token. + pub(crate) async fn invalidate_token(&self) -> Result<()> { + *self.token.lock().await = None; + Ok(()) + } + + /// Invalidate the current token and set a new one. Generates a new token before invalidating + /// the current token, meaning the old token will be used until this function acquires the lock + /// and overwrites the token. + /// + /// If credential is invalid, or the request fails, this method will return an error and leave + /// the current token unchanged. + pub(crate) async fn regenerate_token(&self) -> Result<()> { + let new_token = self.exchange_credential_for_token().await?; + *self.token.lock().await = Some(new_token.clone()); + Ok(()) + } + + /// Authenticate the request by filling token. + /// + /// - If neither token nor credential is provided, this method will do nothing. + /// - If only credential is provided, this method will try to fetch token from the server. + /// - If token is provided, this method will use the token directly. + /// + /// # TODO + /// + /// Support refreshing token while needed. + async fn authenticate(&self, req: &mut Request) -> Result<()> { + // Clone the token from lock without holding the lock for entire function. + let token = self.token.lock().await.clone(); + + if self.credential.is_none() && token.is_none() { + return Ok(()); + } + + // Use token if provided. + if let Some(token) = &token { + req.headers_mut().insert( + http::header::AUTHORIZATION, + format!("Bearer {token}").parse().map_err(|e| { + Error::new( + ErrorKind::DataInvalid, + "Invalid token received from catalog server!", + ) + .with_source(e) + })?, + ); + return Ok(()); + } + + let token = self.exchange_credential_for_token().await?; // Update token. *self.token.lock().await = Some(token.clone()); // Insert token in request.