Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add open-llama model with ckpt #22795

Merged
merged 19 commits into from
Apr 28, 2023
Merged

add open-llama model with ckpt #22795

merged 19 commits into from
Apr 28, 2023

Conversation

s-JoL
Copy link
Contributor

@s-JoL s-JoL commented Apr 16, 2023

This PR adds a new model called Open-Llama, which is based on Llama's implementation in Transformers.
In Open-Llama, emory-efficient attention has been added, resulting in a 30% improvement in training efficiency. Additionally, hidden dropout and attention dropout have been added for better generalization during training.

We have also added two optional features: stable embedding from Bloom and shared input-output vectors from PALM, which have been tested and found to improve training stability and performance.

The following code snippet shows the implementation of memory-efficient attention:

try:
    from xformers import ops as xops
except ImportError:
    xops = None
    print("xformers is not installed correctly.")

if self.config.use_memorry_efficient_attention and xops is not None and self.training:
    attn_weights = None
    query_states = query_states.transpose(1, 2)
    key_states = key_states.transpose(1, 2)
    value_states = value_states.transpose(1, 2)
    attn_output = xops.memory_efficient_attention(
        query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask(), p=self.dropout_prob
    )

At the same time, for maximum compatibility, we have made xformers an optional dependency so that the original implementation can still be used for training and inference if it is not installed.

We implemented pre-training of the Llama model based on transformers + accelerate, incorporating the modifications described above.
Open-Llama

The pre-trained model has already been open-sourced on s-JoL/Open-Llama-V1.

ref: #22386

cc: @sgugger

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 20, 2023

The documentation is not available anymore as the PR was closed or merged.

@s-JoL s-JoL changed the title Dev add open-llama model with ckpt Apr 21, 2023
@sgugger
Copy link
Collaborator

sgugger commented Apr 21, 2023

cc @ArthurZucker and @younesbelkada

@s-JoL
Copy link
Contributor Author

s-JoL commented Apr 25, 2023

Please help me review this pull request. @ArthurZucker @younesbelkada

@ArthurZucker
Copy link
Collaborator

Hey! Thanks will review now

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! Seems like the model is overlall very similar, so missing bunch of copied form here and there. Most importantly I dont think we need a new tokenizer, it's still llama tokenizer.

README.md Outdated Show resolved Hide resolved
docs/source/en/_toctree.yml Outdated Show resolved Hide resolved
docs/source/en/model_doc/open-llama.mdx Outdated Show resolved Hide resolved
src/transformers/models/auto/tokenization_auto.py Outdated Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not convinced that you need a new configuration file either. Args can be added kind of the fly and not be in the default llama config WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm concerned that using the default LlamaConfig directly may result in missing parameters and cause errors.

src/transformers/models/open_llama/modeling_open_llama.py Outdated Show resolved Hide resolved
"""


@add_start_docstrings(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing copied from statements

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, i did't quite understand how to add the copied from statements for this class, there are slight differences here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok you can keep it as is!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, waiting for @sgugger's review

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here, is this not the same as in the llama folder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the reminder. This file is identical to the one in Llama, and since I trained directly with Transformers, there is no need for any conversion. I will delete it.

"""


@add_start_docstrings(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok you can keep it as is!

Comment on lines +745 to +756
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

We usually just use this, but I am guessing the point of the PR is fast / model paralellism so ignore my comment if this doesn't work (we leave parallelism to accelerate)

Comment on lines +19 to +20
The model is mainly based on LLaMA with some modifications, incorporating memory-efficient attention from Xformers, stable embedding from Bloom, and shared input-output embedding from PLAM.
And the model is pre-trained on both Chinese and English, which gives it better performance on Chinese language tasks.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have them, would be cool to add the performance gains here!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great suggestion, but currently I have not conducted a complete ablation experiment. I plan to gradually add it to the documentation after conducting the experiment.

@ArthurZucker ArthurZucker requested a review from sgugger April 28, 2023 13:27
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean, thanks a lot for adding this! I have just a comment on the config and default checkpoint.

Comment on lines 41 to 43
warnings.warn(
"Xformers is not installed correctly. If you want to use memorry_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use our logger here with logger.warn (so move this after the logger is defined below).

@@ -42,6 +42,7 @@
"VisionEncoderDecoderConfig",
"VisionTextDualEncoderConfig",
"LlamaConfig",
"OpenLlamaConfig",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be removed as there is a checkpoint for OpenLlama.

r"""
This is the configuration class to store the configuration of a [`OpenLlamaModel`]. It is used to instantiate an
Open-Llama model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the Open-Llama-7B.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put the full checkpoint name here and link to the Hub. Example we have for GPT-2:

a similar configuration to that of the [gpt2](https://huggingface.co/gpt2) architecture.

It wasn't there for Llama since there is no official checkpoint on the Hub.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review. The three issues mentioned have been fixed.

@sgugger sgugger merged commit c2c99dc into huggingface:main Apr 28, 2023
@sgugger
Copy link
Collaborator

sgugger commented Apr 28, 2023

Thanks a lot for your contribution!

@s-JoL
Copy link
Contributor Author

s-JoL commented May 11, 2023

Thanks a lot for your contribution!

Hello, I have a question, why the open-Llama model cannot be searched in the documentation of transformers? Is there something I forgot to add?

image

@amyeroberts
Copy link
Collaborator

Hi @s-JoL, thanks for notifying.

There was an issue in the doc rendering (resolved with 1, 2) leading to some pages not being retrievable in search. Should be working now!

@PenutChen
Copy link
Contributor

@s-JoL I noticed that the links pertaining to Open-LLaMA are currently leading to 404 errors. Could you please provide some information on what might have happened?

@heya5
Copy link
Contributor

heya5 commented May 24, 2023

@s-JoL Hi, I can't find a Open-LLaMA checkpoint and I noticed you delete your original repo. What happend? How Can I have a try of Open-LLaMA?

gojiteji pushed a commit to gojiteji/transformers that referenced this pull request Jun 5, 2023
* update Open-Llama model

* update

* update format

* update doc

* update

* update stable embedding test

* update test case

* update format

* update readme

* fix typo

* update name

* remove tokenizer and update format

* remove convert_open_llama_weights_to_hf

* update warning and doc_string

---------

Co-authored-by: songliang.bayesian <songliang.bayesian@bytedance.com>
@PenutChen
Copy link
Contributor

@heya5 Possibly due to some controversies surrounding this project, the original author has closed the original project.
https://github.com/chenfeng357/open-Chinese-ChatLLaMA/issues/1

novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* update Open-Llama model

* update

* update format

* update doc

* update

* update stable embedding test

* update test case

* update format

* update readme

* fix typo

* update name

* remove tokenizer and update format

* remove convert_open_llama_weights_to_hf

* update warning and doc_string

---------

Co-authored-by: songliang.bayesian <songliang.bayesian@bytedance.com>
tomaarsen added a commit to tomaarsen/transformers that referenced this pull request Jul 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants