Skip to content

Commit

Permalink
initial changes
Browse files Browse the repository at this point in the history
  • Loading branch information
emdavis02 committed Jul 22, 2024
1 parent 3e2bd25 commit 2310b0e
Showing 1 changed file with 58 additions and 26 deletions.
84 changes: 58 additions & 26 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5285,8 +5285,8 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:
args: Parsed CLI namespace.
Returns:
A tuple of `(provider, data_path)` with the data `Provider` and path to the data
that was specified in the args.
`(provider_list, data_path_list, output_path_list)` with the data `Provider`, path to the data
that was specified in the args, and list out output paths if a csv file was inputed.
"""

# Figure out which input path to use.
Expand All @@ -5299,72 +5299,96 @@ def _make_provider_from_cli(args: argparse.Namespace) -> Tuple[Provider, str]:
)

data_path_obj = Path(data_path)

# Set output_path_list to None as a default to return later
output_path_list = None

# Check that input value is valid
if not data_path_obj.exists():
raise ValueError("Path to data_path does not exist")

elif data_path_obj.suffix.lower() == ".csv":
try:
# Read the CSV file
df = pandas.read_csv(data_path)

# Check if the 'data_path' and 'output_path' columns exist
if "data_path" in df.columns:
raw_data_path_list = df["data_path"].tolist()
else:
print("Column 'data_path' does not exist in data_path csv file.")
if "output_path" in df.columns:
output_path_list = df["output_path"].tolist()

except FileNotFoundError as e:
raise ValueError(f"CSV file not found: {data_path}") from e
except pandas.errors.EmptyDataError as e:
raise ValueError(f"CSV file is empty: {data_path}") from e
except pandas.errors.ParserError as e:
raise ValueError(f"Error parsing CSV file: {data_path}") from e

# Check for multiple video inputs
# Compile file(s) into a list for later itteration
if data_path_obj.is_dir():
data_path_list = []
elif data_path_obj.is_dir():
raw_data_path_list = []
for file_path in data_path_obj.iterdir():
if file_path.is_file():
data_path_list.append(Path(file_path))
raw_data_path_list.append(Path(file_path))

elif data_path_obj.is_file():
data_path_list = [data_path_obj]
raw_data_path_list = [data_path_obj]

# Provider list to accomodate multiple video inputs
output_provider_list = []
output_data_path_list = []
for file_path in data_path_list:
provider_list = []
data_path_list = []
for file_path in raw_data_path_list:
# Create a provider for each file
if file_path.as_posix().endswith(".slp") and len(data_path_list) > 1:
if file_path.as_posix().endswith(".slp") and len(raw_data_path_list) > 1:
print(f"slp file skipped: {file_path.as_posix()}")

elif file_path.as_posix().endswith(".slp"):
labels = sleap.load_file(file_path.as_posix())

if args.only_labeled_frames:
output_provider_list.append(
provider_list.append(
LabelsReader.from_user_labeled_frames(labels)
)
elif args.only_suggested_frames:
output_provider_list.append(
provider_list.append(
LabelsReader.from_unlabeled_suggestions(labels)
)
elif getattr(args, "video.index") != "":
output_provider_list.append(
provider_list.append(
VideoReader(
video=labels.videos[int(getattr(args, "video.index"))],
example_indices=frame_list(args.frames),
)
)
else:
output_provider_list.append(LabelsReader(labels))
provider_list.append(LabelsReader(labels))

output_data_path_list.append(file_path)
data_path_list.append(file_path)

else:
try:
video_kwargs = dict(
dataset=vars(args).get("video.dataset"),
input_format=vars(args).get("video.input_format"),
)
output_provider_list.append(
provider_list.append(
VideoReader.from_filepath(
filename=file_path.as_posix(),
example_indices=frame_list(args.frames),
**video_kwargs,
)
)
print(f"Video: {file_path.as_posix()}")
output_data_path_list.append(file_path)
data_path_list.append(file_path)
# TODO: Clean this up.
except Exception:
print(f"Error reading file: {file_path.as_posix()}")

return output_provider_list, output_data_path_list
return provider_list, data_path_list, output_path_list


def _make_predictor_from_cli(args: argparse.Namespace) -> Predictor:
Expand Down Expand Up @@ -5496,10 +5520,12 @@ def main(args: Optional[list] = None):
print()

# Setup data loader.
provider_list, data_path_list = _make_provider_from_cli(args)

output_path = args.output
provider_list, data_path_list, output_path_list = _make_provider_from_cli(args)

# if output_path has not been extracted from a csv file yet
if output_path_list is None:
output_path = args.output

# check if output_path is valid before running inference
if (
output_path is not None
Expand All @@ -5520,7 +5546,7 @@ def main(args: Optional[list] = None):
if args.models is not None:

# Run inference on all files inputed
for data_path, provider in zip(data_path_list, provider_list):
for i, (data_path, provider) in enumerate(zip(data_path_list, provider_list)):
# Setup models.
data_path_obj = Path(data_path)
predictor = _make_predictor_from_cli(args)
Expand All @@ -5531,11 +5557,17 @@ def main(args: Optional[list] = None):

# if output path was not provided, create an output path
if output_path is None:
output_path = f"{data_path.as_posix()}.predictions.slp"
output_path_obj = Path(output_path)
output_path = data_path + ".predictions.slp"
# if output path was not provided, create an output path
if output_path_list is not None:
output_path = output_path_list[i]

elif output_path is None:
output_path = f"{data_path.as_posix()}.predictions.slp"
output_path_obj = Path(output_path)

else:
output_path_obj = Path(output_path)
else:
output_path_obj = Path(output_path)

# if output_path was provided and multiple inputs were provided, create a directory to store outputs
if len(data_path_list) > 1:
Expand Down

0 comments on commit 2310b0e

Please sign in to comment.