Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Add CUDA Kernel for TreeSearch Ngram Blocking #4633

Merged
merged 39 commits into from
Jul 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
02fca24
add cuda and cpp code for ngram blocking
pearlli98 Jun 23, 2022
bd8ab04
add python wrapper
pearlli98 Jun 23, 2022
0eddfe7
modify agent to use cuda kernel for self-blocking
pearlli98 Jun 23, 2022
9d76ff0
Merge branch 'main' into pearlli-ngram-blocking-kernel
pearlli98 Jun 23, 2022
deda51e
add context blocking
pearlli98 Jun 24, 2022
189d35b
change load paths
pearlli98 Jun 27, 2022
9ccf003
add ninja to requirement
pearlli98 Jun 27, 2022
e864f49
modify setup script to install kernel ahead of time
pearlli98 Jun 27, 2022
49c9bfb
change circleci test to use gpu to build website
pearlli98 Jun 27, 2022
1d2448e
change back to JIT, switch directory when loadingcuda moddule
pearlli98 Jun 27, 2022
0263c8f
add check for cuda
pearlli98 Jun 27, 2022
cbca092
get rid of ninja
pearlli98 Jun 27, 2022
f75dc56
remove unused param
pearlli98 Jun 27, 2022
c03c0d2
move "hyps to cuda" into _block_ngrams()
pearlli98 Jun 27, 2022
aae7127
set gpu_beam_blocking as attribute for TreeSearch, modify block_list …
pearlli98 Jun 27, 2022
f4d6bf1
fix lint formatting issues
pearlli98 Jun 27, 2022
7076ff4
add init file to new folders
pearlli98 Jun 27, 2022
05b81b6
add new line at end of file
pearlli98 Jun 27, 2022
d6d7ab6
new lint errors
pearlli98 Jun 27, 2022
6e17d5f
add ninja
pearlli98 Jun 27, 2022
538cd62
set protobuf
pearlli98 Jun 27, 2022
b5b1df1
cast tensor to list in to pass gpu tests
pearlli98 Jun 28, 2022
e4b78c7
debug long gpu tests
pearlli98 Jun 28, 2022
7a80cfb
fix pointer bug in kernel code and change ngram_size param
pearlli98 Jun 28, 2022
fe327f4
add gpu unit tests and fix torch warning
pearlli98 Jun 28, 2022
44b879f
skip gpu test unless cuda enabled
pearlli98 Jun 28, 2022
acdafcf
use tolist() for conversion
pearlli98 Jun 28, 2022
9af834a
get rid of context's conversion to list, add check data before kernel…
pearlli98 Jun 29, 2022
48d9d97
Revert "get rid of context's conversion to list, add check data befor…
pearlli98 Jun 29, 2022
547d4fa
replace tensor with list for cpu code to make faster
pearlli98 Jun 30, 2022
986fe66
remove unused import
pearlli98 Jun 30, 2022
38c8d97
change botocore version
pearlli98 Jun 30, 2022
a73241c
change botocore again
pearlli98 Jun 30, 2022
0817758
Revert "change botocore again"
pearlli98 Jul 1, 2022
2f3754b
Revert "change botocore version"
pearlli98 Jul 1, 2022
e6dad90
modify pacer set_batch_context
pearlli98 Jul 1, 2022
5128346
remove comments and outdated changes
pearlli98 Jul 5, 2022
557ea12
add comments and copyright headers
pearlli98 Jul 5, 2022
1cb1048
format c++ and cu file
pearlli98 Jul 5, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions parlai/clib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
5 changes: 5 additions & 0 deletions parlai/clib/cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
54 changes: 54 additions & 0 deletions parlai/clib/cuda/ngram_repeat_block_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is our code freshly rewritten, or did we get it from fastseq. We need to maintain the original copyright headers (in addition to our own) if that's the case. Both are MIT so it's no problem, but we need to include Copyright (c) Microsoft etc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out! We definitely referenced their code but i've also made some significant changes. I think we should add the copyright here? can you let me know what's the right move? not super familiar with this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would sth like this work?

/*
Copyright (c) Facebook, Inc. and its affiliates.
Copyright (c) Microsoft Corporation.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
*/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed in our mtg, you can cite the original file and say we adapt from it. And maintain the Microsoft copyright header.

Copyright (c) Facebook, Inc. and its affiliates.
Copyright (c) Microsoft Corporation.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
Code adapted from https://github.com/microsoft/fastseq/blob/main/fastseq/clib/cuda/ngram_repeat_block_cuda.cpp.
*/

#include <torch/extension.h>
#include <vector>

/*
CPP Binding for CUDA OP
*/

// CUDA forward declarations
torch::Tensor ngram_repeat_block_cuda_forward(const torch::Tensor hypothesis, const torch::Tensor context,
torch::Tensor lprobs, int bsz,
int step, int beam_size,
int no_repeat_ngram_size,
bool if_context_blocking);

#define CHECK_CUDA(x) \
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)

// Input check and call to CUDA OP
// Backward method not required
torch::Tensor ngram_repeat_block_forward(const torch::Tensor hypothesis, const torch::Tensor context,
torch::Tensor lprobs, int bsz,
int step, int beam_size,
int no_repeat_ngram_size,
bool if_context_blocking)
{
CHECK_INPUT(hypothesis);
CHECK_INPUT(lprobs);
assert(bsz > 0);
assert(step >= 0);
assert(beam_size > 0);
assert(no_repeat_ngram_size > 0);

return ngram_repeat_block_cuda_forward(hypothesis, context, lprobs, bsz, step, beam_size,
no_repeat_ngram_size, if_context_blocking);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("forward", &ngram_repeat_block_forward,
"No Repeat Ngram Block forward (CUDA)");
}
126 changes: 126 additions & 0 deletions parlai/clib/cuda/ngram_repeat_block_cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
Copyright (c) Facebook, Inc. and its affiliates.
Copyright (c) Microsoft Corporation.
This source code is licensed under the MIT license found in the
Copy link
Contributor

@dexterju27 dexterju27 Jul 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And what @stephenroller mentioned about licensing should apply to this file as well, I suppose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, edited this as well.

LICENSE file in the root directory of this source tree.
Code adapted from https://github.com/microsoft/fastseq/blob/main/fastseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu.
*/

/*
Kernel implementation for blocking repeated n-grams.
*/

#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
#include <torch/extension.h>
#include <vector>

// Ban repeated ngrams of length = 'no_repeat_ngram_size'
__global__ void banRepeatedTokens(long *__restrict__ hypothesis_ptr,
long *__restrict__ context_ptr,
float *__restrict__ lprobs,
int current_seq_length,
int vocab_size,
int no_repeat_ngram_size,
bool if_context_blocking)
{
auto row = blockIdx.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some comments here on what row and col means here?

auto col = threadIdx.x;
// start of context ngram on current thread
auto index = row * current_seq_length + col;
// start of last ngram of hypothesis
auto start_of_ngram = current_seq_length - no_repeat_ngram_size + 1;
long *__restrict__ previous_ngram_ptr;

if (if_context_blocking)
{
previous_ngram_ptr = &context_ptr[col];
}
else
{
previous_ngram_ptr = &hypothesis_ptr[index];
}

auto lprob_start = row * vocab_size;
extern __shared__ long tokens_shm[];
// each thread writes to shared array
tokens_shm[col] = *previous_ngram_ptr;

// final thread writes the end of previous ngram array to tokens_shm
if (col == blockDim.x - 1)
{
for (int i = 1; i < no_repeat_ngram_size; i++)
{
tokens_shm[col + i] = previous_ngram_ptr[i];
}
}
__syncthreads();

// Each thread compares ngram starting from
// thread index with final ngram starting
for (int k = 0; k < no_repeat_ngram_size - 1; k++)
{
if (tokens_shm[col + k] != hypothesis_ptr[row * current_seq_length + start_of_ngram + k])
{
return;
}
}

// reach here means ban
auto token_to_be_banned = tokens_shm[col + no_repeat_ngram_size - 1];
lprobs[lprob_start + token_to_be_banned] = -INFINITY;
}

// Allocate blocks and threads based on
// batch size and sequence length and launch
// kernel
torch::Tensor ngram_repeat_block_cuda_forward(const torch::Tensor hypothesis,
const torch::Tensor context,
torch::Tensor lprobs,
int bsz,
int step,
int beam_size,
int no_repeat_ngram_size,
bool if_context_blocking)
{

auto hypothesis_ptr = hypothesis.data_ptr<long>();
auto context_ptr = context.data_ptr<long>();
auto lprob_ptr = lprobs.data_ptr<float>();

int context_length;
if (if_context_blocking)
{
context_length = context.size(0);
}
else
{
// context is previously generated word sequences for self-blocking
context_length = hypothesis.size(1);
}

int threads = context_length - no_repeat_ngram_size + 1;
if (step - no_repeat_ngram_size + 2 <= 0)
return lprobs;
int vocab_size = lprobs.size(1);
int blocks = bsz * beam_size;
int current_seq_length = hypothesis.size(1);
int shared_mem_size = context_length * sizeof(long);

// Launching N blocks where N is number of samples in a batch (beams*bsz)
// Launching T threads where T is number of previous ngrams in a sample
// Allocating shared mem per block for fastser access of input tokens since
// each token will be accessed N times to compare with current Ngram where
// N is Ngram size.

banRepeatedTokens<<<blocks, threads, shared_mem_size>>>(
hypothesis_ptr,
context_ptr,
lprob_ptr,
current_seq_length,
vocab_size,
no_repeat_ngram_size,
if_context_blocking);
return lprobs;
}
Loading