Skip to content

Commit

Permalink
#60 Replay training dataset query (#73)
Browse files Browse the repository at this point in the history
Closes #60
  • Loading branch information
SirOibaf authored Oct 5, 2020
1 parent f4a6486 commit eb4edfe
Show file tree
Hide file tree
Showing 14 changed files with 250 additions and 48 deletions.
23 changes: 17 additions & 6 deletions java/src/main/java/com/logicalclocks/hsfs/TrainingDataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public class TrainingDataset {
private TrainingDatasetType trainingDatasetType = TrainingDatasetType.HOPSFS_TRAINING_DATASET;

@Getter @Setter
private List<Feature> features;
private List<TrainingDatasetFeature> features;

@Getter @Setter
@JsonIgnore
Expand Down Expand Up @@ -94,6 +94,10 @@ public class TrainingDataset {
@JsonIgnore
private List<String> statisticColumns;

@Getter @Setter
@JsonProperty("queryDTO")
private Query queryInt;

private TrainingDatasetEngine trainingDatasetEngine = new TrainingDatasetEngine();
private StatisticsEngine statisticsEngine = new StatisticsEngine(EntityEndpointType.TRAINING_DATASET);

Expand Down Expand Up @@ -157,11 +161,8 @@ public void save(Dataset<Row> dataset) throws FeatureStoreException, IOException
* @throws IOException
*/
public void save(Query query, Map<String, String> writeOptions) throws FeatureStoreException, IOException {
Dataset<Row> dataset = query.read();
trainingDatasetEngine.save(this, dataset, writeOptions);
if (statisticsEnabled) {
statisticsEngine.computeStatistics(this, dataset);
}
this.queryInt = query;
save(query.read(), writeOptions);
}

/**
Expand Down Expand Up @@ -383,4 +384,14 @@ public Map<String, String> getTag(String name) throws FeatureStoreException, IOE
public void deleteTag(String name) throws FeatureStoreException, IOException {
trainingDatasetEngine.deleteTag(this, name);
}

@JsonIgnore
public String getQuery() throws FeatureStoreException, IOException {
return getQuery(Storage.ONLINE);
}

@JsonIgnore
public String getQuery(Storage storage) throws FeatureStoreException, IOException {
return trainingDatasetEngine.getQuery(this, storage);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) 2020 Logical Clocks AB
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*
* See the License for the specific language governing permissions and limitations under the License.
*/

package com.logicalclocks.hsfs;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;

@NoArgsConstructor
@AllArgsConstructor
public class TrainingDatasetFeature {
@Getter @Setter
private String name;

@Getter @Setter
private String type;

@Getter @Setter
private FeatureGroup featureGroup;

@Getter @Setter
private Integer index;

@Builder
public TrainingDatasetFeature(String name, String type) {
this.name = name;
this.type = type;
}

public TrainingDatasetFeature(String name, String type, Integer index) {
this.name = name;
this.type = type;
this.index = index;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public void saveFeatureGroup(FeatureGroup featureGroup, Dataset<Row> dataset,
throws FeatureStoreException, IOException {

if (featureGroup.getFeatureStore() != null) {
featureGroup.setFeatures(utils.parseSchema(dataset));
featureGroup.setFeatures(utils.parseFeatureGroupSchema(dataset));
}

LOGGER.info("Featuregroup features: " + featureGroup.getFeatures());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.logicalclocks.hsfs.EntityEndpointType;
import com.logicalclocks.hsfs.FeatureStoreException;
import com.logicalclocks.hsfs.Storage;
import com.logicalclocks.hsfs.TrainingDataset;
import com.logicalclocks.hsfs.metadata.TagsApi;
import com.logicalclocks.hsfs.metadata.TrainingDatasetApi;
Expand Down Expand Up @@ -53,8 +54,11 @@ public class TrainingDatasetEngine {
public void save(TrainingDataset trainingDataset, Dataset<Row> dataset,
Map<String, String> userWriteOptions)
throws FeatureStoreException, IOException {
// TODO(Fabio): make sure we can implement the serving part as well
trainingDataset.setFeatures(utils.parseSchema(dataset));

if (trainingDataset.getQueryInt() == null) {
// if the training dataset hasn't been generated from a query, parse the schema and set the features
trainingDataset.setFeatures(utils.parseTrainingDatasetSchema(dataset));
}

// Make the rest call to create the training dataset metadata
TrainingDataset apiTD = trainingDatasetApi.createTrainingDataset(trainingDataset);
Expand Down Expand Up @@ -89,7 +93,7 @@ public void insert(TrainingDataset trainingDataset, Dataset<Row> dataset,
Map<String, String> providedOptions, SaveMode saveMode)
throws FeatureStoreException {
// validate that the schema matches
utils.schemaMatches(dataset, trainingDataset.getFeatures());
utils.trainingDatasetSchemaMatch(dataset, trainingDataset.getFeatures());

Map<String, String> writeOptions =
SparkEngine.getInstance().getWriteOptions(providedOptions, trainingDataset.getDataFormat());
Expand Down Expand Up @@ -128,4 +132,9 @@ public Map<String, String> getTag(TrainingDataset trainingDataset, String name)
public void deleteTag(TrainingDataset trainingDataset, String name) throws FeatureStoreException, IOException {
tagsApi.deleteTag(trainingDataset, name);
}

public String getQuery(TrainingDataset trainingDataset, Storage storage)
throws FeatureStoreException, IOException {
return trainingDatasetApi.getQuery(trainingDataset).getStorageQuery(storage);
}
}
26 changes: 20 additions & 6 deletions java/src/main/java/com/logicalclocks/hsfs/engine/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.logicalclocks.hsfs.FeatureStoreException;
import com.logicalclocks.hsfs.FeatureGroup;
import com.logicalclocks.hsfs.StorageConnector;
import com.logicalclocks.hsfs.TrainingDatasetFeature;
import io.hops.common.Pair;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
Expand All @@ -31,13 +32,13 @@
import scala.collection.Seq;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;

public class Utils {

// TODO(Fabio): make sure we keep save the feature store/feature group for serving
public List<Feature> parseSchema(Dataset<Row> dataset) throws FeatureStoreException {
public List<Feature> parseFeatureGroupSchema(Dataset<Row> dataset) throws FeatureStoreException {
List<Feature> features = new ArrayList<>();
for (StructField structField : dataset.schema().fields()) {
// TODO(Fabio): unit test this one for complext types
Expand All @@ -48,10 +49,23 @@ public List<Feature> parseSchema(Dataset<Row> dataset) throws FeatureStoreExcept
return features;
}

// TODO(Fabio): keep into account the sorting - needs fixing in Hopsworks as well
public void schemaMatches(Dataset<Row> dataset, List<Feature> features) throws FeatureStoreException {
StructType tdStructType = new StructType(features.stream().map(
f -> new StructField(f.getName(),
public List<TrainingDatasetFeature> parseTrainingDatasetSchema(Dataset<Row> dataset) throws FeatureStoreException {
List<TrainingDatasetFeature> features = new ArrayList<>();

int index = 0;
for (StructField structField : dataset.schema().fields()) {
// TODO(Fabio): unit test this one for complext types
features.add(new TrainingDatasetFeature(structField.name(), structField.dataType().catalogString(), index++));
}

return features;
}

public void trainingDatasetSchemaMatch(Dataset<Row> dataset, List<TrainingDatasetFeature> features)
throws FeatureStoreException {
StructType tdStructType = new StructType(features.stream()
.sorted(Comparator.comparingInt(TrainingDatasetFeature::getIndex))
.map(f -> new StructField(f.getName(),
// What should we do about the nullables
new CatalystSqlParser(null).parseDataType(f.getType()), true, Metadata.empty())
).toArray(StructField[]::new));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.damnhandy.uri.template.UriTemplate;
import com.logicalclocks.hsfs.FeatureStore;
import com.logicalclocks.hsfs.FeatureStoreException;
import com.logicalclocks.hsfs.FsQuery;
import com.logicalclocks.hsfs.TrainingDataset;
import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpGet;
Expand All @@ -33,8 +34,7 @@ public class TrainingDatasetApi {

private static final String TRAINING_DATASETS_PATH = "/trainingdatasets";
private static final String TRAINING_DATASET_PATH = TRAINING_DATASETS_PATH + "{/tdName}{?version}";
public static final String TRAINING_DATASET_ID_PATH = TRAINING_DATASETS_PATH + "{/tdId}";
public static final String TRAINING_DATASET_TAGS_PATH = TRAINING_DATASET_ID_PATH + "/tags{/name}{?value}";
private static final String TRAINING_QUERY_PATH = TRAINING_DATASETS_PATH + "{/tdId}/query";

private static final Logger LOGGER = LoggerFactory.getLogger(TrainingDatasetApi.class);

Expand Down Expand Up @@ -83,4 +83,23 @@ public TrainingDataset createTrainingDataset(TrainingDataset trainingDataset)
LOGGER.info(trainingDatasetJson);
return hopsworksClient.handleRequest(postRequest, TrainingDataset.class);
}

public FsQuery getQuery(TrainingDataset trainingDataset)
throws FeatureStoreException, IOException {
HopsworksClient hopsworksClient = HopsworksClient.getInstance();
String pathTemplate = HopsworksClient.PROJECT_PATH
+ FeatureStoreApi.FEATURE_STORE_PATH
+ TRAINING_QUERY_PATH;

String uri = UriTemplate.fromTemplate(pathTemplate)
.set("projectId", trainingDataset.getFeatureStore().getProjectId())
.set("fsId", trainingDataset.getFeatureStore().getId())
.set("tdId", trainingDataset.getId())
.expand();

HttpGet getRequest = new HttpGet(uri);
LOGGER.info("Sending metadata request: " + uri);

return hopsworksClient.handleRequest(getRequest, FsQuery.class);
}
}
2 changes: 1 addition & 1 deletion python/hsfs/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _send_request(
return response
else:
# handle different success response codes
if len(response.content) == 0:
if len(response.content) == 0:
return None
return response.json()

Expand Down
2 changes: 1 addition & 1 deletion python/hsfs/core/feature_group_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def save(self, feature_group, feature_dataframe, storage, write_options):

if len(feature_group.features) == 0:
# User didn't provide a schema. extract it from the dataframe
feature_group._features = engine.get_instance().parse_schema(
feature_group._features = engine.get_instance().parse_schema_feature_group(
feature_dataframe
)

Expand Down
13 changes: 13 additions & 0 deletions python/hsfs/core/training_dataset_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,16 @@ def get(self, name, version):
return training_dataset.TrainingDataset.from_response_json(
_client._send_request("GET", path_params, query_params)[0],
)

def get_query(self, training_dataset_instance):
_client = client.get_instance()
path_params = [
"project",
_client._project_id,
"featurestores",
self._feature_store_id,
"trainingdatasets",
training_dataset_instance.id,
"query",
]
return _client._send_request("GET", path_params)
9 changes: 8 additions & 1 deletion python/hsfs/core/training_dataset_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def insert(
self, training_dataset, feature_dataframe, user_write_options, overwrite
):
# validate matching schema
engine.get_instance().schema_matches(feature_dataframe, training_dataset.schema)
engine.get_instance().training_dataset_schema_match(
feature_dataframe, training_dataset.schema
)

write_options = engine.get_instance().write_options(
training_dataset.data_format, user_write_options
Expand Down Expand Up @@ -72,6 +74,11 @@ def read(self, training_dataset, split, user_read_options):
path,
)

def query(self, training_dataset, storage):
return self._training_dataset_api.get_query(training_dataset)[
"queryOnline" if storage.lower() == "online" else "query"
]

def _write(self, training_dataset, dataset, write_options, save_mode):
if len(training_dataset.splits) == 0:
path = training_dataset.location + "/" + training_dataset.name
Expand Down
41 changes: 27 additions & 14 deletions python/hsfs/engine/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
except ModuleNotFoundError:
pass

from hsfs import feature
from hsfs import feature, training_dataset_feature
from hsfs.storage_connector import StorageConnector
from hsfs.client.exceptions import FeatureStoreException

Expand Down Expand Up @@ -236,7 +236,7 @@ def read_options(self, data_format, provided_options):
options.update(provided_options)
return options

def parse_schema(self, dataframe):
def parse_schema_feature_group(self, dataframe):
return [
feature.Feature(
feat.name,
Expand All @@ -246,6 +246,14 @@ def parse_schema(self, dataframe):
for feat in dataframe.schema
]

def parse_schema_training_dataset(self, dataframe):
return [
training_dataset_feature.TrainingDatasetFeature(
feat.name, feat.dataType.simpleString()
)
for feat in dataframe.schema
]

def parse_schema_dict(self, dataframe):
return {
feat.name: feature.Feature(
Expand All @@ -256,22 +264,27 @@ def parse_schema_dict(self, dataframe):
for feat in dataframe.schema
}

def schema_matches(self, dataframe, schema):
# This does not respect order, for that we would need to make sure the features in the
# list coming from the backend are ordered correctly
insert_schema = self.parse_schema_dict(dataframe)
for feat in schema:
insert_feat = insert_schema.pop(feat.name, False)
if insert_feat:
if insert_feat.type == feat.type:
pass
else:
def training_dataset_schema_match(self, dataframe, schema):
schema_sorted = sorted(schema, key=lambda f: f.index)
insert_schema = dataframe.schema
if len(schema_sorted) != len(insert_schema):
raise SchemaError(
"Schemas do not match. Expected {} features, the dataframe contains {} features".format(
len(schema_sorted), len(insert_schema)
)
)

i = 0
for feat in schema_sorted:
if feat.name != insert_schema[i].name:
raise SchemaError(
"Schemas do not match, could not find feature {} among the data to be inserted.".format(
feat.name
"Schemas do not match, expected feature {} in position {}, found {}".format(
feat.name, str(i), insert_schema[i].name
)
)

i += 1

def _setup_s3(self, storage_connector, path):
if storage_connector.access_key:
self._spark_context._jsc.hadoopConfiguration().set(
Expand Down
9 changes: 0 additions & 9 deletions python/hsfs/feature_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,6 @@ def __init__(
self._default_storage = default_storage
self._hudi_enabled = hudi_enabled

if id is None:
# Initialized from the API
self._primary_key = primary_key
self._partition_key = partition_key
else:
# Initialized from the backend
self._primary_key = [f.name for f in self._features if f.primary]
self._partition_key = [f.name for f in self._features if f.partition]

if id is not None:
# initialized by backend
self.statistics_config = StatisticsConfig(
Expand Down
Loading

0 comments on commit eb4edfe

Please sign in to comment.