Skip to content

Commit

Permalink
Version default to 1 for get methods and increment for create (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
moritzmeister authored Jul 7, 2020
1 parent b95f41f commit bd4bb3d
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 19 deletions.
9 changes: 2 additions & 7 deletions java/src/main/java/com/logicalclocks/hsfs/FeatureGroup.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.logicalclocks.hsfs.metadata.Query;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
Expand Down Expand Up @@ -78,16 +79,10 @@ public class FeatureGroup {
private FeatureGroupEngine featureGroupEngine = new FeatureGroupEngine();

@Builder
public FeatureGroup(FeatureStore featureStore, String name, Integer version, String description,
public FeatureGroup(FeatureStore featureStore, @NonNull String name, Integer version, String description,
List<String> primaryKeys, List<String> partitionKeys,
boolean onlineEnabled, Storage defaultStorage, List<Feature> features)
throws FeatureStoreException {
if (name == null) {
throw new FeatureStoreException("Name is required when creating a feature group");
}
if (version == null) {
throw new FeatureStoreException("Version is required when creating a feature group");
}

this.featureStore = featureStore;
this.name = name;
Expand Down
45 changes: 36 additions & 9 deletions java/src/main/java/com/logicalclocks/hsfs/FeatureStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@

import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Strings;
import com.logicalclocks.hsfs.engine.FeatureGroupEngine;
import com.logicalclocks.hsfs.engine.SparkEngine;
import com.logicalclocks.hsfs.metadata.FeatureGroupApi;
import com.logicalclocks.hsfs.metadata.StorageConnectorApi;
import com.logicalclocks.hsfs.metadata.TrainingDatasetApi;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;

Expand All @@ -45,27 +49,40 @@ public class FeatureStore {
private TrainingDatasetApi trainingDatasetApi;
private StorageConnectorApi storageConnectorApi;

private static final Logger LOGGER = LoggerFactory.getLogger(FeatureStore.class);

final Integer DEFAULT_VERSION = 1;

public FeatureStore() throws FeatureStoreException {
featureGroupApi = new FeatureGroupApi();
trainingDatasetApi = new TrainingDatasetApi();
storageConnectorApi = new StorageConnectorApi();
}

/**
* Get a feature group from the feature store
* Get a feature group object from the feature store
* @param name: the name of the feature group
* @param version: the version of the feature group
* @return
* @throws FeatureStoreException
*/
public FeatureGroup getFeatureGroup(String name, Integer version)
public FeatureGroup getFeatureGroup(@NonNull String name, @NonNull Integer version)
throws FeatureStoreException, IOException {
if (Strings.isNullOrEmpty(name) || version == null) {
throw new FeatureStoreException("Both name and version are required");
}
return featureGroupApi.get(this, name, version);
}

/**
* Get a feature group object with default version `1` from the feature store
* @param name: the name of the feature group
* @return
* @throws FeatureStoreException
*/
public FeatureGroup getFeatureGroup(String name) throws FeatureStoreException, IOException {
LOGGER.info("VersionWarning: No version provided for getting feature group `" + name + "`, defaulting to `" +
DEFAULT_VERSION + "`.");
return getFeatureGroup(name, DEFAULT_VERSION);
}

public Dataset<Row> sql(String query) {
return SparkEngine.getInstance().sql(query);
}
Expand Down Expand Up @@ -93,14 +110,24 @@ public TrainingDataset.TrainingDatasetBuilder createTrainingDataset() {
* @throws FeatureStoreException
* @throws IOException
*/
public TrainingDataset getTrainingDataset(String name, Integer version)
public TrainingDataset getTrainingDataset(@NonNull String name, @NonNull Integer version)
throws FeatureStoreException, IOException {
if (Strings.isNullOrEmpty(name) || version == null) {
throw new FeatureStoreException("Both name and version are required");
}
return trainingDatasetApi.get(this, name, version);
}

/**
* Get a training dataset object with the default version `1` from the selected feature store
* @param name: name of the training dataset
* @return
* @throws FeatureStoreException
* @throws IOException
*/
public TrainingDataset getTrainingDataset(String name) throws FeatureStoreException, IOException {
LOGGER.info("VersionWarning: No version provided for getting training dataset `" + name + "`, defaulting to `" +
DEFAULT_VERSION + "`.");
return getTrainingDataset(name, DEFAULT_VERSION);
}

@Override
public String toString() {
return "FeatureStore{" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public class TrainingDataset {
private TrainingDatasetEngine trainingDatasetEngine = new TrainingDatasetEngine();

@Builder
public TrainingDataset(@NonNull String name, @NonNull Integer version, String description,
public TrainingDataset(@NonNull String name, Integer version, String description,
DataFormat dataFormat, StorageConnector storageConnector,
String location, List<Split> splits, Long seed,
FeatureStore featureStore) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,15 @@ public void saveFeatureGroup(FeatureGroup featureGroup, Dataset<Row> dataset,
}

// Send Hopsworks the request to create a new feature group
featureGroupApi.save(featureGroup);
FeatureGroup apiFG = featureGroupApi.save(featureGroup);

if (featureGroup.getVersion() == null) {
LOGGER.info("VersionWarning: No version provided for creating feature group `" + featureGroup.getName() +
"`, incremented version to `" + apiFG.getVersion() + "`.");
}

// Update the original object - Hopsworks returns the incremented version
featureGroup.setVersion(apiFG.getVersion());

// Write the dataframe
saveDataframe(featureGroup, dataset, storage, SaveMode.Append, writeOptions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.nio.file.Paths;
Expand All @@ -31,6 +33,8 @@ public class TrainingDatasetEngine {
private TrainingDatasetApi trainingDatasetApi = new TrainingDatasetApi();
private Utils utils = new Utils();

private static final Logger LOGGER = LoggerFactory.getLogger(TrainingDatasetEngine.class);

//TODO:
// Compute statistics

Expand All @@ -50,8 +54,15 @@ public void save(TrainingDataset trainingDataset, Dataset<Row> dataset,

// Make the rest call to create the training dataset metadata
TrainingDataset apiTD = trainingDatasetApi.createTrainingDataset(trainingDataset);
// Update the original object - Hopsworks returns the full location

if (trainingDataset.getVersion() == null) {
LOGGER.info("VersionWarning: No version provided for creating training dataset `" + trainingDataset.getName() +
"`, incremented version to `" + apiTD.getVersion() + "`.");
}

// Update the original object - Hopsworks returns the full location and incremented version
trainingDataset.setLocation(apiTD.getLocation());
trainingDataset.setVersion(apiTD.getVersion());

// Build write options map
Map<String, String> writeOptions =
Expand Down

0 comments on commit bd4bb3d

Please sign in to comment.