diff --git a/metaflow/extension_support/plugins.py b/metaflow/extension_support/plugins.py index 10c6f129ee7..dbb11606e0e 100644 --- a/metaflow/extension_support/plugins.py +++ b/metaflow/extension_support/plugins.py @@ -179,6 +179,7 @@ def resolve_plugins(category): "metadata_provider": lambda x: x.TYPE, "datastore": lambda x: x.TYPE, "secrets_provider": lambda x: x.TYPE, + "gcp_client_provider": lambda x: x.name, "azure_client_provider": lambda x: x.name, "sidecar": None, "logging_sidecar": None, diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index e575c86a57c..bfbf24de2bd 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -26,6 +26,7 @@ DEFAULT_MONITOR = from_conf("DEFAULT_MONITOR", "nullSidecarMonitor") DEFAULT_PACKAGE_SUFFIXES = from_conf("DEFAULT_PACKAGE_SUFFIXES", ".py,.R,.RDS") DEFAULT_AWS_CLIENT_PROVIDER = from_conf("DEFAULT_AWS_CLIENT_PROVIDER", "boto3") +DEFAULT_GCP_CLIENT_PROVIDER = from_conf("DEFAULT_GCP_CLIENT_PROVIDER", "gcp-default") DEFAULT_SECRETS_BACKEND_TYPE = from_conf("DEFAULT_SECRETS_BACKEND_TYPE") DEFAULT_SECRETS_ROLE = from_conf("DEFAULT_SECRETS_ROLE") diff --git a/metaflow/plugins/__init__.py b/metaflow/plugins/__init__.py index f5d4ed35898..1b40016edb1 100644 --- a/metaflow/plugins/__init__.py +++ b/metaflow/plugins/__init__.py @@ -127,6 +127,10 @@ ), ] +GCP_CLIENT_PROVIDERS_DESC = [ + ("gcp-default", ".gcp.gs_storage_client_factory.GcpDefaultClientProvider") +] + AZURE_CLIENT_PROVIDERS_DESC = [ ("azure-default", ".azure.azure_credential.AzureDefaultClientProvider") ] @@ -154,6 +158,7 @@ def get_plugin_cli(): AWS_CLIENT_PROVIDERS = resolve_plugins("aws_client_provider") SECRETS_PROVIDERS = resolve_plugins("secrets_provider") AZURE_CLIENT_PROVIDERS = resolve_plugins("azure_client_provider") +GCP_CLIENT_PROVIDERS = resolve_plugins("gcp_client_provider") from .cards.card_modules import MF_EXTERNAL_CARDS diff --git a/metaflow/plugins/gcp/__init__.py b/metaflow/plugins/gcp/__init__.py index e69de29bb2d..d9f9288b3f0 100644 --- a/metaflow/plugins/gcp/__init__.py +++ b/metaflow/plugins/gcp/__init__.py @@ -0,0 +1 @@ +from .gs_storage_client_factory import get_credentials diff --git a/metaflow/plugins/gcp/gs_storage_client_factory.py b/metaflow/plugins/gcp/gs_storage_client_factory.py index df915421182..1ec528a5a61 100644 --- a/metaflow/plugins/gcp/gs_storage_client_factory.py +++ b/metaflow/plugins/gcp/gs_storage_client_factory.py @@ -8,7 +8,7 @@ def _get_cache_key(): return os.getpid(), threading.get_ident() -def get_gs_storage_client(): +def _get_gs_storage_client_default(): cache_key = _get_cache_key() if cache_key not in _client_cache: from google.cloud import storage @@ -19,3 +19,54 @@ def get_gs_storage_client(): credentials=credentials, project=project_id ) return _client_cache[cache_key] + + +class GcpDefaultClientProvider(object): + name = "gcp-default" + + @staticmethod + def get_gs_storage_client(*args, **kwargs): + return _get_gs_storage_client_default() + + @staticmethod + def get_credentials(scopes, *args, **kwargs): + import google.auth + + return google.auth.default(scopes=scopes) + + +cached_provider_class = None + + +def get_gs_storage_client(): + global cached_provider_class + if cached_provider_class is None: + from metaflow.metaflow_config import DEFAULT_GCP_CLIENT_PROVIDER + from metaflow.plugins import GCP_CLIENT_PROVIDERS + + for p in GCP_CLIENT_PROVIDERS: + if p.name == DEFAULT_GCP_CLIENT_PROVIDER: + cached_provider_class = p + break + else: + raise ValueError( + "Cannot find GCP Client provider %s" % DEFAULT_GCP_CLIENT_PROVIDER + ) + return cached_provider_class.get_gs_storage_client() + + +def get_credentials(scopes, *args, **kwargs): + global cached_provider_class + if cached_provider_class is None: + from metaflow.metaflow_config import DEFAULT_GCP_CLIENT_PROVIDER + from metaflow.plugins import GCP_CLIENT_PROVIDERS + + for p in GCP_CLIENT_PROVIDERS: + if p.name == DEFAULT_GCP_CLIENT_PROVIDER: + cached_provider_class = p + break + else: + raise ValueError( + "Cannot find GCP Client provider %s" % DEFAULT_GCP_CLIENT_PROVIDER + ) + return cached_provider_class.get_credentials(scopes, *args, **kwargs) diff --git a/metaflow/plugins/kubernetes/kubernetes.py b/metaflow/plugins/kubernetes/kubernetes.py index ffea507ae64..5946283b37a 100644 --- a/metaflow/plugins/kubernetes/kubernetes.py +++ b/metaflow/plugins/kubernetes/kubernetes.py @@ -27,6 +27,7 @@ DATASTORE_SYSROOT_S3, DATATOOLS_S3ROOT, DEFAULT_AWS_CLIENT_PROVIDER, + DEFAULT_GCP_CLIENT_PROVIDER, DEFAULT_METADATA, DEFAULT_SECRETS_BACKEND_TYPE, AZURE_KEY_VAULT_PREFIX, @@ -307,6 +308,9 @@ def create_job_object( .environment_variable( "METAFLOW_DEFAULT_AWS_CLIENT_PROVIDER", DEFAULT_AWS_CLIENT_PROVIDER ) + .environment_variable( + "METAFLOW_DEFAULT_GCP_CLIENT_PROVIDER", DEFAULT_GCP_CLIENT_PROVIDER + ) .environment_variable( "METAFLOW_AWS_SECRETS_MANAGER_DEFAULT_REGION", AWS_SECRETS_MANAGER_DEFAULT_REGION,