Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for llama 3.1 (8b) Fine Tuning & Inference #85

Closed

Conversation

Bihan
Copy link

@Bihan Bihan commented Aug 11, 2024

Changes I made in optimum-tpu

  1. Bumped TGI version from 2.0.3 to 2.2.0
  2. Bumped Transformers from 4.41.1 to 4.43.3
  3. Bumped Rust from '1.77to1.79`
  4. Added otlp_service_name argument in serve() method with default value text-generation-inference.server
  5. Added max_input_tokens argument in serve() method with default value None
  6. Set default value of rope_scaling to dynamic

Please feel free to provide any suggestions for improvement in this PR.

Copy link
Collaborator

@tengomucho tengomucho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this contribution @Bihan ! I just added few comments. Could you please add a non-regression test that uses this model? you could just modify text-generation-inference/tests/test_decode.py and add a params set in the test_decode_single_slow.

@@ -19,6 +19,7 @@ def serve(
logger_level: str = "INFO",
json_output: bool = False,
otlp_service_name: str = "text-generation-inference.server",
max_input_tokens: Optional[int] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you do not need this, if you set the env var MAX_INPUT_LENGTH

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you do not need this, if you set the env var MAX_INPUT_LENGTH

@tengomucho I tried using --max-input-length with text-generation-launcher but received error as below. Could you please elaborate on how we can use MAX_INPUT_LENGTH as env var. Also TGI documentation mentions max_input_length as legacy version of max_input_tokens, will this be a concern?

2024-08-13T16:04:40.833385Z  INFO text_generation_launcher: Args {
    model_id: "meta-llama/Meta-Llama-3.1-8B",
    ...
    ...
}
...

2024-08-13T16:04:43.826483Z  INFO download: text_generation_launcher: Successfully downloaded weights for meta-llama/Meta-Llama-3.1-8B
2024-08-13T16:04:43.826735Z  INFO shard-manager: text_generation_launcher: Starting shard rank=0
2024-08-13T16:04:43.927753Z ERROR shard-manager: text_generation_launcher: Shard complete standard error output:

Usage: text-generation-server serve [OPTIONS] MODEL_ID
Try 'text-generation-server serve --help' for help.

Error: No such option: --max-input-tokens rank=0
2024-08-13T16:04:44.026753Z ERROR text_generation_launcher: Shard 0 failed to start
2024-08-13T16:04:44.026765Z  INFO text_generation_launcher: Shutting down shards
Error: ShardCannotStart```

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, the right parameter to use is "MAX_INPUT_TOKENS", I mixed things up. My point is: what do you want to achieve by adding this to the cli in the text generation server? For now IIRC this value is used in the launcher to define the maximum number of input tokens that can be passed from the router to the server. The server for now does not use that. It is ok to add it to the cli, but to be effective you will also need to add it to the serve function, and do something with it, otherwise it will not have any effect.

@@ -340,6 +340,8 @@ def _init_rope(self):
base=self.rope_theta,
)
else:
# Set Default rope_scaling to dynamic
self.config.rope_scaling.setdefault("type", "dynamic")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A possible alternative could be to do as it's done in transformers:

            if config.rope_scaling is not None:
                self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
            else:
                self.rope_type = "default"

So the value is read if available.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tengomucho For Llama 3.1 I assume we need to replace rope_scaling["type"] to rope_scaling['rope_type'], which means for Llama 3.1 scaling_type would be rope_scaling['rope_type']. If this is the case should we create a new conditional branch for scaling_type = "llama3"? Below is the rope_scaling dict from config.json of Llama 3.1 with rope_type value "llama3"

"rope_scaling": {
    "factor": 8.0,
    "low_freq_factor": 1.0,
    "high_freq_factor": 4.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  }

My draft implementation of _init_rope is as below: Please suggest

def _init_rope(self):
        if self.config.rope_scaling is None:
            self.rotary_emb = LlamaRotaryEmbedding(
                self.head_dim,
                max_position_embeddings=self.max_position_embeddings,
                base=self.rope_theta,
            )
        else:
            scaling_type = self.config.rope_scaling.get("rope_type", self.config.rope_scaling.get("type"))
            scaling_factor = self.config.rope_scaling["factor"]
            if scaling_type == "linear":
                self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            elif scaling_type == "dynamic":
                self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
                    self.head_dim,
                    max_position_embeddings=self.max_position_embeddings,
                    scaling_factor=scaling_factor,
                    base=self.rope_theta,
                )
            elif scaling_type == "llama3"
            # Your suggestion required.
            ......
            ......
            ......
            else:
                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do that @Bihan , but that would mean that we would end up with a code that diverges more from the original transformers' code. I was suggesting to stay as close as possible to their implementation to simplify maintenance: if there is a new update to transformers to support a new feature or fix a bug, it will be easier to support it if optimum-tpu code is more similar.

@Bihan
Copy link
Author

Bihan commented Aug 14, 2024

Thank you for this contribution @Bihan ! I just added few comments. Could you please add a non-regression test that uses this model? you could just modify text-generation-inference/tests/test_decode.py and add a params set in the test_decode_single_slow.

@tengomucho I have added the test for Llama 3.1, which is working. Below is the summary of test

...
...
2024-08-14 01:31:23.732 | DEBUG    | text_generation_server.generator:leave:933 - Joining...
2024-08-14 01:31:37.549 | DEBUG    | text_generation_server.generator:leave:935 - Generator loop finished
PASSED


Error introduced "Smithy" instead of "Smith"
Case: [expected_text=" Winston Smithy, his chin nuzzled into his breast in an effort to escape the vile wind,",]
...
...
2024-08-14 01:35:47.848 | DEBUG    | text_generation_server.generator:leave:933 - Joining...
2024-08-14 01:36:00.114 | DEBUG    | text_generation_server.generator:leave:935 - Generator loop finished
FAILED```

@tengomucho
Copy link
Collaborator

@tengomucho I have added the test for Llama 3.1, which is working. Below is the summary of test

...
...
2024-08-14 01:31:23.732 | DEBUG    | text_generation_server.generator:leave:933 - Joining...
2024-08-14 01:31:37.549 | DEBUG    | text_generation_server.generator:leave:935 - Generator loop finished
PASSED


Error introduced "Smithy" instead of "Smith"
Case: [expected_text=" Winston Smithy, his chin nuzzled into his breast in an effort to escape the vile wind,",]
...
...
2024-08-14 01:35:47.848 | DEBUG    | text_generation_server.generator:leave:933 - Joining...
2024-08-14 01:36:00.114 | DEBUG    | text_generation_server.generator:leave:935 - Generator loop finished
FAILED```

Nice you added the test! Do you think you can fix it to make it pass and add it to the current branch?

@Bihan
Copy link
Author

Bihan commented Aug 19, 2024

@tengomucho I have added the test for Llama 3.1, which is working. Below is the summary of test

...
...
2024-08-14 01:31:23.732 | DEBUG    | text_generation_server.generator:leave:933 - Joining...
2024-08-14 01:31:37.549 | DEBUG    | text_generation_server.generator:leave:935 - Generator loop finished
PASSED


Error introduced "Smithy" instead of "Smith"
Case: [expected_text=" Winston Smithy, his chin nuzzled into his breast in an effort to escape the vile wind,",]
...
...
2024-08-14 01:35:47.848 | DEBUG    | text_generation_server.generator:leave:933 - Joining...
2024-08-14 01:36:00.114 | DEBUG    | text_generation_server.generator:leave:935 - Generator loop finished
FAILED```

Nice you added the test! Do you think you can fix it to make it pass and add it to the current branch?

@tengomucho Yes test is working. I will add it to PR along with other issues.

@Bihan
Copy link
Author

Bihan commented Aug 29, 2024

Closing this PR. New PR will be sent from dstack

@Bihan Bihan closed this Aug 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants