From 1a89b7851ef6aba1b1a2de2106c9ec32f4afc778 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 22 Feb 2022 16:31:29 -0800 Subject: [PATCH] Delete unnecessary auth configuration (#858) Signed-off-by: Yee Hing Tong --- flytekit/clients/raw.py | 199 +++++++++++++----- flytekit/clis/auth/auth.py | 16 +- flytekit/clis/auth/credentials.py | 56 +---- flytekit/clis/auth/discovery.py | 73 ------- flytekit/clis/sdk_in_container/basic_auth.py | 61 ------ flytekit/configuration/creds.py | 61 +----- flytekit/configuration/platform.py | 17 -- setup.py | 1 + tests/flytekit/unit/cli/auth/test_auth.py | 11 + .../unit/cli/auth/test_credentials.py | 39 ---- .../flytekit/unit/cli/auth/test_discovery.py | 66 ------ .../unit/cli/pyflyte/test_basic_auth.py | 32 --- tests/flytekit/unit/clients/test_raw.py | 145 ++++++++++--- 13 files changed, 301 insertions(+), 476 deletions(-) delete mode 100644 flytekit/clis/auth/discovery.py delete mode 100644 flytekit/clis/sdk_in_container/basic_auth.py delete mode 100644 tests/flytekit/unit/cli/auth/test_credentials.py delete mode 100644 tests/flytekit/unit/cli/auth/test_discovery.py delete mode 100644 tests/flytekit/unit/cli/pyflyte/test_basic_auth.py diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index ce58d8cc57..24d3e118b3 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -1,8 +1,15 @@ +from __future__ import annotations + +import base64 as _base64 +import logging as _logging import subprocess import time -from typing import List +from typing import Optional +import requests as _requests from flyteidl.service import admin_pb2_grpc as _admin_service +from flyteidl.service import auth_pb2 +from flyteidl.service import auth_pb2_grpc as auth_service from google.protobuf.json_format import MessageToJson as _MessageToJson from grpc import RpcError as _RpcError from grpc import StatusCode as _GrpcStatusCode @@ -11,49 +18,57 @@ from grpc import ssl_channel_credentials as _ssl_channel_credentials from flytekit.clis.auth import credentials as _credentials_access -from flytekit.clis.sdk_in_container import basic_auth as _basic_auth -from flytekit.configuration import creds as _creds_config -from flytekit.configuration.creds import _DEPRECATED_CLIENT_CREDENTIALS_SCOPE as _DEPRECATED_SCOPE +from flytekit.configuration import creds as creds_config +from flytekit.configuration.creds import CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET from flytekit.configuration.creds import CLIENT_ID as _CLIENT_ID from flytekit.configuration.creds import COMMAND as _COMMAND -from flytekit.configuration.creds import DEPRECATED_OAUTH_SCOPES, SCOPES -from flytekit.configuration.platform import AUTH as _AUTH from flytekit.exceptions import user as _user_exceptions +from flytekit.exceptions.user import FlyteAuthenticationException from flytekit.loggers import cli_logger +_utf_8 = "utf-8" + -def _refresh_credentials_standard(flyte_client): +def _refresh_credentials_standard(flyte_client: RawSynchronousFlyteClient): """ This function is used when the configuration value for AUTH_MODE is set to 'standard'. This either fetches the existing access token or initiates the flow to request a valid access token and store it. :param flyte_client: RawSynchronousFlyteClient :return: """ - - client = _credentials_access.get_client(flyte_client.url) - if client.can_refresh_token: + authorization_header_key = flyte_client.public_client_config.authorization_metadata_key or None + if not flyte_client.oauth2_metadata or not flyte_client.public_client_config: + raise ValueError( + "Raw Flyte client attempting client credentials flow but no response from Admin detected. " + "Check your Admin server's .well-known endpoints to make sure they're working as expected." + ) + client = _credentials_access.get_client( + redirect_endpoint=flyte_client.public_client_config.redirect_uri, + client_id=flyte_client.public_client_config.client_id, + scopes=flyte_client.public_client_config.scopes, + auth_endpoint=flyte_client.oauth2_metadata.authorization_endpoint, + token_endpoint=flyte_client.oauth2_metadata.token_endpoint, + ) + if client.has_valid_credentials and not flyte_client.check_access_token(client.credentials.access_token): + # When Python starts up, if credentials have been stored in the keyring, then the AuthorizationClient + # will have read them into its _credentials field, but it won't be in the RawSynchronousFlyteClient's + # metadata field yet. Therefore, if there's a mismatch, copy it over. + flyte_client.set_access_token(client.credentials.access_token, authorization_header_key) + # However, after copying over credentials from the AuthorizationClient, we have to clear it to avoid the + # scenario where the stored credentials in the keyring are expired. If that's the case, then we only try + # them once (because client here is a singleton), and the next time, we'll do one of the two other conditions + # below. + client.clear() + return + elif client.can_refresh_token: client.refresh_access_token() + else: + client.start_authorization_flow() - flyte_client.set_access_token(client.credentials.access_token) - - -def _get_basic_flow_scopes() -> List[str]: - """ - Merge the scope value between the old scope config option and the new list option. - - :return: The scopes to use for basic auth flow. - """ - deprecated_single_scope = _DEPRECATED_SCOPE.get() - if deprecated_single_scope: - return [deprecated_single_scope] - scopes = DEPRECATED_OAUTH_SCOPES.get() or SCOPES.get() - if "openid" in scopes: - cli_logger.warning("Basic flow authentication should never use openid.") - - return scopes + flyte_client.set_access_token(client.credentials.access_token, authorization_header_key) -def _refresh_credentials_basic(flyte_client): +def _refresh_credentials_basic(flyte_client: RawSynchronousFlyteClient): """ This function is used by the _handle_rpc_error() decorator, depending on the AUTH_MODE config object. This handler is meant for SDK use-cases of auth (like pyflyte, or when users call SDK functions that require access to Admin, @@ -63,16 +78,24 @@ def _refresh_credentials_basic(flyte_client): :param flyte_client: RawSynchronousFlyteClient :return: """ - auth_endpoints = _credentials_access.get_authorization_endpoints(flyte_client.url) - token_endpoint = auth_endpoints.token_endpoint - client_secret = _basic_auth.get_secret() - cli_logger.debug( - "Basic authorization flow with client id {} scope {}".format(_CLIENT_ID.get(), _get_basic_flow_scopes()) - ) - authorization_header = _basic_auth.get_basic_authorization_header(_CLIENT_ID.get(), client_secret) - token, expires_in = _basic_auth.get_token(token_endpoint, authorization_header, _get_basic_flow_scopes()) + if not flyte_client.oauth2_metadata or not flyte_client.public_client_config: + raise ValueError( + "Raw Flyte client attempting client credentials flow but no response from Admin detected. " + "Check your Admin server's .well-known endpoints to make sure they're working as expected." + ) + + token_endpoint = flyte_client.oauth2_metadata.token_endpoint + scopes = creds_config.SCOPES.get() or flyte_client.public_client_config.scopes + scopes = ",".join(scopes) + + # Note that unlike the Pkce flow, the client ID does not come from Admin. + client_secret = get_secret() + cli_logger.debug("Basic authorization flow with client id {} scope {}".format(_CLIENT_ID.get(), scopes)) + authorization_header = get_basic_authorization_header(_CLIENT_ID.get(), client_secret) + token, expires_in = get_token(token_endpoint, authorization_header, scopes) cli_logger.info("Retrieved new token, expires in {}".format(expires_in)) - flyte_client.set_access_token(token) + authorization_header_key = flyte_client.public_client_config.authorization_metadata_key or None + flyte_client.set_access_token(token, authorization_header_key) def _refresh_credentials_from_command(flyte_client): @@ -101,7 +124,7 @@ def _refresh_credentials_noop(flyte_client): def _get_refresh_handler(auth_mode): if auth_mode == "standard": return _refresh_credentials_standard - elif auth_mode == "basic": + elif auth_mode == "basic" or auth_mode == "client_credentials": return _refresh_credentials_basic elif auth_mode == "external_process": return _refresh_credentials_from_command @@ -133,7 +156,7 @@ def handler(*args, **kwargs): # Exit the loop and wrap the authentication error. raise _user_exceptions.FlyteAuthenticationException(str(e)) cli_logger.error(f"Unauthenticated RPC error {e}, refreshing credentials and retrying\n") - refresh_handler_fn = _get_refresh_handler(_creds_config.AUTH_MODE.get()) + refresh_handler_fn = _get_refresh_handler(creds_config.AUTH_MODE.get()) refresh_handler_fn(args[0]) # There are two cases that we should throw error immediately # 1. Entity already exists when we register entity @@ -210,29 +233,57 @@ def __init__(self, url, insecure=False, credentials=None, options=None, root_cer options=list((options or {}).items()), ) self._stub = _admin_service.AdminServiceStub(self._channel) + self._auth_stub = auth_service.AuthMetadataServiceStub(self._channel) + try: + resp = self._auth_stub.GetPublicClientConfig(auth_pb2.PublicClientAuthConfigRequest()) + self._public_client_config = resp + except _RpcError: + cli_logger.debug("No public client auth config found, skipping.") + self._public_client_config = None + try: + resp = self._auth_stub.GetOAuth2Metadata(auth_pb2.OAuth2MetadataRequest()) + self._oauth2_metadata = resp + except _RpcError: + cli_logger.debug("No OAuth2 Metadata found, skipping.") + self._oauth2_metadata = None + + # metadata will hold the value of the token to send to the various endpoints. self._metadata = None - if _AUTH.get(): - self.force_auth_flow() + + @property + def public_client_config(self) -> Optional[auth_pb2.PublicClientAuthConfigResponse]: + return self._public_client_config + + @property + def oauth2_metadata(self) -> Optional[auth_pb2.OAuth2MetadataResponse]: + return self._oauth2_metadata @property def url(self) -> str: return self._url - def set_access_token(self, access_token): + def set_access_token(self, access_token: str, authorization_header_key: Optional[str] = "authorization"): # Always set the header to lower-case regardless of what the config is. The grpc libraries that Admin uses # to parse the metadata don't change the metadata, but they do automatically lower the key you're looking for. - authorization_metadata_key = _creds_config.AUTHORIZATION_METADATA_KEY.get().lower() - cli_logger.debug(f"Adding authorization header. Header name: {authorization_metadata_key}.") + cli_logger.debug(f"Adding authorization header. Header name: {authorization_header_key}.") self._metadata = [ ( - authorization_metadata_key, + authorization_header_key, f"Bearer {access_token}", ) ] - def force_auth_flow(self): - refresh_handler_fn = _get_refresh_handler(_creds_config.AUTH_MODE.get()) - refresh_handler_fn(self) + def check_access_token(self, access_token: str) -> bool: + """ + This checks to see if the given access token is the same as the one already stored in the client. The reason + this is useful is so that we can prevent unnecessary refreshing of tokens. + + :param access_token: The access token to check + :return: If no access token is stored, or if the stored token doesn't match, return False. + """ + if self._metadata is None: + return False + return access_token == self._metadata[0][1].replace("Bearer ", "") #################################################################################################################### # @@ -749,3 +800,55 @@ def list_matchable_attributes(self, matchable_attributes_list_request): # TODO: (P2) Implement the event endpoints in case there becomes a use-case for third-parties to submit events # through the client in Python. + + +def get_token(token_endpoint, authorization_header, scope): + """ + :param Text token_endpoint: + :param Text authorization_header: This is the value for the "Authorization" key. (eg 'Bearer abc123') + :param Text scope: + :rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration + in seconds + """ + headers = { + "Authorization": authorization_header, + "Cache-Control": "no-cache", + "Accept": "application/json", + "Content-Type": "application/x-www-form-urlencoded", + } + body = { + "grant_type": "client_credentials", + } + if scope is not None: + body["scope"] = scope + response = _requests.post(token_endpoint, data=body, headers=headers) + if response.status_code != 200: + _logging.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text)) + raise FlyteAuthenticationException("Non-200 received from IDP") + + response = response.json() + return response["access_token"], response["expires_in"] + + +def get_secret(): + """ + This function will either read in the password from the file path given by the CLIENT_CREDENTIALS_SECRET_LOCATION + config object, or from the environment variable using the CLIENT_CREDENTIALS_SECRET config object. + :rtype: Text + """ + secret = _CREDENTIALS_SECRET.get() + if secret: + return secret + raise FlyteAuthenticationException("No secret could be found") + + +def get_basic_authorization_header(client_id, client_secret): + """ + This function transforms the client id and the client secret into a header that conforms with http basic auth. + It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text. + :param Text client_id: + :param Text client_secret: + :rtype: Text + """ + concated = "{}:{}".format(client_id, client_secret) + return "Basic {}".format(_base64.b64encode(concated.encode(_utf_8)).decode(_utf_8)) diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index 79127c6748..2f27dfc72d 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -145,12 +145,11 @@ def __init__( scopes=None, client_id=None, redirect_uri=None, - client_secret=None, ): self._auth_endpoint = auth_endpoint self._token_endpoint = token_endpoint self._client_id = client_id - self._scopes = scopes + self._scopes = scopes or [] self._redirect_uri = redirect_uri self._code_verifier = _generate_code_verifier() code_challenge = _create_code_challenge(self._code_verifier) @@ -161,12 +160,11 @@ def __init__( self._refresh_token = None self._headers = {"content-type": "application/x-www-form-urlencoded"} self._expired = False - self._client_secret = client_secret self._params = { "client_id": client_id, # This must match the Client ID of the OAuth application. "response_type": "code", # Indicates the authorization code grant - "scope": " ".join(s.strip("' ") for s in scopes).strip( + "scope": " ".join(s.strip("' ") for s in self._scopes).strip( "[]'" ), # ensures that the /token endpoint returns an ID and refresh token # callback location where the user-agent will be directed to. @@ -239,12 +237,12 @@ def _initialize_credentials(self, auth_token_resp): raise ValueError('Expected "access_token" in response from oauth server') if "refresh_token" in response_body: self._refresh_token = response_body["refresh_token"] + _keyring.set_password( + _keyring_service_name, _keyring_refresh_token_storage_key, response_body["refresh_token"] + ) access_token = response_body["access_token"] - refresh_token = response_body["refresh_token"] - _keyring.set_password(_keyring_service_name, _keyring_access_token_storage_key, access_token) - _keyring.set_password(_keyring_service_name, _keyring_refresh_token_storage_key, refresh_token) self._credentials = Credentials(access_token=access_token) def request_access_token(self, auth_code): @@ -299,6 +297,10 @@ def credentials(self): """ return self._credentials + def clear(self): + self._credentials = None + self._refresh_token = None + @property def expired(self): """ diff --git a/flytekit/clis/auth/credentials.py b/flytekit/clis/auth/credentials.py index 45f51a482a..a8475c8dfc 100644 --- a/flytekit/clis/auth/credentials.py +++ b/flytekit/clis/auth/credentials.py @@ -1,56 +1,28 @@ -import urllib.parse as _urlparse +from typing import List -from flytekit.clis.auth.auth import AuthorizationClient as _AuthorizationClient -from flytekit.clis.auth.discovery import DiscoveryClient as _DiscoveryClient -from flytekit.configuration.creds import CLIENT_CREDENTIALS_SECRET as _CLIENT_SECRET -from flytekit.configuration.creds import CLIENT_ID as _CLIENT_ID -from flytekit.configuration.creds import DEPRECATED_OAUTH_SCOPES -from flytekit.configuration.creds import REDIRECT_URI as _REDIRECT_URI -from flytekit.configuration.creds import SCOPES -from flytekit.configuration.platform import HTTP_URL as _HTTP_URL -from flytekit.configuration.platform import INSECURE as _INSECURE -from flytekit.configuration.platform import URL as _URL +from flytekit.clis.auth.auth import AuthorizationClient from flytekit.loggers import auth_logger # Default, well known-URI string used for fetching JSON metadata. See https://tools.ietf.org/html/rfc8414#section-3. discovery_endpoint_path = "./.well-known/oauth-authorization-server" - -def _get_discovery_endpoint(http_config_val, platform_url_val, insecure_val): - if http_config_val: - scheme, netloc, path, _, _, _ = _urlparse.urlparse(http_config_val) - if not scheme: - scheme = "http" if insecure_val else "https" - else: # Use the main _URL config object effectively - scheme = "http" if insecure_val else "https" - netloc = platform_url_val - path = "" - - computed_endpoint = _urlparse.urlunparse((scheme, netloc, path, None, None, None)) - # The urljoin function needs a trailing slash in order to append things correctly. Also, having an extra slash - # at the end is okay, it just gets stripped out. - computed_endpoint = _urlparse.urljoin(computed_endpoint + "/", discovery_endpoint_path) - auth_logger.debug(f"Using {computed_endpoint} as discovery endpoint") - return computed_endpoint - - # Lazy initialized authorization client singleton _authorization_client = None -def get_client(flyte_client_url): +def get_client( + redirect_endpoint: str, client_id: str, scopes: List[str], auth_endpoint: str, token_endpoint: str +) -> AuthorizationClient: global _authorization_client if _authorization_client is not None and not _authorization_client.expired: return _authorization_client - authorization_endpoints = get_authorization_endpoints(flyte_client_url) - _authorization_client = _AuthorizationClient( - redirect_uri=_REDIRECT_URI.get(), - client_id=_CLIENT_ID.get(), - scopes=DEPRECATED_OAUTH_SCOPES.get() or SCOPES.get(), - auth_endpoint=authorization_endpoints.auth_endpoint, - token_endpoint=authorization_endpoints.token_endpoint, - client_secret=_CLIENT_SECRET.get(), + _authorization_client = AuthorizationClient( + redirect_uri=redirect_endpoint, + client_id=client_id, + scopes=scopes, + auth_endpoint=auth_endpoint, + token_endpoint=token_endpoint, ) auth_logger.debug(f"Created oauth client with redirect {_authorization_client}") @@ -59,9 +31,3 @@ def get_client(flyte_client_url): _authorization_client.start_authorization_flow() return _authorization_client - - -def get_authorization_endpoints(flyte_client_url): - discovery_endpoint = _get_discovery_endpoint(_HTTP_URL.get(), flyte_client_url or _URL.get(), _INSECURE.get()) - discovery_client = _DiscoveryClient(discovery_url=discovery_endpoint) - return discovery_client.get_authorization_endpoints() diff --git a/flytekit/clis/auth/discovery.py b/flytekit/clis/auth/discovery.py deleted file mode 100644 index 5134eab974..0000000000 --- a/flytekit/clis/auth/discovery.py +++ /dev/null @@ -1,73 +0,0 @@ -import logging - -import requests as _requests - -# These response keys are defined in https://tools.ietf.org/id/draft-ietf-oauth-discovery-08.html. -_authorization_endpoint_key = "authorization_endpoint" -_token_endpoint_key = "token_endpoint" - - -class AuthorizationEndpoints(object): - """ - A simple wrapper around commonly discovered endpoints used for the PKCE auth flow. - """ - - def __init__(self, auth_endpoint=None, token_endpoint=None): - self._auth_endpoint = auth_endpoint - self._token_endpoint = token_endpoint - - @property - def auth_endpoint(self): - return self._auth_endpoint - - @property - def token_endpoint(self): - return self._token_endpoint - - -class DiscoveryClient(object): - """ - Discovers well known OpenID configuration and parses out authorization endpoints required for initiating the PKCE - auth flow. - """ - - def __init__(self, discovery_url=None): - logging.debug("Initializing discovery client with {}".format(discovery_url)) - self._discovery_url = discovery_url - self._authorization_endpoints = None - - @property - def authorization_endpoints(self): - """ - :rtype: flytekit.clis.auth.discovery.AuthorizationEndpoints: - """ - return self._authorization_endpoints - - def get_authorization_endpoints(self): - if self.authorization_endpoints is not None: - return self.authorization_endpoints - resp = _requests.get( - url=self._discovery_url, - ) - - response_body = resp.json() - - authorization_endpoint = response_body[_authorization_endpoint_key] - token_endpoint = response_body[_token_endpoint_key] - - if authorization_endpoint is None: - raise ValueError("Unable to discover authorization endpoint") - - if token_endpoint is None: - raise ValueError("Unable to discover token endpoint") - - if authorization_endpoint.startswith("/"): - authorization_endpoint = _requests.compat.urljoin(self._discovery_url, authorization_endpoint) - - if token_endpoint.startswith("/"): - token_endpoint = _requests.compat.urljoin(self._discovery_url, token_endpoint) - - self._authorization_endpoints = AuthorizationEndpoints( - auth_endpoint=authorization_endpoint, token_endpoint=token_endpoint - ) - return self.authorization_endpoints diff --git a/flytekit/clis/sdk_in_container/basic_auth.py b/flytekit/clis/sdk_in_container/basic_auth.py deleted file mode 100644 index 92612595ad..0000000000 --- a/flytekit/clis/sdk_in_container/basic_auth.py +++ /dev/null @@ -1,61 +0,0 @@ -import base64 as _base64 -import logging as _logging - -import requests as _requests - -from flytekit.configuration.creds import CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET -from flytekit.exceptions.user import FlyteAuthenticationException - -_utf_8 = "utf-8" - - -def get_secret(): - """ - This function will either read in the password from the file path given by the CLIENT_CREDENTIALS_SECRET_LOCATION - config object, or from the environment variable using the CLIENT_CREDENTIALS_SECRET config object. - :rtype: Text - """ - secret = _CREDENTIALS_SECRET.get() - if secret: - return secret - raise FlyteAuthenticationException("No secret could be found") - - -def get_basic_authorization_header(client_id, client_secret): - """ - This function transforms the client id and the client secret into a header that conforms with http basic auth. - It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text. - :param Text client_id: - :param Text client_secret: - :rtype: Text - """ - concated = "{}:{}".format(client_id, client_secret) - return "Basic {}".format(_base64.b64encode(concated.encode(_utf_8)).decode(_utf_8)) - - -def get_token(token_endpoint, authorization_header, scope): - """ - :param Text token_endpoint: - :param Text authorization_header: This is the value for the "Authorization" key. (eg 'Bearer abc123') - :param Text scope: - :rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration - in seconds - """ - headers = { - "Authorization": authorization_header, - "Cache-Control": "no-cache", - "Accept": "application/json", - "Content-Type": "application/x-www-form-urlencoded", - } - body = { - "grant_type": "client_credentials", - } - if scope is not None: - body["scope"] = scope - response = _requests.post(token_endpoint, data=body, headers=headers) - if response.status_code != 200: - _logging.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text)) - raise FlyteAuthenticationException("Non-200 received from IDP") - - response = response.json() - return response["access_token"], response["expires_in"] diff --git a/flytekit/configuration/creds.py b/flytekit/configuration/creds.py index 6acfcd9129..9f11ac2d2e 100644 --- a/flytekit/configuration/creds.py +++ b/flytekit/configuration/creds.py @@ -1,10 +1,5 @@ -from warnings import warn - from flytekit.configuration import common as _config_common -deprecated_names = ["CLIENT_CREDENTIALS_SCOPE"] - - COMMAND = _config_common.FlyteStringListConfigurationEntry("credentials", "command", default=None) """ This command is executed to return a token using an external process. @@ -16,37 +11,6 @@ More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/. """ -REDIRECT_URI = _config_common.FlyteStringConfigurationEntry( - "credentials", "redirect_uri", default="http://localhost:12345/callback" -) -""" -This is the callback uri registered with the app which handles authorization for a Flyte deployment. -Please note the hardcoded port number. Ideally we would not do this, but some IDPs do not allow wildcards for -the URL, which means we have to use the same port every time. This is the only reason this is a configuration option, -otherwise, we'd just hardcode the callback path as a constant. -FYI, to see if a given port is already in use, run `sudo lsof -i :` if on a Linux system. -More details here: https://www.oauth.com/oauth2-servers/redirect-uris/. -""" - -SCOPES = _config_common.FlyteStringListConfigurationEntry("credentials", "scopes", default=["openid"]) -""" -This controls the list of scopes to request from the authorization server. -""" - -DEPRECATED_OAUTH_SCOPES = _config_common.FlyteStringListConfigurationEntry("credentials", "oauth_scopes", default=None) -""" -This controls the list of scopes to request from the authorization server. -Deprecated - please use the SCOPES variable. -""" - -AUTHORIZATION_METADATA_KEY = _config_common.FlyteStringConfigurationEntry( - "credentials", "authorization_metadata_key", default="authorization" -) -""" -The authorization metadata key used for passing access tokens in gRPC requests. -Traditionally this value is 'authorization' however it is made configurable. -""" - CLIENT_CREDENTIALS_SECRET = _config_common.FlyteStringConfigurationEntry("credentials", "client_secret", default=None) """ Used for basic auth, which is automatically called during pyflyte. This will allow the Flyte engine to read the @@ -54,33 +18,14 @@ secret as a file is impossible. """ -_DEPRECATED_CLIENT_CREDENTIALS_SCOPE = _config_common.FlyteStringConfigurationEntry( - "credentials", "scope", default=None -) -""" -Used for basic auth, which is automatically called during pyflyte. This is the scope that will be requested. Because -there is no user explicitly in this auth flow, certain IDPs require a custom scope for basic auth in the configuration -of the authorization server. - -Deprecated - please use the OAUTH_SCOPES list variable instead. In the basic flow scenario, flytekit will expect a list -with at least one element. The first element will be used. If list has more than one element a warning will be logged. -Config files with both this option, and the OAUTH_SCOPES, will use this one. -""" +SCOPES = _config_common.FlyteStringListConfigurationEntry("credentials", "scopes", default=[]) AUTH_MODE = _config_common.FlyteStringConfigurationEntry("credentials", "auth_mode", default="standard") """ The auth mode defines the behavior used to request and refresh credentials. The currently supported modes include: - 'standard' This uses the pkce-enhanced authorization code flow by opening a browser window to initiate credentials access. -- 'basic' This uses cert-based auth in which the end user enters his/her username and password and public key encryption - is used to facilitate authentication. +- 'basic' or 'client_credentials' This uses cert-based auth in which the end user enters a client id and a client + secret and public key encryption is used to facilitate authentication. - None: No auth will be attempted. """ - - -# https://www.python.org/dev/peps/pep-0562/ -def __getattr__(name): - if name in deprecated_names: - warn(f"{name} is deprecated", DeprecationWarning) - return globals()[f"_DEPRECATED_{name}"] - raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/flytekit/configuration/platform.py b/flytekit/configuration/platform.py index eecbeda162..ee8bfeb895 100644 --- a/flytekit/configuration/platform.py +++ b/flytekit/configuration/platform.py @@ -1,21 +1,4 @@ from flytekit.configuration import common as _config_common URL = _config_common.FlyteStringConfigurationEntry("platform", "url") - -HTTP_URL = _config_common.FlyteStringConfigurationEntry("platform", "http_url", default=None) -""" -If not starting with either http or https, this setting should begin with // as per the urlparse library and -https://tools.ietf.org/html/rfc1808.html, otherwise the netloc will not be properly parsed. - -Currently the only use-case for this configuration setting is for Auth discovery. This setting supports the case where -Flyte Admin's gRPC and HTTP points are deployed on different ports. -""" - INSECURE = _config_common.FlyteBoolConfigurationEntry("platform", "insecure", default=False) - -AUTH = _config_common.FlyteBoolConfigurationEntry("platform", "auth", default=False) -""" -This config setting should not normally be filled in. Whether or not an admin server requires authentication should be -something published by the admin server itself (typically by returning a 401). However, to help with migration, this -config object is here to force the SDK to attempt the auth flow even without prompting by Admin. -""" diff --git a/setup.py b/setup.py index 885f32b897..3816359083 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,7 @@ "python-dateutil>=2.1", "grpcio>=1.3.0,<2.0", "protobuf>=3.6.1,<4", + "protoc_gen_swagger", "python-json-logger>=2.0.0", "pytimeparse>=1.1.8,<2.0.0", "pytz", diff --git a/tests/flytekit/unit/cli/auth/test_auth.py b/tests/flytekit/unit/cli/auth/test_auth.py index 2deecaefaa..1bd38d3b39 100644 --- a/tests/flytekit/unit/cli/auth/test_auth.py +++ b/tests/flytekit/unit/cli/auth/test_auth.py @@ -1,6 +1,8 @@ import re from multiprocessing import Queue as _Queue +from mock import patch + from flytekit.clis.auth import auth as _auth try: # Python 3 @@ -33,3 +35,12 @@ def test_oauth_http_server(): server.handle_authorization_code(test_auth_code) auth_code = queue.get() assert test_auth_code == auth_code + + +@patch("flytekit.clis.auth.auth._keyring.get_password") +def test_clear(mock_get_password): + mock_get_password.return_value = "token" + ac = _auth.AuthorizationClient() + ac.clear() + assert ac.credentials is None + assert not ac.can_refresh_token diff --git a/tests/flytekit/unit/cli/auth/test_credentials.py b/tests/flytekit/unit/cli/auth/test_credentials.py deleted file mode 100644 index f1ae57a016..0000000000 --- a/tests/flytekit/unit/cli/auth/test_credentials.py +++ /dev/null @@ -1,39 +0,0 @@ -from flytekit.clis.auth import credentials as _credentials - - -def test_get_discovery_endpoint(): - endpoint = _credentials._get_discovery_endpoint("//localhost:8088", "localhost:8089", True) - assert endpoint == "http://localhost:8088/.well-known/oauth-authorization-server" - - endpoint = _credentials._get_discovery_endpoint("//localhost:8088", "localhost:8089", False) - assert endpoint == "https://localhost:8088/.well-known/oauth-authorization-server" - - endpoint = _credentials._get_discovery_endpoint("//localhost:8088/path", "localhost:8089", True) - assert endpoint == "http://localhost:8088/path/.well-known/oauth-authorization-server" - - endpoint = _credentials._get_discovery_endpoint("//localhost:8088/path", "localhost:8089", False) - assert endpoint == "https://localhost:8088/path/.well-known/oauth-authorization-server" - - endpoint = _credentials._get_discovery_endpoint("//flyte.corp.com", "localhost:8089", False) - assert endpoint == "https://flyte.corp.com/.well-known/oauth-authorization-server" - - endpoint = _credentials._get_discovery_endpoint("//flyte.corp.com/path", "localhost:8089", False) - assert endpoint == "https://flyte.corp.com/path/.well-known/oauth-authorization-server" - - endpoint = _credentials._get_discovery_endpoint(None, "localhost:8089", True) - assert endpoint == "http://localhost:8089/.well-known/oauth-authorization-server" - - endpoint = _credentials._get_discovery_endpoint(None, "localhost:8089", False) - assert endpoint == "https://localhost:8089/.well-known/oauth-authorization-server" - - endpoint = _credentials._get_discovery_endpoint(None, "flyte.corp.com", True) - assert endpoint == "http://flyte.corp.com/.well-known/oauth-authorization-server" - - endpoint = _credentials._get_discovery_endpoint(None, "flyte.corp.com", False) - assert endpoint == "https://flyte.corp.com/.well-known/oauth-authorization-server" - - endpoint = _credentials._get_discovery_endpoint(None, "localhost:8089", True) - assert endpoint == "http://localhost:8089/.well-known/oauth-authorization-server" - - endpoint = _credentials._get_discovery_endpoint(None, "localhost:8089", False) - assert endpoint == "https://localhost:8089/.well-known/oauth-authorization-server" diff --git a/tests/flytekit/unit/cli/auth/test_discovery.py b/tests/flytekit/unit/cli/auth/test_discovery.py deleted file mode 100644 index c75427f35d..0000000000 --- a/tests/flytekit/unit/cli/auth/test_discovery.py +++ /dev/null @@ -1,66 +0,0 @@ -import pytest -import responses - -from flytekit.clis.auth import discovery as _discovery - - -@responses.activate -def test_get_authorization_endpoints(): - discovery_url = "http://flyte-admin.com/discovery" - - auth_endpoint = "http://flyte-admin.com/authorization" - token_endpoint = "http://flyte-admin.com/token" - responses.add( - responses.GET, - discovery_url, - json={"authorization_endpoint": auth_endpoint, "token_endpoint": token_endpoint}, - ) - - discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) - assert discovery_client.get_authorization_endpoints().auth_endpoint == auth_endpoint - assert discovery_client.get_authorization_endpoints().token_endpoint == token_endpoint - - -@responses.activate -def test_get_authorization_endpoints_relative(): - discovery_url = "http://flyte-admin.com/discovery" - - auth_endpoint = "/authorization" - token_endpoint = "/token" - responses.add( - responses.GET, - discovery_url, - json={"authorization_endpoint": auth_endpoint, "token_endpoint": token_endpoint}, - ) - - discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) - assert discovery_client.get_authorization_endpoints().auth_endpoint == "http://flyte-admin.com/authorization" - assert discovery_client.get_authorization_endpoints().token_endpoint == "http://flyte-admin.com/token" - - -@responses.activate -def test_get_authorization_endpoints_missing_authorization_endpoint(): - discovery_url = "http://flyte-admin.com/discovery" - responses.add( - responses.GET, - discovery_url, - json={"token_endpoint": "http://flyte-admin.com/token"}, - ) - - discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) - with pytest.raises(Exception): - discovery_client.get_authorization_endpoints() - - -@responses.activate -def test_get_authorization_endpoints_missing_token_endpoint(): - discovery_url = "http://flyte-admin.com/discovery" - responses.add( - responses.GET, - discovery_url, - json={"authorization_endpoint": "http://flyte-admin.com/authorization"}, - ) - - discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) - with pytest.raises(Exception): - discovery_client.get_authorization_endpoints() diff --git a/tests/flytekit/unit/cli/pyflyte/test_basic_auth.py b/tests/flytekit/unit/cli/pyflyte/test_basic_auth.py deleted file mode 100644 index d18f21dfa5..0000000000 --- a/tests/flytekit/unit/cli/pyflyte/test_basic_auth.py +++ /dev/null @@ -1,32 +0,0 @@ -import json - -from mock import MagicMock, patch - -from flytekit.clis.flyte_cli.main import _welcome_message -from flytekit.clis.sdk_in_container import basic_auth -from flytekit.configuration.creds import CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET - -_welcome_message() - - -def test_get_secret(): - import os - - os.environ[_CREDENTIALS_SECRET.env_var] = "abc" - assert basic_auth.get_secret() == "abc" - - -def test_get_basic_authorization_header(): - header = basic_auth.get_basic_authorization_header("client_id", "abc") - assert header == "Basic Y2xpZW50X2lkOmFiYw==" - - -@patch("flytekit.clis.sdk_in_container.basic_auth._requests") -def test_get_token(mock_requests): - response = MagicMock() - response.status_code = 200 - response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") - mock_requests.post.return_value = response - access, expiration = basic_auth.get_token("https://corp.idp.net", "abc123", "my_scope") - assert access == "abc" - assert expiration == 60 diff --git a/tests/flytekit/unit/clients/test_raw.py b/tests/flytekit/unit/clients/test_raw.py index c8e56811d4..86f5a04a6b 100644 --- a/tests/flytekit/unit/clients/test_raw.py +++ b/tests/flytekit/unit/clients/test_raw.py @@ -3,44 +3,59 @@ from subprocess import CompletedProcess import mock +import pytest from flyteidl.admin import project_pb2 as _project_pb2 +from flyteidl.service import auth_pb2 +from mock import MagicMock, patch from flytekit.clients.raw import RawSynchronousFlyteClient as _RawSynchronousFlyteClient -from flytekit.clients.raw import _get_basic_flow_scopes, _refresh_credentials_basic, _refresh_credentials_from_command -from flytekit.clis.auth.discovery import AuthorizationEndpoints as _AuthorizationEndpoints -from flytekit.configuration import TemporaryConfiguration +from flytekit.clients.raw import ( + _get_refresh_handler, + _refresh_credentials_basic, + _refresh_credentials_from_command, + _refresh_credentials_standard, + get_basic_authorization_header, + get_secret, + get_token, +) from flytekit.configuration.creds import CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET -@mock.patch("flytekit.clients.raw.RawSynchronousFlyteClient.force_auth_flow") +def get_admin_stub_mock() -> mock.MagicMock: + auth_stub_mock = mock.MagicMock() + auth_stub_mock.GetPublicClientConfig.return_value = auth_pb2.PublicClientAuthConfigResponse( + client_id="flytectl", + redirect_uri="http://localhost:53593/callback", + scopes=["offline", "all"], + authorization_metadata_key="flyte-authorization", + ) + auth_stub_mock.GetOAuth2Metadata.return_value = auth_pb2.OAuth2MetadataResponse( + issuer="https://your.domain.io", + authorization_endpoint="https://your.domain.io/oauth2/authorize", + token_endpoint="https://your.domain.io/oauth2/token", + response_types_supported=["code", "token", "code token"], + scopes_supported=["all"], + token_endpoint_auth_methods_supported=["client_secret_basic"], + jwks_uri="https://your.domain.io/oauth2/jwks", + code_challenge_methods_supported=["S256"], + grant_types_supported=["client_credentials", "refresh_token", "authorization_code"], + ) + return auth_stub_mock + + +@mock.patch("flytekit.clients.raw.auth_service") @mock.patch("flytekit.clients.raw._admin_service") @mock.patch("flytekit.clients.raw._insecure_channel") @mock.patch("flytekit.clients.raw._secure_channel") -def test_client_set_token(mock_secure_channel, mock_channel, mock_admin, mock_force): - mock_force.return_value = True +def test_client_set_token(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth): mock_secure_channel.return_value = True mock_channel.return_value = True mock_admin.AdminServiceStub.return_value = True + mock_admin_auth.AuthMetadataServiceStub.return_value = get_admin_stub_mock() client = _RawSynchronousFlyteClient(url="a.b.com", insecure=True) client.set_access_token("abc") assert client._metadata[0][1] == "Bearer abc" - - -@mock.patch("flytekit.clis.sdk_in_container.basic_auth._requests") -@mock.patch("flytekit.clients.raw._credentials_access") -def test_refresh_credentials_basic(mock_credentials_access, mock_requests): - mock_credentials_access.get_authorization_endpoints.return_value = _AuthorizationEndpoints("auth", "token") - response = mock.MagicMock() - response.status_code = 200 - response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") - mock_requests.post.return_value = response - os.environ[_CREDENTIALS_SECRET.env_var] = "asdf12345" - - mock_client = mock.MagicMock() - mock_client.url.return_value = "flyte.localhost" - _refresh_credentials_basic(mock_client) - mock_client.set_access_token.assert_called_with("abc") - mock_credentials_access.get_authorization_endpoints.assert_called_with(mock_client.url) + assert client.check_access_token("abc") @mock.patch("flytekit.configuration.creds.COMMAND.get") @@ -59,6 +74,57 @@ def test_refresh_credentials_from_command(mock_call_to_external_process, mock_co mock_client.set_access_token.assert_called_with(token) +@mock.patch("flytekit.configuration.creds.SCOPES.get") +@mock.patch("flytekit.clients.raw.get_secret") +@mock.patch("flytekit.clients.raw.get_basic_authorization_header") +@mock.patch("flytekit.clients.raw.get_token") +@mock.patch("flytekit.clients.raw.auth_service") +@mock.patch("flytekit.clients.raw._admin_service") +@mock.patch("flytekit.clients.raw._insecure_channel") +@mock.patch("flytekit.clients.raw._secure_channel") +def test_refresh_client_credentials_aka_basic( + mock_secure_channel, + mock_channel, + mock_admin, + mock_admin_auth, + mock_get_token, + mock_get_basic_header, + mock_secret, + mock_scopes, +): + mock_secret.return_value = "sosecret" + mock_scopes.return_value = ["a", "b", "c", "d"] + mock_secure_channel.return_value = True + mock_channel.return_value = True + mock_admin.AdminServiceStub.return_value = True + mock_get_basic_header.return_value = "Basic 123" + mock_get_token.return_value = ("token1", 1234567) + + mock_admin_auth.AuthMetadataServiceStub.return_value = get_admin_stub_mock() + client = _RawSynchronousFlyteClient(url="a.b.com", insecure=True) + client._metadata = None + assert not client.check_access_token("fdsa") + _refresh_credentials_basic(client) + + # Scopes from configuration take precendence. + mock_get_token.assert_called_once_with("https://your.domain.io/oauth2/token", "Basic 123", "a,b,c,d") + + client.set_access_token("token") + assert client._metadata[0][0] == "authorization" + + +def test_raises(): + mm = MagicMock() + mm.public_client_config = None + with pytest.raises(ValueError): + _refresh_credentials_basic(mm) + + mm = MagicMock() + mm.oauth2_metadata = None + with pytest.raises(ValueError): + _refresh_credentials_basic(mm) + + @mock.patch("flytekit.clients.raw._admin_service") @mock.patch("flytekit.clients.raw._insecure_channel") def test_update_project(mock_channel, mock_admin): @@ -77,12 +143,31 @@ def test_list_projects_paginated(mock_channel, mock_admin): mock_admin.AdminServiceStub().ListProjects.assert_called_with(project_list_request, metadata=None) -def test_scope_deprecation(): - with TemporaryConfiguration(os.path.join(os.path.dirname(__file__), "auth_deprecation.config")): - assert _get_basic_flow_scopes() == ["custom_basic"] +def test_get_secret(): + os.environ[_CREDENTIALS_SECRET.env_var] = "abc" + assert get_secret() == "abc" + + +def test_get_basic_authorization_header(): + header = get_basic_authorization_header("client_id", "abc") + assert header == "Basic Y2xpZW50X2lkOmFiYw==" - with TemporaryConfiguration(os.path.join(os.path.dirname(__file__), "auth_deprecation2.config")): - assert _get_basic_flow_scopes() == ["custom_basic", "other_scope", "profile"] - with TemporaryConfiguration(os.path.join(os.path.dirname(__file__), "auth_deprecation3.config")): - assert _get_basic_flow_scopes() == ["custom_basic"] +@patch("flytekit.clients.raw._requests") +def test_get_token(mock_requests): + response = MagicMock() + response.status_code = 200 + response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") + mock_requests.post.return_value = response + access, expiration = get_token("https://corp.idp.net", "abc123", "my_scope") + assert access == "abc" + assert expiration == 60 + + +def test_get_refresh_handler(): + cc = _get_refresh_handler("client_credentials") + basic = _get_refresh_handler("basic") + assert basic is cc + assert basic is _refresh_credentials_basic + standard = _get_refresh_handler("standard") + assert standard is _refresh_credentials_standard