From 0f7573906516a47ee375a763469456064e022d14 Mon Sep 17 00:00:00 2001 From: sheridana Date: Mon, 21 Nov 2022 11:44:10 -0800 Subject: [PATCH] Add optional unragging arg to model export --- sleap/nn/inference.py | 53 ++++++++++++++++++++++++++++++++------ tests/nn/test_inference.py | 28 +++++++++++++++++--- 2 files changed, 69 insertions(+), 12 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 3415874a4..41a3984ef 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -450,6 +450,7 @@ def export_model( save_traces: bool = True, model_name: Optional[str] = None, tensors: Optional[Dict[str, str]] = None, + unrag_outputs: bool = True, ): """Export a trained SLEAP model as a frozen graph. Initializes model, @@ -467,6 +468,8 @@ def export_model( model tensors: (Optional) Dictionary describing the predicted tensors (see sleap.nn.data.utils.describe_tensors as an example) + unrag_outputs: If `True` (default), any ragged tensors will be + converted to normal tensors and padded with NaNs """ @@ -485,7 +488,7 @@ def export_model( outputs = self.inference_model.predict(tracing_batch) self.inference_model.export_model( - save_path, signatures, save_traces, model_name, tensors + save_path, signatures, save_traces, model_name, tensors, unrag_outputs ) @@ -980,6 +983,7 @@ def export_model( save_traces: bool = True, model_name: Optional[str] = None, tensors: Optional[Dict[str, str]] = None, + unrag_outputs: bool = True, ): """Save the frozen graph of a model. @@ -994,6 +998,8 @@ def export_model( model tensors: (Optional) Dictionary describing the predicted tensors (see sleap.nn.data.utils.describe_tensors as an example) + unrag_outputs: If `True` (default), any ragged tensors will be + converted to normal tensors and padded with NaNs Notes: @@ -1021,7 +1027,11 @@ def export_model( if tensors: info["predicted_tensors"] = tensors - full_model = tf.function(lambda x: model(x)) + full_model = tf.function( + lambda x: sleap.nn.data.utils.unrag_example(model(x), numpy=False) + if unrag_outputs + else model(x) + ) full_model = full_model.get_concrete_function( tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype) @@ -1458,9 +1468,12 @@ def export_model( save_traces: bool = True, model_name: Optional[str] = None, tensors: Optional[Dict[str, str]] = None, + unrag_outputs: bool = True, ): - super().export_model(save_path, signatures, save_traces, model_name, tensors) + super().export_model( + save_path, signatures, save_traces, model_name, tensors, unrag_outputs + ) self.confmap_config.save_json(os.path.join(save_path, "confmap_config.json")) @@ -2376,9 +2389,12 @@ def export_model( save_traces: bool = True, model_name: Optional[str] = None, tensors: Optional[Dict[str, str]] = None, + unrag_outputs: bool = True, ): - super().export_model(save_path, signatures, save_traces, model_name, tensors) + super().export_model( + save_path, signatures, save_traces, model_name, tensors, unrag_outputs + ) if self.confmap_config is not None: self.confmap_config.save_json( @@ -3735,11 +3751,14 @@ def export_model( save_traces: bool = True, model_name: Optional[str] = None, tensors: Optional[Dict[str, str]] = None, + unrag_outputs: bool = True, ): self.instance_peaks.optimal_grouping = False - super().export_model(save_path, signatures, save_traces, model_name, tensors) + super().export_model( + save_path, signatures, save_traces, model_name, tensors, unrag_outputs + ) @attr.s(auto_attribs=True) @@ -4066,9 +4085,12 @@ def export_model( save_traces: bool = True, model_name: Optional[str] = None, tensors: Optional[Dict[str, str]] = None, + unrag_outputs: bool = True, ): - super().export_model(save_path, signatures, save_traces, model_name, tensors) + super().export_model( + save_path, signatures, save_traces, model_name, tensors, unrag_outputs + ) if self.confmap_config is not None: self.confmap_config.save_json( @@ -4192,6 +4214,7 @@ def export_model( save_traces: bool = True, model_name: Optional[str] = None, tensors: Optional[Dict[str, str]] = None, + unrag_outputs: bool = True, ): """High level export of a trained SLEAP model as a frozen graph. @@ -4207,9 +4230,13 @@ def export_model( output json file containing meta information about the model. tensors: (Optional) Dictionary describing the predicted tensors (see sleap.nn.data.utils.describe_tensors as an example). + unrag_outputs: If `True` (default), any ragged tensors will be + converted to normal tensors and padded with NaNs """ predictor = load_model(model_path) - predictor.export_model(save_path, signatures, save_traces, model_name, tensors) + predictor.export_model( + save_path, signatures, save_traces, model_name, tensors, unrag_outputs + ) def export_cli(): @@ -4236,9 +4263,19 @@ def export_cli(): "Defaults to a folder named 'exported_model'." ), ) + parser.add_argument( + "-u", + "--unrag", + action="store_true", + default=True, + help=( + "Convert ragged tensors into regular tensors with NaN padding. " + "Defaults to True." + ), + ) args, _ = parser.parse_known_args() - export_model(args.models, args.export_path) + export_model(args.models, args.export_path, unrag_outputs=args.unrag) def _make_cli_parser() -> argparse.ArgumentParser: diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index b96fe867d..7a4dec870 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -998,10 +998,16 @@ def test_single_instance_predictor_save(min_single_instance_robot_model_path, tm predictor.export_model(save_path=tmp_path.as_posix()) - # high level export - + # high level export (with unragging) export_model(min_single_instance_robot_model_path, save_path=tmp_path.as_posix()) + # high level export (without unragging) + export_model( + min_single_instance_robot_model_path, + save_path=tmp_path.as_posix(), + unrag_outputs=False, + ) + def test_topdown_predictor_save( min_centroid_model_path, min_centered_instance_model_path, tmp_path @@ -1020,12 +1026,19 @@ def test_topdown_predictor_save( predictor.export_model(save_path=tmp_path.as_posix()) - # high level export + # high level export (with unragging) export_model( [min_centroid_model_path, min_centered_instance_model_path], save_path=tmp_path.as_posix(), ) + # high level export (without unragging) + export_model( + [min_centroid_model_path, min_centered_instance_model_path], + save_path=tmp_path.as_posix(), + unrag_outputs=False, + ) + def test_topdown_id_predictor_save( min_centroid_model_path, min_topdown_multiclass_model_path, tmp_path @@ -1044,10 +1057,17 @@ def test_topdown_id_predictor_save( predictor.export_model(save_path=tmp_path.as_posix()) - # high level export + # high level export (with unragging) + export_model( + [min_centroid_model_path, min_topdown_multiclass_model_path], + save_path=tmp_path.as_posix(), + ) + + # high level export (without unragging) export_model( [min_centroid_model_path, min_topdown_multiclass_model_path], save_path=tmp_path.as_posix(), + unrag_outputs=False, )