diff --git a/fennel/client_tests/test_featureset.py b/fennel/client_tests/test_featureset.py index 3ae09914..efce20c2 100644 --- a/fennel/client_tests/test_featureset.py +++ b/fennel/client_tests/test_featureset.py @@ -75,6 +75,33 @@ def get_user_info(cls, ts: pd.Series, user_id: pd.Series): ] +@featureset +class UserInfoSingleErrorProneExtractor: + userid: int + age: int = F().meta(owner="aditya@fennel.ai") # type: ignore + age_squared: int + age_cubed: int + is_name_common: bool + + @extractor(deps=[UserInfoDataset]) # type: ignore + @inputs("userid", "age") + @outputs("age", "age_squared", "age_cubed", "is_name_common") + def get_user_info( + cls, ts: pd.Series, user_id: pd.Series, user_age: pd.Series + ): + output_df = pd.DataFrame( + { + "userid": user_id.iloc[ + :-1 + ], # Drop the last row to create a mismatch + "age": user_age.iloc[ + :-1 + ], # Drop the last row to create a mismatch + } + ) + return output_df + + def get_country_geoid(country: str) -> int: if country == "Russia": return 1 @@ -243,6 +270,34 @@ def test_simple_extractor(self, client): res["UserInfoMultipleExtractor.age_doubled"].tolist(), [64, 48] ) + @pytest.mark.integration + @mock + def test_simple_error_extractor(self, client): + client.commit( + message="some commit msg", + datasets=[UserInfoDataset], + featuresets=[UserInfoMultipleExtractor], + ) + now = datetime.now(timezone.utc) + data = [ + [18232, "John", 32, "USA", now], + [18234, "Monica", 24, "Chile", now], + ] + columns = ["user_id", "name", "age", "country", "timestamp"] + input_df = pd.DataFrame(data, columns=columns) + response = client.log("fennel_webhook", "UserInfoDataset", input_df) + assert response.status_code == requests.codes.OK, response.json() + client.sleep() + ts = pd.Series([now, now]) + user_ids = pd.Series([18232, 18234]) + age = pd.Series([18, 22]) + with pytest.raises(AssertionError) as exc_info: + UserInfoSingleErrorProneExtractor.get_user_info( + UserInfoMultipleExtractor, ts, user_ids, age + ) + assert "Output length mismatch" in str(exc_info.value) + assert "expected 2, got 1" in str(exc_info.value) + @pytest.mark.integration @mock def test_e2e_query(self, client): diff --git a/fennel/featuresets/featureset.py b/fennel/featuresets/featureset.py index 7cf6831e..68d94b27 100644 --- a/fennel/featuresets/featureset.py +++ b/fennel/featuresets/featureset.py @@ -450,6 +450,55 @@ def is_user_defined(obj): return inspect.isclass(type(obj)) and not inspect.isbuiltin(obj) +def _add_input_constraints(func, params): + """ + Wraps the user's extractor with a function that: + 1. Checks that all input series have the same length. + 2. Runs the extractor. + 3. Asserts that the number of rows in the output matches the input length. + The output can be either a pandas Series or DataFrame. + """ + + @functools.wraps(func) + def inner(*args, **kwargs): + # Ensure we have the correct number of arguments + assert ( + len(args) == len(params) + 2 + ), f"Expected {len(params) + 2} arguments, got {len(args)}" + + args = list(args) + input_series = [ + arg for arg in args[2:] + ] # Skip the first two fixed args + + # Measure length of input series + input_length = len(input_series[0]) + + renamed_args = args[:2] + [ + arg.rename(name.fqn()) for name, arg in zip(params, input_series) + ] + + # Run the extractor + ret = func(*renamed_args, **kwargs) + + # Ensure the output matches the input length + if isinstance(ret, pd.Series) or isinstance(ret, pd.DataFrame): + output_length = len(ret) + else: + raise ValueError( + f"Expected a pandas Series or DataFrame but got {type(ret)} in {func.__qualname__}." + ) + + # Check that the output length matches the input length + assert ( + output_length == input_length + ), f"Output length mismatch in {func.__qualname__}: expected {input_length}, got {output_length}" + + return ret + + return inner + + def _add_featureset_name(func, params): """Rewrites the output column names of the extractor to be fully qualified names. Also add feature names to the input parameters of the extractor. @@ -773,6 +822,7 @@ def _get_extractors(self) -> List[Extractor]: extractor.featureset = self._name extractor.inputs = inputs func = _add_featureset_name(extractor.func, extractor.inputs) + func = _add_input_constraints(func, extractor.inputs) # Setting func at both extractor.func and class attribute extractor.func = func setattr(self.__fennel_original_cls__, name, classmethod(func))