diff --git a/ludwig/cli.py b/ludwig/cli.py index 9ad5bf05cd3..db95bf0d785 100644 --- a/ludwig/cli.py +++ b/ludwig/cli.py @@ -49,6 +49,7 @@ def __init__(self): datasets Downloads and lists Ludwig-ready datasets export_torchscript Exports Ludwig models to Torchscript export_triton Exports Ludwig models to Triton + export_carton Exports Ludwig models to Carton export_neuropod Exports Ludwig models to Neuropod export_mlflow Exports Ludwig models to MLflow preprocess Preprocess data and saves it into HDF5 and JSON format @@ -140,6 +141,11 @@ def export_triton(self): export.cli_export_triton(sys.argv[2:]) + def export_carton(self): + from ludwig import export + + export.cli_export_carton(sys.argv[2:]) + def export_neuropod(self): from ludwig import export diff --git a/ludwig/export.py b/ludwig/export.py index 06d767adf1a..60622650a1f 100644 --- a/ludwig/export.py +++ b/ludwig/export.py @@ -22,6 +22,7 @@ from ludwig.api import LudwigModel from ludwig.contrib import add_contrib_callback_args from ludwig.globals import LUDWIG_VERSION +from ludwig.utils.carton_utils import export_carton as utils_export_carton from ludwig.utils.neuropod_utils import export_neuropod as utils_export_neuropod from ludwig.utils.print_utils import get_logging_level_registry, print_ludwig from ludwig.utils.triton_utils import export_triton as utils_export_triton @@ -88,6 +89,32 @@ def export_triton(model_path, output_path="model_repository", model_name="ludwig logger.info(f"Saved to: {output_path}") +def export_carton(model_path, output_path="carton", model_name="carton", **kwargs): + """Exports a model to Carton. + + # Inputs + + :param model_path: (str) filepath to pre-trained model. + :param output_path: (str, default: `'carton'`) directory to store the + carton model. + :param model_name: (str, default: `'carton'`) save carton under this + name. + + # Return + + :returns: (`None`) + """ + logger.info(f"Model path: {model_path}") + logger.info(f"Output path: {output_path}") + logger.info("\n") + + model = LudwigModel.load(model_path) + os.makedirs(output_path, exist_ok=True) + utils_export_carton(model, output_path, model_name) + + logger.info(f"Saved to: {output_path}") + + def export_neuropod(model_path, output_path="neuropod", model_name="neuropod", **kwargs): """Exports a model to Neuropod. @@ -254,6 +281,52 @@ def cli_export_triton(sys_argv): export_triton(**vars(args)) +def cli_export_carton(sys_argv): + parser = argparse.ArgumentParser( + description="This script loads a pretrained model " "and saves it as a Carton.", + prog="ludwig export_carton", + usage="%(prog)s [options]", + ) + + # ---------------- + # Model parameters + # ---------------- + parser.add_argument("-m", "--model_path", help="model to load", required=True) + parser.add_argument("-mn", "--model_name", help="model name", default="carton") + + # ----------------- + # Output parameters + # ----------------- + parser.add_argument("-op", "--output_path", type=str, help="path where to save the export model", required=True) + + # ------------------ + # Runtime parameters + # ------------------ + parser.add_argument( + "-l", + "--logging_level", + default="info", + help="the level of logging to use", + choices=["critical", "error", "warning", "info", "debug", "notset"], + ) + + add_contrib_callback_args(parser) + args = parser.parse_args(sys_argv) + + args.callbacks = args.callbacks or [] + for callback in args.callbacks: + callback.on_cmdline("export_carton", *sys_argv) + + args.logging_level = get_logging_level_registry()[args.logging_level] + logging.getLogger("ludwig").setLevel(args.logging_level) + global logger + logger = logging.getLogger("ludwig.export") + + print_ludwig("Export Carton", LUDWIG_VERSION) + + export_carton(**vars(args)) + + def cli_export_neuropod(sys_argv): parser = argparse.ArgumentParser( description="This script loads a pretrained model " "and saves it as a Neuropod.", @@ -360,6 +433,8 @@ def cli_export_mlflow(sys_argv): cli_export_mlflow(sys.argv[2:]) elif sys.argv[1] == "triton": cli_export_triton(sys.argv[2:]) + elif sys.argv[1] == "carton": + cli_export_carton(sys.argv[2:]) elif sys.argv[1] == "neuropod": cli_export_neuropod(sys.argv[2:]) else: diff --git a/ludwig/utils/carton_utils.py b/ludwig/utils/carton_utils.py new file mode 100644 index 00000000000..f19ad84dae2 --- /dev/null +++ b/ludwig/utils/carton_utils.py @@ -0,0 +1,133 @@ +import asyncio +import importlib.util +import logging +import os +import shutil +import tempfile +from typing import Any, Dict, List + +import torch + +from ludwig.api import LudwigModel +from ludwig.api_annotations import DeveloperAPI +from ludwig.constants import NAME +from ludwig.types import ModelConfigDict +from ludwig.utils.fs_utils import open_file + +logger = logging.getLogger(__name__) + + +INFERENCE_MODULE_TEMPLATE = """ +from typing import Any, Dict, List, Tuple, Union +import torch +from ludwig.utils.types import TorchscriptPreprocessingInput + +class GeneratedInferenceModule(torch.nn.Module): + def __init__(self, inference_module): + super().__init__() + self.inference_module = inference_module + + def forward(self, inputs: Dict[str, Any]): + retyped_inputs: Dict[str, TorchscriptPreprocessingInput] = {{}} + for k, v in inputs.items(): + assert isinstance(v, TorchscriptPreprocessingInput) + retyped_inputs[k] = v + + results = self.inference_module(retyped_inputs) + return {output_dicts} +""" + + +def _get_output_dicts(config: ModelConfigDict) -> str: + results = [] + for feature in config["output_features"]: + name = feature[NAME] + results.append(f'"{name}": results["{name}"]["predictions"]') + return "{" + ", ".join(results) + "}" + + +@DeveloperAPI +def generate_carton_torchscript(model: LudwigModel): + config = model.config + inference_module = model.to_torchscript() + with tempfile.TemporaryDirectory() as tmpdir: + ts_path = os.path.join(tmpdir, "generated.py") + with open_file(ts_path, "w") as f: + f.write( + INFERENCE_MODULE_TEMPLATE.format( + output_dicts=_get_output_dicts(config), + ) + ) + + spec = importlib.util.spec_from_file_location("generated.ts", ts_path) + gen_ts = importlib.util.module_from_spec(spec) + spec.loader.exec_module(gen_ts) + + gen_module = gen_ts.GeneratedInferenceModule(inference_module) + scripted_module = torch.jit.script(gen_module) + return scripted_module + + +def _get_input_spec(model: LudwigModel) -> List[Dict[str, Any]]: + from cartonml import TensorSpec + + spec = [] + for feature_name, feature in model.model.input_features.items(): + metadata = model.training_set_metadata[feature_name] + spec.append( + TensorSpec( + name=feature.feature_name, dtype=feature.get_preproc_input_dtype(metadata), shape=("batch_size",) + ) + ) + return spec + + +def _get_output_spec(model: LudwigModel) -> List[Dict[str, Any]]: + from cartonml import TensorSpec + + spec = [] + for feature_name, feature in model.model.output_features.items(): + metadata = model.training_set_metadata[feature_name] + spec.append( + TensorSpec( + name=feature.feature_name, dtype=feature.get_postproc_output_dtype(metadata), shape=("batch_size",) + ) + ) + return spec + + +@DeveloperAPI +def export_carton(model: LudwigModel, carton_path: str, carton_model_name="ludwig_model"): + try: + import cartonml as carton + except ImportError: + raise RuntimeError('The "cartonml-nightly" package is not installed in your environment.') + + # Generate a torchscript model + model_ts = generate_carton_torchscript(model) + + with tempfile.TemporaryDirectory() as tmpdir: + # Save the model to a temp dir + input_model_path = os.path.join(tmpdir, "model.pt") + torch.jit.save(model_ts, input_model_path) + + # carton.pack is an async function so we run it and wait until it's complete + # See https://pyo3.rs/v0.20.0/ecosystem/async-await#a-note-about-asynciorun for why we wrap it + # in another function + async def pack(): + return await carton.pack( + input_model_path, + runner_name="torchscript", + # Any 2.x.x version is okay + # TODO: improve this + required_framework_version="=2", + model_name=carton_model_name, + inputs=_get_input_spec(model), + outputs=_get_output_spec(model), + ) + + loop = asyncio.get_event_loop() + tmp_out_path = loop.run_until_complete(pack()) + + # Move it to the output path + shutil.move(tmp_out_path, carton_path) diff --git a/requirements_serve.txt b/requirements_serve.txt index 160f2b2daf8..353adbb3f5c 100644 --- a/requirements_serve.txt +++ b/requirements_serve.txt @@ -3,3 +3,4 @@ httpx fastapi python-multipart neuropod==0.3.0rc6 ; platform_system != "Windows" and python_version < '3.9' +cartonml-nightly diff --git a/tests/integration_tests/test_carton.py b/tests/integration_tests/test_carton.py new file mode 100644 index 00000000000..37fb0b4e389 --- /dev/null +++ b/tests/integration_tests/test_carton.py @@ -0,0 +1,143 @@ +# Copyright (c) 2023 Predibase, Inc., 2019 Uber Technologies, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import asyncio +import os +import platform +from typing import List, Union + +import numpy as np +import pandas as pd +import pytest +import torch + +from ludwig.api import LudwigModel +from ludwig.constants import BATCH_SIZE, NAME, PREDICTIONS, TRAINER +from ludwig.utils.carton_utils import export_carton +from tests.integration_tests.utils import ( + binary_feature, + category_feature, + generate_data, + LocalTestBackend, + number_feature, +) + + +@pytest.mark.skipif(platform.system() == "Windows", reason="Carton is not supported on Windows") +def test_carton_torchscript(csv_filename, tmpdir): + data_csv_path = os.path.join(tmpdir, csv_filename) + + # Configure features to be tested: + bin_str_feature = binary_feature() + input_features = [ + bin_str_feature, + # binary_feature(), + number_feature(), + category_feature(encoder={"vocab_size": 3}), + # TODO: future support + # sequence_feature(vocab_size=3), + # text_feature(vocab_size=3), + # vector_feature(), + # image_feature(image_dest_folder), + # audio_feature(audio_dest_folder), + # timeseries_feature(), + # date_feature(), + # h3_feature(), + # set_feature(vocab_size=3), + # bag_feature(vocab_size=3), + ] + output_features = [ + bin_str_feature, + # binary_feature(), + number_feature(), + category_feature(decoder={"vocab_size": 3}, output_feature=True), + # TODO: future support + # sequence_feature(vocab_size=3), + # text_feature(vocab_size=3), + # set_feature(vocab_size=3), + # vector_feature() + ] + backend = LocalTestBackend() + config = { + "input_features": input_features, + "output_features": output_features, + TRAINER: {"epochs": 2, BATCH_SIZE: 128}, + } + + # Generate training data + training_data_csv_path = generate_data(input_features, output_features, data_csv_path) + + # Convert bool values to strings, e.g., {'Yes', 'No'} + df = pd.read_csv(training_data_csv_path) + false_value, true_value = "No", "Yes" + df[bin_str_feature[NAME]] = df[bin_str_feature[NAME]].map(lambda x: true_value if x else false_value) + df.to_csv(training_data_csv_path) + + # Train Ludwig (Pythonic) model: + ludwig_model = LudwigModel(config, backend=backend) + ludwig_model.train( + dataset=training_data_csv_path, + skip_save_training_description=True, + skip_save_training_statistics=True, + skip_save_model=True, + skip_save_progress=True, + skip_save_log=True, + skip_save_processed_input=True, + ) + + # Obtain predictions from Python model + preds_dict, _ = ludwig_model.predict(dataset=training_data_csv_path, return_type=dict) + + # Create graph inference model (Torchscript) from trained Ludwig model. + carton_path = os.path.join(tmpdir, "carton") + export_carton(ludwig_model, carton_path) + + import cartonml as carton + + # Load the carton model + # See https://pyo3.rs/v0.20.0/ecosystem/async-await#a-note-about-asynciorun for why we wrap it + # in another function + async def load(): + return await carton.load(carton_path) + + loop = asyncio.get_event_loop() + carton_model = loop.run_until_complete(load()) + + def to_input(s: pd.Series) -> Union[List[str], torch.Tensor]: + if s.dtype == "object": + return np.array(s.to_list()) + return s.to_numpy().astype(np.float32) + + df = pd.read_csv(training_data_csv_path) + inputs = {name: to_input(df[feature.column]) for name, feature in ludwig_model.model.input_features.items()} + + # See https://pyo3.rs/v0.20.0/ecosystem/async-await#a-note-about-asynciorun for why we wrap it + # in another function + async def infer(inputs): + return await carton_model.infer(inputs) + + outputs = loop.run_until_complete(infer(inputs)) + + # Compare results from Python trained model against Carton + assert len(preds_dict) == len(outputs) + for feature_name, feature_outputs_expected in preds_dict.items(): + assert feature_name in outputs + + output_values_expected = feature_outputs_expected[PREDICTIONS] + output_values = outputs[feature_name] + if output_values.dtype.type in {np.string_, np.str_}: + # Strings should match exactly + assert np.all(output_values == output_values_expected), f"feature: {feature_name}, output: predictions" + else: + assert np.allclose(output_values, output_values_expected), f"feature: {feature_name}, output: predictions"