Skip to content

Commit

Permalink
HSFS is defaulting to PyHive instead of Spark on EMR (#171)
Browse files Browse the repository at this point in the history
Also remove downloading of the certificates in the connect method for
PySpark clients. Certificates should already be present when the
application is started.
  • Loading branch information
SirOibaf authored Nov 30, 2020
1 parent 1fe5adb commit ac6c24d
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,9 @@ public HopsworksConnection(String host, int port, String project, Region region,
this.apiKeyFilePath = apiKeyFilePath;
this.apiKeyValue = apiKeyValue;

HopsworksClient hopsworksClient = HopsworksClient.setupHopsworksClient(host, port, region, secretStore,
HopsworksClient.setupHopsworksClient(host, port, region, secretStore,
hostnameVerification, trustStorePath, this.apiKeyFilePath, this.apiKeyValue);
projectObj = getProject();
hopsworksClient.downloadCredentials(projectObj, certPath);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;
import com.logicalclocks.hsfs.FeatureStoreException;
import com.logicalclocks.hsfs.Project;
import com.logicalclocks.hsfs.SecretStore;
import lombok.AllArgsConstructor;
import lombok.Getter;
Expand Down Expand Up @@ -155,8 +154,4 @@ public <T> T handleRequest(HttpRequest request, Class<T> cls) throws IOException
public <T> T handleRequest(HttpRequest request) throws IOException, FeatureStoreException {
return hopsworksHttpClient.handleRequest(request, null);
}

public void downloadCredentials(Project project, String certPath) throws IOException, FeatureStoreException {
certPwd = hopsworksHttpClient.downloadCredentials(project, certPath);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@

import com.fasterxml.jackson.databind.ObjectMapper;
import com.logicalclocks.hsfs.FeatureStoreException;
import com.logicalclocks.hsfs.Project;
import com.logicalclocks.hsfs.SecretStore;
import org.apache.commons.io.FileUtils;
import org.apache.commons.net.util.Base64;
import org.apache.http.HttpHeaders;
import org.apache.http.HttpHost;
import org.apache.http.HttpRequest;
Expand Down Expand Up @@ -243,17 +241,4 @@ public <T> T handleRequest(HttpRequest request, ResponseHandler<T> responseHandl
return httpClient.execute(httpHost, request, authHandler);
}
}

@Override
public String downloadCredentials(Project project, String certPath) throws IOException, FeatureStoreException {
LOGGER.info("Fetching certificates for the project");
ProjectApi projectApi = new ProjectApi();
Credentials credentials = projectApi.downloadCredentials(project);

FileUtils.writeByteArrayToFile(Paths.get(certPath, "keyStore.jks").toFile(),
Base64.decodeBase64(credentials.getkStore()));
FileUtils.writeByteArrayToFile(Paths.get(certPath, "trustStore.jks").toFile(),
Base64.decodeBase64(credentials.gettStore()));
return credentials.getPassword();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package com.logicalclocks.hsfs.metadata;

import com.logicalclocks.hsfs.FeatureStoreException;
import com.logicalclocks.hsfs.Project;
import org.apache.http.HttpRequest;
import org.apache.http.client.ResponseHandler;

Expand All @@ -26,6 +25,4 @@
public interface HopsworksHttpClient {
<T> T handleRequest(HttpRequest request, ResponseHandler<T> responseHandler)
throws IOException, FeatureStoreException;

String downloadCredentials(Project project, String certPath) throws IOException, FeatureStoreException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package com.logicalclocks.hsfs.metadata;

import com.logicalclocks.hsfs.FeatureStoreException;
import com.logicalclocks.hsfs.Project;
import jdk.nashorn.internal.runtime.regexp.joni.exception.InternalException;
import org.apache.http.HttpHeaders;
import org.apache.http.HttpHost;
Expand Down Expand Up @@ -157,10 +156,4 @@ public <T> T handleRequest(HttpRequest request, ResponseHandler<T> responseHandl
return httpClient.execute(httpHost, request, authHandler);
}
}

@Override
public String downloadCredentials(Project project, String certPath) {
// In Hopsworks internal client credentials are already setup.
return null;
}
}
10 changes: 0 additions & 10 deletions java/src/main/java/com/logicalclocks/hsfs/metadata/ProjectApi.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ public class ProjectApi {
private static final Logger LOGGER = LoggerFactory.getLogger(ProjectApi.class);

private static final String PROJECT_INFO_PATH = "/project/getProjectInfo{/projectName}";
private static final String CREDENTIALS_PATH = "/project{/projectId}/credentials";

public Project get(String name) throws IOException, FeatureStoreException {
HopsworksClient hopsworksClient = HopsworksClient.getInstance();
Expand All @@ -40,13 +39,4 @@ public Project get(String name) throws IOException, FeatureStoreException {
LOGGER.info("Sending metadata request: " + uri);
return hopsworksClient.handleRequest(new HttpGet(uri), Project.class);
}

public Credentials downloadCredentials(Project project) throws IOException, FeatureStoreException {
HopsworksClient hopsworksClient = HopsworksClient.getInstance();
String uri = UriTemplate.fromTemplate(HopsworksClient.API_PATH + CREDENTIALS_PATH)
.set("projectId", project.getProjectId())
.expand();
LOGGER.info("Sending metadata request: " + uri);
return hopsworksClient.handleRequest(new HttpGet(uri), Credentials.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
package com.logicalclocks.hsfs;

import com.logicalclocks.hsfs.metadata.Credentials;
import com.logicalclocks.hsfs.metadata.HopsworksClient;
import com.logicalclocks.hsfs.metadata.HopsworksExternalClient;
import io.specto.hoverfly.junit.core.SimulationSource;
import io.specto.hoverfly.junit.dsl.HttpBodyConverter;
Expand Down Expand Up @@ -53,21 +52,6 @@ public class TestHopsworksExternalClient {
.willReturn(success().body(HttpBodyConverter.json(credentials)))
));

// @Test
// public void testReadAPIKey() throws IOException, FeatureStoreException {
// CloseableHttpClient httpClient = HttpClients.createSystem();
// try {
// HopsworksConnection hc = HopsworksConnection.builder().host("35.241.253.100").hostnameVerification(false)
// .project("demo_featurestore_admin000")
// .apiKeyValue("ovVQksgJezSckjyK.ftO2YywCI6gZp4btlvWRnSDjSgyAQgCTRAoQTTSXBxPRMo0Dq029eAf3HVq3I6JO").build();
// System.out.println("Connected");
// FeatureStore fs = hc.getFeatureStore();
// Assert.assertTrue(fs != null);
// } catch (Exception e) {
// // Do not assert an error as this unit test method needs an external cluster
// }
// }

@Test
public void testReadAPIKeyFromFile() throws IOException, FeatureStoreException {
Path apiFilePath = Paths.get(System.getProperty("java.io.tmpdir"), "test.api");
Expand All @@ -79,19 +63,4 @@ public void testReadAPIKeyFromFile() throws IOException, FeatureStoreException {
String apiKey = hopsworksExternalClient.readApiKey(null, null, apiFilePath.toString());
Assert.assertEquals("hello", apiKey);
}

@Test
public void testDownloadCredential() throws Exception {
Project project = new Project(1);

CloseableHttpClient httpClient = HttpClients.createSystem();
HttpHost httpHost = new HttpHost("test");

HopsworksExternalClient hopsworksExternalClient = new HopsworksExternalClient(
httpClient, httpHost);

HopsworksClient.setInstance(new HopsworksClient(hopsworksExternalClient));
String password = hopsworksExternalClient.downloadCredentials(project, System.getProperty("java.io.tmpdir"));
Assert.assertEquals(certPwd, password);
}
}
52 changes: 30 additions & 22 deletions python/hsfs/client/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ def __init__(
self._base_url = "https://" + self._host + ":" + str(self._port)
self._project_name = project
self._region_name = region_name or self.DEFAULT_REGION
self._cert_folder_base = cert_folder
self._cert_folder = os.path.join(cert_folder, host, project)

if api_key_value is not None:
api_key = api_key_value
Expand All @@ -69,35 +67,45 @@ def __init__(
project_info = self._get_project_info(self._project_name)
self._project_id = str(project_info["projectId"])

os.makedirs(self._cert_folder, exist_ok=True)
credentials = self._get_credentials(self._project_id)
self._write_b64_cert_to_bytes(
str(credentials["kStore"]),
path=os.path.join(self._cert_folder, "keyStore.jks"),
)
self._write_b64_cert_to_bytes(
str(credentials["tStore"]),
path=os.path.join(self._cert_folder, "trustStore.jks"),
)

self._cert_key = str(credentials["password"])
with open(os.path.join(self._cert_folder, "material_passwd"), "w") as f:
f.write(str(credentials["password"]))
if cert_folder:
# On external Spark clients (Databricks, Spark Cluster),
# certificates need to be provided before the Spark application starts.
self._cert_folder_base = cert_folder
self._cert_folder = os.path.join(cert_folder, host, project)

os.makedirs(self._cert_folder, exist_ok=True)
credentials = self._get_credentials(self._project_id)
self._write_b64_cert_to_bytes(
str(credentials["kStore"]),
path=os.path.join(self._cert_folder, "keyStore.jks"),
)
self._write_b64_cert_to_bytes(
str(credentials["tStore"]),
path=os.path.join(self._cert_folder, "trustStore.jks"),
)

self._cert_key = str(credentials["password"])
with open(os.path.join(self._cert_folder, "material_passwd"), "w") as f:
f.write(str(credentials["password"]))

def _close(self):
"""Closes a client and deletes certificates."""
if not os.path.exists("/dbfs/"):
# Clean up only on AWS, on databricks certs are needed at startup time
self._cleanup_file(os.path.join(self._cert_folder, "keyStore.jks"))
self._cleanup_file(os.path.join(self._cert_folder, "trustStore.jks"))
self._cleanup_file(os.path.join(self._cert_folder, "material_passwd"))
if self._cert_folder_base is None:
# On external Spark clients (Databricks, Spark Cluster),
# certificates need to be provided before the Spark application starts.
return

# Clean up only on AWS
self._cleanup_file(os.path.join(self._cert_folder, "keyStore.jks"))
self._cleanup_file(os.path.join(self._cert_folder, "trustStore.jks"))
self._cleanup_file(os.path.join(self._cert_folder, "material_passwd"))

try:
# delete project level
os.rmdir(self._cert_folder)
# delete host level
os.rmdir(os.path.dirname(self._cert_folder))
# on AWS base dir will be empty, and can be deleted otherwise raises OSError
# on Databricks there will still be the scripts and clients therefore raises OSError
os.rmdir(self._cert_folder_base)
except OSError:
pass
Expand Down
22 changes: 10 additions & 12 deletions python/hsfs/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@


import os
import importlib.util

from requests.exceptions import ConnectionError

from hsfs.decorators import connected, not_connected
Expand All @@ -36,10 +38,10 @@ class Connection:
store but also any feature store which has been shared with the project you connect
to.
This class provides convenience classmethods accesible from the `hsfs`-module:
This class provides convenience classmethods accessible from the `hsfs`-module:
!!! example "Connection factory"
For convenience, `hsfs` provides a factory method, accesible from the top level
For convenience, `hsfs` provides a factory method, accessible from the top level
module, so you don't have to import the `Connection` class manually:
```python
Expand Down Expand Up @@ -89,7 +91,7 @@ class Connection:
trust_store_path: Path on the file system containing the Hopsworks certificates,
defaults to `None`.
cert_folder: The directory to store retrieved HopsFS certificates, defaults to
`"hops"`.
`"hops"`. Only required when running without a Spark environment.
api_key_file: Path to a file containing the API Key, if provided,
`secrets_store` will be ignored, defaults to `None`.
api_key_value: API Key as string, if provided, `secrets_store` will be ignored`,
Expand Down Expand Up @@ -167,8 +169,8 @@ def connect(self):
self._connected = True
try:
if client.base.Client.REST_ENDPOINT not in os.environ:
if os.path.exists("/dbfs/"):
# databricks
if importlib.util.find_spec("pyspark"):
# databricks, emr, external spark clusters
client.init(
"external",
self._host,
Expand All @@ -177,13 +179,9 @@ def connect(self):
self._region_name,
self._secrets_store,
self._hostname_verification,
os.path.join("/dbfs", self._trust_store_path)
if self._trust_store_path is not None
else None,
os.path.join("/dbfs", self._cert_folder),
os.path.join("/dbfs", self._api_key_file)
if self._api_key_file is not None
else None,
self._trust_store_path,
None,
self._api_key_file,
self._api_key_value,
)
engine.init("spark")
Expand Down
4 changes: 2 additions & 2 deletions python/hsfs/engine/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
#

import os
import importlib.util

import pandas as pd
import numpy as np
Expand Down Expand Up @@ -45,7 +45,7 @@ def __init__(self):
self._spark_session.conf.set("hive.exec.dynamic.partition.mode", "nonstrict")
self._spark_session.conf.set("spark.sql.hive.convertMetastoreParquet", "false")

if not os.path.exists("/dbfs/"):
if importlib.util.find_spec("pydoop"):
# If we are on Databricks don't setup Pydoop as it's not available and cannot be easily installed.
self._setup_pydoop()

Expand Down

0 comments on commit ac6c24d

Please sign in to comment.