diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 24559f591c..8ac7273daa 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -14,6 +14,8 @@ class ExtraMegatronArguments: padded_vocab_size: Optional[int] = None rope_scaling: Optional[Union[dict, str]] = None torch_dtype: Optional[torch.dtype] = None + thread_count: Optional[int] = None + max_epochs: Optional[int] = None @dataclass diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 4749b2c7e8..68df8f09d1 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -10,7 +10,7 @@ from swift.llm.train import SwiftSft from swift.utils import get_logger, is_master, plot_images from ..argument import MegatronTrainArguments -from ..utils import patch_megatron_tokenizer +from ..utils import patch_megatron from .patcher import patch_megatron_data_collator, patch_training_log from .utils import build_streaming_dataloader, forward_step, get_swift_datasets_provider @@ -26,7 +26,7 @@ def __init__(self, args: Union[List[str], MegatronTrainArguments, None] = None) super(SwiftSft, self).__init__(args) args = self.args _, self.processor = args.get_model_processor(load_model=False) - patch_megatron_tokenizer(self.processor) + patch_megatron(self.processor) args.init_model_args(self.processor.model_info.config) self._prepare_template() self.template.use_megatron = True diff --git a/swift/megatron/train/utils.py b/swift/megatron/train/utils.py index 1a38f28a90..79b18b7162 100644 --- a/swift/megatron/train/utils.py +++ b/swift/megatron/train/utils.py @@ -72,20 +72,23 @@ def _broadcast(item): item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) if mpu.get_tensor_model_parallel_rank() == 0: - - if data_iterator is not None: + try: data = next(data_iterator) + except StopIteration: + seq_length = -1 else: - data = None - tokens = data['input_ids'] - seq_length = torch.tensor(tokens.shape[1]).cuda(non_blocking=True) - batch = { - 'tokens': tokens.cuda(non_blocking=True), - 'labels': data['labels'].cuda(non_blocking=True), - 'attention_mask': None if 'attention_mask' not in data else data['attention_mask'].cuda(non_blocking=True), - 'position_ids': data['position_ids'].cuda(non_blocking=True) - } + tokens = data['input_ids'] + seq_length = tokens.shape[1] + batch = { + 'tokens': tokens.cuda(non_blocking=True), + 'labels': data['labels'].cuda(non_blocking=True), + 'attention_mask': None if 'attention_mask' not in data else data['attention_mask'].cuda(non_blocking=True), + 'position_ids': data['position_ids'].cuda(non_blocking=True) + } + seq_length = torch.tensor(seq_length).cuda(non_blocking=True) _broadcast(seq_length) + if seq_length.item() == -1: + raise StopIteration if args.pipeline_model_parallel_size == 1: _broadcast(batch['tokens']) _broadcast(batch['labels']) @@ -104,6 +107,8 @@ def _broadcast(item): else: seq_length = torch.empty((), dtype=torch.int64, device=torch.cuda.current_device()) _broadcast(seq_length) + if seq_length.item() == -1: + raise StopIteration micro_batch_size = 1 # use qkv_format 'thd' tokens = torch.empty((micro_batch_size, seq_length), dtype=torch.int64, device=torch.cuda.current_device()) diff --git a/swift/megatron/utils/__init__.py b/swift/megatron/utils/__init__.py index 4d2b722a2c..173d1a28a6 100644 --- a/swift/megatron/utils/__init__.py +++ b/swift/megatron/utils/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .convert import convert_hf2mcore, convert_mcore2hf -from .patcher import patch_megatron_tokenizer +from .patcher import patch_megatron diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index 42d37b945e..443c4d1adc 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -12,7 +12,7 @@ from swift.utils import get_logger, get_n_params_grads from ..argument import MegatronArguments from ..model import get_megatron_model_meta -from .patcher import patch_megatron_tokenizer, patch_torch_dist_shard +from .patcher import patch_megatron logger = get_logger() @@ -66,13 +66,13 @@ def convert_hf2mcore(args: ExportArguments) -> None: if args.thread_count is None: checkpoint_size = sum(get_n_params_grads(hf_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9 args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB - patch_torch_dist_shard(args.thread_count) megatron_model_meta = get_megatron_model_meta(args.model_type) assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config) - megatron_args = MegatronArguments(**kwargs, **convert_kwargs, save=args.output_dir, torch_dtype=args.torch_dtype) - patch_megatron_tokenizer(processor) + megatron_args = MegatronArguments( + **kwargs, **convert_kwargs, save=args.output_dir, torch_dtype=args.torch_dtype, thread_count=args.thread_count) + patch_megatron(processor) extra_args = megatron_args.parse_to_megatron() initialize_megatron(args_defaults=extra_args) @@ -93,13 +93,13 @@ def convert_mcore2hf(args: ExportArguments) -> None: if args.thread_count is None: checkpoint_size = sum(get_n_params_grads(hf_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9 args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB - patch_torch_dist_shard(args.thread_count) megatron_model_meta = get_megatron_model_meta(args.model_type) assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config) - megatron_args = MegatronArguments(**kwargs, **convert_kwargs, load=args.mcore_model, torch_dtype=args.torch_dtype) - patch_megatron_tokenizer(processor) + megatron_args = MegatronArguments( + **kwargs, **convert_kwargs, load=args.mcore_model, torch_dtype=args.torch_dtype, thread_count=args.thread_count) + patch_megatron(processor) extra_args = megatron_args.parse_to_megatron() initialize_megatron(args_defaults=extra_args) diff --git a/swift/megatron/utils/patcher.py b/swift/megatron/utils/patcher.py index 5a4aed76fc..e5bff531ff 100644 --- a/swift/megatron/utils/patcher.py +++ b/swift/megatron/utils/patcher.py @@ -7,8 +7,13 @@ logger = get_logger() -def patch_megatron_tokenizer(tokenizer): +def patch_megatron(tokenizer): + # patch tokenizer + patch_tokenizer(tokenizer) + patch_torch_dist_shard() + patch_cyclic_iter() +def patch_tokenizer(tokenizer): def build_tokenizer(args): args.extra_vocab_size = args.padded_vocab_size - tokenizer.vocab_size return tokenizer @@ -16,11 +21,28 @@ def build_tokenizer(args): global_vars.build_tokenizer = build_tokenizer -def patch_torch_dist_shard(thread_count): +def patch_torch_dist_shard(): __init__ = TorchDistSaveShardedStrategy.__init__ - def __new_init__(*args, **kwargs): - kwargs['thread_count'] = thread_count - return __init__(*args, **kwargs) + def __new_init__(*_args, **kwargs): + args = get_args() + if args.thread_count is not None: + kwargs['thread_count'] = args.thread_count + return __init__(*_args, **kwargs) TorchDistSaveShardedStrategy.__init__ = __new_init__ + + +def patch_cyclic_iter(): + from megatron.training import training + def cyclic_iter(iter): + args = get_args() + n_epoch = 0 + while True: + for x in iter: + yield x + logger.info(f'Epoch {n_epoch} has ended.') + n_epoch += 1 + if args.max_epochs is not None and n_epoch >= args.max_epochs: + break + training.cyclic_iter = cyclic_iter diff --git a/swift/utils/env.py b/swift/utils/env.py index 29a6cbe791..b674f6aba3 100644 --- a/swift/utils/env.py +++ b/swift/utils/env.py @@ -78,7 +78,7 @@ def is_mp() -> bool: def is_mp_ddp() -> bool: # patch_mp_ddp will occur when `import swift`. if is_dist() and is_mp(): - logger.info('Using MP + DDP(device_map)') + logger.info('Using MP(device_map) + DDP') return True return False diff --git a/tests/megatron/test_model.py b/tests/megatron/test_model.py index f4e154879d..21da170f3d 100644 --- a/tests/megatron/test_model.py +++ b/tests/megatron/test_model.py @@ -12,7 +12,7 @@ def get_mg_model_tokenizer(model_id): model_info = processor.model_info kwargs = megatron_model_meta.convert_hf_config(model_info.config) megatron_args = MegatronArguments(**kwargs, seq_length=1, use_cpu_initialization=True, no_initialization=True) - patch_megatron_tokenizer(processor) + patch_megatron(processor) extra_args = megatron_args.parse_to_megatron() initialize_megatron(args_defaults=extra_args) mg_model = megatron_model_meta.model_provider() @@ -59,7 +59,7 @@ def test_align(hf_model, mg_model, processor): from swift.utils import set_default_ddp_config from swift.megatron.argument import MegatronArguments from swift.megatron.model import get_megatron_model_meta - from swift.megatron.utils import patch_megatron_tokenizer + from swift.megatron.utils import patch_megatron # test_bf16_fp32() hf_model, mg_model, processor = get_mg_model_tokenizer(model_id) test_align(hf_model, mg_model, processor) diff --git a/tests/megatron/test_save.py b/tests/megatron/test_save.py index cfc78182ae..9417467784 100644 --- a/tests/megatron/test_save.py +++ b/tests/megatron/test_save.py @@ -22,7 +22,7 @@ def get_mg_model_tokenizer(): save='mcore-hf-test', no_load_optim=True, no_load_rng=True) - patch_megatron_tokenizer(processor) + patch_megatron(processor) extra_args = megatron_args.parse_to_megatron() initialize_megatron(args_defaults=extra_args) mg_model = megatron_model_meta.model_provider() @@ -57,5 +57,5 @@ def test_save(): from swift.utils import set_default_ddp_config from swift.megatron.argument import MegatronArguments from swift.megatron.model import get_megatron_model_meta - from swift.megatron.utils import patch_megatron_tokenizer + from swift.megatron.utils import patch_megatron test_save()