-
Notifications
You must be signed in to change notification settings - Fork 1.7k
rust: add refresh token authentication #4647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
948b724
3ab4fac
358687e
d77430a
cf6869c
971c2c2
428aaab
b95ea29
a890789
a65fa71
61240d1
281ba86
0f22707
1f9224a
c32e282
2dcce13
33e54a4
9f812db
9e04486
a6f7568
d2f8577
b300ea8
cb662b7
9d8b1c6
6013507
a19a6da
7978e34
ed877cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,300 @@ | ||||||
| /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||||||
|
|
||||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
| you may not use this file except in compliance with the License. | ||||||
| You may obtain a copy of the License at | ||||||
|
|
||||||
| http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|
|
||||||
| Unless required by applicable law or agreed to in writing, software | ||||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| See the License for the specific language governing permissions and | ||||||
| limitations under the License. | ||||||
| ==============================================================================*/ | ||||||
|
|
||||||
| //! OAuth integration for GCS. | ||||||
| //! | ||||||
| //! Useful resources: | ||||||
| //! | ||||||
| //! - TensorFlow OAuth implementation: [`oauth_client.cc`], [`google_auth_provider.cc`] | ||||||
| //! - [RFC 6749]: The OAuth 2.0 Authorization Framework | ||||||
| //! - ["Refreshing Access Tokens"] OAuth guide | ||||||
| //! | ||||||
| //! [`oauth_client.cc`]: https://github.com/tensorflow/tensorflow/blob/r2.4/tensorflow/core/platform/cloud/oauth_client.cc | ||||||
| //! [`google_auth_provider.cc`]: https://github.com/tensorflow/tensorflow/blob/r2.4/tensorflow/core/platform/cloud/google_auth_provider.cc | ||||||
| //! [RFC 6749]: https://tools.ietf.org/html/rfc6749 | ||||||
| //! ["Refreshing Access Tokens"]: https://www.oauth.com/oauth2-servers/access-tokens/refreshing-access-tokens/ | ||||||
|
|
||||||
| use log::{debug, info, warn}; | ||||||
| use serde::{Deserialize, Serialize}; | ||||||
| use std::fmt::{self, Debug}; | ||||||
| use std::fs::File; | ||||||
| use std::io::BufReader; | ||||||
| use std::path::PathBuf; | ||||||
| use std::sync::RwLock; | ||||||
| use std::time::{Duration, Instant}; | ||||||
|
|
||||||
| use reqwest::blocking::{Client as HttpClient, RequestBuilder}; | ||||||
|
|
||||||
| const OAUTH_REFRESH_TOKEN_ENDPOINT: &str = "https://www.googleapis.com/oauth2/v3/token"; | ||||||
|
|
||||||
| /// A set of refreshable OAuth credentials plus a potentially active token. Use | ||||||
| /// [`authenticate`][Self::authenticate] to add an `Authorization` header to an outgoing request, | ||||||
| /// fetching a fresh access token if necessary. | ||||||
| /// | ||||||
| /// A `TokenStore` may be freely shared among threads; it synchronizes internally if needed. | ||||||
| pub struct TokenStore { | ||||||
| creds: Credentials, | ||||||
| token: RwLock<Option<BoundedToken>>, | ||||||
| } | ||||||
|
|
||||||
| impl TokenStore { | ||||||
| /// Creates a new token store from the given credentials. | ||||||
| /// | ||||||
| /// This operation is cheap and does not actually fetch any access tokens. | ||||||
| pub fn new(creds: Credentials) -> Self { | ||||||
| Self { | ||||||
| creds, | ||||||
| token: RwLock::new(None), | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| /// An access token that's valid until a particular point in time. | ||||||
| #[derive(Debug)] | ||||||
| struct BoundedToken { | ||||||
| access_token: AccessToken, | ||||||
| expires: Instant, | ||||||
| } | ||||||
|
|
||||||
| impl BoundedToken { | ||||||
| /// Tests whether this token will still be valid at the given instant. | ||||||
| fn valid_at(&self, when: Instant) -> bool { | ||||||
| when < self.expires | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| /// Private access token module to prevent accidental leaking of tokens into logs, etc. An access | ||||||
| /// token can be attached to a request, but cannot be directly extracted. | ||||||
| mod access_token { | ||||||
| use super::*; | ||||||
| #[derive(Deserialize)] | ||||||
| pub struct AccessToken(String); | ||||||
| impl AccessToken { | ||||||
| /// Attaches this token to an outgoing request. | ||||||
| pub fn authenticate(&self, rb: RequestBuilder) -> RequestBuilder { | ||||||
| rb.bearer_auth(&self.0) | ||||||
| } | ||||||
| } | ||||||
| impl Debug for AccessToken { | ||||||
| fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||||||
| f.debug_tuple("AccessToken") | ||||||
| .field(&format_args!("_")) | ||||||
| .finish() | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
| use access_token::AccessToken; | ||||||
|
|
||||||
| impl TokenStore { | ||||||
| /// Attempts to attach an access token to the given outgoing request. | ||||||
| /// | ||||||
| /// A cached token will be reused if it is expected to be valid for at least `lifetime`. | ||||||
| /// Otherwise, a new token will be fetched and stored. | ||||||
| pub fn authenticate( | ||||||
| &self, | ||||||
| rb: RequestBuilder, | ||||||
| http: &HttpClient, | ||||||
| lifetime: Duration, | ||||||
| ) -> RequestBuilder { | ||||||
| if self.creds.anonymous() { | ||||||
| return rb; | ||||||
| } | ||||||
| let token = self.token.read().expect("failed to read auth token"); | ||||||
| if let Some(t) = BoundedToken::unwrap_if_valid_for(&*token, lifetime) { | ||||||
| return t.authenticate(rb); | ||||||
| } | ||||||
| drop(token); | ||||||
| let mut token = self.token.write().expect("failed to write auth token"); | ||||||
| // Check again: may have just been written by a different client, in which case no need to | ||||||
| // re-fetch. | ||||||
| if let Some(t) = BoundedToken::unwrap_if_valid_for(&*token, lifetime) { | ||||||
| return t.authenticate(rb); | ||||||
| } | ||||||
| // If we get here, we need a fresh token. | ||||||
| *token = self.creds.fetch(http); | ||||||
| if let Some(ref t) = *token { | ||||||
| debug!( | ||||||
| "Obtained new access token, live for the next {:?}", | ||||||
| t.expires.saturating_duration_since(Instant::now()) | ||||||
| ); | ||||||
| t.access_token.authenticate(rb) | ||||||
| } else { | ||||||
| rb | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| impl BoundedToken { | ||||||
| /// Checks whether `token` represents a token that will still be valid for at least the given | ||||||
| /// `lifetime`, and if so returns a reference to the inner access token. | ||||||
| fn unwrap_if_valid_for(token: &Option<Self>, lifetime: Duration) -> Option<&AccessToken> { | ||||||
| match token.as_ref() { | ||||||
| Some(t) if t.valid_at(Instant::now() + lifetime) => Some(&t.access_token), | ||||||
| _ => None, | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| /// The user's persistent credentials, if any. This represents all the information needed to | ||||||
| /// request access tokens. | ||||||
| pub enum Credentials { | ||||||
| Anonymous, | ||||||
| RefreshToken(RefreshToken), | ||||||
| } | ||||||
| // public wrapper struct to hide private implementation details | ||||||
| pub struct RefreshToken(RefreshTokenCreds); | ||||||
|
|
||||||
| impl Credentials { | ||||||
| /// Reads credentials from disk. | ||||||
| /// | ||||||
| /// The path is taken from the `GOOGLE_APPLICATION_CREDENTIALS` environment variable if set, | ||||||
| /// else `"${XDG_CONFIG_HOME-${HOME}/.config}/gcloud/application_default_credentials.json"`. If | ||||||
| /// the credentials file is not found or not readable, anonymous credentials will be used. | ||||||
| pub fn from_disk() -> Self { | ||||||
| let creds_file = match Self::credentials_file() { | ||||||
| None => return Credentials::Anonymous, | ||||||
| Some(f) => f, | ||||||
| }; | ||||||
| let reader = match File::open(&creds_file).map(BufReader::new) { | ||||||
| Err(e) => { | ||||||
| warn!( | ||||||
| "Failed to open GCS credentials file {:?}; will use anonymous credentials: {}", | ||||||
| creds_file, e | ||||||
| ); | ||||||
| return Credentials::Anonymous; | ||||||
| } | ||||||
| Ok(f) => f, | ||||||
| }; | ||||||
| let creds: RefreshTokenCreds = match serde_json::from_reader(reader) { | ||||||
| Err(e) => { | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If serde fails to deserialize the contents of the credentials file, it goes into this error case, which is good. However, does the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question. Indeed: that becomes a decode error, which we
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting, thanks |
||||||
| warn!( | ||||||
| "Failed to read GCS credentials file {:?}; will use anonymous credentials: {}", | ||||||
| creds_file, e | ||||||
| ); | ||||||
| return Credentials::Anonymous; | ||||||
| } | ||||||
| Ok(creds) => creds, | ||||||
| }; | ||||||
| info!( | ||||||
| "Using refresh token GCS creds from {}", | ||||||
| creds_file.display() | ||||||
| ); | ||||||
| Credentials::RefreshToken(RefreshToken(creds)) | ||||||
| } | ||||||
|
|
||||||
| /// Determines the file on disk from which credentials might be read, if any. | ||||||
| fn credentials_file() -> Option<PathBuf> { | ||||||
| if let Some(p) = std::env::var_os("GOOGLE_APPLICATION_CREDENTIALS") { | ||||||
| return Some(p.into()); | ||||||
| } | ||||||
| let base_config_dir = std::env::var_os("XDG_CONFIG_HOME") | ||||||
| .map(PathBuf::from) | ||||||
| .or_else(|| std::env::var_os("HOME").map(|p| PathBuf::from(p).join(".config"))); | ||||||
| if let Some(mut path) = base_config_dir { | ||||||
| path.extend(&["gcloud", "application_default_credentials.json"]); | ||||||
| if path.is_file() { | ||||||
| return Some(path); | ||||||
| } | ||||||
| }; | ||||||
| None | ||||||
| } | ||||||
|
|
||||||
| /// Tests whether this credential is inherently anonymous. If this returns `true`, then | ||||||
| /// [`Self::fetch`] will always return `None`. | ||||||
| /// | ||||||
| /// This exists as an optimization so that a [`TokenStore`] doesn't need to check locks all the | ||||||
| /// time when the credential is anonymous, anyway. | ||||||
| fn anonymous(&self) -> bool { | ||||||
| matches!(self, Credentials::Anonymous) | ||||||
| } | ||||||
|
|
||||||
| /// Attempts to fetch a fresh access token with these credentials. | ||||||
| fn fetch(&self, http: &HttpClient) -> Option<BoundedToken> { | ||||||
| match self { | ||||||
| Credentials::Anonymous => None, | ||||||
| Credentials::RefreshToken(RefreshToken(creds)) => match creds.fetch(http) { | ||||||
| Ok(t) => Some(t), | ||||||
| Err(e) => { | ||||||
| warn!("GCS authentication failed: {}", e); | ||||||
| None | ||||||
| } | ||||||
| }, | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| /// Persistent credentials in the form of an OAuth refresh token. Can be posted to an OAuth token | ||||||
| /// endpoint to obtain a short-lived access token. | ||||||
| #[derive(Deserialize)] | ||||||
| struct RefreshTokenCreds { | ||||||
| client_id: String, | ||||||
| client_secret: String, | ||||||
| refresh_token: String, | ||||||
| } | ||||||
| impl Debug for RefreshTokenCreds { | ||||||
| fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||||||
| f.debug_struct("RefreshTokenCreds") | ||||||
| .field("client_id", &self.client_id) | ||||||
| .field("client_secret", &format_args!("<redacted>")) | ||||||
| .field("refresh_token", &format_args!("<redacted>")) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just checking: if the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My understanding is that the client “secret” is actually not secret: tensorboard/tensorboard/uploader/auth.py Lines 46 to 47 in b306651
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. …but, you know, on second thought, it costs nothing to hide the client
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah - it's not secret for well-known installed apps like |
||||||
| .finish() | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| /// POST body to [`OAUTH_REFRESH_TOKEN_ENDPOINT`]. | ||||||
| #[derive(Serialize)] | ||||||
| struct RefreshTokenRequest<'a> { | ||||||
| client_id: &'a str, | ||||||
| client_secret: &'a str, | ||||||
| refresh_token: &'a str, | ||||||
| grant_type: &'a str, | ||||||
| } | ||||||
| /// Response body from [`OAUTH_REFRESH_TOKEN_ENDPOINT`]. | ||||||
| #[derive(Deserialize)] | ||||||
| struct OauthTokenResponse { | ||||||
| access_token: AccessToken, | ||||||
| #[serde(default = "OauthTokenResponse::default_expires_in")] // optional per OAuth spec | ||||||
| expires_in: u64, // seconds | ||||||
| } | ||||||
| impl OauthTokenResponse { | ||||||
| fn default_expires_in() -> u64 { | ||||||
| let v = 3599; // standard response from Google OAuth servers | ||||||
| warn!("OAuth response did not set `expires_in`; assuming {}", v); | ||||||
| v | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| impl RefreshTokenCreds { | ||||||
| /// Fetches a new access token from this refresh token credential. | ||||||
| pub fn fetch(&self, http: &HttpClient) -> reqwest::Result<BoundedToken> { | ||||||
| debug!("Fetching access token from refresh token"); | ||||||
| let req = RefreshTokenRequest { | ||||||
| client_id: &self.client_id, | ||||||
| client_secret: &self.client_secret, | ||||||
| refresh_token: &self.refresh_token, | ||||||
| grant_type: "refresh_token", | ||||||
| }; | ||||||
| let res: OauthTokenResponse = http | ||||||
| .post(OAUTH_REFRESH_TOKEN_ENDPOINT) | ||||||
| .json(&req) | ||||||
| .send()? | ||||||
| .error_for_status()? | ||||||
| .json()?; | ||||||
| Ok(BoundedToken { | ||||||
| access_token: res.access_token, | ||||||
| expires: Instant::now() + Duration::from_secs(res.expires_in), | ||||||
| }) | ||||||
| } | ||||||
| } | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slightly non-obvious IMO from the name that
check_expiryalso unwraps the token. Optional but maybe consider making it a method on BoundedToken (could replacevalid_at) likeunwrap_if_valid_for()?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, sure: changed as suggested.