From 02fca24bedae76507b527a0c47ad0d08dbdd0bfb Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Thu, 23 Jun 2022 11:11:33 -0700 Subject: [PATCH 01/38] add cuda and cpp code for ngram blocking --- parlai/clib/cuda/ngram_repeat_block_cuda.cpp | 48 +++++++++++ .../cuda/ngram_repeat_block_cuda_kernel.cu | 82 +++++++++++++++++++ 2 files changed, 130 insertions(+) create mode 100644 parlai/clib/cuda/ngram_repeat_block_cuda.cpp create mode 100644 parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu diff --git a/parlai/clib/cuda/ngram_repeat_block_cuda.cpp b/parlai/clib/cuda/ngram_repeat_block_cuda.cpp new file mode 100644 index 00000000000..073296ba298 --- /dev/null +++ b/parlai/clib/cuda/ngram_repeat_block_cuda.cpp @@ -0,0 +1,48 @@ +/* +Copyright (c) Facebook, Inc. and its affiliates. +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +*/ + +#include +#include + +/* +CPP Binding for CUDA OP +*/ + +// CUDA forward declarations +torch::Tensor ngram_repeat_block_cuda_forward(torch::Tensor tokens, + torch::Tensor lprobs, int bsz, + int step, int beam_size, + int no_repeat_ngram_size); + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +// Input check and call to CUDA OP +// Backward method not required +torch::Tensor ngram_repeat_block_forward(torch::Tensor tokens, + torch::Tensor lprobs, int bsz, + int step, int beam_size, + int no_repeat_ngram_size) { + CHECK_INPUT(tokens); + CHECK_INPUT(lprobs); + assert(bsz > 0); + assert(step >= 0); + assert(beam_size > 0); + assert(no_repeat_ngram_size > 0); + + return ngram_repeat_block_cuda_forward(tokens, lprobs, bsz, step, beam_size, + no_repeat_ngram_size); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &ngram_repeat_block_forward, + "No Repeat Ngram Block forward (CUDA)"); +} \ No newline at end of file diff --git a/parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu b/parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu new file mode 100644 index 00000000000..2f300345c6d --- /dev/null +++ b/parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu @@ -0,0 +1,82 @@ +/* +Copyright (c) Facebook, Inc. and its affiliates. +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +*/ + +/* +Kernel implementation for blocking repeated n-grams. +*/ + +#include +#include +#include +#include +#include + +// Ban repeated ngrams of length = 'no_repeat_ngram_size' +__global__ void banRepeatedTokens(long* __restrict__ tokens, + float* __restrict__ lprobs, + int max_predict_len, int vocab_size, + int no_repeat_ngram_size) { + auto row = blockIdx.x; + auto col = threadIdx.x; + auto start = row * (max_predict_len) + col; + + // Each thread compares ngram starting from + // thread index with final ngram starting from + // step - no_repeat_ngram_size + 2 + + auto check_start_pos = blockDim.x; + auto lprob_start = row * vocab_size; + extern __shared__ long tokens_shm[]; + tokens_shm[col] = tokens[start]; + + if (col == blockDim.x - 1) { + for (int i=1; i(); + auto lprob_ptr = lprobs.data_ptr(); + int blocks = bsz * beam_size; + int shared_mem_size = (step + 1) * sizeof(long); + + // Launching N blocks where N is number of samples in a batch (beams*bsz) + // Launching T threads where T is number of previous ngrams in a sample + // Allocating shared mem per block for fastser access of input tokens since + // each token will be accessed N times to compare with current Ngram where + // N is Ngram size. + + // printf("blocks: %d, threads: %d, max_pred: %d\n", blocks, threads, step); + banRepeatedTokens<<>>( + token_ptr, lprob_ptr, step + 1, vocab_size, no_repeat_ngram_size); + return lprobs; +} From bd8ab04a3365dd5bcb7e9e37738921ea779db92b Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Thu, 23 Jun 2022 11:12:46 -0700 Subject: [PATCH 02/38] add python wrapper --- parlai/ops/ngram_repeat_block.py | 73 ++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 parlai/ops/ngram_repeat_block.py diff --git a/parlai/ops/ngram_repeat_block.py b/parlai/ops/ngram_repeat_block.py new file mode 100644 index 00000000000..6958debd3d4 --- /dev/null +++ b/parlai/ops/ngram_repeat_block.py @@ -0,0 +1,73 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" Wrapper for ngram_repeat_block cuda extension """ +from torch import nn +from torch.autograd import Function + +# import ngram_repeat_block_cuda +from torch.utils.cpp_extension import load + +ngram_repeat_block_cuda = load( + name='ngram_repeat_block_cuda', + sources=[ + 'parlai/clib/cuda/ngram_repeat_block_cuda.cpp', + 'parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu', + ], +) + + +class NGramRepeatBlockFunction(Function): + """ + forward inputs to ngram_repeat_block cuda extension + backward method not needed. + + """ + + def forward(self, tokens, lprobs, bsz, step, beam_size, no_repeat_ngram_size): + """ + Args: + tokens(Tensor): Input tokens(Bsz*beam, seq_len) + lprobs(Tensor): likelihood probability + Expected to be updated in place.(Bsz*beam, vocab_size) + bsz(int): batch size + step(int): current step + beam_size(int): beam size + no_repeat_ngram_size(int): Ngram size + """ + outputs = ngram_repeat_block_cuda.forward( + tokens, lprobs, bsz, step, beam_size, no_repeat_ngram_size + ) + return outputs + + def backward(*args): + raise NotImplementedError + + +class NGramRepeatBlock(nn.Module): + """Wrapper class for calling ngram_repeat_block cuda extension""" + + def __init__(self): + super(NGramRepeatBlock, self).__init__() + + def reset_parameters(self): + pass + + def forward(self, tokens, lprobs, bsz, step, beam_size, no_repeat_ngram_size): + """ + Args: + tokens(Tensor): Input tokens(beam, current_sequence_length) + lprobs(Tensor): likelihood probability(beam, vocab_size) + bsz(int): batch size + step(int): current step + beam_size(int): beam size + no_repeat_ngram_size(int): Ngram size + """ + assert tokens.size(0) == bsz * beam_size + assert lprobs.size(0) == bsz * beam_size + tokens = tokens.contiguous() + lprobs = lprobs.contiguous() + return NGramRepeatBlockFunction.apply( + tokens, lprobs, bsz, step, beam_size, no_repeat_ngram_size + ) From 0eddfe75062e4ce4f568d73d7765549edaa4bcce Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Thu, 23 Jun 2022 11:24:37 -0700 Subject: [PATCH 03/38] modify agent to use cuda kernel for self-blocking --- parlai/core/torch_generator_agent.py | 55 +++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 10 deletions(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 7c637e3323e..9a8fecca466 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -43,6 +43,7 @@ trainable_parameters, PipelineHelper, ) +from parlai.ops.ngram_repeat_block import NGramRepeatBlock class SearchBlocklist(object): @@ -1185,7 +1186,7 @@ def _generate( score[prefix_mask] = neginf(score.dtype) for i, b in enumerate(beams): if not b.is_done(): - b.advance(score[i]) + b.advance(score[i], _ts) incr_state_inds = torch.cat( [ beam_size * i + b.get_backtrack_from_current_step() @@ -1357,7 +1358,9 @@ def __init__( self.eos_top = False self.eos_top_ts = None self.n_best_counter = 0 - self.partial_hyps = [[self.bos] for i in range(beam_size)] + self.partial_hyps = torch.tensor([[self.bos] for i in range(beam_size)]) + # self.partial_hyps = [[self.bos] for i in range(beam_size)] + self.no_repeat_ngram_op = NGramRepeatBlock() def set_context(self: TSType, context: torch.LongTensor) -> TSType: """ @@ -1448,10 +1451,18 @@ def _block_ngrams( continue source_ = hyp if source is None else source ngrams = self._find_ngrams(source_, ngram_size) + print(ngrams) prefix = hyp[-(ngram_size - 1) :] for ngram in ngrams: - if ngram_size == 1 or prefix == list(ngram[:-1]): + # when doing context blocking, ngram is tuple where prefix is tensor + # need to cast both into lists for comparison + if ngram_size == 1 or list(prefix) == list(ngram[:-1]): logprobs[beam_id][ngram[-1]] = neginf(logprobs.dtype) + # if source != None: + # print("context blocking: going to ban stuff!") + # else: + # print("self blocking: going to ban stuff!") + return logprobs def _block_block_list(self, logprobs: torch.Tensor) -> torch.Tensor: @@ -1466,7 +1477,7 @@ def _block_block_list(self, logprobs: torch.Tensor) -> torch.Tensor: logprobs[beam_id][ngram[-1]] = neginf(logprobs.dtype) return logprobs - def advance(self, logprobs): + def advance(self, logprobs, step): """ Advance the beam one step. """ @@ -1487,7 +1498,20 @@ def advance(self, logprobs): # beam blocking if self.block_ngram > 0: - logprobs = self._block_ngrams(self.block_ngram, logprobs, None) + # self blocking + if not self.partial_hyps.is_cuda: + self.partial_hyps = self.partial_hyps.cuda() + if not logprobs.is_cuda: + logprobs = logprobs.cuda() + # logprobs = self._block_ngrams(self.block_ngram, logprobs, None) + logprobs = self.no_repeat_ngram_op( + tokens=self.partial_hyps, + lprobs=logprobs, + bsz=1, + step=step, + beam_size=self.beam_size, + no_repeat_ngram_size=self.block_ngram, + ) logprobs = self._block_block_list(logprobs) @@ -1496,6 +1520,7 @@ def advance(self, logprobs): raise ValueError( "Must use TreeSearch.set_context to use context blocking." ) + # context blocking logprobs = self._block_ngrams( self.context_block_ngram, logprobs, self.context ) @@ -1508,11 +1533,21 @@ def advance(self, logprobs): self.outputs.append(path_selection.token_ids) self.bookkeep.append(path_selection.hypothesis_ids) - tok_id_list = path_selection.token_ids.tolist() - self.partial_hyps = [ - self.partial_hyps[path_selection.hypothesis_ids[i]] + [tok_id_list[i]] - for i in range(self.beam_size) - ] + + # list + # self.partial_hyps = [ + # self.partial_hyps[path_selection.hypothesis_ids[i]] + + # [path_selection.token_ids.tolist()[i]] + # for i in range(self.beam_size)] + + # tensor + self.partial_hyps = torch.cat( + ( + self.partial_hyps[path_selection.hypothesis_ids.long()], + path_selection.token_ids.view(path_selection.token_ids.shape[0], -1), + ), + 1, + ) if self.verbose: assert path_selection.token_details From deda51e5d036936d732aea516d8b66e28e1a9889 Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Fri, 24 Jun 2022 16:28:38 -0700 Subject: [PATCH 04/38] add context blocking --- parlai/clib/cuda/ngram_repeat_block_cuda.cpp | 16 ++-- .../cuda/ngram_repeat_block_cuda_kernel.cu | 84 +++++++++++++------ parlai/core/torch_generator_agent.py | 82 +++++++++++------- parlai/ops/ngram_repeat_block.py | 62 +++++++++++--- 4 files changed, 169 insertions(+), 75 deletions(-) diff --git a/parlai/clib/cuda/ngram_repeat_block_cuda.cpp b/parlai/clib/cuda/ngram_repeat_block_cuda.cpp index 073296ba298..e13bb84c3cf 100644 --- a/parlai/clib/cuda/ngram_repeat_block_cuda.cpp +++ b/parlai/clib/cuda/ngram_repeat_block_cuda.cpp @@ -12,10 +12,11 @@ CPP Binding for CUDA OP */ // CUDA forward declarations -torch::Tensor ngram_repeat_block_cuda_forward(torch::Tensor tokens, +torch::Tensor ngram_repeat_block_cuda_forward(const torch::Tensor hypothesis, const torch::Tensor context, torch::Tensor lprobs, int bsz, int step, int beam_size, - int no_repeat_ngram_size); + int no_repeat_ngram_size, + bool if_context_blocking); #define CHECK_CUDA(x) \ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") @@ -27,19 +28,20 @@ torch::Tensor ngram_repeat_block_cuda_forward(torch::Tensor tokens, // Input check and call to CUDA OP // Backward method not required -torch::Tensor ngram_repeat_block_forward(torch::Tensor tokens, +torch::Tensor ngram_repeat_block_forward(const torch::Tensor hypothesis, const torch::Tensor context, torch::Tensor lprobs, int bsz, int step, int beam_size, - int no_repeat_ngram_size) { - CHECK_INPUT(tokens); + int no_repeat_ngram_size, + bool if_context_blocking) { + CHECK_INPUT(hypothesis); CHECK_INPUT(lprobs); assert(bsz > 0); assert(step >= 0); assert(beam_size > 0); assert(no_repeat_ngram_size > 0); - return ngram_repeat_block_cuda_forward(tokens, lprobs, bsz, step, beam_size, - no_repeat_ngram_size); + return ngram_repeat_block_cuda_forward(hypothesis, context, lprobs, bsz, step, beam_size, + no_repeat_ngram_size, if_context_blocking); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { diff --git a/parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu b/parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu index 2f300345c6d..0ca4f231e96 100644 --- a/parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu +++ b/parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu @@ -15,39 +15,48 @@ Kernel implementation for blocking repeated n-grams. #include // Ban repeated ngrams of length = 'no_repeat_ngram_size' -__global__ void banRepeatedTokens(long* __restrict__ tokens, +__global__ void banRepeatedTokens(long* __restrict__ hypothesis_ptr, + long* __restrict__ context_ptr, float* __restrict__ lprobs, - int max_predict_len, int vocab_size, - int no_repeat_ngram_size) { + int current_seq_length, + int vocab_size, + int no_repeat_ngram_size, + bool if_context_blocking) { auto row = blockIdx.x; auto col = threadIdx.x; - auto start = row * (max_predict_len) + col; + // start of context ngram on current thread + auto index = row * current_seq_length + col; + // start of last ngram of hypothesis + auto start_of_ngram = current_seq_length - no_repeat_ngram_size + 1; + long* previous_ngram_ptr; - // Each thread compares ngram starting from - // thread index with final ngram starting from - // step - no_repeat_ngram_size + 2 + if (if_context_blocking) { + previous_ngram_ptr = &context_ptr[col]; + } else { + previous_ngram_ptr = &hypothesis_ptr[index]; + } - auto check_start_pos = blockDim.x; auto lprob_start = row * vocab_size; extern __shared__ long tokens_shm[]; - tokens_shm[col] = tokens[start]; + // each thread writes to shared array + tokens_shm[col] = *previous_ngram_ptr; + // final thread writes the end of previous ngram array to tokens_shm if (col == blockDim.x - 1) { for (int i=1; i(); +torch::Tensor ngram_repeat_block_cuda_forward(const torch::Tensor hypothesis, + const torch::Tensor context, + torch::Tensor lprobs, + int bsz, + int step, + int beam_size, + int no_repeat_ngram_size, + bool if_context_blocking) { + + auto hypothesis_ptr = hypothesis.data_ptr(); + auto context_ptr = context.data_ptr(); auto lprob_ptr = lprobs.data_ptr(); + + int context_length; + if (if_context_blocking) { + context_length = context.size(0); + } else { + // context is previously generated word sequences for self-blocking + context_length = hypothesis.size(1); + } + + int threads = context_length - no_repeat_ngram_size + 1; + if (step - no_repeat_ngram_size + 2 <= 0) return lprobs; + int vocab_size = lprobs.size(1); int blocks = bsz * beam_size; - int shared_mem_size = (step + 1) * sizeof(long); + int current_seq_length = hypothesis.size(1); + int shared_mem_size = context_length * sizeof(long); + // Launching N blocks where N is number of samples in a batch (beams*bsz) // Launching T threads where T is number of previous ngrams in a sample @@ -75,8 +100,13 @@ torch::Tensor ngram_repeat_block_cuda_forward(const torch::Tensor tokens, // each token will be accessed N times to compare with current Ngram where // N is Ngram size. - // printf("blocks: %d, threads: %d, max_pred: %d\n", blocks, threads, step); banRepeatedTokens<<>>( - token_ptr, lprob_ptr, step + 1, vocab_size, no_repeat_ngram_size); + hypothesis_ptr, + context_ptr, + lprob_ptr, + current_seq_length, + vocab_size, + no_repeat_ngram_size, + if_context_blocking); return lprobs; } diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 9a8fecca466..1ce435d2f15 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -454,6 +454,12 @@ def add_cmdline_args( default=False, help='if true, compute tokenized bleu scores', ) + parser.add_argument( + '--gpu-beam-blocking', + type='bool', + help='Set to use CUDA kernel for beam search ngram blocking', + default=False, + ) super().add_cmdline_args(parser, partial_opt=partial_opt) return agent @@ -957,6 +963,7 @@ def _treesearch_factory(self, device, verbose=False): ) elif method == 'beam': return BeamSearch( + self.opt['gpu_beam_blocking'], beam_size, min_length=self.beam_min_length, block_ngram=self.beam_block_ngram, @@ -1359,7 +1366,6 @@ def __init__( self.eos_top_ts = None self.n_best_counter = 0 self.partial_hyps = torch.tensor([[self.bos] for i in range(beam_size)]) - # self.partial_hyps = [[self.bos] for i in range(beam_size)] self.no_repeat_ngram_op = NGramRepeatBlock() def set_context(self: TSType, context: torch.LongTensor) -> TSType: @@ -1370,7 +1376,7 @@ def set_context(self: TSType, context: torch.LongTensor) -> TSType: a LongTensor representing the input context; used for context ngram blocking, if supplied """ - self.context = context.tolist() + self.context = torch.Tensor(context.tolist()).long() return self def set_batch_context( @@ -1386,7 +1392,7 @@ def set_batch_context( :param batch_idx: index of the batch """ - self.context = batch_context_list[batch_idx] + self.context = torch.Tensor(batch_context_list[batch_idx]).long() return self def set_block_list(self: TSType, block_list: Optional[SearchBlocklist]) -> TSType: @@ -1432,7 +1438,12 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection pass def _block_ngrams( - self, ngram_size: int, logprobs: torch.Tensor, source: torch.LongTensor = None + self, + ngram_size: int, + logprobs: torch.Tensor, + step: int, + GPU_BEAM_BLOCKING: bool = True, + if_context_blocking=False, ): """ Hard block ngrams from the logprobs, based on the source. @@ -1446,23 +1457,36 @@ def _block_ngrams( Source text to grab ngrams from. If None, it uses the current hypothesis (i.e. self-blocking). """ + context = None + if self.gpu_beam_blocking: + if if_context_blocking: + if not self.context.is_cuda: + self.context = self.context.cuda() + context = self.context + + logprobs = self.no_repeat_ngram_op( + hypothesis=self.partial_hyps, + context=context, + lprobs=logprobs, + bsz=1, + step=step, + beam_size=self.beam_size, + no_repeat_ngram_size=self.block_ngram, + if_context_blocking=if_context_blocking, + ) + return logprobs + for beam_id, hyp in enumerate(self.partial_hyps): if len(hyp) < ngram_size - 1: continue - source_ = hyp if source is None else source + source_ = hyp if if_context_blocking is False else self.context ngrams = self._find_ngrams(source_, ngram_size) - print(ngrams) prefix = hyp[-(ngram_size - 1) :] for ngram in ngrams: # when doing context blocking, ngram is tuple where prefix is tensor # need to cast both into lists for comparison if ngram_size == 1 or list(prefix) == list(ngram[:-1]): - logprobs[beam_id][ngram[-1]] = neginf(logprobs.dtype) - # if source != None: - # print("context blocking: going to ban stuff!") - # else: - # print("self blocking: going to ban stuff!") - + logprobs[beam_id][ngram[-1].long()] = neginf(logprobs.dtype) return logprobs def _block_block_list(self, logprobs: torch.Tensor) -> torch.Tensor: @@ -1497,20 +1521,16 @@ def advance(self, logprobs, step): self.scores[hyp_id] = neginf(self.scores.dtype) # beam blocking + if not self.partial_hyps.is_cuda: + self.partial_hyps = self.partial_hyps.cuda() + if self.block_ngram > 0: # self blocking - if not self.partial_hyps.is_cuda: - self.partial_hyps = self.partial_hyps.cuda() - if not logprobs.is_cuda: - logprobs = logprobs.cuda() - # logprobs = self._block_ngrams(self.block_ngram, logprobs, None) - logprobs = self.no_repeat_ngram_op( - tokens=self.partial_hyps, - lprobs=logprobs, - bsz=1, + logprobs = self._block_ngrams( + ngram_size=self.block_ngram, + logprobs=logprobs, step=step, - beam_size=self.beam_size, - no_repeat_ngram_size=self.block_ngram, + if_context_blocking=False, ) logprobs = self._block_block_list(logprobs) @@ -1522,7 +1542,10 @@ def advance(self, logprobs, step): ) # context blocking logprobs = self._block_ngrams( - self.context_block_ngram, logprobs, self.context + ngram_size=self.block_ngram, + logprobs=logprobs, + step=step, + if_context_blocking=True, ) path_selection = self.select_paths(logprobs, self.scores, current_length) @@ -1534,13 +1557,6 @@ def advance(self, logprobs, step): self.outputs.append(path_selection.token_ids) self.bookkeep.append(path_selection.hypothesis_ids) - # list - # self.partial_hyps = [ - # self.partial_hyps[path_selection.hypothesis_ids[i]] + - # [path_selection.token_ids.tolist()[i]] - # for i in range(self.beam_size)] - - # tensor self.partial_hyps = torch.cat( ( self.partial_hyps[path_selection.hypothesis_ids.long()], @@ -1745,6 +1761,10 @@ class BeamSearch(TreeSearch): Beam search. """ + def __init__(self, gpu_beam_blocking, *args, **kwargs): + super().__init__(*args, **kwargs) + self.gpu_beam_blocking = gpu_beam_blocking + def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection: """ Select the next vocabulary item in these beams. diff --git a/parlai/ops/ngram_repeat_block.py b/parlai/ops/ngram_repeat_block.py index 6958debd3d4..4340d0e180a 100644 --- a/parlai/ops/ngram_repeat_block.py +++ b/parlai/ops/ngram_repeat_block.py @@ -3,6 +3,7 @@ # LICENSE file in the root directory of this source tree. """ Wrapper for ngram_repeat_block cuda extension """ +import torch from torch import nn from torch.autograd import Function @@ -25,19 +26,37 @@ class NGramRepeatBlockFunction(Function): """ - def forward(self, tokens, lprobs, bsz, step, beam_size, no_repeat_ngram_size): + def forward( + self, + hypothesis, + context, + lprobs, + bsz, + step, + beam_size, + no_repeat_ngram_size, + if_context_blocking, + ): """ Args: - tokens(Tensor): Input tokens(Bsz*beam, seq_len) - lprobs(Tensor): likelihood probability - Expected to be updated in place.(Bsz*beam, vocab_size) + hypothesis(Tensor): (beam*bsz, current_sequence_length) + context(Tensor): context for context-blocking + lprobs(Tensor): likelihood probability(beam, vocab_size) bsz(int): batch size step(int): current step beam_size(int): beam size no_repeat_ngram_size(int): Ngram size + if_context_blocking(bool): whether to use context-blocking """ outputs = ngram_repeat_block_cuda.forward( - tokens, lprobs, bsz, step, beam_size, no_repeat_ngram_size + hypothesis, + context, + lprobs, + bsz, + step, + beam_size, + no_repeat_ngram_size, + if_context_blocking, ) return outputs @@ -54,20 +73,43 @@ def __init__(self): def reset_parameters(self): pass - def forward(self, tokens, lprobs, bsz, step, beam_size, no_repeat_ngram_size): + def forward( + self, + hypothesis, + context, + lprobs, + bsz, + step, + beam_size, + no_repeat_ngram_size, + if_context_blocking=False, + ): """ Args: - tokens(Tensor): Input tokens(beam, current_sequence_length) + hypothesis(Tensor): (beam*bsz, current_sequence_length) + context(Tensor): context for context-blocking lprobs(Tensor): likelihood probability(beam, vocab_size) bsz(int): batch size step(int): current step beam_size(int): beam size no_repeat_ngram_size(int): Ngram size + if_context_blocking(bool): whether to use context-blocking """ - assert tokens.size(0) == bsz * beam_size + # placeholder tensor to pass in to pass type check, won't be used + if not if_context_blocking: + context = torch.Tensor([0]).long() + assert hypothesis.size(0) == bsz * beam_size assert lprobs.size(0) == bsz * beam_size - tokens = tokens.contiguous() + hypothesis = hypothesis.contiguous() + context = context.contiguous() lprobs = lprobs.contiguous() return NGramRepeatBlockFunction.apply( - tokens, lprobs, bsz, step, beam_size, no_repeat_ngram_size + hypothesis, + context, + lprobs, + bsz, + step, + beam_size, + no_repeat_ngram_size, + if_context_blocking, ) From 189d35b07a18d68697bc9fa79bf8443d9bacd71f Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Mon, 27 Jun 2022 09:16:11 -0700 Subject: [PATCH 05/38] change load paths --- parlai/ops/ngram_repeat_block.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parlai/ops/ngram_repeat_block.py b/parlai/ops/ngram_repeat_block.py index 4340d0e180a..ab3e562279d 100644 --- a/parlai/ops/ngram_repeat_block.py +++ b/parlai/ops/ngram_repeat_block.py @@ -13,8 +13,8 @@ ngram_repeat_block_cuda = load( name='ngram_repeat_block_cuda', sources=[ - 'parlai/clib/cuda/ngram_repeat_block_cuda.cpp', - 'parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu', + '../../parlai/clib/cuda/ngram_repeat_block_cuda.cpp', + '../../parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu', ], ) From 9ccf003fd019d4529786a03ae618795942a66203 Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Mon, 27 Jun 2022 09:29:13 -0700 Subject: [PATCH 06/38] add ninja to requirement --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 643cb7bbe7a..00b1be69c3a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -56,3 +56,4 @@ jsonlines==1.2.0 numpy<=1.21 # Used to be `==1.17.5` before but tests -- pulling in latest at 1.22 not happy markdown<=3.3.2 # Pin to something that works so tests are happy jinja2==3.0.3 +ninja==1.10.2.3 From e864f4996632af145d86005a8ad7f6cc66f663a5 Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Mon, 27 Jun 2022 10:28:22 -0700 Subject: [PATCH 07/38] modify setup script to install kernel ahead of time --- parlai/ops/ngram_repeat_block.py | 19 ++++++++++--------- setup.py | 12 ++++++++++++ 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/parlai/ops/ngram_repeat_block.py b/parlai/ops/ngram_repeat_block.py index ab3e562279d..f9e6774f402 100644 --- a/parlai/ops/ngram_repeat_block.py +++ b/parlai/ops/ngram_repeat_block.py @@ -7,16 +7,17 @@ from torch import nn from torch.autograd import Function -# import ngram_repeat_block_cuda -from torch.utils.cpp_extension import load +import ngram_repeat_block_cuda -ngram_repeat_block_cuda = load( - name='ngram_repeat_block_cuda', - sources=[ - '../../parlai/clib/cuda/ngram_repeat_block_cuda.cpp', - '../../parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu', - ], -) +# from torch.utils.cpp_extension import load + +# ngram_repeat_block_cuda = load( +# name='ngram_repeat_block_cuda', +# sources=[ +# 'parlai/clib/cuda/ngram_repeat_block_cuda.cpp', +# 'parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu', +# ], +# ) class NGramRepeatBlockFunction(Function): diff --git a/setup.py b/setup.py index 0d88b17edaa..a38ccb6e932 100644 --- a/setup.py +++ b/setup.py @@ -8,6 +8,7 @@ import sys from setuptools import setup, find_packages +from torch.utils.cpp_extension import BuildExtension, CUDAExtension VERSION = '1.6.0' # if you update, update parlai/__init__.py too! @@ -26,6 +27,15 @@ if __name__ == '__main__': + extensions = [ + CUDAExtension( + 'ngram_repeat_block_cuda', + [ + 'parlai/clib/cuda/ngram_repeat_block_cuda.cpp', + 'parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu', + ], + ), + ] setup( name='parlai', version=VERSION, @@ -38,10 +48,12 @@ install_requires=reqs, include_package_data=True, package_data={'': ['*.txt', '*.md', '*.opt']}, + ext_modules=extensions, entry_points={ "flake8.extension": ["PAI = parlai.utils.flake8:ParlAIChecker"], "console_scripts": ["parlai=parlai.__main__:main"], }, + cmdclass={'build_ext': BuildExtension}, classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", From 49c9bfb1de9305d380c73f7ef2686b1f59de1f41 Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Mon, 27 Jun 2022 11:09:17 -0700 Subject: [PATCH 08/38] change circleci test to use gpu to build website --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 5ab664abdbf..ee1450f69de 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -375,7 +375,7 @@ jobs: pytest_flags: -v -s build_website: - executor: small_cpu38 + executor: gpu_small working_directory: ~/ParlAI parallelism: 1 steps: From 1d2448ea9b8f73ee50003f565ef7924e5baac74e Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Mon, 27 Jun 2022 11:46:12 -0700 Subject: [PATCH 09/38] change back to JIT, switch directory when loadingcuda moddule --- .circleci/config.yml | 2 +- parlai/ops/ngram_repeat_block.py | 27 ++++++++++++++++++--------- setup.py | 25 +++++++++++++------------ 3 files changed, 32 insertions(+), 22 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index ee1450f69de..5ab664abdbf 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -375,7 +375,7 @@ jobs: pytest_flags: -v -s build_website: - executor: gpu_small + executor: small_cpu38 working_directory: ~/ParlAI parallelism: 1 steps: diff --git a/parlai/ops/ngram_repeat_block.py b/parlai/ops/ngram_repeat_block.py index f9e6774f402..02d32b7a580 100644 --- a/parlai/ops/ngram_repeat_block.py +++ b/parlai/ops/ngram_repeat_block.py @@ -7,17 +7,26 @@ from torch import nn from torch.autograd import Function -import ngram_repeat_block_cuda -# from torch.utils.cpp_extension import load +import os -# ngram_repeat_block_cuda = load( -# name='ngram_repeat_block_cuda', -# sources=[ -# 'parlai/clib/cuda/ngram_repeat_block_cuda.cpp', -# 'parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu', -# ], -# ) +current = os.getcwd() +abspath = os.path.abspath(__file__) +dname = os.path.dirname(abspath) +os.chdir(dname) + +# import ngram_repeat_block_cuda + +from torch.utils.cpp_extension import load + +ngram_repeat_block_cuda = load( + name='ngram_repeat_block_cuda', + sources=[ + '../clib/cuda/ngram_repeat_block_cuda.cpp', + '../clib/cuda/ngram_repeat_block_cuda_kernel.cu', + ], +) +os.chdir(current) class NGramRepeatBlockFunction(Function): diff --git a/setup.py b/setup.py index a38ccb6e932..47e68b87ebd 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,8 @@ import sys from setuptools import setup, find_packages -from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +# from torch.utils.cpp_extension import BuildExtension, CUDAExtension VERSION = '1.6.0' # if you update, update parlai/__init__.py too! @@ -27,15 +28,15 @@ if __name__ == '__main__': - extensions = [ - CUDAExtension( - 'ngram_repeat_block_cuda', - [ - 'parlai/clib/cuda/ngram_repeat_block_cuda.cpp', - 'parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu', - ], - ), - ] + # extensions = [ + # CUDAExtension( + # 'ngram_repeat_block_cuda', + # [ + # 'parlai/clib/cuda/ngram_repeat_block_cuda.cpp', + # 'parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu', + # ], + # ), + # ] setup( name='parlai', version=VERSION, @@ -48,12 +49,12 @@ install_requires=reqs, include_package_data=True, package_data={'': ['*.txt', '*.md', '*.opt']}, - ext_modules=extensions, + # ext_modules=extensions, entry_points={ "flake8.extension": ["PAI = parlai.utils.flake8:ParlAIChecker"], "console_scripts": ["parlai=parlai.__main__:main"], }, - cmdclass={'build_ext': BuildExtension}, + # cmdclass={'build_ext': BuildExtension}, classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", From 0263c8f319ba1a6e099225d430489960c16f3ec8 Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Mon, 27 Jun 2022 12:05:00 -0700 Subject: [PATCH 10/38] add check for cuda --- parlai/core/torch_generator_agent.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 1ce435d2f15..e42442a9b1a 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -43,7 +43,9 @@ trainable_parameters, PipelineHelper, ) -from parlai.ops.ngram_repeat_block import NGramRepeatBlock + +if torch.cuda.is_available(): + from parlai.ops.ngram_repeat_block import NGramRepeatBlock class SearchBlocklist(object): @@ -1764,6 +1766,7 @@ class BeamSearch(TreeSearch): def __init__(self, gpu_beam_blocking, *args, **kwargs): super().__init__(*args, **kwargs) self.gpu_beam_blocking = gpu_beam_blocking + print('using gpu beam blocking!') def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection: """ From cbca0926f4c11a01bead344105dc21e94fac0496 Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Mon, 27 Jun 2022 12:10:42 -0700 Subject: [PATCH 11/38] get rid of ninja --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 00b1be69c3a..643cb7bbe7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -56,4 +56,3 @@ jsonlines==1.2.0 numpy<=1.21 # Used to be `==1.17.5` before but tests -- pulling in latest at 1.22 not happy markdown<=3.3.2 # Pin to something that works so tests are happy jinja2==3.0.3 -ninja==1.10.2.3 From f75dc56ccb924f20ff0666a8c8e94d3a871bcc61 Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Mon, 27 Jun 2022 12:37:23 -0700 Subject: [PATCH 12/38] remove unused param --- parlai/core/torch_generator_agent.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index e42442a9b1a..312f61803c5 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -965,7 +965,7 @@ def _treesearch_factory(self, device, verbose=False): ) elif method == 'beam': return BeamSearch( - self.opt['gpu_beam_blocking'], + self.opt.get('gpu_beam_blocking', False), beam_size, min_length=self.beam_min_length, block_ngram=self.beam_block_ngram, @@ -1368,7 +1368,8 @@ def __init__( self.eos_top_ts = None self.n_best_counter = 0 self.partial_hyps = torch.tensor([[self.bos] for i in range(beam_size)]) - self.no_repeat_ngram_op = NGramRepeatBlock() + if torch.cuda.is_available(): + self.no_repeat_ngram_op = NGramRepeatBlock() def set_context(self: TSType, context: torch.LongTensor) -> TSType: """ @@ -1444,7 +1445,6 @@ def _block_ngrams( ngram_size: int, logprobs: torch.Tensor, step: int, - GPU_BEAM_BLOCKING: bool = True, if_context_blocking=False, ): """ @@ -1460,7 +1460,7 @@ def _block_ngrams( hypothesis (i.e. self-blocking). """ context = None - if self.gpu_beam_blocking: + if self.gpu_beam_blocking == True: if if_context_blocking: if not self.context.is_cuda: self.context = self.context.cuda() @@ -1763,10 +1763,9 @@ class BeamSearch(TreeSearch): Beam search. """ - def __init__(self, gpu_beam_blocking, *args, **kwargs): + def __init__(self, gpu_beam_blocking=False, *args, **kwargs): super().__init__(*args, **kwargs) self.gpu_beam_blocking = gpu_beam_blocking - print('using gpu beam blocking!') def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection: """ From c03c0d2e0d6cbb65e478d404e9dea127e6d34e46 Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Mon, 27 Jun 2022 12:49:32 -0700 Subject: [PATCH 13/38] move "hyps to cuda" into _block_ngrams() --- parlai/core/torch_generator_agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 312f61803c5..eabc8b713d1 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -1461,6 +1461,8 @@ def _block_ngrams( """ context = None if self.gpu_beam_blocking == True: + if not self.partial_hyps.is_cuda: + self.partial_hyps = self.partial_hyps.cuda() if if_context_blocking: if not self.context.is_cuda: self.context = self.context.cuda() @@ -1523,8 +1525,6 @@ def advance(self, logprobs, step): self.scores[hyp_id] = neginf(self.scores.dtype) # beam blocking - if not self.partial_hyps.is_cuda: - self.partial_hyps = self.partial_hyps.cuda() if self.block_ngram > 0: # self blocking From aae7127f31744a24b6a8ea758958c6a9ded52a13 Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Mon, 27 Jun 2022 13:34:57 -0700 Subject: [PATCH 14/38] set gpu_beam_blocking as attribute for TreeSearch, modify block_list function to cast into list, set current ngram_size for context blocking, move path to cpu when needed --- parlai/core/torch_generator_agent.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index eabc8b713d1..071f32873a5 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -962,10 +962,10 @@ def _treesearch_factory(self, device, verbose=False): eos_token=self.END_IDX, device=device, verbose=verbose, + gpu_beam_blocking=self.opt.get('gpu_beam_blocking', False), ) elif method == 'beam': return BeamSearch( - self.opt.get('gpu_beam_blocking', False), beam_size, min_length=self.beam_min_length, block_ngram=self.beam_block_ngram, @@ -976,6 +976,7 @@ def _treesearch_factory(self, device, verbose=False): eos_token=self.END_IDX, device=device, verbose=verbose, + gpu_beam_blocking=self.opt.get('gpu_beam_blocking', False), ) elif method == 'delayedbeam': return DelayedBeamSearch( @@ -991,6 +992,7 @@ def _treesearch_factory(self, device, verbose=False): eos_token=self.END_IDX, device=device, verbose=verbose, + gpu_beam_blocking=self.opt.get('gpu_beam_blocking', False), ) elif method == 'topk': return TopKSampling( @@ -1005,6 +1007,7 @@ def _treesearch_factory(self, device, verbose=False): eos_token=self.END_IDX, device=device, verbose=verbose, + gpu_beam_blocking=self.opt.get('gpu_beam_blocking', False), ) elif method == 'nucleus': return NucleusSampling( @@ -1019,6 +1022,7 @@ def _treesearch_factory(self, device, verbose=False): eos_token=self.END_IDX, device=device, verbose=verbose, + gpu_beam_blocking=self.opt.get('gpu_beam_blocking', False), ) else: raise ValueError(f"Can't use inference method {method}") @@ -1311,6 +1315,7 @@ def __init__( device='cpu', length_penalty=0.65, verbose=False, + gpu_beam_blocking=False, ): """ Instantiate Beam object. @@ -1370,6 +1375,7 @@ def __init__( self.partial_hyps = torch.tensor([[self.bos] for i in range(beam_size)]) if torch.cuda.is_available(): self.no_repeat_ngram_op = NGramRepeatBlock() + self.gpu_beam_blocking = gpu_beam_blocking def set_context(self: TSType, context: torch.LongTensor) -> TSType: """ @@ -1501,7 +1507,7 @@ def _block_block_list(self, logprobs: torch.Tensor) -> torch.Tensor: for ngram_size, bad_ngrams in self.block_list.items(): prefix = hyp[-(ngram_size - 1) :] for ngram in bad_ngrams: - if (ngram_size == 1) or prefix == list(ngram[:-1]): + if (ngram_size == 1) or list(prefix) == list(ngram[:-1]): logprobs[beam_id][ngram[-1]] = neginf(logprobs.dtype) return logprobs @@ -1544,7 +1550,7 @@ def advance(self, logprobs, step): ) # context blocking logprobs = self._block_ngrams( - ngram_size=self.block_ngram, + ngram_size=self.context_block_ngram, logprobs=logprobs, step=step, if_context_blocking=True, @@ -1559,10 +1565,16 @@ def advance(self, logprobs, step): self.outputs.append(path_selection.token_ids) self.bookkeep.append(path_selection.hypothesis_ids) + # this checking for device seems suboptimal + # might need to change later + if not self.gpu_beam_blocking: + hyp_device = 'cpu' self.partial_hyps = torch.cat( ( self.partial_hyps[path_selection.hypothesis_ids.long()], - path_selection.token_ids.view(path_selection.token_ids.shape[0], -1), + path_selection.token_ids.view(path_selection.token_ids.shape[0], -1).to( + hyp_device + ), ), 1, ) @@ -1763,10 +1775,6 @@ class BeamSearch(TreeSearch): Beam search. """ - def __init__(self, gpu_beam_blocking=False, *args, **kwargs): - super().__init__(*args, **kwargs) - self.gpu_beam_blocking = gpu_beam_blocking - def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection: """ Select the next vocabulary item in these beams. From f4d6bf1ed4dafcdf798dbafd58c3c6248c877a2b Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Mon, 27 Jun 2022 13:38:12 -0700 Subject: [PATCH 15/38] fix lint formatting issues --- parlai/ops/ngram_repeat_block.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/parlai/ops/ngram_repeat_block.py b/parlai/ops/ngram_repeat_block.py index 02d32b7a580..f311b82f526 100644 --- a/parlai/ops/ngram_repeat_block.py +++ b/parlai/ops/ngram_repeat_block.py @@ -2,7 +2,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -""" Wrapper for ngram_repeat_block cuda extension """ +""" +Wrapper for ngram_repeat_block cuda extension. +""" import torch from torch import nn from torch.autograd import Function @@ -31,9 +33,7 @@ class NGramRepeatBlockFunction(Function): """ - forward inputs to ngram_repeat_block cuda extension - backward method not needed. - + forward inputs to ngram_repeat_block cuda extension backward method not needed. """ def forward( @@ -75,7 +75,9 @@ def backward(*args): class NGramRepeatBlock(nn.Module): - """Wrapper class for calling ngram_repeat_block cuda extension""" + """ + Wrapper class for calling ngram_repeat_block cuda extension. + """ def __init__(self): super(NGramRepeatBlock, self).__init__() From 7076ff439809a20956683b1269d43c1abc6c99f2 Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Mon, 27 Jun 2022 13:53:16 -0700 Subject: [PATCH 16/38] add init file to new folders --- parlai/clib/__init__.py | 5 +++++ parlai/clib/cuda/__init__.py | 5 +++++ parlai/ops/__init__.py | 5 +++++ 3 files changed, 15 insertions(+) create mode 100644 parlai/clib/__init__.py create mode 100644 parlai/clib/cuda/__init__.py create mode 100644 parlai/ops/__init__.py diff --git a/parlai/clib/__init__.py b/parlai/clib/__init__.py new file mode 100644 index 00000000000..240697e3247 --- /dev/null +++ b/parlai/clib/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/parlai/clib/cuda/__init__.py b/parlai/clib/cuda/__init__.py new file mode 100644 index 00000000000..240697e3247 --- /dev/null +++ b/parlai/clib/cuda/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/parlai/ops/__init__.py b/parlai/ops/__init__.py new file mode 100644 index 00000000000..240697e3247 --- /dev/null +++ b/parlai/ops/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. From 05b81b6299dd15607367d100783990900319dd50 Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Mon, 27 Jun 2022 13:56:18 -0700 Subject: [PATCH 17/38] add new line at end of file --- parlai/clib/cuda/ngram_repeat_block_cuda.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parlai/clib/cuda/ngram_repeat_block_cuda.cpp b/parlai/clib/cuda/ngram_repeat_block_cuda.cpp index e13bb84c3cf..f7f2044a3be 100644 --- a/parlai/clib/cuda/ngram_repeat_block_cuda.cpp +++ b/parlai/clib/cuda/ngram_repeat_block_cuda.cpp @@ -47,4 +47,4 @@ torch::Tensor ngram_repeat_block_forward(const torch::Tensor hypothesis, const t PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &ngram_repeat_block_forward, "No Repeat Ngram Block forward (CUDA)"); -} \ No newline at end of file +} From d6d7ab68138d2d21ffcaf6e83715da11cba36d3b Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Mon, 27 Jun 2022 14:02:12 -0700 Subject: [PATCH 18/38] new lint errors --- parlai/core/torch_generator_agent.py | 2 +- parlai/ops/ngram_repeat_block.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 071f32873a5..74976ce506d 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -1466,7 +1466,7 @@ def _block_ngrams( hypothesis (i.e. self-blocking). """ context = None - if self.gpu_beam_blocking == True: + if self.gpu_beam_blocking: if not self.partial_hyps.is_cuda: self.partial_hyps = self.partial_hyps.cuda() if if_context_blocking: diff --git a/parlai/ops/ngram_repeat_block.py b/parlai/ops/ngram_repeat_block.py index f311b82f526..762c27a1636 100644 --- a/parlai/ops/ngram_repeat_block.py +++ b/parlai/ops/ngram_repeat_block.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 + # Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. @@ -9,8 +11,8 @@ from torch import nn from torch.autograd import Function - import os +from torch.utils.cpp_extension import load current = os.getcwd() abspath = os.path.abspath(__file__) @@ -19,8 +21,6 @@ # import ngram_repeat_block_cuda -from torch.utils.cpp_extension import load - ngram_repeat_block_cuda = load( name='ngram_repeat_block_cuda', sources=[ From 6e17d5f86671c35d5143c8f3c02c11dcb075c782 Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Mon, 27 Jun 2022 14:55:34 -0700 Subject: [PATCH 19/38] add ninja --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 643cb7bbe7a..777c7f609b5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -56,3 +56,4 @@ jsonlines==1.2.0 numpy<=1.21 # Used to be `==1.17.5` before but tests -- pulling in latest at 1.22 not happy markdown<=3.3.2 # Pin to something that works so tests are happy jinja2==3.0.3 +ninja From 538cd6268aca7bfcc93753581b3bf5ecacfd4c5f Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Mon, 27 Jun 2022 15:16:33 -0700 Subject: [PATCH 20/38] set protobuf --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 777c7f609b5..e259b84d5da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -57,3 +57,4 @@ numpy<=1.21 # Used to be `==1.17.5` before but tests -- pulling in latest at 1.2 markdown<=3.3.2 # Pin to something that works so tests are happy jinja2==3.0.3 ninja +protobuf~=3.20 From b5b1df19b9399bc7e213de9bec4f28a61722876b Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Mon, 27 Jun 2022 21:58:49 -0700 Subject: [PATCH 21/38] cast tensor to list in to pass gpu tests --- projects/light_whoami/agents/pacer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/projects/light_whoami/agents/pacer.py b/projects/light_whoami/agents/pacer.py index f8014d7655c..db34a4e1124 100644 --- a/projects/light_whoami/agents/pacer.py +++ b/projects/light_whoami/agents/pacer.py @@ -264,7 +264,8 @@ def modify_logprobs(self, logprobs: torch.Tensor) -> torch.Tensor: h for i in range(len(self.partial_hyps)) for h in [ - self.agent._v2t(self.partial_hyps[i][1:] + [ind]) for ind in inds[i] + self.agent._v2t(self.partial_hyps[i][1:].tolist() + [ind]) + for ind in inds[i] ] ] # Classify all beam outputs From e4b78c7731cdc224499ae26ed1a7a70d468e0361 Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Mon, 27 Jun 2022 22:39:26 -0700 Subject: [PATCH 22/38] debug long gpu tests --- parlai/core/torch_generator_agent.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 74976ce506d..33caa4ec8f7 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -1496,7 +1496,7 @@ def _block_ngrams( # when doing context blocking, ngram is tuple where prefix is tensor # need to cast both into lists for comparison if ngram_size == 1 or list(prefix) == list(ngram[:-1]): - logprobs[beam_id][ngram[-1].long()] = neginf(logprobs.dtype) + logprobs[beam_id][ngram[-1]] = neginf(logprobs.dtype) return logprobs def _block_block_list(self, logprobs: torch.Tensor) -> torch.Tensor: @@ -1569,6 +1569,8 @@ def advance(self, logprobs, step): # might need to change later if not self.gpu_beam_blocking: hyp_device = 'cpu' + else: + hyp_device = 'cuda' self.partial_hyps = torch.cat( ( self.partial_hyps[path_selection.hypothesis_ids.long()], From 7a80cfb70712187575744f2ac5a32a3bcd15d70f Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Tue, 28 Jun 2022 10:50:37 -0700 Subject: [PATCH 23/38] fix pointer bug in kernel code and change ngram_size param --- parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu | 4 ++-- parlai/core/torch_generator_agent.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu b/parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu index 0ca4f231e96..d8feffc1198 100644 --- a/parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu +++ b/parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu @@ -28,7 +28,7 @@ __global__ void banRepeatedTokens(long* __restrict__ hypothesis_ptr, auto index = row * current_seq_length + col; // start of last ngram of hypothesis auto start_of_ngram = current_seq_length - no_repeat_ngram_size + 1; - long* previous_ngram_ptr; + long* __restrict__ previous_ngram_ptr; if (if_context_blocking) { previous_ngram_ptr = &context_ptr[col]; @@ -44,7 +44,7 @@ __global__ void banRepeatedTokens(long* __restrict__ hypothesis_ptr, // final thread writes the end of previous ngram array to tokens_shm if (col == blockDim.x - 1) { for (int i=1; i Date: Tue, 28 Jun 2022 13:35:33 -0700 Subject: [PATCH 24/38] add gpu unit tests and fix torch warning --- parlai/core/torch_generator_agent.py | 3 +- tests/test_transformers.py | 84 +++++++++++++++++++++++++++- 2 files changed, 84 insertions(+), 3 deletions(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 1ce141570ac..915409cb704 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -1792,7 +1792,8 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection voc_size = logprobs.size(-1) # get the backtracking hypothesis id as a multiple of full voc_sizes - hyp_ids = best_idxs // voc_size + # hyp_ids = best_idxs // voc_size + hyp_ids = torch.div(best_idxs, voc_size, rounding_mode='trunc') # get the actual word id from residual of the same division tok_ids = best_idxs % voc_size diff --git a/tests/test_transformers.py b/tests/test_transformers.py index f4482bc2b68..ecaeb1282e4 100644 --- a/tests/test_transformers.py +++ b/tests/test_transformers.py @@ -379,7 +379,7 @@ def test_beamsearch_return_all_texts(self): self.assertEqual(len(response["beam_texts"]), size) @pytest.mark.nofbcode - def test_beamsearch_blocking(self): + def test_beamsearch_blocking_cpu(self): """ Test beamsearch blocking. """ @@ -422,7 +422,55 @@ def test_beamsearch_blocking(self): assert '34 34' not in text @pytest.mark.nofbcode - def test_beamsearch_contextblocking(self): + def test_beamsearch_blocking_gpu(self): + """ + Test beamsearch blocking. + """ + with testing_utils.tempdir() as tmpdir: + agent = create_agent_from_model_file('zoo:unittest/beam_blocking/model') + agent.observe({'text': '5 5 5 5 5 5 5', 'episode_done': True}) + assert agent.act()['text'] == '5 5 5 5 5 5 5' + + agent = create_agent_from_model_file( + 'zoo:unittest/beam_blocking/model', + Opt(beam_block_ngram=1, gpu_beam_blocking=True), + ) + agent.observe({'text': '5 5 5 5 5 5 5', 'episode_done': True}) + assert '5 5' not in agent.act()['text'] + + agent = create_agent_from_model_file( + 'zoo:unittest/beam_blocking/model', + Opt(beam_block_ngram=2, gpu_beam_blocking=True), + ) + agent.observe({'text': '5 5 5 5 5 5 5', 'episode_done': True}) + assert '5 5 5' not in agent.act()['text'] + + with open(os.path.join(tmpdir, 'blocklist.txt'), 'w') as f: + f.write("38\n62\n34 34\n") + + agent = create_agent_from_model_file( + 'zoo:unittest/beam_blocking/model', + Opt( + beam_block_list_filename=os.path.join(tmpdir, 'blocklist.txt'), + gpu_beam_blocking=True, + ), + ) + agent.observe({'text': '4 4 4', 'episode_done': True}) + assert agent.act()['text'] == '4 4 4' + + agent.observe({'text': '38 38 38', 'episode_done': True}) + assert '38' not in agent.act()['text'] + + agent.observe({'text': '62 62 62', 'episode_done': True}) + assert '62' not in agent.act()['text'] + + agent.observe({'text': '34 34 34', 'episode_done': True}) + text = agent.act()['text'] + assert '34' in text + assert '34 34' not in text + + @pytest.mark.nofbcode + def test_beamsearch_contextblocking_cpu(self): """ Test beamsearch context blocking. """ @@ -451,6 +499,38 @@ def test_beamsearch_contextblocking(self): assert '4 3' not in text assert '3 2' not in text + @pytest.mark.nofbcode + def test_beamsearch_contextblocking_gpu(self): + """ + Test beamsearch context blocking. + """ + + agent = create_agent_from_model_file('zoo:unittest/context_blocking/model') + agent.observe({'text': '5 4 3 2', 'episode_done': True}) + assert agent.act()['text'] == '5 4 3 2' + + agent = create_agent_from_model_file( + 'zoo:unittest/context_blocking/model', + Opt(beam_context_block_ngram=1, gpu_beam_blocking=True), + ) + agent.observe({'text': '5 4 3 2', 'episode_done': True}) + text = agent.act()['text'] + assert '5' not in text + assert '4' not in text + assert '3' not in text + assert '2' not in text + + agent = create_agent_from_model_file( + 'zoo:unittest/context_blocking/model', + Opt(beam_context_block_ngram=2, gpu_beam_blocking=True), + ) + agent.observe({'text': '5 4 3 2', 'episode_done': True}) + text = agent.act()['text'] + assert '5' in text + assert '5 4' not in text + assert '4 3' not in text + assert '3 2' not in text + def test_nucleus(self): """ Test nucleus generation. From 44b879f9efc8ea3c66a12aa4e4d8312432f5cddf Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Tue, 28 Jun 2022 13:55:20 -0700 Subject: [PATCH 25/38] skip gpu test unless cuda enabled --- tests/test_transformers.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_transformers.py b/tests/test_transformers.py index ecaeb1282e4..3f7528da236 100644 --- a/tests/test_transformers.py +++ b/tests/test_transformers.py @@ -422,12 +422,16 @@ def test_beamsearch_blocking_cpu(self): assert '34 34' not in text @pytest.mark.nofbcode + @testing_utils.skipUnlessGPU def test_beamsearch_blocking_gpu(self): """ Test beamsearch blocking. """ with testing_utils.tempdir() as tmpdir: - agent = create_agent_from_model_file('zoo:unittest/beam_blocking/model') + agent = create_agent_from_model_file( + 'zoo:unittest/beam_blocking/model', + Opt(gpu_beam_blocking=True), + ) agent.observe({'text': '5 5 5 5 5 5 5', 'episode_done': True}) assert agent.act()['text'] == '5 5 5 5 5 5 5' @@ -500,12 +504,16 @@ def test_beamsearch_contextblocking_cpu(self): assert '3 2' not in text @pytest.mark.nofbcode + @testing_utils.skipUnlessGPU def test_beamsearch_contextblocking_gpu(self): """ Test beamsearch context blocking. """ - agent = create_agent_from_model_file('zoo:unittest/context_blocking/model') + agent = create_agent_from_model_file( + 'zoo:unittest/context_blocking/model', + Opt(gpu_beam_blocking=True), + ) agent.observe({'text': '5 4 3 2', 'episode_done': True}) assert agent.act()['text'] == '5 4 3 2' From acdafcfd5a81d8e0116273a24826ae22f2465840 Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Tue, 28 Jun 2022 15:29:32 -0700 Subject: [PATCH 26/38] use tolist() for conversion --- parlai/core/torch_generator_agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 915409cb704..fe92d0cf9e6 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -1450,7 +1450,7 @@ def _block_ngrams( self, ngram_size: int, logprobs: torch.Tensor, - step: int, + step: int = 0, if_context_blocking=False, ): """ @@ -1495,7 +1495,7 @@ def _block_ngrams( for ngram in ngrams: # when doing context blocking, ngram is tuple where prefix is tensor # need to cast both into lists for comparison - if ngram_size == 1 or list(prefix) == list(ngram[:-1]): + if ngram_size == 1 or prefix.tolist() == list(ngram[:-1]): logprobs[beam_id][ngram[-1]] = neginf(logprobs.dtype) return logprobs @@ -1507,7 +1507,7 @@ def _block_block_list(self, logprobs: torch.Tensor) -> torch.Tensor: for ngram_size, bad_ngrams in self.block_list.items(): prefix = hyp[-(ngram_size - 1) :] for ngram in bad_ngrams: - if (ngram_size == 1) or list(prefix) == list(ngram[:-1]): + if (ngram_size == 1) or prefix.tolist() == list(ngram[:-1]): logprobs[beam_id][ngram[-1]] = neginf(logprobs.dtype) return logprobs From 9af834a435bcefb9bd2a049219fe078b7e62e9fd Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Wed, 29 Jun 2022 10:32:59 -0700 Subject: [PATCH 27/38] get rid of context's conversion to list, add check data before kernel code --- parlai/clib/cuda/ngram_repeat_block_cuda.cpp | 3 +++ parlai/core/torch_generator_agent.py | 18 ++++++------------ parlai/ops/ngram_repeat_block.py | 1 + 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/parlai/clib/cuda/ngram_repeat_block_cuda.cpp b/parlai/clib/cuda/ngram_repeat_block_cuda.cpp index f7f2044a3be..cd78c82b27e 100644 --- a/parlai/clib/cuda/ngram_repeat_block_cuda.cpp +++ b/parlai/clib/cuda/ngram_repeat_block_cuda.cpp @@ -34,6 +34,9 @@ torch::Tensor ngram_repeat_block_forward(const torch::Tensor hypothesis, const t int no_repeat_ngram_size, bool if_context_blocking) { CHECK_INPUT(hypothesis); + if(if_context_blocking) { + CHECK_INPUT(context); + } CHECK_INPUT(lprobs); assert(bsz > 0); assert(step >= 0); diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index fe92d0cf9e6..fb11810293e 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -1153,7 +1153,7 @@ def _generate( bsz = batch.batchsize if batch.text_vec is not None: batchsize = batch.batchsize - batch_context_list = self._get_batch_context(batch).tolist() + batch_context_list = self._get_batch_context(batch) beams = [ self._treesearch_factory(dev, verbose=self.show_token_details) .set_batch_context(batch_context_list, batch_idx) @@ -1372,7 +1372,7 @@ def __init__( self.eos_top = False self.eos_top_ts = None self.n_best_counter = 0 - self.partial_hyps = torch.tensor([[self.bos] for i in range(beam_size)]) + self.partial_hyps = torch.full((self.beam_size, 1), self.bos) if torch.cuda.is_available(): self.no_repeat_ngram_op = NGramRepeatBlock() self.gpu_beam_blocking = gpu_beam_blocking @@ -1385,11 +1385,11 @@ def set_context(self: TSType, context: torch.LongTensor) -> TSType: a LongTensor representing the input context; used for context ngram blocking, if supplied """ - self.context = torch.Tensor(context.tolist()).long() + self.context = context return self def set_batch_context( - self: TSType, batch_context_list: List[List[int]], batch_idx: int + self: TSType, batch_context_list: torch.LongTensor, batch_idx: int ) -> TSType: """ Version of .set_context() that operates on a single element of a batch. @@ -1401,7 +1401,7 @@ def set_batch_context( :param batch_idx: index of the batch """ - self.context = torch.Tensor(batch_context_list[batch_idx]).long() + self.context = batch_context_list[batch_idx] return self def set_block_list(self: TSType, block_list: Optional[SearchBlocklist]) -> TSType: @@ -1465,18 +1465,12 @@ def _block_ngrams( Source text to grab ngrams from. If None, it uses the current hypothesis (i.e. self-blocking). """ - context = None if self.gpu_beam_blocking: if not self.partial_hyps.is_cuda: self.partial_hyps = self.partial_hyps.cuda() - if if_context_blocking: - if not self.context.is_cuda: - self.context = self.context.cuda() - context = self.context - logprobs = self.no_repeat_ngram_op( hypothesis=self.partial_hyps, - context=context, + context=self.context, lprobs=logprobs, bsz=1, step=step, diff --git a/parlai/ops/ngram_repeat_block.py b/parlai/ops/ngram_repeat_block.py index 762c27a1636..2c0f1df0e0c 100644 --- a/parlai/ops/ngram_repeat_block.py +++ b/parlai/ops/ngram_repeat_block.py @@ -108,6 +108,7 @@ def forward( if_context_blocking(bool): whether to use context-blocking """ # placeholder tensor to pass in to pass type check, won't be used + # TODO: find a better way to do this? if not if_context_blocking: context = torch.Tensor([0]).long() assert hypothesis.size(0) == bsz * beam_size From 48d9d9701f64fd92d438b315a47ff038e05d1c8d Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Wed, 29 Jun 2022 12:33:03 -0700 Subject: [PATCH 28/38] Revert "get rid of context's conversion to list, add check data before kernel code" This reverts commit 9af834a435bcefb9bd2a049219fe078b7e62e9fd. --- parlai/clib/cuda/ngram_repeat_block_cuda.cpp | 3 --- parlai/core/torch_generator_agent.py | 18 ++++++++++++------ parlai/ops/ngram_repeat_block.py | 1 - 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/parlai/clib/cuda/ngram_repeat_block_cuda.cpp b/parlai/clib/cuda/ngram_repeat_block_cuda.cpp index cd78c82b27e..f7f2044a3be 100644 --- a/parlai/clib/cuda/ngram_repeat_block_cuda.cpp +++ b/parlai/clib/cuda/ngram_repeat_block_cuda.cpp @@ -34,9 +34,6 @@ torch::Tensor ngram_repeat_block_forward(const torch::Tensor hypothesis, const t int no_repeat_ngram_size, bool if_context_blocking) { CHECK_INPUT(hypothesis); - if(if_context_blocking) { - CHECK_INPUT(context); - } CHECK_INPUT(lprobs); assert(bsz > 0); assert(step >= 0); diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index fb11810293e..fe92d0cf9e6 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -1153,7 +1153,7 @@ def _generate( bsz = batch.batchsize if batch.text_vec is not None: batchsize = batch.batchsize - batch_context_list = self._get_batch_context(batch) + batch_context_list = self._get_batch_context(batch).tolist() beams = [ self._treesearch_factory(dev, verbose=self.show_token_details) .set_batch_context(batch_context_list, batch_idx) @@ -1372,7 +1372,7 @@ def __init__( self.eos_top = False self.eos_top_ts = None self.n_best_counter = 0 - self.partial_hyps = torch.full((self.beam_size, 1), self.bos) + self.partial_hyps = torch.tensor([[self.bos] for i in range(beam_size)]) if torch.cuda.is_available(): self.no_repeat_ngram_op = NGramRepeatBlock() self.gpu_beam_blocking = gpu_beam_blocking @@ -1385,11 +1385,11 @@ def set_context(self: TSType, context: torch.LongTensor) -> TSType: a LongTensor representing the input context; used for context ngram blocking, if supplied """ - self.context = context + self.context = torch.Tensor(context.tolist()).long() return self def set_batch_context( - self: TSType, batch_context_list: torch.LongTensor, batch_idx: int + self: TSType, batch_context_list: List[List[int]], batch_idx: int ) -> TSType: """ Version of .set_context() that operates on a single element of a batch. @@ -1401,7 +1401,7 @@ def set_batch_context( :param batch_idx: index of the batch """ - self.context = batch_context_list[batch_idx] + self.context = torch.Tensor(batch_context_list[batch_idx]).long() return self def set_block_list(self: TSType, block_list: Optional[SearchBlocklist]) -> TSType: @@ -1465,12 +1465,18 @@ def _block_ngrams( Source text to grab ngrams from. If None, it uses the current hypothesis (i.e. self-blocking). """ + context = None if self.gpu_beam_blocking: if not self.partial_hyps.is_cuda: self.partial_hyps = self.partial_hyps.cuda() + if if_context_blocking: + if not self.context.is_cuda: + self.context = self.context.cuda() + context = self.context + logprobs = self.no_repeat_ngram_op( hypothesis=self.partial_hyps, - context=self.context, + context=context, lprobs=logprobs, bsz=1, step=step, diff --git a/parlai/ops/ngram_repeat_block.py b/parlai/ops/ngram_repeat_block.py index 2c0f1df0e0c..762c27a1636 100644 --- a/parlai/ops/ngram_repeat_block.py +++ b/parlai/ops/ngram_repeat_block.py @@ -108,7 +108,6 @@ def forward( if_context_blocking(bool): whether to use context-blocking """ # placeholder tensor to pass in to pass type check, won't be used - # TODO: find a better way to do this? if not if_context_blocking: context = torch.Tensor([0]).long() assert hypothesis.size(0) == bsz * beam_size From 547d4fac10cbaee5a3e0fc324b8c5e8af82011ca Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Thu, 30 Jun 2022 14:20:05 -0700 Subject: [PATCH 29/38] replace tensor with list for cpu code to make faster --- parlai/core/torch_generator_agent.py | 48 ++++++++++++++-------------- parlai/ops/ngram_repeat_block.py | 48 +--------------------------- 2 files changed, 25 insertions(+), 71 deletions(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index fe92d0cf9e6..0faacafbf72 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -1153,10 +1153,14 @@ def _generate( bsz = batch.batchsize if batch.text_vec is not None: batchsize = batch.batchsize - batch_context_list = self._get_batch_context(batch).tolist() + batch_context_list = self._get_batch_context(batch) beams = [ self._treesearch_factory(dev, verbose=self.show_token_details) - .set_batch_context(batch_context_list, batch_idx) + .set_batch_context( + batch_context_list, + batch_idx, + self.opt.get('gpu_beam_blocking', False), + ) .set_block_list(self.beam_block_list) for batch_idx in range(batchsize) ] @@ -1372,10 +1376,12 @@ def __init__( self.eos_top = False self.eos_top_ts = None self.n_best_counter = 0 + self.gpu_beam_blocking = gpu_beam_blocking self.partial_hyps = torch.tensor([[self.bos] for i in range(beam_size)]) + if self.gpu_beam_blocking: + self.partial_hyps = self.partial_hyps.cuda() if torch.cuda.is_available(): self.no_repeat_ngram_op = NGramRepeatBlock() - self.gpu_beam_blocking = gpu_beam_blocking def set_context(self: TSType, context: torch.LongTensor) -> TSType: """ @@ -1389,7 +1395,10 @@ def set_context(self: TSType, context: torch.LongTensor) -> TSType: return self def set_batch_context( - self: TSType, batch_context_list: List[List[int]], batch_idx: int + self: TSType, + batch_context_list: torch.LongTensor, + batch_idx: int, + gpu_beam_blocking: bool, ) -> TSType: """ Version of .set_context() that operates on a single element of a batch. @@ -1401,7 +1410,8 @@ def set_batch_context( :param batch_idx: index of the batch """ - self.context = torch.Tensor(batch_context_list[batch_idx]).long() + context = batch_context_list[batch_idx] + self.context = context if gpu_beam_blocking else context.tolist() return self def set_block_list(self: TSType, block_list: Optional[SearchBlocklist]) -> TSType: @@ -1465,15 +1475,8 @@ def _block_ngrams( Source text to grab ngrams from. If None, it uses the current hypothesis (i.e. self-blocking). """ - context = None if self.gpu_beam_blocking: - if not self.partial_hyps.is_cuda: - self.partial_hyps = self.partial_hyps.cuda() - if if_context_blocking: - if not self.context.is_cuda: - self.context = self.context.cuda() - context = self.context - + context = self.context if if_context_blocking else None logprobs = self.no_repeat_ngram_op( hypothesis=self.partial_hyps, context=context, @@ -1486,16 +1489,14 @@ def _block_ngrams( ) return logprobs - for beam_id, hyp in enumerate(self.partial_hyps): + for beam_id, hyp in enumerate(self.partial_hyps.tolist()): if len(hyp) < ngram_size - 1: continue - source_ = hyp if if_context_blocking is False else self.context - ngrams = self._find_ngrams(source_, ngram_size) + source = hyp if if_context_blocking is False else self.context prefix = hyp[-(ngram_size - 1) :] - for ngram in ngrams: - # when doing context blocking, ngram is tuple where prefix is tensor - # need to cast both into lists for comparison - if ngram_size == 1 or prefix.tolist() == list(ngram[:-1]): + for i in range(len(source) - ngram_size + 1): + ngram = source[i : i + ngram_size] + if ngram_size == 1 or prefix == ngram[:-1]: logprobs[beam_id][ngram[-1]] = neginf(logprobs.dtype) return logprobs @@ -1503,11 +1504,11 @@ def _block_block_list(self, logprobs: torch.Tensor) -> torch.Tensor: if self.block_list is None: return logprobs - for beam_id, hyp in enumerate(self.partial_hyps): + for beam_id, hyp in enumerate(self.partial_hyps.tolist()): for ngram_size, bad_ngrams in self.block_list.items(): prefix = hyp[-(ngram_size - 1) :] for ngram in bad_ngrams: - if (ngram_size == 1) or prefix.tolist() == list(ngram[:-1]): + if (ngram_size == 1) or prefix == list(ngram[:-1]): logprobs[beam_id][ngram[-1]] = neginf(logprobs.dtype) return logprobs @@ -1531,7 +1532,6 @@ def advance(self, logprobs, step): self.scores[hyp_id] = neginf(self.scores.dtype) # beam blocking - if self.block_ngram > 0: # self blocking logprobs = self._block_ngrams( @@ -1570,7 +1570,7 @@ def advance(self, logprobs, step): if self.partial_hyps.get_device() == -1: hyp_device = 'cpu' else: - hyp_device = 'cuda' + hyp_device = self.partial_hyps.get_device() self.partial_hyps = torch.cat( ( self.partial_hyps[path_selection.hypothesis_ids.long()], diff --git a/parlai/ops/ngram_repeat_block.py b/parlai/ops/ngram_repeat_block.py index 762c27a1636..f8d42240a28 100644 --- a/parlai/ops/ngram_repeat_block.py +++ b/parlai/ops/ngram_repeat_block.py @@ -31,49 +31,6 @@ os.chdir(current) -class NGramRepeatBlockFunction(Function): - """ - forward inputs to ngram_repeat_block cuda extension backward method not needed. - """ - - def forward( - self, - hypothesis, - context, - lprobs, - bsz, - step, - beam_size, - no_repeat_ngram_size, - if_context_blocking, - ): - """ - Args: - hypothesis(Tensor): (beam*bsz, current_sequence_length) - context(Tensor): context for context-blocking - lprobs(Tensor): likelihood probability(beam, vocab_size) - bsz(int): batch size - step(int): current step - beam_size(int): beam size - no_repeat_ngram_size(int): Ngram size - if_context_blocking(bool): whether to use context-blocking - """ - outputs = ngram_repeat_block_cuda.forward( - hypothesis, - context, - lprobs, - bsz, - step, - beam_size, - no_repeat_ngram_size, - if_context_blocking, - ) - return outputs - - def backward(*args): - raise NotImplementedError - - class NGramRepeatBlock(nn.Module): """ Wrapper class for calling ngram_repeat_block cuda extension. @@ -82,9 +39,6 @@ class NGramRepeatBlock(nn.Module): def __init__(self): super(NGramRepeatBlock, self).__init__() - def reset_parameters(self): - pass - def forward( self, hypothesis, @@ -115,7 +69,7 @@ def forward( hypothesis = hypothesis.contiguous() context = context.contiguous() lprobs = lprobs.contiguous() - return NGramRepeatBlockFunction.apply( + return ngram_repeat_block_cuda.forward( hypothesis, context, lprobs, From 986fe66578183293312748088e43b009f97281f9 Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Thu, 30 Jun 2022 14:29:50 -0700 Subject: [PATCH 30/38] remove unused import --- parlai/ops/ngram_repeat_block.py | 1 - 1 file changed, 1 deletion(-) diff --git a/parlai/ops/ngram_repeat_block.py b/parlai/ops/ngram_repeat_block.py index f8d42240a28..50c7c46f88c 100644 --- a/parlai/ops/ngram_repeat_block.py +++ b/parlai/ops/ngram_repeat_block.py @@ -9,7 +9,6 @@ """ import torch from torch import nn -from torch.autograd import Function import os from torch.utils.cpp_extension import load From 38c8d97aabebc20b109995b1f0413baefe75fc26 Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Thu, 30 Jun 2022 14:37:56 -0700 Subject: [PATCH 31/38] change botocore version --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e259b84d5da..88b920696ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # comment just to bump caches boto3==1.17.95 -botocore==1.20.95 +botocore==1.27.21 coloredlogs==14.0 datasets<2.2.2,>=1.4.1 docutils<0.16,>=0.14 From a73241c06586015c7c38897fe7aea26e9bca7f16 Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Thu, 30 Jun 2022 14:40:06 -0700 Subject: [PATCH 32/38] change botocore again --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 88b920696ec..8dab44171a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # comment just to bump caches boto3==1.17.95 -botocore==1.27.21 +botocore>=1.27.21 coloredlogs==14.0 datasets<2.2.2,>=1.4.1 docutils<0.16,>=0.14 From 0817758d14d6c08c68c113857a9af25ed4fbba27 Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Thu, 30 Jun 2022 20:33:33 -0700 Subject: [PATCH 33/38] Revert "change botocore again" This reverts commit a73241c06586015c7c38897fe7aea26e9bca7f16. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8dab44171a0..88b920696ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # comment just to bump caches boto3==1.17.95 -botocore>=1.27.21 +botocore==1.27.21 coloredlogs==14.0 datasets<2.2.2,>=1.4.1 docutils<0.16,>=0.14 From 2f3754bea337f7f067ef36801ed2c0b744c9f039 Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Thu, 30 Jun 2022 20:34:12 -0700 Subject: [PATCH 34/38] Revert "change botocore version" This reverts commit 38c8d97aabebc20b109995b1f0413baefe75fc26. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 88b920696ec..e259b84d5da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # comment just to bump caches boto3==1.17.95 -botocore==1.27.21 +botocore==1.20.95 coloredlogs==14.0 datasets<2.2.2,>=1.4.1 docutils<0.16,>=0.14 From e6dad903c9ac6d66bff6f825f557d8c0cc55a3c9 Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Thu, 30 Jun 2022 20:42:03 -0700 Subject: [PATCH 35/38] modify pacer set_batch_context --- projects/light_whoami/agents/pacer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/projects/light_whoami/agents/pacer.py b/projects/light_whoami/agents/pacer.py index db34a4e1124..7ea2674c47a 100644 --- a/projects/light_whoami/agents/pacer.py +++ b/projects/light_whoami/agents/pacer.py @@ -197,7 +197,10 @@ def get_target_character(self): return extract_characters(self.context_str)['_self_name'] def set_batch_context( - self: TSType, batch_context_list: List[List[int]], batch_idx: int + self: TSType, + batch_context_list: List[List[int]], + batch_idx: int, + gpu_beam_blocking: bool, ) -> TSType: """ Override to save de-tokenized version of context. From 5128346c7484eed6a886672d228154e2d39b45b7 Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Tue, 5 Jul 2022 10:18:48 -0700 Subject: [PATCH 36/38] remove comments and outdated changes --- parlai/core/torch_generator_agent.py | 3 +-- parlai/ops/ngram_repeat_block.py | 2 -- setup.py | 13 ------------- 3 files changed, 1 insertion(+), 17 deletions(-) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 0faacafbf72..d87cef2d163 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -1391,7 +1391,7 @@ def set_context(self: TSType, context: torch.LongTensor) -> TSType: a LongTensor representing the input context; used for context ngram blocking, if supplied """ - self.context = torch.Tensor(context.tolist()).long() + self.context = context.tolist() return self def set_batch_context( @@ -1792,7 +1792,6 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection voc_size = logprobs.size(-1) # get the backtracking hypothesis id as a multiple of full voc_sizes - # hyp_ids = best_idxs // voc_size hyp_ids = torch.div(best_idxs, voc_size, rounding_mode='trunc') # get the actual word id from residual of the same division tok_ids = best_idxs % voc_size diff --git a/parlai/ops/ngram_repeat_block.py b/parlai/ops/ngram_repeat_block.py index 50c7c46f88c..1af31f4f916 100644 --- a/parlai/ops/ngram_repeat_block.py +++ b/parlai/ops/ngram_repeat_block.py @@ -18,8 +18,6 @@ dname = os.path.dirname(abspath) os.chdir(dname) -# import ngram_repeat_block_cuda - ngram_repeat_block_cuda = load( name='ngram_repeat_block_cuda', sources=[ diff --git a/setup.py b/setup.py index 47e68b87ebd..0d88b17edaa 100644 --- a/setup.py +++ b/setup.py @@ -9,8 +9,6 @@ from setuptools import setup, find_packages -# from torch.utils.cpp_extension import BuildExtension, CUDAExtension - VERSION = '1.6.0' # if you update, update parlai/__init__.py too! if sys.version_info < (3, 8): @@ -28,15 +26,6 @@ if __name__ == '__main__': - # extensions = [ - # CUDAExtension( - # 'ngram_repeat_block_cuda', - # [ - # 'parlai/clib/cuda/ngram_repeat_block_cuda.cpp', - # 'parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu', - # ], - # ), - # ] setup( name='parlai', version=VERSION, @@ -49,12 +38,10 @@ install_requires=reqs, include_package_data=True, package_data={'': ['*.txt', '*.md', '*.opt']}, - # ext_modules=extensions, entry_points={ "flake8.extension": ["PAI = parlai.utils.flake8:ParlAIChecker"], "console_scripts": ["parlai=parlai.__main__:main"], }, - # cmdclass={'build_ext': BuildExtension}, classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", From 557ea125470c81d9386626e8ce5c6b779b94445a Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Tue, 5 Jul 2022 11:20:55 -0700 Subject: [PATCH 37/38] add comments and copyright headers --- parlai/clib/cuda/ngram_repeat_block_cuda.cpp | 2 ++ .../clib/cuda/ngram_repeat_block_cuda_kernel.cu | 2 ++ parlai/core/torch_generator_agent.py | 16 +++++++++++----- parlai/ops/ngram_repeat_block.py | 2 ++ 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/parlai/clib/cuda/ngram_repeat_block_cuda.cpp b/parlai/clib/cuda/ngram_repeat_block_cuda.cpp index f7f2044a3be..40cc4d17667 100644 --- a/parlai/clib/cuda/ngram_repeat_block_cuda.cpp +++ b/parlai/clib/cuda/ngram_repeat_block_cuda.cpp @@ -1,7 +1,9 @@ /* Copyright (c) Facebook, Inc. and its affiliates. +Copyright (c) Microsoft Corporation. This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree. +Code adapted from https://github.com/microsoft/fastseq/blob/main/fastseq/clib/cuda/ngram_repeat_block_cuda.cpp. */ #include diff --git a/parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu b/parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu index d8feffc1198..8b680248388 100644 --- a/parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu +++ b/parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu @@ -1,7 +1,9 @@ /* Copyright (c) Facebook, Inc. and its affiliates. +Copyright (c) Microsoft Corporation. This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree. +Code adapted from https://github.com/microsoft/fastseq/blob/main/fastseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu. */ /* diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index d87cef2d163..628951dd573 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -1409,6 +1409,9 @@ def set_batch_context( a list of lists, each one containing the context for one member of the batch :param batch_idx: index of the batch + :param gpu_beam_blocking: + whether we are using gpu kernel for beam blocking, if so return a tensor, + else return a list. """ context = batch_context_list[batch_idx] self.context = context if gpu_beam_blocking else context.tolist() @@ -1464,17 +1467,19 @@ def _block_ngrams( if_context_blocking=False, ): """ - Hard block ngrams from the logprobs, based on the source. + Hard block ngrams from the logprobs. :param ngram_size: The length of ngrams to block. Must be > 0. :param logprobs: Float or HalfTensor, representing the log-probabilities. This is modified in place. - :param source: - Source text to grab ngrams from. If None, it uses the current - hypothesis (i.e. self-blocking). + :param step: + current step on generating utterances + :param if_context_blocking: + whether we are doing context blocking """ + # gpu beam blocking if self.gpu_beam_blocking: context = self.context if if_context_blocking else None logprobs = self.no_repeat_ngram_op( @@ -1489,6 +1494,7 @@ def _block_ngrams( ) return logprobs + # cpu beam blocking for beam_id, hyp in enumerate(self.partial_hyps.tolist()): if len(hyp) < ngram_size - 1: continue @@ -1508,7 +1514,7 @@ def _block_block_list(self, logprobs: torch.Tensor) -> torch.Tensor: for ngram_size, bad_ngrams in self.block_list.items(): prefix = hyp[-(ngram_size - 1) :] for ngram in bad_ngrams: - if (ngram_size == 1) or prefix == list(ngram[:-1]): + if (ngram_size == 1) or prefix == ngram[:-1]: logprobs[beam_id][ngram[-1]] = neginf(logprobs.dtype) return logprobs diff --git a/parlai/ops/ngram_repeat_block.py b/parlai/ops/ngram_repeat_block.py index 1af31f4f916..090b2813a5f 100644 --- a/parlai/ops/ngram_repeat_block.py +++ b/parlai/ops/ngram_repeat_block.py @@ -1,8 +1,10 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Microsoft Corporation. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +# Code adapted from https://github.com/microsoft/fastseq/blob/main/fastseq/ops/ngram_repeat_block.py. """ Wrapper for ngram_repeat_block cuda extension. From 1cb104899adfd7fa62ccf74f195d57d67d7ece6b Mon Sep 17 00:00:00 2001 From: Pearl Li Date: Tue, 5 Jul 2022 11:34:46 -0700 Subject: [PATCH 38/38] format c++ and cu file --- parlai/clib/cuda/ngram_repeat_block_cuda.cpp | 6 +- .../cuda/ngram_repeat_block_cuda_kernel.cu | 76 +++++++++++-------- 2 files changed, 48 insertions(+), 34 deletions(-) diff --git a/parlai/clib/cuda/ngram_repeat_block_cuda.cpp b/parlai/clib/cuda/ngram_repeat_block_cuda.cpp index 40cc4d17667..d44724ed1d3 100644 --- a/parlai/clib/cuda/ngram_repeat_block_cuda.cpp +++ b/parlai/clib/cuda/ngram_repeat_block_cuda.cpp @@ -34,7 +34,8 @@ torch::Tensor ngram_repeat_block_forward(const torch::Tensor hypothesis, const t torch::Tensor lprobs, int bsz, int step, int beam_size, int no_repeat_ngram_size, - bool if_context_blocking) { + bool if_context_blocking) +{ CHECK_INPUT(hypothesis); CHECK_INPUT(lprobs); assert(bsz > 0); @@ -46,7 +47,8 @@ torch::Tensor ngram_repeat_block_forward(const torch::Tensor hypothesis, const t no_repeat_ngram_size, if_context_blocking); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ m.def("forward", &ngram_repeat_block_forward, "No Repeat Ngram Block forward (CUDA)"); } diff --git a/parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu b/parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu index 8b680248388..4e7a0fb8cc7 100644 --- a/parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu +++ b/parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu @@ -17,44 +17,52 @@ Kernel implementation for blocking repeated n-grams. #include // Ban repeated ngrams of length = 'no_repeat_ngram_size' -__global__ void banRepeatedTokens(long* __restrict__ hypothesis_ptr, - long* __restrict__ context_ptr, - float* __restrict__ lprobs, - int current_seq_length, +__global__ void banRepeatedTokens(long *__restrict__ hypothesis_ptr, + long *__restrict__ context_ptr, + float *__restrict__ lprobs, + int current_seq_length, int vocab_size, int no_repeat_ngram_size, - bool if_context_blocking) { + bool if_context_blocking) +{ auto row = blockIdx.x; auto col = threadIdx.x; // start of context ngram on current thread auto index = row * current_seq_length + col; // start of last ngram of hypothesis auto start_of_ngram = current_seq_length - no_repeat_ngram_size + 1; - long* __restrict__ previous_ngram_ptr; + long *__restrict__ previous_ngram_ptr; - if (if_context_blocking) { + if (if_context_blocking) + { previous_ngram_ptr = &context_ptr[col]; - } else { + } + else + { previous_ngram_ptr = &hypothesis_ptr[index]; } auto lprob_start = row * vocab_size; extern __shared__ long tokens_shm[]; // each thread writes to shared array - tokens_shm[col] = *previous_ngram_ptr; + tokens_shm[col] = *previous_ngram_ptr; // final thread writes the end of previous ngram array to tokens_shm - if (col == blockDim.x - 1) { - for (int i=1; i(); auto context_ptr = context.data_ptr(); auto lprob_ptr = lprobs.data_ptr(); - + int context_length; - if (if_context_blocking) { + if (if_context_blocking) + { context_length = context.size(0); - } else { + } + else + { // context is previously generated word sequences for self-blocking context_length = hypothesis.size(1); } - - int threads = context_length - no_repeat_ngram_size + 1; - if (step - no_repeat_ngram_size + 2 <= 0) return lprobs; + + int threads = context_length - no_repeat_ngram_size + 1; + if (step - no_repeat_ngram_size + 2 <= 0) + return lprobs; int vocab_size = lprobs.size(1); - int blocks = bsz * beam_size; + int blocks = bsz * beam_size; int current_seq_length = hypothesis.size(1); int shared_mem_size = context_length * sizeof(long); - // Launching N blocks where N is number of samples in a batch (beams*bsz) // Launching T threads where T is number of previous ngrams in a sample @@ -103,12 +115,12 @@ torch::Tensor ngram_repeat_block_cuda_forward(const torch::Tensor hypothesis, // N is Ngram size. banRepeatedTokens<<>>( - hypothesis_ptr, - context_ptr, - lprob_ptr, - current_seq_length, - vocab_size, - no_repeat_ngram_size, + hypothesis_ptr, + context_ptr, + lprob_ptr, + current_seq_length, + vocab_size, + no_repeat_ngram_size, if_context_blocking); return lprobs; }