Skip to content

Commit

Permalink
[HOPSWORKS-1762] Add support for Spark-Tfrecord (#94)
Browse files Browse the repository at this point in the history
* Spark-TFRecord
  • Loading branch information
davitbzh authored Aug 26, 2020
1 parent 141677e commit d0e83ab
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion python/hsfs/engine/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ def write_options(self, data_format, provided_options):
if data_format.lower() == "tfrecords":
options = dict(recordType="Example")
options.update(provided_options)
elif data_format.lower() == "tfrecord":
options = dict(recordType="Example")
options.update(provided_options)
elif data_format.lower() == "csv":
options = dict(delimiter=",", header="true")
options.update(provided_options)
Expand All @@ -212,6 +215,9 @@ def read_options(self, data_format, provided_options):
if data_format.lower() == "tfrecords":
options = dict(recordType="Example", **provided_options)
options.update(provided_options)
elif data_format.lower() == "tfrecord":
options = dict(recordType="Example")
options.update(provided_options)
elif data_format.lower() == "csv":
options = dict(delimiter=",", header="true", inferSchema="true")
options.update(provided_options)
Expand All @@ -235,7 +241,7 @@ def parse_schema(self, dataframe):

def parse_schema_dict(self, dataframe):
return {
feat["name"]: feature.Feature(
feat.name: feature.Feature(
feat.name,
feat.dataType.simpleString(),
feat.metadata.get("description", ""),
Expand Down

0 comments on commit d0e83ab

Please sign in to comment.