-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FlaxGPTNeo #12493
FlaxGPTNeo #12493
Conversation
|
||
self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool") | ||
if self.attention_type == "local": | ||
self.causal_mask = self.causal_mask ^ jnp.tril(self.causal_mask, -config.window_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe an additional comment here would be nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome - very clean!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very nice, thanks for working on it @patil-suraj!
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0 | ||
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1 | ||
pt_model = pt_model_class(config).eval() | ||
fx_model = model_class(config, dtype=jnp.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is fx_model
a common name for Flax models? It reminds of torch.fx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aah, yeah this is confusing. Maybe we could use flx
or just flax
for flax models. (cc @patrickvonplaten )
What does this PR do?
This PR adds the Flax version of GPTNeo. For local attention, it uses the fix proposed by @finetuneanon in #11630.
Thanks a lot, @finetuneanon for proposing the solution, it's especially important in JAX/Flax where we can't have dynamic shapes.
Official GPTNeo flax checkpoints are up on the hub and slow tests are passing.