diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index c0f4c1cd0a70..20be6cc16203 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -239,3 +239,11 @@ def disabled_deployment_input_names(self): @property def disabled_deployment_output_names(self): return self.encoder.disabled_deployment_output_names + + def set_export_config(self, args): + if 'cache_support' in args: + enable = bool(args['cache_support']) + self.encoder.export_cache_support = enable + logging.info(f"Caching support enabled: {enable}") + self.encoder.setup_streaming_params() + super().set_export_config(args) diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 3d2682f2304e..e6f131fa0617 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -302,3 +302,17 @@ def list_export_subnets(self): First goes the one receiving input (input_example) """ return ['self'] + + def get_export_config(self): + """ + Returns export_config dictionary + """ + return getattr(self, 'export_config', {}) + + def set_export_config(self, kwargs): + """ + Sets/updates export_config dictionary + """ + ex_config = self.get_export_config() + ex_config.update(kwargs) + self.export_config = ex_config diff --git a/scripts/export.py b/scripts/export.py index fe3b79ebdf28..e22023786448 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -62,6 +62,15 @@ def get_args(argv): ) parser.add_argument("--device", default="cuda", help="Device to export for") parser.add_argument("--check-tolerance", type=float, default=0.01, help="tolerance for verification") + parser.add_argument( + "--config", + metavar="KEY=VALUE", + nargs='+', + help="Set a number of key-value pairs to model.export_config dictionary " + "(do not put spaces before or after the = sign). " + "Note that values are always treated as strings.", + ) + args = parser.parse_args(argv) return args @@ -130,10 +139,12 @@ def nemo_export(argv): in_args["max_dim"] = args.max_dim max_dim = args.max_dim - if args.cache_support and hasattr(model, "encoder") and hasattr(model.encoder, "export_cache_support"): - model.encoder.export_cache_support = True - logging.info("Caching support is enabled.") - model.encoder.setup_streaming_params() + if args.cache_support: + model.set_export_config({cache_support: True}) + + if args.config: + kv = dict(map(lambda s: s.split('='), args.config)) + model.set_export_config(kv) autocast = nullcontext if args.autocast: