From 2b57d47ed8996193d8a3376a0adf2955103b1f17 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 4 Jul 2023 18:28:20 -0700 Subject: [PATCH 1/3] Implemented generic kv-pair setting of export_config from args Signed-off-by: Boris Fomitchev --- nemo/collections/asr/models/asr_model.py | 8 ++++++++ nemo/core/classes/exportable.py | 8 ++++++++ scripts/export.py | 19 +++++++++++++++---- 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index c0f4c1cd0a70..1e38b8b93062 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -239,3 +239,11 @@ def disabled_deployment_input_names(self): @property def disabled_deployment_output_names(self): return self.encoder.disabled_deployment_output_names + + def set_export_config(self, **kwargs): + if 'cache_support' in kwargs: + enable = bool(kwargs['cache_support']) + self.encoder.export_cache_support = enable + logging.info(f"Caching support enabled: {enable}") + self.encoder.setup_streaming_params() + super().set_export_config(**kwargs) diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 3d2682f2304e..3f54fb9a0f02 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -302,3 +302,11 @@ def list_export_subnets(self): First goes the one receiving input (input_example) """ return ['self'] + + def get_export_config(self): + return getattr(self, 'export_config', {}) + + def set_export_config(self, **kwargs): + ex_config = self.get_export_config() + ex_config.update(kwargs) + self.export_config = ex_config diff --git a/scripts/export.py b/scripts/export.py index fe3b79ebdf28..7c68fa4cd41c 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -62,6 +62,15 @@ def get_args(argv): ) parser.add_argument("--device", default="cuda", help="Device to export for") parser.add_argument("--check-tolerance", type=float, default=0.01, help="tolerance for verification") + parser.add_argument( + "--config", + metavar="KEY=VALUE", + nargs='+', + help="Set a number of key-value pairs to model.export_config dictionary " + "(do not put spaces before or after the = sign). " + "Note that values are always treated as strings.", + ) + args = parser.parse_args(argv) return args @@ -130,10 +139,12 @@ def nemo_export(argv): in_args["max_dim"] = args.max_dim max_dim = args.max_dim - if args.cache_support and hasattr(model, "encoder") and hasattr(model.encoder, "export_cache_support"): - model.encoder.export_cache_support = True - logging.info("Caching support is enabled.") - model.encoder.setup_streaming_params() + if args.cache_support: + model.set_export_config(cache_support=True) + + if args.config: + kv = dict(map(lambda s: s.split('='), args.config)) + model.set_export_config(**kv) autocast = nullcontext if args.autocast: From 251aec806fc319ada281e00aa1bf2757601c5efd Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 5 Jul 2023 17:06:33 -0700 Subject: [PATCH 2/3] Changed from **kwargs Signed-off-by: Boris Fomitchev --- nemo/collections/asr/models/asr_model.py | 8 ++++---- nemo/core/classes/exportable.py | 2 +- scripts/export.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index 1e38b8b93062..20be6cc16203 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -240,10 +240,10 @@ def disabled_deployment_input_names(self): def disabled_deployment_output_names(self): return self.encoder.disabled_deployment_output_names - def set_export_config(self, **kwargs): - if 'cache_support' in kwargs: - enable = bool(kwargs['cache_support']) + def set_export_config(self, args): + if 'cache_support' in args: + enable = bool(args['cache_support']) self.encoder.export_cache_support = enable logging.info(f"Caching support enabled: {enable}") self.encoder.setup_streaming_params() - super().set_export_config(**kwargs) + super().set_export_config(args) diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 3f54fb9a0f02..0a77ccfee41a 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -306,7 +306,7 @@ def list_export_subnets(self): def get_export_config(self): return getattr(self, 'export_config', {}) - def set_export_config(self, **kwargs): + def set_export_config(self, kwargs): ex_config = self.get_export_config() ex_config.update(kwargs) self.export_config = ex_config diff --git a/scripts/export.py b/scripts/export.py index 7c68fa4cd41c..e22023786448 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -140,11 +140,11 @@ def nemo_export(argv): max_dim = args.max_dim if args.cache_support: - model.set_export_config(cache_support=True) + model.set_export_config({cache_support: True}) if args.config: kv = dict(map(lambda s: s.split('='), args.config)) - model.set_export_config(**kv) + model.set_export_config(kv) autocast = nullcontext if args.autocast: From 6392c43cee2a46739227b8753437f2614723958a Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 5 Jul 2023 17:30:46 -0700 Subject: [PATCH 3/3] Docstring Signed-off-by: Boris Fomitchev --- nemo/core/classes/exportable.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 0a77ccfee41a..e6f131fa0617 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -304,9 +304,15 @@ def list_export_subnets(self): return ['self'] def get_export_config(self): + """ + Returns export_config dictionary + """ return getattr(self, 'export_config', {}) def set_export_config(self, kwargs): + """ + Sets/updates export_config dictionary + """ ex_config = self.get_export_config() ex_config.update(kwargs) self.export_config = ex_config