-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add CUDA Kernel for TreeSearch Ngram Blocking #4633
Conversation
@@ -1466,7 +1503,7 @@ def _block_block_list(self, logprobs: torch.Tensor) -> torch.Tensor: | |||
logprobs[beam_id][ngram[-1]] = neginf(logprobs.dtype) | |||
return logprobs | |||
|
|||
def advance(self, logprobs): | |||
def advance(self, logprobs, step): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a easy walk around here to avoid breaking current tests, is to set a default parameter to the step here, step=0 or sth when it is not given, since it is used in non GPU beam blocking settings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test failure seems to be cased by
NameError: name 'NGramRepeatBlock' is not defined
indicates the CUDA cpp binding function is not handled properly by the CI.
parlai/core/torch_generator_agent.py
Outdated
@@ -956,6 +965,7 @@ def _treesearch_factory(self, device, verbose=False): | |||
) | |||
elif method == 'beam': | |||
return BeamSearch( | |||
self.opt['gpu_beam_blocking'], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have set self.gpu_beam_blocking
, why not use it?
parlai/core/torch_generator_agent.py
Outdated
@@ -1443,15 +1459,36 @@ def _block_ngrams( | |||
Source text to grab ngrams from. If None, it uses the current | |||
hypothesis (i.e. self-blocking). | |||
""" | |||
context = None | |||
if self.gpu_beam_blocking: | |||
if if_context_blocking: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make it if self.gpu_beam_blocking and if_context_blocking?
…function to cast into list, set current ngram_size for context blocking, move path to cpu when needed
is this ready for review? or still WIP? |
…e kernel code" This reverts commit 9af834a.
this is ready for review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this looks great! I'm approving but it looks like we still have cleaninstall
failures; any chance you could try fixing those?
parlai/core/torch_generator_agent.py
Outdated
@@ -1367,11 +1391,14 @@ def set_context(self: TSType, context: torch.LongTensor) -> TSType: | |||
a LongTensor representing the input context; used for context | |||
ngram blocking, if supplied | |||
""" | |||
self.context = context.tolist() | |||
self.context = torch.Tensor(context.tolist()).long() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is default dtype here? curious why .long()
cast is needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
outdated change, reverted to main code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cleaninstall failures seemed to have disappeared without me changing anything.
step, | ||
beam_size, | ||
no_repeat_ngram_size, | ||
if_context_blocking=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this parameter necessary? I.e., can't we assume that if context
is passed, we block on it? or does this make the downstream logic easier to handle in the kernel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this helps downstream kernel logic easier. Kernel params can't be none so I initialize a placeholder tensor as the empty context. Will need a bool to tell whether we are doing self-blocking or context-blocking.
parlai/core/torch_generator_agent.py
Outdated
@@ -1725,7 +1792,8 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection | |||
voc_size = logprobs.size(-1) | |||
|
|||
# get the backtracking hypothesis id as a multiple of full voc_sizes | |||
hyp_ids = best_idxs // voc_size | |||
# hyp_ids = best_idxs // voc_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: feel free to delete. and also, thank you for changing this (I've seen this warning several times)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deleted
setup.py
Outdated
@@ -9,6 +9,8 @@ | |||
|
|||
from setuptools import setup, find_packages | |||
|
|||
# from torch.utils.cpp_extension import BuildExtension, CUDAExtension |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you please delete these commented lines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deleted
@@ -0,0 +1,50 @@ | |||
/* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is our code freshly rewritten, or did we get it from fastseq. We need to maintain the original copyright headers (in addition to our own) if that's the case. Both are MIT so it's no problem, but we need to include Copyright (c) Microsoft etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this out! We definitely referenced their code but i've also made some significant changes. I think we should add the copyright here? can you let me know what's the right move? not super familiar with this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would sth like this work?
/*
Copyright (c) Facebook, Inc. and its affiliates.
Copyright (c) Microsoft Corporation.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
*/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we discussed in our mtg, you can cite the original file and say we adapt from it. And maintain the Microsoft copyright header.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow great progress!
@@ -0,0 +1,112 @@ | |||
/* | |||
Copyright (c) Facebook, Inc. and its affiliates. | |||
This source code is licensed under the MIT license found in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And what @stephenroller mentioned about licensing should apply to this file as well, I suppose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, edited this as well.
int vocab_size, | ||
int no_repeat_ngram_size, | ||
bool if_context_blocking) { | ||
auto row = blockIdx.x; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add some comments here on what row and col means here?
|
||
// final thread writes the end of previous ngram array to tokens_shm | ||
if (col == blockDim.x - 1) { | ||
for (int i=1; i<no_repeat_ngram_size; i++){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: spacing, you could apply a CPP formatter on this file
@@ -422,7 +422,59 @@ def test_beamsearch_blocking(self): | |||
assert '34 34' not in text | |||
|
|||
@pytest.mark.nofbcode | |||
def test_beamsearch_contextblocking(self): | |||
@testing_utils.skipUnlessGPU |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding all these tests
Major changes
--gpu-beam-blocking True
, functionalities are wrapped in the classNGramRepeatBlockFunction
TreeSearch
has new boolean attributegpu_beam_blocking
that we use in_block_ngrams()
self.partial_hyps
andself.context
ofTorchGeneratorAgent
instead of listsTesting steps
gpu tests: with cuda enabled
pytest tests/test_transformers.py -k test_beamsearch_blocking_gpu
pytest tests/test_transformers.py -k test_beamsearch_contextblocking_gpu
cpu tests:
pytest tests/test_transformers.py -k test_beamsearch_blocking_cpu
pytest tests/test_transformers.py -k test_beamsearch_contextblocking_cpu
Other information
parlai interactive --model-file "zoo:tutorial_transformer_generator/model" --gpu-beam-blocking True
Evaluation
Have done evaluation on the convai2 teacher task on 3 settings: (1) code on main (2) new code with cpu (3) new code with gpu kernel, results shown below