-
Notifications
You must be signed in to change notification settings - Fork 867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding TP llama example #2623
adding TP llama example #2623
Conversation
Codecov Report
@@ Coverage Diff @@
## master #2623 +/- ##
=======================================
Coverage 72.39% 72.39%
=======================================
Files 85 85
Lines 3956 3956
Branches 58 58
=======================================
Hits 2864 2864
Misses 1088 1088
Partials 4 4 📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, please see comments. A unit test for the handler would also be great. You can mock out the model etc but test the logic.
### How to use it? | ||
|
||
|
||
1- Make sure you have access to llama weights on [HF model hub](https://huggingface.co/meta-llama), there is form you need to fill up and within few mins you will get access. ANy model name on the hub **without -hf** is Meta/FAIR weight. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo:
llama -> Llama
there is form -> there is a form
fill up -> fill out
nit: Any model name -> Any Llama model name
""" | ||
if isinstance(input_text, (bytes, bytearray)): | ||
input_text = input_text.decode("utf-8") | ||
logger.info("Received text: '%s'", input_text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be debug?
@@ -0,0 +1 @@ | |||
Hey, are you conscious? Can you talk to me? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️ that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I added some comments so that the model artifacts with converted checkpoint can be uploaded to S3, which can be directly used by cx.
converted_ckpt_dir=ctx.model_yaml_config["handler"]["converted_ckpt_dir"], | ||
tokenizer_path= ctx.model_yaml_config["handler"]["tokenizer_path"], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you replace them with the following:
converted_ckpt_dir=f'{model_dir}/{ctx.model_yaml_config["handler"]["converted_ckpt_dir"]}',
tokenizer_path= f'{model_dir}/{ctx.model_yaml_config["handler"]["tokenizer_path"]}',
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this and the one below errors out the reason is we are not bundling these files into mar file
FileNotFoundError: [Errno 2] No such file or directory
'/tmp/models/38d7adf021da4f278f6711ff1584fac0/llama//data/home/hamidnazeri/fresh_ts/serve/examples/large_models/tp_llama/model_args.json'
logger.info("Instantiating Llama model") | ||
model_load_start = time.perf_counter() | ||
llama_model_and_tok= Llama.build( | ||
model_args=ctx.model_yaml_config["handler"]["model_args_path"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you change:
model_args=f'{model_dir}/{ctx.model_yaml_config["handler"]["model_args_path"]}'
converted_ckpt_dir: "PATH/TO/converted_checkpoints" | ||
tokenizer_path: "/PATH/TO/MODEL/CHECKPOINTS/tokenizer.model" | ||
model_args_path: "PATH/TO/model_args.json" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove ""PATH"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what I should I replace it with?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just remove "PATH/"
Create the mar file using the following command here. | ||
|
||
``` | ||
torch-model-archiver --model-name llama --version 1.0 --handler llama-handler.py --config-file model-config.yaml --archive-format tgz --extra-files "llama2.py,llama2_tokenizer.py,generate.py,checkpoint_converter.py" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you change to:
torch-model-archiver --model-name llama --version 1.0 --handler llama-handler.py --config-file model-config.yaml --archive-format no-archive --extra-files "llama2.py,llama2_tokenizer.py,generate.py,checkpoint_converter.py"
mv TO llama/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lxning changed the packaging step as suggested just cp files instead of mv.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't cp. you can just exactly copy my comments
Description
Adding Pytorch TP example for llama2, the idea here is to get "meta/ original" weights from HF model hub, do a checkpoint conversion and do the distributed inference with TP. "meta/ original" llama2 model is using Fairscale TP, which we are trying to use PyTorch TP instead here.
Fixes #(issue)
Type of change
Please delete options that are not relevant.
Feature/Issue validation/testing
Please describe the Unit or Integration tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
Please also list any relevant details for your test configuration.
https://gist.github.com/HamidShojanazeri/1c5927759c0c04fcddc8a714bb26365b
13B client
https://gist.github.com/HamidShojanazeri/46d3882ac9fe55fcbdcac4a9dfcae9b5
Logs for Test B
Checklist: