diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index 545ea5cb346c..36b30aae47b9 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -396,6 +396,7 @@ def synced_generate( context_length_tensor, tokens_to_generate, all_probs, + compute_attention_mask=compute_attention_mask, temperature=temperature, ) else: @@ -825,6 +826,7 @@ def tab_sample_sequence_batch( context_lengths, tokens_to_generate, all_probs=True, + compute_attention_mask=True, type_ids=None, temperature=None, ): @@ -848,7 +850,7 @@ def tab_sample_sequence_batch( # initialize the batch with torch.no_grad(): context_length = context_lengths.min().item() - inference_strategy.init_batch(context_tokens, context_length) + inference_strategy.init_batch(context_tokens, context_length, compute_attention_mask) context = context_tokens[:, :context_length] # the context may start in the middle of the row, # calculate the offset according to the position of '\n' or '<|endoftext|>' @@ -882,7 +884,7 @@ def tab_sample_sequence_batch( while context_length < maxlen: batch, tensor_shape = inference_strategy.prepare_batch_at_step( - tokens, maxlen, micro_batch_size, counter, context_length + tokens, maxlen, micro_batch_size, counter, context_length, compute_attention_mask ) output = inference_strategy.forward_step(batch, tensor_shape)