-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for llama 3.1 (8b) Fine Tuning & Inference #85
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this contribution @Bihan ! I just added few comments. Could you please add a non-regression test that uses this model? you could just modify text-generation-inference/tests/test_decode.py
and add a params set in the test_decode_single_slow.
@@ -19,6 +19,7 @@ def serve( | |||
logger_level: str = "INFO", | |||
json_output: bool = False, | |||
otlp_service_name: str = "text-generation-inference.server", | |||
max_input_tokens: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you do not need this, if you set the env var MAX_INPUT_LENGTH
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you do not need this, if you set the env var MAX_INPUT_LENGTH
@tengomucho I tried using --max-input-length
with text-generation-launcher
but received error as below. Could you please elaborate on how we can use MAX_INPUT_LENGTH as env var. Also TGI documentation mentions max_input_length as legacy version of max_input_tokens, will this be a concern?
2024-08-13T16:04:40.833385Z INFO text_generation_launcher: Args {
model_id: "meta-llama/Meta-Llama-3.1-8B",
...
...
}
...
2024-08-13T16:04:43.826483Z INFO download: text_generation_launcher: Successfully downloaded weights for meta-llama/Meta-Llama-3.1-8B
2024-08-13T16:04:43.826735Z INFO shard-manager: text_generation_launcher: Starting shard rank=0
2024-08-13T16:04:43.927753Z ERROR shard-manager: text_generation_launcher: Shard complete standard error output:
Usage: text-generation-server serve [OPTIONS] MODEL_ID
Try 'text-generation-server serve --help' for help.
Error: No such option: --max-input-tokens rank=0
2024-08-13T16:04:44.026753Z ERROR text_generation_launcher: Shard 0 failed to start
2024-08-13T16:04:44.026765Z INFO text_generation_launcher: Shutting down shards
Error: ShardCannotStart```
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right, the right parameter to use is "MAX_INPUT_TOKENS", I mixed things up. My point is: what do you want to achieve by adding this to the cli in the text generation server? For now IIRC this value is used in the launcher to define the maximum number of input tokens that can be passed from the router to the server. The server for now does not use that. It is ok to add it to the cli, but to be effective you will also need to add it to the serve
function, and do something with it, otherwise it will not have any effect.
@@ -340,6 +340,8 @@ def _init_rope(self): | |||
base=self.rope_theta, | |||
) | |||
else: | |||
# Set Default rope_scaling to dynamic | |||
self.config.rope_scaling.setdefault("type", "dynamic") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A possible alternative could be to do as it's done in transformers:
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
So the value is read if available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tengomucho For Llama 3.1 I assume we need to replace rope_scaling["type"]
to rope_scaling['rope_type']
, which means for Llama 3.1 scaling_type
would be rope_scaling['rope_type']
. If this is the case should we create a new conditional branch for scaling_type = "llama3"
? Below is the rope_scaling
dict from config.json
of Llama 3.1 with rope_type value "llama3"
"rope_scaling": {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
}
My draft implementation of _init_rope is as below: Please suggest
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling.get("rope_type", self.config.rope_scaling.get("type"))
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "llama3"
# Your suggestion required.
......
......
......
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could do that @Bihan , but that would mean that we would end up with a code that diverges more from the original transformers' code. I was suggesting to stay as close as possible to their implementation to simplify maintenance: if there is a new update to transformers to support a new feature or fix a bug, it will be easier to support it if optimum-tpu code is more similar.
@tengomucho I have added the test for Llama 3.1, which is working. Below is the summary of test
|
Nice you added the test! Do you think you can fix it to make it pass and add it to the current branch? |
@tengomucho Yes test is working. I will add it to PR along with other issues. |
Closing this PR. New PR will be sent from dstack |
Changes I made in optimum-tpu
TGI
version from2.0.3 to 2.2.0
Transformers
from4.41.1 to 4.43.3
Rust
from '1.77to
1.79`otlp_service_name
argument in serve() method with default valuetext-generation-inference.server
max_input_tokens
argument in serve() method with default valueNone
rope_scaling
todynamic
Please feel free to provide any suggestions for improvement in this PR.