From ceb2430d458166d30b38224c1d09807b9b6d38aa Mon Sep 17 00:00:00 2001 From: Aviram Hassan Date: Mon, 29 Jan 2024 13:13:49 +0200 Subject: [PATCH 1/4] refactored auth, tried not breaking public API while adding re-auth [token gen] mechanism when needed Signed-off-by: Aviram Hassan --- src/client.rs | 144 ++++++++++++++++++++++++++++++--------------- src/token_cache.rs | 20 +++---- 2 files changed, 105 insertions(+), 59 deletions(-) diff --git a/src/client.rs b/src/client.rs index d30034fb..62de530d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -33,6 +33,7 @@ use std::collections::HashMap; use std::convert::TryFrom; use std::sync::Arc; use tokio::io::{AsyncWrite, AsyncWriteExt}; +use tokio::sync::RwLock; use tracing::{debug, trace, warn}; const MIME_TYPES_DISTRIBUTION_MANIFEST: &[&str] = &[ @@ -205,6 +206,8 @@ impl TryFrom for ConfigFile { #[derive(Clone)] pub struct Client { config: Arc, + // Registry -> RegistryAuth + auth_store: Arc>>, tokens: TokenCache, client: reqwest::Client, push_chunk_size: usize, @@ -213,9 +216,10 @@ pub struct Client { impl Default for Client { fn default() -> Self { Self { - config: Arc::new(ClientConfig::default()), - tokens: TokenCache::new(), - client: reqwest::Client::new(), + config: Arc::default(), + auth_store: Arc::default(), + tokens: TokenCache::default(), + client: reqwest::Client::default(), push_chunk_size: PUSH_CHUNK_MAX_SIZE, } } @@ -257,9 +261,9 @@ impl TryFrom for Client { Ok(Self { config: Arc::new(config), - tokens: TokenCache::new(), client: client_builder.build()?, push_chunk_size: PUSH_CHUNK_MAX_SIZE, + ..Default::default() }) } } @@ -271,10 +275,8 @@ impl Client { warn!("Cannot create OCI client from config: {:?}", err); warn!("Creating client with default configuration"); Self { - config: Arc::new(ClientConfig::default()), - tokens: TokenCache::new(), - client: reqwest::Client::new(), push_chunk_size: PUSH_CHUNK_MAX_SIZE, + ..Default::default() } }) } @@ -284,6 +286,41 @@ impl Client { Self::new(config_source.client_config()) } + async fn store_auth(&self, registry: &str, auth: RegistryAuth) { + self.auth_store + .write() + .await + .insert(registry.to_string(), auth); + } + + async fn is_stored_auth(&self, registry: &str) -> bool { + self.auth_store.read().await.contains_key(registry) + } + + async fn store_auth_if_needed(&self, registry: &str, auth: &RegistryAuth) { + if !self.is_stored_auth(registry).await { + self.store_auth(registry, auth.clone()).await; + } + } + + /// Checks if we got a token, if we don't - create it and store it in cache. + async fn get_auth_token( + &self, + reference: &Reference, + op: RegistryOperation, + ) -> Option { + let registry = reference.resolve_registry(); + let auth = self.auth_store.read().await.get(registry)?.clone(); + match self.tokens.get(reference, op).await { + Some(token) => Some(token), + None => { + let token = self._auth(reference, &auth, op).await.ok()??; + self.tokens.insert(reference, op, token.clone()).await; + Some(token) + } + } + } + /// Fetches the available Tags for the given Reference /// /// The client will check if it's already been authenticated and if @@ -298,9 +335,8 @@ impl Client { let op = RegistryOperation::Pull; let url = self.to_list_tags_url(image); - if !self.tokens.contains_key(image, op).await { - self.auth(image, auth, op).await?; - } + self.store_auth_if_needed(image.resolve_registry(), auth) + .await; let request = self.client.get(&url); let request = if let Some(num) = n { @@ -342,10 +378,8 @@ impl Client { accepted_media_types: Vec<&str>, ) -> Result { debug!("Pulling image: {:?}", image); - let op = RegistryOperation::Pull; - if !self.tokens.contains_key(image, op).await { - self.auth(image, auth, op).await?; - } + self.store_auth_if_needed(image.resolve_registry(), auth) + .await; let (manifest, digest, config) = self._pull_manifest_and_config(image).await?; @@ -400,10 +434,8 @@ impl Client { manifest: Option, ) -> Result { debug!("Pushing image: {:?}", image_ref); - let op = RegistryOperation::Push; - if !self.tokens.contains_key(image_ref, op).await { - self.auth(image_ref, auth, op).await?; - } + self.store_auth_if_needed(image_ref.resolve_registry(), auth) + .await; let manifest: OciImageManifest = match manifest { Some(m) => m, @@ -502,6 +534,36 @@ impl Client { authentication: &RegistryAuth, operation: RegistryOperation, ) -> Result> { + // preserve old caching behavior + match self._auth(image, authentication, operation).await { + Ok(Some(RegistryTokenType::Bearer(token))) => { + self.tokens + .insert(image, operation, RegistryTokenType::Bearer(token.clone())) + .await; + Ok(Some(token.token().to_string())) + } + Ok(Some(RegistryTokenType::Basic(username, password))) => { + self.tokens + .insert( + image, + operation, + RegistryTokenType::Basic(username, password), + ) + .await; + Ok(None) + } + Ok(None) => Ok(None), + Err(e) => Err(e), + } + } + + /// Internal auth that retrieves token. + async fn _auth( + &self, + image: &Reference, + authentication: &RegistryAuth, + operation: RegistryOperation, + ) -> Result> { debug!("Authorizing for image: {:?}", image); // The version request will tell us where to go. let url = format!( @@ -521,13 +583,10 @@ impl Client { Err(e) => { debug!(error = ?e, "Falling back to HTTP Basic Auth"); if let RegistryAuth::Basic(username, password) = authentication { - self.tokens - .insert( - image, - operation, - RegistryTokenType::Basic(username.to_string(), password.to_string()), - ) - .await; + return Ok(Some(RegistryTokenType::Basic( + username.to_string(), + password.to_string(), + ))); } return Ok(None); } @@ -566,11 +625,7 @@ impl Client { let token: RegistryToken = serde_json::from_str(&text) .map_err(|e| OciDistributionError::RegistryTokenDecodeError(e.to_string()))?; debug!("Successfully authorized for image '{:?}'", image); - let oauth_token = token.token().to_string(); - self.tokens - .insert(image, operation, RegistryTokenType::Bearer(token)) - .await; - Ok(Some(oauth_token)) + Ok(Some(RegistryTokenType::Bearer(token))) } _ => { let reason = auth_res.text().await?; @@ -593,10 +648,8 @@ impl Client { image: &Reference, auth: &RegistryAuth, ) -> Result { - let op = RegistryOperation::Pull; - if !self.tokens.contains_key(image, op).await { - self.auth(image, auth, op).await?; - } + self.store_auth_if_needed(image.resolve_registry(), auth) + .await; let url = self.to_v2_manifest_url(image); debug!("HEAD image manifest from {}", url); @@ -670,10 +723,8 @@ impl Client { image: &Reference, auth: &RegistryAuth, ) -> Result<(OciImageManifest, String)> { - let op = RegistryOperation::Pull; - if !self.tokens.contains_key(image, op).await { - self.auth(image, auth, op).await?; - } + self.store_auth_if_needed(image.resolve_registry(), auth) + .await; self._pull_image_manifest(image).await } @@ -690,10 +741,8 @@ impl Client { image: &Reference, auth: &RegistryAuth, ) -> Result<(OciManifest, String)> { - let op = RegistryOperation::Pull; - if !self.tokens.contains_key(image, op).await { - self.auth(image, auth, op).await?; - } + self.store_auth_if_needed(image.resolve_registry(), auth) + .await; self._pull_manifest(image).await } @@ -811,10 +860,8 @@ impl Client { image: &Reference, auth: &RegistryAuth, ) -> Result<(OciImageManifest, String, String)> { - let op = RegistryOperation::Pull; - if !self.tokens.contains_key(image, op).await { - self.auth(image, auth, op).await?; - } + self.store_auth_if_needed(image.resolve_registry(), auth) + .await; self._pull_manifest_and_config(image) .await @@ -856,7 +903,8 @@ impl Client { auth: &RegistryAuth, manifest: OciImageIndex, ) -> Result { - self.auth(reference, auth, RegistryOperation::Push).await?; + self.store_auth_if_needed(reference.resolve_registry(), auth) + .await; self.push_manifest(reference, &OciManifest::ImageIndex(manifest)) .await } @@ -1368,7 +1416,7 @@ impl<'a> RequestBuilderWrapper<'a> { ) -> Result { let mut headers = HeaderMap::new(); - if let Some(token) = self.client.tokens.get(image, op).await { + if let Some(token) = self.client.get_auth_token(image, op).await { match token { RegistryTokenType::Bearer(token) => { debug!("Using bearer token authentication."); diff --git a/src/token_cache.rs b/src/token_cache.rs index 2b682801..a28d5a25 100644 --- a/src/token_cache.rs +++ b/src/token_cache.rs @@ -59,7 +59,15 @@ pub enum RegistryOperation { Pull, } -type CacheType = BTreeMap<(String, String, RegistryOperation), (RegistryTokenType, u64)>; +// Types to allow better naming +type Registry = String; +type Repository = String; +type TokenCacheKey = (Registry, Repository, RegistryOperation); +type TokenExpiration = u64; +type TokenCacheValue = (RegistryTokenType, TokenExpiration); + +// (registry, repository, scope) -> (token, expiration) +type CacheType = BTreeMap; #[derive(Default, Clone)] pub(crate) struct TokenCache { @@ -68,12 +76,6 @@ pub(crate) struct TokenCache { } impl TokenCache { - pub(crate) fn new() -> Self { - TokenCache { - tokens: Arc::new(RwLock::new(BTreeMap::new())), - } - } - pub(crate) async fn insert( &self, reference: &Reference, @@ -158,8 +160,4 @@ impl TokenCache { } } } - - pub(crate) async fn contains_key(&self, reference: &Reference, op: RegistryOperation) -> bool { - self.get(reference, op).await.is_some() - } } From a2882951889901c21328b3560ccd6604c34cb7af Mon Sep 17 00:00:00 2001 From: Aviram Hassan Date: Fri, 2 Feb 2024 15:25:33 +0200 Subject: [PATCH 2/4] refactor further token cache Signed-off-by: Aviram Hassan --- src/token_cache.rs | 50 +++++++++++++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/src/token_cache.rs b/src/token_cache.rs index a28d5a25..237102ee 100644 --- a/src/token_cache.rs +++ b/src/token_cache.rs @@ -62,9 +62,20 @@ pub enum RegistryOperation { // Types to allow better naming type Registry = String; type Repository = String; -type TokenCacheKey = (Registry, Repository, RegistryOperation); + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +struct TokenCacheKey { + registry: Registry, + repository: Repository, + operation: RegistryOperation, +} + type TokenExpiration = u64; -type TokenCacheValue = (RegistryTokenType, TokenExpiration); + +struct TokenCacheValue { + token: RegistryTokenType, + expiration: TokenExpiration, +} // (registry, repository, scope) -> (token, expiration) type CacheType = BTreeMap; @@ -121,10 +132,14 @@ impl TokenCache { let registry = reference.resolve_registry().to_string(); let repository = reference.repository().to_string(); debug!(%registry, %repository, ?op, %expiration, "Inserting token"); - self.tokens - .write() - .await - .insert((registry, repository, op), (token, expiration)); + self.tokens.write().await.insert( + TokenCacheKey { + registry, + repository, + operation: op, + }, + TokenCacheValue { token, expiration }, + ); } pub(crate) async fn get( @@ -134,28 +149,31 @@ impl TokenCache { ) -> Option { let registry = reference.resolve_registry().to_string(); let repository = reference.repository().to_string(); - match self - .tokens - .read() - .await - .get(&(registry.clone(), repository.clone(), op)) - { - Some((ref token, expiration)) => { + let key = TokenCacheKey { + registry, + repository, + operation: op, + }; + match self.tokens.read().await.get(&key) { + Some(TokenCacheValue { + ref token, + expiration, + }) => { let now = SystemTime::now(); let epoch = now .duration_since(UNIX_EPOCH) .expect("Time went backwards") .as_secs(); if epoch > *expiration { - debug!(%registry, %repository, ?op, %expiration, miss=false, expired=true, "Fetching token"); + debug!(?key, %expiration, miss=false, expired=true, "Fetching token"); None } else { - debug!(%registry, %repository, ?op, %expiration, miss=false, expired=false, "Fetching token"); + debug!(?key, %expiration, miss=false, expired=false, "Fetching token"); Some(token.clone()) } } None => { - debug!(%registry, %repository, ?op, miss=true, "Fetching token"); + debug!(?key, miss = true, "Fetching token"); None } } From 3574da410d5602391ce53d8956fa6ec8b3d25da7 Mon Sep 17 00:00:00 2001 From: Aviram Hassan Date: Tue, 6 Feb 2024 12:22:48 +0200 Subject: [PATCH 3/4] fix tests Signed-off-by: Aviram Hassan --- examples/get-manifest/main.rs | 2 +- src/client.rs | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/get-manifest/main.rs b/examples/get-manifest/main.rs index 5398f5cb..0fc2edd1 100644 --- a/examples/get-manifest/main.rs +++ b/examples/get-manifest/main.rs @@ -85,7 +85,7 @@ pub async fn main() { let auth = build_auth(&reference, &cli); let client_config = build_client_config(&cli); - let mut client = Client::new(client_config); + let client = Client::new(client_config); let (manifest, _) = client .pull_manifest(&reference, &auth) diff --git a/src/client.rs b/src/client.rs index 62de530d..1673baae 100644 --- a/src/client.rs +++ b/src/client.rs @@ -534,6 +534,8 @@ impl Client { authentication: &RegistryAuth, operation: RegistryOperation, ) -> Result> { + self.store_auth_if_needed(image.resolve_registry(), authentication) + .await; // preserve old caching behavior match self._auth(image, authentication, operation).await { Ok(Some(RegistryTokenType::Bearer(token))) => { @@ -1802,6 +1804,14 @@ mod test { .as_str() .to_string(); + // we have to have it in the stored auth so we'll get to the token cache check. + client + .store_auth( + &Reference::try_from(HELLO_IMAGE_TAG)?.resolve_registry(), + RegistryAuth::Anonymous, + ) + .await; + client .tokens .insert( From f8c23d2ae4ce5ab08a1e5bf629a3c3664fa83551 Mon Sep 17 00:00:00 2001 From: Aviram Hassan Date: Wed, 7 Feb 2024 08:30:21 +0200 Subject: [PATCH 4/4] cr Signed-off-by: Aviram Hassan --- src/token_cache.rs | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/src/token_cache.rs b/src/token_cache.rs index 237102ee..b4f153d1 100644 --- a/src/token_cache.rs +++ b/src/token_cache.rs @@ -59,31 +59,22 @@ pub enum RegistryOperation { Pull, } -// Types to allow better naming -type Registry = String; -type Repository = String; - #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] struct TokenCacheKey { - registry: Registry, - repository: Repository, + registry: String, + repository: String, operation: RegistryOperation, } -type TokenExpiration = u64; - struct TokenCacheValue { token: RegistryTokenType, - expiration: TokenExpiration, + expiration: u64, } -// (registry, repository, scope) -> (token, expiration) -type CacheType = BTreeMap; - #[derive(Default, Clone)] pub(crate) struct TokenCache { // (registry, repository, scope) -> (token, expiration) - tokens: Arc>, + tokens: Arc>>, } impl TokenCache { @@ -165,15 +156,15 @@ impl TokenCache { .expect("Time went backwards") .as_secs(); if epoch > *expiration { - debug!(?key, %expiration, miss=false, expired=true, "Fetching token"); + debug!(%key.registry, %key.repository, ?key.operation, %expiration, miss=false, expired=true, "Fetching token"); None } else { - debug!(?key, %expiration, miss=false, expired=false, "Fetching token"); + debug!(%key.registry, %key.repository, ?key.operation, %expiration, miss=false, expired=false, "Fetching token"); Some(token.clone()) } } None => { - debug!(?key, miss = true, "Fetching token"); + debug!(%key.registry, %key.repository, ?key.operation, miss = true, "Fetching token"); None } }