Skip to content

Commit

Permalink
[FSTORE-359] Rename Transformation Function Output Types and Fix Time…
Browse files Browse the repository at this point in the history
…zone-related issues (logicalclocks#829)

* remove unused import

* spark types for tf

* spark types for tf

* tmp

* merge rebase with upstream/master

* [HOPSWORKS-3342][Append] unit tests (logicalclocks#9)

* added unit tests for transformation functions

* fixed stylecheck

* fixed a small thing in docs

* infer_python_type

* fix tests

* fix tests

* fix checks

* removed legacy types cast from test

* add convert_column method, fix style

* updated builtin transformation functions

* style

(cherry picked from commit 6ad26c4cb72cfd286dbb3c7049cd72f0aed3ab78)

* made transformation functions timezone-safe

(cherry picked from commit 176c2da134517648d568f69784b1e13c56c580f1)

* style

* removed deprecated np types

* tests tf function returning none

* style

* fixed tests

* style

Co-authored-by: davitbzh <davit.bzhalava@gmail.com>
  • Loading branch information
2 people authored and kennethmhc committed Nov 16, 2022
1 parent f439332 commit bf363c4
Show file tree
Hide file tree
Showing 7 changed files with 296 additions and 121 deletions.
56 changes: 17 additions & 39 deletions python/hsfs/core/transformation_function_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,62 +177,40 @@ def populate_builtin_attached_fns(self, attached_transformation_fns, stat_conten
@staticmethod
def infer_spark_type(output_type):
if not output_type:
return "StringType()" # StringType() is default type for spark udfs
return "STRING" # STRING is default type for spark udfs

if isinstance(output_type, str):
output_type = output_type.lower()

if output_type in (str, "str", "string"):
return "StringType()"
elif output_type in (bytes, "binary"):
return "BinaryType()"
elif output_type in (numpy.int8, "int8", "byte", "tinyint"):
return "ByteType()"
elif output_type in (numpy.int16, "int16", "short", "smallint"):
return "ShortType()"
elif output_type in (int, "int", numpy.int, numpy.int32):
return "IntegerType()"
elif output_type in (numpy.int64, "int64", "long", "bigint"):
return "LongType()"
elif output_type in (float, "float", numpy.float):
return "FloatType()"
elif output_type in (numpy.float64, "float64", "double"):
return "DoubleType()"
elif output_type in (datetime.datetime, numpy.datetime64, "datetime"):
return "TimestampType()"
elif output_type in (datetime.date, "date"):
return "DateType()"
elif output_type in (bool, "boolean", "bool", numpy.bool):
return "BooleanType()"
else:
raise TypeError("Not supported type %s." % output_type)

@staticmethod
def convert_legacy_type(output_type):
if output_type == "StringType()":
return "STRING"
elif output_type == "BinaryType()":
elif output_type in (bytes, "binary"):
return "BINARY"
elif output_type == "ByteType()":
elif output_type in (numpy.int8, "int8", "byte", "tinyint"):
return "BYTE"
elif output_type == "ShortType()":
elif output_type in (numpy.int16, "int16", "short", "smallint"):
return "SHORT"
elif output_type == "IntegerType()":
elif output_type in (int, "int", "integer", numpy.int32):
return "INT"
elif output_type == "LongType()":
elif output_type in (numpy.int64, "int64", "long", "bigint"):
return "LONG"
elif output_type == "FloatType()":
elif output_type in (float, "float"):
return "FLOAT"
elif output_type == "DoubleType()":
elif output_type in (numpy.float64, "float64", "double"):
return "DOUBLE"
elif output_type == "TimestampType()":
elif output_type in (
datetime.datetime,
numpy.datetime64,
"datetime",
"timestamp",
):
return "TIMESTAMP"
elif output_type == "DateType()":
elif output_type in (datetime.date, "date"):
return "DATE"
elif output_type == "BooleanType()":
elif output_type in (bool, "boolean", "bool", numpy.bool):
return "BOOLEAN"
else:
return "STRING" # handle gracefully, and return STRING type, the default for spark udfs
raise TypeError("Not supported type %s." % output_type)

@staticmethod
def compute_transformation_fn_statistics(
Expand Down
30 changes: 15 additions & 15 deletions python/hsfs/engine/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,29 +825,29 @@ def _apply_transformation_function(self, transformation_functions, dataset):

return dataset

@staticmethod
def convert_column(output_type, feature_column):
if output_type in ("StringType()",):
def convert_column(self, output_type, feature_column):
if output_type == "STRING":
return feature_column.astype(str)
elif output_type in ("BinaryType()",):
elif output_type == "BINARY":
return feature_column.astype(bytes)
elif output_type in ("ByteType()",):
elif output_type == "BYTE":
return feature_column.astype(np.int8)
elif output_type in ("ShortType()",):
elif output_type == "SHORT":
return feature_column.astype(np.int16)
elif output_type in ("IntegerType()",):
elif output_type == "INT":
return feature_column.astype(int)
elif output_type in ("LongType()",):
elif output_type == "LONG":
return feature_column.astype(np.int64)
elif output_type in ("FloatType()",):
elif output_type == "FLOAT":
return feature_column.astype(float)
elif output_type in ("DoubleType()",):
elif output_type == "DOUBLE":
return feature_column.astype(np.float64)
elif output_type in ("TimestampType()",):
return pd.to_datetime(feature_column)
elif output_type in ("DateType()",):
return pd.to_datetime(feature_column).dt.date
elif output_type in ("BooleanType()",):
elif output_type == "TIMESTAMP":
# convert (if tz!=UTC) to utc, then make timezone unaware
return pd.to_datetime(feature_column, utc=True).dt.tz_localize(None)
elif output_type == "DATE":
return pd.to_datetime(feature_column, utc=True).dt.date
elif output_type == "BOOLEAN":
return feature_column.astype(bool)
else:
return feature_column # handle gracefully, just return the column as-is
Expand Down
32 changes: 27 additions & 5 deletions python/hsfs/engine/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
import numpy as np
import pandas as pd
import avro
from datetime import datetime
from datetime import datetime, timezone

# in case importing in %%local

try:
from pyspark import SparkFiles
from pyspark.sql import SparkSession, DataFrame, SQLContext
Expand Down Expand Up @@ -888,12 +889,33 @@ def _apply_transformation_function(self, transformation_functions, dataset):
+ "_"
+ feature_name
)

def timezone_decorator(func):
if transformation_fn.output_type != "TIMESTAMP":
return func

current_timezone = datetime.now().astimezone().tzinfo

def decorated_func(x):
result = func(x)
if isinstance(result, datetime):
if result.tzinfo is None:
# if timestamp is timezone unaware, make sure it's localized to the system's timezone.
# otherwise, spark will implicitly convert it to the system's timezone.
return result.replace(tzinfo=current_timezone)
else:
# convert to utc, then localize to system's timezone
return result.astimezone(timezone.utc).replace(
tzinfo=current_timezone
)
return result

return decorated_func

self._spark_session.udf.register(
fn_registration_name,
transformation_fn.transformation_fn,
transformation_function_engine.TransformationFunctionEngine.convert_legacy_type(
transformation_fn.output_type
),
timezone_decorator(transformation_fn.transformation_fn),
transformation_fn.output_type,
)
transformation_fn_expressions.append(
"{fn_name:}({name:}) AS {name:}".format(
Expand Down
Loading

0 comments on commit bf363c4

Please sign in to comment.