Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

训练速度慢,GPU利用率低 #1793

Open
dayunyan opened this issue Oct 30, 2024 · 8 comments
Open

训练速度慢,GPU利用率低 #1793

dayunyan opened this issue Oct 30, 2024 · 8 comments
Labels
bug Something isn't working

Comments

@dayunyan
Copy link

Describe the bug/ 问题描述 (Mandatory / 必填)
LoRA微调Qwen2.5-3B模型时,训练阶段前10个step的速度比较快,能达到1~2s/step,随后逐渐减慢到10s/step以上,并且GPU的利用率在前期能达到100%,但在100个step之后就长时间地停在2%。

  • Hardware Environment(Ascend/GPU/CPU) / 硬件环境:

GPU

  • Software Environment / 软件环境 (Mandatory / 必填):
    -- MindSpore version (e.g., 1.7.0.Bxxx) : 2.2.14
    -- Python version (e.g., Python 3.7.5) : 3.9
    -- OS platform and distribution (e.g., Linux Ubuntu 16.04): 22.04
    -- GCC/Compiler version (if compiled from source):

  • Excute Mode / 执行模式 (Mandatory / 必填)(PyNative/Graph):

/mode pynative

To Reproduce / 重现步骤 (Mandatory / 必填)
Steps to reproduce the behavior:

    def forward_fn(input_ids, attention_mask, labels):
        output = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        # loss = compute_ce_loss(output.logits, labels)
        return output.loss, output.logits

    grad_fn = ms.value_and_grad(
        forward_fn, None, model.trainable_params(), has_aux=True
    )

    def train_step(input_ids, attention_mask, labels):
        (loss, logits), grads = grad_fn(input_ids, attention_mask, labels)
        optimizer.step(grads)
        return loss, logits

    for epoch in tqdm(range(num_epochs), desc="Epoch"):
        model.set_train(True)
        total_loss, total_step = 0, 0
        with tqdm(total=num_batches, leave=False, position=1, desc="train_step") as t:
            for step, pack in enumerate(train_dataset.create_dict_iterator()):
                input_ids = pack["input_ids"]
                attention_mask = pack["attention_mask"]
                labels = pack["labels"]
                loss, logits = train_step(
                    input_ids=input_ids, attention_mask=attention_mask, labels=labels
                )
                total_loss += loss.asnumpy()
                lr_scheduler.step()
                total_step += 1
                curr_loss = total_loss / total_step
                t.set_postfix({"train-loss": f"{curr_loss:.2f}"})
                t.update(1)
                # if profiler is not None:
                #     if step == 10:
                #         profiler.start()
                #     if step == 100:
                #         profiler.stop()

        model.set_train(False)
        eval_loss = 0
        total_step = 0
        eval_preds = []
        total_text_labels = []
        with tqdm(
            total=num_batches_eval, leave=False, position=1, desc="eval_step"
        ) as t:
            for step, pack in enumerate(eval_dataset.create_dict_iterator()):
                input_ids = pack["input_ids"]
                attention_mask = pack["attention_mask"]
                labels = pack["labels"]
                text_inputs = pack["text_inputs"]
                text_labels = pack["text_labels"]
                with ms._no_grad():
                    outputs = model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                    )
                loss = compute_ce_loss(outputs.logits, labels)
                eval_loss += loss.asnumpy()
                total_step += 1
                curr_eval_loss = eval_loss / total_step
                eval_preds.extend(
                    tokenizer.batch_decode(
                        outputs.logits.argmax(axis=-1).asnumpy(),
                        skip_special_tokens=True,
                    )
                )

                total_text_labels.extend(text_labels.tolist())
                t.set_postfix({"eval-loss": f"{curr_eval_loss:.2f}"})
                t.update(1)
        bleu_avg = compute_bleu_metrics(eval_preds, total_text_labels)
        # accuracy = correct / total * 100
        # print(f"{accuracy=} % on the evaluation dataset")
        eval_epoch_loss = eval_loss / eval_dataset.get_dataset_size()
        eval_ppl = np.exp(eval_epoch_loss)
        train_epoch_loss = total_loss / train_dataset.get_dataset_size()
        train_ppl = np.exp(train_epoch_loss)
        tqdm.write(
            f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=} {bleu_avg=}"
        )

Expected behavior / 预期结果 (Mandatory / 必填)
训练速度保持稳定且快速,GPU利用率能稳定且不能过低。

Screenshots/ 日志 / 截图 (Mandatory / 必填)
image
image

Additional context / 备注 (Optional / 选填)
Add any other context about the problem here.

@dayunyan dayunyan added the bug Something isn't working label Oct 30, 2024
@lvyufeng
Copy link
Collaborator

试试用mindspore高版本?

@lvyufeng
Copy link
Collaborator

完整代码也用附件传一下,我看看能不能复现

@dayunyan
Copy link
Author

完整代码也用附件传一下,我看看能不能复现

我暂时还没有Ascend环境,应该用不了2.3版本吧。代码已上传,麻烦您了。
Code.zip

@lvyufeng
Copy link
Collaborator

加一下QQ群,给你申请点代金券 721548151

@dayunyan
Copy link
Author

加一下QQ群,给你申请点代金券 721548151

昨晚刚加😀

@EdwinWang37
Copy link

@lvyufeng @dayunyan 请问一下大家这个问题换高版本之后解决没,我也是这个问题呜呜(_

@dayunyan
Copy link
Author

@EdwinWang37 没有,我后面换Ascend了

@EdwinWang37
Copy link

@EdwinWang37 没有,我后面换Ascend了

@dayunyan 非常感谢!那我也换成Ascend试试,不过为啥变慢,这个问题还真是个谜呀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants