Skip to content

Commit

Permalink
[HOPSWORKS-2091] Label/Prediction feature metadata for training datas…
Browse files Browse the repository at this point in the history
…ets (#127)
  • Loading branch information
moritzmeister authored Nov 6, 2020
1 parent 4753612 commit 7c704e0
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 15 deletions.
30 changes: 26 additions & 4 deletions java/src/main/java/com/logicalclocks/hsfs/TrainingDataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@NoArgsConstructor
public class TrainingDataset {
Expand Down Expand Up @@ -98,14 +99,18 @@ public class TrainingDataset {
@JsonProperty("queryDTO")
private Query queryInt;

@Setter
@JsonIgnore
private List<String> label;

private TrainingDatasetEngine trainingDatasetEngine = new TrainingDatasetEngine();
private StatisticsEngine statisticsEngine = new StatisticsEngine(EntityEndpointType.TRAINING_DATASET);

@Builder
public TrainingDataset(@NonNull String name, Integer version, String description, DataFormat dataFormat,
StorageConnector storageConnector, String location, List<Split> splits, Long seed,
FeatureStore featureStore, Boolean statisticsEnabled, Boolean histograms,
Boolean correlations, List<String> statisticColumns) {
Boolean correlations, List<String> statisticColumns, List<String> label) {
this.name = name;
this.version = version;
this.description = description;
Expand All @@ -128,6 +133,7 @@ public TrainingDataset(@NonNull String name, Integer version, String description
this.histograms = histograms;
this.correlations = correlations;
this.statisticColumns = statisticColumns;
this.label = label;
}

/**
Expand Down Expand Up @@ -175,7 +181,7 @@ public void save(Query query, Map<String, String> writeOptions) throws FeatureSt
*/
public void save(Dataset<Row> dataset, Map<String, String> writeOptions)
throws FeatureStoreException, IOException {
trainingDatasetEngine.save(this, dataset, writeOptions);
trainingDatasetEngine.save(this, dataset, writeOptions, label);
if (statisticsEnabled) {
statisticsEngine.computeStatistics(this, dataset);
}
Expand Down Expand Up @@ -387,11 +393,27 @@ public void deleteTag(String name) throws FeatureStoreException, IOException {

@JsonIgnore
public String getQuery() throws FeatureStoreException, IOException {
return getQuery(Storage.ONLINE);
return getQuery(Storage.ONLINE, false);
}

@JsonIgnore
public String getQuery(boolean withLabel) throws FeatureStoreException, IOException {
return getQuery(Storage.ONLINE, withLabel);
}

@JsonIgnore
public String getQuery(Storage storage) throws FeatureStoreException, IOException {
return trainingDatasetEngine.getQuery(this, storage);
return getQuery(storage, false);
}

@JsonIgnore
public String getQuery(Storage storage, boolean withLabel) throws FeatureStoreException, IOException {
return trainingDatasetEngine.getQuery(this, storage, withLabel);
}

@JsonIgnore
public List<String> getLabel() {
return features.stream().filter(TrainingDatasetFeature::getLabel).map(TrainingDatasetFeature::getName).collect(
Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ public class TrainingDatasetFeature {
@Getter @Setter
private Integer index;

@Getter @Setter
private Boolean label = false;

@Builder
public TrainingDatasetFeature(String name, String type) {
this.name = name;
Expand All @@ -48,4 +51,11 @@ public TrainingDatasetFeature(String name, String type, Integer index) {
this.type = type;
this.index = index;
}

public TrainingDatasetFeature(String name, String type, Integer index, Boolean label) {
this.name = name;
this.type = type;
this.index = index;
this.label = label;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.logicalclocks.hsfs.FeatureStoreException;
import com.logicalclocks.hsfs.Storage;
import com.logicalclocks.hsfs.TrainingDataset;
import com.logicalclocks.hsfs.TrainingDatasetFeature;
import com.logicalclocks.hsfs.metadata.TagsApi;
import com.logicalclocks.hsfs.metadata.TrainingDatasetApi;
import org.apache.hadoop.fs.Path;
Expand All @@ -30,7 +31,9 @@
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class TrainingDatasetEngine {

Expand All @@ -39,8 +42,6 @@ public class TrainingDatasetEngine {
private Utils utils = new Utils();

private static final Logger LOGGER = LoggerFactory.getLogger(TrainingDatasetEngine.class);
//TODO:
// Compute statistics

/**
* Make a REST call to Hopsworks to create the metadata and write the data on the File System.
Expand All @@ -51,13 +52,24 @@ public class TrainingDatasetEngine {
* @throws FeatureStoreException
* @throws IOException
*/
public void save(TrainingDataset trainingDataset, Dataset<Row> dataset,
Map<String, String> userWriteOptions)
public void save(TrainingDataset trainingDataset, Dataset<Row> dataset, Map<String, String> userWriteOptions,
List<String> label)
throws FeatureStoreException, IOException {

if (trainingDataset.getQueryInt() == null) {
// if the training dataset hasn't been generated from a query, parse the schema and set the features
trainingDataset.setFeatures(utils.parseTrainingDatasetSchema(dataset));
trainingDataset.setFeatures(utils.parseTrainingDatasetSchema(dataset));

// set label features
if (label != null && !label.isEmpty()) {
for (String l : label) {
Optional<TrainingDatasetFeature> feature =
trainingDataset.getFeatures().stream().filter(f -> f.getName().equals(l)).findFirst();
if (feature.isPresent()) {
feature.get().setLabel(true);
} else {
throw new FeatureStoreException("The specified label `" + l + "` could not be found among the features: "
+ trainingDataset.getFeatures().stream().map(TrainingDatasetFeature::getName) + ".");
}
}
}

// Make the rest call to create the training dataset metadata
Expand Down Expand Up @@ -133,8 +145,8 @@ public void deleteTag(TrainingDataset trainingDataset, String name) throws Featu
tagsApi.deleteTag(trainingDataset, name);
}

public String getQuery(TrainingDataset trainingDataset, Storage storage)
public String getQuery(TrainingDataset trainingDataset, Storage storage, boolean withLabel)
throws FeatureStoreException, IOException {
return trainingDatasetApi.getQuery(trainingDataset).getStorageQuery(storage);
return trainingDatasetApi.getQuery(trainingDataset, withLabel).getStorageQuery(storage);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public class TrainingDatasetApi {

private static final String TRAINING_DATASETS_PATH = "/trainingdatasets";
private static final String TRAINING_DATASET_PATH = TRAINING_DATASETS_PATH + "{/tdName}{?version}";
private static final String TRAINING_QUERY_PATH = TRAINING_DATASETS_PATH + "{/tdId}/query";
private static final String TRAINING_QUERY_PATH = TRAINING_DATASETS_PATH + "{/tdId}/query{?withLabel}";

private static final Logger LOGGER = LoggerFactory.getLogger(TrainingDatasetApi.class);

Expand Down Expand Up @@ -84,7 +84,7 @@ public TrainingDataset createTrainingDataset(TrainingDataset trainingDataset)
return hopsworksClient.handleRequest(postRequest, TrainingDataset.class);
}

public FsQuery getQuery(TrainingDataset trainingDataset)
public FsQuery getQuery(TrainingDataset trainingDataset, boolean withLabel)
throws FeatureStoreException, IOException {
HopsworksClient hopsworksClient = HopsworksClient.getInstance();
String pathTemplate = HopsworksClient.PROJECT_PATH
Expand All @@ -95,6 +95,7 @@ public FsQuery getQuery(TrainingDataset trainingDataset)
.set("projectId", trainingDataset.getFeatureStore().getProjectId())
.set("fsId", trainingDataset.getFeatureStore().getId())
.set("tdId", trainingDataset.getId())
.set("withLabel", withLabel)
.expand();

HttpGet getRequest = new HttpGet(uri);
Expand Down

0 comments on commit 7c704e0

Please sign in to comment.