Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for exporting models to Carton #3797

Merged
merged 1 commit into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ludwig/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(self):
datasets Downloads and lists Ludwig-ready datasets
export_torchscript Exports Ludwig models to Torchscript
export_triton Exports Ludwig models to Triton
export_carton Exports Ludwig models to Carton
export_neuropod Exports Ludwig models to Neuropod
export_mlflow Exports Ludwig models to MLflow
preprocess Preprocess data and saves it into HDF5 and JSON format
Expand Down Expand Up @@ -140,6 +141,11 @@ def export_triton(self):

export.cli_export_triton(sys.argv[2:])

def export_carton(self):
from ludwig import export

export.cli_export_carton(sys.argv[2:])

def export_neuropod(self):
from ludwig import export

Expand Down
75 changes: 75 additions & 0 deletions ludwig/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ludwig.api import LudwigModel
from ludwig.contrib import add_contrib_callback_args
from ludwig.globals import LUDWIG_VERSION
from ludwig.utils.carton_utils import export_carton as utils_export_carton
from ludwig.utils.neuropod_utils import export_neuropod as utils_export_neuropod
from ludwig.utils.print_utils import get_logging_level_registry, print_ludwig
from ludwig.utils.triton_utils import export_triton as utils_export_triton
Expand Down Expand Up @@ -88,6 +89,32 @@ def export_triton(model_path, output_path="model_repository", model_name="ludwig
logger.info(f"Saved to: {output_path}")


def export_carton(model_path, output_path="carton", model_name="carton", **kwargs):
"""Exports a model to Carton.

# Inputs

:param model_path: (str) filepath to pre-trained model.
:param output_path: (str, default: `'carton'`) directory to store the
carton model.
:param model_name: (str, default: `'carton'`) save carton under this
name.

# Return

:returns: (`None`)
"""
logger.info(f"Model path: {model_path}")
logger.info(f"Output path: {output_path}")
logger.info("\n")

model = LudwigModel.load(model_path)
os.makedirs(output_path, exist_ok=True)
utils_export_carton(model, output_path, model_name)

logger.info(f"Saved to: {output_path}")


def export_neuropod(model_path, output_path="neuropod", model_name="neuropod", **kwargs):
"""Exports a model to Neuropod.

Expand Down Expand Up @@ -254,6 +281,52 @@ def cli_export_triton(sys_argv):
export_triton(**vars(args))


def cli_export_carton(sys_argv):
parser = argparse.ArgumentParser(
description="This script loads a pretrained model " "and saves it as a Carton.",
prog="ludwig export_carton",
usage="%(prog)s [options]",
)

# ----------------
# Model parameters
# ----------------
parser.add_argument("-m", "--model_path", help="model to load", required=True)
parser.add_argument("-mn", "--model_name", help="model name", default="carton")

# -----------------
# Output parameters
# -----------------
parser.add_argument("-op", "--output_path", type=str, help="path where to save the export model", required=True)

# ------------------
# Runtime parameters
# ------------------
parser.add_argument(
"-l",
"--logging_level",
default="info",
help="the level of logging to use",
choices=["critical", "error", "warning", "info", "debug", "notset"],
)

add_contrib_callback_args(parser)
args = parser.parse_args(sys_argv)

args.callbacks = args.callbacks or []
for callback in args.callbacks:
callback.on_cmdline("export_carton", *sys_argv)

args.logging_level = get_logging_level_registry()[args.logging_level]
logging.getLogger("ludwig").setLevel(args.logging_level)
global logger
logger = logging.getLogger("ludwig.export")

print_ludwig("Export Carton", LUDWIG_VERSION)

export_carton(**vars(args))


def cli_export_neuropod(sys_argv):
parser = argparse.ArgumentParser(
description="This script loads a pretrained model " "and saves it as a Neuropod.",
Expand Down Expand Up @@ -360,6 +433,8 @@ def cli_export_mlflow(sys_argv):
cli_export_mlflow(sys.argv[2:])
elif sys.argv[1] == "triton":
cli_export_triton(sys.argv[2:])
elif sys.argv[1] == "carton":
cli_export_carton(sys.argv[2:])
elif sys.argv[1] == "neuropod":
cli_export_neuropod(sys.argv[2:])
else:
Expand Down
133 changes: 133 additions & 0 deletions ludwig/utils/carton_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import asyncio
import importlib.util
import logging
import os
import shutil
import tempfile
from typing import Any, Dict, List

import torch

from ludwig.api import LudwigModel
from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import NAME
from ludwig.types import ModelConfigDict
from ludwig.utils.fs_utils import open_file

logger = logging.getLogger(__name__)


INFERENCE_MODULE_TEMPLATE = """
from typing import Any, Dict, List, Tuple, Union
import torch
from ludwig.utils.types import TorchscriptPreprocessingInput
class GeneratedInferenceModule(torch.nn.Module):
def __init__(self, inference_module):
super().__init__()
self.inference_module = inference_module
def forward(self, inputs: Dict[str, Any]):
retyped_inputs: Dict[str, TorchscriptPreprocessingInput] = {{}}
for k, v in inputs.items():
assert isinstance(v, TorchscriptPreprocessingInput)
retyped_inputs[k] = v
results = self.inference_module(retyped_inputs)
return {output_dicts}
"""


def _get_output_dicts(config: ModelConfigDict) -> str:
results = []
for feature in config["output_features"]:
name = feature[NAME]
results.append(f'"{name}": results["{name}"]["predictions"]')
return "{" + ", ".join(results) + "}"


@DeveloperAPI
def generate_carton_torchscript(model: LudwigModel):
config = model.config
inference_module = model.to_torchscript()
with tempfile.TemporaryDirectory() as tmpdir:
ts_path = os.path.join(tmpdir, "generated.py")
with open_file(ts_path, "w") as f:
f.write(
INFERENCE_MODULE_TEMPLATE.format(
output_dicts=_get_output_dicts(config),
)
)

spec = importlib.util.spec_from_file_location("generated.ts", ts_path)
gen_ts = importlib.util.module_from_spec(spec)
spec.loader.exec_module(gen_ts)

gen_module = gen_ts.GeneratedInferenceModule(inference_module)
scripted_module = torch.jit.script(gen_module)
return scripted_module


def _get_input_spec(model: LudwigModel) -> List[Dict[str, Any]]:
from cartonml import TensorSpec

spec = []
for feature_name, feature in model.model.input_features.items():
metadata = model.training_set_metadata[feature_name]
spec.append(
TensorSpec(
name=feature.feature_name, dtype=feature.get_preproc_input_dtype(metadata), shape=("batch_size",)
)
)
return spec


def _get_output_spec(model: LudwigModel) -> List[Dict[str, Any]]:
from cartonml import TensorSpec

spec = []
for feature_name, feature in model.model.output_features.items():
metadata = model.training_set_metadata[feature_name]
spec.append(
TensorSpec(
name=feature.feature_name, dtype=feature.get_postproc_output_dtype(metadata), shape=("batch_size",)
)
)
return spec


@DeveloperAPI
def export_carton(model: LudwigModel, carton_path: str, carton_model_name="ludwig_model"):
try:
import cartonml as carton
except ImportError:
raise RuntimeError('The "cartonml-nightly" package is not installed in your environment.')

# Generate a torchscript model
model_ts = generate_carton_torchscript(model)

with tempfile.TemporaryDirectory() as tmpdir:
# Save the model to a temp dir
input_model_path = os.path.join(tmpdir, "model.pt")
torch.jit.save(model_ts, input_model_path)

# carton.pack is an async function so we run it and wait until it's complete
# See https://pyo3.rs/v0.20.0/ecosystem/async-await#a-note-about-asynciorun for why we wrap it
# in another function
async def pack():
return await carton.pack(
input_model_path,
runner_name="torchscript",
# Any 2.x.x version is okay
# TODO: improve this
required_framework_version="=2",
model_name=carton_model_name,
inputs=_get_input_spec(model),
outputs=_get_output_spec(model),
)

loop = asyncio.get_event_loop()
tmp_out_path = loop.run_until_complete(pack())

# Move it to the output path
shutil.move(tmp_out_path, carton_path)
1 change: 1 addition & 0 deletions requirements_serve.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ httpx
fastapi
python-multipart
neuropod==0.3.0rc6 ; platform_system != "Windows" and python_version < '3.9'
cartonml-nightly
Loading
Loading