Skip to content

Commit

Permalink
Add fields parameter in join to fetch a subset of fields from rhs (#535)
Browse files Browse the repository at this point in the history
* Add fields parameter in join to fetch a subset of fields from rhs

* Add right fields filtering in join to mock executor

* Remove a check disallowing right_fields from containing right's timestamp

* Add more checks and a test

* Update proto generated files and fix test.

* Add an integration test in social network scenario

* Address comments

* Rebased

* Fix when rhs fields has timestamp included

---------

Co-authored-by: Aditya Nambiar <aditya.nambiar007@gmail.com>
  • Loading branch information
satrana42 and aditya-nambiar authored Sep 7, 2024
1 parent 1d22322 commit 8427e32
Show file tree
Hide file tree
Showing 12 changed files with 650 additions and 143 deletions.
3 changes: 2 additions & 1 deletion docs/pages/api-reference/operators/join.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sides must have the same data types.
<Expandable title="within" type="Tuple[Duration, Duration]" defaultVal='("forever", "0s")'>
Optional kwarg specifying the time window relative to the left side timestamp
within which the join should be performed. This can be seen as adding another
condition to join like `WHERE left_time - d1 < right_time AND right_time < left_time + d1`
condition to join like `WHERE left_time - d1 < right_time AND right_time < left_time + d2`
where (d1, d2) = within.
- The first value in the tuple represents how far back in time should a join
happen. The term "forever" means that we can go infinitely back in time
when searching for an event to join from the left-hand side data.
Expand Down
3 changes: 3 additions & 0 deletions fennel/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

## [1.5.19] - 2024-09-06
- Add ability to specify fields in join

## [1.5.18] - 2024-09-05
- Struct initializer + arrow fixes + type promotion in assign

Expand Down
145 changes: 145 additions & 0 deletions fennel/client_tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,151 @@ def pipeline_join(cls, info: Dataset, sale: Dataset):
]


@meta(owner="test@test.com")
@source(
webhook.endpoint("MovieRevenueWithRightFields"),
disorder="14d",
cdc="upsert",
)
@dataset(index=True)
class MovieRevenueWithRightFields:
movie: oneof(str, ["Jumanji", "Titanic", "RaOne", "ABC"]) = field( # type: ignore
key=True
)
revenue: int
extra_field: int
t: datetime


@meta(owner="satwant@fennel.ai")
@dataset(index=True)
class MovieStatsWithRightFields:
movie: oneof(str, ["Jumanji", "Titanic", "RaOne", "ABC"]) = field( # type: ignore
key=True
)
rating: float
revenue_in_millions: float
t: Optional[datetime]
ts: datetime = field(timestamp=True)

@pipeline
@inputs(MovieRating, MovieRevenueWithRightFields)
def pipeline_join(cls, rating: Dataset, revenue: Dataset):
def to_millions(df: pd.DataFrame) -> pd.DataFrame:
df[str(cls.revenue_in_millions)] = df["revenue"] / 1000000
df[str(cls.revenue_in_millions)].fillna(-1, inplace=True)
return df[
[
str(cls.movie),
str(cls.t),
str(cls.ts),
str(cls.revenue_in_millions),
str(cls.rating),
]
]

rating = rating.rename({"t": "ts"}) # type: ignore
c = rating.join(
revenue, how="left", on=[str(cls.movie)], fields=["revenue", "t"]
)
# Transform provides additional columns which will be filtered out.
return c.transform(
to_millions,
schema={
str(cls.movie): oneof(
str, ["Jumanji", "Titanic", "RaOne", "ABC"]
),
str(cls.rating): float,
str(cls.t): Optional[datetime],
str(cls.ts): datetime,
str(cls.revenue_in_millions): float,
},
)


class TestBasicJoinWithRightFields(unittest.TestCase):
@pytest.mark.integration
@mock
def test_basic_join_with_fields(self, client):
# # Sync the dataset
client.commit(
message="msg",
datasets=[
MovieRating,
MovieRevenueWithRightFields,
MovieStatsWithRightFields,
RatingActivity,
],
)
now = datetime.now(timezone.utc)
one_hour_ago = now - timedelta(hours=1)
data = [
["Jumanji", 4, 343, 789, one_hour_ago],
["Titanic", 5, 729, 1232, now],
]
columns = ["movie", "rating", "num_ratings", "sum_ratings", "t"]
df = pd.DataFrame(data, columns=columns)
response = client.log("fennel_webhook", "MovieRating", df)
assert response.status_code == requests.codes.OK, response.json()

two_hours_ago = now - timedelta(hours=2)
data = [
["Jumanji", 2000000, 1, two_hours_ago],
["Titanic", 50000000, 2, now],
]
columns = ["movie", "revenue", "extra_field", "t"]
df = pd.DataFrame(data, columns=columns)
response = client.log(
"fennel_webhook", "MovieRevenueWithRightFields", df
)
assert response.status_code == requests.codes.OK, response.json()
client.sleep()

# Do some lookups to verify pipeline_join is working as expected
keys = pd.DataFrame({"movie": ["Jumanji", "Titanic"]})
df, _ = client.lookup(
"MovieStatsWithRightFields",
keys=keys,
)
assert df.shape == (2, 5)
assert df["movie"].tolist() == ["Jumanji", "Titanic"]
assert df["rating"].tolist() == [4, 5]
assert df["revenue_in_millions"].tolist() == [2, 50]
assert df["t"].tolist() == [two_hours_ago, now]
assert "extra_field" not in df.columns

# Do some lookup at various timestamps in the past
ts = pd.Series([two_hours_ago, one_hour_ago, one_hour_ago, now])
keys = pd.DataFrame(
{"movie": ["Jumanji", "Jumanji", "Titanic", "Titanic"]}
)
df, _ = client.lookup(
"MovieStatsWithRightFields",
timestamps=ts,
keys=keys,
)
assert df.shape == (4, 5)
assert df["movie"].tolist() == [
"Jumanji",
"Jumanji",
"Titanic",
"Titanic",
]
assert pd.isna(df["rating"].tolist()[0])
assert df["rating"].tolist()[1] == 4
assert pd.isna(df["rating"].tolist()[2])
assert df["rating"].tolist()[3] == 5
assert pd.isna(df["revenue_in_millions"].tolist()[0])
assert df["revenue_in_millions"].tolist()[1] == 2
assert pd.isna(df["revenue_in_millions"].tolist()[2])
assert df["revenue_in_millions"].tolist()[3] == 50
assert pd.isna(df["t"].tolist()[0])
assert df["t"].tolist()[1] == two_hours_ago
assert pd.isna(df["t"].tolist()[2])
assert df["t"].tolist()[3] == now
assert "extra_field" not in df.columns


class TestBasicAggregate(unittest.TestCase):
@pytest.mark.integration
@mock
Expand Down
155 changes: 155 additions & 0 deletions fennel/client_tests/test_social_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@ class PostInfo:
timestamp: datetime


@source(
webhook.endpoint("PostInfoWithRightFields"), disorder="14d", cdc="upsert"
)
@dataset(index=True)
@meta(owner="data-eng@myspace.com")
class PostInfoWithRightFields:
title: str
category: str # type: ignore
post_id: int = field(key=True)
timestamp: datetime
extra_field: str


@meta(owner="data-eng@myspace.com")
@dataset
@source(webhook.endpoint("ViewData"), disorder="14d", cdc="append")
Expand Down Expand Up @@ -100,6 +113,25 @@ def count_user_views(cls, view_data: Dataset, post_info: Dataset):
)


@meta(owner="ml-eng@myspace.com")
@dataset(index=True)
class UserCategoryDatasetWithRightFields:
user_id: str = field(key=True)
category: str = field(key=True)
num_views: int
time_stamp: datetime

@pipeline
@inputs(ViewData, PostInfoWithRightFields)
def count_user_views(cls, view_data: Dataset, post_info: Dataset):
post_info_enriched = view_data.join(
post_info, how="inner", on=["post_id"], fields=["title", "category"]
)
return post_info_enriched.groupby("user_id", "category").aggregate(
[Count(window=Continuous("6y 8s"), into_field="num_views")]
)


@meta(owner="ml-eng@myspace.com")
@dataset(index=True)
class LastViewedPost:
Expand Down Expand Up @@ -166,6 +198,28 @@ def extract_user_views(cls, ts: pd.Series, user_ids: pd.Series):
return views["num_views"]


@meta(owner="feature-team@myspace.com")
@featureset
class UserFeaturesWithRightFields:
user_id: str = F(Request.user_id) # type: ignore
num_views: int
category: str = F(Request.category) # type: ignore
num_category_views: int = F(UserCategoryDatasetWithRightFields.num_views, default=0) # type: ignore
category_view_ratio: float = F(col("num_category_views") / col("num_views"))
last_viewed_post: int = F(LastViewedPost.post_id, default=-1) # type: ignore
last_viewed_post2: List[int] = F(
LastViewedPostByAgg.post_id, default=[-1] # type: ignore
)

@extractor(deps=[UserViewsDataset]) # type: ignore
@inputs(Request.user_id)
@outputs("num_views")
def extract_user_views(cls, ts: pd.Series, user_ids: pd.Series):
views, _ = UserViewsDataset.lookup(ts, user_id=user_ids) # type: ignore
views = views.fillna(0)
return views["num_views"]


@pytest.mark.integration
@mock
def test_social_network(client):
Expand Down Expand Up @@ -259,6 +313,107 @@ def test_social_network(client):
assert df.shape == (1998, 4)


@pytest.mark.integration
@mock
def test_social_network_with_fields(client):
client.commit(
message="social network",
datasets=[
UserInfo,
PostInfoWithRightFields,
ViewData,
CityInfo,
UserViewsDataset,
UserCategoryDatasetWithRightFields,
LastViewedPost,
LastViewedPostByAgg,
],
featuresets=[Request, UserFeaturesWithRightFields],
)
user_data_df = pd.read_csv("fennel/client_tests/data/user_data.csv")
post_data_df = pd.read_csv("fennel/client_tests/data/post_data.csv")
post_data_len = len(post_data_df.index)
post_data_df["extra_field"] = list(range(0, post_data_len))
view_data_df = pd.read_csv("fennel/client_tests/data/view_data_sampled.csv")
ts = "2018-01-01 00:00:00"
user_data_df["timestamp"] = ts
post_data_df["timestamp"] = ts
view_data_df["time_stamp"] = view_data_df["time_stamp"].apply(
lambda x: datetime.strptime(x, "%m/%d/%Y %H:%M %p")
)
# # Advance all timestamps by 6 years
user_data_df["timestamp"] = pd.to_datetime(
user_data_df["timestamp"]
) + pd.DateOffset(years=4)
post_data_df["timestamp"] = pd.to_datetime(
post_data_df["timestamp"]
) + pd.DateOffset(years=4)
view_data_df["time_stamp"] = view_data_df["time_stamp"] + pd.DateOffset(
years=4
)

res = client.log("fennel_webhook", "UserInfo", user_data_df)
assert res.status_code == requests.codes.OK, res.json()
res = client.log("fennel_webhook", "PostInfoWithRightFields", post_data_df)
assert res.status_code == requests.codes.OK, res.json()
res = client.log("fennel_webhook", "ViewData", view_data_df)
assert res.status_code == requests.codes.OK, res.json()

if client.is_integration_client():
client.sleep(120)

keys = pd.DataFrame(
{
"city": ["Wufeng", "Coyaima", "San Angelo"],
"gender": ["Male", "Male", "Female"],
}
)

df, found = client.lookup(
"CityInfo",
keys=keys,
)
assert found.to_list() == [True, True, True]

feature_df = client.query(
outputs=[UserFeaturesWithRightFields],
inputs=[Request.user_id, Request.category],
input_dataframe=pd.DataFrame(
{
"Request.user_id": [
"5eece14efc13ae6609000000",
"5eece14efc13ae660900003c",
],
"Request.category": ["banking", "programming"],
}
),
)
assert (
feature_df["UserFeaturesWithRightFields.num_views"].to_list(),
feature_df["UserFeaturesWithRightFields.num_category_views"].to_list(),
feature_df["UserFeaturesWithRightFields.category_view_ratio"].to_list(),
) == ([2, 4], [0, 1], [0.0, 0.25])

# Assert that both the last_viewed_post and last_viewed_post2 features are extracted correctly
last_post_viewed = feature_df[
"UserFeaturesWithRightFields.last_viewed_post"
].to_list()
last_post_viewed2 = [
x[0]
for x in feature_df[
"UserFeaturesWithRightFields.last_viewed_post2"
].to_list()
]
assert last_post_viewed == [936609766, 735291550]
assert last_post_viewed2 == last_post_viewed

if client.is_integration_client():
return
df = client.get_dataset_df("UserCategoryDatasetWithRightFields")
assert "extra_field" not in df.columns
assert df.shape == (1998, 4)


@mock
def test_social_network_with_mock_log(client):
client.commit(
Expand Down
Loading

0 comments on commit 8427e32

Please sign in to comment.