Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions java/src/main/java/com/logicalclocks/hsfs/FeatureGroup.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.logicalclocks.hsfs.metadata.Query;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
Expand Down Expand Up @@ -78,16 +79,10 @@ public class FeatureGroup {
private FeatureGroupEngine featureGroupEngine = new FeatureGroupEngine();

@Builder
public FeatureGroup(FeatureStore featureStore, String name, Integer version, String description,
public FeatureGroup(FeatureStore featureStore, @NonNull String name, Integer version, String description,
List<String> primaryKeys, List<String> partitionKeys,
boolean onlineEnabled, Storage defaultStorage, List<Feature> features)
throws FeatureStoreException {
if (name == null) {
throw new FeatureStoreException("Name is required when creating a feature group");
}
if (version == null) {
throw new FeatureStoreException("Version is required when creating a feature group");
}

this.featureStore = featureStore;
this.name = name;
Expand Down
45 changes: 36 additions & 9 deletions java/src/main/java/com/logicalclocks/hsfs/FeatureStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@

import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Strings;
import com.logicalclocks.hsfs.engine.FeatureGroupEngine;
import com.logicalclocks.hsfs.engine.SparkEngine;
import com.logicalclocks.hsfs.metadata.FeatureGroupApi;
import com.logicalclocks.hsfs.metadata.StorageConnectorApi;
import com.logicalclocks.hsfs.metadata.TrainingDatasetApi;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;

Expand All @@ -45,27 +49,40 @@ public class FeatureStore {
private TrainingDatasetApi trainingDatasetApi;
private StorageConnectorApi storageConnectorApi;

private static final Logger LOGGER = LoggerFactory.getLogger(FeatureStore.class);

final Integer DEFAULT_VERSION = 1;

public FeatureStore() throws FeatureStoreException {
featureGroupApi = new FeatureGroupApi();
trainingDatasetApi = new TrainingDatasetApi();
storageConnectorApi = new StorageConnectorApi();
}

/**
* Get a feature group from the feature store
* Get a feature group object from the feature store
* @param name: the name of the feature group
* @param version: the version of the feature group
* @return
* @throws FeatureStoreException
*/
public FeatureGroup getFeatureGroup(String name, Integer version)
public FeatureGroup getFeatureGroup(@NonNull String name, @NonNull Integer version)
throws FeatureStoreException, IOException {
if (Strings.isNullOrEmpty(name) || version == null) {
throw new FeatureStoreException("Both name and version are required");
}
return featureGroupApi.get(this, name, version);
}

/**
* Get a feature group object with default version `1` from the feature store
* @param name: the name of the feature group
* @return
* @throws FeatureStoreException
*/
public FeatureGroup getFeatureGroup(String name) throws FeatureStoreException, IOException {
LOGGER.info("VersionWarning: No version provided for getting feature group `" + name + "`, defaulting to `" +
DEFAULT_VERSION + "`.");
return getFeatureGroup(name, DEFAULT_VERSION);
}

public Dataset<Row> sql(String query) {
return SparkEngine.getInstance().sql(query);
}
Expand Down Expand Up @@ -93,14 +110,24 @@ public TrainingDataset.TrainingDatasetBuilder createTrainingDataset() {
* @throws FeatureStoreException
* @throws IOException
*/
public TrainingDataset getTrainingDataset(String name, Integer version)
public TrainingDataset getTrainingDataset(@NonNull String name, @NonNull Integer version)
throws FeatureStoreException, IOException {
if (Strings.isNullOrEmpty(name) || version == null) {
throw new FeatureStoreException("Both name and version are required");
}
return trainingDatasetApi.get(this, name, version);
}

/**
* Get a training dataset object with the default version `1` from the selected feature store
* @param name: name of the training dataset
* @return
* @throws FeatureStoreException
* @throws IOException
*/
public TrainingDataset getTrainingDataset(String name) throws FeatureStoreException, IOException {
LOGGER.info("VersionWarning: No version provided for getting training dataset `" + name + "`, defaulting to `" +
DEFAULT_VERSION + "`.");
return getTrainingDataset(name, DEFAULT_VERSION);
}

@Override
public String toString() {
return "FeatureStore{" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public class TrainingDataset {
private TrainingDatasetEngine trainingDatasetEngine = new TrainingDatasetEngine();

@Builder
public TrainingDataset(@NonNull String name, @NonNull Integer version, String description,
public TrainingDataset(@NonNull String name, Integer version, String description,
DataFormat dataFormat, StorageConnector storageConnector,
String location, List<Split> splits, Long seed,
FeatureStore featureStore) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,15 @@ public void saveFeatureGroup(FeatureGroup featureGroup, Dataset<Row> dataset,
}

// Send Hopsworks the request to create a new feature group
featureGroupApi.save(featureGroup);
FeatureGroup apiFG = featureGroupApi.save(featureGroup);

if (featureGroup.getVersion() == null) {
LOGGER.info("VersionWarning: No version provided for creating feature group `" + featureGroup.getName() +
"`, incremented version to `" + apiFG.getVersion() + "`.");
}

// Update the original object - Hopsworks returns the incremented version
featureGroup.setVersion(apiFG.getVersion());

// Write the dataframe
saveDataframe(featureGroup, dataset, storage, SaveMode.Append, writeOptions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.nio.file.Paths;
Expand All @@ -31,6 +33,8 @@ public class TrainingDatasetEngine {
private TrainingDatasetApi trainingDatasetApi = new TrainingDatasetApi();
private Utils utils = new Utils();

private static final Logger LOGGER = LoggerFactory.getLogger(TrainingDatasetEngine.class);

//TODO:
// Compute statistics

Expand All @@ -50,8 +54,15 @@ public void save(TrainingDataset trainingDataset, Dataset<Row> dataset,

// Make the rest call to create the training dataset metadata
TrainingDataset apiTD = trainingDatasetApi.createTrainingDataset(trainingDataset);
// Update the original object - Hopsworks returns the full location

if (trainingDataset.getVersion() == null) {
LOGGER.info("VersionWarning: No version provided for creating training dataset `" + trainingDataset.getName() +
"`, incremented version to `" + apiTD.getVersion() + "`.");
}

// Update the original object - Hopsworks returns the full location and incremented version
trainingDataset.setLocation(apiTD.getLocation());
trainingDataset.setVersion(apiTD.getVersion());

// Build write options map
Map<String, String> writeOptions =
Expand Down