Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refined export_config #7053

Merged
merged 3 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/asr/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ You may find FastConformer variants of cache-aware streaming models under ``<NeM
Note cache-aware streaming models are being exported without caching support by default.
To include caching support, `model.set_export_config({'cache_support' : 'True'})` should be called before export.
Or, if ``<NeMo_git_root>/scripts/export.py`` is being used:
`python export.py cache_aware_conformer.nemo cache_aware_conformer.onnx --config cache_support=True`
`python export.py cache_aware_conformer.nemo cache_aware_conformer.onnx --export-config cache_support=True`

.. _LSTM-Transducer_model:

Expand Down Expand Up @@ -299,7 +299,7 @@ Similar example configs for FastConformer variants of Hybrid models can be found
Note Hybrid models are being exported as RNNT (encoder and decoder+joint parts) by default.
To export as CTC (single encoder+decoder graph), `model.set_export_config({'decoder_type' : 'ctc'})` should be called before export.
Or, if ``<NeMo_git_root>/scripts/export.py`` is being used:
`python export.py hybrid_transducer.nemo hybrid_transducer.onnx --config decoder_type=ctc`
`python export.py hybrid_transducer.nemo hybrid_transducer.onnx --export-config decoder_type=ctc`

.. _Conformer-HAT_model:

Expand Down
2 changes: 1 addition & 1 deletion docs/source/core/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ An example can be found in ``<NeMo_git_root>/nemo/collections/asr/models/rnnt_mo
Here is example on now `set_export_config()` call is being tied to command line arguments in ``<NeMo_git_root>/scripts/export.py`` :

.. code-block:: Python
python scripts/export.py hybrid_conformer.nemo hybrid_conformer.onnx --config decoder_type=ctc
python scripts/export.py hybrid_conformer.nemo hybrid_conformer.onnx --export-config decoder_type=ctc
Exportable Model Code
~~~~~~~~~~~~~~~~~~~~~
Expand Down
12 changes: 12 additions & 0 deletions nemo/collections/tts/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ def list_available_models(cls) -> 'List[PretrainedModelInfo]':
list_of_models.extend(subclass_models)
return list_of_models

def set_export_config(self, args):
for k in ['enable_volume', 'enable_ragged_batches']:
if k in args:
self.export_config[k] = bool(args[k])
args.pop(k)
if 'num_speakers' in args:
self.export_config['num_speakers'] = int(args['num_speakers'])
args.pop('num_speakers')
if 'emb_range' in args:
raise Exception('embedding range is not user-settable')
super().set_export_config(args)


class Vocoder(ModelPT, ABC):
"""
Expand Down
12 changes: 9 additions & 3 deletions scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_args(argv):
parser.add_argument("--device", default="cuda", help="Device to export for")
parser.add_argument("--check-tolerance", type=float, default=0.01, help="tolerance for verification")
parser.add_argument(
"--config",
"--export-config",
metavar="KEY=VALUE",
nargs='+',
help="Set a number of key-value pairs to model.export_config dictionary "
Expand Down Expand Up @@ -142,8 +142,14 @@ def nemo_export(argv):
if args.cache_support:
model.set_export_config({"cache_support": "True"})

if args.config:
kv = dict(map(lambda s: s.split('='), args.config))
if args.export_config:
kv = {}
for key_value in args.export_config:
lst = key_value.split("=")
if len(lst) != 2:
raise Exception("Use correct format for --export_config: k=v")
k, v = lst
kv[k] = v
model.set_export_config(kv)

autocast = nullcontext
Expand Down
3 changes: 1 addition & 2 deletions tests/collections/tts/test_tts_exportables.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def radtts_model():
model = RadTTSModel(cfg=cfg.model)
app_state.is_model_being_restored = False
model.eval()
model.export_config['enable_ragged_batches'] = True
model.export_config['enable_volume'] = True
model.set_export_config({'enable_ragged_batches': 'True', 'enable_volume': 'True'})
return model


Expand Down