Skip to content

Commit

Permalink
Delete unnecessary auth configuration (#858)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
  • Loading branch information
wild-endeavor authored Feb 23, 2022
1 parent 9477e1f commit 1a89b78
Show file tree
Hide file tree
Showing 13 changed files with 301 additions and 476 deletions.
199 changes: 151 additions & 48 deletions flytekit/clients/raw.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from __future__ import annotations

import base64 as _base64
import logging as _logging
import subprocess
import time
from typing import List
from typing import Optional

import requests as _requests
from flyteidl.service import admin_pb2_grpc as _admin_service
from flyteidl.service import auth_pb2
from flyteidl.service import auth_pb2_grpc as auth_service
from google.protobuf.json_format import MessageToJson as _MessageToJson
from grpc import RpcError as _RpcError
from grpc import StatusCode as _GrpcStatusCode
Expand All @@ -11,49 +18,57 @@
from grpc import ssl_channel_credentials as _ssl_channel_credentials

from flytekit.clis.auth import credentials as _credentials_access
from flytekit.clis.sdk_in_container import basic_auth as _basic_auth
from flytekit.configuration import creds as _creds_config
from flytekit.configuration.creds import _DEPRECATED_CLIENT_CREDENTIALS_SCOPE as _DEPRECATED_SCOPE
from flytekit.configuration import creds as creds_config
from flytekit.configuration.creds import CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET
from flytekit.configuration.creds import CLIENT_ID as _CLIENT_ID
from flytekit.configuration.creds import COMMAND as _COMMAND
from flytekit.configuration.creds import DEPRECATED_OAUTH_SCOPES, SCOPES
from flytekit.configuration.platform import AUTH as _AUTH
from flytekit.exceptions import user as _user_exceptions
from flytekit.exceptions.user import FlyteAuthenticationException
from flytekit.loggers import cli_logger

_utf_8 = "utf-8"


def _refresh_credentials_standard(flyte_client):
def _refresh_credentials_standard(flyte_client: RawSynchronousFlyteClient):
"""
This function is used when the configuration value for AUTH_MODE is set to 'standard'.
This either fetches the existing access token or initiates the flow to request a valid access token and store it.
:param flyte_client: RawSynchronousFlyteClient
:return:
"""

client = _credentials_access.get_client(flyte_client.url)
if client.can_refresh_token:
authorization_header_key = flyte_client.public_client_config.authorization_metadata_key or None
if not flyte_client.oauth2_metadata or not flyte_client.public_client_config:
raise ValueError(
"Raw Flyte client attempting client credentials flow but no response from Admin detected. "
"Check your Admin server's .well-known endpoints to make sure they're working as expected."
)
client = _credentials_access.get_client(
redirect_endpoint=flyte_client.public_client_config.redirect_uri,
client_id=flyte_client.public_client_config.client_id,
scopes=flyte_client.public_client_config.scopes,
auth_endpoint=flyte_client.oauth2_metadata.authorization_endpoint,
token_endpoint=flyte_client.oauth2_metadata.token_endpoint,
)
if client.has_valid_credentials and not flyte_client.check_access_token(client.credentials.access_token):
# When Python starts up, if credentials have been stored in the keyring, then the AuthorizationClient
# will have read them into its _credentials field, but it won't be in the RawSynchronousFlyteClient's
# metadata field yet. Therefore, if there's a mismatch, copy it over.
flyte_client.set_access_token(client.credentials.access_token, authorization_header_key)
# However, after copying over credentials from the AuthorizationClient, we have to clear it to avoid the
# scenario where the stored credentials in the keyring are expired. If that's the case, then we only try
# them once (because client here is a singleton), and the next time, we'll do one of the two other conditions
# below.
client.clear()
return
elif client.can_refresh_token:
client.refresh_access_token()
else:
client.start_authorization_flow()

flyte_client.set_access_token(client.credentials.access_token)


def _get_basic_flow_scopes() -> List[str]:
"""
Merge the scope value between the old scope config option and the new list option.
:return: The scopes to use for basic auth flow.
"""
deprecated_single_scope = _DEPRECATED_SCOPE.get()
if deprecated_single_scope:
return [deprecated_single_scope]
scopes = DEPRECATED_OAUTH_SCOPES.get() or SCOPES.get()
if "openid" in scopes:
cli_logger.warning("Basic flow authentication should never use openid.")

return scopes
flyte_client.set_access_token(client.credentials.access_token, authorization_header_key)


def _refresh_credentials_basic(flyte_client):
def _refresh_credentials_basic(flyte_client: RawSynchronousFlyteClient):
"""
This function is used by the _handle_rpc_error() decorator, depending on the AUTH_MODE config object. This handler
is meant for SDK use-cases of auth (like pyflyte, or when users call SDK functions that require access to Admin,
Expand All @@ -63,16 +78,24 @@ def _refresh_credentials_basic(flyte_client):
:param flyte_client: RawSynchronousFlyteClient
:return:
"""
auth_endpoints = _credentials_access.get_authorization_endpoints(flyte_client.url)
token_endpoint = auth_endpoints.token_endpoint
client_secret = _basic_auth.get_secret()
cli_logger.debug(
"Basic authorization flow with client id {} scope {}".format(_CLIENT_ID.get(), _get_basic_flow_scopes())
)
authorization_header = _basic_auth.get_basic_authorization_header(_CLIENT_ID.get(), client_secret)
token, expires_in = _basic_auth.get_token(token_endpoint, authorization_header, _get_basic_flow_scopes())
if not flyte_client.oauth2_metadata or not flyte_client.public_client_config:
raise ValueError(
"Raw Flyte client attempting client credentials flow but no response from Admin detected. "
"Check your Admin server's .well-known endpoints to make sure they're working as expected."
)

token_endpoint = flyte_client.oauth2_metadata.token_endpoint
scopes = creds_config.SCOPES.get() or flyte_client.public_client_config.scopes
scopes = ",".join(scopes)

# Note that unlike the Pkce flow, the client ID does not come from Admin.
client_secret = get_secret()
cli_logger.debug("Basic authorization flow with client id {} scope {}".format(_CLIENT_ID.get(), scopes))
authorization_header = get_basic_authorization_header(_CLIENT_ID.get(), client_secret)
token, expires_in = get_token(token_endpoint, authorization_header, scopes)
cli_logger.info("Retrieved new token, expires in {}".format(expires_in))
flyte_client.set_access_token(token)
authorization_header_key = flyte_client.public_client_config.authorization_metadata_key or None
flyte_client.set_access_token(token, authorization_header_key)


def _refresh_credentials_from_command(flyte_client):
Expand Down Expand Up @@ -101,7 +124,7 @@ def _refresh_credentials_noop(flyte_client):
def _get_refresh_handler(auth_mode):
if auth_mode == "standard":
return _refresh_credentials_standard
elif auth_mode == "basic":
elif auth_mode == "basic" or auth_mode == "client_credentials":
return _refresh_credentials_basic
elif auth_mode == "external_process":
return _refresh_credentials_from_command
Expand Down Expand Up @@ -133,7 +156,7 @@ def handler(*args, **kwargs):
# Exit the loop and wrap the authentication error.
raise _user_exceptions.FlyteAuthenticationException(str(e))
cli_logger.error(f"Unauthenticated RPC error {e}, refreshing credentials and retrying\n")
refresh_handler_fn = _get_refresh_handler(_creds_config.AUTH_MODE.get())
refresh_handler_fn = _get_refresh_handler(creds_config.AUTH_MODE.get())
refresh_handler_fn(args[0])
# There are two cases that we should throw error immediately
# 1. Entity already exists when we register entity
Expand Down Expand Up @@ -210,29 +233,57 @@ def __init__(self, url, insecure=False, credentials=None, options=None, root_cer
options=list((options or {}).items()),
)
self._stub = _admin_service.AdminServiceStub(self._channel)
self._auth_stub = auth_service.AuthMetadataServiceStub(self._channel)
try:
resp = self._auth_stub.GetPublicClientConfig(auth_pb2.PublicClientAuthConfigRequest())
self._public_client_config = resp
except _RpcError:
cli_logger.debug("No public client auth config found, skipping.")
self._public_client_config = None
try:
resp = self._auth_stub.GetOAuth2Metadata(auth_pb2.OAuth2MetadataRequest())
self._oauth2_metadata = resp
except _RpcError:
cli_logger.debug("No OAuth2 Metadata found, skipping.")
self._oauth2_metadata = None

# metadata will hold the value of the token to send to the various endpoints.
self._metadata = None
if _AUTH.get():
self.force_auth_flow()

@property
def public_client_config(self) -> Optional[auth_pb2.PublicClientAuthConfigResponse]:
return self._public_client_config

@property
def oauth2_metadata(self) -> Optional[auth_pb2.OAuth2MetadataResponse]:
return self._oauth2_metadata

@property
def url(self) -> str:
return self._url

def set_access_token(self, access_token):
def set_access_token(self, access_token: str, authorization_header_key: Optional[str] = "authorization"):
# Always set the header to lower-case regardless of what the config is. The grpc libraries that Admin uses
# to parse the metadata don't change the metadata, but they do automatically lower the key you're looking for.
authorization_metadata_key = _creds_config.AUTHORIZATION_METADATA_KEY.get().lower()
cli_logger.debug(f"Adding authorization header. Header name: {authorization_metadata_key}.")
cli_logger.debug(f"Adding authorization header. Header name: {authorization_header_key}.")
self._metadata = [
(
authorization_metadata_key,
authorization_header_key,
f"Bearer {access_token}",
)
]

def force_auth_flow(self):
refresh_handler_fn = _get_refresh_handler(_creds_config.AUTH_MODE.get())
refresh_handler_fn(self)
def check_access_token(self, access_token: str) -> bool:
"""
This checks to see if the given access token is the same as the one already stored in the client. The reason
this is useful is so that we can prevent unnecessary refreshing of tokens.
:param access_token: The access token to check
:return: If no access token is stored, or if the stored token doesn't match, return False.
"""
if self._metadata is None:
return False
return access_token == self._metadata[0][1].replace("Bearer ", "")

####################################################################################################################
#
Expand Down Expand Up @@ -749,3 +800,55 @@ def list_matchable_attributes(self, matchable_attributes_list_request):

# TODO: (P2) Implement the event endpoints in case there becomes a use-case for third-parties to submit events
# through the client in Python.


def get_token(token_endpoint, authorization_header, scope):
"""
:param Text token_endpoint:
:param Text authorization_header: This is the value for the "Authorization" key. (eg 'Bearer abc123')
:param Text scope:
:rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration
in seconds
"""
headers = {
"Authorization": authorization_header,
"Cache-Control": "no-cache",
"Accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded",
}
body = {
"grant_type": "client_credentials",
}
if scope is not None:
body["scope"] = scope
response = _requests.post(token_endpoint, data=body, headers=headers)
if response.status_code != 200:
_logging.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text))
raise FlyteAuthenticationException("Non-200 received from IDP")

response = response.json()
return response["access_token"], response["expires_in"]


def get_secret():
"""
This function will either read in the password from the file path given by the CLIENT_CREDENTIALS_SECRET_LOCATION
config object, or from the environment variable using the CLIENT_CREDENTIALS_SECRET config object.
:rtype: Text
"""
secret = _CREDENTIALS_SECRET.get()
if secret:
return secret
raise FlyteAuthenticationException("No secret could be found")


def get_basic_authorization_header(client_id, client_secret):
"""
This function transforms the client id and the client secret into a header that conforms with http basic auth.
It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text.
:param Text client_id:
:param Text client_secret:
:rtype: Text
"""
concated = "{}:{}".format(client_id, client_secret)
return "Basic {}".format(_base64.b64encode(concated.encode(_utf_8)).decode(_utf_8))
16 changes: 9 additions & 7 deletions flytekit/clis/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,11 @@ def __init__(
scopes=None,
client_id=None,
redirect_uri=None,
client_secret=None,
):
self._auth_endpoint = auth_endpoint
self._token_endpoint = token_endpoint
self._client_id = client_id
self._scopes = scopes
self._scopes = scopes or []
self._redirect_uri = redirect_uri
self._code_verifier = _generate_code_verifier()
code_challenge = _create_code_challenge(self._code_verifier)
Expand All @@ -161,12 +160,11 @@ def __init__(
self._refresh_token = None
self._headers = {"content-type": "application/x-www-form-urlencoded"}
self._expired = False
self._client_secret = client_secret

self._params = {
"client_id": client_id, # This must match the Client ID of the OAuth application.
"response_type": "code", # Indicates the authorization code grant
"scope": " ".join(s.strip("' ") for s in scopes).strip(
"scope": " ".join(s.strip("' ") for s in self._scopes).strip(
"[]'"
), # ensures that the /token endpoint returns an ID and refresh token
# callback location where the user-agent will be directed to.
Expand Down Expand Up @@ -239,12 +237,12 @@ def _initialize_credentials(self, auth_token_resp):
raise ValueError('Expected "access_token" in response from oauth server')
if "refresh_token" in response_body:
self._refresh_token = response_body["refresh_token"]
_keyring.set_password(
_keyring_service_name, _keyring_refresh_token_storage_key, response_body["refresh_token"]
)

access_token = response_body["access_token"]
refresh_token = response_body["refresh_token"]

_keyring.set_password(_keyring_service_name, _keyring_access_token_storage_key, access_token)
_keyring.set_password(_keyring_service_name, _keyring_refresh_token_storage_key, refresh_token)
self._credentials = Credentials(access_token=access_token)

def request_access_token(self, auth_code):
Expand Down Expand Up @@ -299,6 +297,10 @@ def credentials(self):
"""
return self._credentials

def clear(self):
self._credentials = None
self._refresh_token = None

@property
def expired(self):
"""
Expand Down
Loading

0 comments on commit 1a89b78

Please sign in to comment.