From 2b57d47ed8996193d8a3376a0adf2955103b1f17 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 4 Jul 2023 18:28:20 -0700 Subject: [PATCH] Implemented generic kv-pair setting of export_config from args Signed-off-by: Boris Fomitchev --- nemo/collections/asr/models/asr_model.py | 8 ++++++++ nemo/core/classes/exportable.py | 8 ++++++++ scripts/export.py | 19 +++++++++++++++---- 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index c0f4c1cd0a70..1e38b8b93062 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -239,3 +239,11 @@ def disabled_deployment_input_names(self): @property def disabled_deployment_output_names(self): return self.encoder.disabled_deployment_output_names + + def set_export_config(self, **kwargs): + if 'cache_support' in kwargs: + enable = bool(kwargs['cache_support']) + self.encoder.export_cache_support = enable + logging.info(f"Caching support enabled: {enable}") + self.encoder.setup_streaming_params() + super().set_export_config(**kwargs) diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 3d2682f2304e..3f54fb9a0f02 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -302,3 +302,11 @@ def list_export_subnets(self): First goes the one receiving input (input_example) """ return ['self'] + + def get_export_config(self): + return getattr(self, 'export_config', {}) + + def set_export_config(self, **kwargs): + ex_config = self.get_export_config() + ex_config.update(kwargs) + self.export_config = ex_config diff --git a/scripts/export.py b/scripts/export.py index fe3b79ebdf28..7c68fa4cd41c 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -62,6 +62,15 @@ def get_args(argv): ) parser.add_argument("--device", default="cuda", help="Device to export for") parser.add_argument("--check-tolerance", type=float, default=0.01, help="tolerance for verification") + parser.add_argument( + "--config", + metavar="KEY=VALUE", + nargs='+', + help="Set a number of key-value pairs to model.export_config dictionary " + "(do not put spaces before or after the = sign). " + "Note that values are always treated as strings.", + ) + args = parser.parse_args(argv) return args @@ -130,10 +139,12 @@ def nemo_export(argv): in_args["max_dim"] = args.max_dim max_dim = args.max_dim - if args.cache_support and hasattr(model, "encoder") and hasattr(model.encoder, "export_cache_support"): - model.encoder.export_cache_support = True - logging.info("Caching support is enabled.") - model.encoder.setup_streaming_params() + if args.cache_support: + model.set_export_config(cache_support=True) + + if args.config: + kv = dict(map(lambda s: s.split('='), args.config)) + model.set_export_config(**kv) autocast = nullcontext if args.autocast: