This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add CUDA Kernel for TreeSearch Ngram Blocking #4633
Merged
Merged
Changes from all commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
02fca24
add cuda and cpp code for ngram blocking
pearlli98 bd8ab04
add python wrapper
pearlli98 0eddfe7
modify agent to use cuda kernel for self-blocking
pearlli98 9d76ff0
Merge branch 'main' into pearlli-ngram-blocking-kernel
pearlli98 deda51e
add context blocking
pearlli98 189d35b
change load paths
pearlli98 9ccf003
add ninja to requirement
pearlli98 e864f49
modify setup script to install kernel ahead of time
pearlli98 49c9bfb
change circleci test to use gpu to build website
pearlli98 1d2448e
change back to JIT, switch directory when loadingcuda moddule
pearlli98 0263c8f
add check for cuda
pearlli98 cbca092
get rid of ninja
pearlli98 f75dc56
remove unused param
pearlli98 c03c0d2
move "hyps to cuda" into _block_ngrams()
pearlli98 aae7127
set gpu_beam_blocking as attribute for TreeSearch, modify block_list …
pearlli98 f4d6bf1
fix lint formatting issues
pearlli98 7076ff4
add init file to new folders
pearlli98 05b81b6
add new line at end of file
pearlli98 d6d7ab6
new lint errors
pearlli98 6e17d5f
add ninja
pearlli98 538cd62
set protobuf
pearlli98 b5b1df1
cast tensor to list in to pass gpu tests
pearlli98 e4b78c7
debug long gpu tests
pearlli98 7a80cfb
fix pointer bug in kernel code and change ngram_size param
pearlli98 fe327f4
add gpu unit tests and fix torch warning
pearlli98 44b879f
skip gpu test unless cuda enabled
pearlli98 acdafcf
use tolist() for conversion
pearlli98 9af834a
get rid of context's conversion to list, add check data before kernel…
pearlli98 48d9d97
Revert "get rid of context's conversion to list, add check data befor…
pearlli98 547d4fa
replace tensor with list for cpu code to make faster
pearlli98 986fe66
remove unused import
pearlli98 38c8d97
change botocore version
pearlli98 a73241c
change botocore again
pearlli98 0817758
Revert "change botocore again"
pearlli98 2f3754b
Revert "change botocore version"
pearlli98 e6dad90
modify pacer set_batch_context
pearlli98 5128346
remove comments and outdated changes
pearlli98 557ea12
add comments and copyright headers
pearlli98 1cb1048
format c++ and cu file
pearlli98 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/* | ||
Copyright (c) Facebook, Inc. and its affiliates. | ||
Copyright (c) Microsoft Corporation. | ||
This source code is licensed under the MIT license found in the | ||
LICENSE file in the root directory of this source tree. | ||
Code adapted from https://github.com/microsoft/fastseq/blob/main/fastseq/clib/cuda/ngram_repeat_block_cuda.cpp. | ||
*/ | ||
|
||
#include <torch/extension.h> | ||
#include <vector> | ||
|
||
/* | ||
CPP Binding for CUDA OP | ||
*/ | ||
|
||
// CUDA forward declarations | ||
torch::Tensor ngram_repeat_block_cuda_forward(const torch::Tensor hypothesis, const torch::Tensor context, | ||
torch::Tensor lprobs, int bsz, | ||
int step, int beam_size, | ||
int no_repeat_ngram_size, | ||
bool if_context_blocking); | ||
|
||
#define CHECK_CUDA(x) \ | ||
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") | ||
#define CHECK_CONTIGUOUS(x) \ | ||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") | ||
#define CHECK_INPUT(x) \ | ||
CHECK_CUDA(x); \ | ||
CHECK_CONTIGUOUS(x) | ||
|
||
// Input check and call to CUDA OP | ||
// Backward method not required | ||
torch::Tensor ngram_repeat_block_forward(const torch::Tensor hypothesis, const torch::Tensor context, | ||
torch::Tensor lprobs, int bsz, | ||
int step, int beam_size, | ||
int no_repeat_ngram_size, | ||
bool if_context_blocking) | ||
{ | ||
CHECK_INPUT(hypothesis); | ||
CHECK_INPUT(lprobs); | ||
assert(bsz > 0); | ||
assert(step >= 0); | ||
assert(beam_size > 0); | ||
assert(no_repeat_ngram_size > 0); | ||
|
||
return ngram_repeat_block_cuda_forward(hypothesis, context, lprobs, bsz, step, beam_size, | ||
no_repeat_ngram_size, if_context_blocking); | ||
} | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) | ||
{ | ||
m.def("forward", &ngram_repeat_block_forward, | ||
"No Repeat Ngram Block forward (CUDA)"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
/* | ||
Copyright (c) Facebook, Inc. and its affiliates. | ||
Copyright (c) Microsoft Corporation. | ||
This source code is licensed under the MIT license found in the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And what @stephenroller mentioned about licensing should apply to this file as well, I suppose? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, edited this as well. |
||
LICENSE file in the root directory of this source tree. | ||
Code adapted from https://github.com/microsoft/fastseq/blob/main/fastseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu. | ||
*/ | ||
|
||
/* | ||
Kernel implementation for blocking repeated n-grams. | ||
*/ | ||
|
||
#include <cuda.h> | ||
#include <cuda_runtime.h> | ||
#include <math.h> | ||
#include <torch/extension.h> | ||
#include <vector> | ||
|
||
// Ban repeated ngrams of length = 'no_repeat_ngram_size' | ||
__global__ void banRepeatedTokens(long *__restrict__ hypothesis_ptr, | ||
long *__restrict__ context_ptr, | ||
float *__restrict__ lprobs, | ||
int current_seq_length, | ||
int vocab_size, | ||
int no_repeat_ngram_size, | ||
bool if_context_blocking) | ||
{ | ||
auto row = blockIdx.x; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add some comments here on what row and col means here? |
||
auto col = threadIdx.x; | ||
// start of context ngram on current thread | ||
auto index = row * current_seq_length + col; | ||
// start of last ngram of hypothesis | ||
auto start_of_ngram = current_seq_length - no_repeat_ngram_size + 1; | ||
long *__restrict__ previous_ngram_ptr; | ||
|
||
if (if_context_blocking) | ||
{ | ||
previous_ngram_ptr = &context_ptr[col]; | ||
} | ||
else | ||
{ | ||
previous_ngram_ptr = &hypothesis_ptr[index]; | ||
} | ||
|
||
auto lprob_start = row * vocab_size; | ||
extern __shared__ long tokens_shm[]; | ||
// each thread writes to shared array | ||
tokens_shm[col] = *previous_ngram_ptr; | ||
|
||
// final thread writes the end of previous ngram array to tokens_shm | ||
if (col == blockDim.x - 1) | ||
{ | ||
for (int i = 1; i < no_repeat_ngram_size; i++) | ||
{ | ||
tokens_shm[col + i] = previous_ngram_ptr[i]; | ||
} | ||
} | ||
__syncthreads(); | ||
|
||
// Each thread compares ngram starting from | ||
// thread index with final ngram starting | ||
for (int k = 0; k < no_repeat_ngram_size - 1; k++) | ||
{ | ||
if (tokens_shm[col + k] != hypothesis_ptr[row * current_seq_length + start_of_ngram + k]) | ||
{ | ||
return; | ||
} | ||
} | ||
|
||
// reach here means ban | ||
auto token_to_be_banned = tokens_shm[col + no_repeat_ngram_size - 1]; | ||
lprobs[lprob_start + token_to_be_banned] = -INFINITY; | ||
} | ||
|
||
// Allocate blocks and threads based on | ||
// batch size and sequence length and launch | ||
// kernel | ||
torch::Tensor ngram_repeat_block_cuda_forward(const torch::Tensor hypothesis, | ||
const torch::Tensor context, | ||
torch::Tensor lprobs, | ||
int bsz, | ||
int step, | ||
int beam_size, | ||
int no_repeat_ngram_size, | ||
bool if_context_blocking) | ||
{ | ||
|
||
auto hypothesis_ptr = hypothesis.data_ptr<long>(); | ||
auto context_ptr = context.data_ptr<long>(); | ||
auto lprob_ptr = lprobs.data_ptr<float>(); | ||
|
||
int context_length; | ||
if (if_context_blocking) | ||
{ | ||
context_length = context.size(0); | ||
} | ||
else | ||
{ | ||
// context is previously generated word sequences for self-blocking | ||
context_length = hypothesis.size(1); | ||
} | ||
|
||
int threads = context_length - no_repeat_ngram_size + 1; | ||
if (step - no_repeat_ngram_size + 2 <= 0) | ||
return lprobs; | ||
int vocab_size = lprobs.size(1); | ||
int blocks = bsz * beam_size; | ||
int current_seq_length = hypothesis.size(1); | ||
int shared_mem_size = context_length * sizeof(long); | ||
|
||
// Launching N blocks where N is number of samples in a batch (beams*bsz) | ||
// Launching T threads where T is number of previous ngrams in a sample | ||
// Allocating shared mem per block for fastser access of input tokens since | ||
// each token will be accessed N times to compare with current Ngram where | ||
// N is Ngram size. | ||
|
||
banRepeatedTokens<<<blocks, threads, shared_mem_size>>>( | ||
hypothesis_ptr, | ||
context_ptr, | ||
lprob_ptr, | ||
current_seq_length, | ||
vocab_size, | ||
no_repeat_ngram_size, | ||
if_context_blocking); | ||
return lprobs; | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is our code freshly rewritten, or did we get it from fastseq. We need to maintain the original copyright headers (in addition to our own) if that's the case. Both are MIT so it's no problem, but we need to include Copyright (c) Microsoft etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this out! We definitely referenced their code but i've also made some significant changes. I think we should add the copyright here? can you let me know what's the right move? not super familiar with this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would sth like this work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we discussed in our mtg, you can cite the original file and say we adapt from it. And maintain the Microsoft copyright header.