Skip to content

Commit

Permalink
Refined export_config (NVIDIA#7053)
Browse files Browse the repository at this point in the history
* Refined export_config

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>

* Rolling back hierarchy change

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>

---------

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>
  • Loading branch information
borisfom authored and zhehuaichen committed Oct 4, 2023
1 parent 48f7106 commit c6c33b7
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 8 deletions.
4 changes: 2 additions & 2 deletions docs/source/asr/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ You may find FastConformer variants of cache-aware streaming models under ``<NeM
Note cache-aware streaming models are being exported without caching support by default.
To include caching support, `model.set_export_config({'cache_support' : 'True'})` should be called before export.
Or, if ``<NeMo_git_root>/scripts/export.py`` is being used:
`python export.py cache_aware_conformer.nemo cache_aware_conformer.onnx --config cache_support=True`
`python export.py cache_aware_conformer.nemo cache_aware_conformer.onnx --export-config cache_support=True`

.. _LSTM-Transducer_model:

Expand Down Expand Up @@ -299,7 +299,7 @@ Similar example configs for FastConformer variants of Hybrid models can be found
Note Hybrid models are being exported as RNNT (encoder and decoder+joint parts) by default.
To export as CTC (single encoder+decoder graph), `model.set_export_config({'decoder_type' : 'ctc'})` should be called before export.
Or, if ``<NeMo_git_root>/scripts/export.py`` is being used:
`python export.py hybrid_transducer.nemo hybrid_transducer.onnx --config decoder_type=ctc`
`python export.py hybrid_transducer.nemo hybrid_transducer.onnx --export-config decoder_type=ctc`

.. _Conformer-HAT_model:

Expand Down
2 changes: 1 addition & 1 deletion docs/source/core/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ An example can be found in ``<NeMo_git_root>/nemo/collections/asr/models/rnnt_mo
Here is example on now `set_export_config()` call is being tied to command line arguments in ``<NeMo_git_root>/scripts/export.py`` :

.. code-block:: Python
python scripts/export.py hybrid_conformer.nemo hybrid_conformer.onnx --config decoder_type=ctc
python scripts/export.py hybrid_conformer.nemo hybrid_conformer.onnx --export-config decoder_type=ctc
Exportable Model Code
~~~~~~~~~~~~~~~~~~~~~
Expand Down
12 changes: 12 additions & 0 deletions nemo/collections/tts/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ def list_available_models(cls) -> 'List[PretrainedModelInfo]':
list_of_models.extend(subclass_models)
return list_of_models

def set_export_config(self, args):
for k in ['enable_volume', 'enable_ragged_batches']:
if k in args:
self.export_config[k] = bool(args[k])
args.pop(k)
if 'num_speakers' in args:
self.export_config['num_speakers'] = int(args['num_speakers'])
args.pop('num_speakers')
if 'emb_range' in args:
raise Exception('embedding range is not user-settable')
super().set_export_config(args)


class Vocoder(ModelPT, ABC):
"""
Expand Down
12 changes: 9 additions & 3 deletions scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_args(argv):
parser.add_argument("--device", default="cuda", help="Device to export for")
parser.add_argument("--check-tolerance", type=float, default=0.01, help="tolerance for verification")
parser.add_argument(
"--config",
"--export-config",
metavar="KEY=VALUE",
nargs='+',
help="Set a number of key-value pairs to model.export_config dictionary "
Expand Down Expand Up @@ -142,8 +142,14 @@ def nemo_export(argv):
if args.cache_support:
model.set_export_config({"cache_support": "True"})

if args.config:
kv = dict(map(lambda s: s.split('='), args.config))
if args.export_config:
kv = {}
for key_value in args.export_config:
lst = key_value.split("=")
if len(lst) != 2:
raise Exception("Use correct format for --export_config: k=v")
k, v = lst
kv[k] = v
model.set_export_config(kv)

autocast = nullcontext
Expand Down
3 changes: 1 addition & 2 deletions tests/collections/tts/test_tts_exportables.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def radtts_model():
model = RadTTSModel(cfg=cfg.model)
app_state.is_model_being_restored = False
model.eval()
model.export_config['enable_ragged_batches'] = True
model.export_config['enable_volume'] = True
model.set_export_config({'enable_ragged_batches': 'True', 'enable_volume': 'True'})
return model


Expand Down

0 comments on commit c6c33b7

Please sign in to comment.