Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented generic kv-pair setting of export_config from args #6978

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,11 @@ def disabled_deployment_input_names(self):
@property
def disabled_deployment_output_names(self):
return self.encoder.disabled_deployment_output_names

titu1994 marked this conversation as resolved.
Show resolved Hide resolved
def set_export_config(self, args):
if 'cache_support' in args:
enable = bool(args['cache_support'])
self.encoder.export_cache_support = enable
logging.info(f"Caching support enabled: {enable}")
self.encoder.setup_streaming_params()
super().set_export_config(args)
14 changes: 14 additions & 0 deletions nemo/core/classes/exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,17 @@ def list_export_subnets(self):
First goes the one receiving input (input_example)
"""
return ['self']

def get_export_config(self):
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns export_config dictionary
"""
return getattr(self, 'export_config', {})

def set_export_config(self, kwargs):
"""
Sets/updates export_config dictionary
"""
ex_config = self.get_export_config()
ex_config.update(kwargs)
self.export_config = ex_config
19 changes: 15 additions & 4 deletions scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ def get_args(argv):
)
parser.add_argument("--device", default="cuda", help="Device to export for")
parser.add_argument("--check-tolerance", type=float, default=0.01, help="tolerance for verification")
parser.add_argument(
"--config",
metavar="KEY=VALUE",
nargs='+',
help="Set a number of key-value pairs to model.export_config dictionary "
"(do not put spaces before or after the = sign). "
"Note that values are always treated as strings.",
)

args = parser.parse_args(argv)
return args

Expand Down Expand Up @@ -130,10 +139,12 @@ def nemo_export(argv):
in_args["max_dim"] = args.max_dim
max_dim = args.max_dim

if args.cache_support and hasattr(model, "encoder") and hasattr(model.encoder, "export_cache_support"):
model.encoder.export_cache_support = True
logging.info("Caching support is enabled.")
model.encoder.setup_streaming_params()
if args.cache_support:
model.set_export_config({cache_support: True})

if args.config:
kv = dict(map(lambda s: s.split('='), args.config))
model.set_export_config(kv)

autocast = nullcontext
if args.autocast:
Expand Down