Skip to content

Commit

Permalink
ENH: Convert post-processing to transforms, fix #537
Browse files Browse the repository at this point in the history
- Factor all functions out of labeled_timebins except has_unlabeled
- Only test has_unlabeled in test_labeled_timebins,
  parametrize the unit test
- Add 'multi_char_labels_to_single_char' to vak.labels
  - and add type annotations and clean up docstrings
- Clean up / add tests to tests/test_labels.py
- Add src/vak/transforms/labeled_timebins/functional.py
- Add src/vak/transforms/labeled_timebins/transforms.py
- Fix imports in vak/transforms/ __init__.py files
  - Import labeled_timebins transforms in transforms/__init__.py
  - Fix other imports
  - Import classes from transforms in
    src/vak/transforms/labeled_timebins/__init__.py
- TST: Move paths outside fixtures in fixtures/annot.py
  to be able to import directly in test modules
  to parametrize there.
- TST: Move path to constant outside fixture in fixtures/config.py
- Add tests/test_transforms/__init__.py
  so we can do relative import from tests/fixtures
- Add tests/test_transforms/test_labeled_timebins/test_functional.py
- Add tests/test_transforms/test_labeled_timebins/test_transforms.py
- Use `vak.transforms.labeled_timebins.from_segments`
  in vocal_dataset.py
- Use vak.transforms.labeled_timebins in engine/model.py
- Use transforms.labeled_timebins.from_segments
  in datasets/window_dataset.py
- Use transforms.labeled_timebins.to_segments in core/predict.py
- Use labeled_timebins transforms 'postprocess' and 'to_segments'
  in core/predict.py
  • Loading branch information
NickleDave committed Feb 7, 2023
1 parent 1ae478d commit 8fb7665
Show file tree
Hide file tree
Showing 18 changed files with 1,787 additions and 819 deletions.
14 changes: 10 additions & 4 deletions src/vak/core/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
constants,
files,
io,
labeled_timebins,
validators
)
from .. import models
Expand Down Expand Up @@ -223,12 +222,19 @@ def predict(

spect_dict = files.spect.load(spect_path)
t = spect_dict[timebins_key]
labels, onsets_s, offsets_s = labeled_timebins.lbl_tb2segments(

if majority_vote or min_segment_dur:
y_pred = transforms.labeled_timebins.postprocess(
y_pred,
timebin_dur=timebin_dur,
min_segment_dur=min_segment_dur,
majority_vote=majority_vote,
)

labels, onsets_s, offsets_s = transforms.labeled_timebins.to_segments(
y_pred,
labelmap=labelmap,
t=t,
min_segment_dur=min_segment_dur,
majority_vote=majority_vote,
)
if labels is None and onsets_s is None and offsets_s is None:
# handle the case when all time bins are predicted to be unlabeled
Expand Down
4 changes: 2 additions & 2 deletions src/vak/datasets/vocal_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .. import annotation
from .. import files
from .. import labeled_timebins
from .. import transforms


class VocalDataset:
Expand Down Expand Up @@ -80,7 +80,7 @@ def __getitem__(self, idx):
annot = self.annots[idx]
lbls_int = [self.labelmap[lbl] for lbl in annot.seq.labels]
# "lbl_tb": labeled timebins. Target for output of network
lbl_tb = labeled_timebins.label_timebins(
lbl_tb = transforms.labeled_timebins.from_segments(
lbls_int,
annot.seq.onsets_s,
annot.seq.offsets_s,
Expand Down
16 changes: 9 additions & 7 deletions src/vak/datasets/window_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import torch
from torchvision.datasets.vision import VisionDataset

from .. import annotation
from .. import files
from .. import io
from .. import labeled_timebins
from .. import validators
from .. import (
annotation,
files,
io,
transforms,
validators
)


class WindowDataset(VisionDataset):
Expand Down Expand Up @@ -206,7 +208,7 @@ def __get_window_labelvec(self, idx):
spect_id
] # "annot id" == spect_id if both were taken from rows of DataFrame
lbls_int = [self.labelmap[lbl] for lbl in annot.seq.labels]
lbl_tb = labeled_timebins.label_timebins(
lbl_tb = transforms.labeled_timebins.from_segments(
lbls_int,
annot.seq.onsets_s,
annot.seq.offsets_s,
Expand Down Expand Up @@ -694,7 +696,7 @@ def spect_vectors_from_df(
lbls_int = [labelmap[lbl] for lbl in annot.seq.labels]
timebins = spect_dict[timebins_key]
lbl_tb.append(
labeled_timebins.label_timebins(
transforms.labeled_timebins.from_segments(
lbls_int,
annot.seq.onsets_s,
annot.seq.offsets_s,
Expand Down
6 changes: 3 additions & 3 deletions src/vak/engine/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tqdm import tqdm

from ..device import get_default as get_default_device
from ..labeled_timebins import lbl_tb2labels
from .. import transforms


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -267,10 +267,10 @@ def _eval(self, eval_data):
for metric_name in self.metrics.keys()
]
):
y_labels = lbl_tb2labels(
y_labels = transforms.labeled_timebins.lbl_tb2labels(
y.cpu().numpy(), eval_data.dataset.labelmap
)
y_pred_labels = lbl_tb2labels(
y_pred_labels = transforms.labeled_timebins.lbl_tb2labels(
y_pred.cpu().numpy(), eval_data.dataset.labelmap
)
else:
Expand Down
Loading

0 comments on commit 8fb7665

Please sign in to comment.