diff --git a/sdk/python/feast/grpc/auth.py b/sdk/python/feast/grpc/auth.py index bc7a9921d8..3deb95be24 100644 --- a/sdk/python/feast/grpc/auth.py +++ b/sdk/python/feast/grpc/auth.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +import time from http import HTTPStatus import grpc @@ -159,6 +160,7 @@ def __init__(self, config: Config): self._static_token = None self._token = None + self._token_expiry_ts = time.time() # If provided, set a static token if config.exists(opt.AUTH_TOKEN): @@ -169,6 +171,9 @@ def __init__(self, config: Config): def get_signed_meta(self): """ Creates a signed authorization metadata token.""" + + if time.time() > self._token_expiry_ts: + self._refresh_token() return (("authorization", "Bearer {}".format(self._token)),) def _refresh_token(self): @@ -179,10 +184,13 @@ def _refresh_token(self): self._token = self._static_token return - from google.oauth2.id_token import fetch_id_token + from google.oauth2.id_token import fetch_id_token, verify_oauth2_token try: self._token = fetch_id_token(self._request, audience="feast.dev") + self._token_expiry_ts = verify_oauth2_token(self._token, self._request)[ + "exp" + ] return except DefaultCredentialsError: pass @@ -195,6 +203,9 @@ def _refresh_token(self): credentials.refresh(self._request) if hasattr(credentials, "id_token"): self._token = credentials.id_token + self._token_expiry_ts = verify_oauth2_token(self._token, self._request)[ + "exp" + ] return except DefaultCredentialsError: pass # Could not determine credentials, skip diff --git a/sdk/python/tests/grpc/test_auth.py b/sdk/python/tests/grpc/test_auth.py index ebbe936f88..507c9fa032 100644 --- a/sdk/python/tests/grpc/test_auth.py +++ b/sdk/python/tests/grpc/test_auth.py @@ -147,9 +147,13 @@ def test_get_auth_metadata_plugin_oauth_should_raise_when_config_is_incorrect( get_auth_metadata_plugin(config_with_missing_variable) +@patch( + "google.oauth2.id_token.verify_token", + return_value={"iss": "accounts.google.com", "exp": 12341234}, +) @patch("google.oauth2.id_token.fetch_id_token", return_value="Some Token") def test_get_auth_metadata_plugin_google_should_pass_with_token_from_gcloud_sdk( - fetch_id_token, config_google + verify_token, fetch_id_token, config_google ): auth_metadata_plugin = get_auth_metadata_plugin(config_google) assert isinstance(auth_metadata_plugin, GoogleOpenIDAuthMetadataPlugin) @@ -158,6 +162,10 @@ def test_get_auth_metadata_plugin_google_should_pass_with_token_from_gcloud_sdk( ) +@patch( + "google.oauth2.id_token.verify_token", + return_value={"iss": "accounts.google.com", "exp": 12341234}, +) @patch( "google.auth.default", return_value=[ @@ -167,7 +175,7 @@ def test_get_auth_metadata_plugin_google_should_pass_with_token_from_gcloud_sdk( ) @patch("google.oauth2.id_token.fetch_id_token", side_effect=DefaultCredentialsError()) def test_get_auth_metadata_plugin_google_should_pass_with_token_from_google_auth_lib( - fetch_id_token, default, config_google + verify_token, fetch_id_token, default, config_google ): auth_metadata_plugin = get_auth_metadata_plugin(config_google) assert isinstance(auth_metadata_plugin, GoogleOpenIDAuthMetadataPlugin)