Skip to content

Commit

Permalink
Change max age to integer, filter source feature tables, tests for la…
Browse files Browse the repository at this point in the history
…rge dataframe

Signed-off-by: Khor Shu Heng <khor.heng@gojek.com>
  • Loading branch information
khorshuheng committed Oct 6, 2020
1 parent 7522b88 commit 1a36737
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 14 deletions.
28 changes: 22 additions & 6 deletions sdk/python/feast/pyspark/historical_feature_retrieval_job.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from datetime import timedelta
from typing import Any, Dict, List

from pyspark import SparkFiles
Expand All @@ -12,7 +13,7 @@ def as_of_join(
feature_table: DataFrame,
features: List[str],
feature_prefix: str = "",
max_age: str = None,
max_age: int = None,
) -> DataFrame:
"""Perform an as of join between entity and feature table, given a maximum age tolerance.
Join conditions:
Expand All @@ -36,8 +37,8 @@ def as_of_join(
feature_prefix (str):
Feature column prefix for the result dataframe. Useful for cases where the entity dataframe
contains one or more columns that share the same name as the features.
max_age (str):
Tolerance for the feature event timestamp recency.
max_age (int):
Tolerance for the feature event timestamp recency, in seconds.
Returns:
DataFrame: Join result.
Expand Down Expand Up @@ -66,7 +67,7 @@ def as_of_join(
| 1001|2020-09-02 00:00:00| 200|
+------+-------------------+--------------+
>>> df = as_of_join(entity, ["entity"], feature_table, ["feature"], max_age = "12 hour")
>>> df = as_of_join(entity, ["entity"], feature_table, ["feature"], max_age = 12 * 60 * 60)
>>> df.show()
+------+-------------------+-------+
|entity| event_timestamp|feature|
Expand Down Expand Up @@ -96,7 +97,7 @@ def as_of_join(
if max_age:
join_cond = join_cond & (
selected_feature_table[feature_event_timestamp]
>= entity_with_id.event_timestamp - expr(f"INTERVAL {max_age}")
>= entity_with_id.event_timestamp - expr(f"INTERVAL {max_age} seconds")
)

for key in entity_keys:
Expand Down Expand Up @@ -317,7 +318,7 @@ def retrieve_historical_features(spark: SparkSession, conf: Dict) -> DataFrame:
"table": "transactions",
"features": ["daily_transactions"],
"join": ["customer_id"],
"max_age": "2 day",
"max_age": 172800,
},
{
"table": "bookings",
Expand All @@ -337,6 +338,9 @@ def retrieve_historical_features(spark: SparkSession, conf: Dict) -> DataFrame:
`options` is optional. If present, the options will be used when reading / writing the input / output.
`max_age` is in seconds, and determines the lower bound of the timestamp of the retrieved feature.
If not specified, this would be unbounded.
If necessary, `col_mapping` can be provided to map the columns of the dataframes before performing
the join operation. `col_mapping` is a dictionary where the key is the source column and the value
is the mapped column.
Expand Down Expand Up @@ -384,6 +388,18 @@ def map_column(df: DataFrame, col_mapping: Dict[str, str]):
is_feature_table=True,
)

max_timestamp = mapped_entity_df.agg({"event_timestamp": "max"}).collect()[0][0]
min_timestamp = mapped_entity_df.agg({"event_timestamp": "min"}).collect()[0][0]

for query in conf["queries"]:
max_age = query.get("max_age")
if max_age:
tables[query["table"]] = tables[query["table"]].filter(
col("event_timestamp").between(
min_timestamp - timedelta(seconds=max_age), max_timestamp
)
)

return join_entity_to_feature_tables(conf["queries"], mapped_entity_df, tables)


Expand Down
123 changes: 115 additions & 8 deletions sdk/python/tests/test_as_of_join.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import os
import pathlib
from datetime import datetime
import shutil
import tempfile
from datetime import datetime, timedelta
from os import path
from typing import Any, Dict, List

import pytest
from pyspark.sql import DataFrame, SparkSession
Expand Down Expand Up @@ -35,6 +39,62 @@ def spark(pytestconfig):
spark_session.stop()


@pytest.yield_fixture(scope="module")
def large_entity_csv_file(pytestconfig, spark):
start_datetime = datetime(year=2020, month=8, day=31)
nr_rows = 1000
entity_data = [
(1000 + i, start_datetime + timedelta(days=i)) for i in range(nr_rows)
]
temp_dir = tempfile.mkdtemp()
file_path = os.path.join(temp_dir, "large_entity")
entity_schema = StructType(
[
StructField("customer_id", IntegerType()),
StructField("event_timestamp", TimestampType()),
]
)
large_entity_df = spark.createDataFrame(
spark.sparkContext.parallelize(entity_data), entity_schema
)

large_entity_df.write.option("header", "true").csv(file_path)
yield file_path
shutil.rmtree(temp_dir)


@pytest.yield_fixture(scope="module")
def large_feature_csv_file(pytestconfig, spark):
start_datetime = datetime(year=2020, month=8, day=30)
nr_rows = 1000
feature_data = [
(
1000 + i,
start_datetime + timedelta(days=i),
start_datetime + timedelta(days=i + 1),
i * 10,
)
for i in range(nr_rows)
]
temp_dir = tempfile.mkdtemp()
file_path = os.path.join(temp_dir, "large_feature")
feature_schema = StructType(
[
StructField("customer_id", IntegerType()),
StructField("event_timestamp", TimestampType()),
StructField("created_timestamp", TimestampType()),
StructField("total_bookings", IntegerType()),
]
)
large_feature_df = spark.createDataFrame(
spark.sparkContext.parallelize(feature_data), feature_schema
)

large_feature_df.write.option("header", "true").csv(file_path)
yield file_path
shutil.rmtree(temp_dir)


@pytest.fixture
def single_entity_schema():
return StructType(
Expand Down Expand Up @@ -235,7 +295,7 @@ def test_join_with_max_age(
feature_table_df,
["daily_transactions"],
feature_prefix="transactions__",
max_age="1 day",
max_age=86400,
)

expected_joined_schema = StructType(
Expand Down Expand Up @@ -308,7 +368,7 @@ def test_join_with_composite_entity(
feature_table_df,
["customer_rating", "driver_rating"],
feature_prefix="ratings__",
max_age="1 day",
max_age=86400,
)

expected_joined_schema = StructType(
Expand Down Expand Up @@ -397,12 +457,12 @@ def test_multiple_join(
customer_feature_schema: StructType,
driver_feature_schema: StructType,
):
query_conf = [
query_conf: List[Dict[str, Any]] = [
{
"table": "transactions",
"features": ["daily_transactions"],
"join": ["customer_id"],
"max_age": "1 day",
"max_age": 86400,
},
{
"table": "bookings",
Expand Down Expand Up @@ -524,7 +584,7 @@ def test_historical_feature_retrieval(spark):
"table": "transactions",
"features": ["daily_transactions"],
"join": ["customer_id"],
"max_age": "1 day",
"max_age": 86400,
},
{
"table": "bookings",
Expand Down Expand Up @@ -562,7 +622,7 @@ def test_historical_feature_retrieval(spark):

def test_historical_feature_retrieval_with_mapping(spark):
test_data_dir = path.join(pathlib.Path(__file__).parent.absolute(), "data")
batch_retrieval_conf = {
retrieval_conf = {
"entity": {
"format": "csv",
"path": f"file://{path.join(test_data_dir, 'column_mapping_test_entity.csv')}",
Expand Down Expand Up @@ -592,7 +652,7 @@ def test_historical_feature_retrieval_with_mapping(spark):
],
}

joined_df = retrieve_historical_features(spark, batch_retrieval_conf)
joined_df = retrieve_historical_features(spark, retrieval_conf)

expected_joined_schema = StructType(
[
Expand All @@ -616,6 +676,53 @@ def test_historical_feature_retrieval_with_mapping(spark):
assert_dataframe_equal(joined_df, expected_joined_df)


def test_large_historical_feature_retrieval(
spark, large_entity_csv_file, large_feature_csv_file
):
nr_rows = 1000
start_datetime = datetime(year=2020, month=8, day=31)
expected_join_data = [
(1000 + i, start_datetime + timedelta(days=i), i * 10) for i in range(nr_rows)
]
expected_join_data_schema = StructType(
[
StructField("customer_id", IntegerType()),
StructField("event_timestamp", TimestampType()),
StructField("feature__total_bookings", IntegerType()),
]
)

expected_join_data_df = spark.createDataFrame(
spark.sparkContext.parallelize(expected_join_data), expected_join_data_schema
)

retrieval_conf = {
"entity": {
"format": "csv",
"path": f"file://{large_entity_csv_file}",
"options": {"inferSchema": "true", "header": "true"},
},
"tables": [
{
"format": "csv",
"path": f"file://{large_feature_csv_file}",
"name": "feature",
"options": {"inferSchema": "true", "header": "true"},
},
],
"queries": [
{
"table": "feature",
"features": ["total_bookings"],
"join": ["customer_id"],
}
],
}

joined_df = retrieve_historical_features(spark, retrieval_conf)
assert_dataframe_equal(joined_df, expected_join_data_df)


def test_schema_verification(spark):
entity_schema = StructType(
[
Expand Down

0 comments on commit 1a36737

Please sign in to comment.