Skip to content

Commit

Permalink
Update SparkSource to have proper comparable
Browse files Browse the repository at this point in the history
  • Loading branch information
thechopkins committed Oct 25, 2023
1 parent 0151961 commit 9852b6d
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,18 @@ def get_table_query_string(self) -> str:

return f"`{tmp_table_name}`"

# Note: Python requires redefining hash in child classes that override __eq__
def __hash__(self):
return super().__hash__()

def __eq__(self, other):
if not isinstance(other, SparkSource):
raise TypeError("Comparisons should only involve SparkSource class objects.")
return (
super().__eq__(other)
and self.spark_options == other.spark_options
)


class SparkOptions:
allowed_formats = [format.value for format in SparkSourceFormat]
Expand Down Expand Up @@ -282,6 +294,17 @@ def to_proto(self) -> DataSourceProto.SparkOptions:

return spark_options_proto

def __eq__(self, other: object) -> bool:
if not isinstance(other, SparkOptions):
raise TypeError("Comparisons should only involve SparkOptions class objects.")

return (
self.table == other.table
and self.query == other.query
and self.path == other.path
and self.file_format == other.file_format
)


class SavedDatasetSparkStorage(SavedDatasetStorage):
_proto_attr_name = "spark_storage"
Expand Down
53 changes: 53 additions & 0 deletions sdk/python/tests/unit/test_data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from feast.infra.offline_stores.file_source import FileSource
from feast.infra.offline_stores.redshift_source import RedshiftSource
from feast.infra.offline_stores.snowflake_source import SnowflakeSource
from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import SparkSource
from feast.types import Bool, Float32, Int64


Expand Down Expand Up @@ -233,3 +234,55 @@ def test_redshift_fully_qualified_table_name(source_kwargs, expected_name):
)

assert redshift_source.redshift_options.fully_qualified_table_name == expected_name

@pytest.mark.parameterize(
"test_data,are_equal",
[
(
SparkSource(
name='name',
table='table',
query='query',
file_format='file_format'
),
True
),
(
SparkSource(
table='table',
query='query',
file_format='file_format'
),
False
),
(
SparkSource(
name='name',
table='table',
query='query',
file_format='file_format1'
),
False
),
(
SparkSource(
name='name',
table='table',
query='query1',
file_format='file_format'
),
True
),
]
)
def test_spark_source_equality(test_data, are_equal):
default = SparkSource(
name='name',
table='table1',
query='query',
file_format='file_format'
)
if are_equal:
assert default == test_data
else:
assert default != test_data

0 comments on commit 9852b6d

Please sign in to comment.