Skip to content

Commit

Permalink
Fix wrong label mapping in batch_inference for label_model (NVIDIA#5767
Browse files Browse the repository at this point in the history
…) (NVIDIA#5870)

* fix batch inference

* add test for batch

* fix device

Signed-off-by: fayejf <fayejf07@gmail.com>
Co-authored-by: fayejf <36722593+fayejf@users.noreply.github.com>
Signed-off-by: Jason <jasoli@nvidia.com>
  • Loading branch information
2 people authored and blisc committed Feb 10, 2023
1 parent f5ea3b4 commit 99de963
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
4 changes: 3 additions & 1 deletion nemo/collections/asr/models/label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,9 @@ def batch_inference(self, manifest_filepath, batch_size=32, sample_rate=16000, d
mapped_labels = list(mapped_labels)

featurizer = WaveformFeaturizer(sample_rate=sample_rate)
dataset = AudioToSpeechLabelDataset(manifest_filepath=manifest_filepath, labels=None, featurizer=featurizer)
dataset = AudioToSpeechLabelDataset(
manifest_filepath=manifest_filepath, labels=mapped_labels, featurizer=featurizer
)

dataloader = torch.utils.data.DataLoader(
dataset=dataset, batch_size=batch_size, collate_fn=dataset.fixed_seq_collate_fn,
Expand Down
22 changes: 22 additions & 0 deletions tests/collections/asr/test_speaker_label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import tempfile
from unittest import TestCase

import pytest
import torch
from omegaconf import DictConfig

from nemo.collections.asr.models import EncDecSpeakerLabelModel
Expand Down Expand Up @@ -170,3 +173,22 @@ def test_pretrained_ambernet_logits(self, test_data_dir):
label = lang_model.get_label(filename)

assert label == "en"

@pytest.mark.unit
def test_pretrained_ambernet_logits_batched(self, test_data_dir):
model_name = 'langid_ambernet'
lang_model = EncDecSpeakerLabelModel.from_pretrained(model_name)
relative_filepath = "an4_speaker/an4/wav/an4_clstk/fash/an255-fash-b.wav"
filename = os.path.join(test_data_dir, relative_filepath)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

with tempfile.TemporaryDirectory() as tmpdir:
temp_manifest = os.path.join(tmpdir, 'manifest.json')
with open(temp_manifest, 'w', encoding='utf-8') as fp:
entry = {"audio_filepath": filename, "duration": 4.5, "label": 'en'}
fp.write(json.dumps(entry) + '\n')

embs, logits, gt_labels, mapped_labels = lang_model.batch_inference(temp_manifest, device=device)
pred_label = mapped_labels[logits.argmax(axis=-1)[0]]
true_label = mapped_labels[gt_labels[0]]
assert pred_label == true_label

0 comments on commit 99de963

Please sign in to comment.