Skip to content

Commit

Permalink
Moved set_export_config
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
  • Loading branch information
borisfom committed Jul 6, 2023
2 parents 628edc9 + 6392c43 commit 7a5e750
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 12 deletions.
13 changes: 4 additions & 9 deletions nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,15 +240,10 @@ def disabled_deployment_input_names(self):
def disabled_deployment_output_names(self):
return self.encoder.disabled_deployment_output_names

def set_export_config(self, **kwargs):
if 'cache_support' in kwargs:
enable = bool(kwargs['cache_support'])
def set_export_config(self, args):
if 'cache_support' in args:
enable = bool(args['cache_support'])
self.encoder.export_cache_support = enable
logging.info(f"Caching support enabled: {enable}")
self.encoder.setup_streaming_params()
if 'decoder_type' in kwargs:
if hasattr(self, 'change_decoding_strategy'):
self.change_decoding_strategy(decoder_type=kwargs['decoder_type'])
else:
raise Exception("Model does not have decoder type option")
super().set_export_config(**kwargs)
super().set_export_config(args)
8 changes: 8 additions & 0 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,14 @@ def list_export_subnets(self):
def decoder_joint(self):
return RNNTDecoderJoint(self.decoder, self.joint)

def set_export_config(self, args):
if 'decoder_type' in args:
if hasattr(self, 'change_decoding_strategy'):
self.change_decoding_strategy(decoder_type=args['decoder_type'])
else:
raise Exception("Model does not have decoder type option")
super().set_export_config(args)

@classmethod
def list_available_models(cls) -> List[PretrainedModelInfo]:
"""
Expand Down
8 changes: 7 additions & 1 deletion nemo/core/classes/exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,15 @@ def list_export_subnets(self):
return ['self']

def get_export_config(self):
"""
Returns export_config dictionary
"""
return getattr(self, 'export_config', {})

def set_export_config(self, **kwargs):
def set_export_config(self, kwargs):
"""
Sets/updates export_config dictionary
"""
ex_config = self.get_export_config()
ex_config.update(kwargs)
self.export_config = ex_config
4 changes: 2 additions & 2 deletions scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ def nemo_export(argv):
max_dim = args.max_dim

if args.cache_support:
model.set_export_config(cache_support=True)
model.set_export_config({cache_support: True})

if args.config:
kv = dict(map(lambda s: s.split('='), args.config))
model.set_export_config(**kv)
model.set_export_config(kv)

autocast = nullcontext
if args.autocast:
Expand Down

0 comments on commit 7a5e750

Please sign in to comment.