Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NTK-Aware interpolation "by parts" correction #1

Merged
merged 2 commits into from
Jul 9, 2023

Conversation

bloc97
Copy link
Collaborator

@bloc97 bloc97 commented Jul 7, 2023

This PR adds the new and improved "by parts" correction to the NTK-aware interpolation method.

This corrected method improves from previous methods fourfold:

  1. Decreases PPL in all context lengths when used on non-finetuned models compared to previous NTK-Aware method, especially for higher context sizes as alpha value can be set much lower for same context size.
  2. Removes the alpha parameter, which did not accurately predict effective context length and was variable across different models. Now uses same scale parameter as linear interpolation which is much more intuitive and less prone to mistakes/misuse. (This was possible by fixing the alpha scale "drift" found in all LLaMA models)
  3. Fixes the extrapolation regime that was breaking a lot of fine-tunes when alpha was set to a non-optimal value. Fine-tuning should be much easier and performance should in theory be increased significantly as there is no need to search for optimal alpha.
  4. This method generalizes on both Extrapolation, NTK-Aware and Linear interpolation. For example, setting ntk_factor and extrapolation_factor to 0 will yield identical results to linear interpolation.

scale parameter should be used the same as linear interpolation. (eg. scale=2 is 2048 base ctx extended to 4096)
extrapolation_factor and ntk_factor are used for validation purposes, and should not be changed unless it is necessary.
Edit: Also max_position_embeddings is assumed to be the original pretrained model context size! Leave it at 2048 for LLaMA models, changing it will break the code... Fixed, added original_max_position_embeddings parameter to avoid any confusion

Comparison of new corrected NTK-Aware method to previous non-corrected NTK-Aware method. Note the new scale factor is still called alpha in this graph.
Comparison graph

Now all is left is to validate this by finetuning!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants