Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic RotaryEmbedding Layer #1180

Merged
merged 11 commits into from
Aug 1, 2023
Merged

Conversation

shivance
Copy link
Collaborator

RotaryEmbedding is in the air. New SOTA models like Falcon, GPT-J, LLaMA are using it. I already developed RotaryEmbedding Layer to KerasNLP.

This is the perfect time to move it to layers and exposing it in our API should be a good addition to our API.

Closes #1087

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! A few comments.

from keras_nlp.tests.test_case import TestCase


class RotaryEmbeddingTest(TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should definitely fill out the test case now that we are exposing this standalone. Maybe look of the SinePositionEncoding layer tests as a start.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, added tests from SinePositionEncoding layer, they all pass now for all backends

matrix. It calculates the rotary encoding with a mix of sine and
cosine functions with geometrically increasing wavelengths.
Defined and formulated in [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4).
Takes as input the query and key tensors. The input must have shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should also allow this layer to take in shape [batch_size, sequence_length, feature_dim]. I'll leave some more comments below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgtm.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still need updating in the docstring.

References:
- [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4)
"""

def __init__(self, max_wavelength=10000, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably add two arguments here. sequence_axis=1, and feature_axis=-1, that users can set as desired.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be cool. Added.

References:
- [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4)
"""

def __init__(self, max_wavelength=10000, **kwargs):
super().__init__(**kwargs)
self.max_wavelength = max_wavelength
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments for below (github won't let me comment). Why is the following necessary?

cos_emb = cos_emb[:, : ops.shape(tensor)[1], :, :]
sin_emb = sin_emb[:, : ops.shape(tensor)[1], :, :]

The cos/sin embeddings should already have seq_len shape when you compute them.

Lastly, if you wanted to make _compute_cos_sin_embedding work with any number of dimensions, you would need to update it. Here's a draft of a change, but haven't tested this yet.

embedding = ops.concatenate((freqs, freqs), axis=-1)
for dim in range(len(x.shape)):
    if axis != self.sequence_axis and axis != self.feature_axis:
        embedding = ops.expand_dims(embedding, axis)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer, got it working with some tweaks !

@shivance shivance changed the title Move RotaryEmbedding to Modeling Layers Move RotaryEmbedding to Modeling Layers Jul 29, 2023
@shivance shivance changed the title Move RotaryEmbedding to Modeling Layers General purpose RotaryEmbedding Layer Jul 29, 2023
@shivance
Copy link
Collaborator Author

shivance commented Jul 29, 2023

test_float16_dtype fails for tf.keras 👀

ValueError: Unsupported dtype19 for '{{node rotary_embedding/range}} = Range[Tidx=DT_HALF](rotary_embedding/range/start, rotary_embedding/range/Cast, rotary_embedding/range/delta)' with input shapes: [], [], [] and with computed input tensors: input[0] = <0>, input[1] = <32>, input[2] = <2>.

self.max_wavelength = max_wavelength
self.sequence_axis = sequence_axis
self.feature_axis = feature_axis
self.scaling_factor = scaling_factor
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattdangerw LLaMa Rotary Embedding layers use a scaling factor, I've added that too!

@shivance
Copy link
Collaborator Author

/gcbrun

@shivance shivance requested a review from mattdangerw July 29, 2023 17:06
@shivance shivance changed the title General purpose RotaryEmbedding Layer Generic RotaryEmbedding Layer Jul 29, 2023
@mattdangerw
Copy link
Member

ValueError: Unsupported dtype19 for '{{node rotary_embedding/range}} = Range[Tidx=DT_HALF](rotary_embedding/range/start, rotary_embedding/range/Cast, rotary_embedding/range/delta)' with input shapes: [], [], [] and with computed input tensors: input[0] = <0>, input[1] = <32>, input[2] = <2>.

Potentially this is just arange not supporting half floats? You could maybe just do

freq_range = ops.arange(0, rotary_dim, 2, "float32")
freq_range = ops.cast(freq_range, self.compute_dtype)

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Looking generally good in terms of the call method, but some polish needed.

matrix. It calculates the rotary encoding with a mix of sine and
cosine functions with geometrically increasing wavelengths.
Defined and formulated in [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4).
Takes as input the query and key tensors. The input must have shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still need updating in the docstring.

@shivance
Copy link
Collaborator Author

shivance commented Aug 1, 2023

Potentially this is just arange not supporting half floats? You could maybe just do

freq_range = ops.arange(0, rotary_dim, 2, "float32")
freq_range = ops.cast(freq_range, self.compute_dtype)

This is persisting in any case. I've tried this and other typecasting.

@shivance
Copy link
Collaborator Author

shivance commented Aug 1, 2023

/gcbrun

Fix dtypes with arange.
@shivance
Copy link
Collaborator Author

shivance commented Aug 1, 2023

/gcbrun

@mattdangerw
Copy link
Member

/gcbrun

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! This ended up quite clean.

@mattdangerw mattdangerw merged commit 272ba83 into keras-team:master Aug 1, 2023
@shivance shivance deleted the move_rot_emb branch August 5, 2023 14:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Move RotaryEmbedding to layers
2 participants