Skip to content

Commit

Permalink
FIX Galileo-Galilei#83 - Check input_name when set in PipelineML
Browse files Browse the repository at this point in the history
  • Loading branch information
Galileo-Galilei committed Oct 10, 2020
1 parent b98b2e0 commit 4c0b6a7
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
- Remove `conda_env` and `model_name` arguments from `MlflowPipelineHook` and add them to `PipelineML` and `pipeline_ml`. This is necessary for incoming hook auto-discovery in future release and it enables having multiple `PipelineML` in the same project ([#58](https://github.com/Galileo-Galilei/kedro-mlflow/pull/58)). This mechanically fixes [#54](https://github.com/Galileo-Galilei/kedro-mlflow/issues/54) by making `conda_env` path absolute for airflow suppport.
- `flatten_dict_params`, `recursive` and `sep` arguments of the `MlflowNodeHook` are moved to the `mlflow.yml` config file to prepare plugin auto registration. This also modifies the `run.py` template (to remove the args) and the `mlflow.yml` keys to add a `hooks` entry. ([#59](https://github.com/Galileo-Galilei/kedro-mlflow/pull/59))
- Rename CI workflow to `test` ([#57](https://github.com/Galileo-Galilei/kedro-mlflow/issues/57), [#68](https://github.com/Galileo-Galilei/kedro-mlflow/pull/68))
- The `input_name` attributes of `PipelineML` is now a python property and makes a check at setting time to prevent setting an invalid value. The check ensures that `input_name` is a valid input of the `inference` pipeline.


### Deprecated
Expand Down
14 changes: 11 additions & 3 deletions kedro_mlflow/pipeline/pipeline_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,17 @@ def __init__(
self.conda_env = conda_env
self.model_name = model_name

self._check_input_name(input_name)
self.input_name = input_name

@property
def input_name(self) -> str:
return self._input_name

@input_name.setter
def input_name(self, name: str) -> None:
self._check_input_name(name)
self._input_name = name

def extract_pipeline_catalog(self, catalog: DataCatalog) -> DataCatalog:
sub_catalog = DataCatalog()
for data_set_name in self.inference.inputs():
Expand Down Expand Up @@ -134,10 +142,10 @@ def training(self):

def _check_input_name(self, input_name: str) -> str:
allowed_names = self.inference.inputs()
pp_allowed_names = "\n - ".join(allowed_names)
pp_allowed_names = "\n - ".join(allowed_names)
if input_name not in allowed_names:
raise KedroMlflowPipelineMLInputsError(
f"input_name='{input_name}' but it must be an input of inference, i.e. one of: {pp_allowed_names}"
f"input_name='{input_name}' but it must be an input of 'inference', i.e. one of: \n - {pp_allowed_names}"
)
else:
free_inputs_set = (
Expand Down
18 changes: 14 additions & 4 deletions tests/pipeline/test_pipeline_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,31 +119,33 @@ def _get_pipelines(self):
# Disable logging.config.dictConfig in KedroContext._setup_logging as
# it changes logging.config and affects other unit tests
mocker.patch("logging.config.dictConfig")

return DummyContext(tmp_path.as_posix())
dummy_context = DummyContext(tmp_path.as_posix())
return dummy_context


@pytest.fixture
def dummy_catalog():
return DataCatalog(
dummy_catalog = DataCatalog(
{
"raw_data": MemoryDataSet(),
"data": MemoryDataSet(),
"model": CSVDataSet("fake/path/to/model.csv"),
}
)
return dummy_catalog


@pytest.fixture
def catalog_with_encoder():
return DataCatalog(
catalog_with_encoder = DataCatalog(
{
"raw_data": MemoryDataSet(),
"data": MemoryDataSet(),
"encoder": CSVDataSet("fake/path/to/encoder.csv"),
"model": CSVDataSet("fake/path/to/model.csv"),
}
)
return catalog_with_encoder


@pytest.mark.parametrize(
Expand Down Expand Up @@ -332,3 +334,11 @@ def fake_dec(x):

new_pl = pipeline_ml_with_tag.decorate(fake_dec)
assert all([fake_dec in node._decorators for node in new_pl.nodes])


def test_invalid_input_name(pipeline_ml_with_tag):
with pytest.raises(
KedroMlflowPipelineMLInputsError,
match="input_name='whoops_bad_name' but it must be an input of 'inference'",
):
pipeline_ml_with_tag.input_name = "whoops_bad_name"

0 comments on commit 4c0b6a7

Please sign in to comment.