-
Notifications
You must be signed in to change notification settings - Fork 81
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Issue #182: added tools/get_model_signature.py
- Loading branch information
amesar
committed
Jun 3, 2024
1 parent
9177351
commit b7bec85
Showing
2 changed files
with
75 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import click | ||
|
||
def opt_input_file(function): | ||
function = click.option("--input-file", | ||
help="Input file.", | ||
type=str, | ||
required=True | ||
)(function) | ||
return function | ||
|
||
def opt_output_file(function): | ||
function = click.option("--output-file", | ||
help="Output file.", | ||
type=str, | ||
required=False | ||
)(function) | ||
return function | ||
|
||
def opt_model_uri(function): | ||
function = click.option("--model-uri", | ||
help="Model URI such as 'models:/my_model/3' or 'runs:/73ab168e5775409fa3595157a415bb62/my_model'.", | ||
type=str, | ||
required=True | ||
)(function) | ||
return function | ||
|
||
def opt_filter(function): | ||
function = click.option("--filter", | ||
help="For OSS MLflow this is a filter for search_model_version(), for Databricks it is for search_registered_models() due to Databricks MLflow search limitations.", | ||
type=str, | ||
required=False | ||
)(function) | ||
return function |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
""" | ||
Get the signature of an MLflow model. | ||
""" | ||
|
||
import click | ||
import mlflow | ||
from mlflow_export_import.common import io_utils | ||
from mlflow_export_import.common.dump_utils import dump_as_json | ||
from . click_options import opt_model_uri, opt_output_file | ||
from . tools_utils import to_json_signature | ||
|
||
|
||
def get(model_uri): | ||
model_info = mlflow.models.get_model_info(model_uri) | ||
if model_info.signature: | ||
sig = model_info.signature.to_dict() | ||
return to_json_signature(sig) | ||
else: | ||
return None | ||
|
||
|
||
@click.command() | ||
@opt_model_uri | ||
@opt_output_file | ||
def main(model_uri, output_file): | ||
""" | ||
Get the signature of an MLflow model. | ||
""" | ||
print("Options:") | ||
for k,v in locals().items(): | ||
print(f" {k}: {v}") | ||
signature = get(model_uri) | ||
if signature: | ||
print("Model Signature:") | ||
dump_as_json(signature) | ||
if output_file: | ||
io_utils.write_file(output_file, signature) | ||
else: | ||
print(f"WARNING: No model signature for '{model_uri}'") | ||
|
||
if __name__ == "__main__": | ||
main() |