Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow retracking of old predictions #975

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion sleap/config/pipeline_form.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ inference:
label: Training/Inference Pipeline Type
type: stacked
default: "multi-animal bottom-up "
options: "multi-animal bottom-up,multi-animal top-down,single animal,none"
options: "multi-animal bottom-up,multi-animal top-down,single animal,load predictions,none"

multi-animal bottom-up:
- type: text
Expand All @@ -224,6 +224,16 @@ inference:
For predicting on videos with more than one animal per frame, use a
multi-animal pipeline (even if your training data has one instance per frame).'

load predictions:
- type: bool
label: Use predictions from other file
name: use_prediction_file
default: false
- type: file_open
label: Predictions file
name: _predicted_labels
filter: '*.slp'

none:

- name: tracking.tracker
Expand Down
20 changes: 17 additions & 3 deletions sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,27 @@ def make_predict_cli_call(
cli_args.extend(item_for_inference.cli_args)

# TODO: encapsulate in inference item class
only_tracking = False
_labels = self.inference_params.get("_predicted_labels", None)
# Make the 'use_prediction_file' arg value-less.
use_prediction_file = False
if "use_prediction_file" in self.inference_params:
use_prediction_file = self.inference_params["use_prediction_file"]
del self.inference_params["use_prediction_file"]

if (
not self.trained_job_paths
and "tracking.tracker" in self.inference_params
and self.labels_filename
):
# No models so we must want to re-track previous predictions
cli_args.extend(("--labels", self.labels_filename))
only_tracking = True
if use_prediction_file and _labels:
cli_args.extend(("--use_prediction_file",))
else:
# Use the project filename for labels
_labels = self.labels_filename
cli_args.extend(("--labels", _labels))

# Make path where we'll save predictions (if not specified)
if output_path is None:
Expand All @@ -227,10 +241,10 @@ def make_predict_cli_call(

# Build filename with video name and timestamp
timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
mode = "predictions" if not only_tracking else "tracking"
output_path = os.path.join(
predictions_dir,
f"{os.path.basename(item_for_inference.path)}.{timestamp}."
"predictions.slp",
f"{os.path.basename(item_for_inference.path)}.{timestamp}.{mode}.slp",
)

for job_path in self.trained_job_paths:
Expand Down
58 changes: 41 additions & 17 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,6 @@ def get_keras_model_path(path: Text) -> str:
return os.path.join(path, "best_model.h5")


class RateColumn(rich.progress.ProgressColumn):
"""Renders the progress rate."""

def render(self, task: "Task") -> rich.progress.Text:
"""Show progress rate."""
speed = task.speed
if speed is None:
return rich.progress.Text("?", style="progress.data.speed")
return rich.progress.Text(f"{speed:.1f} FPS", style="progress.data.speed")


@attr.s(auto_attribs=True)
class Predictor(ABC):
"""Base interface class for predictors."""
Expand Down Expand Up @@ -4318,6 +4307,17 @@ def _make_cli_parser() -> argparse.ArgumentParser:
"saving to output."
),
)
parser.add_argument(
"--use_prediction_file",
action="store_true",
default=False,
help=(
"For tracking-only, use the --labels argument to define the file "
"containing the predicted instances to be tracked. "
"If this option is not defined, use all the frames from the project "
"or the specified frames in the specified video."
),
)
parser.add_argument(
"--verbosity",
type=str,
Expand Down Expand Up @@ -4474,7 +4474,8 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:
"""
# Figure out which input path to use.
labels_path = getattr(args, "labels", None)
if labels_path is not None:
has_models = getattr(args, "models", None)
if labels_path is not None and has_models:
data_path = labels_path
else:
data_path = args.data_path
Expand Down Expand Up @@ -4645,6 +4646,7 @@ def main(args: list = None):

# Setup data loader.
provider, data_path = _make_provider_from_cli(args)
n_infer = len(provider)

# Setup tracker.
tracker = _make_tracker_from_cli(args)
Expand All @@ -4668,11 +4670,33 @@ def main(args: list = None):
elif getattr(args, "tracking.tracker") is not None:
# Load predictions
print("Loading predictions...")
labels_pr = sleap.load_file(args.data_path)
frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx)
use_labels_predictions = getattr(args, "use_prediction_file", False)
labels_path = getattr(args, "labels", None)
if use_labels_predictions and labels_path:
labels_pr = sleap.load_file(labels_path)
else:
labels_pr = sleap.load_file(args.data_path)

if isinstance(provider, VideoReader):
fr_list = frame_list(args.frames)
labels_pr = [
lf
for lf in labels_pr
if lf.video.filename == provider.video.filename
and lf.frame_idx in fr_list
]
elif isinstance(provider, LabelsReader):
labels_pr = provider.labels.labeled_frames
else:
mess = f"Labels should be user defined or coming from a video: {type(provider)}"
raise ValueError(mess)

# Sort
frames = sorted(labels_pr, key=lambda lf: lf.frame_idx)
n_infer = len(frames)
print(f"... found {n_infer} frames to track")

print("Starting tracker...")
frames = run_tracker(frames=frames, tracker=tracker)
frames = run_tracker(frames=frames, tracker=tracker, verbosity=args.verbosity)
tracker.final_pass(frames)

labels_pr = Labels(labeled_frames=frames)
Expand All @@ -4696,7 +4720,7 @@ def main(args: list = None):
total_elapsed = time() - t0
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")
print(f"Predicted frames: {len(labels_pr)}/{n_infer}")

# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
Expand Down
Loading