-
Notifications
You must be signed in to change notification settings - Fork 254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move RotaryEmbedding
layer from gpt_neo_x to layers
#1092
Conversation
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs a full tests file, similar to other layers in keras_nlp/layers
.
keras_nlp/layers/rotary_embedding.py
Outdated
def __init__(self, rotary_percentage, max_wavelength=10000): | ||
"""Rotary positional encoding layer. | ||
|
||
Tbjs layer encodes absolute positional information with rotation matrix and naturally |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tbjs -> This. The alignment this whole docstring looks wrong, just four space in should do it.
Also please make sure to check everything besides links is <= 80 characters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any automated way to work this out for line length?
I make sure to run ./shell/format.sh
everytime
keras_nlp/layers/rotary_embedding.py
Outdated
](https://arxiv.org/abs/2104.09864v4). | ||
|
||
Takes as input the query and key tensors. The input must have shape | ||
[batch_size, num_heads, sequence_length, query_length]. This layer will return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want to consider the general form of this. I don't think we want to require a head axis that seems a little two special cased.
We can safely assume that the batch axis is 0, most layers do this. We can also assume feature axis is -1. If we need to take in the sequence dim axis, maybe let's add that as an argument sequence_axis=1
and allow it to be specified. Then we should test this layer with "multi-head" inputs, and simpler (batch_size, sequence_length, feature_dim)
inputs.
keras_nlp/layers/rotary_embedding.py
Outdated
incorporates explicit relative position dependency in self-attention formulation. | ||
It layer calculates the rotary encoding with a mix of sine and cosine | ||
functions with geometrically increasing wavelengths. Defined and formulized | ||
in [RoFormer: Enhanced Transformer with Rotary Position Embedding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep the link on one line, this can exceed 80 chars.
keras_nlp/layers/rotary_embedding.py
Outdated
num_heads = 8 | ||
sequence_length = 256 | ||
query_length = key_length = 256 | ||
query = tf.ones((batch_size, num_heads, sequence_length, query_length)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not match our general ordering for dims I think? After projecting to multi-headed space, I believe our shapes will look like (batch_size, sequence_length, num_heads, head_dim)
. Important to follow KerasNLP conventions here, not the ones we picked up from gpt-neox
Also query_length is a bit of an odd term here, should that be head_dim? Or if this is the token length of the query, how is this different than sequence_length?
@@ -14,7 +14,7 @@ | |||
import tensorflow as tf | |||
from tensorflow import keras | |||
|
|||
from keras_nlp.models.gpt_neo_x.rotary_embedding import RotaryEmbedding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one more nit, let's pass the arguments to RotaryEmbedding via keyword args, not positionally, below.
968c71d
to
f68c256
Compare
* fix rotary emb * refactor + remove unnecessary typecast * fix formatting * refactor * formatting fix * refactoring rotary emb * added a kwarg in super().__init__()
Let's merge #1111 first, we will need that anyway. |
Closing this for now (due to some reasons), will open a followup PR. |
Closes #1087