Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move RotaryEmbedding layer from gpt_neo_x to layers #1092

Closed
wants to merge 4 commits into from

Conversation

shivance
Copy link
Collaborator

Closes #1087

@mattdangerw
Copy link
Member

/gcbrun

@shivance shivance mentioned this pull request Jun 27, 2023
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a full tests file, similar to other layers in keras_nlp/layers.

def __init__(self, rotary_percentage, max_wavelength=10000):
"""Rotary positional encoding layer.

Tbjs layer encodes absolute positional information with rotation matrix and naturally
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbjs -> This. The alignment this whole docstring looks wrong, just four space in should do it.

Also please make sure to check everything besides links is <= 80 characters.

Copy link
Collaborator Author

@shivance shivance Jun 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any automated way to work this out for line length?
I make sure to run ./shell/format.sh everytime

](https://arxiv.org/abs/2104.09864v4).

Takes as input the query and key tensors. The input must have shape
[batch_size, num_heads, sequence_length, query_length]. This layer will return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to consider the general form of this. I don't think we want to require a head axis that seems a little two special cased.

We can safely assume that the batch axis is 0, most layers do this. We can also assume feature axis is -1. If we need to take in the sequence dim axis, maybe let's add that as an argument sequence_axis=1 and allow it to be specified. Then we should test this layer with "multi-head" inputs, and simpler (batch_size, sequence_length, feature_dim) inputs.

incorporates explicit relative position dependency in self-attention formulation.
It layer calculates the rotary encoding with a mix of sine and cosine
functions with geometrically increasing wavelengths. Defined and formulized
in [RoFormer: Enhanced Transformer with Rotary Position Embedding
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep the link on one line, this can exceed 80 chars.

num_heads = 8
sequence_length = 256
query_length = key_length = 256
query = tf.ones((batch_size, num_heads, sequence_length, query_length))
Copy link
Member

@mattdangerw mattdangerw Jun 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not match our general ordering for dims I think? After projecting to multi-headed space, I believe our shapes will look like (batch_size, sequence_length, num_heads, head_dim). Important to follow KerasNLP conventions here, not the ones we picked up from gpt-neox

Also query_length is a bit of an odd term here, should that be head_dim? Or if this is the token length of the query, how is this different than sequence_length?

@@ -14,7 +14,7 @@
import tensorflow as tf
from tensorflow import keras

from keras_nlp.models.gpt_neo_x.rotary_embedding import RotaryEmbedding
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one more nit, let's pass the arguments to RotaryEmbedding via keyword args, not positionally, below.

@shivance shivance closed this Jul 6, 2023
@shivance shivance force-pushed the master branch 2 times, most recently from 968c71d to f68c256 Compare July 6, 2023 17:44
shivance and others added 4 commits July 6, 2023 23:15
* fix rotary emb

* refactor + remove unnecessary typecast

* fix formatting

* refactor

* formatting fix

* refactoring rotary emb

* added a kwarg in super().__init__()
@shivance shivance reopened this Jul 6, 2023
@mattdangerw
Copy link
Member

Let's merge #1111 first, we will need that anyway.

@shivance
Copy link
Collaborator Author

Closing this for now (due to some reasons), will open a followup PR.

@shivance shivance closed this Jul 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Move RotaryEmbedding to layers
2 participants