Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batched implementation of mauve #1

Merged
merged 1 commit into from
Feb 4, 2022
Merged

Conversation

wade3han
Copy link
Contributor

@wade3han wade3han commented Dec 5, 2021

Here, I write a simplified version of batched MAUVE implementation, and there is some place to clean up the code yet.
If you think this implementation is precise, then it would be nice to replace original MAUVE implementation with batched version.
I tested with the sample data and it gives me 0.991679398536574 of mauve score.

@krishnap25
Copy link
Owner

Hi @wade3han, thanks for the fantastic work! The implementation looks correct to me from a quick glance.

Would it be possible for you to make a couple of changes before we merge it in?

  1. Refactoring: Could you add a batch_size as an argument to the existing compute_mauve function with a default value of 1? Then, propagate this in to get_features_from_input such that you call your new fast_featurize_tokens_from_model here if batch_size > 1, else use the old featurize_tokens_from_model.
  2. Tests: It would be great to see a test that the features obtained are correct. In particular, you could return p_features and q_features in the output of compute_mauve here, and print the different between the features from the original implementation and your batched implementation. Once the features are the same (or are extremely close), we can be guaranteed that the final MAUVE score will be the same.

Thank you so much for your contribution -- we really appreciate it!

@wade3han
Copy link
Contributor Author

wade3han commented Dec 6, 2021

❯❯❯ pytest tests
=================================================================================================================== test session starts ====================================================================================================================
platform linux -- Python 3.8.12, pytest-6.2.5, py-1.11.0, pluggy-1.0.0 -- /home/wade/.miniconda3/envs/mauve/bin/python
cachedir: .pytest_cache
rootdir: /home/wade/hyper/mauve
collected 7 items

tests/test_mauve.py::TestMauve::test_default_mauve PASSED                                                                                                                                                                                            [ 14%]
tests/test_mauve.py::TestMauve::test_batchify_mauve[1] PASSED                                                                                                                                                                                        [ 28%]
tests/test_mauve.py::TestMauve::test_batchify_mauve[2] PASSED                                                                                                                                                                                        [ 42%]
tests/test_mauve.py::TestMauve::test_batchify_mauve[3] PASSED                                                                                                                                                                                        [ 57%]
tests/test_mauve.py::TestMauve::test_batchify_mauve[4] PASSED                                                                                                                                                                                        [ 71%]
tests/test_mauve.py::TestMauve::test_batchify_mauve[8] PASSED                                                                                                                                                                                        [ 85%]
tests/test_mauve.py::TestMauve::test_batchify_mauve[16] PASSED                                                                                                                                                                                       [100%]

=============================================================================================================== 7 passed in 60.87s (0:01:00) ===============================================================================================================

I reflected your comments! It was honor to contribute to this wonderful work 😄

@krishnap25
Copy link
Owner

Hi @wade3han,

It would be nice to see that the features from your batched implementation are close to the features from the original representation. If the features are close, the actual value of MAUVE will of course be close.

Here is a concrete example on show to do that. Could you cover a case like this with your test?

p_features_original = ... # Original code without batching, shape = (n, d)
p_features_batched = ... # Your new code with batching, shape = (n, d)
norm_of_difference = np.linalg.norm(p_features_original - p_features_batched, axis=1)   # shape = (n,)
# ensure that new features are close to old features 
assert norm_of_difference.max() < 1e-5 * np.linalg.norm(p_features_original, axis=1).max()   
# repeat the same for q_features

Thanks!

@wade3han
Copy link
Contributor Author

wade3han commented Dec 14, 2021

I misunderstood what you asked for me at first!

I tried to check if the features are close. However, I found a slight difference between features, and now I am trying to find the reason.

So far, this is what I figured out: GPT output differs when the batch size changes. I also reported this phenomenon on HuggingFace. (huggingface/transformers#14743)

There was also a similar issue before, yet it is not resolved yet. (huggingface/transformers#2401)

Any ideas? Thanks 🙏

@krishnap25
Copy link
Owner

Hi @wade3han, thanks for the great detective work!

This might just be a floating point error, which can be quite large for float32. There are a couple of ways to figure this out.

  • What is the relative error? This issue only looks at the absolute error (feat_batch - feat_nobatch), but the relative error (feat_batch - feat_nobatch) / feat_nobatch is more important. Generally, a floating point error of around 1e-5 or 1e-6 is reasonable for float32 but 1e-4 is too large.
  • Please rerun the above experiment with model and data cast to float64. The extra numerical precision should reduce the numerical error to a very small number (< 1e-9 or so).

Based on these two, we can decide whether the difference you see is purely due to numerical errors or if there is some other underlying bug. I am happy to accept the pull request if the problem is purely due to numerical errors.

Thanks again for your fantastic work!

@wade3han
Copy link
Contributor Author

wade3han commented Feb 3, 2022

Sorry for late response!

  • As reply mentioned, the error becomes zero if it is runned on CPU.
  • Moreover, as you suggested, I rerun the above experiment using float64 and both absolute error and relative error goes around zero:
>>> print((outs.hidden_states[-1][0] - outs2.hidden_states[-1][0]) / outs2.hidden_states[-1][0]) 

tensor([[ 9.6949e-16,  5.6669e-15, -5.3828e-15,  ...,  0.0000e+00,
          2.7774e-13,  4.4508e-15],
        [-3.0176e-15, -1.8830e-15, -3.8597e-15,  ...,  8.1987e-15,
         -5.8965e-15, -8.6832e-16],
        [ 1.5195e-15,  1.8600e-15,  4.8238e-15,  ..., -6.2110e-16,
          2.1444e-14, -1.2429e-14],
        [-0.0000e+00,  1.2500e-16,  1.5938e-15,  ...,  1.2114e-15,
          2.0344e-14,  7.2175e-15],
        [-1.4844e-15, -3.8043e-16,  7.0142e-15,  ...,  1.9496e-16,
          2.7474e-14,  1.0162e-13]], device='cuda:0', dtype=torch.float64,
       grad_fn=<DivBackward0>)
       
>>> print(outs.hidden_states[-1][0] - outs2.hidden_states[-1][0])

tensor([[-1.1102e-15, -3.5527e-15, -2.1094e-15,  ...,  0.0000e+00,
          3.9569e-15, -7.7716e-15],
        [ 2.3315e-15,  3.6082e-16,  1.3878e-15,  ...,  1.6098e-15,
         -3.3307e-16, -5.5511e-16],
        [-7.7716e-16,  8.3267e-16, -2.3315e-15,  ...,  6.6613e-16,
          9.0900e-16,  6.3144e-16],
        [ 0.0000e+00,  5.5511e-17, -7.7716e-16,  ..., -1.3323e-15,
          1.1172e-15, -3.6082e-16],
        [ 7.7716e-16, -1.6653e-16, -3.4972e-15,  ..., -2.2204e-16,
          1.7139e-15, -2.0331e-15]], device='cuda:0', dtype=torch.float64,
       grad_fn=<SubBackward0>)
  • I added the test case which you mentioned. After switching the model into float64 (I used model.double()), it successfully passed the test!

src/mauve/utils.py Outdated Show resolved Hide resolved
@krishnap25
Copy link
Owner

Hi @wade3han,

Thank you for the wonderful detective work! Two minor (and last) changes before we merge the changes in:

  • We want to use a float64 model only for the testing the batched implementation. It is not a good idea to always use float64. Could you delete the model.double() from get_model() and only use it in your test?
  • Could you squash all your commits into one?

Thanks again -- we really appreciate your work!
Krishna

@wade3han
Copy link
Contributor Author

wade3han commented Feb 4, 2022

I reflected your feedbacks :) Thanks!

@krishnap25 krishnap25 merged commit 20613ee into krishnap25:main Feb 4, 2022
@krishnap25
Copy link
Owner

Merged! Thank you for the contribution!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants