Skip to content

Commit

Permalink
logicalclocks#60 Replay training dataset query
Browse files Browse the repository at this point in the history
  • Loading branch information
SirOibaf committed Aug 9, 2020
1 parent ac4df84 commit e016f1a
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 12 deletions.
14 changes: 13 additions & 1 deletion java/src/main/java/com/logicalclocks/hsfs/TrainingDataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public class TrainingDataset {
private TrainingDatasetType trainingDatasetType = TrainingDatasetType.HOPSFS_TRAINING_DATASET;

@Getter @Setter
private List<Feature> features;
private List<TrainingDatasetFeature> features;

@Getter @Setter
@JsonIgnore
Expand All @@ -75,6 +75,9 @@ public class TrainingDataset {
@Getter @Setter
private List<Split> splits;

@Getter @Setter
private Query query;

private TrainingDatasetEngine trainingDatasetEngine = new TrainingDatasetEngine();

@Builder
Expand Down Expand Up @@ -133,6 +136,7 @@ public void save(Dataset<Row> dataset) throws FeatureStoreException, IOException
* @throws IOException
*/
public void save(Query query, Map<String, String> writeOptions) throws FeatureStoreException, IOException {
this.query = query;
trainingDatasetEngine.save(this, query.read(), writeOptions);
}

Expand Down Expand Up @@ -311,4 +315,12 @@ public Map<String, String> getTag(String name) throws FeatureStoreException, IOE
public void deleteTag(String name) throws FeatureStoreException, IOException {
trainingDatasetEngine.deleteTag(this, name);
}

public String getOriginatingQuery() throws FeatureStoreException, IOException {
return getOriginatingQuery(Storage.ONLINE);
}

public String getOriginatingQuery(Storage storage) throws FeatureStoreException, IOException {
return trainingDatasetEngine.getOriginatingQuery(this, storage);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) 2020 Logical Clocks AB
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*
* See the License for the specific language governing permissions and limitations under the License.
*/

package com.logicalclocks.hsfs;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;

@NoArgsConstructor
@AllArgsConstructor
public class TrainingDatasetFeature {
@Getter @Setter
private String name;

@Getter @Setter
private String type;

@Getter @Setter
private FeatureGroup featureGroup;

@Getter @Setter
private Integer index;

@Builder
public TrainingDatasetFeature(String name, String type) {
this.name = name;
this.type = type;
}

public TrainingDatasetFeature(String name, String type, Integer index) {
this.name = name;
this.type = type;
this.index = index;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public void saveFeatureGroup(FeatureGroup featureGroup, Dataset<Row> dataset,
throws FeatureStoreException, IOException {

if (featureGroup.getFeatureStore() != null) {
featureGroup.setFeatures(utils.parseSchema(dataset));
featureGroup.setFeatures(utils.parseFeatureGroupSchema(dataset));
}

LOGGER.info("Featuregroup features: " + featureGroup.getFeatures());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.logicalclocks.hsfs.EntityEndpointType;
import com.logicalclocks.hsfs.FeatureStoreException;
import com.logicalclocks.hsfs.Storage;
import com.logicalclocks.hsfs.TrainingDataset;
import com.logicalclocks.hsfs.metadata.TagsApi;
import com.logicalclocks.hsfs.metadata.TrainingDatasetApi;
Expand Down Expand Up @@ -54,8 +55,11 @@ public class TrainingDatasetEngine {
public void save(TrainingDataset trainingDataset, Dataset<Row> dataset,
Map<String, String> userWriteOptions)
throws FeatureStoreException, IOException {
// TODO(Fabio): make sure we can implement the serving part as well
trainingDataset.setFeatures(utils.parseSchema(dataset));

if (trainingDataset.getQuery() != null) {
// if the training dataset hasn't been generated from a query, parse the schema and set the features
trainingDataset.setFeatures(utils.parseTrainingDatasetSchema(dataset));
}

// Make the rest call to create the training dataset metadata
TrainingDataset apiTD = trainingDatasetApi.createTrainingDataset(trainingDataset);
Expand Down Expand Up @@ -128,4 +132,9 @@ public Map<String, String> getTag(TrainingDataset trainingDataset, String name)
public void deleteTag(TrainingDataset trainingDataset, String name) throws FeatureStoreException, IOException {
tagsApi.deleteTag(trainingDataset, name);
}

public String getOriginatingQuery(TrainingDataset trainingDataset, Storage storage)
throws FeatureStoreException, IOException {
return trainingDatasetApi.getOriginatingQuery(trainingDataset).getStorageQuery(storage);
}
}
25 changes: 19 additions & 6 deletions java/src/main/java/com/logicalclocks/hsfs/engine/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.logicalclocks.hsfs.FeatureStoreException;
import com.logicalclocks.hsfs.FeatureGroup;
import com.logicalclocks.hsfs.StorageConnector;
import com.logicalclocks.hsfs.TrainingDatasetFeature;
import io.hops.common.Pair;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
Expand All @@ -31,13 +32,13 @@
import scala.collection.Seq;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;

public class Utils {

// TODO(Fabio): make sure we keep save the feature store/feature group for serving
public List<Feature> parseSchema(Dataset<Row> dataset) throws FeatureStoreException {
public List<Feature> parseFeatureGroupSchema(Dataset<Row> dataset) throws FeatureStoreException {
List<Feature> features = new ArrayList<>();
for (StructField structField : dataset.schema().fields()) {
// TODO(Fabio): unit test this one for complext types
Expand All @@ -48,10 +49,22 @@ public List<Feature> parseSchema(Dataset<Row> dataset) throws FeatureStoreExcept
return features;
}

// TODO(Fabio): keep into account the sorting - needs fixing in Hopsworks as well
public void schemaMatches(Dataset<Row> dataset, List<Feature> features) throws FeatureStoreException {
StructType tdStructType = new StructType(features.stream().map(
f -> new StructField(f.getName(),
public List<TrainingDatasetFeature> parseTrainingDatasetSchema(Dataset<Row> dataset) throws FeatureStoreException {
List<TrainingDatasetFeature> features = new ArrayList<>();

int index = 0;
for (StructField structField : dataset.schema().fields()) {
// TODO(Fabio): unit test this one for complext types
features.add(new TrainingDatasetFeature(structField.name(), structField.dataType().catalogString(), index++));
}

return features;
}

public void schemaMatches(Dataset<Row> dataset, List<TrainingDatasetFeature> features) throws FeatureStoreException {
StructType tdStructType = new StructType(features.stream()
.sorted(Comparator.comparingInt(TrainingDatasetFeature::getIndex))
.map(f -> new StructField(f.getName(),
// What should we do about the nullables
new CatalystSqlParser(null).parseDataType(f.getType()), true, Metadata.empty())
).toArray(StructField[]::new));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.damnhandy.uri.template.UriTemplate;
import com.logicalclocks.hsfs.FeatureStore;
import com.logicalclocks.hsfs.FeatureStoreException;
import com.logicalclocks.hsfs.FsQuery;
import com.logicalclocks.hsfs.TrainingDataset;
import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpGet;
Expand All @@ -33,8 +34,7 @@ public class TrainingDatasetApi {

private static final String TRAINING_DATASETS_PATH = "/trainingdatasets";
private static final String TRAINING_DATASET_PATH = TRAINING_DATASETS_PATH + "{/tdName}{?version}";
public static final String TRAINING_DATASET_ID_PATH = TRAINING_DATASETS_PATH + "{/tdId}";
public static final String TRAINING_DATASET_TAGS_PATH = TRAINING_DATASET_ID_PATH + "/tags{/name}{?value}";
private static final String TRAINING_QUERY_PATH = TRAINING_DATASETS_PATH + "{/tdId}/query";

private static final Logger LOGGER = LoggerFactory.getLogger(TrainingDatasetApi.class);

Expand Down Expand Up @@ -83,4 +83,23 @@ public TrainingDataset createTrainingDataset(TrainingDataset trainingDataset)
LOGGER.info(trainingDatasetJson);
return hopsworksClient.handleRequest(postRequest, TrainingDataset.class);
}

public FsQuery getOriginatingQuery(TrainingDataset trainingDataset)
throws FeatureStoreException, IOException {
HopsworksClient hopsworksClient = HopsworksClient.getInstance();
String pathTemplate = HopsworksClient.PROJECT_PATH
+ FeatureStoreApi.FEATURE_STORE_PATH
+ TRAINING_QUERY_PATH;

String uri = UriTemplate.fromTemplate(pathTemplate)
.set("projectId", trainingDataset.getFeatureStore().getProjectId())
.set("fsId", trainingDataset.getFeatureStore().getId())
.set("tdId", trainingDataset.getId())
.expand();

HttpGet getRequest = new HttpGet(uri);
LOGGER.info("Sending metadata request: " + uri);

return hopsworksClient.handleRequest(getRequest, FsQuery.class);
}
}

0 comments on commit e016f1a

Please sign in to comment.