Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
948b724
rust: extract `Logdir` trait for other filesystems
wchargin Feb 1, 2021
3ab4fac
[rust-logdir-trait: update patch]
wchargin Feb 2, 2021
358687e
[rust-logdir-trait: update diffbase]
wchargin Feb 2, 2021
d77430a
[rust-logdir-trait: update patch]
wchargin Feb 2, 2021
cf6869c
[rust-logdir-trait: update patch]
wchargin Feb 2, 2021
971c2c2
[rust-logdir-trait: update patch]
wchargin Feb 2, 2021
428aaab
[rust-logdir-trait: update diffbase]
wchargin Feb 3, 2021
b95ea29
[rust-logdir-trait: remove unnecessary `move`]
wchargin Feb 3, 2021
a890789
rust: add `reqwest` dependency
wchargin Feb 3, 2021
a65fa71
rust: add GCS listing and reading
wchargin Feb 3, 2021
61240d1
[rust-gcs-client: handle no-items case]
wchargin Feb 3, 2021
281ba86
Merge remote-tracking branch 'origin/wchargin-rust-logdir-trait' into…
wchargin Feb 3, 2021
0f22707
rust: add GCS logdir support
wchargin Feb 3, 2021
1f9224a
rust: add refresh token authentication
wchargin Feb 3, 2021
c32e282
[rust-gcs-logdir: update diffbase]
wchargin Feb 4, 2021
2dcce13
[rust-gcs-logdir: resolve conflicts]
wchargin Feb 4, 2021
33e54a4
[rust-gcs-auth: update diffbase]
wchargin Feb 4, 2021
9f812db
[rust-gcs-logdir: update diffbase]
wchargin Feb 5, 2021
9e04486
[rust-gcs-logdir: work around seanmonstar/reqwest#1017 in debug mode]
wchargin Feb 5, 2021
a6f7568
[rust-gcs-auth: update diffbase]
wchargin Feb 5, 2021
d2f8577
[rust-gcs-auth: update patch]
wchargin Feb 5, 2021
b300ea8
[rust-gcs-auth: redact `client_secret` from `Debug` impl]
wchargin Feb 5, 2021
cb662b7
[rust-gcs-auth: minor readability improvements]
wchargin Feb 6, 2021
9d8b1c6
[rust-gcs-logdir: update diffbase]
wchargin Feb 9, 2021
6013507
[rust-gcs-auth: update diffbase]
wchargin Feb 9, 2021
a19a6da
[rust-gcs-auth: update diffbase]
wchargin Feb 9, 2021
7978e34
[rust-gcs-auth: resolve conflicts]
wchargin Feb 9, 2021
ed877cd
[rust-gcs-auth: BoundedToken::unwrap_if_valid_for]
wchargin Feb 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorboard/data/server/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ rust_library(
"downsample.rs",
"event_file.rs",
"gcs.rs",
"gcs/auth.rs",
"gcs/client.rs",
"gcs/logdir.rs",
"logdir.rs",
Expand Down
5 changes: 4 additions & 1 deletion tensorboard/data/server/cli/dynamic_logdir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ impl DynLogdir {
/// This succeeds unless the path represents a GCS logdir and no HTTP client can be opened. In
/// case of failure, errors will be logged to the active logger.
///
/// This constructor is heavyweight; it may construct an HTTP client and read a GCS credentials
/// file from disk.
///
/// # Panics
///
/// May panic in debug mode if called from a thread with an active Tokio runtime; see
Expand All @@ -59,7 +62,7 @@ impl DynLogdir {
let mut parts = gcs_path.splitn(2, '/');
let bucket = parts.next().unwrap().to_string(); // splitn always yields at least one element
let prefix = parts.next().unwrap_or("").to_string();
let client = match gcs::Client::new() {
let client = match gcs::Client::new(gcs::Credentials::from_disk()) {
Err(e) => {
error!("Could not open GCS connection: {}", e);
return None;
Expand Down
2 changes: 2 additions & 0 deletions tensorboard/data/server/gcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ limitations under the License.

//! Google Cloud Storage interop.

mod auth;
mod client;
mod logdir;

pub use auth::Credentials;
pub use client::Client;
pub use logdir::Logdir;
300 changes: 300 additions & 0 deletions tensorboard/data/server/gcs/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

//! OAuth integration for GCS.
//!
//! Useful resources:
//!
//! - TensorFlow OAuth implementation: [`oauth_client.cc`], [`google_auth_provider.cc`]
//! - [RFC 6749]: The OAuth 2.0 Authorization Framework
//! - ["Refreshing Access Tokens"] OAuth guide
//!
//! [`oauth_client.cc`]: https://github.com/tensorflow/tensorflow/blob/r2.4/tensorflow/core/platform/cloud/oauth_client.cc
//! [`google_auth_provider.cc`]: https://github.com/tensorflow/tensorflow/blob/r2.4/tensorflow/core/platform/cloud/google_auth_provider.cc
//! [RFC 6749]: https://tools.ietf.org/html/rfc6749
//! ["Refreshing Access Tokens"]: https://www.oauth.com/oauth2-servers/access-tokens/refreshing-access-tokens/

use log::{debug, info, warn};
use serde::{Deserialize, Serialize};
use std::fmt::{self, Debug};
use std::fs::File;
use std::io::BufReader;
use std::path::PathBuf;
use std::sync::RwLock;
use std::time::{Duration, Instant};

use reqwest::blocking::{Client as HttpClient, RequestBuilder};

const OAUTH_REFRESH_TOKEN_ENDPOINT: &str = "https://www.googleapis.com/oauth2/v3/token";

/// A set of refreshable OAuth credentials plus a potentially active token. Use
/// [`authenticate`][Self::authenticate] to add an `Authorization` header to an outgoing request,
/// fetching a fresh access token if necessary.
///
/// A `TokenStore` may be freely shared among threads; it synchronizes internally if needed.
pub struct TokenStore {
creds: Credentials,
token: RwLock<Option<BoundedToken>>,
}

impl TokenStore {
/// Creates a new token store from the given credentials.
///
/// This operation is cheap and does not actually fetch any access tokens.
pub fn new(creds: Credentials) -> Self {
Self {
creds,
token: RwLock::new(None),
}
}
}

/// An access token that's valid until a particular point in time.
#[derive(Debug)]
struct BoundedToken {
access_token: AccessToken,
expires: Instant,
}

impl BoundedToken {
/// Tests whether this token will still be valid at the given instant.
fn valid_at(&self, when: Instant) -> bool {
when < self.expires
}
}

/// Private access token module to prevent accidental leaking of tokens into logs, etc. An access
/// token can be attached to a request, but cannot be directly extracted.
mod access_token {
use super::*;
#[derive(Deserialize)]
pub struct AccessToken(String);
impl AccessToken {
/// Attaches this token to an outgoing request.
pub fn authenticate(&self, rb: RequestBuilder) -> RequestBuilder {
rb.bearer_auth(&self.0)
}
}
impl Debug for AccessToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("AccessToken")
.field(&format_args!("_"))
.finish()
}
}
}
use access_token::AccessToken;

impl TokenStore {
/// Attempts to attach an access token to the given outgoing request.
///
/// A cached token will be reused if it is expected to be valid for at least `lifetime`.
/// Otherwise, a new token will be fetched and stored.
pub fn authenticate(
&self,
rb: RequestBuilder,
http: &HttpClient,
lifetime: Duration,
) -> RequestBuilder {
if self.creds.anonymous() {
return rb;
}
let token = self.token.read().expect("failed to read auth token");
if let Some(t) = BoundedToken::unwrap_if_valid_for(&*token, lifetime) {
return t.authenticate(rb);
}
drop(token);
let mut token = self.token.write().expect("failed to write auth token");
// Check again: may have just been written by a different client, in which case no need to
// re-fetch.
if let Some(t) = BoundedToken::unwrap_if_valid_for(&*token, lifetime) {
return t.authenticate(rb);
}
// If we get here, we need a fresh token.
*token = self.creds.fetch(http);
if let Some(ref t) = *token {
debug!(
"Obtained new access token, live for the next {:?}",
t.expires.saturating_duration_since(Instant::now())
);
t.access_token.authenticate(rb)
} else {
rb
}
}
}

impl BoundedToken {
/// Checks whether `token` represents a token that will still be valid for at least the given
/// `lifetime`, and if so returns a reference to the inner access token.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly non-obvious IMO from the name that check_expiry also unwraps the token. Optional but maybe consider making it a method on BoundedToken (could replace valid_at) like unwrap_if_valid_for()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, sure: changed as suggested.

fn unwrap_if_valid_for(token: &Option<Self>, lifetime: Duration) -> Option<&AccessToken> {
match token.as_ref() {
Some(t) if t.valid_at(Instant::now() + lifetime) => Some(&t.access_token),
_ => None,
}
}
}

/// The user's persistent credentials, if any. This represents all the information needed to
/// request access tokens.
pub enum Credentials {
Anonymous,
RefreshToken(RefreshToken),
}
// public wrapper struct to hide private implementation details
pub struct RefreshToken(RefreshTokenCreds);

impl Credentials {
/// Reads credentials from disk.
///
/// The path is taken from the `GOOGLE_APPLICATION_CREDENTIALS` environment variable if set,
/// else `"${XDG_CONFIG_HOME-${HOME}/.config}/gcloud/application_default_credentials.json"`. If
/// the credentials file is not found or not readable, anonymous credentials will be used.
pub fn from_disk() -> Self {
let creds_file = match Self::credentials_file() {
None => return Credentials::Anonymous,
Some(f) => f,
};
let reader = match File::open(&creds_file).map(BufReader::new) {
Err(e) => {
warn!(
"Failed to open GCS credentials file {:?}; will use anonymous credentials: {}",
creds_file, e
);
return Credentials::Anonymous;
}
Ok(f) => f,
};
let creds: RefreshTokenCreds = match serde_json::from_reader(reader) {
Err(e) => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If serde fails to deserialize the contents of the credentials file, it goes into this error case, which is good. However, does the match still fall into this case when serde properly deserializes the contents into a proper struct, which does not conform to the RefreshTokenCreds struct? (e.g. if it doesn't have a client_id field)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. Indeed: that becomes a decode error, which we
propagate up (as an io::ErrorKind::InvalidData error, through the
mapping in gcs::logdir::reqwest_to_io_error from #4646).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, thanks

warn!(
"Failed to read GCS credentials file {:?}; will use anonymous credentials: {}",
creds_file, e
);
return Credentials::Anonymous;
}
Ok(creds) => creds,
};
info!(
"Using refresh token GCS creds from {}",
creds_file.display()
);
Credentials::RefreshToken(RefreshToken(creds))
}

/// Determines the file on disk from which credentials might be read, if any.
fn credentials_file() -> Option<PathBuf> {
if let Some(p) = std::env::var_os("GOOGLE_APPLICATION_CREDENTIALS") {
return Some(p.into());
}
let base_config_dir = std::env::var_os("XDG_CONFIG_HOME")
.map(PathBuf::from)
.or_else(|| std::env::var_os("HOME").map(|p| PathBuf::from(p).join(".config")));
if let Some(mut path) = base_config_dir {
path.extend(&["gcloud", "application_default_credentials.json"]);
if path.is_file() {
return Some(path);
}
};
None
}

/// Tests whether this credential is inherently anonymous. If this returns `true`, then
/// [`Self::fetch`] will always return `None`.
///
/// This exists as an optimization so that a [`TokenStore`] doesn't need to check locks all the
/// time when the credential is anonymous, anyway.
fn anonymous(&self) -> bool {
matches!(self, Credentials::Anonymous)
}

/// Attempts to fetch a fresh access token with these credentials.
fn fetch(&self, http: &HttpClient) -> Option<BoundedToken> {
match self {
Credentials::Anonymous => None,
Credentials::RefreshToken(RefreshToken(creds)) => match creds.fetch(http) {
Ok(t) => Some(t),
Err(e) => {
warn!("GCS authentication failed: {}", e);
None
}
},
}
}
}

/// Persistent credentials in the form of an OAuth refresh token. Can be posted to an OAuth token
/// endpoint to obtain a short-lived access token.
#[derive(Deserialize)]
struct RefreshTokenCreds {
client_id: String,
client_secret: String,
refresh_token: String,
}
impl Debug for RefreshTokenCreds {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RefreshTokenCreds")
.field("client_id", &self.client_id)
.field("client_secret", &format_args!("<redacted>"))
.field("refresh_token", &format_args!("<redacted>"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking: if the <redacted> here is to cover the case when we accidentally save/log our refresh tokens somewhere, should we redact the client_secret too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that the client “secret” is actually not secret:

# The client "secret" is public by design for installed apps. See
# https://developers.google.com/identity/protocols/OAuth2?csw=1#installed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

…but, you know, on second thought, it costs nothing to hide the client
“secret”, so I might as well, just in case my understanding is missing
some pieces. Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah - it's not secret for well-known installed apps like gcloud or our uploader, since it's not possible to hide it, but I think it could be a real secret for a generic GOOGLE_APPLICATION_CREDENTIALS path.

.finish()
}
}

/// POST body to [`OAUTH_REFRESH_TOKEN_ENDPOINT`].
#[derive(Serialize)]
struct RefreshTokenRequest<'a> {
client_id: &'a str,
client_secret: &'a str,
refresh_token: &'a str,
grant_type: &'a str,
}
/// Response body from [`OAUTH_REFRESH_TOKEN_ENDPOINT`].
#[derive(Deserialize)]
struct OauthTokenResponse {
access_token: AccessToken,
#[serde(default = "OauthTokenResponse::default_expires_in")] // optional per OAuth spec
expires_in: u64, // seconds
}
impl OauthTokenResponse {
fn default_expires_in() -> u64 {
let v = 3599; // standard response from Google OAuth servers
warn!("OAuth response did not set `expires_in`; assuming {}", v);
v
}
}

impl RefreshTokenCreds {
/// Fetches a new access token from this refresh token credential.
pub fn fetch(&self, http: &HttpClient) -> reqwest::Result<BoundedToken> {
debug!("Fetching access token from refresh token");
let req = RefreshTokenRequest {
client_id: &self.client_id,
client_secret: &self.client_secret,
refresh_token: &self.refresh_token,
grant_type: "refresh_token",
};
let res: OauthTokenResponse = http
.post(OAUTH_REFRESH_TOKEN_ENDPOINT)
.json(&req)
.send()?
.error_for_status()?
.json()?;
Ok(BoundedToken {
access_token: res.access_token,
expires: Instant::now() + Duration::from_secs(res.expires_in),
})
}
}
Loading