diff --git a/projects/blenderbot2/agents/sub_modules.py b/projects/blenderbot2/agents/sub_modules.py index a5fe5999449..42ac93cbc28 100644 --- a/projects/blenderbot2/agents/sub_modules.py +++ b/projects/blenderbot2/agents/sub_modules.py @@ -183,6 +183,7 @@ def __init__(self, opt: Opt): overrides['beam_size'] = opt.get('query_generator_beam_size', 3) overrides['beam_min_length'] = opt.get('query_generator_beam_min_length', 2) overrides['model_parallel'] = opt['model_parallel'] + overrides['no_cuda'] = opt['no_cuda'] if self.opt['query_generator_truncate'] > 0: overrides['text_truncate'] = self.opt['query_generator_truncate'] overrides['truncate'] = self.opt['query_generator_truncate'] @@ -287,6 +288,7 @@ def __init__(self, opt: Opt): 'beam_size': opt.get('memory_decoder_beam_size', 3), 'beam_min_length': opt.get('memory_decoder_beam_min_length', 10), 'beam_block_ngram': 3, + 'no_cuda': opt.get('no_cuda', False), } if self.opt.get('memory_decoder_truncate', -1) > 0: overrides['text_truncate'] = self.opt['memory_decoder_truncate']