Skip to content

Commit

Permalink
[HOPSWORKS-3242] Improve handling of different data types in feature …
Browse files Browse the repository at this point in the history
…groups across online/offline (logicalclocks#682)
  • Loading branch information
tdoehmen authored and kennethmhc committed Nov 16, 2022
1 parent 6c298af commit 94f2b06
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ public ExternalFeatureGroup saveFeatureGroup(ExternalFeatureGroup externalFeatur
if (externalFeatureGroup.getFeatures() == null) {
onDemandDataset = SparkEngine.getInstance()
.registerOnDemandTemporaryTable(externalFeatureGroup, "read_ondmd");
externalFeatureGroup.setFeatures(utils.parseFeatureGroupSchema(onDemandDataset));
externalFeatureGroup.setFeatures(utils.parseFeatureGroupSchema(onDemandDataset,
externalFeatureGroup.getTimeTravelFormat()));
}

/* set primary features */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ public FeatureGroup saveFeatureGroupMetaData(FeatureGroup featureGroup, List<Str
throws FeatureStoreException, IOException, ParseException {

if (featureGroup.getFeatures() == null) {
featureGroup.setFeatures(utils.parseFeatureGroupSchema(featureData));
featureGroup.setFeatures(utils.parseFeatureGroupSchema(featureData,
featureGroup.getTimeTravelFormat()));
}

LOGGER.info("Featuregroup features: " + featureGroup.getFeatures());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ public class FeatureGroupUtils {
private KafkaApi kafkaApi = new KafkaApi();
private SimpleDateFormat dateFormat = new SimpleDateFormat("yyyyMMddHHmmssSSS");

public <S> List<Feature> parseFeatureGroupSchema(S datasetGeneric)
public <S> List<Feature> parseFeatureGroupSchema(S datasetGeneric, TimeTravelFormat timeTravelFormat)
throws FeatureStoreException {
return SparkEngine.getInstance().parseFeatureGroupSchema(datasetGeneric);
return SparkEngine.getInstance().parseFeatureGroupSchema(datasetGeneric, timeTravelFormat);
}

public <S> S sanitizeFeatureNames(S datasetGeneric) throws FeatureStoreException {
Expand Down
49 changes: 45 additions & 4 deletions java/src/main/java/com/logicalclocks/hsfs/engine/SparkEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,21 @@
import org.apache.spark.sql.streaming.DataStreamWriter;
import org.apache.spark.sql.streaming.StreamingQuery;
import org.apache.spark.sql.streaming.StreamingQueryException;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.BinaryType;
import org.apache.spark.sql.types.BooleanType;
import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.DateType;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.FloatType;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.LongType;
import org.apache.spark.sql.types.ShortType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.TimestampType;
import scala.collection.JavaConverters;

import java.io.IOException;
Expand Down Expand Up @@ -614,16 +628,43 @@ public void streamToHudiTable(StreamFeatureGroup streamFeatureGroup, Map<String,
writeOptions = utils.getKafkaConfig(streamFeatureGroup, writeOptions);
hudiEngine.streamToHoodieTable(sparkSession, streamFeatureGroup, writeOptions);
}

public <S> List<Feature> parseFeatureGroupSchema(S datasetGeneric) throws FeatureStoreException {

public <S> List<Feature> parseFeatureGroupSchema(S datasetGeneric,
TimeTravelFormat timeTravelFormat) throws FeatureStoreException {
List<Feature> features = new ArrayList<>();
Dataset<Row> dataset = (Dataset<Row>) datasetGeneric;
Boolean usingHudi = timeTravelFormat == TimeTravelFormat.HUDI;
for (StructField structField : dataset.schema().fields()) {
// TODO(Fabio): unit test this one for complext types
Feature f = new Feature(structField.name().toLowerCase(), structField.dataType().catalogString(), false, false);
String featureType = "";
if (!usingHudi) {
featureType = structField.dataType().catalogString();
} else if (structField.dataType() instanceof ByteType) {
featureType = "int";
} else if (structField.dataType() instanceof ShortType) {
featureType = "int";
} else if (structField.dataType() instanceof BooleanType
|| structField.dataType() instanceof IntegerType
|| structField.dataType() instanceof LongType
|| structField.dataType() instanceof FloatType
|| structField.dataType() instanceof DoubleType
|| structField.dataType() instanceof DecimalType
|| structField.dataType() instanceof TimestampType
|| structField.dataType() instanceof DateType
|| structField.dataType() instanceof StringType
|| structField.dataType() instanceof ArrayType
|| structField.dataType() instanceof StructType
|| structField.dataType() instanceof BinaryType) {
featureType = structField.dataType().catalogString();
} else {
throw new FeatureStoreException("Feature '" + structField.name().toLowerCase() + "': "
+ "spark type " + structField.dataType().catalogString() + " not supported.");
}

Feature f = new Feature(structField.name().toLowerCase(), featureType, false, false);
if (structField.metadata().contains("description")) {
f.setDescription(structField.metadata().getString("description"));
}

features.add(f);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ public <S> StreamFeatureGroup saveFeatureGroupMetaData(StreamFeatureGroup featur
throws FeatureStoreException, IOException, ParseException {

if (featureGroup.getFeatures() == null) {
featureGroup.setFeatures(utils.parseFeatureGroupSchema(utils.sanitizeFeatureNames(featureData)));
featureGroup.setFeatures(utils
.parseFeatureGroupSchema(utils.sanitizeFeatureNames(featureData),
featureGroup.getTimeTravelFormat()));
}

LOGGER.info("Featuregroup features: " + featureGroup.getFeatures());
Expand Down
86 changes: 79 additions & 7 deletions python/hsfs/core/feature_group_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from hsfs import engine, client, util
from hsfs import feature_group as fg
from hsfs.client import exceptions
from hsfs.client.exceptions import FeatureStoreException
from hsfs.core import feature_group_base_engine, hudi_engine


Expand All @@ -29,8 +30,12 @@ def __init__(self, feature_store_id):

def save(self, feature_group, feature_dataframe, write_options, validation_options):

dataframe_features = engine.get_instance().parse_schema_feature_group(
feature_dataframe, feature_group.time_travel_format
)

self._save_feature_group_metadata(
feature_group, feature_dataframe, write_options
feature_group, dataframe_features, write_options
)

# ge validation on python and non stream feature groups on spark
Expand Down Expand Up @@ -70,9 +75,19 @@ def insert(
validation_options,
):

dataframe_features = engine.get_instance().parse_schema_feature_group(
feature_dataframe, feature_group.time_travel_format
)

if not feature_group._id:
# only save metadata if feature group does not exist
self._save_feature_group_metadata(
feature_group, feature_dataframe, write_options
feature_group, dataframe_features, write_options
)
else:
# else, just verify that feature group schema matches user-provided dataframe
self._verify_schema_compatibility(
feature_group.features, dataframe_features
)

# ge validation on python and non stream feature groups on spark
Expand Down Expand Up @@ -253,8 +268,14 @@ def insert_stream(
"It is currently only possible to stream to the online storage."
)

dataframe_features = engine.get_instance().parse_schema_feature_group(
dataframe, feature_group.time_travel_format
)

if not feature_group._id:
self._save_feature_group_metadata(feature_group, dataframe, write_options)
self._save_feature_group_metadata(
feature_group, dataframe_features, write_options
)

if not feature_group.stream:
# insert_stream method was called on non stream feature group object that has not been saved.
Expand All @@ -272,6 +293,11 @@ def insert_stream(
offline_write_options,
online_write_options,
)
else:
# else, just verify that feature group schema matches user-provided dataframe
self._verify_schema_compatibility(
feature_group.features, dataframe_features
)

if not feature_group.stream:
warnings.warn(
Expand All @@ -292,15 +318,61 @@ def insert_stream(

return streaming_query

def _verify_schema_compatibility(self, feature_group_features, dataframe_features):
err = []
feature_df_dict = {feat.name: feat.type for feat in dataframe_features}
for feature_fg in feature_group_features:
fg_type = feature_fg.type.lower().replace(" ", "")
# check if feature exists dataframe
if feature_fg.name in feature_df_dict:
df_type = feature_df_dict[feature_fg.name].lower().replace(" ", "")
# remove match from lookup table
del feature_df_dict[feature_fg.name]

# check if types match
if fg_type != df_type:
# don't check structs for exact match
if fg_type.startswith("struct") and df_type.startswith("struct"):
continue

err += [
f"{feature_fg.name} ("
f"expected type: '{fg_type}', "
f"derived from input: '{df_type}') has the wrong type."
]

else:
err += [
f"{feature_fg.name} (type: '{feature_fg.type}') is missing from "
f"input dataframe."
]

# any features that are left in lookup table are superfluous
for feature_df_name, feature_df_type in feature_df_dict.items():
err += [
f"{feature_df_name} (type: '{feature_df_type}') does not exist "
f"in feature group."
]

# raise exception if any errors were found.
if len(err) > 0:
raise FeatureStoreException(
"Features are not compatible with Feature Group schema: "
+ "".join(["\n - " + e for e in err])
)

def _save_feature_group_metadata(
self, feature_group, feature_dataframe, write_options
self, feature_group, dataframe_features, write_options
):

# this means FG doesn't exist and should create the new one
if len(feature_group.features) == 0:
# User didn't provide a schema. extract it from the dataframe
feature_group._features = engine.get_instance().parse_schema_feature_group(
feature_dataframe
# User didn't provide a schema; extract it from the dataframe
feature_group._features = dataframe_features
else:
# User provided a schema; check if it is compatible with dataframe.
self._verify_schema_compatibility(
feature_group.features, dataframe_features
)

# set primary and partition key columns
Expand Down
81 changes: 54 additions & 27 deletions python/hsfs/engine/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from tqdm.auto import tqdm

from hsfs import client, feature, util
from hsfs.client.exceptions import FeatureStoreException
from hsfs.core import (
feature_group_api,
dataset_api,
Expand Down Expand Up @@ -337,7 +338,8 @@ def convert_to_default_dataframe(self, dataframe):
]
if len(upper_case_features) > 0:
warnings.warn(
"The ingested dataframe contains upper case letters in feature names: `{}`. Feature names are sanitized to lower case in the feature store.".format(
"The ingested dataframe contains upper case letters in feature names: `{}`. "
"Feature names are sanitized to lower case in the feature store.".format(
upper_case_features
),
util.FeatureGroupWarning,
Expand All @@ -353,61 +355,86 @@ def convert_to_default_dataframe(self, dataframe):
+ "The provided dataframe has type: {}".format(type(dataframe))
)

def parse_schema_feature_group(self, dataframe):
def parse_schema_feature_group(self, dataframe, time_travel_format):
arrow_schema = pa.Schema.from_pandas(dataframe)
return [
feature.Feature(
feat_name.lower(),
self._convert_pandas_type(feat_name, feat_type, arrow_schema),
)
for feat_name, feat_type in dataframe.dtypes.items()
]
features = []
for feat_name, feat_type in dataframe.dtypes.items():
name = feat_name.lower()
try:
converted_type = self._convert_pandas_type(
feat_type, arrow_schema.field(feat_name).type
)
except ValueError as e:
raise FeatureStoreException(f"Feature '{name}': {str(e)}")
features.append(feature.Feature(name, converted_type))
return features

def parse_schema_training_dataset(self, dataframe):
raise NotImplementedError(
"Training dataset creation from Dataframes is not "
+ "supported in Python environment. Use HSFS Query object instead."
)

def _convert_pandas_type(self, feat_name, dtype, arrow_schema):
def _convert_pandas_type(self, dtype, arrow_type):
# This is a simple type conversion between pandas dtypes and pyspark (hive) types,
# using pyarrow types to convert "O (object)"-typed fields.
# In the backend, the types specified here will also be used for mapping to Avro types.
if dtype == np.dtype("O"):
return self._infer_type_pyarrow(feat_name, arrow_schema)
return self._infer_type_pyarrow(arrow_type)

return self._convert_simple_pandas_type(dtype)

def _convert_simple_pandas_type(self, dtype):
# This is a simple type conversion between pandas type and pyspark types.
# In PySpark they use PyArrow to do the schema conversion, but this python layer
# should be as thin as possible. Adding PyArrow will make the library less flexible.
# If the conversion fails, users can always fall back and provide their own types

if dtype == np.dtype("O"):
return "string"
if dtype == np.dtype("uint8"):
return "int"
elif dtype == np.dtype("uint16"):
return "int"
elif dtype == np.dtype("int8"):
return "int"
elif dtype == np.dtype("int16"):
return "int"
elif dtype == np.dtype("int32"):
return "int"
elif dtype == np.dtype("uint32"):
return "bigint"
elif dtype == np.dtype("int64"):
return "bigint"
elif dtype == np.dtype("float16"):
return "float"
elif dtype == np.dtype("float32"):
return "float"
elif dtype == np.dtype("float64"):
return "double"
elif dtype == np.dtype("datetime64[ns]"):
return "timestamp"
elif dtype == np.dtype("bool"):
return "bool"

return "string"
return "boolean"
elif dtype == "category":
return "string"

def _infer_type_pyarrow(self, field, schema):
arrow_type = schema.field(field).type
raise ValueError(f"dtype '{dtype}' not supported")

def _infer_type_pyarrow(self, arrow_type):
if pa.types.is_list(arrow_type):
# figure out sub type
subtype = self._convert_simple_pandas_type(
arrow_type.value_type.to_pandas_dtype()
)
sub_arrow_type = arrow_type.value_type
sub_dtype = np.dtype(sub_arrow_type.to_pandas_dtype())
subtype = self._convert_pandas_type(sub_dtype, sub_arrow_type)
return "array<{}>".format(subtype)
return "string"
if pa.types.is_struct(arrow_type):
# best effort, based on pyarrow's string representation
return str(arrow_type)
# Currently not supported
# elif pa.types.is_decimal(arrow_type):
# return str(arrow_type).replace("decimal128", "decimal")
elif pa.types.is_date(arrow_type):
return "date"
elif pa.types.is_binary(arrow_type):
return "binary"
elif pa.types.is_string(arrow_type) or pa.types.is_unicode(arrow_type):
return "string"

raise ValueError(f"dtype 'O' (arrow_type '{str(arrow_type)}') not supported")

def save_dataframe(
self,
Expand Down
Loading

0 comments on commit 94f2b06

Please sign in to comment.