-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Sharding with PP_Degree > 1 #2717
base: master
Are you sure you want to change the base?
Support Sharding with PP_Degree > 1 #2717
Conversation
@@ -32,6 +32,8 @@ | |||
|
|||
CHUNK_MB = 8 | |||
|
|||
CONFIG_FILENAME = "sagemaker-fast-model-loader-manifest.json" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this is unused, can remove
pp_degree: int, tp_degree: int, | ||
chunk_mb: int) -> None: | ||
pp_degree: int, tp_degree: int, chunk_mb: int, | ||
pp_rank_to_shard: int) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe let's rename pp_rank_to_shard
to target_pp_rank
for clarity
def save_configs(self, | ||
input_dir: str, | ||
output_dir: str, | ||
configs: list, | ||
pp_rank: int = 0) -> None: | ||
for entry in configs: | ||
conf.add_shard( | ||
if entry["pp"] != pp_rank: | ||
continue | ||
self.shard_config.add_shard( | ||
pipeline_parallel_degree=int(entry["pp"]), | ||
tensor_parallel_degree=int(entry["tp"]), | ||
shard_config=entry["config"], | ||
) | ||
|
||
conf.save(output_dir=output_dir) | ||
logging.info( | ||
f"SageMaker Fast Model Loader config file saved to {output_dir}") | ||
self.copy_non_safetensors_files(input_dir, output_dir) | ||
logging.info(f"Other non-Safetensors files copied to {output_dir}") | ||
if pp_rank == self.pp_degree - 1: | ||
self.shard_config.save(output_dir=output_dir) | ||
logging.info( | ||
f"SageMaker Fast Model Loader config file saved to {output_dir}" | ||
) | ||
self.copy_non_safetensors_files(input_dir, output_dir) | ||
logging.info(f"Other non-Safetensors files copied to {output_dir}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if we reorganize the logic here to make the intention more clear, maybe something along the lines of
- instead of passing in pp_rank here, we pass a boolean like
should_save_final_config
, and move the check for setting this value to the caller - instead of checking for misaligned pp_rank, we check for an empty config object to determine if we should skip adding a shard to the ModelConfig
tp_degree=int( | ||
self.properties["option.tensor_parallel_degree"]), | ||
chunk_mb=CHUNK_MB) | ||
for i in range(self.pp_degree): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes only enable sharding for multi-node when the nodes are split by pipeline parallelism, e.g. tp=8, pp=2; we need to also account for the tensor parallel-only multi-node case, e.g. tp=16, which is rarer but could be required or more performant in some scenarios.
So we need to check if the tensor parallel degree is greater than what fits on a single node, and if so then partition the ranks accordingly and shard each node’s world sequentially, as we are doing here for pipeline
Description
Support Sharding with PP_Degree > 1
NeoShardingService
classsave_configs
so it save info to model config obj gradually. It will only output the file to disk when it gets the last pp_rank.pp_rank_to_shard
param toshard_lmi_dist_model
run_sharding
, we create a loop of pp_degree, so it calls shard_lmi_dist_model for each pp_rank.this will break if the lmi_dist is not updated
Type of change
Please delete options that are not relevant.
Checklist:
pytest tests.py -k "TestCorrectnessLmiDist" -m "lmi_dist"
Feature/Issue validation/testing
Tested with
llama3 405b pp=2 tp=8, sharded on one instance, and tested serving with two p4de instances.
llama3 8b pp=2 tp=4, sharded on one instance, and tested serving with two g6.12xlarge instances.
llama3 8b pp=1 tp=4, sharded on one instance, and tested serving with one g6.12xlarge instances.