Skip to content

Commit

Permalink
Patch small bugs (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaprasad authored Jul 30, 2024
1 parent db707f7 commit 2af0dd5
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
7 changes: 2 additions & 5 deletions dreem/datasets/sleap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram

vid_reader = self.videos[label_idx]

img = vid_reader.get_data(0)
# img = vid_reader.get_data(0)

skeleton = video.skeletons[-1]

Expand All @@ -162,10 +162,7 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram
lf = video[frame_ind]

try:
img = vid_reader.get_data(frame_ind)
if len(img.shape) == 2:
img = np.expand_dims(img, 0)
h, w, c = img.shape
img = vid_reader.get_data(int(lf.frame_idx))
except IndexError as e:
logger.warning(
f"Could not read frame {frame_ind} from {video_name} due to {e}"
Expand Down
3 changes: 3 additions & 0 deletions dreem/inference/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:
logger.info(f"Using the following tracker:")
print(model.tracker)
model.metrics["test"] = eval_cfg.cfg.runner.metrics.test
model.persistent_tracking["test"] = eval_cfg.cfg.tracker.get(
"persistent_tracking", False
)
logger.info(f"Computing the following metrics:")
logger.info(model.metrics.test)
model.test_results["save_path"] = eval_cfg.cfg.runner.save_path
Expand Down
2 changes: 1 addition & 1 deletion dreem/models/gtr_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def forward(
self,
ref_instances: list["dreem.io.Instance"],
query_instances: list["dreem.io.Instance"] | None = None,
) -> torch.Tensor:
) -> list["AssociationMatrix"]:
"""Execute forward pass of the lightning module.
Args:
Expand Down
4 changes: 2 additions & 2 deletions dreem/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def forward(

window_length = len(ref_times.unique())

ref_temp_emb = self.temp_emb(ref_times / window_length)
ref_temp_emb = self.temp_emb(ref_times)

ref_pos_emb = self.pos_emb(ref_boxes)

Expand Down Expand Up @@ -218,7 +218,7 @@ def forward(

query_boxes = get_boxes(query_instances)
query_boxes = torch.nan_to_num(query_boxes, -1.0)
query_temp_emb = self.temp_emb(query_times / window_length)
query_temp_emb = self.temp_emb(query_times)

query_pos_emb = self.pos_emb(query_boxes)

Expand Down

0 comments on commit 2af0dd5

Please sign in to comment.