Skip to content

Commit

Permalink
Java client fixes for databricks (#53)
Browse files Browse the repository at this point in the history
Co-authored-by: Fabio Buso <buso.fabio@gmail.com>
  • Loading branch information
moritzmeister and SirOibaf authored Jul 28, 2020
1 parent 4d3d139 commit 604ae63
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public TrainingDataset(@NonNull String name, Integer version, String description
this.name = name;
this.version = version;
this.description = description;
this.dataFormat = dataFormat;
this.dataFormat = dataFormat != null ? dataFormat : DataFormat.TFRECORDS;
this.location = location;
this.storageConnector = storageConnector;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@
import com.logicalclocks.hsfs.TrainingDataset;
import com.logicalclocks.hsfs.util.Constants;
import lombok.Getter;
import org.apache.hadoop.fs.Path;
import org.apache.parquet.Strings;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;

import java.io.IOException;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -117,7 +116,7 @@ public void write(TrainingDataset trainingDataset, Dataset<Row> dataset,
// The actual data will be stored in training_ds_version/training_ds the double directory is needed
// for cases such as tfrecords in which we need to store also the schema
// also in case of multiple splits, the single splits will be stored inside the training dataset dir
String path = Paths.get(trainingDataset.getLocation(), trainingDataset.getName()).toString();
String path = new Path(trainingDataset.getLocation(), trainingDataset.getName()).toString();

writeSingle(dataset, trainingDataset.getDataFormat(),
writeOptions, saveMode, path);
Expand Down Expand Up @@ -200,7 +199,7 @@ private void writeSplits(Dataset<Row>[] datasets, DataFormat dataFormat, Map<Str
SaveMode saveMode, String basePath, List<Split> splits) {
for (int i=0; i < datasets.length; i++) {
writeSingle(datasets[i], dataFormat, writeOptions, saveMode,
Paths.get(basePath, splits.get(i).getName()).toString());
new Path(basePath, splits.get(i).getName()).toString());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,18 @@
package com.logicalclocks.hsfs.engine;

import com.logicalclocks.hsfs.EntityEndpointType;
import com.logicalclocks.hsfs.FeatureGroup;
import com.logicalclocks.hsfs.FeatureStoreException;
import com.logicalclocks.hsfs.TrainingDataset;
import com.logicalclocks.hsfs.metadata.TagsApi;
import com.logicalclocks.hsfs.metadata.TrainingDatasetApi;
import org.apache.hadoop.fs.Path;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.nio.file.Paths;
import java.util.Map;

public class TrainingDatasetEngine {
Expand Down Expand Up @@ -103,10 +102,9 @@ public Dataset<Row> read(TrainingDataset trainingDataset, String split, Map<Stri
String path = "";
if (com.google.common.base.Strings.isNullOrEmpty(split)) {
// ** glob means "all sub directories"
// TODO(Fabio): make sure it works on S3
path = Paths.get(trainingDataset.getLocation(), "**").toString();
path = new Path(trainingDataset.getLocation(), "**").toString();
} else {
path = Paths.get(trainingDataset.getLocation(), split).toString();
path = new Path(trainingDataset.getLocation(), split).toString();
}

Map<String, String> readOptions =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,39 @@
package com.logicalclocks.hsfs.metadata;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;

@NoArgsConstructor
@AllArgsConstructor
public class Credentials {

@Getter @Setter
private String kStore;
@Getter @Setter

private String tStore;
@Getter @Setter

private String password;

public String getkStore() {
return kStore;
}

public void setkStore(String kStore) {
this.kStore = kStore;
}

public String gettStore() {
return tStore;
}

public void settStore(String tStore) {
this.tStore = tStore;
}

public String getPassword() {
return password;
}

public void setPassword(String password) {
this.password = password;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package com.logicalclocks.hsfs.metadata;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.logicalclocks.hsfs.FeatureStoreException;
import com.logicalclocks.hsfs.Project;
import com.logicalclocks.hsfs.SecretStore;
Expand All @@ -29,6 +30,7 @@
import org.apache.http.conn.socket.ConnectionSocketFactory;
import org.apache.http.conn.socket.PlainConnectionSocketFactory;
import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
import org.apache.http.conn.ssl.TrustAllStrategy;
import org.apache.http.conn.ssl.TrustSelfSignedStrategy;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
Expand All @@ -54,11 +56,13 @@
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
import java.util.HashMap;

public class HopsworksExternalClient implements HopsworksHttpClient {

private static final Logger LOGGER = LoggerFactory.getLogger(HopsworksExternalClient.class.getName());
private static final String PARAM_NAME = "/hopsworks/role/";
private static final String PARAM_NAME_SECRET_STORE = "hopsworks/role/";
private static final String PARAM_NAME_PARAMETER_STORE = "/hopsworks/role/";

private PoolingHttpClientConnectionManager connectionPool = null;

Expand All @@ -73,7 +77,7 @@ public HopsworksExternalClient(String host, int port, Region region,
throws IOException, FeatureStoreException, KeyStoreException, CertificateException,
NoSuchAlgorithmException, KeyManagementException {

httpHost = new HttpHost(host, port);
httpHost = new HttpHost(host, port, "https");

connectionPool = new PoolingHttpClientConnectionManager(
createConnectionFactory(httpHost, hostnameVerification, trustStorePath));
Expand Down Expand Up @@ -102,6 +106,11 @@ private Registry<ConnectionSocketFactory> createConnectionFactory(HttpHost httpH
sslCtx = SSLContexts.custom()
.loadTrustMaterial(Paths.get(trustStorePath).toFile(), null, new TrustSelfSignedStrategy())
.build();
} else if (!hostnameVerification) {
// if hostnameVerification is set to false then accept also self signed certificates
sslCtx = SSLContexts.custom()
.loadTrustMaterial(new TrustAllStrategy())
.build();
} else {
sslCtx = SSLContext.getDefault();
}
Expand Down Expand Up @@ -147,27 +156,37 @@ private String readAPIKeyParamStore(Region region, String secretKey) throws Feat
SsmClient ssmClient = SsmClient.builder()
.region(region)
.build();
String paramName = PARAM_NAME + getAssumedRole() + "/type/" + secretKey;
String paramName = PARAM_NAME_PARAMETER_STORE + getAssumedRole() + "/type/" + secretKey;
GetParameterRequest paramRequest = GetParameterRequest.builder()
.name(paramName)
.withDecryption(true)
.build();
GetParameterResponse parameterResponse = ssmClient.getParameter(paramRequest);
return parameterResponse.getValueForField("Parameter", String.class)
.orElseThrow(() -> new FeatureStoreException("Could not find parameter " + paramName + " in parameter store"));
String apiKey = parameterResponse.parameter().value();
if (!Strings.isNullOrEmpty(apiKey)) {
return apiKey;
} else {
throw new FeatureStoreException("Could not find parameter " + paramName + " in parameter store");
}
}

private String readAPIKeySecretManager(Region region, String secretKey) throws FeatureStoreException {
private String readAPIKeySecretManager(Region region, String secretKey) throws FeatureStoreException, IOException {
SecretsManagerClient secretsManagerClient = SecretsManagerClient.builder()
.region(region)
.build();
String paramName = PARAM_NAME + getAssumedRole();
String paramName = PARAM_NAME_SECRET_STORE + getAssumedRole();
GetSecretValueRequest secretValueRequest = GetSecretValueRequest.builder()
.secretId(paramName)
.build();
GetSecretValueResponse secretValueResponse = secretsManagerClient.getSecretValue(secretValueRequest);
return secretValueResponse.getValueForField(secretKey, String.class)
.orElseThrow(() -> new FeatureStoreException("Could not find secret " + paramName + " in secret store"));
ObjectMapper objectMapper = new ObjectMapper();
HashMap<String, String> secretMap = objectMapper.readValue(secretValueResponse.secretString(), HashMap.class);
String apiKey = secretMap.get("api-key");
if (!Strings.isNullOrEmpty(apiKey)) {
return apiKey;
} else {
throw new FeatureStoreException("Could not find secret " + paramName + " in secret store");
}
}

private String getAssumedRole() throws FeatureStoreException {
Expand All @@ -177,16 +196,15 @@ private String getAssumedRole() throws FeatureStoreException {
// arn:aws:sts::123456789012:assumed-role/my-role-name/my-role-session-name
String arn = callerIdentityResponse.arn();
String[] arnSplits = arn.split("/");
if (arnSplits.length != 3 || !arnSplits[0].equals("assumed-role")) {
if (arnSplits.length != 3 || !arnSplits[0].endsWith("assumed-role")) {
throw new FeatureStoreException("Failed to extract assumed role from arn: " + arn);
}
return arnSplits[1];
}

@Override
public <T> T handleRequest(HttpRequest request, ResponseHandler<T> responseHandler) throws IOException,
FeatureStoreException {
LOGGER.debug("Handling metadata request: " + request);
public <T> T handleRequest(HttpRequest request, ResponseHandler<T> responseHandler) throws IOException {
LOGGER.info("Handling metadata request: " + request);
AuthorizationHandler<T> authHandler = new AuthorizationHandler<>(responseHandler);
request.setHeader(HttpHeaders.AUTHORIZATION, "ApiKey " + apiKey);
try {
Expand All @@ -204,9 +222,9 @@ public String downloadCredentials(Project project, String certPath) throws IOExc
Credentials credentials = projectApi.downloadCredentials(project);

FileUtils.writeByteArrayToFile(Paths.get(certPath, "keyStore.jks").toFile(),
Base64.decodeBase64(credentials.getKStore()));
Base64.decodeBase64(credentials.getkStore()));
FileUtils.writeByteArrayToFile(Paths.get(certPath, "trustStore.jks").toFile(),
Base64.decodeBase64(credentials.getTStore()));
Base64.decodeBase64(credentials.gettStore()));
return credentials.getPassword();
}
}

0 comments on commit 604ae63

Please sign in to comment.