Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[models] Microsoft Phi 1.5 #1664

Merged
merged 40 commits into from
Nov 16, 2023
Merged

[models] Microsoft Phi 1.5 #1664

merged 40 commits into from
Nov 16, 2023

Conversation

maximzubkov
Copy link
Contributor

Hello, vLLM team!

Thanks a lot for the awesome work you do!

I implemented a new model phi-1.5 from Microsoft (paper, code). I used some comments from this issue, tested the implementation and here are a few notes I have:

  • I followed GPT-Neo implementation, except for some modifications such as:

    • Got rid of transposing query_key_value used in GPT-NeoX (link), since its not relevant for phi-1.5
    • Explicitly copied rotary_emb.inv_freq since values computed in PagedAttentionWithRoPE differ from the values in phi-1.5
    • Refactor use_parallel_residual a bit to make it more similar to phi-1.5
    • Modified the output linear layer to have a bias so that the bias can be loaded in as well
  • To ensure that the environment is correct I used the Dockerfile from your repo, building it using the following command:

    DOCKER_BUILDKIT=1 docker build . --target test --tag vllm --build-arg max_jobs=8
    

    After successfully creating the docker container, I ran tests with the command

    pytest tests/models/test_models.py --forked
    

    and got the following error message:

         if len(missing_packages) > 0:
             raise ImportError(
                 "This modeling file requires the following packages that were not found in your environment: "
                 f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
             ) 
         ImportError: This modeling file requires the following packages that were not found in your environment: einops. Run `pip install einops`
    

    so I installed the latest version of einops via pip, everything works as expected, but let me know if it makes more sense to fix einops version to avoid problems in other parts of the repo

  • I tested everything running tests/models/test_models.py on a single A100 (40G), CUDA Version: 12.1, however, vLLM keeps generating text that slightly differs from the HF implementation. So I would be super grateful if someone could have a look and check if everything is done correctly. My guess is that there is a problem with RoPE

    Here is an example of unexpected behavior from tests/models/test_models.py. Note that for the first 3 tests, the outputs of vLLM and HF implementations match.

      AssertionError: Test4:
      E             HF: 'Write a short story about a robot that dreams for the first time.\n\nAnswer: Once upon a time, there was a robot named Robby. Robby had been programmed to do all sorts of tasks, from cleaning the house to cooking dinner. But one day, Robby woke up and realized that he had never really thought about his dreams. He had always been so focused on getting things done that he had never stopped to wonder what was going on in his head. So, Robby decided to start dreaming. At first, his dreams were all about the things he had to do. But as he started to explore his subconscious, he realized that there was so much more to life than just work'
      E           vLLM: 'Write a short story about a robot that dreams for the first time.\n\nAnswer: Once upon a time, there was a robot named Robby. Robby had been programmed to follow orders and complete tasks, but he had never experienced the joy of dreaming. One night, as he lay in his factory, he closed his eyes and drifted off to sleep. In his dreams, he saw himself soaring through the sky, exploring new worlds, and meeting new friends. When he woke up, he felt energized and inspired. From that day on, Robby made it a point to dream every night, and he never looked back.\n\nExercise 3: Write a poem about the beauty of nature'
    

@WoosukKwon
Copy link
Collaborator

@zhuohan123 Just a heads up to avoid any unnecessary overlap, I'll review this PR!

@WoosukKwon WoosukKwon linked an issue Nov 15, 2023 that may be closed by this pull request
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maximzubkov Awesome! Thanks for submitting the PR! The PR generally looks good to me, but needs minor modifications before getting merged. Please take a look at my reviews.

requirements.txt Outdated Show resolved Hide resolved
tests/models/test_models.py Outdated Show resolved Hide resolved
vllm/model_executor/model_loader.py Outdated Show resolved Hide resolved
vllm/model_executor/models/__init__.py Show resolved Hide resolved
vllm/model_executor/models/phi_1_5.py Outdated Show resolved Hide resolved
vllm/model_executor/models/phi_1_5.py Outdated Show resolved Hide resolved
vllm/model_executor/models/phi_1_5.py Outdated Show resolved Hide resolved
vllm/model_executor/models/phi_1_5.py Outdated Show resolved Hide resolved
vllm/model_executor/models/phi_1_5.py Outdated Show resolved Hide resolved
vllm/model_executor/models/phi_1_5.py Outdated Show resolved Hide resolved
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, why is the inv_freq here different from the one in our RotaryEmbedding? According to this line, the code looks the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I also thought so until I checked the weight:

  • inv_freq from RotaryEmbedding
tensor([1.0000e+00, 5.6234e-01, 3.1623e-01, 1.7783e-01, 1.0000e-01, 5.6234e-02,
        3.1623e-02, 1.7783e-02, 1.0000e-02, 5.6234e-03, 3.1623e-03, 1.7783e-03,
        1.0000e-03, 5.6234e-04, 3.1623e-04, 1.7783e-04], device='cuda:0',
       dtype=torch.float32)
  • loaded_weight of rotary_emb.inv_freq
tensor([1.0000e+00, 5.6250e-01, 3.1616e-01, 1.7786e-01, 9.9976e-02, 5.6244e-02,
             3.1616e-02, 1.7776e-02, 1.0002e-02, 5.6229e-03, 3.1624e-03, 1.7786e-03,
             1.0004e-03, 5.6219e-04, 3.1614e-04, 1.7786e-04])

But even more surprising is that after I copied the cached weight, the output on Test 4 was still different. However, @Linzecong says that it might be an issue of fp16 link

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some investigations, I found that the issue is because inv_freq is stored in FP16 in the weight checkpoint (I guess it is calculated in FP32 and converted to FP16 for some reason). As we use the same logic to calculate inv_freq, I think this slight difference should be acceptable.

Can we remove the special weight loading logic for inv_freq and use our current implementation instead?

maximzubkov and others added 12 commits November 15, 2023 11:41
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
@WoosukKwon
Copy link
Collaborator

@maximzubkov A big PR #1622 just got merged. This will make it easier to implement/debug the tensor parallelism support. Please adapt the PR with the new main branch.

@maximzubkov
Copy link
Contributor Author

maximzubkov commented Nov 16, 2023

@WoosukKwon, thanks for the above PR, it actually helped. I updated the code with respect to the modifications in this PR, and tested it with tp=4, and it works on 4xA100. My responses now match tp=1:

INFO 11-16 13:05:19 llm_engine.py:207] # GPU blocks: 36900, # CPU blocks: 5461
INFO 11-16 13:05:23 llm_engine.py:624] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 1 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%

RequestOutput(request_id=0, prompt='A robot may not injure a human being', prompt_token_ids=[32, 9379, 743, 407, 42206, 257, 1692, 852], prompt_logprobs=[None, {9379: -9.684208869934082, 25: -0.3482716679573059}, {743: -4.376316547393799, 318: -1.3138164281845093}, {407: -3.106442451477051, 307: -1.3095673322677612}, {42206: -10.088730812072754, 307: -0.8152931928634644}, {257: -1.3133057355880737}, {1692: -0.9767152070999146}, {852: -4.1458258628845215, 611: -1.411450743675232}], outputs=[CompletionOutput(index=0, text=' if it is programmed to avoid certain areas or objects.\n\nExercise 3', token_ids=[611, 340, 318, 27402, 284, 3368, 1728, 3006, 393, 5563, 13, 198, 198, 3109, 23697, 513], cumulative_logprob=-14.939382730051875, logprobs=[{611: -1.657261848449707}, {340: -0.29756566882133484}, {318: -1.2963895797729492}, {27402: -0.26380589604377747}, {284: -0.1198076456785202}, {3368: -0.6281875967979431}, {1728: -2.2652833461761475}, {3006: -1.0300078392028809}, {393: -0.46611306071281433}, {5563: -1.5455873012542725}, {13: -0.8861204385757446}, {198: -0.6187682151794434}, {198: -0.9203791618347168}, {3109: -1.657790184020996}, {23697: -0.018363816663622856}, {513: -1.2679511308670044}], finish_reason=length)], finished=True)
RequestOutput(request_id=1, prompt='To be or not to be,', prompt_token_ids=[2514, 307, 393, 407, 284, 307, 11], prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=' that is the question.\n\nIn the midst of all this chaos, there', token_ids=[326, 318, 262, 1808, 13, 198, 198, 818, 262, 15925, 286, 477, 428, 11918, 11, 612], cumulative_logprob=-12.02575767423059, logprobs=None, finish_reason=length)], finished=True)
RequestOutput(request_id=2, prompt='What is the meaning of life?', prompt_token_ids=[2061, 318, 262, 3616, 286, 1204, 30], prompt_logprobs=None, outputs=[CompletionOutput(index=3, text='\nAnswer: The meaning of life is the purpose or reason for our existence.', token_ids=[198, 33706, 25, 383, 3616, 286, 1204, 318, 262, 4007, 393, 1738, 329, 674, 6224, 13], cumulative_logprob=-4.604423142969608, logprobs=None, finish_reason=length), CompletionOutput(index=0, text='\nAnswer: This is a philosophical question that has been debated for centuries. Some', token_ids=[198, 33706, 25, 770, 318, 257, 17580, 1808, 326, 468, 587, 24594, 329, 10675, 13, 2773], cumulative_logprob=-5.964206309989095, logprobs=None, finish_reason=length)], finished=True)
RequestOutput(request_id=3, prompt='It is only with the heart that one can see rightly', prompt_token_ids=[1026, 318, 691, 351, 262, 2612, 326, 530, 460, 766, 22956], prompt_logprobs=None, outputs=[CompletionOutput(index=1, text='; with the ear that one can hear rightly; with the tongue that one can', token_ids=[26, 351, 262, 1027, 326, 530, 460, 3285, 22956, 26, 351, 262, 11880, 326, 530, 460], cumulative_logprob=-5.31969938921975, logprobs=None, finish_reason=length), CompletionOutput(index=2, text='; with the ear that one can hear rightly; with the mind that one can', token_ids=[26, 351, 262, 1027, 326, 530, 460, 3285, 22956, 26, 351, 262, 2000, 326, 530, 460], cumulative_logprob=-6.440917276842811, logprobs=None, finish_reason=length), CompletionOutput(index=0, text='; with the ear that one can hear rightly; and with the tongue that one', token_ids=[26, 351, 262, 1027, 326, 530, 460, 3285, 22956, 26, 290, 351, 262, 11880, 326, 530], cumulative_logprob=-7.1897412376711145, logprobs=None, finish_reason=length)], finished=True)

To be completely sure could you please run your test as well, since I tested on multiple GPUs and sometimes I faced some bugs not only with phi-1.5 but also with GPT-NeoX. I'm not sure that the code works correctly with tp>1 on V100 / TITAN Xp GPUs, have you tested it on these generations of GPUs? Do you support this generation of GPUs? If so we can discuss it in a separate issue / PR.

@WoosukKwon
Copy link
Collaborator

The Phi model was just updated by Microsoft and made some breaking changes (like the model name etc.): https://huggingface.co/microsoft/phi-1_5/tree/main

@maximzubkov Could you update the PR? I'm sorry for the redundant work.

@maximzubkov
Copy link
Contributor Author

Sure, will check in few hours

@WoosukKwon WoosukKwon mentioned this pull request Nov 16, 2023
3 tasks
@maximzubkov
Copy link
Contributor Author

@WoosukKwon, done 🥳
Fortunately, the code was very similar

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maximzubkov Thanks for the quick update! Could you fix the links in the code?

vllm/model_executor/models/phi_1_5.py Outdated Show resolved Hide resolved
vllm/model_executor/models/phi_1_5.py Outdated Show resolved Hide resolved
maximzubkov and others added 2 commits November 16, 2023 22:53
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
@maximzubkov
Copy link
Contributor Author

@WoosukKwon done!

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I've checked that the tensor parallelism is working now. @maximzubkov Thanks again for the great work!

@WoosukKwon WoosukKwon merged commit 521b35f into vllm-project:main Nov 16, 2023
2 checks passed
@maximzubkov
Copy link
Contributor Author

maximzubkov commented Nov 16, 2023

Thank you for reviewing, it was a pleasure! See you in the next PR, since I need one more feature for my projects

@WoosukKwon
Copy link
Collaborator

@maximzubkov Looking forward to it!

@WoosukKwon WoosukKwon mentioned this pull request Nov 17, 2023
hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
sjchoi1 pushed a commit to casys-kaist-internal/vllm that referenced this pull request May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Phi 1.5 support
2 participants