Skip to content

Commit

Permalink
ENH: Added CLI command to update skops files (#343)
Browse files Browse the repository at this point in the history
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
  • Loading branch information
EdAbati and BenjaminBossan authored Jun 27, 2023
1 parent 107904c commit 70185a9
Show file tree
Hide file tree
Showing 7 changed files with 415 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ node_modules
# Vim
*.swp

# MacOS
.DS_Store

exports
trash
2 changes: 2 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ v0.8
----
- Adds the abillity to set the :attr:`.Section.folded` property when using :meth:`.Card.add`.
:pr:`361` by :user:`Thomas Lazarus <lazarust>`.
- Add the CLI command to update Skops files to the latest Skops persistence format.
(:func:`.cli._update.main`). :pr:`333` by :user:`Edoardo Abati <EdAbati>`

v0.7
----
Expand Down
25 changes: 23 additions & 2 deletions docs/persistence.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,13 @@ for more details.
Command Line Interface
######################

Skops has a command line interface to convert scikit-learn models persisted with
``Pickle`` to ``Skops`` files.
Skops has a command line interface to:

- convert scikit-learn models persisted with ``Pickle`` to ``Skops`` files.
- update ``Skops`` files to the latest version.

``skops convert``
~~~~~~~~~~~~~~~~~

To convert a file from the command line, use the ``skops convert`` entrypoint.

Expand All @@ -151,6 +156,22 @@ For example, to convert all ``.pkl`` flies in the current directory:
Further help for the different supported options can be found by calling
``skops convert --help`` in a terminal.

``skops update``
~~~~~~~~~~~~~~~~

To update a ``Skops`` file from the command line, use the ``skops update`` command.
Skops will check the protocol version of the file to determine if it needs to be updated to the current version.

The below command is an example on how to create an updated version of a file
``my_model.skops`` and save it as ``my_model-updated.skops``:

.. code-block:: console
skops update my_model.skops -o my_model-updated.skops
Further help for the different supported options can be found by calling
``skops update --help`` in a terminal.

Visualization
#############

Expand Down
136 changes: 136 additions & 0 deletions skops/cli/_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from __future__ import annotations

import argparse
import json
import logging
import shutil
import tempfile
import zipfile
from pathlib import Path

from skops.cli._utils import get_log_level
from skops.io import dump, load
from skops.io._protocol import PROTOCOL


def _update_file(
input_file: str | Path,
output_file: str | Path | None = None,
inplace: bool = False,
logger: logging.Logger = logging.getLogger(),
) -> None:
"""Function that is called by ``skops update`` entrypoint.
Loads a skops model from the input path, updates it to the current skops format, and
saves to an output file. It will overwrite the input file if `inplace` is True.
Parameters
----------
input_file : str, or Path
Path of input skops model to load.
output_file : str, or Path, default=None
Path to save the updated skops model to.
inplace : bool, default=False
Whether to update and overwrite the input file in place.
logger : logging.Logger, default=logging.getLogger()
Logger to use for logging.
"""
if inplace:
if output_file is None:
output_file = input_file
else:
raise ValueError(
"Cannot specify both an output file path and the inplace flag. Please"
" choose whether you want to create a new file or overwrite the input"
" file."
)

input_model = load(input_file, trusted=True)
with zipfile.ZipFile(input_file, "r") as zip_file:
input_file_schema = json.loads(zip_file.read("schema.json"))

if input_file_schema["protocol"] == PROTOCOL:
logger.warning(
"File was not updated because already up to date with the current protocol:"
f" {PROTOCOL}"
)
return None

if input_file_schema["protocol"] > PROTOCOL:
logger.warning(
"File cannot be updated because its protocol is more recent than the "
f"current protocol: {PROTOCOL}"
)
return None

if output_file is None:
logger.warning(
f"File can be updated to the current protocol: {PROTOCOL}. Please"
" specify an output file path or use the `inplace` flag to create the"
" updated Skops file."
)
return None

with tempfile.TemporaryDirectory() as tmp_dir:
tmp_output_file = Path(tmp_dir) / f"{output_file}.tmp"
dump(input_model, tmp_output_file)
shutil.move(str(tmp_output_file), str(output_file))
logger.info(f"Updated skops file written to {output_file}")


def format_parser(
parser: argparse.ArgumentParser | None = None,
) -> argparse.ArgumentParser:
"""Adds arguments and help to parent CLI parser for the `update` method."""

if not parser: # used in tests
parser = argparse.ArgumentParser()

parser_subgroup = parser.add_argument_group("update")
parser_subgroup.add_argument("input", help="Path to an input file to update.")

parser_subgroup.add_argument(
"-o",
"--output-file",
help="Specify the output file name for the updated skops file.",
default=None,
)
parser_subgroup.add_argument(
"--inplace",
help="Update and overwrite the input file in place.",
action="store_true",
)
parser_subgroup.add_argument(
"-v",
"--verbose",
help=(
"Increases verbosity of logging. Can be used multiple times to increase "
"verbosity further."
),
action="count",
dest="loglevel",
default=0,
)
return parser


def main(
parsed_args: argparse.Namespace,
logger: logging.Logger = logging.getLogger(),
) -> None:
output_file = Path(parsed_args.output_file) if parsed_args.output_file else None
input_file = Path(parsed_args.input)
inplace = parsed_args.inplace

logging.basicConfig(format="%(levelname)-8s: %(message)s")
logger.setLevel(level=get_log_level(parsed_args.loglevel))

_update_file(
input_file=input_file,
output_file=output_file,
inplace=inplace,
logger=logger,
)
5 changes: 5 additions & 0 deletions skops/cli/entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse

import skops.cli._convert
import skops.cli._update


def main_cli(command_line_args=None):
Expand Down Expand Up @@ -32,6 +33,10 @@ def main_cli(command_line_args=None):
"method": skops.cli._convert.main,
"format_parser": skops.cli._convert.format_parser,
},
"update": {
"method": skops.cli._update.main,
"format_parser": skops.cli._update.format_parser,
},
}

for func_name, values in function_map.items():
Expand Down
20 changes: 20 additions & 0 deletions skops/cli/tests/test_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,23 @@ def test_convert_works_as_expected(
)

assert caplog.at_level(logging.WARNING)

@mock.patch("skops.cli._update._update_file")
def test_update_works_as_expected(
self,
update_file_mock: mock.MagicMock,
):
"""
To make sure the parser is configured correctly, when 'update'
is the first argument.
"""

args = ["update", "abc.skops", "-o", "abc-new.skops"]

main_cli(args)
update_file_mock.assert_called_once_with(
input_file=pathlib.Path("abc.skops"),
output_file=pathlib.Path("abc-new.skops"),
inplace=False,
logger=mock.ANY,
)
Loading

0 comments on commit 70185a9

Please sign in to comment.