diff --git a/python/hsfs/client/external.py b/python/hsfs/client/external.py index da607371b6..56fd7da1c7 100644 --- a/python/hsfs/client/external.py +++ b/python/hsfs/client/external.py @@ -53,8 +53,6 @@ def __init__( self._base_url = "https://" + self._host + ":" + str(self._port) self._project_name = project self._region_name = region_name or self.DEFAULT_REGION - self._cert_folder_base = cert_folder - self._cert_folder = os.path.join(cert_folder, host, project) if api_key_value is not None: api_key = api_key_value @@ -69,35 +67,45 @@ def __init__( project_info = self._get_project_info(self._project_name) self._project_id = str(project_info["projectId"]) - os.makedirs(self._cert_folder, exist_ok=True) - credentials = self._get_credentials(self._project_id) - self._write_b64_cert_to_bytes( - str(credentials["kStore"]), - path=os.path.join(self._cert_folder, "keyStore.jks"), - ) - self._write_b64_cert_to_bytes( - str(credentials["tStore"]), - path=os.path.join(self._cert_folder, "trustStore.jks"), - ) - - self._cert_key = str(credentials["password"]) - with open(os.path.join(self._cert_folder, "material_passwd"), "w") as f: - f.write(str(credentials["password"])) + if cert_folder: + # On external Spark clients (Databricks, Spark Cluster), + # certificates need to be provided before the Spark application starts. + self._cert_folder_base = cert_folder + self._cert_folder = os.path.join(cert_folder, host, project) + + os.makedirs(self._cert_folder, exist_ok=True) + credentials = self._get_credentials(self._project_id) + self._write_b64_cert_to_bytes( + str(credentials["kStore"]), + path=os.path.join(self._cert_folder, "keyStore.jks"), + ) + self._write_b64_cert_to_bytes( + str(credentials["tStore"]), + path=os.path.join(self._cert_folder, "trustStore.jks"), + ) + + self._cert_key = str(credentials["password"]) + with open(os.path.join(self._cert_folder, "material_passwd"), "w") as f: + f.write(str(credentials["password"])) def _close(self): """Closes a client and deletes certificates.""" - if not os.path.exists("/dbfs/"): - # Clean up only on AWS, on databricks certs are needed at startup time - self._cleanup_file(os.path.join(self._cert_folder, "keyStore.jks")) - self._cleanup_file(os.path.join(self._cert_folder, "trustStore.jks")) - self._cleanup_file(os.path.join(self._cert_folder, "material_passwd")) + if self._cert_folder_base is None: + # On external Spark clients (Databricks, Spark Cluster), + # certificates need to be provided before the Spark application starts. + return + + # Clean up only on AWS + self._cleanup_file(os.path.join(self._cert_folder, "keyStore.jks")) + self._cleanup_file(os.path.join(self._cert_folder, "trustStore.jks")) + self._cleanup_file(os.path.join(self._cert_folder, "material_passwd")) + try: # delete project level os.rmdir(self._cert_folder) # delete host level os.rmdir(os.path.dirname(self._cert_folder)) # on AWS base dir will be empty, and can be deleted otherwise raises OSError - # on Databricks there will still be the scripts and clients therefore raises OSError os.rmdir(self._cert_folder_base) except OSError: pass diff --git a/python/hsfs/connection.py b/python/hsfs/connection.py index ddfe87c865..de02a01cfc 100644 --- a/python/hsfs/connection.py +++ b/python/hsfs/connection.py @@ -16,6 +16,8 @@ import os +import importlib.util + from requests.exceptions import ConnectionError from hsfs.decorators import connected, not_connected @@ -36,10 +38,10 @@ class Connection: store but also any feature store which has been shared with the project you connect to. - This class provides convenience classmethods accesible from the `hsfs`-module: + This class provides convenience classmethods accessible from the `hsfs`-module: !!! example "Connection factory" - For convenience, `hsfs` provides a factory method, accesible from the top level + For convenience, `hsfs` provides a factory method, accessible from the top level module, so you don't have to import the `Connection` class manually: ```python @@ -89,7 +91,7 @@ class Connection: trust_store_path: Path on the file system containing the Hopsworks certificates, defaults to `None`. cert_folder: The directory to store retrieved HopsFS certificates, defaults to - `"hops"`. + `"hops"`. Only required when running without a Spark environment. api_key_file: Path to a file containing the API Key, if provided, `secrets_store` will be ignored, defaults to `None`. api_key_value: API Key as string, if provided, `secrets_store` will be ignored`, @@ -167,8 +169,8 @@ def connect(self): self._connected = True try: if client.base.Client.REST_ENDPOINT not in os.environ: - if os.path.exists("/dbfs/"): - # databricks + if importlib.util.find_spec("pyspark"): + # databricks, emr, external spark clusters client.init( "external", self._host, @@ -177,11 +179,11 @@ def connect(self): self._region_name, self._secrets_store, self._hostname_verification, - os.path.join("/dbfs", self._trust_store_path) + self._trust_store_path if self._trust_store_path is not None else None, - os.path.join("/dbfs", self._cert_folder), - os.path.join("/dbfs", self._api_key_file) + None, + self._api_key_file, if self._api_key_file is not None else None, self._api_key_value,