Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional unragging arg to model export #1054

Merged
merged 1 commit into from
Nov 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 45 additions & 8 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def export_model(
save_traces: bool = True,
model_name: Optional[str] = None,
tensors: Optional[Dict[str, str]] = None,
unrag_outputs: bool = True,
):

"""Export a trained SLEAP model as a frozen graph. Initializes model,
Expand All @@ -467,6 +468,8 @@ def export_model(
model
tensors: (Optional) Dictionary describing the predicted tensors (see
sleap.nn.data.utils.describe_tensors as an example)
unrag_outputs: If `True` (default), any ragged tensors will be
converted to normal tensors and padded with NaNs

"""

Expand All @@ -485,7 +488,7 @@ def export_model(
outputs = self.inference_model.predict(tracing_batch)

self.inference_model.export_model(
save_path, signatures, save_traces, model_name, tensors
save_path, signatures, save_traces, model_name, tensors, unrag_outputs
)


Expand Down Expand Up @@ -980,6 +983,7 @@ def export_model(
save_traces: bool = True,
model_name: Optional[str] = None,
tensors: Optional[Dict[str, str]] = None,
unrag_outputs: bool = True,
):
"""Save the frozen graph of a model.

Expand All @@ -994,6 +998,8 @@ def export_model(
model
tensors: (Optional) Dictionary describing the predicted tensors (see
sleap.nn.data.utils.describe_tensors as an example)
unrag_outputs: If `True` (default), any ragged tensors will be
converted to normal tensors and padded with NaNs


Notes:
Expand Down Expand Up @@ -1021,7 +1027,11 @@ def export_model(
if tensors:
info["predicted_tensors"] = tensors

full_model = tf.function(lambda x: model(x))
full_model = tf.function(
lambda x: sleap.nn.data.utils.unrag_example(model(x), numpy=False)
if unrag_outputs
else model(x)
)

full_model = full_model.get_concrete_function(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype)
Expand Down Expand Up @@ -1458,9 +1468,12 @@ def export_model(
save_traces: bool = True,
model_name: Optional[str] = None,
tensors: Optional[Dict[str, str]] = None,
unrag_outputs: bool = True,
):

super().export_model(save_path, signatures, save_traces, model_name, tensors)
super().export_model(
save_path, signatures, save_traces, model_name, tensors, unrag_outputs
)

self.confmap_config.save_json(os.path.join(save_path, "confmap_config.json"))

Expand Down Expand Up @@ -2376,9 +2389,12 @@ def export_model(
save_traces: bool = True,
model_name: Optional[str] = None,
tensors: Optional[Dict[str, str]] = None,
unrag_outputs: bool = True,
):

super().export_model(save_path, signatures, save_traces, model_name, tensors)
super().export_model(
save_path, signatures, save_traces, model_name, tensors, unrag_outputs
)

if self.confmap_config is not None:
self.confmap_config.save_json(
Expand Down Expand Up @@ -3735,11 +3751,14 @@ def export_model(
save_traces: bool = True,
model_name: Optional[str] = None,
tensors: Optional[Dict[str, str]] = None,
unrag_outputs: bool = True,
):

self.instance_peaks.optimal_grouping = False

super().export_model(save_path, signatures, save_traces, model_name, tensors)
super().export_model(
save_path, signatures, save_traces, model_name, tensors, unrag_outputs
)


@attr.s(auto_attribs=True)
Expand Down Expand Up @@ -4066,9 +4085,12 @@ def export_model(
save_traces: bool = True,
model_name: Optional[str] = None,
tensors: Optional[Dict[str, str]] = None,
unrag_outputs: bool = True,
):

super().export_model(save_path, signatures, save_traces, model_name, tensors)
super().export_model(
save_path, signatures, save_traces, model_name, tensors, unrag_outputs
)

if self.confmap_config is not None:
self.confmap_config.save_json(
Expand Down Expand Up @@ -4192,6 +4214,7 @@ def export_model(
save_traces: bool = True,
model_name: Optional[str] = None,
tensors: Optional[Dict[str, str]] = None,
unrag_outputs: bool = True,
):
"""High level export of a trained SLEAP model as a frozen graph.

Expand All @@ -4207,9 +4230,13 @@ def export_model(
output json file containing meta information about the model.
tensors: (Optional) Dictionary describing the predicted tensors (see
sleap.nn.data.utils.describe_tensors as an example).
unrag_outputs: If `True` (default), any ragged tensors will be
converted to normal tensors and padded with NaNs
"""
predictor = load_model(model_path)
predictor.export_model(save_path, signatures, save_traces, model_name, tensors)
predictor.export_model(
save_path, signatures, save_traces, model_name, tensors, unrag_outputs
)


def export_cli():
Expand All @@ -4236,9 +4263,19 @@ def export_cli():
"Defaults to a folder named 'exported_model'."
),
)
parser.add_argument(
"-u",
"--unrag",
action="store_true",
default=True,
help=(
"Convert ragged tensors into regular tensors with NaN padding. "
"Defaults to True."
),
)

args, _ = parser.parse_known_args()
export_model(args.models, args.export_path)
export_model(args.models, args.export_path, unrag_outputs=args.unrag)


def _make_cli_parser() -> argparse.ArgumentParser:
Expand Down
28 changes: 24 additions & 4 deletions tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,10 +998,16 @@ def test_single_instance_predictor_save(min_single_instance_robot_model_path, tm

predictor.export_model(save_path=tmp_path.as_posix())

# high level export

# high level export (with unragging)
export_model(min_single_instance_robot_model_path, save_path=tmp_path.as_posix())

# high level export (without unragging)
export_model(
min_single_instance_robot_model_path,
save_path=tmp_path.as_posix(),
unrag_outputs=False,
)


def test_topdown_predictor_save(
min_centroid_model_path, min_centered_instance_model_path, tmp_path
Expand All @@ -1020,12 +1026,19 @@ def test_topdown_predictor_save(

predictor.export_model(save_path=tmp_path.as_posix())

# high level export
# high level export (with unragging)
export_model(
[min_centroid_model_path, min_centered_instance_model_path],
save_path=tmp_path.as_posix(),
)

# high level export (without unragging)
export_model(
[min_centroid_model_path, min_centered_instance_model_path],
save_path=tmp_path.as_posix(),
unrag_outputs=False,
)


def test_topdown_id_predictor_save(
min_centroid_model_path, min_topdown_multiclass_model_path, tmp_path
Expand All @@ -1044,10 +1057,17 @@ def test_topdown_id_predictor_save(

predictor.export_model(save_path=tmp_path.as_posix())

# high level export
# high level export (with unragging)
export_model(
[min_centroid_model_path, min_topdown_multiclass_model_path],
save_path=tmp_path.as_posix(),
)

# high level export (without unragging)
export_model(
[min_centroid_model_path, min_topdown_multiclass_model_path],
save_path=tmp_path.as_posix(),
unrag_outputs=False,
)


Expand Down