diff --git a/tensorboard/data/server/BUILD b/tensorboard/data/server/BUILD index 03ecd17195..e5d077c033 100644 --- a/tensorboard/data/server/BUILD +++ b/tensorboard/data/server/BUILD @@ -35,6 +35,7 @@ rust_library( "downsample.rs", "event_file.rs", "gcs.rs", + "gcs/auth.rs", "gcs/client.rs", "gcs/logdir.rs", "logdir.rs", diff --git a/tensorboard/data/server/cli/dynamic_logdir.rs b/tensorboard/data/server/cli/dynamic_logdir.rs index c275485b68..3045f53b08 100644 --- a/tensorboard/data/server/cli/dynamic_logdir.rs +++ b/tensorboard/data/server/cli/dynamic_logdir.rs @@ -43,6 +43,9 @@ impl DynLogdir { /// This succeeds unless the path represents a GCS logdir and no HTTP client can be opened. In /// case of failure, errors will be logged to the active logger. /// + /// This constructor is heavyweight; it may construct an HTTP client and read a GCS credentials + /// file from disk. + /// /// # Panics /// /// May panic in debug mode if called from a thread with an active Tokio runtime; see @@ -59,7 +62,7 @@ impl DynLogdir { let mut parts = gcs_path.splitn(2, '/'); let bucket = parts.next().unwrap().to_string(); // splitn always yields at least one element let prefix = parts.next().unwrap_or("").to_string(); - let client = match gcs::Client::new() { + let client = match gcs::Client::new(gcs::Credentials::from_disk()) { Err(e) => { error!("Could not open GCS connection: {}", e); return None; diff --git a/tensorboard/data/server/gcs.rs b/tensorboard/data/server/gcs.rs index 8c2b8c1ba7..492317a368 100644 --- a/tensorboard/data/server/gcs.rs +++ b/tensorboard/data/server/gcs.rs @@ -15,8 +15,10 @@ limitations under the License. //! Google Cloud Storage interop. +mod auth; mod client; mod logdir; +pub use auth::Credentials; pub use client::Client; pub use logdir::Logdir; diff --git a/tensorboard/data/server/gcs/auth.rs b/tensorboard/data/server/gcs/auth.rs new file mode 100644 index 0000000000..4d4a3f693a --- /dev/null +++ b/tensorboard/data/server/gcs/auth.rs @@ -0,0 +1,300 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +//! OAuth integration for GCS. +//! +//! Useful resources: +//! +//! - TensorFlow OAuth implementation: [`oauth_client.cc`], [`google_auth_provider.cc`] +//! - [RFC 6749]: The OAuth 2.0 Authorization Framework +//! - ["Refreshing Access Tokens"] OAuth guide +//! +//! [`oauth_client.cc`]: https://github.com/tensorflow/tensorflow/blob/r2.4/tensorflow/core/platform/cloud/oauth_client.cc +//! [`google_auth_provider.cc`]: https://github.com/tensorflow/tensorflow/blob/r2.4/tensorflow/core/platform/cloud/google_auth_provider.cc +//! [RFC 6749]: https://tools.ietf.org/html/rfc6749 +//! ["Refreshing Access Tokens"]: https://www.oauth.com/oauth2-servers/access-tokens/refreshing-access-tokens/ + +use log::{debug, info, warn}; +use serde::{Deserialize, Serialize}; +use std::fmt::{self, Debug}; +use std::fs::File; +use std::io::BufReader; +use std::path::PathBuf; +use std::sync::RwLock; +use std::time::{Duration, Instant}; + +use reqwest::blocking::{Client as HttpClient, RequestBuilder}; + +const OAUTH_REFRESH_TOKEN_ENDPOINT: &str = "https://www.googleapis.com/oauth2/v3/token"; + +/// A set of refreshable OAuth credentials plus a potentially active token. Use +/// [`authenticate`][Self::authenticate] to add an `Authorization` header to an outgoing request, +/// fetching a fresh access token if necessary. +/// +/// A `TokenStore` may be freely shared among threads; it synchronizes internally if needed. +pub struct TokenStore { + creds: Credentials, + token: RwLock>, +} + +impl TokenStore { + /// Creates a new token store from the given credentials. + /// + /// This operation is cheap and does not actually fetch any access tokens. + pub fn new(creds: Credentials) -> Self { + Self { + creds, + token: RwLock::new(None), + } + } +} + +/// An access token that's valid until a particular point in time. +#[derive(Debug)] +struct BoundedToken { + access_token: AccessToken, + expires: Instant, +} + +impl BoundedToken { + /// Tests whether this token will still be valid at the given instant. + fn valid_at(&self, when: Instant) -> bool { + when < self.expires + } +} + +/// Private access token module to prevent accidental leaking of tokens into logs, etc. An access +/// token can be attached to a request, but cannot be directly extracted. +mod access_token { + use super::*; + #[derive(Deserialize)] + pub struct AccessToken(String); + impl AccessToken { + /// Attaches this token to an outgoing request. + pub fn authenticate(&self, rb: RequestBuilder) -> RequestBuilder { + rb.bearer_auth(&self.0) + } + } + impl Debug for AccessToken { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("AccessToken") + .field(&format_args!("_")) + .finish() + } + } +} +use access_token::AccessToken; + +impl TokenStore { + /// Attempts to attach an access token to the given outgoing request. + /// + /// A cached token will be reused if it is expected to be valid for at least `lifetime`. + /// Otherwise, a new token will be fetched and stored. + pub fn authenticate( + &self, + rb: RequestBuilder, + http: &HttpClient, + lifetime: Duration, + ) -> RequestBuilder { + if self.creds.anonymous() { + return rb; + } + let token = self.token.read().expect("failed to read auth token"); + if let Some(t) = BoundedToken::unwrap_if_valid_for(&*token, lifetime) { + return t.authenticate(rb); + } + drop(token); + let mut token = self.token.write().expect("failed to write auth token"); + // Check again: may have just been written by a different client, in which case no need to + // re-fetch. + if let Some(t) = BoundedToken::unwrap_if_valid_for(&*token, lifetime) { + return t.authenticate(rb); + } + // If we get here, we need a fresh token. + *token = self.creds.fetch(http); + if let Some(ref t) = *token { + debug!( + "Obtained new access token, live for the next {:?}", + t.expires.saturating_duration_since(Instant::now()) + ); + t.access_token.authenticate(rb) + } else { + rb + } + } +} + +impl BoundedToken { + /// Checks whether `token` represents a token that will still be valid for at least the given + /// `lifetime`, and if so returns a reference to the inner access token. + fn unwrap_if_valid_for(token: &Option, lifetime: Duration) -> Option<&AccessToken> { + match token.as_ref() { + Some(t) if t.valid_at(Instant::now() + lifetime) => Some(&t.access_token), + _ => None, + } + } +} + +/// The user's persistent credentials, if any. This represents all the information needed to +/// request access tokens. +pub enum Credentials { + Anonymous, + RefreshToken(RefreshToken), +} +// public wrapper struct to hide private implementation details +pub struct RefreshToken(RefreshTokenCreds); + +impl Credentials { + /// Reads credentials from disk. + /// + /// The path is taken from the `GOOGLE_APPLICATION_CREDENTIALS` environment variable if set, + /// else `"${XDG_CONFIG_HOME-${HOME}/.config}/gcloud/application_default_credentials.json"`. If + /// the credentials file is not found or not readable, anonymous credentials will be used. + pub fn from_disk() -> Self { + let creds_file = match Self::credentials_file() { + None => return Credentials::Anonymous, + Some(f) => f, + }; + let reader = match File::open(&creds_file).map(BufReader::new) { + Err(e) => { + warn!( + "Failed to open GCS credentials file {:?}; will use anonymous credentials: {}", + creds_file, e + ); + return Credentials::Anonymous; + } + Ok(f) => f, + }; + let creds: RefreshTokenCreds = match serde_json::from_reader(reader) { + Err(e) => { + warn!( + "Failed to read GCS credentials file {:?}; will use anonymous credentials: {}", + creds_file, e + ); + return Credentials::Anonymous; + } + Ok(creds) => creds, + }; + info!( + "Using refresh token GCS creds from {}", + creds_file.display() + ); + Credentials::RefreshToken(RefreshToken(creds)) + } + + /// Determines the file on disk from which credentials might be read, if any. + fn credentials_file() -> Option { + if let Some(p) = std::env::var_os("GOOGLE_APPLICATION_CREDENTIALS") { + return Some(p.into()); + } + let base_config_dir = std::env::var_os("XDG_CONFIG_HOME") + .map(PathBuf::from) + .or_else(|| std::env::var_os("HOME").map(|p| PathBuf::from(p).join(".config"))); + if let Some(mut path) = base_config_dir { + path.extend(&["gcloud", "application_default_credentials.json"]); + if path.is_file() { + return Some(path); + } + }; + None + } + + /// Tests whether this credential is inherently anonymous. If this returns `true`, then + /// [`Self::fetch`] will always return `None`. + /// + /// This exists as an optimization so that a [`TokenStore`] doesn't need to check locks all the + /// time when the credential is anonymous, anyway. + fn anonymous(&self) -> bool { + matches!(self, Credentials::Anonymous) + } + + /// Attempts to fetch a fresh access token with these credentials. + fn fetch(&self, http: &HttpClient) -> Option { + match self { + Credentials::Anonymous => None, + Credentials::RefreshToken(RefreshToken(creds)) => match creds.fetch(http) { + Ok(t) => Some(t), + Err(e) => { + warn!("GCS authentication failed: {}", e); + None + } + }, + } + } +} + +/// Persistent credentials in the form of an OAuth refresh token. Can be posted to an OAuth token +/// endpoint to obtain a short-lived access token. +#[derive(Deserialize)] +struct RefreshTokenCreds { + client_id: String, + client_secret: String, + refresh_token: String, +} +impl Debug for RefreshTokenCreds { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RefreshTokenCreds") + .field("client_id", &self.client_id) + .field("client_secret", &format_args!("")) + .field("refresh_token", &format_args!("")) + .finish() + } +} + +/// POST body to [`OAUTH_REFRESH_TOKEN_ENDPOINT`]. +#[derive(Serialize)] +struct RefreshTokenRequest<'a> { + client_id: &'a str, + client_secret: &'a str, + refresh_token: &'a str, + grant_type: &'a str, +} +/// Response body from [`OAUTH_REFRESH_TOKEN_ENDPOINT`]. +#[derive(Deserialize)] +struct OauthTokenResponse { + access_token: AccessToken, + #[serde(default = "OauthTokenResponse::default_expires_in")] // optional per OAuth spec + expires_in: u64, // seconds +} +impl OauthTokenResponse { + fn default_expires_in() -> u64 { + let v = 3599; // standard response from Google OAuth servers + warn!("OAuth response did not set `expires_in`; assuming {}", v); + v + } +} + +impl RefreshTokenCreds { + /// Fetches a new access token from this refresh token credential. + pub fn fetch(&self, http: &HttpClient) -> reqwest::Result { + debug!("Fetching access token from refresh token"); + let req = RefreshTokenRequest { + client_id: &self.client_id, + client_secret: &self.client_secret, + refresh_token: &self.refresh_token, + grant_type: "refresh_token", + }; + let res: OauthTokenResponse = http + .post(OAUTH_REFRESH_TOKEN_ENDPOINT) + .json(&req) + .send()? + .error_for_status()? + .json()?; + Ok(BoundedToken { + access_token: res.access_token, + expires: Instant::now() + Duration::from_secs(res.expires_in), + }) + } +} diff --git a/tensorboard/data/server/gcs/client.rs b/tensorboard/data/server/gcs/client.rs index 5269e885a8..da0c958ae9 100644 --- a/tensorboard/data/server/gcs/client.rs +++ b/tensorboard/data/server/gcs/client.rs @@ -16,32 +16,44 @@ limitations under the License. //! Client for listing and reading GCS files. use log::debug; -use reqwest::{blocking::Client as HttpClient, StatusCode, Url}; +use reqwest::{ + blocking::{Client as HttpClient, RequestBuilder, Response}, + StatusCode, Url, +}; use std::ops::RangeInclusive; +use std::sync::Arc; +use std::time::Duration; + +use super::auth::{Credentials, TokenStore}; /// Base URL for direct object reads. const STORAGE_BASE: &str = "https://storage.googleapis.com"; /// Base URL for JSON API access. const API_BASE: &str = "https://www.googleapis.com/storage/v1"; +/// Refresh access tokens once their remaining lifetime is shorter than this threshold. +const TOKEN_EXPIRATION_MARGIN: Duration = Duration::from_secs(60); + /// GCS client. /// -/// Cloning a GCS client is cheap and shares the underlying connection pool, as with a -/// [`reqwest::Client`]. +/// Cloning a GCS client is cheap and shares the underlying credential store and connection pool, +/// as with a [`reqwest::Client`]. #[derive(Clone)] pub struct Client { + token_store: Arc, http: HttpClient, } impl Client { - /// Creates a new GCS client. + /// Creates a new GCS client with the given credentials. /// /// May fail if constructing the underlying HTTP client fails. - pub fn new() -> reqwest::Result { + pub fn new(creds: Credentials) -> reqwest::Result { let http = HttpClient::builder() .user_agent(format!("tensorboard-data-server/{}", crate::VERSION)) .build()?; - Ok(Self { http }) + let token_store = Arc::new(TokenStore::new(creds)); + Ok(Self { http, token_store }) } } @@ -63,6 +75,12 @@ struct ListResponseItem { } impl Client { + fn send_authenticated(&self, rb: RequestBuilder) -> reqwest::Result { + self.token_store + .authenticate(rb, &self.http, TOKEN_EXPIRATION_MARGIN) + .send() + } + /// Lists all objects in a bucket matching the given prefix. pub fn list(&self, bucket: &str, prefix: &str) -> reqwest::Result> { let mut base_url = Url::parse(API_BASE).unwrap(); @@ -86,7 +104,10 @@ impl Client { "Listing page {} of bucket {:?} (prefix={:?})", page, bucket, prefix ); - let res: ListResponse = self.http.get(url).send()?.error_for_status()?.json()?; + let res: ListResponse = self + .send_authenticated(self.http.get(url))? + .error_for_status()? + .json()?; results.extend(res.items.into_iter().map(|i| i.name)); if res.next_page_token.is_none() { break; @@ -111,7 +132,7 @@ impl Client { // With "Range: bytes=a-b", if `b >= 2**63` then GCS ignores the range entirely. let max_max = (1 << 63) - 1; let range = format!("bytes={}-{}", range.start(), range.end().min(&max_max)); - let res = self.http.get(url).header("Range", range).send()?; + let res = self.send_authenticated(self.http.get(url).header("Range", range))?; if res.status() == StatusCode::RANGE_NOT_SATISFIABLE { return Ok(Vec::new()); } diff --git a/tensorboard/data/server/gcs/gsutil.rs b/tensorboard/data/server/gcs/gsutil.rs index 41bede5c4a..fa197402f7 100644 --- a/tensorboard/data/server/gcs/gsutil.rs +++ b/tensorboard/data/server/gcs/gsutil.rs @@ -60,7 +60,7 @@ fn main() { let opts: Opts = Opts::parse(); init_logging(&opts); - let client = gcs::Client::new().unwrap(); + let client = gcs::Client::new(gcs::Credentials::from_disk()).unwrap(); match opts.subcmd { Subcommand::Ls(opts) => { log::info!("ENTER gcs::Client::list");