Skip to content

Commit

Permalink
Refined export_config
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
  • Loading branch information
borisfom committed Jul 17, 2023
1 parent 6db4097 commit d698ae9
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 42 deletions.
4 changes: 2 additions & 2 deletions docs/source/asr/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ You may find FastConformer variants of cache-aware streaming models under ``<NeM
Note cache-aware streaming models are being exported without caching support by default.
To include caching support, `model.set_export_config({'cache_support' : 'True'})` should be called before export.
Or, if ``<NeMo_git_root>/scripts/export.py`` is being used:
`python export.py cache_aware_conformer.nemo cache_aware_conformer.onnx --config cache_support=True`
`python export.py cache_aware_conformer.nemo cache_aware_conformer.onnx --export-config cache_support=True`

.. _LSTM-Transducer_model:

Expand Down Expand Up @@ -299,7 +299,7 @@ Similar example configs for FastConformer variants of Hybrid models can be found
Note Hybrid models are being exported as RNNT (encoder and decoder+joint parts) by default.
To export as CTC (single encoder+decoder graph), `model.set_export_config({'decoder_type' : 'ctc'})` should be called before export.
Or, if ``<NeMo_git_root>/scripts/export.py`` is being used:
`python export.py hybrid_transducer.nemo hybrid_transducer.onnx --config decoder_type=ctc`
`python export.py hybrid_transducer.nemo hybrid_transducer.onnx --export-config decoder_type=ctc`

.. _Conformer-HAT_model:

Expand Down
2 changes: 1 addition & 1 deletion docs/source/core/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ An example can be found in ``<NeMo_git_root>/nemo/collections/asr/models/rnnt_mo
Here is example on now `set_export_config()` call is being tied to command line arguments in ``<NeMo_git_root>/scripts/export.py`` :

.. code-block:: Python
python scripts/export.py hybrid_conformer.nemo hybrid_conformer.onnx --config decoder_type=ctc
python scripts/export.py hybrid_conformer.nemo hybrid_conformer.onnx --export-config decoder_type=ctc
Exportable Model Code
~~~~~~~~~~~~~~~~~~~~~
Expand Down
39 changes: 37 additions & 2 deletions nemo/collections/tts/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
from tqdm import tqdm

from nemo.collections.tts.parts.utils.helpers import OperationMode
from nemo.core.classes import ModelPT
from nemo.core.classes import Exportable, ModelPT
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types.elements import AudioSignal
from nemo.core.neural_types.neural_type import NeuralType
from nemo.utils import logging, model_utils


class SpectrogramGenerator(ModelPT, ABC):
class SpectrogramGenerator(ModelPT, Exportable, ABC):
""" Base class for all TTS models that turn text into a spectrogram """

@abstractmethod
Expand Down Expand Up @@ -330,3 +330,38 @@ def list_available_models(cls) -> 'List[PretrainedModelInfo]':
# recursively walk the subclasses to generate pretrained model info
list_of_models = model_utils.resolve_subclass_pretrained_model_info(cls)
return list_of_models

@property
def input_types(self):
return self._input_types

@property
def output_types(self):
return self._output_types

def _export_teardown(self):
self._input_types = self._output_types = None

@property
def disabled_deployment_input_names(self):
"""Implement this method to return a set of input names disabled for export"""
disabled_inputs = set()
if self.fastpitch.speaker_emb is None:
disabled_inputs.add("speaker")
if not self.export_config["enable_ragged_batches"]:
disabled_inputs.add("batch_lengths")
if not self.export_config["enable_volume"]:
disabled_inputs.add("volume")
return disabled_inputs

def set_export_config(self, args):
for k in ['enable_volume', 'enable_ragged_batches']:
if k in args:
self.export_config[k] = bool(args[k])
args.pop(k)
if 'num_speakers' in args:
self.export_config['num_speakers'] = int(args['num_speakers'])
args.pop('num_speakers')
if 'emb_range' in args:
raise Exception('embedding range is not user-settable')
super().set_export_config(args)
26 changes: 1 addition & 25 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
process_batch,
sample_tts_input,
)
from nemo.core.classes import Exportable
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types.elements import (
Index,
Expand Down Expand Up @@ -78,7 +77,7 @@ class TextTokenizerConfig:
text_tokenizer: TextTokenizer = TextTokenizer()


class FastPitchModel(SpectrogramGenerator, Exportable, FastPitchAdapterModelMixin):
class FastPitchModel(SpectrogramGenerator, FastPitchAdapterModelMixin):
"""FastPitch model (https://arxiv.org/abs/2006.06873) that is used to generate mel spectrogram from text."""

def __init__(self, cfg: DictConfig, trainer: Trainer = None):
Expand Down Expand Up @@ -819,29 +818,6 @@ def _prepare_for_export(self, **kwargs):
if self.export_config["enable_volume"]:
self._output_types["volume_aligned"] = NeuralType(('B', 'T'), RegressionValuesType())

def _export_teardown(self):
self._input_types = self._output_types = None

@property
def disabled_deployment_input_names(self):
"""Implement this method to return a set of input names disabled for export"""
disabled_inputs = set()
if self.fastpitch.speaker_emb is None:
disabled_inputs.add("speaker")
if not self.export_config["enable_ragged_batches"]:
disabled_inputs.add("batch_lengths")
if not self.export_config["enable_volume"]:
disabled_inputs.add("volume")
return disabled_inputs

@property
def input_types(self):
return self._input_types

@property
def output_types(self):
return self._output_types

def input_example(self, max_batch=1, max_dim=44):
"""
Generates input examples for tracing etc.
Expand Down
10 changes: 1 addition & 9 deletions nemo/collections/tts/models/radtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
regulate_len,
sample_tts_input,
)
from nemo.core.classes import Exportable
from nemo.core.classes.common import typecheck
from nemo.core.neural_types.elements import (
Index,
Expand All @@ -45,7 +44,7 @@


@experimental
class RadTTSModel(SpectrogramGenerator, Exportable):
class RadTTSModel(SpectrogramGenerator):
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
if isinstance(cfg, dict):
cfg = OmegaConf.create(cfg)
Expand Down Expand Up @@ -434,13 +433,6 @@ def load_state_dict(self, state_dict, strict=True):
super().load_state_dict(new_state_dict, strict=strict)

# Methods for model exportability
@property
def input_types(self):
return self._input_types

@property
def output_types(self):
return self._output_types

def _prepare_for_export(self, **kwargs):
self.model.remove_norms()
Expand Down
12 changes: 9 additions & 3 deletions scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_args(argv):
parser.add_argument("--device", default="cuda", help="Device to export for")
parser.add_argument("--check-tolerance", type=float, default=0.01, help="tolerance for verification")
parser.add_argument(
"--config",
"--export-config",
metavar="KEY=VALUE",
nargs='+',
help="Set a number of key-value pairs to model.export_config dictionary "
Expand Down Expand Up @@ -142,8 +142,14 @@ def nemo_export(argv):
if args.cache_support:
model.set_export_config({"cache_support": "True"})

if args.config:
kv = dict(map(lambda s: s.split('='), args.config))
if args.export_config:
kv = {}
for key_value in args.export_config:
lst = key_value.split("=")
if len(lst) != 2:
raise Exception("Use correct format for --export_config: k=v")
k, v = lst
kv[k] = v
model.set_export_config(kv)

autocast = nullcontext
Expand Down

0 comments on commit d698ae9

Please sign in to comment.