Skip to content

Commit

Permalink
Issue #182: added tools/get_model_signature.py
Browse files Browse the repository at this point in the history
  • Loading branch information
amesar committed Jun 3, 2024
1 parent 9177351 commit b7bec85
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 0 deletions.
33 changes: 33 additions & 0 deletions mlflow_export_import/tools/click_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import click

def opt_input_file(function):
function = click.option("--input-file",
help="Input file.",
type=str,
required=True
)(function)
return function

def opt_output_file(function):
function = click.option("--output-file",
help="Output file.",
type=str,
required=False
)(function)
return function

def opt_model_uri(function):
function = click.option("--model-uri",
help="Model URI such as 'models:/my_model/3' or 'runs:/73ab168e5775409fa3595157a415bb62/my_model'.",
type=str,
required=True
)(function)
return function

def opt_filter(function):
function = click.option("--filter",
help="For OSS MLflow this is a filter for search_model_version(), for Databricks it is for search_registered_models() due to Databricks MLflow search limitations.",
type=str,
required=False
)(function)
return function
42 changes: 42 additions & 0 deletions mlflow_export_import/tools/get_model_signature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Get the signature of an MLflow model.
"""

import click
import mlflow
from mlflow_export_import.common import io_utils
from mlflow_export_import.common.dump_utils import dump_as_json
from . click_options import opt_model_uri, opt_output_file
from . tools_utils import to_json_signature


def get(model_uri):
model_info = mlflow.models.get_model_info(model_uri)
if model_info.signature:
sig = model_info.signature.to_dict()
return to_json_signature(sig)
else:
return None


@click.command()
@opt_model_uri
@opt_output_file
def main(model_uri, output_file):
"""
Get the signature of an MLflow model.
"""
print("Options:")
for k,v in locals().items():
print(f" {k}: {v}")
signature = get(model_uri)
if signature:
print("Model Signature:")
dump_as_json(signature)
if output_file:
io_utils.write_file(output_file, signature)
else:
print(f"WARNING: No model signature for '{model_uri}'")

if __name__ == "__main__":
main()

0 comments on commit b7bec85

Please sign in to comment.