Skip to content

Commit

Permalink
Refactor inference script (#51)
Browse files Browse the repository at this point in the history
Co-authored-by: Talmo Pereira <talmo@salk.edu>
  • Loading branch information
aaprasad and talmo authored Jun 3, 2024
1 parent 1ad8a47 commit 63506c8
Show file tree
Hide file tree
Showing 16 changed files with 260 additions and 165 deletions.
10 changes: 7 additions & 3 deletions biogtr/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class BaseDataset(Dataset):

def __init__(
self,
files: list[str],
label_files: list[str],
vid_files: list[str],
padding: int,
crop_size: int,
chunk: bool,
Expand All @@ -27,7 +28,9 @@ def __init__(
"""Initialize Dataset.
Args:
files: a list of files, file types are combined in subclasses
label_files: a list of paths to label files.
should at least contain detections for inference, detections + tracks for training.
vid_files: list of paths to video files.
padding: amount of padding around object crops
crop_size: the size of the object crops
chunk: whether or not to chunk the dataset into batches
Expand All @@ -42,7 +45,8 @@ def __init__(
gt_list: An optional path to .txt file containing ground truth for
cell tracking challenge datasets.
"""
self.files = files
self.vid_files = vid_files
self.label_files = label_files
self.padding = padding
self.crop_size = crop_size
self.chunk = chunk
Expand Down
3 changes: 2 additions & 1 deletion biogtr/datasets/cell_tracking_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def __init__(
"end_frame", "parent_id"
"""
super().__init__(
raw_images + gt_images,
gt_images,
raw_images,
padding,
crop_size,
chunk,
Expand Down
3 changes: 2 additions & 1 deletion biogtr/datasets/microscopy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def __init__(
seed: set a seed for reproducibility
"""
super().__init__(
videos + tracks,
tracks,
videos,
padding,
crop_size,
chunk,
Expand Down
3 changes: 2 additions & 1 deletion biogtr/datasets/sleap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def __init__(
verbose: boolean representing whether to print
"""
super().__init__(
slp_files + video_files,
slp_files,
video_files,
padding,
crop_size,
chunk,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ tracker:
decay_time: 0.9
iou: "mult"
max_center_dist: 1.0
persistent_tracking: True

dataset:
test_dataset:
slp_files: ["../training/190612_110405_wt_18159111_rig2.2@11730.slp", "../training/190612_110405_wt_18159111_rig2.2@11730.slp"]
video_files: ["../training/190612_110405_wt_18159111_rig2.2@11730.mp4", "../training/190612_110405_wt_18159111_rig2.2@11730.mp4"]
chunk: True
clip_length: 32
anchor: "centroid"

dataloader:
test_dataloader:
Expand Down
103 changes: 54 additions & 49 deletions biogtr/inference/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pandas as pd
import pytorch_lightning as pl
import torch
import sleap_io as sio


def export_trajectories(frames_pred: list["biogtr.io.Frame"], save_path: str = None):
Expand Down Expand Up @@ -50,60 +51,45 @@ def export_trajectories(frames_pred: list["biogtr.io.Frame"], save_path: str = N
return save_df


def inference(
model: GTRRunner, dataloader: torch.utils.data.DataLoader
def track(
model: GTRRunner, trainer: pl.Trainer, dataloader: torch.utils.data.DataLoader
) -> list[pd.DataFrame]:
"""Run Inference.
Args:
model: model loaded from checkpoint used for inference
model: GTRRunner model loaded from checkpoint used for inference
trainer: lighting Trainer object used for handling inference log.
dataloader: dataloader containing inference data
Return:
List of DataFrames containing prediction results for each video
"""
num_videos = len(dataloader.dataset.slp_files)
trainer = pl.Trainer(devices=1, limit_predict_batches=3)
num_videos = len(dataloader.dataset.vid_files)
preds = trainer.predict(model, dataloader)

vid_trajectories = [[] for i in range(num_videos)]
vid_trajectories = {i: [] for i in range(num_videos)}

tracks = {}
for batch in preds:
for frame in batch:
vid_trajectories[frame.video_id].append(frame)
lf, tracks = frame.to_slp(tracks)
if frame.frame_id.item() == 0:
print(f"Video: {lf.video}")
vid_trajectories[frame.video_id.item()].append(lf)

saved = []

for video in vid_trajectories:
for vid_id, video in vid_trajectories.items():
if len(video) > 0:
save_dict = {}
video_ids = []
frame_ids = []
X, Y = [], []
pred_track_ids = []
for frame in video:
for i, instance in frame.instances:
video_ids.append(frame.video_id.item())
frame_ids.append(frame.frame_id.item())
bbox = instance.bbox
y = (bbox[2] + bbox[0]) / 2
x = (bbox[3] + bbox[1]) / 2
X.append(x.item())
Y.append(y.item())
pred_track_ids.append(instance.pred_track_id.item())
save_dict["Video"] = video_ids
save_dict["Frame"] = frame_ids
save_dict["X"] = X
save_dict["Y"] = Y
save_dict["Pred_track_id"] = pred_track_ids
save_df = pd.DataFrame(save_dict)
saved.append(save_df)

return saved
try:
vid_trajectories[vid_id] = sio.Labels(video)
except AttributeError as e:
print(video[0].video)
raise (e)

return vid_trajectories


@hydra.main(config_path="configs", config_name=None, version_base=None)
def main(cfg: DictConfig):
def run(cfg: DictConfig) -> dict[int, sio.Labels]:
"""Run inference based on config file.
Args:
Expand All @@ -116,37 +102,56 @@ def main(cfg: DictConfig):
index = int(os.environ["POD_INDEX"])
# For testing without deploying a job on runai
except KeyError:
print("Pod Index Not found! Setting index to 0")
index = 0
index = input("Pod Index Not found! Please choose a pod index: ")

print(f"Pod Index: {index}")

checkpoints = pd.read_csv(cfg.checkpoints)
checkpoint = checkpoints.iloc[index]
else:
checkpoint = pred_cfg.get_ckpt_path()
checkpoint = pred_cfg.cfg.ckpt_path

model = GTRRunner.load_from_checkpoint(checkpoint)
tracker_cfg = pred_cfg.get_tracker_cfg()
print("Updating tracker hparams")
model.tracker_cfg = tracker_cfg
print(f"Using the following params for tracker:")
pprint(model.tracker_cfg)
dataset = pred_cfg.get_dataset(mode="test")

dataset = pred_cfg.get_dataset(mode="test")
dataloader = pred_cfg.get_dataloader(dataset, mode="test")
preds = inference(model, dataloader)
for i, pred in enumerate(preds):
print(pred)
outdir = pred_cfg.cfg.outdir if "outdir" in pred_cfg.cfg else "./results"
os.makedirs(outdir, exist_ok=True)

trainer = pred_cfg.get_trainer()

preds = track(model, trainer, dataloader)

outdir = pred_cfg.cfg.outdir if "outdir" in pred_cfg.cfg else "./results"
os.makedirs(outdir, exist_ok=True)

run_num = 0
for i, pred in preds.items():
outpath = os.path.join(
outdir,
f"{Path(pred_cfg.cfg.dataset.test_dataset.slp_files[i]).stem}_tracking_results",
f"{Path(dataloader.dataset.label_files[i]).stem}.biogtr_inference.v{run_num}.slp",
)
print(f"Saving to {outpath}")
# TODO: Figure out how to overwrite sleap labels instance labels w pred instance labels then save as a new slp file
pred.to_csv(outpath, index=False)
if os.path.exists(outpath):
run_num += 1
outpath = outpath.replace(f".v{run_num-1}", f".v{run_num}")
print(f"Saving {preds} to {outpath}")
pred.save(outpath)

return preds


if __name__ == "__main__":
main()
# example calls:

# train with base config:
# python train.py --config-dir=./configs --config-name=inference

# override with params config:
# python train.py --config-dir=./configs --config-name=inference +params_config=configs/params.yaml

# override with params config, and specific params:
# python train.py --config-dir=./configs --config-name=inference +params_config=configs/params.yaml dataset.train_dataset.padding=10
run()
11 changes: 6 additions & 5 deletions biogtr/inference/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame
self.track_queue.end_tracks()

"""
Initialize tracks on first frame of video or first instance of detections.
Initialize tracks on first frame where detections appear.
"""
if len(self.track_queue) == 0:
if frame_to_track.has_instances():
Expand All @@ -167,12 +167,12 @@ def sliding_inference(self, model: GlobalTrackingTransformer, frames: list[Frame
curr_track_id = 0
for i, instance in enumerate(frames[batch_idx].instances):
instance.pred_track_id = instance.gt_track_id
curr_track_id = instance.pred_track_id
curr_track_id = max(curr_track_id, instance.pred_track_id)

for i, instance in enumerate(frames[batch_idx].instances):
if instance.pred_track_id == -1:
instance.pred_track_id = curr_track_id
curr_track += 1
instance.pred_track_id = curr_track_id

else:
if (
Expand Down Expand Up @@ -250,6 +250,7 @@ def _run_global_tracker(
overlap_thresh = self.overlap_thresh
mult_thresh = self.mult_thresh
n_traj = self.track_queue.n_tracks
curr_track = self.track_queue.curr_track

reid_features = torch.cat([frame.get_features() for frame in frames], dim=0)[
None
Expand Down Expand Up @@ -470,8 +471,8 @@ def _run_global_tracker(
if track_ids[i] < 0:
if self.verbose:
print(f"Creating new track {n_traj}")
track_ids[i] = n_traj
n_traj += 1
curr_track += 1
track_ids[i] = curr_track

query_frame.matches = (match_i, match_j)

Expand Down
Loading

0 comments on commit 63506c8

Please sign in to comment.