Skip to content

Commit

Permalink
Refactor config and patch small typo (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaprasad authored Oct 17, 2024
1 parent 2af0dd5 commit 38c4cb3
Show file tree
Hide file tree
Showing 10 changed files with 197 additions and 86 deletions.
2 changes: 1 addition & 1 deletion docs/configs/index.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# DREEM Config API

We utilize `.yaml` based configs with `hydra` and `omegaconf` for config parsing.
We utilize `.yaml` based configs with [`hydra`](https://hydra.cc) and [`omegaconf`](https://omegaconf.readthedocs.io/en/2.3_branch/) for config parsing.
2 changes: 1 addition & 1 deletion docs/configs/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Here, we describe the hyperparameters used for setting up training. Please see [here](./training.md#example-config) for an example training config.

> Note: for using defaults, simply leave the field blank or don't include the key. Using `null` will initialize the value to `None` e.g
> Note: for using defaults, simply leave the field blank or don't include the key. Using `null` will initialize the value to `None` which we use to represent turning off certain features such as logging, early stopping etc. e.g
> ```YAML
> model:
> d_model: #defaults to 1024
Expand Down
19 changes: 18 additions & 1 deletion dreem/inference/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dreem.models import GTRRunner
from omegaconf import DictConfig
from pathlib import Path
from datetime import datetime

import hydra
import os
Expand All @@ -14,9 +15,21 @@
import sleap_io as sio
import logging


logger = logging.getLogger("dreem.inference")


def get_timestamp() -> str:
"""Get current timestamp.
Returns:
the current timestamp in /m/d/y-H:M:S format
"""
date_time = datetime.now().strftime("%m-%d-%Y-%H:%M:%S")
print(date_time)
return date_time


def export_trajectories(
frames_pred: list["dreem.io.Frame"], save_path: str | None = None
) -> pd.DataFrame:
Expand Down Expand Up @@ -129,7 +142,11 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:
)
dataloader = pred_cfg.get_dataloader(dataset, mode="test")
preds = track(model, trainer, dataloader)
outpath = os.path.join(outdir, f"{Path(label_file).stem}.dreem_inference.slp")
outpath = os.path.join(
outdir, f"{Path(label_file).stem}.dreem_inference.{get_timestamp()}.slp"
)
if os.path.exists(outpath):
outpath.replace(".slp", ".")
preds.save(outpath)

return preds
Expand Down
2 changes: 1 addition & 1 deletion dreem/inference/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def sliding_inference(

for i, instance in enumerate(frames[batch_idx].instances):
if instance.pred_track_id == -1:
curr_track += 1
curr_track_id += 1
instance.pred_track_id = curr_track_id

else:
Expand Down
Loading

0 comments on commit 38c4cb3

Please sign in to comment.