Skip to content

Commit

Permalink
FEA update_env implementation (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
merveenoyan authored Sep 19, 2022
1 parent 4fe963a commit cfea7cb
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
19 changes: 13 additions & 6 deletions skops/hub_utils/_hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,7 @@ def recursively_default_dict() -> MutableMapping:
else:
raise ValueError("The data needs to be a list of strings.")

with open(Path(dst) / "config.json", mode="w") as f:
json.dump(config, f, sort_keys=True, indent=4)
dump_json(Path(dst) / "config.json", config)


def _check_model_file(path: str | Path) -> Path:
Expand Down Expand Up @@ -381,6 +380,11 @@ def add_files(*files: str | Path, dst: str | Path, exist_ok: bool = False) -> No
shutil.copy2(src_file, dst_file)


def dump_json(path, content):
with open(Path(path), mode="w") as f:
json.dump(content, f, sort_keys=True, indent=4)


def update_env(
*, path: Union[str, Path], requirements: Union[List[str], None] = None
) -> None:
Expand All @@ -398,11 +402,14 @@ def update_env(
The list of required packages for the model. If none is passed, the
list of existing requirements is used and their versions are updated.
Returns
-------
None
"""
pass

with open(Path(path) / "config.json") as f:
config = json.load(f)

config["sklearn"]["environment"] = requirements

dump_json(Path(path) / "config.json", config)


def push(
Expand Down
8 changes: 8 additions & 0 deletions skops/hub_utils/tests/test_hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
get_requirements,
init,
push,
update_env,
)
from skops.hub_utils._hf_hub import (
_create_config,
Expand Down Expand Up @@ -462,6 +463,13 @@ def test_get_config(repo_path):
assert get_requirements(repo_path) == ['scikit-learn="1.1.1"']


def test_update_env(repo_path, config_json):
# sanity check
assert get_requirements(repo_path) == ['scikit-learn="1.1.1"']
update_env(path=repo_path, requirements=['scikit-learn="1.1.2"'])
assert get_requirements(repo_path) == ['scikit-learn="1.1.2"']


def test_get_example_input():
"""Test the _get_example_input function."""
with pytest.raises(
Expand Down

0 comments on commit cfea7cb

Please sign in to comment.