Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored convergence tests to be portable #41

Merged
merged 7 commits into from
Aug 19, 2024

Conversation

shimizust
Copy link
Collaborator

@shimizust shimizust commented Aug 17, 2024

Summary

  • Make convergence tests more portable and easier to run by using pre-tokenized data. This removes internal paths and allows users to not have to download specific model tokenizers in a certain location.
  • Since we're just testing convergence on a mini model with random weights, the specific tokenizer doesn't really matter.
  • Convergence tests also finish faster: ~95 sec -> ~60 sec

Alternatives:

  • Provide users the ability to configure paths to the different models used in convergence test or HF token to download the tokenizer (inconvenient to configure/download things run tests, not portable, different tokenizer versions could break tests)
  • Save the tokenizers in the repo (licensing issues)
  • Save a small, completely OS tokenizer in the repo and use across all tests (could do this, but also more performant to just pre-tokenize the data)

Testing Done

Ran convergence tests successfully

  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence
jobuser [ ~/Liger-Kernel ]$ make checkstyle
flake8 .; flake8_status=$?; \
isort .; isort_status=$?; \
black .; black_status=$?; \
if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \
        exit 1; \
fi
Fixing /home/jobuser/Liger-Kernel/test/convergence/test_mini_models.py
Skipped 1 files
All done! ✨ 🍰 ✨
45 files left unchanged.
jobuser [ ~/Liger-Kernel ]$ make test
pytest --disable-warnings test/ --ignore=test/convergence
===================================================================================================================== test session starts ======================================================================================================================
platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0
rootdir: /home/jobuser/Liger-Kernel
plugins: lipy-config-base-30.6.1, lipy-fabric-35.2.3, lipy-test-8.0.52, datadir-1.3.1, lipy-mp-34.4.191
collected 111 items                                                                                                                                                                                                                                            

test/transformers/test_cross_entropy.py ..........................................................                                                                                                                                                       [ 52%]
test/transformers/test_fused_linear_cross_entropy.py ......                                                                                                                                                                                              [ 57%]
test/transformers/test_geglu.py ........                                                                                                                                                                                                                 [ 64%]
test/transformers/test_rms_norm.py ................                                                                                                                                                                                                      [ 79%]
test/transformers/test_rope.py ............                                                                                                                                                                                                              [ 90%]
test/transformers/test_swiglu.py ........                                                                                                                                                                                                                [ 97%]
test/transformers/test_transformers_monkey_patch.py .                                                                                                                                                                                                    [ 98%]
test/triton/test_triton_monkey_patch.py ..                                                                                                                                                                                                               [100%]

================================================================================================================ 111 passed in 60.81s (0:01:00) ================================================================================================================
jobuser [ ~/Liger-Kernel ]$ make test-convergence
HF_DATASETS_OFFLINE=1 pytest --disable-warnings test/convergence
===================================================================================================================== test session starts ======================================================================================================================
platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0
rootdir: /home/jobuser/Liger-Kernel
plugins: lipy-config-base-30.6.1, lipy-fabric-35.2.3, lipy-test-8.0.52, datadir-1.3.1, lipy-mp-34.4.191
collected 8 items                                                                                                                                                                                                                                              

test/convergence/test_mini_models.py ......                                                                                                                                                                                                              [ 75%]
test/convergence/test_mini_models_no_logits.py ..                                                                                                                                                                                                        [100%]

====================================================================================================================== 8 passed in 58.41s ======================================================================================================================

@lancerts
Copy link
Collaborator

lancerts commented Aug 17, 2024

Great work!! Can we paste the testing screenshot in the PR as #21? Thanks

test/convergence/test_mini_models.py Show resolved Hide resolved
@@ -210,7 +172,7 @@ def run_mini_model(
("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 1e-4, 1e-5, 2e-3, 1e-5),
("mini_llama3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5),
# TODO: torch 2.5.0 nightly breaks mixtral test, but torch 2.3.0 works fine
("mini_mixtral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 1e-3, 1e-5, 8e-3, 1e-5),
("mini_mixtral", 32, 1e-4, torch.float32, 1e-8, 1e-4, 1e-3, 3e-2, 8e-3, 1e-5),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to relax the bound? test failed? 1e-5 -> 3e-2 seems too much?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the test failed for the previous tolerances. I'm not sure how to account for this--we should probably investigate more the effect of the dataset and other parameters on the expected tolerances. Thoughts @ByronHsu ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we try
("mini_mixtral", 32, 1e-4, torch.float32, 1e-8, 1e-4, 2e-3, 1e-5, 8e-3, 1e-5),
?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lancerts Loss had a few errors:

>           raise AssertionError("\n".join(mismatch_details))
E           AssertionError: Number of mismatched elements: 4
E           Mismatch at index (0, 23): tensor1[(0, 23)] = 0.46933501958847046, tensor2[(0, 23)] = 0.4692351222038269
E           Mismatch at index (0, 24): tensor1[(0, 24)] = 0.4860617518424988, tensor2[(0, 24)] = 0.48613235354423523
E           Mismatch at index (0, 25): tensor1[(0, 25)] = 0.43753352761268616, tensor2[(0, 25)] = 0.4377014636993408
E           Mismatch at index (0, 26): tensor1[(0, 26)] = 0.36302775144577026, tensor2[(0, 26)] = 0.3631027042865753

test/utils.py:83: AssertionError

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works: ("mini_mixtral", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 8e-3, 1e-5),

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, lets use ("mini_mixtral", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 8e-3, 1e-5).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if the test tolerance should be refactored to use a single value instead of 2 degrees of freedom, or like keep the absolute tolerance fixed, and tests just define the relative tolerance

@@ -145,7 +105,7 @@ def run_mini_model(
@pytest.mark.parametrize(
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
[
("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 1e-4, 1e-5, 5e-3, 1e-5),
("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, test failed with the previous tolerance

@ByronHsu
Copy link
Collaborator

This looks awesome!! Can we also include the code for generating the tokenized dataset? name it as generate_tokenized_dataset.py

@ByronHsu
Copy link
Collaborator

Let's ensure this is in before we go public!

@ByronHsu ByronHsu added the p0 label Aug 19, 2024
@shimizust
Copy link
Collaborator Author

This looks awesome!! Can we also include the code for generating the tokenized dataset? name it as generate_tokenized_dataset.py

Thanks, added the generation script

@lancerts lancerts merged commit 8ce3b53 into main Aug 19, 2024
1 check passed
@ByronHsu ByronHsu deleted the sshimizu/test-tokenizer-refactor branch August 23, 2024 06:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants