Skip to content

Commit

Permalink
FIX #66 - Accessing project context inside hooks (and consequently fix
Browse files Browse the repository at this point in the history
  • Loading branch information
takikadiri committed Oct 17, 2020
1 parent d9d3bf5 commit 26b6572
Show file tree
Hide file tree
Showing 9 changed files with 296 additions and 54 deletions.
16 changes: 16 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,22 @@

## [Unreleased]

### Added

-

### Fixed

- `get_mlflow_config` now uses the kedro context config_loader to get configs (#66). This indirectly solves the following issues:
- `get_mlflow_config` now works in interactive mode if `load_context` is called with a path different from the working directory (#30)
- kedro_mlflow now works fine with kedro jupyter notebook independently of the working directory (#64)
- You can use global variables in `mlflow.yml` which is now properly parsed if you use a `TemplatedConfigLoader` (#72)
- `mlflow init` is now getting conf path from context.CONF_ROOT instead of hardcoded conf folder. This makes the package robust to Kedro changes.

### Changed

- `MlflowNodeHook` have now a before_pipeline_run hook which stores the ProjectContext and enable to retrieve configuration.

## [0.3.0] - 2020-10-11

### Added
Expand Down
14 changes: 11 additions & 3 deletions kedro_mlflow/framework/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import click
from kedro import __file__ as KEDRO_PATH
from kedro.framework.context import load_context

from kedro_mlflow.framework.cli.cli_utils import (
render_jinja_template,
Expand Down Expand Up @@ -88,19 +89,23 @@ def init(force, silent):
# get constants
project_path = Path().cwd()
project_globals = _get_project_globals()
context = load_context(project_path)
conf_root = context.CONF_ROOT

# mlflow.yml is just a static file,
# but the name of the experiment is set to be the same as the project
mlflow_yml = "mlflow.yml"
write_jinja_template(
src=TEMPLATE_FOLDER_PATH / mlflow_yml,
is_cookiecutter=False,
dst=project_path / "conf" / "base" / mlflow_yml,
dst=project_path / conf_root / "base" / mlflow_yml,
python_package=project_globals["python_package"],
)
if not silent:
click.secho(
click.style("'conf/base/mlflow.yml' successfully updated.", fg="green")
click.style(
f"'{conf_root}/base/mlflow.yml' successfully updated.", fg="green"
)
)
# make a check whether the project run.py is strictly identical to the template
# if yes, replace the script by the template silently
Expand Down Expand Up @@ -184,8 +189,11 @@ def ui(project_path, env):
"""

if not project_path:
project_path = Path().cwd()
context = load_context(project_path=project_path, env=env)
# the context must contains the self.mlflow attribues with mlflow configuration
mlflow_conf = get_mlflow_config(project_path=project_path, env=env)
mlflow_conf = get_mlflow_config(context)

# call mlflow ui with specific options
# TODO : add more options for ui
Expand Down
19 changes: 5 additions & 14 deletions kedro_mlflow/framework/context/mlflow_context.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
from pathlib import Path

from kedro.config import ConfigLoader
from kedro.framework.context import KedroContext

from kedro_mlflow.framework.context.config import KedroMlflowConfig


# this could be a read-only property in the context
# with a @property decorator
# but for consistency with hook system, it is an external function
def get_mlflow_config(project_path=None, env="local"):
if project_path is None:
project_path = Path.cwd()
project_path = Path(project_path)
conf_paths = [
str(project_path / "conf" / "base"),
str(project_path / "conf" / env),
]
config_loader = ConfigLoader(conf_paths=conf_paths)
conf_mlflow_yml = config_loader.get("mlflow*", "mlflow*/**")
conf_mlflow = KedroMlflowConfig(project_path=project_path)
def get_mlflow_config(context: KedroContext):

conf_mlflow_yml = context.config_loader.get("mlflow*", "mlflow*/**")
conf_mlflow = KedroMlflowConfig(context.project_path)
conf_mlflow.from_dict(conf_mlflow_yml)
return conf_mlflow
41 changes: 40 additions & 1 deletion kedro_mlflow/framework/hooks/node_hook.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,55 @@
from typing import Any, Dict

import mlflow
from kedro.framework.context import load_context
from kedro.framework.hooks import hook_impl
from kedro.io import DataCatalog
from kedro.pipeline import Pipeline
from kedro.pipeline.node import Node

from kedro_mlflow.framework.context import get_mlflow_config


class MlflowNodeHook:
def __init__(self):
config = get_mlflow_config()
self.context = None
self.flatten = False
self.recursive = True
self.sep = "."

@hook_impl
def before_pipeline_run(
self, run_params: Dict[str, Any], pipeline: Pipeline, catalog: DataCatalog
) -> None:
"""Hook to be invoked before a pipeline runs.
Args:
run_params: The params needed for the given run.
Should be identical to the data logged by Journal.
# @fixme: this needs to be modelled explicitly as code, instead of comment
Schema: {
"run_id": str,
"project_path": str,
"env": str,
"kedro_version": str,
"tags": Optional[List[str]],
"from_nodes": Optional[List[str]],
"to_nodes": Optional[List[str]],
"node_names": Optional[List[str]],
"from_inputs": Optional[List[str]],
"load_versions": Optional[List[str]],
"pipeline_name": str,
"extra_params": Optional[Dict[str, Any]],
}
pipeline: The ``Pipeline`` that will be run.
catalog: The ``DataCatalog`` to be used during the run.
"""

self.context = load_context(
project_path=run_params["project_path"],
env=run_params["env"],
extra_params=run_params["extra_params"],
)
config = get_mlflow_config(self.context)
self.flatten = config.node_hook_opts["flatten_dict_params"]
self.recursive = config.node_hook_opts["recursive"]
self.sep = config.node_hook_opts["sep"]
Expand Down
14 changes: 11 additions & 3 deletions kedro_mlflow/framework/hooks/pipeline_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import mlflow
import yaml
from kedro.framework.context import load_context
from kedro.framework.hooks import hook_impl
from kedro.io import DataCatalog
from kedro.pipeline import Pipeline
Expand All @@ -12,11 +13,14 @@
from kedro_mlflow.framework.context import get_mlflow_config
from kedro_mlflow.io import MlflowMetricsDataSet
from kedro_mlflow.mlflow import KedroPipelineModel
from kedro_mlflow.pipeline.pipeline_ml_factory import PipelineML
from kedro_mlflow.pipeline.pipeline_ml import PipelineML
from kedro_mlflow.utils import _parse_requirements


class MlflowPipelineHook:
def __init__(self):
self.context = None

@hook_impl
def after_catalog_created(
self,
Expand Down Expand Up @@ -62,9 +66,13 @@ def before_pipeline_run(
pipeline: The ``Pipeline`` that will be run.
catalog: The ``DataCatalog`` to be used during the run.
"""
mlflow_conf = get_mlflow_config(
project_path=run_params["project_path"], env=run_params["env"]
self.context = load_context(
project_path=run_params["project_path"],
env=run_params["env"],
extra_params=run_params["extra_params"],
)

mlflow_conf = get_mlflow_config(self.context)
mlflow.set_tracking_uri(mlflow_conf.mlflow_tracking_uri)
# TODO : if the pipeline fails, we need to be able to end stop the mlflow run
# cannot figure out how to do this within hooks
Expand Down
41 changes: 40 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pathlib import Path
from typing import Dict

Expand Down Expand Up @@ -49,7 +50,45 @@ def config_dir(tmp_path):
credentials = tmp_path / "conf" / env / "credentials.yml"
logging = tmp_path / "conf" / env / "logging.yml"
parameters = tmp_path / "conf" / env / "parameters.yml"
globals_yaml = tmp_path / "conf" / env / "globals.yml"
kedro_yaml = tmp_path / ".kedro.yml"
_write_yaml(catalog, dict())
_write_yaml(parameters, dict())
_write_yaml(globals_yaml, dict())
_write_yaml(credentials, dict())
_write_yaml(logging, _get_local_logging_config())
_write_yaml(logging, _get_local_logging_config()),

_write_yaml(
kedro_yaml,
dict(
{
"context_path": "dummy_package.run.ProjectContext",
"project_name": "dummy_package",
"project_version": "0.16.4",
"package_name": "dummy_package",
}
),
)

os.mkdir(tmp_path / "src")
os.mkdir(tmp_path / "src" / "dummy_package")
with open(tmp_path / "src" / "dummy_package" / "run.py", "w") as f:
f.writelines(
[
"from kedro.framework.context import KedroContext\n",
"from kedro.config import TemplatedConfigLoader \n"
"class ProjectContext(KedroContext):\n",
" project_name = 'dummy_package'\n",
" project_version = '0.16.4'\n",
" package_name = 'dummy_package'\n",
]
)
f.writelines(
[
" def _create_config_loader(self, conf_paths):\n",
" return TemplatedConfigLoader(\n",
" conf_paths,\n",
" globals_pattern='globals.yml'\n",
" )\n",
]
)
48 changes: 42 additions & 6 deletions tests/framework/context/test_mlflow_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import yaml
from kedro.framework.context import load_context

from kedro_mlflow.framework.context import get_mlflow_config

Expand All @@ -7,15 +8,17 @@
# get_mlflow_config(project_path=tmp_path,env="local")


def _write_yaml(filepath, config):
filepath.parent.mkdir(parents=True, exist_ok=True)
yaml_str = yaml.dump(config)
filepath.write_text(yaml_str)


def test_get_mlflow_config(mocker, tmp_path, config_dir):
# config_with_base_mlflow_conf is a pytest.fixture in conftest
mocker.patch("logging.config.dictConfig")
mocker.patch("kedro_mlflow.utils._is_kedro_project", return_value=True)

def _write_yaml(filepath, config):
filepath.parent.mkdir(parents=True, exist_ok=True)
yaml_str = yaml.dump(config)
filepath.write_text(yaml_str)

_write_yaml(
tmp_path / "conf" / "base" / "mlflow.yml",
dict(
Expand All @@ -35,4 +38,37 @@ def _write_yaml(filepath, config):
"node": {"flatten_dict_params": True, "recursive": False, "sep": "-"}
},
}
assert get_mlflow_config(project_path=tmp_path, env="local").to_dict() == expected
context = load_context(tmp_path)
assert get_mlflow_config(context).to_dict() == expected


def test_mlflow_config_with_templated_config(mocker, tmp_path, config_dir):

_write_yaml(
tmp_path / "conf" / "base" / "mlflow.yml",
dict(
mlflow_tracking_uri="${mlflow_tracking_uri}",
experiment=dict(name="fake_package", create=True),
run=dict(id="123456789", name="my_run", nested=True),
ui=dict(port="5151", host="localhost"),
hooks=dict(node=dict(flatten_dict_params=True, recursive=False, sep="-")),
),
)

_write_yaml(
tmp_path / "conf" / "base" / "globals.yml",
dict(mlflow_tracking_uri="testruns"),
)

expected = {
"mlflow_tracking_uri": (tmp_path / "testruns").as_uri(),
"experiments": {"name": "fake_package", "create": True},
"run": {"id": "123456789", "name": "my_run", "nested": True},
"ui": {"port": "5151", "host": "localhost"},
"hooks": {
"node": {"flatten_dict_params": True, "recursive": False, "sep": "-"}
},
}

context = load_context(tmp_path)
assert get_mlflow_config(context).to_dict() == expected
Loading

0 comments on commit 26b6572

Please sign in to comment.