Skip to content

Commit

Permalink
fix mpt run.py multi gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
bheilbrun committed Nov 13, 2023
1 parent 975726a commit 416eee2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
1 change: 0 additions & 1 deletion examples/mpt/convert_hf_mpt_to_ft.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def write_zero_bias(weight_name: str, weight_file_path: str,
bias.tofile(bias_file_path)



def convert_weight_to_ft_each(out_dir: str, tensor_parallelism: int,
tensor_name: str, config: Dict[str, Any],
data: np.ndarray, data_type: torch.dtype):
Expand Down
5 changes: 5 additions & 0 deletions examples/mpt/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def read_config(config_path: Path):
tokens_per_block = config['plugin_config']['tokens_per_block']
dtype = config['builder_config']['precision']

num_kv_heads = (num_kv_heads + world_size - 1) // world_size
assert (num_heads % world_size) == 0
num_heads = num_heads // world_size
hidden_size = hidden_size // world_size

model_config = ModelConfig(num_heads=num_heads,
num_kv_heads=num_kv_heads,
hidden_size=hidden_size,
Expand Down

0 comments on commit 416eee2

Please sign in to comment.