Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
satrana42 authored and aditya-nambiar committed Sep 6, 2024
1 parent 06cd8aa commit a0084e8
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 43 deletions.
4 changes: 2 additions & 2 deletions fennel/client_tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,7 +1434,7 @@ def to_millions(df: pd.DataFrame) -> pd.DataFrame:
]
]

c = rating.join(revenue, how="left", on=[str(cls.movie)], right_fields=["revenue"])
c = rating.join(revenue, how="left", on=[str(cls.movie)], fields=["revenue"])
# Transform provides additional columns which will be filtered out.
return c.transform(
to_millions,
Expand All @@ -1452,7 +1452,7 @@ def to_millions(df: pd.DataFrame) -> pd.DataFrame:
class TestBasicJoinWithRightFields(unittest.TestCase):
@pytest.mark.integration
@mock
def test_basic_join_with_right_fields(self, client):
def test_basic_join_with_fields(self, client):
# # Sync the dataset
client.commit(
message="msg",
Expand Down
4 changes: 2 additions & 2 deletions fennel/client_tests/test_social_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class UserCategoryDatasetWithRightFields:
@inputs(ViewData, PostInfoWithRightFields)
def count_user_views(cls, view_data: Dataset, post_info: Dataset):
post_info_enriched = view_data.join(
post_info, how="inner", on=["post_id"], right_fields=["title", "category"]
post_info, how="inner", on=["post_id"], fields=["title", "category"]
)
return post_info_enriched.groupby("user_id", "category").aggregate(
[Count(window=Continuous("6y 8s"), into_field="num_views")]
Expand Down Expand Up @@ -313,7 +313,7 @@ def test_social_network(client):

@pytest.mark.integration
@mock
def test_social_network_with_right_fields(client):
def test_social_network_with_fields(client):
client.commit(
message="social network",
datasets=[
Expand Down
46 changes: 23 additions & 23 deletions fennel/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def join(
left_on: Optional[List[str]] = None,
right_on: Optional[List[str]] = None,
within: Tuple[Duration, Duration] = ("forever", "0s"),
right_fields: Optional[List[str]] = None,
fields: Optional[List[str]] = None,
) -> Join:
if not isinstance(other, Dataset) and isinstance(other, _Node):
raise ValueError(
Expand All @@ -310,7 +310,7 @@ def join(
)
if not isinstance(other, _Node):
raise TypeError("Cannot join with a non-dataset object")
return Join(self, other, within, how, on, left_on, right_on, right_fields)
return Join(self, other, within, how, on, left_on, right_on, fields)

def rename(self, columns: Dict[str, str]) -> _Node:
return Rename(self, columns)
Expand Down Expand Up @@ -936,7 +936,7 @@ def __init__(
on: Optional[List[str]] = None,
left_on: Optional[List[str]] = None,
right_on: Optional[List[str]] = None,
right_fields: Optional[List[str]] = None,
fields: Optional[List[str]] = None,
# Currently not supported
lsuffix: str = "",
rsuffix: str = "",
Expand Down Expand Up @@ -965,7 +965,7 @@ def __init__(
self.right_on = right_on
self.within = within
self.how = how
self.right_fields = right_fields
self.fields = fields
self.lsuffix = lsuffix
self.rsuffix = rsuffix
self.node.out_edges.append(self)
Expand All @@ -979,7 +979,7 @@ def signature(self):
self.left_on,
self.right_on,
self.how,
self.right_fields,
self.fields,
self.lsuffix,
self.rsuffix,
)
Expand All @@ -991,7 +991,7 @@ def signature(self):
self.right_on,
self.within,
self.how,
self.right_fields,
self.fields,
self.lsuffix,
self.rsuffix,
)
Expand Down Expand Up @@ -1041,30 +1041,29 @@ def make_types_optional(types: Dict[str, Type]) -> Dict[str, Type]:
if self.how == "left":
right_value_schema = make_types_optional(right_value_schema)

# TODO(sat): Are the same checks here and in SchemaValidator redundant?
# If right_fields is set, check that it contains elements from right schema values and timestamp only
if self.right_fields is not None and len(self.right_fields) > 0:
# If fields is set, check that it contains elements from right schema values and timestamp only
if self.fields is not None and len(self.fields) > 0:
allowed_col_names = [x for x in right_value_schema.keys()] + [right_ts]
for col_name in self.right_fields:
for col_name in self.fields:
if col_name not in allowed_col_names:
raise ValueError(
f"Column `{col_name}` not found in schema {self.dataset.dsschema()} of right input "
f"fields member `{col_name}` not present in allowed fields {allowed_col_names} of right input "
f"{self.dataset.dsschema().name}"
)

# Add right value columns to left schema. Check for column name collisions. Filter keys present in right_fields.
# Add right value columns to left schema. Check for column name collisions. Filter keys present in fields.
joined_dsschema = copy.deepcopy(left_dsschema)
for col, dtype in right_value_schema.items():
if col in left_schema:
raise ValueError(
f"Column name collision. `{col}` already exists in schema of left input {left_dsschema.name}, while joining with {self.dataset.dsschema().name}"
)
if self.right_fields is not None and len(self.right_fields) > 0 and col not in self.right_fields:
if self.fields is not None and len(self.fields) > 0 and col not in self.fields:
continue
joined_dsschema.append_value_column(col, dtype)

# Add timestamp column if present in right_fields
if self.right_fields is not None and right_ts in self.right_fields:
# Add timestamp column if present in fields
if self.fields is not None and right_ts in self.fields:
joined_dsschema.append_value_column(right_ts, datetime.datetime)

return joined_dsschema
Expand Down Expand Up @@ -2866,18 +2865,19 @@ def validate_right_index(right_dataset: Dataset):
f'"how" in {output_schema_name} must be either "inner" or "left" for `{output_schema_name}`'
)

if obj.right_fields is not None and len(obj.right_fields) > 0:
allowed_right_fields = [x for x in right_schema.values.keys()] + [right_schema.timestamp]
for field in obj.right_fields:
if field not in allowed_right_fields:
if obj.fields is not None and len(obj.fields) > 0:
allowed_fields = [x for x in right_schema.values.keys()] + [right_schema.timestamp]
for field in obj.fields:
if field not in allowed_fields:
raise ValueError(
f"Field `{field}` specified in right_fields {obj.right_fields} doesn't exist in "
f"allowed fields of right schema {right_schema} of {output_schema_name}."
f"Field `{field}` specified in fields {obj.fields} "
f"doesn't exist in allowed fields {allowed_fields} of "
f"right schema of {output_schema_name}."
)

if right_schema.timestamp in obj.right_fields and right_schema.timestamp in left_schema.fields():
if right_schema.timestamp in obj.fields and right_schema.timestamp in left_schema.fields():
raise ValueError(
f"Field `{right_schema.timestamp}` specified in right_fields {obj.right_fields} "
f"Field `{right_schema.timestamp}` specified in fields {obj.fields} "
f"already exists in left schema of {output_schema_name}."
)

Expand Down
29 changes: 21 additions & 8 deletions fennel/datasets/test_invalid_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ def create_pipeline(cls, a: Dataset, b: Dataset):
)


def test_dataset_incorrect_join_right_fields():
def test_dataset_incorrect_join_fields():
with pytest.raises(ValueError) as e:

@dataset
Expand All @@ -789,10 +789,15 @@ class XYZJoinedABC:
@pipeline
@inputs(XYZ, ABC)
def create_pipeline(cls, a: Dataset, b: Dataset):
c = a.join(b, how="inner", on=["user_id"], right_fields=["rank"]) # type: ignore
c = a.join(b, how="inner", on=["user_id"], fields=["rank"]) # type: ignore
return c

assert "doesn't exist in allowed fields of right schema" in str(e.value)
assert(
str(e.value)
== "Field `rank` specified in fields ['rank'] doesn't exist in "
"allowed fields ['age', 'timestamp'] of right schema of "
"'[Pipeline:create_pipeline]->join node'."
)

with pytest.raises(ValueError) as e:

Expand All @@ -818,10 +823,15 @@ class XYZJoinedABC:
@pipeline
@inputs(XYZ, ABC)
def create_pipeline(cls, a: Dataset, b: Dataset):
c = a.join(b, how="inner", on=["user_id"], right_fields=["user_id"]) # type: ignore
c = a.join(b, how="inner", on=["user_id"], fields=["user_id"]) # type: ignore
return c

assert "doesn't exist in allowed fields of right schema" in str(e.value)
assert(
str(e.value)
== "Field `user_id` specified in fields ['user_id'] doesn't exist in "
"allowed fields ['age', 'timestamp'] of right schema of "
"'[Pipeline:create_pipeline]->join node'."
)

with pytest.raises(ValueError) as e:

Expand All @@ -847,11 +857,14 @@ class XYZJoinedABC:
@pipeline
@inputs(XYZ, ABC)
def create_pipeline(cls, a: Dataset, b: Dataset):
c = a.join(b, how="inner", on=["user_id"], right_fields=["timestamp"]) # type: ignore
c = a.join(b, how="inner", on=["user_id"], fields=["timestamp"]) # type: ignore
return c

assert "already exists in left schema" in str(e.value)

assert(
str(e.value)
== "Field `timestamp` specified in fields ['timestamp'] already "
"exists in left schema of '[Pipeline:create_pipeline]->join node'."
)


def test_dataset_incorrect_join_bounds():
Expand Down
Loading

0 comments on commit a0084e8

Please sign in to comment.