Skip to content

Commit

Permalink
Added Exportable bits to T5 + unit test
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
  • Loading branch information
borisfom committed Aug 29, 2023
1 parent 2baef81 commit 6633f30
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 2 deletions.
12 changes: 11 additions & 1 deletion nemo/collections/tts/g2p/models/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from nemo.collections.asr.metrics.wer import word_error_rate
from nemo.collections.tts.g2p.data.t5 import T5G2PDataset
from nemo.collections.tts.models.base import G2PModel
from nemo.core.classes import Exportable
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types import LabelsType, LossType, MaskType, NeuralType, TokenIndex
from nemo.utils import logging
Expand All @@ -38,7 +39,7 @@ class T5G2PConfig:
test_ds: Optional[Dict[Any, Any]] = None


class T5G2PModel(G2PModel):
class T5G2PModel(G2PModel, Exportable):
"""
T5-based grapheme-to-phoneme model.
"""
Expand Down Expand Up @@ -260,3 +261,12 @@ def setup_multiple_test_data(self, test_data_config: Union[DictConfig, Dict] = N
@classmethod
def list_available_models(cls) -> 'List[PretrainedModelInfo]':
return []

# ===== export methods ===========$

def input_example(self, max_batch=1, max_dim=64, seq_len=16):
sample = next(self.parameters())
input_ids = torch.randint(low=0, high=max_dim, size=(max_batch, seq_len), device=sample.device)
labels = torch.randint(low=0, high=max_dim, size=(max_batch, seq_len), device=sample.device)
attention_mask = torch.randint(low=0, high=1, size=(max_batch, seq_len), device=sample.device)
return tuple([input_ids, attention_mask, labels])
4 changes: 3 additions & 1 deletion nemo/core/classes/exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ def _export(
input_list, input_dict = parse_input_example(input_example)
input_names = self.input_names
output_names = self.output_names
output_example = tuple(self.forward(*input_list, **input_dict))
output_example = self.forward(*input_list, **input_dict)
if torch.is_tensor(output_example):
output_example = (output_example,)

if check_trace:
if isinstance(check_trace, bool):
Expand Down
8 changes: 8 additions & 0 deletions nemo/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def verify_torchscript(model, output, input_examples, check_tolerance=0.01):
# We disable autocast here to make sure exported TS will run under Triton or other C++ env
with torch.cuda.amp.autocast(enabled=False):
output_example = model.forward(*input_list, **input_dict)
if torch.is_tensor(output_example):
output_example = (output_example,)
ts_model = torch.jit.load(output)
all_good = all_good and run_ts_and_compare(
ts_model, input_list, input_dict, output_example, check_tolerance
Expand Down Expand Up @@ -172,6 +174,8 @@ def verify_runtime(model, output, input_examples, input_names, check_tolerance=0
for input_example in input_examples:
input_list, input_dict = parse_input_example(input_example)
output_example = model.forward(*input_list, **input_dict)
if torch.is_tensor(output_example):
output_example = (output_example,)
ort_input = to_onnxrt_input(ort_input_names, input_names, input_dict, input_list)
all_good = all_good and run_ort_and_compare(sess, ort_input, output_example, check_tolerance)
status = "SUCCESS" if all_good else "FAIL"
Expand All @@ -184,6 +188,8 @@ def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, c
ts_out = ts_model(*ts_input_list, **ts_input_dict)

all_good = True
if torch.is_tensor(ts_out):
ts_out = (ts_out,)
for i, out in enumerate(ts_out):
expected = output_example[i]

Expand All @@ -206,6 +212,8 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01):
# Verify the model can be read, and is valid
ort_out = sess.run(None, ort_input)
all_good = True
if torch.is_tensor(ort_out):
ort_out = (ort_out,)
for i, out in enumerate(ort_out):
expected = output_example[i]

Expand Down
65 changes: 65 additions & 0 deletions tests/collections/tts/g2p/test_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile

import pytest
import torch
from omegaconf import OmegaConf

from nemo.collections.tts.g2p.models.t5 import T5G2PModel
from nemo.utils.app_state import AppState


@pytest.fixture()
def t5_model():
this_test_dir = os.path.dirname(os.path.abspath(__file__))

cfg = OmegaConf.load(os.path.join(this_test_dir, '../../../../examples/tts/g2p/conf/g2p_t5.yaml'))
cfg.train_manifest = None
cfg.validation_manifest = None
app_state = AppState()
app_state.is_model_being_restored = True
model = T5G2PModel(cfg=cfg.model)
app_state.is_model_being_restored = False
model.eval()
return model


def extra_cfg():
cfg.model.init_from_ptl_ckpt = None
cfg.model.train_ds.dataset.manifest_filepath = "dummy.json"
cfg.model.train_ds.dataset.sup_data_path = "dummy.json"
cfg.model.validation_ds.dataset.manifest_filepath = "dummy.json"
cfg.model.validation_ds.dataset.sup_data_path = "dummy.json"
cfg.pitch_mean = 212.35
cfg.pitch_std = 68.52


class TestExportable:
@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
def test_T5Model_export_to_onnx(self, t5_model):
model = t5_model.cuda()
with tempfile.TemporaryDirectory() as tmpdir:
filename = os.path.join(tmpdir, 'fp.onnx')
model.export(output=filename, verbose=True, onnx_opset_version=18, check_trace=True)

@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
def test_T5Model_export_to_ts(self, t5_model):
model = t5_model.cuda()
with tempfile.TemporaryDirectory() as tmpdir:
filename = os.path.join(tmpdir, 'fp.ts')
model.export(output=filename, verbose=True, check_trace=True)

0 comments on commit 6633f30

Please sign in to comment.