Skip to content

[megatron] support max_epochs #3677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions swift/megatron/argument/megatron_args.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,8 @@ class ExtraMegatronArguments:
padded_vocab_size: Optional[int] = None
rope_scaling: Optional[Union[dict, str]] = None
torch_dtype: Optional[torch.dtype] = None
thread_count: Optional[int] = None
max_epochs: Optional[int] = None


@dataclass
4 changes: 2 additions & 2 deletions swift/megatron/train/sft.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@
from swift.llm.train import SwiftSft
from swift.utils import get_logger, is_master, plot_images
from ..argument import MegatronTrainArguments
from ..utils import patch_megatron_tokenizer
from ..utils import patch_megatron
from .patcher import patch_megatron_data_collator, patch_training_log
from .utils import build_streaming_dataloader, forward_step, get_swift_datasets_provider

@@ -26,7 +26,7 @@ def __init__(self, args: Union[List[str], MegatronTrainArguments, None] = None)
super(SwiftSft, self).__init__(args)
args = self.args
_, self.processor = args.get_model_processor(load_model=False)
patch_megatron_tokenizer(self.processor)
patch_megatron(self.processor)
args.init_model_args(self.processor.model_info.config)
self._prepare_template()
self.template.use_megatron = True
27 changes: 16 additions & 11 deletions swift/megatron/train/utils.py
Original file line number Diff line number Diff line change
@@ -72,20 +72,23 @@ def _broadcast(item):
item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group())

if mpu.get_tensor_model_parallel_rank() == 0:

if data_iterator is not None:
try:
data = next(data_iterator)
except StopIteration:
seq_length = -1
else:
data = None
tokens = data['input_ids']
seq_length = torch.tensor(tokens.shape[1]).cuda(non_blocking=True)
batch = {
'tokens': tokens.cuda(non_blocking=True),
'labels': data['labels'].cuda(non_blocking=True),
'attention_mask': None if 'attention_mask' not in data else data['attention_mask'].cuda(non_blocking=True),
'position_ids': data['position_ids'].cuda(non_blocking=True)
}
tokens = data['input_ids']
seq_length = tokens.shape[1]
batch = {
'tokens': tokens.cuda(non_blocking=True),
'labels': data['labels'].cuda(non_blocking=True),
'attention_mask': None if 'attention_mask' not in data else data['attention_mask'].cuda(non_blocking=True),
'position_ids': data['position_ids'].cuda(non_blocking=True)
}
seq_length = torch.tensor(seq_length).cuda(non_blocking=True)
_broadcast(seq_length)
if seq_length.item() == -1:
raise StopIteration
if args.pipeline_model_parallel_size == 1:
_broadcast(batch['tokens'])
_broadcast(batch['labels'])
@@ -104,6 +107,8 @@ def _broadcast(item):
else:
seq_length = torch.empty((), dtype=torch.int64, device=torch.cuda.current_device())
_broadcast(seq_length)
if seq_length.item() == -1:
raise StopIteration

micro_batch_size = 1 # use qkv_format 'thd'
tokens = torch.empty((micro_batch_size, seq_length), dtype=torch.int64, device=torch.cuda.current_device())
2 changes: 1 addition & 1 deletion swift/megatron/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from .convert import convert_hf2mcore, convert_mcore2hf
from .patcher import patch_megatron_tokenizer
from .patcher import patch_megatron
14 changes: 7 additions & 7 deletions swift/megatron/utils/convert.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
from swift.utils import get_logger, get_n_params_grads
from ..argument import MegatronArguments
from ..model import get_megatron_model_meta
from .patcher import patch_megatron_tokenizer, patch_torch_dist_shard
from .patcher import patch_megatron

logger = get_logger()

@@ -66,13 +66,13 @@ def convert_hf2mcore(args: ExportArguments) -> None:
if args.thread_count is None:
checkpoint_size = sum(get_n_params_grads(hf_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9
args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB
patch_torch_dist_shard(args.thread_count)

megatron_model_meta = get_megatron_model_meta(args.model_type)
assert megatron_model_meta is not None, f'Model: {args.model} is not supported.'
kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config)
megatron_args = MegatronArguments(**kwargs, **convert_kwargs, save=args.output_dir, torch_dtype=args.torch_dtype)
patch_megatron_tokenizer(processor)
megatron_args = MegatronArguments(
**kwargs, **convert_kwargs, save=args.output_dir, torch_dtype=args.torch_dtype, thread_count=args.thread_count)
patch_megatron(processor)
extra_args = megatron_args.parse_to_megatron()
initialize_megatron(args_defaults=extra_args)

@@ -93,13 +93,13 @@ def convert_mcore2hf(args: ExportArguments) -> None:
if args.thread_count is None:
checkpoint_size = sum(get_n_params_grads(hf_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9
args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB
patch_torch_dist_shard(args.thread_count)

megatron_model_meta = get_megatron_model_meta(args.model_type)
assert megatron_model_meta is not None, f'Model: {args.model} is not supported.'
kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config)
megatron_args = MegatronArguments(**kwargs, **convert_kwargs, load=args.mcore_model, torch_dtype=args.torch_dtype)
patch_megatron_tokenizer(processor)
megatron_args = MegatronArguments(
**kwargs, **convert_kwargs, load=args.mcore_model, torch_dtype=args.torch_dtype, thread_count=args.thread_count)
patch_megatron(processor)
extra_args = megatron_args.parse_to_megatron()
initialize_megatron(args_defaults=extra_args)

32 changes: 27 additions & 5 deletions swift/megatron/utils/patcher.py
Original file line number Diff line number Diff line change
@@ -7,20 +7,42 @@
logger = get_logger()


def patch_megatron_tokenizer(tokenizer):
def patch_megatron(tokenizer):
# patch tokenizer
patch_tokenizer(tokenizer)
patch_torch_dist_shard()
patch_cyclic_iter()

def patch_tokenizer(tokenizer):
def build_tokenizer(args):
args.extra_vocab_size = args.padded_vocab_size - tokenizer.vocab_size
return tokenizer

global_vars.build_tokenizer = build_tokenizer


def patch_torch_dist_shard(thread_count):
def patch_torch_dist_shard():
__init__ = TorchDistSaveShardedStrategy.__init__

def __new_init__(*args, **kwargs):
kwargs['thread_count'] = thread_count
return __init__(*args, **kwargs)
def __new_init__(*_args, **kwargs):
args = get_args()
if args.thread_count is not None:
kwargs['thread_count'] = args.thread_count
return __init__(*_args, **kwargs)

TorchDistSaveShardedStrategy.__init__ = __new_init__


def patch_cyclic_iter():
from megatron.training import training
def cyclic_iter(iter):
args = get_args()
n_epoch = 0
while True:
for x in iter:
yield x
logger.info(f'Epoch {n_epoch} has ended.')
n_epoch += 1
if args.max_epochs is not None and n_epoch >= args.max_epochs:
break
training.cyclic_iter = cyclic_iter
2 changes: 1 addition & 1 deletion swift/utils/env.py
Original file line number Diff line number Diff line change
@@ -78,7 +78,7 @@ def is_mp() -> bool:
def is_mp_ddp() -> bool:
# patch_mp_ddp will occur when `import swift`.
if is_dist() and is_mp():
logger.info('Using MP + DDP(device_map)')
logger.info('Using MP(device_map) + DDP')
return True
return False

4 changes: 2 additions & 2 deletions tests/megatron/test_model.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@ def get_mg_model_tokenizer(model_id):
model_info = processor.model_info
kwargs = megatron_model_meta.convert_hf_config(model_info.config)
megatron_args = MegatronArguments(**kwargs, seq_length=1, use_cpu_initialization=True, no_initialization=True)
patch_megatron_tokenizer(processor)
patch_megatron(processor)
extra_args = megatron_args.parse_to_megatron()
initialize_megatron(args_defaults=extra_args)
mg_model = megatron_model_meta.model_provider()
@@ -59,7 +59,7 @@ def test_align(hf_model, mg_model, processor):
from swift.utils import set_default_ddp_config
from swift.megatron.argument import MegatronArguments
from swift.megatron.model import get_megatron_model_meta
from swift.megatron.utils import patch_megatron_tokenizer
from swift.megatron.utils import patch_megatron
# test_bf16_fp32()
hf_model, mg_model, processor = get_mg_model_tokenizer(model_id)
test_align(hf_model, mg_model, processor)
4 changes: 2 additions & 2 deletions tests/megatron/test_save.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@ def get_mg_model_tokenizer():
save='mcore-hf-test',
no_load_optim=True,
no_load_rng=True)
patch_megatron_tokenizer(processor)
patch_megatron(processor)
extra_args = megatron_args.parse_to_megatron()
initialize_megatron(args_defaults=extra_args)
mg_model = megatron_model_meta.model_provider()
@@ -57,5 +57,5 @@ def test_save():
from swift.utils import set_default_ddp_config
from swift.megatron.argument import MegatronArguments
from swift.megatron.model import get_megatron_model_meta
from swift.megatron.utils import patch_megatron_tokenizer
from swift.megatron.utils import patch_megatron
test_save()