Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor inference script #51

Merged
merged 42 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
8ddc557
add back-reference to frame as attribute for instane
aaprasad May 15, 2024
7887d56
separate get_boxes_times into two functions and use `Instances` as input
aaprasad May 15, 2024
9bab7bc
use instances as input into model instead of frames
aaprasad May 15, 2024
82293d2
create io module, move config, visualize there. abstract `Frame` and …
aaprasad May 16, 2024
b4049b8
refactor `Frame` and `Instance` initialization to use `attrs` instead…
aaprasad May 16, 2024
2b0bc55
add doc strings, fix small bugs
aaprasad May 16, 2024
ef26012
Implement AssociationMatrix class for handling model output
aaprasad May 16, 2024
94a0e61
create io module, move config, visualize there. abstract `Frame` and …
aaprasad May 16, 2024
c4bc0fb
refactor `Frame` and `Instance` initialization to use `attrs` instead…
aaprasad May 16, 2024
42f8a8c
add doc strings, fix small bugs
aaprasad May 16, 2024
b5f39b4
Implement AssociationMatrix class for handling model output
aaprasad May 16, 2024
ccd523a
Merge remote-tracking branch 'origin/aadi/refactor-data-structures' i…
aaprasad May 16, 2024
0f535af
fix overwrites from merge
aaprasad May 17, 2024
56e038a
store model outputs in association matrix
aaprasad May 17, 2024
8a71d6d
add track object for storing tracklets
aaprasad May 17, 2024
766820b
add reduction function to association matrix
aaprasad May 20, 2024
c0ceac1
add doc_strings
aaprasad May 20, 2024
92095e0
fix tests, docstrings
aaprasad May 20, 2024
a6a6ace
add spatial/temporal embeddings as attribute to `Instance`
aaprasad May 20, 2024
56e0555
fix typo
aaprasad May 20, 2024
80400fc
add `from_slp` converters
aaprasad May 21, 2024
7ff22e7
fix docstrings
aaprasad May 21, 2024
89007a9
store embeddings in Instance object instead of returning
aaprasad May 21, 2024
cbf915a
only keep visualize in io
aaprasad May 21, 2024
b3c5661
remove mutable types from default arguments. Don't use kwargs unless …
aaprasad May 21, 2024
adb3715
handle edge case where ckpt_path is not in config
aaprasad May 21, 2024
557d4e9
expose appropriate modules in respective `__init__.py`
aaprasad May 21, 2024
1013847
separate `files` into `vid_files`, `label_files` for finer grained co…
aaprasad May 27, 2024
e030da7
fix edge case for get trainer when trainer params don't exist
aaprasad May 27, 2024
f86317c
fix `to_slp` bugs stemming from type change
aaprasad May 27, 2024
ed85a40
use tmp dir for tests
aaprasad May 27, 2024
cf88413
refactor inference script
aaprasad May 27, 2024
9f20e48
add logic to handle directory paths instead of only file paths
aaprasad May 27, 2024
7e9eb65
add `from_yaml` classmethod for direct config loading
aaprasad May 27, 2024
d00a9f4
add documentation for cli calls
aaprasad May 27, 2024
5b8ad52
fix small typo + docstrings
aaprasad May 27, 2024
9d4d24a
fix docstring typo
aaprasad May 27, 2024
e1f7c66
fix small edge case when initializing new tracks
aaprasad May 28, 2024
bd724ec
Update biogtr/datasets/base_dataset.py
talmo May 28, 2024
f5b766b
lint
aaprasad May 29, 2024
06a4500
Merge branch 'main' into aadi/refactor-inference
aaprasad Jun 3, 2024
f725848
lint post-merge
aaprasad Jun 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions biogtr/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class BaseDataset(Dataset):

def __init__(
self,
files: list[str],
label_files: list[str],
vid_files: list[str],
padding: int,
crop_size: int,
chunk: bool,
Expand All @@ -27,7 +28,9 @@ def __init__(
"""Initialize Dataset.

Args:
files: a list of files, file types are combined in subclasses
label_files: a list of paths to label files.
should at least contain detections for inference, detections + tracks for training.
vid_files: list of paths to video files.
padding: amount of padding around object crops
crop_size: the size of the object crops
chunk: whether or not to chunk the dataset into batches
Expand All @@ -42,7 +45,8 @@ def __init__(
gt_list: An optional path to .txt file containing ground truth for
cell tracking challenge datasets.
"""
self.files = files
self.vid_files = vid_files
self.label_files = label_files
self.padding = padding
self.crop_size = crop_size
self.chunk = chunk
Expand Down
3 changes: 2 additions & 1 deletion biogtr/datasets/cell_tracking_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def __init__(
"end_frame", "parent_id"
"""
super().__init__(
raw_images + gt_images,
gt_images,
raw_images,
padding,
crop_size,
chunk,
Expand Down
3 changes: 2 additions & 1 deletion biogtr/datasets/microscopy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def __init__(
seed: set a seed for reproducibility
"""
super().__init__(
videos + tracks,
tracks,
videos,
padding,
crop_size,
chunk,
Expand Down
3 changes: 2 additions & 1 deletion biogtr/datasets/sleap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def __init__(
verbose: boolean representing whether to print
"""
super().__init__(
slp_files + video_files,
slp_files,
video_files,
padding,
crop_size,
chunk,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ tracker:
decay_time: 0.9
iou: "mult"
max_center_dist: 1.0
persistent_tracking: True

dataset:
test_dataset:
slp_files: ["../training/190612_110405_wt_18159111_rig2.2@11730.slp", "../training/190612_110405_wt_18159111_rig2.2@11730.slp"]
video_files: ["../training/190612_110405_wt_18159111_rig2.2@11730.mp4", "../training/190612_110405_wt_18159111_rig2.2@11730.mp4"]
chunk: True
clip_length: 32
anchor: "centroid"

dataloader:
test_dataloader:
Expand Down
103 changes: 54 additions & 49 deletions biogtr/inference/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pandas as pd
import pytorch_lightning as pl
import torch
import sleap_io as sio


def export_trajectories(frames_pred: list["biogtr.io.Frame"], save_path: str = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix undefined name error.

- from biogtr.io import Config
+ from biogtr.io import Config, Frame
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def export_trajectories(frames_pred: list["biogtr.io.Frame"], save_path: str = None):
from biogtr.io import Config, Frame
def export_trajectories(frames_pred: list["biogtr.io.Frame"], save_path: str = None):
Tools
Ruff

17-17: Undefined name biogtr

Expand Down Expand Up @@ -50,60 +51,45 @@ def export_trajectories(frames_pred: list["biogtr.io.Frame"], save_path: str = N
return save_df


def inference(
model: GTRRunner, dataloader: torch.utils.data.DataLoader
def track(
model: GTRRunner, trainer: pl.Trainer, dataloader: torch.utils.data.DataLoader
) -> list[pd.DataFrame]:
"""Run Inference.

Args:
model: model loaded from checkpoint used for inference
model: GTRRunner model loaded from checkpoint used for inference
trainer: lighting Trainer object used for handling inference log.
dataloader: dataloader containing inference data

Return:
List of DataFrames containing prediction results for each video
"""
num_videos = len(dataloader.dataset.slp_files)
trainer = pl.Trainer(devices=1, limit_predict_batches=3)
num_videos = len(dataloader.dataset.vid_files)
preds = trainer.predict(model, dataloader)

vid_trajectories = [[] for i in range(num_videos)]
vid_trajectories = {i: [] for i in range(num_videos)}

tracks = {}
for batch in preds:
for frame in batch:
vid_trajectories[frame.video_id].append(frame)
lf, tracks = frame.to_slp(tracks)
if frame.frame_id.item() == 0:
print(f"Video: {lf.video}")
vid_trajectories[frame.video_id.item()].append(lf)

saved = []

for video in vid_trajectories:
for vid_id, video in vid_trajectories.items():
if len(video) > 0:
save_dict = {}
video_ids = []
frame_ids = []
X, Y = [], []
pred_track_ids = []
for frame in video:
for i, instance in frame.instances:
video_ids.append(frame.video_id.item())
frame_ids.append(frame.frame_id.item())
bbox = instance.bbox
y = (bbox[2] + bbox[0]) / 2
x = (bbox[3] + bbox[1]) / 2
X.append(x.item())
Y.append(y.item())
pred_track_ids.append(instance.pred_track_id.item())
save_dict["Video"] = video_ids
save_dict["Frame"] = frame_ids
save_dict["X"] = X
save_dict["Y"] = Y
save_dict["Pred_track_id"] = pred_track_ids
save_df = pd.DataFrame(save_dict)
saved.append(save_df)

return saved
try:
vid_trajectories[vid_id] = sio.Labels(video)
except AttributeError as e:
print(video[0].video)
raise (e)

return vid_trajectories


@hydra.main(config_path="configs", config_name=None, version_base=None)
def main(cfg: DictConfig):
def run(cfg: DictConfig) -> dict[int, sio.Labels]:
"""Run inference based on config file.

Args:
Expand All @@ -116,37 +102,56 @@ def main(cfg: DictConfig):
index = int(os.environ["POD_INDEX"])
# For testing without deploying a job on runai
except KeyError:
print("Pod Index Not found! Setting index to 0")
index = 0
index = input("Pod Index Not found! Please choose a pod index: ")

print(f"Pod Index: {index}")

checkpoints = pd.read_csv(cfg.checkpoints)
checkpoint = checkpoints.iloc[index]
else:
checkpoint = pred_cfg.get_ckpt_path()
checkpoint = pred_cfg.cfg.ckpt_path

model = GTRRunner.load_from_checkpoint(checkpoint)
tracker_cfg = pred_cfg.get_tracker_cfg()
print("Updating tracker hparams")
model.tracker_cfg = tracker_cfg
print(f"Using the following params for tracker:")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unnecessary f-string.

- print(f"Using the following params for tracker:")
+ print("Using the following params for tracker:")
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
print(f"Using the following params for tracker:")
print("Using the following params for tracker:")
Tools
Ruff

118-118: f-string without any placeholders

pprint(model.tracker_cfg)
dataset = pred_cfg.get_dataset(mode="test")

dataset = pred_cfg.get_dataset(mode="test")
dataloader = pred_cfg.get_dataloader(dataset, mode="test")
preds = inference(model, dataloader)
for i, pred in enumerate(preds):
print(pred)
outdir = pred_cfg.cfg.outdir if "outdir" in pred_cfg.cfg else "./results"
os.makedirs(outdir, exist_ok=True)

trainer = pred_cfg.get_trainer()

preds = track(model, trainer, dataloader)

outdir = pred_cfg.cfg.outdir if "outdir" in pred_cfg.cfg else "./results"
os.makedirs(outdir, exist_ok=True)

run_num = 0
for i, pred in preds.items():
outpath = os.path.join(
outdir,
f"{Path(pred_cfg.cfg.dataset.test_dataset.slp_files[i]).stem}_tracking_results",
f"{Path(dataloader.dataset.label_files[i]).stem}.biogtr_inference.v{run_num}.slp",
)
print(f"Saving to {outpath}")
# TODO: Figure out how to overwrite sleap labels instance labels w pred instance labels then save as a new slp file
pred.to_csv(outpath, index=False)
if os.path.exists(outpath):
run_num += 1
outpath = outpath.replace(f".v{run_num-1}", f".v{run_num}")
print(f"Saving {preds} to {outpath}")
pred.save(outpath)

return preds


if __name__ == "__main__":
main()
# example calls:

# train with base config:
# python train.py --config-dir=./configs --config-name=inference

# override with params config:
# python train.py --config-dir=./configs --config-name=inference +params_config=configs/params.yaml

# override with params config, and specific params:
# python train.py --config-dir=./configs --config-name=inference +params_config=configs/params.yaml dataset.train_dataset.padding=10
run()
11 changes: 6 additions & 5 deletions biogtr/inference/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame
self.track_queue.end_tracks()

"""
Initialize tracks on first frame of video or first instance of detections.
Initialize tracks on first frame where detections appear.
"""
if len(self.track_queue) == 0:
if frame_to_track.has_instances():
Expand All @@ -167,12 +167,12 @@ def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame
curr_track_id = 0
for i, instance in enumerate(frames[batch_idx].instances):
instance.pred_track_id = instance.gt_track_id
curr_track_id = instance.pred_track_id
curr_track_id = max(curr_track_id, instance.pred_track_id)

for i, instance in enumerate(frames[batch_idx].instances):
if instance.pred_track_id == -1:
instance.pred_track_id = curr_track_id
curr_track += 1
instance.pred_track_id = curr_track_id

else:
if (
Expand Down Expand Up @@ -250,6 +250,7 @@ def _run_global_tracker(
overlap_thresh = self.overlap_thresh
mult_thresh = self.mult_thresh
n_traj = self.track_queue.n_tracks
curr_track = self.track_queue.curr_track

reid_features = torch.cat([frame.get_features() for frame in frames], dim=0)[
None
Expand Down Expand Up @@ -470,8 +471,8 @@ def _run_global_tracker(
if track_ids[i] < 0:
if self.verbose:
print(f"Creating new track {n_traj}")
track_ids[i] = n_traj
n_traj += 1
curr_track += 1
track_ids[i] = curr_track

query_frame.matches = (match_i, match_j)

Expand Down
Loading
Loading