Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing inference on multiple videos via sleap-track #1784

Merged
merged 29 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
132a1ea
implementing proposed code changes from issue #1777
emdavis02 May 24, 2024
e867ec0
comments
emdavis02 May 24, 2024
babaa77
configuring output_path to support multiple video inputs
emdavis02 May 24, 2024
83f444a
fixing errors from preexisting test cases
emdavis02 May 24, 2024
b0ac880
Test case / code fixes
emdavis02 Jun 12, 2024
dcc7a63
extending test cases for mp4 folders
emdavis02 Jun 13, 2024
35db452
test case for output directory
emdavis02 Jun 18, 2024
6f0c929
black and code rabbit fixes
emdavis02 Jun 24, 2024
bd2b016
code rabbit fixes
emdavis02 Jun 24, 2024
ec4c26d
as_posix errors resolved
emdavis02 Jun 27, 2024
abdc57c
syntax error
emdavis02 Jul 8, 2024
5ffdc96
adding test data
emdavis02 Jul 8, 2024
f179f5e
black
emdavis02 Jul 8, 2024
af565cb
output error resolved
emdavis02 Jul 8, 2024
8568cc3
edited for push to dev branch
emdavis02 Jul 8, 2024
ead7af8
black
emdavis02 Jul 8, 2024
8f0df1c
errors fixed, test cases implemented
emdavis02 Jul 8, 2024
760059f
invalid output test and invalid input test
emdavis02 Jul 9, 2024
ff706d8
deleting debugging statements
emdavis02 Jul 9, 2024
beb5e1e
deleting print statements
emdavis02 Jul 9, 2024
55bfe4b
black
emdavis02 Jul 10, 2024
3b9cd45
deleting unnecessary test case
emdavis02 Jul 10, 2024
be02a7d
implemented tmpdir
emdavis02 Jul 10, 2024
6a481c3
deleting extraneous file
emdavis02 Jul 10, 2024
488edde
fixing broken test case
emdavis02 Jul 12, 2024
4443686
fixing test_sleap_track_invalid_output
emdavis02 Jul 12, 2024
d86123d
removing support for multiple slp files
emdavis02 Jul 15, 2024
ae11b8d
implementing talmo's comments
emdavis02 Jul 15, 2024
fb587e5
adding comments
emdavis02 Jul 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 201 additions & 86 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5288,46 +5288,79 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:
A tuple of `(provider, data_path)` with the data `Provider` and path to the data
that was specified in the args.
"""

# Figure out which input path to use.
labels_path = getattr(args, "labels", None)
if labels_path is not None:
data_path = labels_path
else:
data_path = args.data_path
data_path = args.data_path

if data_path is None or data_path == "":
raise ValueError(
"You must specify a path to a video or a labels dataset. "
"Run 'sleap-track -h' to see full command documentation."
)

if data_path.endswith(".slp"):
labels = sleap.load_file(data_path)

if args.only_labeled_frames:
provider = LabelsReader.from_user_labeled_frames(labels)
elif args.only_suggested_frames:
provider = LabelsReader.from_unlabeled_suggestions(labels)
elif getattr(args, "video.index") != "":
provider = VideoReader(
video=labels.videos[int(getattr(args, "video.index"))],
example_indices=frame_list(args.frames),
)
else:
provider = LabelsReader(labels)
data_path_obj = Path(data_path)

# Check for multiple video inputs
# Compile file(s) into a list for later itteration
if data_path_obj.is_dir():
data_path_list = []
for file_path in data_path_obj.iterdir():
if file_path.is_file():
data_path_list.append(Path(file_path))
elif data_path_obj.is_file():
data_path_list = [data_path_obj]

# Provider list to accomodate multiple video inputs
output_provider_list = []
output_data_path_list = []
for file_path in data_path_list:
# Create a provider for each file
if file_path.as_posix().endswith(".slp") and len(data_path_list) > 1:
print(f"slp file skipped: {file_path.as_posix()}")

elif file_path.as_posix().endswith(".slp"):
labels = sleap.load_file(file_path.as_posix())

if args.only_labeled_frames:
output_provider_list.append(
LabelsReader.from_user_labeled_frames(labels)
)
elif args.only_suggested_frames:
output_provider_list.append(
LabelsReader.from_unlabeled_suggestions(labels)
)
elif getattr(args, "video.index") != "":
output_provider_list.append(
VideoReader(
video=labels.videos[int(getattr(args, "video.index"))],
example_indices=frame_list(args.frames),
)
)
else:
output_provider_list.append(LabelsReader(labels))

else:
print(f"Video: {data_path}")
# TODO: Clean this up.
video_kwargs = dict(
dataset=vars(args).get("video.dataset"),
input_format=vars(args).get("video.input_format"),
)
provider = VideoReader.from_filepath(
filename=data_path, example_indices=frame_list(args.frames), **video_kwargs
)
output_data_path_list.append(file_path)

else:
try:
video_kwargs = dict(
dataset=vars(args).get("video.dataset"),
input_format=vars(args).get("video.input_format"),
)
output_provider_list.append(
VideoReader.from_filepath(
filename=file_path.as_posix(),
example_indices=frame_list(args.frames),
**video_kwargs,
)
)
print(f"Video: {file_path.as_posix()}")
output_data_path_list.append(file_path)
# TODO: Clean this up.
except Exception:
print(f"Error reading file: {file_path.as_posix()}")
Comment on lines +5349 to +5365
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle exceptions when reading files.

The code should handle exceptions more gracefully when reading files to avoid partial processing.

+ import logging

  for file_path in data_path_list:
      try:
          # Create a provider for each file
          if file_path.as_posix().endswith(".slp"):
              labels = sleap.load_file(file_path.as_posix())
              ...
          else:
              video_kwargs = dict(
                  dataset=vars(args).get("video.dataset"),
                  input_format=vars(args).get("video.input_format"),
              )
              output_provider_list.append(
                  VideoReader.from_filepath(
                      filename=file_path.as_posix(),
                      example_indices=frame_list(args.frames),
                      **video_kwargs,
                  )
              )
              print(f"Video: {file_path.as_posix()}")
              output_data_path_list.append(file_path)
      except Exception as e:
          logging.error(f"Error reading file: {file_path.as_posix()} - {e}")

Committable suggestion was skipped due to low confidence.


return provider, data_path
return output_provider_list, output_data_path_list


def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor:
Expand Down Expand Up @@ -5422,8 +5455,6 @@ def main(args: Optional[list] = None):
pprint(vars(args))
print()

output_path = args.output

# Setup devices.
if args.cpu or not sleap.nn.system.is_gpu_system():
sleap.nn.system.use_cpu_only()
Expand Down Expand Up @@ -5461,43 +5492,160 @@ def main(args: Optional[list] = None):
print()

# Setup data loader.
provider, data_path = _make_provider_from_cli(args)
provider_list, data_path_list = _make_provider_from_cli(args)

# Setup tracker.
tracker = _make_tracker_from_cli(args)

output_path = args.output

if args.models is not None and "movenet" in args.models[0]:
args.models = args.models[0]

# Either run inference (and tracking) or just run tracking
# Either run inference (and tracking) or just run tracking (if using an existing prediction where inference has already been run)
if args.models is not None:
# Setup models.
predictor = _make_predictor_from_cli(args)
predictor.tracker = tracker

# Run inference!
labels_pr = predictor.predict(provider)
# Run inference on all files inputed
for data_path, provider in zip(data_path_list, provider_list):
# Setup models.
data_path_obj = Path(data_path)
predictor = _make_predictor_from_cli(args)
predictor.tracker = tracker

# Run inference!
labels_pr = predictor.predict(provider)

if output_path is None:
output_path = data_path + ".predictions.slp"
# if output path was not provided, create an output path
if output_path is None:
output_path = f"{data_path.as_posix()}.predictions.slp"
output_path_obj = Path(output_path)

else:
output_path_obj = Path(output_path)
if output_path_obj.is_file() and len(data_path_list) > 1:
raise ValueError(
"output_path argument must be a directory if multiple video inputs are given"
)

# if output_path was provided and multiple inputs were provided, create a directory to store outputs
if len(data_path_list) > 1:
output_path = (
output_path_obj
/ data_path_obj.with_suffix(".predictions.slp").name
)
output_path_obj = Path(output_path)

labels_pr.provenance["model_paths"] = predictor.model_paths
labels_pr.provenance["predictor"] = type(predictor).__name__
labels_pr.provenance["model_paths"] = predictor.model_paths
labels_pr.provenance["predictor"] = type(predictor).__name__

if args.no_empty_frames:
# Clear empty frames if specified.
labels_pr.remove_empty_frames()

finish_timestamp = str(datetime.now())
total_elapsed = time() - t0
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")

# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path_obj.as_posix()
labels_pr.provenance["output_path"] = output_path_obj.as_posix()
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp

print("Provenance:")
pprint(labels_pr.provenance)
print()

labels_pr.provenance["args"] = vars(args)

# Save results.
labels_pr.save(output_path)
print("Saved output:", output_path)

if args.open_in_gui:
subprocess.call(["sleap-label", output_path])

# Reset output_path for next iteration
output_path = args.output

# running tracking on existing prediction file
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improve error handling for missing tracker specification.

The error message is clear, but consider providing more guidance or a direct link to documentation on how to specify a tracker.

raise ValueError(
    "To retrack on predictions, must specify tracker. "
    "Use 'sleap-track --tracking.tracker ...' to specify tracker to use. "
    "See [documentation link] for more details."
)

elif getattr(args, "tracking.tracker") is not None:
# Load predictions
print("Loading predictions...")
labels_pr = sleap.load_file(args.data_path)
frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx)
for data_path, provider in zip(data_path_list, provider_list):
# Load predictions
data_path_obj = Path(data_path)
print("Loading predictions...")
labels_pr = sleap.load_file(data_path_obj.as_posix())
frames = sorted(labels_pr.labeled_frames, key=lambda lf: lf.frame_idx)

print("Starting tracker...")
frames = run_tracker(frames=frames, tracker=tracker)
tracker.final_pass(frames)

print("Starting tracker...")
frames = run_tracker(frames=frames, tracker=tracker)
tracker.final_pass(frames)
labels_pr = Labels(labeled_frames=frames)

labels_pr = Labels(labeled_frames=frames)
if output_path is None:
output_path = f"{data_path}.{tracker.get_name()}.slp"
output_path_obj = Path(output_path)

if output_path is None:
output_path = f"{data_path}.{tracker.get_name()}.slp"
else:
output_path_obj = Path(output_path)
if (
output_path_obj.exists()
and output_path_obj.is_file()
and len(data_path_list) > 1
):
raise ValueError(
"output_path argument must be a directory if multiple video inputs are given"
)

elif not output_path_obj.exists() and len(data_path_list) > 1:
output_path = output_path_obj / data_path_obj.with_suffix(
".predictions.slp"
)
output_path_obj = Path(output_path)
output_path_obj.parent.mkdir(exist_ok=True, parents=True)

if args.no_empty_frames:
# Clear empty frames if specified.
labels_pr.remove_empty_frames()

finish_timestamp = str(datetime.now())
total_elapsed = time() - t0
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")

# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path_obj.as_posix()
labels_pr.provenance["output_path"] = output_path_obj.as_posix()
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp

print("Provenance:")
pprint(labels_pr.provenance)
print()

labels_pr.provenance["args"] = vars(args)

# Save results.
labels_pr.save(output_path)
print("Saved output:", output_path)

if args.open_in_gui:
subprocess.call(["sleap-label", output_path])

# Reset output_path for next iteration
output_path = args.output
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add output folder as an option when the input is a folder.

This suggestion addresses the need to specify an output folder when processing multiple video files from a directory.

Do you want me to generate the code to implement this feature or open a GitHub issue to track this task?

Comment on lines +5659 to +5660
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resetting output_path might lead to unintended behavior.

The reset of output_path at the end of each iteration could lead to unexpected behavior in subsequent iterations, especially if args.output is modified during the process.

+    original_output_path = args.output
     for data_path, provider in zip(data_path_list, provider_list):
         ...
-        output_path = args.output
+        output_path = original_output_path
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Reset output_path for next iteration
output_path = args.output
original_output_path = args.output
for data_path, provider in zip(data_path_list, provider_list):
...
output_path = original_output_path


else:
raise ValueError(
Expand All @@ -5506,36 +5654,3 @@ def main(args: Optional[list] = None):
"To retrack on predictions, must specify tracker. "
"Use \"sleap-track --tracking.tracker ...' to specify tracker to use."
)

if args.no_empty_frames:
# Clear empty frames if specified.
labels_pr.remove_empty_frames()

finish_timestamp = str(datetime.now())
total_elapsed = time() - t0
print("Finished inference at:", finish_timestamp)
print(f"Total runtime: {total_elapsed} secs")
print(f"Predicted frames: {len(labels_pr)}/{len(provider)}")

# Add provenance metadata to predictions.
labels_pr.provenance["sleap_version"] = sleap.__version__
labels_pr.provenance["platform"] = platform.platform()
labels_pr.provenance["command"] = " ".join(sys.argv)
labels_pr.provenance["data_path"] = data_path
labels_pr.provenance["output_path"] = output_path
labels_pr.provenance["total_elapsed"] = total_elapsed
labels_pr.provenance["start_timestamp"] = start_timestamp
labels_pr.provenance["finish_timestamp"] = finish_timestamp

print("Provenance:")
pprint(labels_pr.provenance)
print()

labels_pr.provenance["args"] = vars(args)

# Save results.
labels_pr.save(output_path)
print("Saved output:", output_path)

if args.open_in_gui:
subprocess.call(["sleap-label", output_path])
Loading