Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Gradient Accumulation issue #34191

Merged
merged 58 commits into from
Oct 17, 2024
Merged

Fix Gradient Accumulation issue #34191

merged 58 commits into from
Oct 17, 2024

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Oct 16, 2024

What does this PR do?

First draft

End goal is to make it easy for anyone to:

  • change the loss for his model
  • contribute a new loss for a model (like vision model, ENCODEC etc)
  • allow passing arbitrary kwargs, interfacing

TODO:

  • Fix deformable detr loss computation

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for coming forward with this fix so quickly. There is probably not much I can help with, but I took a look and added some comments.

@paulcx
Copy link

paulcx commented Oct 22, 2024

As hiyouga/LLaMA-Factory#5747 (comment) suggested, maybe a few changes will make this fix work again?

@iridescentee
Copy link

After using this fix, the following problem occurred: the grad norm and loss (custom loss from trl cpo trainer) increased dramatically
image green:llama34b-deepspeed-zero3-8gpus-fullparams-bs4-grad8-wo_fix orange: llama34b-deepspeed-zero3-8gpus-fullparams-bs4-grad8-wo_fix

Can confirm the same issue
same. encountered when fine-tuning using qwen2.5.

@muellerzr
Copy link
Contributor

Fix is being worked on here: #34283

@thusinh1969
Copy link

Oh my....

@taehyunzzz
Copy link

taehyunzzz commented Oct 28, 2024

Hello all, I could not follow the whole context of the PR here, but I am guessing a fix is underway regarding the loss becoming huge due to this code. The code scales the loss by gradient accumulation step in the latest trainer (not in previous version). The snippet is included in a single training iteration step, for which I think scaling the loss by grad_acc_step is not appropriate, if the loss is not de-scaled by grad_acc_step inside the accelerator.backward call in the following line.

Does the latest accelerator backward de-scale by grad_acc_step (or is being fixed to do so)? If so, maybe the output loss of training_step() should be de-scaled one more time by grad_acc_step?
For now, is it correct to scale down the learning rate by the grad_acc_step for the original functionality?

@muellerzr
Copy link
Contributor

Yes accelerate de-scales the loss when calling backward() which is why we do so. (Accelerate has always done this)

@taehyunzzz
Copy link

@muellerzr Thank you for the reply. I have just one more question to clear things up.

You've mentioned that accelerate had always internally de-scaled the loss by gradient accumulation steps.
However, the previous trainer.py implementations do not scale (scale up) the loss by grad_acc_steps before accelerator.backward.
If accelerator.backward had been internally de-scaling the loss by grad_acc_steps, does that mean the previous trainer had been training a model on 1/grad_acc_steps * sum(loss) instead of sum(loss) when grad_acc_steps > 1?

ByronHsu added a commit to linkedin/Liger-Kernel that referenced this pull request Oct 31, 2024
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

To fix #322

This PR introduces a new `lce_forward` compatible with
`transformers>=4.46.0` (after grad acc fix) while ensuring backward
compatibilty.

To be specific, i keep the original flce untouched and write a new one
for `4.46.0`. If HF version is `<4.46.0`, it will show a warning for
deprecation, and fallback to the old flce.


```python
        if transformer_version >= version.parse("4.46.0"):
            modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
        else:  # if version < 4.46.0
            logger.warning(
                "Support for transformers versions < 4.46.0 will soon be discontinued due to issues with incorrect gradient accumulation. "
                "Please consider upgrading to avoid potential issues. See details: huggingface/transformers#34191"
            )
            modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
```


For more context of grad acc fix, please see
huggingface/transformers#34191

## TODO

- [ ] broadcast the changes to all models once the effect is verified.


## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
gheinrich added a commit to gheinrich/transformers that referenced this pull request Nov 2, 2024
This PR enables handling loss keyword arguments in the Mistral
forward() method. Specifically, if `num_items_in_batch` is passed,
the value is used to properly normalize the loss value.

This relates to the Gradient Accumulation fix (huggingface#34191)

Fixes huggingface#34575
wizyoung added a commit to wizyoung/Liger-Kernel that referenced this pull request Nov 7, 2024
commit ae7e13ba1eaf58e5066b5cd60dfddf4f66f3cfed
Merge: ede50df 280cb81
Author: Wizyoung <happyyanghehe@gmail.com>
Date:   Thu Nov 7 15:58:13 2024 +0800

    Merge branch 'linkedin:main' into main

commit 280cb8139511753ab3a16f286ebffe694ddd1970
Author: Haoyi Wu <43395692+why-in-Shanghaitech@users.noreply.github.com>
Date:   Thu Nov 7 13:45:16 2024 +0800

    Improve compatibility to access the base models (#340)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->
    This PR resolves #337, which improves the compatibility to access the
    base models through the `base_model_prefix` attribute.

    ## Details
    <!---
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->
    One thing to mention: The `mllama` seems to be an outlier. It has text
    model and vision model so it is impossible to access through one
    attribute. Meanwhile, the `base_model_prefix` seems to have different
    semantics for `mllama` model classes. I left the codes for `mllama`
    unchanged.

    For other models, I look into the `transformers` library and manually
    check the correctness.

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->
    The changes passed `test/transformers/test_monkey_patch.py` by running
    `pytest`.

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: RTX 3090
    - [ ] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [ ] run `make test-convergence` to ensure convergence

    Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

commit ab5e88be1950aba248555e5e01907de04329e4dc
Author: Tcc0403 <76503978+Tcc0403@users.noreply.github.com>
Date:   Thu Nov 7 13:29:08 2024 +0800

    Support Z Loss in CE (#239)

    ## Summary
    This PR aims to resolve #197

    Implemented z loss in LigerCrossEntropy.

    note: `lse_square_scale` not exposed at flce yet, having issues passing
    the tests.
    ## Details
    ### For loss:
    ```math
    \begin{align}
    L_{total} &= L_{ce} + z\_loss\
    z\_loss &= lse\_square\_scale \cdot lse^2\
    lse &= log \sum e^{X_i}
    \end{align}
    ```
    We can use $m = max(X_i)$ and $d = \sum e^{X_i - m}$, obtained from
    online softmax algorithm, to calculate $lse$ directly.
    ```math
    \begin{align}
    lse &= log \sum e^{X_i}\
         &= log \sum e^{X_i - m + m} = log \sum e^{X_i -m} \cdot e^m\
         &= log\ e^m\sum e^{X_i - m} = m + d
    \end{align}
    ```
    ### For gradients:
    First, we calculate the derivative of lse
    ```math
    \begin{align}
    \frac{\partial}{\partial x_i}(lse) &= \frac{\partial}{\partial x_i}(log \sum e^{x_i}) \
                                               &= \frac{1}{\sum e^{x_i}} \cdot  \frac{\partial}{\partial x_i} \sum e^{x_i}\
                                               &= \frac{e^{x_i}}{\sum e^{x_i}} = softmax(x_i).
    \end{align}
    ```
    Then we can obtain the derivative of z_loss by chain rule.
    ```math
    \frac{\partial z\_loss}{\partial x_i} = \frac{\partial}{\partial x_i}\left( lse\_square\_scale \cdot lse^2\right)  = 2\cdot lse\_square\_scale \cdot lse \cdot  softmax(x_i),
    ```
    and we have the derivative of cross entropy loss with label smoothing
    ```math
    \frac{\partial L_{ce}}{\partial x_i} = softmax(x_i) - (1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K}= \begin{cases} softmax(x_i) - \frac{\epsilon}{K},                        &  i \neq y \\
                                                       softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon) &  i = y \end{cases}
    ```
    where $\epsilon$ is label_smoothing and $K$ is the number of total
    classes.
    Thus, the derivative of total loss is
    ```math
    \begin{align}
    \frac{\partial}{\partial x_i}L_{total} &= \frac{\partial}{\partial x_i}L_{ce} + \frac{\partial}{\partial x_i}z\_loss\
                                                         &= softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon)\delta_{k,y} +  2\cdot lse\_square\_scale \cdot lse \cdot softmax(x_i)\
                                                         &=\begin{cases} (1 + 2\cdot lse\_square\_scale \cdot lse)\ softmax(x_i) - \frac{\epsilon}{K}, & i \neq y\\
    (1 + 2\cdot lse\_square\_scale \cdot lse)\ softmax(x_i) - \frac{\epsilon}{K} -  (1 - \epsilon), & i = y \end{cases}
    \end{align}
    ```
    ### Reference
    [PaLM: Scaling Language Modeling with
    Pathways](https://www.jmlr.org/papers/v24/22-1144.html)
    [Chameleon: Mixed-Modal Early-Fusion Foundation
    Models](https://arxiv.org/abs/2405.09818)
    ## Testing Done
    [benchmark
    gist](https://gist.github.com/Tcc0403/b9120282334196f66b5169d9f52bccaa)
    neglectable error in speed benchmark.

    This benchmark was done on my machine, which is probably not accurate.
    ```
    liger ce: 66.123ms
    Peak mem:  8.66200832

    liger ce with zloss: 65.991ms
    Peak mem:  8.66200832

    liger ce with zloss with return zloss: 65.951ms
    Peak mem:  8.662073856
    ```

    - Hardware Type: <BLANK>
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

    ---------

    Co-authored-by: Shao Tang <tangshao28@gmail.com>
    Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

commit 85d34efbd423cd97d3e97525af419193fbb07354
Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com>
Date:   Wed Nov 6 17:44:54 2024 +0000

    BUG: Fix bug in layer norm tests. (#359)

    ## Summary
    This PR fixes a bug in a test case for layer norm, where the assert on
    the gradient of x was incorrectly compared against itself meaning that
    the assertion would always succeed.

    ## Testing Done
    Tested on, A100-80G-SXM4

    - Hardware Type: <BLANK>
    - [X] run `make test` to ensure correctness
    - [X] run `make checkstyle` to ensure code style
    - [X] run `make test-convergence` to ensure convergence

commit c131f0423ccef96e71a13d58bda168f5904bfa89
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Tue Nov 5 16:50:38 2024 -0800

    Update ci.yml

commit 985e6c74b61656061f28be74434a6de2de3aabfd
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Tue Nov 5 16:13:49 2024 -0800

    Update ci.yml

commit a8c085488f3c47b86b2d560a1225bc27ec59c68d
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Tue Nov 5 15:58:11 2024 -0800

    fixing ci

commit e985195bec82ea9d89b9d20a758356eee1650dc1
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Tue Nov 5 14:10:52 2024 -0800

    Update pyproject.toml

commit 98d77e077d7bf8335a4a7748067ea8fc3633e3ef
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Tue Nov 5 14:05:27 2024 -0800

    broadcast grad acc fix to all models (#354)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    follow up for https://github.com/linkedin/Liger-Kernel/pull/339

    However, identify few issues
    1. revert patching causes flce not taking effect (comment out revert
    patching for now, and only test float32)
    2. qwen2 vl flce is broken. we should fix later
    3. we should provide a real "on-instance" patch that does not use any
    monkey patch. now the on-instance patch still relies on monkey patch

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [ ] run `make test` to ensure correctness
    - [ ] run `make checkstyle` to ensure code style
    - [ ] run `make test-convergence` to ensure convergence

commit ef3f55dcd06b4fca95a5b75c9fe51ef1b7b7bfef
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Mon Nov 4 17:04:47 2024 -0800

    merge two tests into one (#349)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    remove the launching overhead of the 2nd container

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [ ] run `make test` to ensure correctness
    - [ ] run `make checkstyle` to ensure code style
    - [ ] run `make test-convergence` to ensure convergence

commit b09fb65a37a045aa64e92b4d493897ba1c462ce8
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Mon Nov 4 16:40:52 2024 -0800

    Trim conv test (#348)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    Remove non flce convergence test since most users are using flce

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [ ] run `make test` to ensure correctness
    - [ ] run `make checkstyle` to ensure code style
    - [ ] run `make test-convergence` to ensure convergence

commit fbcb52d615f46f54ce865cec028ce5c64a205a2a
Author: ByronHsu <byronhsu1230@gmail.com>
Date:   Mon Nov 4 22:54:09 2024 +0000

    Move dependent license to a folder

commit a2dfa3cb2f7b6f0e23a65ad76b38a6b567404a2c
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Mon Nov 4 14:04:40 2024 -0800

    Aggressively trim test bloat (#346)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    1. Disable the test for experimental kernels
    2. Reduce the size of tensor if the tests takes too long
    3. Remove redundant tests that are testing the same thing

    Make sure unit test time < 5 mins

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [ ] run `make test` to ensure correctness
    - [ ] run `make checkstyle` to ensure code style
    - [ ] run `make test-convergence` to ensure convergence

commit e68b291f11d2f1ab22c5db9b1038021ee1821a0e
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Mon Nov 4 13:14:38 2024 -0800

    avoid duplicate ci (#345)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [ ] run `make test` to ensure correctness
    - [ ] run `make checkstyle` to ensure code style
    - [ ] run `make test-convergence` to ensure convergence

commit c34843c45eb8c3501d54f506fa359401e06d0166
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Mon Nov 4 13:08:19 2024 -0800

    set up modal ci (#344)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    follow https://github.com/modal-labs/ci-on-modal

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [ ] run `make test` to ensure correctness
    - [ ] run `make checkstyle` to ensure code style
    - [ ] run `make test-convergence` to ensure convergence

commit ac7b38a2fdd3368b648d5ee02f6c0fb8661d8005
Author: TJian <tunjian1996@gmail.com>
Date:   Sun Nov 3 01:07:39 2024 +0800

    [AMD] [ROCm] Pick `num_warps` based on platform (#326)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->

    This is a PR to enable the kernel to run on AMD GPUs through the initial
    changes to the `num_warps`.
    This change is proposed by @Edenzzzz and @DocShotgun in this issue
    https://github.com/linkedin/Liger-Kernel/issues/266

    ## Details
    <!---
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->
    I have updated the `transformers` version from `4.44.0` to `4.46.0`
    requirement and all unit tests passed on A100 and MI300X.

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: AMD Instinct MI300X
    - [x] run `make test` to ensure correctness
    - There are some test failed due to numerical precision issue. Passed by
    relaxing the condition by 1 order of magnitude (following the advice in
    the Liger-Kernel technical report
    https://arxiv.org/pdf/[2410.10989](https://arxiv.org/pdf/2410.10989)
    **Footnote 12:** _Note that in practice, the tolerance may need further
    relaxation in some cases by one or two orders of magnitude, even for
    exact kernels. We use convergence tests to ensure exactness in cases
    where the tolerance for correctness needs to be loose._ )
    - The test that the tolerance are relaxed involves `kl_div` and `jsd` in
    `float32` tests
        - The relax conditions are described by the following code snippet
          ```
          _DTYPE_PARAMS = (
              "dtype, atol, rtol",
              [
                  pytest.param(
                      torch.bfloat16,
                      1e-8,
                      5e-2,
                      marks=pytest.mark.skipif(
    not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                      ),
                  ),
                  (torch.float32, 1e-8 if not is_hip() else 1e-7, 1e-6),
                  (torch.float16, 1e-3, 1e-3),
              ],
          )

          ```
    - To pass the test, the triton must not be installed from source, it
    must be installed through pypi `pip install triton==3.0.0`. This issue
    will be tracked with an issue at triton
    https://github.com/triton-lang/triton/issues/5013 .
    - ~~Something is weird as well, if I just run the failed test
    `test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[10.0-dtype5-1e-08-1e-06-sum-2-4096-32000--100]`,
    the test passed. By running `pytest
    test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[10.0-dtype5-1e-08-1e-06-sum-2-4096-32000--100]`.
    However it will failed if there are other tests running before this
    test.~~
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence
    <details>
    <summary> <s>Failure Test Logs (Click to expand/collapse) </s>
    </summary>
    ```bash
            ============================================================= FAILURES =============================================================
        ________________________ test_correctness_with_ignore_index[10.0-dtype5-1e-08-1e-06-sum-2-4096-32000--100] _________________________

        B = 2, T = 4096, V = 32000, ignore_index = -100, reduction = 'sum', scalar = 10.0, dtype = torch.float32, atol = 1e-08, rtol = 1e-06

            @pytest.mark.parametrize(
                "B, T, V, ignore_index",
                [
                    (2, 4096, 32000, -100),  # llama2, mistral
                    (2, 4096, 32000, 2),  # llama2, mistral
                    (1, 4096, 128256, -300),  # llama3
                    # weird shapes
                    (3, 423, 32000, -123),
                ],
            )
            @pytest.mark.parametrize("reduction", ["sum", "mean"])
            @pytest.mark.parametrize(
                "scalar, dtype, atol, rtol",
                [
                    pytest.param(
                        0.1,
                        torch.bfloat16,
                        1e-8,
                        5e-2,
                        marks=pytest.mark.skipif(
                            not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                        ),
                    ),
                    pytest.param(
                        1.0,
                        torch.bfloat16,
                        1e-8,
                        5e-2,
                        marks=pytest.mark.skipif(
                            not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                        ),
                    ),
                    pytest.param(
                        10.0,
                        torch.bfloat16,
                        1e-8,
                        5e-2,
                        marks=pytest.mark.skipif(
                            not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                        ),
                    ),
                    (0.1, torch.float32, 1e-8, 1e-6),
                    (1.0, torch.float32, 1e-8, 1e-6),
                    (10.0, torch.float32, 1e-8, 1e-6),
                ],
            )
            @pytest.mark.skipif(
                torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
                reason="Needs 16GB+ GPU memory.",
            )
            def test_correctness_with_ignore_index(
                B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol
            ):
                liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)
        >       _test_correctness_with_ignore_index_once(
                    liger_ce, B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol
                )

        test/transformers/test_cross_entropy.py:302:
        _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

        target_ce = LigerCrossEntropyLoss(), B = 2, T = 4096, V = 32000, ignore_index = -100, reduction = 'sum', scalar = 10.0
        dtype = torch.float32, atol = 1e-08, rtol = 1e-06

            def _test_correctness_with_ignore_index_once(
                target_ce, B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol
            ):

                torch_ce = CrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)

                _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar
                _input = _tensor.detach().clone().requires_grad_(True)
                _input2 = _tensor.detach().clone().requires_grad_(True)

                target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long)

                # Assign some random number of elements as ignore_index
                num_elements_to_assign = torch.randint(
                    1, B * T // 2, (1,)
                ).item()  # Random number of elements to set to ignore_index
                indices_to_assign = torch.randperm(B * T)[
                    :num_elements_to_assign
                ]  # Randomly select indices
                target[indices_to_assign] = ignore_index

                output = torch_ce(_input, target)
                output2 = target_ce(_input2, target)

                assert torch.allclose(output, output2, atol=atol, rtol=rtol)

                output.backward()
                output2.backward()
        >       assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)
        E       AssertionError: assert False
        E        +  where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19,  ..., 1.3759e-13, 7.6381e-10,\n         4.4185e-23],\n        [2.9569e-12, 3.8580e-19, 5.3756e-16,  ..., 6.0166e-23, 1.4681e-17,\n         5.1994e-20],\n        [4.7900e-26, 1.0599e-04, 7.0237e-19,  ..., 1.1461e-20, 1.0415e-10,\n         1.0237e-19],\n        ...,\n        [6.9540e-17, 3.4471e-22, 2.7309e-14,  ..., 2.5999e-26, 2.5635e-19,\n         7.0793e-16],\n        [6.3721e-23, 1.2054e-13, 1.8638e-20,  ..., 1.2807e-23, 5.5705e-16,\n         2.3085e-13],\n        [1.9623e-20, 2.4720e-11, 1.8808e-15,  ..., 3.5100e-20, 3.6195e-15,\n         1.5356e-23]], device='cuda:0'), tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19,  ..., 1.3759e-13, 7.6381e-10,\n         4.4185e-23],\n        [2.9569e-12, 3.8580e-19, 5.3756e-16,  ..., 6.0166e-23, 1.4681e-17,\n         5.1994e-20],\n        [4.7900e-26, 1.0599e-04, 7.0237e-19,  ..., 1.1461e-20, 1.0415e-10,\n         1.0237e-19],\n        ...,\n        [6.9540e-17, 3.4471e-22, 2.7309e-14,  ..., 2.5999e-26, 2.5635e-19,\n         7.0793e-16],\n        [6.3722e-23, 1.2054e-13, 1.8638e-20,  ..., 1.2807e-23, 5.5705e-16,\n         2.3085e-13],\n        [1.9623e-20, 2.4720e-11, 1.8808e-15,  ..., 3.5100e-20, 3.6195e-15,\n         1.5356e-23]], device='cuda:0'), atol=1e-08, rtol=1e-06)
        E        +    where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose
        E        +    and   tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19,  ..., 1.3759e-13, 7.6381e-10,\n         4.4185e-23],\n        [2.9569e-12, 3.8580e-19, 5.3756e-16,  ..., 6.0166e-23, 1.4681e-17,\n         5.1994e-20],\n        [4.7900e-26, 1.0599e-04, 7.0237e-19,  ..., 1.1461e-20, 1.0415e-10,\n         1.0237e-19],\n        ...,\n        [6.9540e-17, 3.4471e-22, 2.7309e-14,  ..., 2.5999e-26, 2.5635e-19,\n         7.0793e-16],\n        [6.3721e-23, 1.2054e-13, 1.8638e-20,  ..., 1.2807e-23, 5.5705e-16,\n         2.3085e-13],\n        [1.9623e-20, 2.4720e-11, 1.8808e-15,  ..., 3.5100e-20, 3.6195e-15,\n         1.5356e-23]], device='cuda:0') = tensor([[  6.0503,   3.7258,  -0.3530,  ...,  11.8853,  20.5071,  -9.9739],\n        [ 15.2597,  -0.5924,   6.6471,  ...,  -9.3584,   3.0466,  -2.5966],\n        [-17.9122,  31.2363,  -1.4114,  ...,  -5.5268,  17.4033,  -3.3372],\n        ...,\n        [  4.3242,  -7.8904,  10.2973,  ..., -17.3829,  -1.2789,   6.6447],\n        [-10.9055,  10.4553,  -5.2270,  ..., -12.5100,   5.0782,  11.1050],\n        [ -5.8922,  15.0620,   5.5783,  ...,  -5.3107,   6.2329, -13.0452]],\n       device='cuda:0', requires_grad=True).grad
        E        +    and   tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19,  ..., 1.3759e-13, 7.6381e-10,\n         4.4185e-23],\n        [2.9569e-12, 3.8580e-19, 5.3756e-16,  ..., 6.0166e-23, 1.4681e-17,\n         5.1994e-20],\n        [4.7900e-26, 1.0599e-04, 7.0237e-19,  ..., 1.1461e-20, 1.0415e-10,\n         1.0237e-19],\n        ...,\n        [6.9540e-17, 3.4471e-22, 2.7309e-14,  ..., 2.5999e-26, 2.5635e-19,\n         7.0793e-16],\n        [6.3722e-23, 1.2054e-13, 1.8638e-20,  ..., 1.2807e-23, 5.5705e-16,\n         2.3085e-13],\n        [1.9623e-20, 2.4720e-11, 1.8808e-15,  ..., 3.5100e-20, 3.6195e-15,\n         1.5356e-23]], device='cuda:0') = tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19,  ..., 1.3759e-13, 7.6381e-10,\n         4.4185e-23],\n        [2.9569e-12, 3.8580e-19, 5.3756e-16,  ..., 6.0166e-23, 1.4681e-17,\n         5.1994e-20],\n        [4.7900e-26, 1.0599e-04, 7.0237e-19,  ..., 1.1461e-20, 1.0415e-10,\n         1.0237e-19],\n        ...,\n        [6.9540e-17, 3.4471e-22, 2.7309e-14,  ..., 2.5999e-26, 2.5635e-19,\n         7.0793e-16],\n        [6.3722e-23, 1.2054e-13, 1.8638e-20,  ..., 1.2807e-23, 5.5705e-16,\n         2.3085e-13],\n        [1.9623e-20, 2.4720e-11, 1.8808e-15,  ..., 3.5100e-20, 3.6195e-15,\n         1.5356e-23]], device='cuda:0', requires_grad=True).grad

        test/transformers/test_cross_entropy.py:61: AssertionError
        _________________________________ test_correctness_with_beta[0.1-dtype1-1e-08-1e-06-1-4096-128256] _________________________________

        B = 1, T = 4096, V = 128256, beta = 0.1, dtype = torch.float32, atol = 1e-08, rtol = 1e-06

            @pytest.mark.parametrize(*_SHAPE_PARAMS)
            @pytest.mark.parametrize(*_DTYPE_PARAMS)
            @pytest.mark.parametrize("beta", [0.1, 0.5, 0.9])
            def test_correctness_with_beta(B, T, V, beta, dtype, atol, rtol):
                liger_jsd = LigerJSD(beta=beta)
        >       _test_correctness_with_beta_once(liger_jsd, beta, B, T, V, dtype, atol, rtol)

        test/transformers/test_jsd.py:269:
        _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
        test/transformers/test_jsd.py:157: in _test_correctness_with_beta_once
            assert_verbose_allclose(output, output2, atol=atol, rtol=rtol)
        _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

        tensor1 = tensor(0.0805, device='cuda:0', grad_fn=<SumBackward0>)
        tensor2 = tensor(0.0805, device='cuda:0', grad_fn=<LigerJSDFunctionBackward>), rtol = 1e-06, atol = 1e-08, max_print = 5

            def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
                """
                Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.

                Parameters:
                tensor1 (torch.Tensor): First tensor to compare.
                tensor2 (torch.Tensor): Second tensor to compare.
                rtol (float): Relative tolerance.
                atol (float): Absolute tolerance.
                max_print (int): Maximum number of mismatched elements to print.

                Raises:
                AssertionError: If the tensors are not all close within the given tolerance.
                """
                # Check if the shapes of the tensors match
                if tensor1.shape != tensor2.shape:
                    raise AssertionError("Input tensors must have the same shape.")

                # Calculate the difference between the tensors
                diff = torch.abs(tensor1 - tensor2)

                # Determine the tolerance
                tolerance = atol + rtol * torch.abs(tensor2)

                # Find tolerance mismatched elements
                tol_mismatched = diff > tolerance

                # Find nan mismatched elements
                nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))

                # Find +inf mismatched elements
                posinf_mismatched = torch.logical_xor(
                    torch.isposinf(tensor1), torch.isposinf(tensor2)
                )
                # Find -inf mismatched elements
                neginf_mismatched = torch.logical_xor(
                    torch.isneginf(tensor1), torch.isneginf(tensor2)
                )

                # Find all mismatched elements
                mismatched = torch.logical_or(
                    torch.logical_or(tol_mismatched, nan_mismatched),
                    torch.logical_or(posinf_mismatched, neginf_mismatched),
                )

                mismatched_indices = torch.nonzero(mismatched)

                # Count the number of mismatched elements
                num_mismatched = mismatched.sum().item()

                # Check if all elements are close
                all_close = num_mismatched == 0

                # Raise AssertionError with detailed information if there are mismatches
                if not all_close and num_mismatched >= 1:
                    mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
                    print_count = min(max_print, num_mismatched)
                    for index in mismatched_indices[:print_count]:
                        i = tuple(index.tolist())
                        mismatch_details.append(
                            f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}"
                        )
                    if num_mismatched > max_print:
                        mismatch_details.append(
                            f"... and {num_mismatched - max_print} more mismatched elements."
                        )

        >           raise AssertionError("\n".join(mismatch_details))
        E           AssertionError: Number of mismatched elements: 1
        E           Mismatch at index (): tensor1[()] = 0.08054989576339722, tensor2[()] = 0.08054977655410767

        test/utils.py:106: AssertionError
        _________________________________ test_correctness_with_beta[0.9-dtype1-1e-08-1e-06-1-4096-128256] _________________________________

        B = 1, T = 4096, V = 128256, beta = 0.9, dtype = torch.float32, atol = 1e-08, rtol = 1e-06

            @pytest.mark.parametrize(*_SHAPE_PARAMS)
            @pytest.mark.parametrize(*_DTYPE_PARAMS)
            @pytest.mark.parametrize("beta", [0.1, 0.5, 0.9])
            def test_correctness_with_beta(B, T, V, beta, dtype, atol, rtol):
                liger_jsd = LigerJSD(beta=beta)
        >       _test_correctness_with_beta_once(liger_jsd, beta, B, T, V, dtype, atol, rtol)

        test/transformers/test_jsd.py:269:
        _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
        test/transformers/test_jsd.py:157: in _test_correctness_with_beta_once
            assert_verbose_allclose(output, output2, atol=atol, rtol=rtol)
        _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

        tensor1 = tensor(0.0805, device='cuda:0', grad_fn=<SumBackward0>)
        tensor2 = tensor(0.0805, device='cuda:0', grad_fn=<LigerJSDFunctionBackward>), rtol = 1e-06, atol = 1e-08, max_print = 5

            def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
                """
                Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.

                Parameters:
                tensor1 (torch.Tensor): First tensor to compare.
                tensor2 (torch.Tensor): Second tensor to compare.
                rtol (float): Relative tolerance.
                atol (float): Absolute tolerance.
                max_print (int): Maximum number of mismatched elements to print.

                Raises:
                AssertionError: If the tensors are not all close within the given tolerance.
                """
                # Check if the shapes of the tensors match
                if tensor1.shape != tensor2.shape:
                    raise AssertionError("Input tensors must have the same shape.")

                # Calculate the difference between the tensors
                diff = torch.abs(tensor1 - tensor2)

                # Determine the tolerance
                tolerance = atol + rtol * torch.abs(tensor2)

                # Find tolerance mismatched elements
                tol_mismatched = diff > tolerance

                # Find nan mismatched elements
                nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))

                # Find +inf mismatched elements
                posinf_mismatched = torch.logical_xor(
                    torch.isposinf(tensor1), torch.isposinf(tensor2)
                )
                # Find -inf mismatched elements
                neginf_mismatched = torch.logical_xor(
                    torch.isneginf(tensor1), torch.isneginf(tensor2)
                )

                # Find all mismatched elements
                mismatched = torch.logical_or(
                    torch.logical_or(tol_mismatched, nan_mismatched),
                    torch.logical_or(posinf_mismatched, neginf_mismatched),
                )

                mismatched_indices = torch.nonzero(mismatched)

                # Count the number of mismatched elements
                num_mismatched = mismatched.sum().item()

                # Check if all elements are close
                all_close = num_mismatched == 0

                # Raise AssertionError with detailed information if there are mismatches
                if not all_close and num_mismatched >= 1:
                    mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
                    print_count = min(max_print, num_mismatched)
                    for index in mismatched_indices[:print_count]:
                        i = tuple(index.tolist())
                        mismatch_details.append(
                            f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}"
                        )
                    if num_mismatched > max_print:
                        mismatch_details.append(
                            f"... and {num_mismatched - max_print} more mismatched elements."
                        )

        >           raise AssertionError("\n".join(mismatch_details))
        E           AssertionError: Number of mismatched elements: 1
        E           Mismatch at index (): tensor1[()] = 0.08054172992706299, tensor2[()] = 0.08054161071777344

        test/utils.py:106: AssertionError
        ___________________________________ test_correctness[dtype1-1e-08-1e-06-none-False-32-4096-1024] ___________________________________

        B = 32, T = 4096, V = 1024, log_target = False, reduction = 'none', dtype = torch.float32, atol = 1e-08, rtol = 1e-06

            @pytest.mark.parametrize(*_SHAPE_PARAMS)
            @pytest.mark.parametrize("log_target", [True, False])
            @pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"])
            @pytest.mark.parametrize(*_DTYPE_PARAMS)
            def test_correctness(B, T, V, log_target, reduction, dtype, atol, rtol):
                liger_kldiv = LigerKLDIVLoss(reduction=reduction, log_target=log_target)
        >       _test_correctness_once(
                    liger_kldiv, B, T, V, dtype, atol, rtol, reduction, log_target
                )

        test/transformers/test_kl_div.py:97:
        _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

        target_kldiv = LigerKLDIVLoss(), B = 32, T = 4096, V = 1024, dtype = torch.float32, atol = 1e-08, rtol = 1e-06, reduction = 'none'
        log_target = False, is_last_layer = True, device = 'cuda'

            def _test_correctness_once(
                target_kldiv,
                B,
                T,
                V,
                dtype,
                atol,
                rtol,
                reduction,
                log_target,
                is_last_layer=True,
                device="cuda",
            ):
                torch.manual_seed(0)
                torch_kldiv = KLDivLoss(reduction=reduction, log_target=log_target)

                input = torch.randn(
                    B * T, V, device=device, dtype=dtype, requires_grad=True
                ).log_softmax(dim=-1)

                x1 = input.detach().clone().requires_grad_(True)
                x2 = input.detach().clone().requires_grad_(True)

                with torch.no_grad():
                    target = torch.randn(B * T, V, device=device).softmax(dim=-1)

                output = torch_kldiv(x1, target)
                output2 = target_kldiv(x2, target)
        >       assert torch.allclose(output, output2, atol=atol, rtol=rtol)
        E       AssertionError: assert False
        E        +  where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[ 3.8871e-04,  1.5342e-03,  9.7731e-04,  ...,  1.5857e-04,\n          2.0651e-05, -2.0225e-04],\n        [ 3.0436e-04,  1.4040e-03, -1.4338e-04,  ..., -9.6487e-04,\n          3.6957e-04, -1.7970e-04],\n        [ 1.3870e-02,  1.8989e-03, -2.3409e-04,  ..., -9.2741e-05,\n         -2.1325e-03, -3.6861e-04],\n        ...,\n        [ 1.6965e-04,  7.5081e-04,  1.7243e-03,  ..., -3.3345e-04,\n          2.9291e-04,  4.6570e-03],\n        [-8.5313e-04,  5.1247e-04,  2.9434e-03,  ..., -1.6669e-04,\n          6.3304e-04,  8.2082e-04],\n        [-1.0297e-03, -5.9040e-05, -4.5201e-04,  ...,  1.1601e-03,\n          1.0437e-03,  2.4179e-04]], device='cuda:0', grad_fn=<SubBackward0>), tensor([[ 3.8871e-04,  1.5342e-03,  9.7731e-04,  ...,  1.5857e-04,\n          2.0651e-05, -2.0225e-04],\n        [ 3.0436e-04,  1.4040e-03, -1.4338e-04,  ..., -9.6487e-04,\n          3.6957e-04, -1.7970e-04],\n        [ 1.3870e-02,  1.8989e-03, -2.3409e-04,  ..., -9.2741e-05,\n         -2.1325e-03, -3.6861e-04],\n        ...,\n        [ 1.6965e-04,  7.5081e-04,  1.7243e-03,  ..., -3.3345e-04,\n          2.9291e-04,  4.6570e-03],\n        [-8.5313e-04,  5.1247e-04,  2.9434e-03,  ..., -1.6669e-04,\n          6.3304e-04,  8.2082e-04],\n        [-1.0297e-03, -5.9040e-05, -4.5201e-04,  ...,  1.1601e-03,\n          1.0437e-03,  2.4179e-04]], device='cuda:0',\n       grad_fn=<LigerKLDivLossFunctionBackward>), atol=1e-08, rtol=1e-06)
        E        +    where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose

        test/transformers/test_kl_div.py:75: AssertionError
        ______________________________ test_correctness_not_last[dtype1-1e-08-1e-06-none-False-32-4096-1024] _______________________________

        B = 32, T = 4096, V = 1024, log_target = False, reduction = 'none', dtype = torch.float32, atol = 1e-08, rtol = 1e-06

            @pytest.mark.parametrize(*_SHAPE_PARAMS)
            @pytest.mark.parametrize("log_target", [True, False])
            @pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"])
            @pytest.mark.parametrize(*_DTYPE_PARAMS)
            def test_correctness_not_last(B, T, V, log_target, reduction, dtype, atol, rtol):
                liger_kldiv = LigerKLDIVLoss(reduction=reduction, log_target=log_target)
        >       _test_correctness_once(
                    liger_kldiv,
                    B,
                    T,
                    V,
                    dtype,
                    atol,
                    rtol,
                    reduction,
                    log_target,
                    is_last_layer=False,
                )

        test/transformers/test_kl_div.py:108:
        _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

        target_kldiv = LigerKLDIVLoss(), B = 32, T = 4096, V = 1024, dtype = torch.float32, atol = 1e-08, rtol = 1e-06, reduction = 'none'
        log_target = False, is_last_layer = False, device = 'cuda'

            def _test_correctness_once(
                target_kldiv,
                B,
                T,
                V,
                dtype,
                atol,
                rtol,
                reduction,
                log_target,
                is_last_layer=True,
                device="cuda",
            ):
                torch.manual_seed(0)
                torch_kldiv = KLDivLoss(reduction=reduction, log_target=log_target)

                input = torch.randn(
                    B * T, V, device=device, dtype=dtype, requires_grad=True
                ).log_softmax(dim=-1)

                x1 = input.detach().clone().requires_grad_(True)
                x2 = input.detach().clone().requires_grad_(True)

                with torch.no_grad():
                    target = torch.randn(B * T, V, device=device).softmax(dim=-1)

                output = torch_kldiv(x1, target)
                output2 = target_kldiv(x2, target)
        >       assert torch.allclose(output, output2, atol=atol, rtol=rtol)
        E       AssertionError: assert False
        E        +  where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[ 3.8871e-04,  1.5342e-03,  9.7731e-04,  ...,  1.5857e-04,\n          2.0651e-05, -2.0225e-04],\n        [ 3.0436e-04,  1.4040e-03, -1.4338e-04,  ..., -9.6487e-04,\n          3.6957e-04, -1.7970e-04],\n        [ 1.3870e-02,  1.8989e-03, -2.3409e-04,  ..., -9.2741e-05,\n         -2.1325e-03, -3.6861e-04],\n        ...,\n        [ 1.6965e-04,  7.5081e-04,  1.7243e-03,  ..., -3.3345e-04,\n          2.9291e-04,  4.6570e-03],\n        [-8.5313e-04,  5.1247e-04,  2.9434e-03,  ..., -1.6669e-04,\n          6.3304e-04,  8.2082e-04],\n        [-1.0297e-03, -5.9040e-05, -4.5201e-04,  ...,  1.1601e-03,\n          1.0437e-03,  2.4179e-04]], device='cuda:0', grad_fn=<SubBackward0>), tensor([[ 3.8871e-04,  1.5342e-03,  9.7731e-04,  ...,  1.5857e-04,\n          2.0651e-05, -2.0225e-04],\n        [ 3.0436e-04,  1.4040e-03, -1.4338e-04,  ..., -9.6487e-04,\n          3.6957e-04, -1.7970e-04],\n        [ 1.3870e-02,  1.8989e-03, -2.3409e-04,  ..., -9.2741e-05,\n         -2.1325e-03, -3.6861e-04],\n        ...,\n        [ 1.6965e-04,  7.5081e-04,  1.7243e-03,  ..., -3.3345e-04,\n          2.9291e-04,  4.6570e-03],\n        [-8.5313e-04,  5.1247e-04,  2.9434e-03,  ..., -1.6669e-04,\n          6.3304e-04,  8.2082e-04],\n        [-1.0297e-03, -5.9040e-05, -4.5201e-04,  ...,  1.1601e-03,\n          1.0437e-03,  2.4179e-04]], device='cuda:0',\n       grad_fn=<LigerKLDivLossFunctionBackward>), atol=1e-08, rtol=1e-06)
        E        +    where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose

        test/transformers/test_kl_div.py:75: AssertionError
        _________________________________________________ test_import_custom_cache_manager _________________________________________________

            def test_import_custom_cache_manager():
                from triton.runtime.cache import get_cache_manager

                from liger_kernel.triton import apply_liger_triton_cache_manager

                apply_liger_triton_cache_manager()
        >       cache_manager = get_cache_manager(key="test_hash")

        test/triton/test_triton_monkey_patch.py:17:
        _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
        /opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/cache.py:277: in get_cache_manager
            return __cache_cls(_base64(key))
        _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

        key = 'test_hash'

            def _base64(key):
                # Assume key is a hex string.
        >       return base64.urlsafe_b64encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
        E       ValueError: non-hexadecimal number found in fromhex() arg at position 0

        /opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/cache.py:261: ValueError
        ===================================================== short test summary info ======================================================
        FAILED test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[10.0-dtype5-1e-08-1e-06-sum-2-4096-32000--100] - AssertionError: assert False
         +  where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19,  ..., 1.3759e-13, 7.6381e-10,\n         4.4185e-23],\n        [2.9569e-12, 3.8580e-19, 5.3756e-16,  ..., 6.0166e-23, 1.4681e-17,\n         5.1994e-20],\n        [4.7900e-26, 1.0599e-04, 7.0237e-19,  ..., 1.1461e-20, 1.0415e-10,\n         1.0237e-19],\n        ...,\n        [6.9540e-17, 3.4471e-22, 2.7309e-14,  ..., 2.5999e-26, 2.5635e-19,\n         7.0793e-16],\n        [6.3721e-23, 1.2054e-13, 1.8638e-20,  ..., 1.2807e-23, 5.5705e-16,\n         2.3085e-13],\n        [1.9623e-20, 2.4720e-11, 1.8808e-15,  ..., 3.5100e-20, 3.6195e-15,\n         1.5356e-23]], device='cuda:0'), tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19,  ..., 1.3759e-13, 7.6381e-10,\n         4.4185e-23],\n        [2.9569e-12, 3.8580e-19, 5.3756e-16,  ..., 6.0166e-23, 1.4681e-17,\n         5.1994e-20],\n        [4.7900e-26, 1.0599e-04, 7.0237e-19,  ..., 1.1461e-20, 1.0415e-10,\n         1.0237e-19],\n        ...,\n        [6.9540e-17, 3.4471e-22, 2.7309e-14,  ..., 2.5999e-26, 2.5635e-19,\n         7.0793e-16],\n        [6.3722e-23, 1.2054e-13, 1.8638e-20,  ..., 1.2807e-23, 5.5705e-16,\n         2.3085e-13],\n        [1.9623e-20, 2.4720e-11, 1.8808e-15,  ..., 3.5100e-20, 3.6195e-15,\n         1.5356e-23]], device='cuda:0'), atol=1e-08, rtol=1e-06)
         +    where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose
         +    and   tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19,  ..., 1.3759e-13, 7.6381e-10,\n         4.4185e-23],\n        [2.9569e-12, 3.8580e-19, 5.3756e-16,  ..., 6.0166e-23, 1.4681e-17,\n         5.1994e-20],\n        [4.7900e-26, 1.0599e-04, 7.0237e-19,  ..., 1.1461e-20, 1.0415e-10,\n         1.0237e-19],\n        ...,\n        [6.9540e-17, 3.4471e-22, 2.7309e-14,  ..., 2.5999e-26, 2.5635e-19,\n         7.0793e-16],\n        [6.3721e-23, 1.2054e-13, 1.8638e-20,  ..., 1.2807e-23, 5.5705e-16,\n         2.3085e-13],\n        [1.9623e-20, 2.4720e-11, 1.8808e-15,  ..., 3.5100e-20, 3.6195e-15,\n         1.5356e-23]], device='cuda:0') = tensor([[  6.0503,   3.7258,  -0.3530,  ...,  11.8853,  20.5071,  -9.9739],\n        [ 15.2597,  -0.5924,   6.6471,  ...,  -9.3584,   3.0466,  -2.5966],\n        [-17.9122,  31.2363,  -1.4114,  ...,  -5.5268,  17.4033,  -3.3372],\n        ...,\n        [  4.3242,  -7.8904,  10.2973,  ..., -17.3829,  -1.2789,   6.6447],\n        [-10.9055,  10.4553,  -5.2270,  ..., -12.5100,   5.0782,  11.1050],\n        [ -5.8922,  15.0620,   5.5783,  ...,  -5.3107,   6.2329, -13.0452]],\n       device='cuda:0', requires_grad=True).grad
         +    and   tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19,  ..., 1.3759e-13, 7.6381e-10,\n         4.4185e-23],\n        [2.9569e-12, 3.8580e-19, 5.3756e-16,  ..., 6.0166e-23, 1.4681e-17,\n         5.1994e-20],\n        [4.7900e-26, 1.0599e-04, 7.0237e-19,  ..., 1.1461e-20, 1.0415e-10,\n         1.0237e-19],\n        ...,\n        [6.9540e-17, 3.4471e-22, 2.7309e-14,  ..., 2.5999e-26, 2.5635e-19,\n         7.0793e-16],\n        [6.3722e-23, 1.2054e-13, 1.8638e-20,  ..., 1.2807e-23, 5.5705e-16,\n         2.3085e-13],\n        [1.9623e-20, 2.4720e-11, 1.8808e-15,  ..., 3.5100e-20, 3.6195e-15,\n         1.5356e-23]], device='cuda:0') = tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19,  ..., 1.3759e-13, 7.6381e-10,\n         4.4185e-23],\n        [2.9569e-12, 3.8580e-19, 5.3756e-16,  ..., 6.0166e-23, 1.4681e-17,\n         5.1994e-20],\n        [4.7900e-26, 1.0599e-04, 7.0237e-19,  ..., 1.1461e-20, 1.0415e-10,\n         1.0237e-19],\n        ...,\n        [6.9540e-17, 3.4471e-22, 2.7309e-14,  ..., 2.5999e-26, 2.5635e-19,\n         7.0793e-16],\n        [6.3722e-23, 1.2054e-13, 1.8638e-20,  ..., 1.2807e-23, 5.5705e-16,\n         2.3085e-13],\n        [1.9623e-20, 2.4720e-11, 1.8808e-15,  ..., 3.5100e-20, 3.6195e-15,\n         1.5356e-23]], device='cuda:0', requires_grad=True).grad
        FAILED test/transformers/test_jsd.py::test_correctness_with_beta[0.1-dtype1-1e-08-1e-06-1-4096-128256] - AssertionError: Number of mismatched elements: 1
        Mismatch at index (): tensor1[()] = 0.08054989576339722, tensor2[()] = 0.08054977655410767
        FAILED test/transformers/test_jsd.py::test_correctness_with_beta[0.9-dtype1-1e-08-1e-06-1-4096-128256] - AssertionError: Number of mismatched elements: 1
        Mismatch at index (): tensor1[()] = 0.08054172992706299, tensor2[()] = 0.08054161071777344
        FAILED test/transformers/test_kl_div.py::test_correctness[dtype1-1e-08-1e-06-none-False-32-4096-1024] - AssertionError: assert False
         +  where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[ 3.8871e-04,  1.5342e-03,  9.7731e-04,  ...,  1.5857e-04,\n          2.0651e-05, -2.0225e-04],\n        [ 3.0436e-04,  1.4040e-03, -1.4338e-04,  ..., -9.6487e-04,\n          3.6957e-04, -1.7970e-04],\n        [ 1.3870e-02,  1.8989e-03, -2.3409e-04,  ..., -9.2741e-05,\n         -2.1325e-03, -3.6861e-04],\n        ...,\n        [ 1.6965e-04,  7.5081e-04,  1.7243e-03,  ..., -3.3345e-04,\n          2.9291e-04,  4.6570e-03],\n        [-8.5313e-04,  5.1247e-04,  2.9434e-03,  ..., -1.6669e-04,\n          6.3304e-04,  8.2082e-04],\n        [-1.0297e-03, -5.9040e-05, -4.5201e-04,  ...,  1.1601e-03,\n          1.0437e-03,  2.4179e-04]], device='cuda:0', grad_fn=<SubBackward0>), tensor([[ 3.8871e-04,  1.5342e-03,  9.7731e-04,  ...,  1.5857e-04,\n          2.0651e-05, -2.0225e-04],\n        [ 3.0436e-04,  1.4040e-03, -1.4338e-04,  ..., -9.6487e-04,\n          3.6957e-04, -1.7970e-04],\n        [ 1.3870e-02,  1.8989e-03, -2.3409e-04,  ..., -9.2741e-05,\n         -2.1325e-03, -3.6861e-04],\n        ...,\n        [ 1.6965e-04,  7.5081e-04,  1.7243e-03,  ..., -3.3345e-04,\n          2.9291e-04,  4.6570e-03],\n        [-8.5313e-04,  5.1247e-04,  2.9434e-03,  ..., -1.6669e-04,\n          6.3304e-04,  8.2082e-04],\n        [-1.0297e-03, -5.9040e-05, -4.5201e-04,  ...,  1.1601e-03,\n          1.0437e-03,  2.4179e-04]], device='cuda:0',\n       grad_fn=<LigerKLDivLossFunctionBackward>), atol=1e-08, rtol=1e-06)
         +    where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose
        FAILED test/transformers/test_kl_div.py::test_correctness_not_last[dtype1-1e-08-1e-06-none-False-32-4096-1024] - AssertionError: assert False
         +  where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[ 3.8871e-04,  1.5342e-03,  9.7731e-04,  ...,  1.5857e-04,\n          2.0651e-05, -2.0225e-04],\n        [ 3.0436e-04,  1.4040e-03, -1.4338e-04,  ..., -9.6487e-04,\n          3.6957e-04, -1.7970e-04],\n        [ 1.3870e-02,  1.8989e-03, -2.3409e-04,  ..., -9.2741e-05,\n         -2.1325e-03, -3.6861e-04],\n        ...,\n        [ 1.6965e-04,  7.5081e-04,  1.7243e-03,  ..., -3.3345e-04,\n          2.9291e-04,  4.6570e-03],\n        [-8.5313e-04,  5.1247e-04,  2.9434e-03,  ..., -1.6669e-04,\n          6.3304e-04,  8.2082e-04],\n        [-1.0297e-03, -5.9040e-05, -4.5201e-04,  ...,  1.1601e-03,\n          1.0437e-03,  2.4179e-04]], device='cuda:0', grad_fn=<SubBackward0>), tensor([[ 3.8871e-04,  1.5342e-03,  9.7731e-04,  ...,  1.5857e-04,\n          2.0651e-05, -2.0225e-04],\n        [ 3.0436e-04,  1.4040e-03, -1.4338e-04,  ..., -9.6487e-04,\n          3.6957e-04, -1.7970e-04],\n        [ 1.3870e-02,  1.8989e-03, -2.3409e-04,  ..., -9.2741e-05,\n         -2.1325e-03, -3.6861e-04],\n        ...,\n        [ 1.6965e-04,  7.5081e-04,  1.7243e-03,  ..., -3.3345e-04,\n          2.9291e-04,  4.6570e-03],\n        [-8.5313e-04,  5.1247e-04,  2.9434e-03,  ..., -1.6669e-04,\n          6.3304e-04,  8.2082e-04],\n        [-1.0297e-03, -5.9040e-05, -4.5201e-04,  ...,  1.1601e-03,\n          1.0437e-03,  2.4179e-04]], device='cuda:0',\n       grad_fn=<LigerKLDivLossFunctionBackward>), atol=1e-08, rtol=1e-06)
         +    where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose
        FAILED test/triton/test_triton_monkey_patch.py::test_import_custom_cache_manager - ValueError: non-hexadecimal number found in fromhex() arg at position 0
        ================================ 6 failed, 1012 passed, 8 skipped, 72 warnings in 630.02s (0:10:30) ================================
        make: *** [Makefile:8: test] Error 1
    ```
    </details>

    ---------

    Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
    Co-authored-by: root <tjtanaa>

commit a2f301759e051278c1491a1acd2e8ae9d09d21c5
Author: hoshi-hiyouga <hiyouga@buaa.edu.cn>
Date:   Sat Nov 2 14:51:31 2024 +0800

    Fix llama forward patch (#339)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->

    The present version of liger kernel use `kwargs` in model forward
    function, while in transformers 4.46.0-4.46.1, they pass the
    `num_items_in_batch` parameter when `loss_kwargs` was in the model's
    forward function [1][2], thus, we change the `kwargs` to `loss_kwargs`
    to align with the transformers' implementation [3].

    [1]
    https://github.com/huggingface/transformers/blob/v4.46.1/src/transformers/trainer.py#L593
    [2]
    https://github.com/huggingface/transformers/blob/v4.46.1/src/transformers/trainer.py#L3620-L3625
    [3]
    https://github.com/huggingface/transformers/blob/v4.46.1/src/transformers/models/llama/modeling_llama.py#L1137-L1151

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [ ] run `make test` to ensure correctness
    - [ ] run `make checkstyle` to ensure code style
    - [ ] run `make test-convergence` to ensure convergence

commit 1b04de6b47845f47473500ea18ed55b87e68a68e
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Fri Nov 1 13:18:31 2024 -0700

    Update pyproject.toml

    After https://github.com/linkedin/Liger-Kernel/pull/274, triton needs to be >=2.3.1

commit ac2e8f4563289f7bee0ad9652926afec5c46747b
Author: Yun Dai <yundai424@gmail.com>
Date:   Thu Oct 31 21:46:53 2024 -0700

    Fix FusedLinearJSD precision issue when using AMP (#336)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->
    1. make sure all the computation between logit to final JSD loss happen
    on FP32
    2. make sure FLJSD works properly under mixed precision scenario, also
    add a test to guard
    3. the Torch CE loss impl we use in testing FLCE misses out the fp32
    cast for logits, add it back. **NOTE: we should definitely jus switch
    directly to [HF
    impl](https://github.com/huggingface/transformers/blob/main/src/transformers/loss/loss_utils.py#L32)
    for testing to ensure always doing apple-to-apple comparison. See the
    added TODO item.**

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

commit 659d7d7856bf755c1cf26f2df6173da68841ba17
Author: Chiwan Park <chiwanpark@hotmail.com>
Date:   Fri Nov 1 08:24:06 2024 +0900

    Fix incorrect training of first and last Medusa heads (#325)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->
    Currently, there are two errors on Medusa training examples:

    1. When we use Liger Kernel, the first head (`model.medusa_head[0]`) is
    not trained.
    2. When we don't use Liger Kernel, the logits of the last head
    (`medusa_logits[-1]`) is ignored.

    This PR fixes these errors.

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: A100 80GB 8 GPUs
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

commit 827b51c45762d6fc0ffaa7655126467c16f06d44
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Thu Oct 31 15:33:05 2024 -0700

    Update llama.py

commit e28521bed9f13daacdc363b6975158a2e67ec3a4
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Thu Oct 31 14:40:41 2024 -0700

    Fix huggingface GA issue for llama (#333)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    To fix https://github.com/linkedin/Liger-Kernel/pull/322

    This PR introduces a new `lce_forward` compatible with
    `transformers>=4.46.0` (after grad acc fix) while ensuring backward
    compatibilty.

    To be specific, i keep the original flce untouched and write a new one
    for `4.46.0`. If HF version is `<4.46.0`, it will show a warning for
    deprecation, and fallback to the old flce.

    ```python
            if transformer_version >= version.parse("4.46.0"):
                modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
            else:  # if version < 4.46.0
                logger.warning(
                    "Support for transformers versions < 4.46.0 will soon be discontinued due to issues with incorrect gradient accumulation. "
                    "Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/34191"
                )
                modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
    ```

    For more context of grad acc fix, please see
    https://github.com/huggingface/transformers/pull/34191

    ## TODO

    - [ ] broadcast the changes to all models once the effect is verified.

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

commit 337bf9a8361740c1caf38ba28b9dc9f7303c9aca
Author: Anish <98446102+novanish@users.noreply.github.com>
Date:   Thu Oct 31 06:04:25 2024 +0545

    docs(CONTRIBUTING): fix typo (#331)

    ## Fix typo in CONTRIBUTING.md

    This PR corrects a typo in the CONTRIBUTING.md file, changing
    "functionaility" to "functionality" in the semantic versioning section.

    Co-authored-by: Yun Dai <yundai424@gmail.com>

commit 48aa62d3ecb0a46009d2b92510a63e39e860fe82
Author: Tcc0403 <76503978+Tcc0403@users.noreply.github.com>
Date:   Thu Oct 31 01:15:12 2024 +0800

    Add missing ignore_index tests (#310)

    ## Summary
    `ignore_index` in fused_linear_cross_entropy was not tested

    ## Testing Done

    - Hardware Type: gpu-ci
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

    ---------

    Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
    Co-authored-by: Yun Dai <yundai424@gmail.com>

commit 1c0c75c3455e788d575966bfc5edec3ef166835e
Author: Yun Dai <yundai424@gmail.com>
Date:   Tue Oct 29 21:59:37 2024 -0700

    fix fused JSD with ignore index (#330)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->
    1. There's currently a bug in fused linear JSD where we don't extract
    the correct subset of label corresponding to the currently processed
    chunk
    2. add some tests to make sure results are correct when all tokens are
    ignored
    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->
    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

commit 6cdc93deee15ab6c843149d6ed660c297c5c2d4a
Author: Yun Dai <yundai424@gmail.com>
Date:   Fri Oct 25 17:23:23 2024 -0700

    fix FLCE AMP issue (#318)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->
    fixes #305 : just rely on torch AMP to determine the input dtype when
    AMP context is enabled
    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [x] run `make test` to ensure correctness
    - [x] run `make checkstyle` to ensure code style
    - [x] run `make test-convergence` to ensure convergence

commit 9ad8f89373b2206e86e9bb1cdc6e63c37275bd81
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Fri Oct 25 09:53:42 2024 -0700

    Update README.md

commit 4e2f7c6b9185560294c24ee48c32c07cefc7e828
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Fri Oct 25 09:53:08 2024 -0700

    remove torch compile section until the issue is fixed

commit 99599091373f178e8ad6a69ecb1b32351d1d5c1f
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Mon Oct 21 14:41:32 2024 -0700

    Update README.md

commit e49b83a4af985ef1f75c994bbdb7ed103b22ae11
Author: Byron Hsu <byronhsu1230@gmail.com>
Date:   Mon Oct 21 14:40:01 2024 -0700

    Update citation and add tech report (#317)

    ## Summary
    <!--- This is a required section; please describe the main purpose of
    this proposed code change. --->

    <!---
    ## Details
    This is an optional section; is there anything specific that reviewers
    should be aware of?
    --->

    ## Testing Done
    <!--- This is a required section; please describe how this change was
    tested. --->

    <!--
    Replace BLANK with your device type. For example, A100-80G-PCIe

    Complete the following tasks before sending your PR, and replace `[ ]`
    with
    `[x]` to indicate you have done them.
    -->

    - Hardware Type: <BLANK>
    - [ ] run `make test` to ensure correctness
    - [ ] run `make checkstyle` to ensure code style
    - [ ] run `make test-convergence` to ensure convergence

commit 7da01b7188266342b94858fd2e01bf037099441c
Author: Kürşat Aktaş <kursat.ce@gmail.com>
Date:   Tue Oct 22 00:22:41 2024 +0300

    Introducing Liger Kernel Guru on Gurubase.io (#316)

    I created the [Liger Kernel Guru](https://gurubase.io/g/liger-kernel)
    badge on Gurubase.io upon request from @ByronHsu.

    Adding a new badge next to the Discord badge made all the badge text
    smaller, as the current style presen…
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* quick fix

* 3 losses

* oups

* fix

* nits

* check how it scales for special models

* propagate for conditiona detr

* propagate

* propagate

* propagate

* fixes

* propagate changes

* update

* fixup

* nits

* f string

* fixes

* more fixes

* ?

* nit

* arg annoying f string

* nits

* grumble

* update

* nit

* refactor

* fix fetch tests

* nit

* nit

* Update src/transformers/loss/loss_utils.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* update

* nit

* fixup

* make pass

* nits

* port code to more models

* fixup

* ntis

* arf

* update

* update

* nits

* update

* fix

* update

* nits

* fine

* agjkfslga.jsdlkgjklas

* nits

* fix fx?

* update

* update

* styel

* fix imports

* update

* update

* fixup to fix the torch fx?

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
@@ -1291,27 +1291,8 @@ def forward(

loss = None
if labels is not None:
labels = labels.to(logits.device)
Copy link

@cphillippi-stripe cphillippi-stripe Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @ArthurZucker,

I'm seeing issues w/ multi-gpu sequence classification training after this change (using Mistral). I believe it is due to the removal of this line (which I'm having a very hard time monkeypatching via loss_function for some reason). I also see this line very often in quite a few different models here. Were multi-gpu setups tested here for training in sequence classification mode? I'm really curious how this passed if so, because I don't see any lines like this in the new ForSequenceClassificationLoss function that self.loss_function now seems to resolve to.

I'm seeing tracebacks like this (edited for privacy):

...
215       File ".../training_workflow.py", line 426, in train
216     trainer.train()
217       File ".../transformers/transformers/transformers/trainer.py", line 2123, in train
218     return inner_training_loop(
219            ^^^^^^^^^^^^^^^^^^^^
220       File ".../transformers/transformers/transformers/trainer.py", line 2481, in _inner_training_loop
221     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
222                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
223       File ".../transformers/transformers/transformers/trainer.py", line 3579, in training_step
224     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
225            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
226       File ".../transformers/transformers/transformers/trainer.py", line 3633, in compute_loss
227     outputs = model(**inputs)
228               ^^^^^^^^^^^^^^^
229       File ".../torch/torch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
230     return self._call_impl(*args, **kwargs)
231            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
232       File ".../torch/torch/torch/nn/modules/module.py", line 1541, in _call_impl
233     return forward_call(*args, **kwargs)
234            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
235       File ".../accelerate/accelerate/accelerate/utils/operations.py", line 823, in forward
236     return model_forward(*args, **kwargs)
237            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
238       File ".../accelerate/accelerate/accelerate/utils/operations.py", line 811, in __call__
239     return convert_to_fp32(self.model_forward(*args, **kwargs))
240                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
241       File ".../torch/torch/torch/amp/autocast_mode.py", line 16, in decorate_autocast
242     return func(*args, **kwargs)
243            ^^^^^^^^^^^^^^^^^^^^^
244       File ".../peft/peft/peft/peft_model.py", line 1446, in forward
245     return self.base_model(
246            ^^^^^^^^^^^^^^^^
247       File ".../torch/torch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
248     return self._call_impl(*args, **kwargs)
249            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
250       File ".../torch/torch/torch/nn/modules/module.py", line 1541, in _call_impl
251     return forward_call(*args, **kwargs)
252            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
253       File ".../peft/peft/peft/tuners/tuners_utils.py", line 197, in forward
254     return self.model.forward(*args, **kwargs)
255            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
256       File ".../accelerate/accelerate/accelerate/hooks.py", line 170, in new_forward
257     output = module._old_forward(*args, **kwargs)
258              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
259       File ".../transformers/transformers/transformers/models/mistral/modeling_mistral.py", line 1200, in forward
260     loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
261            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
262       File ".../transformers/transformers/transformers/loss/loss_utils.py", line 67, in ForSequenceClassificationLoss
263     loss = fixed_cross_entropy(pooled_logits.view(-1, num_labels), labels.view(-1), **kwargs)
264            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
265       File ".../transformers/transformers/transformers/loss/loss_utils.py", line 26, in fixed_cross_entropy
266     loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
267            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
268       File ".../torch/torch/torch/nn/functional.py", line 3086, in cross_entropy
269     return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
270            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
271 
272 Message:
273 
274     RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:7 and cuda:0! (when checking argument for argument target in method wrapper_CUDA_nll_loss_forward)

For now, if you think you have a good way I can just patch this, I'd appreciate it, but replacing PretrainedModel.loss_function doesn't seem to do the trick...

@bauwenst
Copy link
Contributor

bauwenst commented Dec 12, 2024

FYI: As also noted by @cphillippi-stripe, the @property PreTrainedModel.loss_function that was added by this PR breaks all software packages that have models which (1) inherit from PreTrainedModel while (2) setting a loss_function beforehand.

I maintain such a package, and indeed, training is now broken and my IDE complains that loss_function cannot be set.

image

The default behaviour of setting the loss to causal LM loss when no loss_type is defined in the config seems highly unwarranted to me. It is assumed now that either (1) models define this brand new loss_type field in the config and otherwise (2) the loss type can be inferred from the ForXYZ suffix:

if getattr(self.config, "loss_type", None) is not None:
loss_type = self.config.loss_type
else:
loss_type = self.__class__.__name__
if loss_type not in LOSS_MAPPING:
loss_groups = f"({'|'.join(LOSS_MAPPING)})"
loss_type = re.findall(loss_groups, self.__class__.__name__)
if len(loss_type) > 0:
loss_type = loss_type[0]
else:
loss_type = None

Yet, this fails to take into account the following cases where transformers is used as a dependency:

  1. Model classes defined for tasks that don't have a known auto class in transformers. For example: if your package defines a ModelForDependencyParsing, the above regex search will not find its loss and default to causal loss, which is wrong.
  File ".../ArchIt/src/archit/instantiation/tasks.py", line 328, in computeLoss
    return self.loss_function(arc_scores, arc_labels) + self.loss_function(rel_scores, rel_labels)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: ForCausalLMLoss() missing 1 required positional argument: 'vocab_size'
  1. Model classes that have a known task, but it's not in the name (e.g. ModelForSentenceClassifying). Unless the config is from after October 2024, the config does not define its loss type.

  2. Loss is not inherent to a PreTrainedModel. There are plenty of papers that start from the model weights of one model and then change its loss function for continued pretraining (take this one as a recent example).

TL;DR: Seems like a design flaw to make loss_function a @property and seems like a mistake to overwrite it to ForCausalLMLoss for all models that don't fit the very narrow range of tasks supported by transformers.

@cphillippi-stripe
Copy link

cphillippi-stripe commented Dec 12, 2024

FYI: As also noted by @cphillippi-stripe, the @property PreTrainedModel.loss_function that was added by this PR breaks all software packages that have models which (1) inherit from PreTrainedModel while (2) setting a loss_function beforehand.

I maintain such a package, and indeed, training is now broken and my IDE complains that loss_function cannot be set.

image

The default behaviour of setting the loss to causal LM loss when no loss_type is defined in the config seems highly unwarranted to me. It is assumed now that either models define this brand new loss_type field in the config and that otherwise the loss type can be inferred from the ForXYZ suffix:

if getattr(self.config, "loss_type", None) is not None:
loss_type = self.config.loss_type
else:
loss_type = self.__class__.__name__
if loss_type not in LOSS_MAPPING:
loss_groups = f"({'|'.join(LOSS_MAPPING)})"
loss_type = re.findall(loss_groups, self.__class__.__name__)
if len(loss_type) > 0:
loss_type = loss_type[0]
else:
loss_type = None

Yet, this fails to take into account the following cases where transformers is used as a dependency:

  1. Model classes defined for tasks that don't have a known auto class in transformers. For example: if your package defines a ModelForDependencyParsing, the above regex search will not find its loss and default to causal loss, which is wrong.
  File ".../ArchIt/src/archit/instantiation/tasks.py", line 328, in computeLoss
    return self.loss_function(arc_scores, arc_labels) + self.loss_function(rel_scores, rel_labels)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: ForCausalLMLoss() missing 1 required positional argument: 'vocab_size'
  1. Model classes that have a known task, but it's not in the name (e.g. ModelForSentenceClassifying). Unless the config is from pre-October 2024, the config does not define its loss type.
  2. Loss is not inherent to a PreTrainedModel. There are plenty of papers that start from the model weights of one model and then change its loss function for continued pretraining (take this one as a recent example).

TL;DR: Seems like a design flaw to make loss_function a @property and seems like a mistake to overwrite it to ForCausalLMLoss for all models that don't fit the very narrow range of tasks supported by transformers.

Not a fix but a temporary workaround for me was this patch after loading a mdoel:

model = AutoModelForSequenceClassification.from_pretrained(...)

old_loss_function = model.loss_function
@lru_cache
def get_loss_function(self):
    def fixed_loss(labels, pooled_logits, config, **kwargs):
        labels = labels.to(pooled_logits.device)
        return old_loss_function(labels, pooled_logits, config, **kwargs)
    return fixed_loss
    
model.__class__.loss_function = property(get_loss_function)

Also, @ArthurZucker, I believe this fix confirms removing the labels = labels.to(pooled_logits.device) line breaks the multi-gpu setup. I would be very surprised if only Mistral is affected here given similar lines were removed from other models. Initially, I tried patching this at the PreTrainedModel.loss_function level, but for whatever reason the loaded models don't seem to pick up the change, and I'm not sure why.

@ArthurZucker
Copy link
Collaborator Author

Feedback taken, will open a PR to get to something better. It's hard to take everything into account, and support important to get your feedbacks! 🤗
Will merge the fix to loss on multi-GPU (I don't think we test all tasks for training and that is a flaw in our test environnement for sure! ) cc @ydshieh as for a TODO 😉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.