Skip to content

EmbeddingBag op and layer #2352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 66 commits into from
May 28, 2021
Merged

EmbeddingBag op and layer #2352

merged 66 commits into from
May 28, 2021

Conversation

Rocketknight1
Copy link
Contributor

Description

Brief Description of the PR:
This is a PR for the EmbeddingBag op. Please don't merge it yet! Although it works, testing is incomplete and the file structure needs to be cleaned up. I'm opening it now just to get some initial feedback. I'll keep working on several of these issues (particularly 1, 3, 4 and 6 see below), but I'll need some feedback on 2) and 5), plus any other feedback you have for the rest of it!

Fixes # (issue)
#2201

Type of change

New layer and associated C++/CUDA op

Comments

There are a few issues that need to be resolved before I'd feel comfortable with this being merged. In no particular order, they are:

  1. The CUDA/C++ code is split with the forward and backward passes in separate files, which is not how other Tensorflow or Addons ops do it. This is just a style thing - I'll merge them soon.

  2. There are really two different entrypoints for users here, the function/op (analogous to tf.gather) and the layer (analogous to tf.keras.layers.Embedding). Like Embedding, the layer instantiates its own embeddings tensor and expects to be passed only indices and weights, whereas the function needs to be passed embeddings as well. Following PyTorch's naming conventions, I called the op embeddingbag and the layer EmbeddingBag, but this is almost certainly not what you want. What is the right way to name these two? Should I make the function/op a stateless Layer rather than just a function?

  3. No support for float16/bfloat16 yet.

  4. Because context->AllocateTemp continuously segfaulted for me when I was compiling in the custom-op repo, I used AllocateOutput to make some dummy outputs and then just used them as temp arrays. Compiling in tensorflow_addons itself seems much more stable, but I still need to go back and set that properly to AllocateTemp.

  5. The CUDA/C++ ops expect a weight tensor. When no weights are passed, the Python wrapper instantiates dummy weights with tf.ones_like(). Is this acceptable?

  6. More tests! I don't have any gradient tests at all yet, and I should probably add additional tests with weird shapes.

@google-cla
Copy link

google-cla bot commented Jan 18, 2021

Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

📝 Please visit https://cla.developers.google.com/ to sign.

Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.


What to do if you already signed the CLA

Individual signers
Corporate signers

ℹ️ Googlers: Go here for more info.

@Rocketknight1
Copy link
Contributor Author

@googlebot I signed it!

@google-cla google-cla bot added cla: yes and removed cla: no labels Jan 18, 2021
@bhack
Copy link
Contributor

bhack commented Jan 18, 2021

@bhack
Copy link
Contributor

bhack commented Jan 18, 2021

Here some other examples about custom + python ops alternative impl (now python only on master) #1114 (comment)

@Rocketknight1
Copy link
Contributor Author

Sure, I'll add that soon while I'm fixing everything else up.

Copy link
Member

@WindQAQ WindQAQ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Rocketknight1, thanks for the contribution. This is a huge PR. I have a suggestion: I can help you do C++ CPU impl (with threading and vectorization) and python interface. This can reduce back and forth communication during review. If you can accept it, I'll work on your branch and comment out GPU build (I'll modify style and input/output), or I can file another PR for CPU ops. Thank you!

The CUDA/C++ ops expect a weight tensor. When no weights are passed, the Python wrapper instantiates dummy weights with tf.ones_like(). Is this acceptable.

It is acceptable. TensorFlow op registration mini language does not support optional input. The workaround is

  1. Manipulate some inputs in python (as what you do).
  2. Or pass input as a list of Tensor. In this way, you can check the length of input to identity if Tensors present or not.

@tanguycdls
Copy link

tanguycdls commented Jan 19, 2021

Hi @Rocketknight1 thanks for your PR. We recently switched from Torch to TF and Embedding Bag was missing for us too!

In our use case we often work with list of non constant len called Ragged tensors in TF which uses a similar data format as Sparse CSR matrix:

https://www.tensorflow.org/guide/ragged_tensor

We have ragged tensors such as :

offsets = [0, 0, 0, 1, 5, 7]
indices = [12, 13, 14, 15, 16, 78, 16]
tf.RaggedTensor.from_row_splits(indices, offsets)

<tf.RaggedTensor [[], [], [12], [13, 14, 15, 16], [78, 16]]>

For now we currently replace embedding bag by converting our ragged to a Sparse Tensor (indices being the rows and y the nbr of item in each row) and values being the indices we want to gather. We also use a second sparse Tensor which will have the weights instead of the indices.
We can then use tf.nn.safe_embedding_lookup_sparse and get better result than a simple gather then reduce. I'm not very clear on the ram usage of that one but it does the embedding lookup on unique indices and then apply a gather on it. (see)

Another workaround we found is to create a sparse tensor Indicator: the x coordinate will be the rows of your batch, the y the indices and values being the weight: you can then consider your embedding + sum as a sparse_dense_matmul between the embeddings matrix and your sparse indicator. The sparse_dense_matmul itself is very fast the issue is more on creating the sparse indicator. I'm not sure how that option behaves on memory since the internals are handled by TF.

I did a few tests here:

https://colab.research.google.com/gist/tanguycdls/9c696097642844fc1e548c0cade48e11/sparseembeddings.ipynb

the performance depends a lot on the sparsity of indices and nbr of items to compute in the ragged case.

I'll try to compile your branch to compare the performance between sparse matmul and your EmbeddingBag ! We did a few months ago a benchmark Pytorch vs TF and embeddingbag was slightly better than matmul in some cases.

Would be happy to help on benchmarks if you need some in that PR !

@bhack
Copy link
Contributor

bhack commented Jan 19, 2021

@tanguycdls Thanks. This seems interesting to explore.

@Rocketknight1
Copy link
Contributor Author

Hi @tanguycdls, your workaround with sparse multiplications is really interesting! I'm also curious to know how it compares in terms of memory usage to the CUDA EmbeddingBag. Please note that my CPU implementations are not very optimized right now (as @WindQAQ pointed out), but the CUDA should be at least reasonably performant.

@Rocketknight1
Copy link
Contributor Author

Also yes @WindQAQ, if you want to improve the CPU implementations feel free to make those changes. I'm aware that they are currently 'reference' implementations that are not high-performance. Thank you!

@bhack
Copy link
Contributor

bhack commented Jan 19, 2021

@tanguycdls It could be interesting to benchmark also against TF-nightly as in the Compiler/Stack some optimizations on sparse are really on the development edge. See https://llvm.discourse.group/t/mlir-support-for-sparse-tensors/

@bhack
Copy link
Contributor

bhack commented Jan 19, 2021

/cc @aartbik if he is interested on the sparse code-path.

@google-cla
Copy link

google-cla bot commented Jan 19, 2021

All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter.

We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only @googlebot I consent. in this pull request.

Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the cla label to yes (if enabled on your project).

ℹ️ Googlers: Go here for more info.

@google-cla google-cla bot added cla: no and removed cla: yes labels Jan 19, 2021
@Rocketknight1
Copy link
Contributor Author

Hi @bhack can you approve running workflows? It's awkward for me to run all the tests locally and it's really helpful if I can quickly check via CI!

bhack
bhack previously approved these changes May 7, 2021
@Rocketknight1 Rocketknight1 dismissed stale reviews from bhack and fsx950223 via aedb074 May 9, 2021 17:56
@Rocketknight1
Copy link
Contributor Author

I have run into a problem - the code added by @WindQAQ to convert the parameter gradients to an IndexedSlices object does not seem to work in graph mode. I don't think this is a problem with the code - it looks good to me! However, the IndexedSlices tensor is returned as a dense tensor when I wrap _embedding_bag with tf.function().

I googled around and couldn't find a cause for this, but there's a suggestive issue here - it's possible this is because of a different TF bug that got monkey-patched: tensorflow/tensorflow#36236

Either way, I've edited the tests and hopefully it should all work now.

sorted_unique_indices = tf.sort(unique_indices)
return [
None,
tf.IndexedSlices(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I made the change!

@Rocketknight1
Copy link
Contributor Author

@bhack I believe this PR should have fixed it! Can you approve running the tests?

@bhack
Copy link
Contributor

bhack commented May 10, 2021

You have a lint error.

@Rocketknight1
Copy link
Contributor Author

Fixed! I think.

@Rocketknight1
Copy link
Contributor Author

We did it!

fsx950223
fsx950223 previously approved these changes May 10, 2021
Copy link
Contributor

@bhack bhack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rocketknight1
Copy link
Contributor Author

@bhack Ah, sorry, just saw that. I've added myself!

@Rocketknight1
Copy link
Contributor Author

@bhack Can we run tests? I think this is ready to merge now!

@Rocketknight1
Copy link
Contributor Author

@bhack I made the requested change (adding myself to CODEOWNERS) so I think this is ready to merge!

@bhack
Copy link
Contributor

bhack commented May 27, 2021

@fsx950223 Anything else?

@bhack bhack merged commit 3c662c6 into tensorflow:master May 28, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants