Skip to content

Commit

Permalink
dtypes: Add type converstion to pyspark types
Browse files Browse the repository at this point in the history
  • Loading branch information
aditya-nambiar committed Sep 26, 2024
1 parent 94f5a02 commit a21e5fc
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 3,878 deletions.
27 changes: 1 addition & 26 deletions fennel/client_tests/test_featureset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,31 +1232,6 @@ class IndexFeatures:
assert response.status_code == requests.codes.OK, response.json()


<<<<<<< HEAD
=======
@featureset
class QueryOffline:
MER_TRN_COUNT_1H: int
MER_TRN_COUNT_1H_SQ: int = F(col("MER_TRN_COUNT_1H") * col("MER_TRN_COUNT_1H"))
MER_TRN_COUNT_1H_DOUBLE: float = F(col("MER_TRN_COUNT_1H") * 2.0)
MER_TRN_COUNT_1H_HALF: float
MER_TRN_COUNT_1H_QUARTER: float
MER_TRN_COUNT_1H_RANDOM: float

@extractor
@inputs("MER_TRN_COUNT_1H")
@outputs("MER_TRN_COUNT_1H_HALF")
def calc(cls, ts: pd.Series, MER_TRN_COUNT_1H: pd.Series):
df = pd.DataFrame({"MER_TRN_COUNT_1H_HALF": MER_TRN_COUNT_1H * 0.5})
return df

@extractor
@inputs("MER_TRN_COUNT_1H")
@outputs("MER_TRN_COUNT_1H_QUARTER", "MER_TRN_COUNT_1H_RANDOM")
def calc2(cls, ts: pd.Series, MER_TRN_COUNT_1H: pd.Series):
df = pd.DataFrame({"MER_TRN_COUNT_1H_QUARTER": MER_TRN_COUNT_1H * 0.25, "MER_TRN_COUNT_1H_RANDOM": np.random.rand(len(MER_TRN_COUNT_1H))})
return df

@pytest.mark.integration
@mock
def test_query_time_features(client):
Expand Down Expand Up @@ -1288,4 +1263,4 @@ def time_feature_extractor(cls, ts: pd.Series) -> pd.DataFrame:
client.commit(
featuresets=[QueryTimeFeatures],
message="first_commit",
)
)
29 changes: 0 additions & 29 deletions fennel/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,35 +410,6 @@ def pa_to_pd(pa_data, ret_type, parse=True):
ret_type = output_dtype

serialized_ret_type = get_datatype(ret_type).SerializeToString()
print("I AM HERE")
print(proto_bytes)
print(df_pa)
print(proto_schema)
print(serialized_ret_type)
print("="*100)

# Convert serialized_ret_type back to schema_proto.DataType
# Convert serialized_ret_type back to schema_proto.DataType
from fennel.gen import schema_pb2 as schema_proto
ret_type_proto = schema_proto.DataType()
ret_type_proto.ParseFromString(serialized_ret_type)
print("CONVERTTed back :", ret_type_proto)
python_type = from_proto(ret_type_proto)
print("PYTHON TYPE :", python_type)
import base64
# Base64 encode the proto_bytes
proto_bytes_base64 = base64.b64encode(proto_bytes).decode("utf-8")
base64_schema = {}
for key, value in proto_schema.items():
base64_schema[key] = base64.b64encode(value).decode("utf-8")
# Base64 encode proto_schema
# proto_schema_base64 = base64.b64encode(proto_schema).decode("utf-8")
# Base64 encode serialized_ret_type
serialized_ret_type_base64 = base64.b64encode(serialized_ret_type).decode("utf-8")

print("Expr\n", proto_bytes_base64)
print("Schema\n", base64_schema)
print("Ret Type\n", serialized_ret_type_base64)
arrow_col = assign(
proto_bytes, df_pa, proto_schema, serialized_ret_type
)
Expand Down
39 changes: 25 additions & 14 deletions fennel/internal_lib/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@
parse_datetime_in_value,
)

from typing import get_args, get_origin, Any, Union, List, Dict, Optional
import dataclasses
import datetime
import decimal
from pyspark.sql.types import (
DataType,
Expand Down Expand Up @@ -1018,14 +1015,16 @@ def to_spark_type(py_type: Any, nullable: bool = True) -> DataType:
args = get_args(py_type)
if len(args) == 2 and type(None) in args:
# It's Optional[T]
non_none_type = args[0] if args[1] is type(None) else args[1]
non_none_type = (
args[0] if args[1] is type(None) else args[1] # noqa: E721
) # noqa: E721
return to_spark_type(non_none_type, nullable=True)
else:
# Unions of multiple types are not directly supported; default to StringType
return StringType()
elif isinstance(py_type, _Embedding):
return ArrayType(DoubleType(), containsNull=False)
return ArrayType(DoubleType(), containsNull=False)

# Handle List[T]
elif origin in (list, List):
element_type = get_args(py_type)[0]
Expand All @@ -1035,7 +1034,9 @@ def to_spark_type(py_type: Any, nullable: bool = True) -> DataType:
# Handle Dict[K, V]
elif origin in (dict, Dict):
key_type, value_type = get_args(py_type)
spark_key_type = to_spark_type(key_type, nullable=False) # Keys cannot be null
spark_key_type = to_spark_type(
key_type, nullable=False
) # Keys cannot be null
spark_value_type = to_spark_type(value_type)
return MapType(spark_key_type, spark_value_type, valueContainsNull=True)

Expand All @@ -1053,10 +1054,20 @@ def to_spark_type(py_type: Any, nullable: bool = True) -> DataType:
field_args = get_args(field_type)
if len(field_args) == 2 and type(None) in field_args:
field_nullable = True
field_type = field_args[0] if field_args[1] is type(None) else field_args[1]
field_type = (
field_args[0]
if field_args[1] is type(None) # noqa: E721
else field_args[1]
)

spark_field_type = to_spark_type(field_type, nullable=field_nullable)
fields.append(StructField(field_name, spark_field_type, nullable=field_nullable))
spark_field_type = to_spark_type(
field_type, nullable=field_nullable
)
fields.append(
StructField(
field_name, spark_field_type, nullable=field_nullable
)
)
return StructType(fields)

# Handle basic types
Expand All @@ -1068,18 +1079,18 @@ def to_spark_type(py_type: Any, nullable: bool = True) -> DataType:
return StringType()
elif py_type is bool:
return BooleanType()
elif py_type is datetime.datetime:
elif py_type is datetime:
return TimestampType()
elif py_type is datetime.date:
elif py_type is date:
return DateType()
elif py_type is bytes:
return BinaryType()
elif py_type is decimal.Decimal:
# Default precision and scale; adjust as needed
return DecimalType(precision=38, scale=18)

elif py_type is type(None):
elif py_type is type(None): # noqa: E721
return NullType()

else:
raise ValueError(f"Unsupported type: {py_type}")
raise ValueError(f"Unsupported type: {py_type}")
71 changes: 44 additions & 27 deletions fennel/internal_lib/schema/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
)



def test_get_data_type():
assert get_datatype(int) == proto.DataType(int_type=proto.IntType())
assert get_datatype(Optional[int]) == proto.DataType(
Expand Down Expand Up @@ -987,38 +986,48 @@ def test_to_spark_type():
assert to_spark_type(str) == StringType()
assert to_spark_type(datetime) == TimestampType()
assert to_spark_type(date) == DateType()
# Types in pyspark are nullable by default
assert to_spark_type(Optional[int]) == LongType()

# Test complex types
assert to_spark_type(List[int]) == ArrayType(LongType(), containsNull=True)
assert to_spark_type(Dict[str, float]) == MapType(StringType(), DoubleType(), valueContainsNull=True)
assert to_spark_type(Optional[List[str]]) == ArrayType(StringType(), containsNull=True)

assert to_spark_type(Dict[str, float]) == MapType(
StringType(), DoubleType(), valueContainsNull=True
)
assert to_spark_type(Optional[List[str]]) == ArrayType(
StringType(), containsNull=True
)

# Test nested complex types
assert to_spark_type(List[Dict[str, List[float]]]) == ArrayType(
MapType(StringType(), ArrayType(DoubleType(), containsNull=True), valueContainsNull=True),
containsNull=True
MapType(
StringType(),
ArrayType(DoubleType(), containsNull=True),
valueContainsNull=True,
),
containsNull=True,
)

assert to_spark_type(Union[int, str, float]) == StringType()



# Test Embedding type (should default to ArrayType of DoubleType)
assert to_spark_type(Embedding[10]) == ArrayType(DoubleType(), containsNull=False)

assert to_spark_type(Embedding[10]) == ArrayType(
DoubleType(), containsNull=False
)

# Test complex nested structure
complex_type = Dict[str, List[Optional[Dict[int, Union[str, float]]]]]
expected_complex_type = MapType(
StringType(),
ArrayType(
MapType(
LongType(),
StringType(), # Union defaults to StringType
valueContainsNull=True
StringType(),
valueContainsNull=True,
),
containsNull=True
containsNull=True,
),
valueContainsNull=True
valueContainsNull=True,
)
assert to_spark_type(complex_type) == expected_complex_type

Expand All @@ -1034,18 +1043,26 @@ class Person:
age: int
address: Address
emails: List[str]

# Convert Person dataclass to StructType
spark_type = to_spark_type(Person)

expected_spark_type = StructType([
StructField("name", StringType(), False),
StructField("age", LongType(), False),
StructField("address", StructType([
StructField("street", StringType(), False),
StructField("city", StringType(), False),
StructField("zip_code", LongType(), True)
]), False),
StructField("emails", ArrayType(StringType()), False)
])

expected_spark_type = StructType(
[
StructField("name", StringType(), False),
StructField("age", LongType(), False),
StructField(
"address",
StructType(
[
StructField("street", StringType(), False),
StructField("city", StringType(), False),
StructField("zip_code", LongType(), True),
]
),
False,
),
StructField("emails", ArrayType(StringType()), False),
]
)
assert spark_type == expected_spark_type
4 changes: 0 additions & 4 deletions fennel/testing/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ def get_extractor_func(extractor_proto: ProtoExtractor) -> Callable:
code = (
extractor_proto.pycode.imports + extractor_proto.pycode.generated_code
)
print("-"*100)
print(code)
print("-"*100)
print("Entry point: ", extractor_proto.pycode.entry_point)
try:
sys.modules[fqn] = mod
exec(code, mod.__dict__)
Expand Down
Loading

0 comments on commit a21e5fc

Please sign in to comment.