Skip to content

EmbeddingBag op and layer #2352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 66 commits into from
May 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
32b3e6c
Adding EmbeddingBag layer and associated updates to BUILD files etc.
Rocketknight1 Jan 18, 2021
358d58d
Adding EmbeddingBag layer and associated updates to BUILD files etc.
Rocketknight1 Jan 18, 2021
0f05f51
Made sure the Layers know they depend on the new .so file
Rocketknight1 Jan 18, 2021
fa1600a
Bugfixes to test
Rocketknight1 Jan 18, 2021
5365329
Bugfixes to test
Rocketknight1 Jan 18, 2021
9cc8cb7
Fixing initializer arguments
Rocketknight1 Jan 19, 2021
251a9d1
Fixing names of input arguments -> "input_dim" and "output_dim"
Rocketknight1 Jan 19, 2021
7ee13dc
Fixes to get_config()
Rocketknight1 Jan 19, 2021
0b06706
Finish CPU forward part
WindQAQ Jan 19, 2021
0d22a0b
Fix cost estimation
WindQAQ Jan 19, 2021
9572fa4
Feed combiner into ops
WindQAQ Jan 19, 2021
30cc869
Run buildifier
WindQAQ Jan 19, 2021
1ed620f
Typecheckd
WindQAQ Jan 19, 2021
c99e5f0
Uppercase
WindQAQ Jan 19, 2021
961f389
Fix windows build
WindQAQ Jan 19, 2021
7c5c993
Pass tensor into functor
WindQAQ Jan 20, 2021
8d08349
Fix cost estimation
WindQAQ Jan 20, 2021
a4b1c59
CPU backward
WindQAQ Jan 20, 2021
19ce511
Fix type
WindQAQ Jan 20, 2021
01ee9a6
Watch list of tensors
WindQAQ Jan 20, 2021
be06d0e
Tests other input shapes
WindQAQ Jan 20, 2021
5682aff
Add const
WindQAQ Jan 20, 2021
497a26a
Fix value_grads rank
WindQAQ Jan 21, 2021
d624dc2
Enable GPU forward kernel and tests
WindQAQ Jan 21, 2021
ba0a15a
Add missing comma
WindQAQ Jan 21, 2021
a814e7f
Fix wrong name
WindQAQ Jan 21, 2021
7d7e532
Fix wrong name
WindQAQ Jan 21, 2021
2fc54db
Disable GPU backward functor
WindQAQ Jan 21, 2021
9e871c1
Cast for divup
WindQAQ Jan 21, 2021
582c826
Limiting indices and weights to rank 2 as discussed in the PR
Rocketknight1 Jan 21, 2021
9ba1269
Drop rank > 2 support
WindQAQ Jan 22, 2021
4637889
Support mean combiner and make it default
WindQAQ Jan 22, 2021
83dc9a5
Include header in backward files
WindQAQ Jan 22, 2021
97b3c11
Remove rank > 2 support
WindQAQ Jan 22, 2021
3d89a39
Fix build
WindQAQ Jan 22, 2021
12ddb94
Move op definition out of functor namespace
WindQAQ Jan 22, 2021
dc53639
Put files together and match tensorflow naming convention
WindQAQ Jan 22, 2021
55b3c55
Fix guard
WindQAQ Jan 22, 2021
90bebdb
Fix GPU build
WindQAQ Jan 22, 2021
6d95f92
Try to fix GPU build
WindQAQ Jan 22, 2021
b41e6af
Try to fix ubuntu build
WindQAQ Jan 22, 2021
ab903b9
Try to fix ubuntu build
WindQAQ Jan 22, 2021
d13fe4b
Fix indices shape error msg
WindQAQ Jan 22, 2021
131f3f0
Better cost estimation
WindQAQ Jan 22, 2021
0f22120
Fix gradient shape inference
WindQAQ Jan 22, 2021
8bfae5e
Cleanup shape inference
WindQAQ Jan 22, 2021
6f2943c
Rename values to params
WindQAQ Jan 24, 2021
829fd0a
Fix name
WindQAQ Jan 24, 2021
e6ffdd8
Fix weights dtype
WindQAQ Jan 24, 2021
430e769
Return sparse gradient
WindQAQ Jan 24, 2021
56673cb
Re-adding GPU support for backward pass
Rocketknight1 Feb 16, 2021
0770ed0
Limiting indices and weights to rank 2 as discussed in the PR
Rocketknight1 Mar 15, 2021
2743e1a
Clang-format updates
Rocketknight1 Mar 16, 2021
201bc78
Proper clang-format with Google formatting
Rocketknight1 Mar 16, 2021
43a5b1b
Fixing bazel formatting to pass checks
Rocketknight1 Mar 23, 2021
08ffc2e
Fixes to tests to match the API
Rocketknight1 Mar 24, 2021
1ddb54e
Fix some test bugs
Rocketknight1 Apr 5, 2021
3e63625
Revert formatting commit that blanked a couple of files by accident
Rocketknight1 Apr 5, 2021
0431858
Run precommit formatting
Rocketknight1 Apr 5, 2021
2e77fe2
Default the combiner to sum in the layer as well as the op
Rocketknight1 Apr 5, 2021
ef84621
Syntax fix to avoid an error when running non-eagerly
Rocketknight1 Apr 6, 2021
4b82c7f
Adding tf.function to tests
Rocketknight1 Apr 24, 2021
aedb074
Update test syntax
Rocketknight1 May 9, 2021
184f624
Stop returning IndexedSlices gradients
Rocketknight1 May 10, 2021
6480f67
Code style pass
Rocketknight1 May 10, 2021
3137c3c
Adding myself to CODEOWNERS
Rocketknight1 May 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
/tensorflow_addons/layers/tests/noisy_dense_test.py @leonshams @markub3327
/tensorflow_addons/layers/max_unpooling_2d.py @thaink
/tensorflow_addons/layers/tests/max_unpooling_2d_test.py @thaink
/tensorflow_addons/layers/embedding_bag.py @rocketknight1
/tensorflow_addons/layers/tests/embedding_bag_test.py @rocketknight1

/tensorflow_addons/losses/contrastive.py @windqaq
/tensorflow_addons/losses/tests/contrastive_test.py @windqaq
Expand Down
14 changes: 14 additions & 0 deletions tensorflow_addons/custom_ops/layers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,17 @@ custom_op_library(
"cc/kernels/correlation_cost_op_gpu.cu.cc",
],
)

custom_op_library(
name = "_embedding_bag_ops.so",
srcs = [
"cc/kernels/embedding_bag_ops.cc",
"cc/kernels/embedding_bag_ops.h",
"cc/ops/embedding_bag_ops.cc",
],
cuda_srcs = [
"cc/kernels/embedding_bag_ops.h",
"cc/kernels/embedding_bag_ops_gpu.cu.cc",
"cc/kernels/embedding_bag_backward_kernels.cu.cc",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if GOOGLE_CUDA

#define EIGEN_USE_GPU

#include <thrust/device_ptr.h>
#include <thrust/fill.h>
#include <thrust/sort.h>

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
#include "tensorflow_addons/custom_ops/layers/cc/kernels/embedding_bag_ops.h"

constexpr int MAX_THREADS_PER_BLOCK = 1024;

namespace tensorflow {
namespace addons {
namespace functor {

typedef Eigen::GpuDevice GPUDevice;

template <typename Tindices, const int kThreadsPerBlock>
__global__ void PrepTempArraysKernel(
const Tindices *__restrict__ indices, Tindices *__restrict__ sortedIndices,
Tindices *__restrict__ sortedIndicesCounter, const int indices_size) {
const int arrayIdx = (blockIdx.x * kThreadsPerBlock) + threadIdx.x;
if (arrayIdx <
indices_size) { // Make sure we don't run off the end of the actual array
sortedIndices[arrayIdx] = indices[arrayIdx];
sortedIndicesCounter[arrayIdx] = arrayIdx;
}
}

// Define the CUDA kernel.
template <typename T, typename Tindices, const int kThreadsPerBlock>
__global__ void EmbeddingBagWeightsGradKernel(
const int value_dim, const Tindices *__restrict__ indices,
const T *__restrict__ values, const T *__restrict__ dloss,
T *__restrict__ weights_grad) {
const int sample_idx = blockIdx.x;
const int bag_idx = blockIdx.y;
const int bag_dim = gridDim.y;
const int valueBaseIdx =
indices[(sample_idx * bag_dim) + bag_idx] * value_dim;
const int dlossBaseIdx = sample_idx * value_dim;
// Use a full-precision accumulator even for half-precision inputs
float partialDotProduct = 0.0f;
for (int i = threadIdx.x; i < value_dim;
i += blockDim.x) // Note that some threads may stop one iteration
// earlier if the block straddles the end of the array
{
partialDotProduct +=
static_cast<float>(values[valueBaseIdx + i] * dloss[dlossBaseIdx + i]);
}
unsigned activeMask = 0xffffffff;
#pragma unroll
for (int offset = kThreadsPerBlock / 2; offset > 0; offset /= 2) {
partialDotProduct +=
__shfl_down_sync(activeMask, partialDotProduct, offset);
}
// Thread 0 now has the full dot product
if (threadIdx.x == 0) {
weights_grad[(sample_idx * bag_dim) + bag_idx] =
static_cast<T>(partialDotProduct);
}
}

template <typename T, typename Tindices>
__global__ void EmbeddingBagValuesGradKernel(
const int value_dim, const int bag_dim,
const Tindices *__restrict__ sortedIndices,
const Tindices *__restrict__ counter, const T *__restrict__ values,
const T *__restrict__ weights, const T *__restrict__ dloss,
T *__restrict__ values_grad) {
const int startIdx = blockIdx.x;
const int chunk = blockIdx.y;
const int kThreadsPerBlock = blockDim.x;
const int featureIdx = threadIdx.x + (chunk * kThreadsPerBlock);
// The core problem here is that we want to avoid parallel writes to the
// same element of the grads. We avoid that by pre-sorting a copy of the
// indices tensor, and also co-sorting a 'counter' array so that we still know
// which element of the incoming gradient tensor corresponds to each. Then, we
// take the slightly lazy approach of spinning up a warp for each element of
// the indices array, but having each warp check the previous element before
// it starts. If the two elements are the same, then the warp immediately
// returns without doing anything. If not, then the warp iterates forward and
// accumulates gradient until it hits a different index element, at which
// point it writes the accumulated value and returns. This ensures that each
// row of the values grad tensor is handled by one and exactly one warp.
const int valuesIdx = ldg(sortedIndices + startIdx);
if (startIdx > 0) {
const int prevIdx = ldg(sortedIndices + startIdx - 1);
if (prevIdx == valuesIdx) {
return; // Another block is handling this index, exit
}
}
int endIdx = startIdx;
while (endIdx < gridDim.x - 1) // Don't run off the end of the array
{
int nextIdx = endIdx + 1;
int nextValuesIdx = ldg(sortedIndices + nextIdx);
if (nextValuesIdx == valuesIdx) {
endIdx += 1;
} else {
break;
}
}
if (featureIdx < value_dim) // Don't run off the end of the row
{
const int outputOffset = (valuesIdx * value_dim) + featureIdx;
float accum = 0.0f; // Full precision even if the inputs aren't

for (int currentIdx = startIdx; currentIdx <= endIdx; ++currentIdx) {
int originalIdxPosition = ldg(counter + currentIdx);
T weight = weights[originalIdxPosition];
// The floor division on this line is correct and intentional
T featureDloss =
ldg(dloss + (originalIdxPosition / bag_dim) + featureIdx);
accum += static_cast<float>(weight * featureDloss);
}
values_grad[outputOffset] = static_cast<T>(accum);
}
}

// Define the GPU implementation that launches the CUDA kernel.
template <typename T, typename Tindices>
struct EmbeddingBagBackwardFunctor<GPUDevice, T, Tindices> {
// indices should remain unchanged, but thrust complains if it's a const
// pointer
void operator()(const GPUDevice &d,
typename TTypes<Tindices, 2>::ConstTensor indices,
typename TTypes<T, 2>::ConstTensor params,
typename TTypes<T, 2>::ConstTensor weights,
typename TTypes<T, 2>::ConstTensor grads,
typename TTypes<T, 2>::Tensor params_grads,
typename TTypes<T, 2>::Tensor weights_grads,
Combiner combiner, OpKernelContext *context) {
// I copy-pasted this bit from histogram_op_gpu.cu.cc and I sure hope it
// works
tensorflow::AllocatorAttributes gpu_allocator;
gpu_allocator.set_on_host(false);
gpu_allocator.set_gpu_compatible(true);

Tensor sortedIndicesTensor;
Tensor sortedIndicesCounterTensor;

OP_REQUIRES_OK(context,
context->allocate_temp(DataTypeToEnum<Tindices>::value,
TensorShape({indices.size()}),
&sortedIndicesTensor, gpu_allocator));
OP_REQUIRES_OK(context, context->allocate_temp(
DataTypeToEnum<Tindices>::value,
TensorShape({indices.size()}),
&sortedIndicesCounterTensor, gpu_allocator));
auto sortedIndices = sortedIndicesTensor.flat<Tindices>();
auto sortedIndicesCounter = sortedIndicesCounterTensor.flat<Tindices>();
// Note: I tried splitting the two kernels into different streams but
// performance was barely affected.
const Eigen::Index batch_dim = indices.dimension(0);
const Eigen::Index bag_dim = indices.dimension(1);
const Eigen::Index output_dim = params.dimension(1);
const auto params_size = params.size();
const int kThreadsPerBlock = 32;
dim3 gridShape = dim3(batch_dim, bag_dim, 1);
TF_CHECK_OK(GpuLaunchKernel(
EmbeddingBagWeightsGradKernel<T, Tindices, kThreadsPerBlock>, gridShape,
kThreadsPerBlock, 0, d.stream(), output_dim, indices.data(),
params.data(), grads.data(), weights_grads.data()));

const int indices_size = indices.size();
const int values_size = params.size();
const int total_blocks = Eigen::divup(indices_size, kThreadsPerBlock);
gridShape = dim3(total_blocks, 1, 1);

TF_CHECK_OK(GpuLaunchKernel(
PrepTempArraysKernel<Tindices, kThreadsPerBlock>, gridShape,
kThreadsPerBlock, 0, d.stream(), indices.data(), sortedIndices.data(),
sortedIndicesCounter.data(), indices_size));

thrust::device_ptr<Tindices> sortedIndicesCounterDevicePtr(
sortedIndicesCounter.data());
thrust::device_ptr<Tindices> sortedIndicesDevicePtr(sortedIndices.data());
thrust::device_ptr<T> paramsGradDevicePtr(params_grads.data());
thrust::fill(paramsGradDevicePtr,
paramsGradDevicePtr + static_cast<int>(params_size),
static_cast<T>(0.0f));
thrust::sort_by_key(sortedIndicesDevicePtr,
sortedIndicesDevicePtr + indices_size,
sortedIndicesCounterDevicePtr);
// Handle each row with as few thread blocks as possible
int threadsPerBlock;
int blocksPerRow;
if (output_dim <= MAX_THREADS_PER_BLOCK) {
blocksPerRow = 1;
threadsPerBlock = output_dim;
} else {
blocksPerRow =
Eigen::divup(static_cast<int>(output_dim), MAX_THREADS_PER_BLOCK);
threadsPerBlock =
Eigen::divup(static_cast<int>(output_dim), blocksPerRow);
}
// int blocksPerRow = 1;
// while (threadsPerBlock > MAX_THREADS_PER_BLOCK) {
// threadsPerBlock = (threadsPerBlock + 1) / 2; // Ceiling division
// blocksPerRow *= 2;
// }
gridShape = dim3(indices_size, blocksPerRow, 1);
TF_CHECK_OK(GpuLaunchKernel(
EmbeddingBagValuesGradKernel<T, Tindices>, gridShape, threadsPerBlock,
0, d.stream(), output_dim, bag_dim, sortedIndices.data(),
sortedIndicesCounter.data(), params.data(), weights.data(),
grads.data(), params_grads.data()));
}
};

// Explicitly instantiate functors for the types of OpKernels registered.
template struct EmbeddingBagBackwardFunctor<GPUDevice, double, int32>;
template struct EmbeddingBagBackwardFunctor<GPUDevice, float, int32>;
template struct EmbeddingBagBackwardFunctor<GPUDevice, Eigen::half, int32>;
template struct EmbeddingBagBackwardFunctor<GPUDevice, double, int64>;
template struct EmbeddingBagBackwardFunctor<GPUDevice, float, int64>;
template struct EmbeddingBagBackwardFunctor<GPUDevice, Eigen::half, int64>;
} // namespace functor
} // namespace addons
} // namespace tensorflow

#endif // GOOGLE_CUDA
Loading