Skip to content

Commit 275390b

Browse files
committed
[feat] Update sync model by tensor, fix tMbs problem, add qwen train benchmark.
1 parent d9b5f10 commit 275390b

File tree

16 files changed

+492
-60
lines changed

16 files changed

+492
-60
lines changed

applications/ColossalChat/coati/distributed/comm.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,32 @@ def ray_broadcast_tensor_dict(
5555
if rank == src:
5656
out_dict = tensor_dict
5757
return out_dict
58+
59+
60+
def ray_broadcast_tensor_dict_and_load(
61+
producer_obj, tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default"
62+
):
63+
rank = cc.get_rank(group_name)
64+
if rank == src:
65+
metadata = []
66+
for k, v in tensor_dict.items():
67+
metadata.append((k, v.shape, v.dtype))
68+
else:
69+
metadata = None
70+
metadata = ray_broadcast_object(metadata, src, device, group_name)
71+
for k, shape, dtype in metadata:
72+
if "consumer_global_step" == k:
73+
continue
74+
if rank == src:
75+
tensor = tensor_dict[k]
76+
else:
77+
out_dict = {}
78+
tensor = torch.empty(shape, dtype=dtype, device=device)
79+
cc.broadcast(tensor, src, group_name)
80+
if rank != src:
81+
out_dict[k] = tensor
82+
producer_obj.load_state_dict(out_dict)
83+
del out_dict
84+
torch.npu.empty_cache()
85+
if rank == src:
86+
out_dict = tensor_dict

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from colossalai.initialize import launch
1616
from colossalai.nn.optimizer import HybridAdam
1717

18-
from .comm import ray_broadcast_tensor_dict
18+
from .comm import ray_broadcast_tensor_dict, ray_broadcast_tensor_dict_and_load
1919
from .utils import bind_batch, post_recv, unbind_batch
2020

2121

@@ -172,6 +172,8 @@ def loop(self) -> None:
172172
)
173173
self.profiler.enter("step")
174174
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
175+
del batch
176+
del raw_mini_batches_metric_dict
175177
self.profiler.exit("step")
176178
self.buffer = self.buffer[
177179
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
@@ -303,16 +305,21 @@ def loop(self) -> None:
303305
state_dict = self.state_dict()
304306
if self.pp_size > 1:
305307
if self.tp_rank == 0 and self.dp_rank == 0:
306-
ray_broadcast_tensor_dict(
308+
ray_broadcast_tensor_dict_and_load(
309+
None,
307310
state_dict,
308311
src=self.num_producers,
309312
device=self.device,
310313
group_name=f"sync_model_{self.pp_rank}",
311314
)
312315
else:
313316
if self.rank == 0:
314-
ray_broadcast_tensor_dict(
315-
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
317+
ray_broadcast_tensor_dict_and_load(
318+
None,
319+
state_dict,
320+
src=self.num_producers,
321+
device=self.device,
322+
group_name="sync_model",
316323
)
317324
del state_dict
318325
torch.npu.empty_cache()

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(
6262
batch_size,
6363
model_config,
6464
plugin_config,
65+
generate_config,
6566
minibatch_size,
6667
save_interval=save_interval,
6768
save_dir=save_dir,

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torch.utils.data import DataLoader, DistributedSampler
1818
from transformers import AutoTokenizer
1919

20-
from .comm import ray_broadcast_tensor_dict
20+
from .comm import ray_broadcast_tensor_dict, ray_broadcast_tensor_dict_and_load
2121
from .inference_backend import BACKEND_MAP
2222
from .utils import safe_append_to_jsonl_file
2323

@@ -191,6 +191,7 @@ def setup(self) -> None:
191191
)
192192
else:
193193
cc.init_collective_group(self.num_producers + 1, self.producer_idx, backend="hccl", group_name="sync_model")
194+
cc.init_collective_group(self.num_producers, self.producer_idx, backend="hccl", group_name="producer_group")
194195

195196
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
196197
raise NotImplementedError
@@ -340,25 +341,16 @@ def loop(self) -> None:
340341
print(
341342
f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(i + 1) // self.num_microbatches - 1}"
342343
)
343-
state_dict = ray_broadcast_tensor_dict(
344-
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
344+
ray_broadcast_tensor_dict_and_load(
345+
self, None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
345346
)
346-
if "consumer_global_step" in state_dict:
347-
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
348-
self.load_state_dict(state_dict)
349347
else:
350348
print(
351349
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
352350
)
353-
state_dict = ray_broadcast_tensor_dict(
354-
None, self.num_producers, device=self.device, group_name="sync_model"
351+
ray_broadcast_tensor_dict_and_load(
352+
self, None, self.num_producers, device=self.device, group_name=f"sync_model"
355353
)
356-
if "consumer_global_step" in state_dict:
357-
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
358-
self.load_state_dict(state_dict)
359-
self.profiler.exit("sync_model")
360-
del state_dict
361-
torch.npu.empty_cache()
362354
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
363355
"enable_sleep_mode", False
364356
):

applications/ColossalChat/rl_example.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@
166166
parser.add_argument(
167167
"--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process."
168168
)
169+
parser.add_argument("--cpu_offload", action="store_true", default=False, help="Cpu offload.")
169170
args = parser.parse_args()
170171

171172
if args.train_minibatch_size is None:
@@ -251,7 +252,7 @@
251252
)
252253
generate_config.update(
253254
dict(
254-
max_tokens=args.max_new_tokens, # max new tokens
255+
max_tokens=args.max_new_tokens + args.max_prompt_tokens, # max new tokens
255256
include_stop_str_in_output=True,
256257
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
257258
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
@@ -344,6 +345,7 @@
344345
1, args.train_microbatch_size // args.pipeline_parallel_size
345346
), # microbatch size should be set to train_microbatch_size // pp_size
346347
"zero_stage": args.zero_stage,
348+
"cpu_offlpad": args.cpu_offload,
347349
"max_norm": 1.0,
348350
"enable_flash_attention": True,
349351
"sp_size": args.tensor_parallel_size,

colossalai/shardformer/modeling/qwen2.py

Lines changed: 9 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@
1212
)
1313

1414
try:
15-
from transformers.modeling_attn_mask_utils import (
16-
_prepare_4d_causal_attention_mask,
17-
_prepare_4d_causal_attention_mask_for_sdpa,
18-
)
15+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
1916
from transformers.models.qwen2.modeling_qwen2 import (
2017
Qwen2Attention,
2118
Qwen2ForCausalLM,
@@ -132,46 +129,20 @@ def qwen2_model_forward(
132129
else:
133130
position_ids = position_ids.view(-1, seq_length).long()
134131

135-
if (
136-
not shard_config.enable_flash_attention
137-
and attention_mask is not None
138-
and self._attn_implementation == "flash_attention_2"
139-
and use_cache
140-
):
141-
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
142-
if is_padding_right:
143-
raise ValueError(
144-
"You are attempting to perform batched generation with padding_side='right'"
145-
" this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
146-
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
147-
)
148132
# embed positions, for the first stage, hidden_states is the input embeddings,
149133
# for the other stages, hidden_states is the output of the previous stage
150134
if shard_config.enable_flash_attention:
151135
# in this case, attention_mask is a dict rather than a tensor
152136
attention_mask = None
153137
else:
154-
if self._attn_implementation == "flash_attention_2":
155-
# 2d mask is passed through the layers
156-
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
157-
elif self._attn_implementation == "sdpa" and not output_attentions:
158-
# output_attentions=True can not be supported when using SDPA, and we fall back on
159-
# the manual implementation that requires a 4D causal mask in all cases.
160-
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
161-
attention_mask,
162-
(batch_size, seq_length),
163-
inputs_embeds,
164-
past_key_values_length,
165-
)
166-
else:
167-
# 4d mask is passed through the layers
168-
attention_mask = _prepare_4d_causal_attention_mask(
169-
attention_mask,
170-
(batch_size, seq_length),
171-
hidden_states,
172-
past_key_values_length,
173-
sliding_window=self.config.sliding_window,
174-
)
138+
# 4d mask is passed through the layers
139+
attention_mask = _prepare_4d_causal_attention_mask(
140+
attention_mask,
141+
(batch_size, seq_length),
142+
hidden_states,
143+
past_key_values_length,
144+
sliding_window=self.config.sliding_window,
145+
)
175146

176147
if stage_manager.is_first_stage():
177148
if shard_config.enable_sequence_parallelism:

examples/language/performance_evaluator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,18 +161,18 @@ def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
161161
) * (
162162
1.0 + (seq_len / (6.0 * self.hidden_size)) + (self.vocab_size / (16.0 * self.num_layers * self.hidden_size))
163163
)
164-
self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))
164+
self.flop += batch_size * (seq_len // 1024) * self.model_numel * (3 + int(self.enable_grad_checkpoint))
165165

166166
def on_fit_end(self) -> None:
167167
avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size)
168168
avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
169169
mp_world_size = self.coordinator.world_size // self.dp_world_size
170-
avg_tflops_per_gpu_megatron = self.flop_megatron / 1e12 / (avg_duration + 1e-12) / mp_world_size
170+
self.flop_megatron / 1e12 / (avg_duration + 1e-12) / mp_world_size
171171
avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
172172
self.coordinator.print_on_master(
173173
f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop_megatron: {self.flop_megatron}, flop: {self.flop}, avg_duration: {avg_duration}, "
174174
f"avg_throughput: {avg_throughput}"
175175
)
176176
self.coordinator.print_on_master(
177-
f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU by Megatron: {avg_tflops_per_gpu_megatron:.2f}, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}"
177+
f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}"
178178
)

examples/language/qwen2/README.md

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Pretraining LLaMA-1/2/3: best practices for building LLaMA-1/2/3-like base models
2+
### LLaMA3
3+
<p align="center">
4+
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/images/LLaMA3-70B-H100.png" width=600/>
5+
</p>
6+
7+
- 70 billion parameter LLaMA3 model training accelerated by 18%
8+
9+
### LLaMA2
10+
<p align="center">
11+
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/llama2_pretraining.png" width=600/>
12+
</p>
13+
14+
- 70 billion parameter LLaMA2 model training accelerated by 195%
15+
[[blog]](https://www.hpc-ai.tech/blog/70b-llama2-training)
16+
17+
### LLaMA1
18+
<p align="center">
19+
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/examples/images/LLaMA_pretraining.png" width=600/>
20+
</p>
21+
22+
- 65-billion-parameter large model pretraining accelerated by 38%
23+
[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)
24+
25+
## Usage
26+
27+
> ⚠ This example only has benchmarking script. For training/finetuning, please refer to the [applications/Colossal-LLaMA](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA).
28+
29+
### 1. Installation
30+
31+
Please install the latest ColossalAI from source.
32+
33+
```bash
34+
BUILD_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI
35+
```
36+
37+
Then install other dependencies.
38+
39+
```bash
40+
pip install -r requirements.txt
41+
```
42+
43+
### 4. Shell Script Examples
44+
45+
For your convenience, we provide some shell scripts to run benchmark with various configurations.
46+
47+
You can find them in `scripts/benchmark_7B` and `scripts/benchmark_70B` directory. The main command should be in the format of:
48+
```bash
49+
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
50+
benchmark.py --OTHER_CONFIGURATIONS
51+
```
52+
Here we will show an example of how to run training
53+
llama pretraining with `gemini, batch_size=16, sequence_length=4096, gradient_checkpoint=True, flash_attn=True`.
54+
55+
#### a. Running environment
56+
This experiment was performed on 4 computing nodes with 32 A800/H800 80GB GPUs in total for LLaMA-1 65B or LLaMA-2 70B. The nodes are
57+
connected with RDMA and GPUs within one node are fully connected with NVLink.
58+
59+
#### b. Running command
60+
61+
```bash
62+
cd scripts/benchmark_7B
63+
```
64+
65+
First, put your host file (`hosts.txt`) in this directory with your real host ip or host name.
66+
67+
Here is a sample `hosts.txt`:
68+
```text
69+
hostname1
70+
hostname2
71+
hostname3
72+
hostname4
73+
```
74+
75+
Then add environment variables to script if needed.
76+
77+
Finally, run the following command to start training:
78+
79+
```bash
80+
bash gemini.sh
81+
```
82+
83+
If you encounter out-of-memory(OOM) error during training with script `gemini.sh`, changing to script `gemini_auto.sh` might be a solution, since gemini_auto will set a upper limit on GPU memory usage through offloading part of the model parameters and optimizer states back to CPU memory. But there's a trade-off: `gemini_auto.sh` will be a bit slower, since more data are transmitted between CPU and GPU.
84+
85+
#### c. Results
86+
If you run the above command successfully, you will get the following results:
87+
`max memory usage: 55491.10 MB, throughput: 24.26 samples/s, TFLOPS/GPU: 167.43`.
88+
89+
90+
## Reference
91+
```
92+
@article{bian2021colossal,
93+
title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
94+
author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
95+
journal={arXiv preprint arXiv:2110.14883},
96+
year={2021}
97+
}
98+
```
99+
100+
```bibtex
101+
@software{openlm2023openllama,
102+
author = {Geng, Xinyang and Liu, Hao},
103+
title = {OpenLLaMA: An Open Reproduction of LLaMA},
104+
month = May,
105+
year = 2023,
106+
url = {https://github.com/openlm-research/open_llama}
107+
}
108+
```
109+
110+
```bibtex
111+
@software{together2023redpajama,
112+
author = {Together Computer},
113+
title = {RedPajama-Data: An Open Source Recipe to Reproduce LLaMA training dataset},
114+
month = April,
115+
year = 2023,
116+
url = {https://github.com/togethercomputer/RedPajama-Data}
117+
}
118+
```
119+
120+
```bibtex
121+
@article{touvron2023llama,
122+
title={Llama: Open and efficient foundation language models},
123+
author={Touvron, Hugo and Lavril, Thibaut and Izacard, Gautier and Martinet, Xavier and Lachaux, Marie-Anne and Lacroix, Timoth{\'e}e and Rozi{\`e}re, Baptiste and Goyal, Naman and Hambro, Eric and Azhar, Faisal and others},
124+
journal={arXiv preprint arXiv:2302.13971},
125+
year={2023}
126+
}
127+
```

0 commit comments

Comments
 (0)