Skip to content

Commit

Permalink
Implemented generic kv-pair setting of export_config from args
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
  • Loading branch information
borisfom committed Jul 5, 2023
1 parent 0b6e4e6 commit 2b57d47
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
8 changes: 8 additions & 0 deletions nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,11 @@ def disabled_deployment_input_names(self):
@property
def disabled_deployment_output_names(self):
return self.encoder.disabled_deployment_output_names

def set_export_config(self, **kwargs):
if 'cache_support' in kwargs:
enable = bool(kwargs['cache_support'])
self.encoder.export_cache_support = enable
logging.info(f"Caching support enabled: {enable}")
self.encoder.setup_streaming_params()
super().set_export_config(**kwargs)
8 changes: 8 additions & 0 deletions nemo/core/classes/exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,11 @@ def list_export_subnets(self):
First goes the one receiving input (input_example)
"""
return ['self']

def get_export_config(self):
return getattr(self, 'export_config', {})

def set_export_config(self, **kwargs):
ex_config = self.get_export_config()
ex_config.update(kwargs)
self.export_config = ex_config
19 changes: 15 additions & 4 deletions scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ def get_args(argv):
)
parser.add_argument("--device", default="cuda", help="Device to export for")
parser.add_argument("--check-tolerance", type=float, default=0.01, help="tolerance for verification")
parser.add_argument(
"--config",
metavar="KEY=VALUE",
nargs='+',
help="Set a number of key-value pairs to model.export_config dictionary "
"(do not put spaces before or after the = sign). "
"Note that values are always treated as strings.",
)

args = parser.parse_args(argv)
return args

Expand Down Expand Up @@ -130,10 +139,12 @@ def nemo_export(argv):
in_args["max_dim"] = args.max_dim
max_dim = args.max_dim

if args.cache_support and hasattr(model, "encoder") and hasattr(model.encoder, "export_cache_support"):
model.encoder.export_cache_support = True
logging.info("Caching support is enabled.")
model.encoder.setup_streaming_params()
if args.cache_support:
model.set_export_config(cache_support=True)

if args.config:
kv = dict(map(lambda s: s.split('='), args.config))
model.set_export_config(**kv)

autocast = nullcontext
if args.autocast:
Expand Down

0 comments on commit 2b57d47

Please sign in to comment.