From 96c4f20225f3a9e2bbd0358fb2f1aba3f4f5a676 Mon Sep 17 00:00:00 2001 From: wizyoung Date: Thu, 7 Nov 2024 16:11:46 +0800 Subject: [PATCH] Squashed commit of the following: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit commit ae7e13ba1eaf58e5066b5cd60dfddf4f66f3cfed Merge: ede50df 280cb81 Author: Wizyoung Date: Thu Nov 7 15:58:13 2024 +0800 Merge branch 'linkedin:main' into main commit 280cb8139511753ab3a16f286ebffe694ddd1970 Author: Haoyi Wu <43395692+why-in-Shanghaitech@users.noreply.github.com> Date: Thu Nov 7 13:45:16 2024 +0800 Improve compatibility to access the base models (#340) ## Summary This PR resolves #337, which improves the compatibility to access the base models through the `base_model_prefix` attribute. ## Details One thing to mention: The `mllama` seems to be an outlier. It has text model and vision model so it is impossible to access through one attribute. Meanwhile, the `base_model_prefix` seems to have different semantics for `mllama` model classes. I left the codes for `mllama` unchanged. For other models, I look into the `transformers` library and manually check the correctness. ## Testing Done The changes passed `test/transformers/test_monkey_patch.py` by running `pytest`. - Hardware Type: RTX 3090 - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence Co-authored-by: Byron Hsu commit ab5e88be1950aba248555e5e01907de04329e4dc Author: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Thu Nov 7 13:29:08 2024 +0800 Support Z Loss in CE (#239) ## Summary This PR aims to resolve #197 Implemented z loss in LigerCrossEntropy. note: `lse_square_scale` not exposed at flce yet, having issues passing the tests. ## Details ### For loss: ```math \begin{align} L_{total} &= L_{ce} + z\_loss\ z\_loss &= lse\_square\_scale \cdot lse^2\ lse &= log \sum e^{X_i} \end{align} ``` We can use $m = max(X_i)$ and $d = \sum e^{X_i - m}$, obtained from online softmax algorithm, to calculate $lse$ directly. ```math \begin{align} lse &= log \sum e^{X_i}\ &= log \sum e^{X_i - m + m} = log \sum e^{X_i -m} \cdot e^m\ &= log\ e^m\sum e^{X_i - m} = m + d \end{align} ``` ### For gradients: First, we calculate the derivative of lse ```math \begin{align} \frac{\partial}{\partial x_i}(lse) &= \frac{\partial}{\partial x_i}(log \sum e^{x_i}) \ &= \frac{1}{\sum e^{x_i}} \cdot \frac{\partial}{\partial x_i} \sum e^{x_i}\ &= \frac{e^{x_i}}{\sum e^{x_i}} = softmax(x_i). \end{align} ``` Then we can obtain the derivative of z_loss by chain rule. ```math \frac{\partial z\_loss}{\partial x_i} = \frac{\partial}{\partial x_i}\left( lse\_square\_scale \cdot lse^2\right) = 2\cdot lse\_square\_scale \cdot lse \cdot softmax(x_i), ``` and we have the derivative of cross entropy loss with label smoothing ```math \frac{\partial L_{ce}}{\partial x_i} = softmax(x_i) - (1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K}= \begin{cases} softmax(x_i) - \frac{\epsilon}{K}, & i \neq y \\ softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon) & i = y \end{cases} ``` where $\epsilon$ is label_smoothing and $K$ is the number of total classes. Thus, the derivative of total loss is ```math \begin{align} \frac{\partial}{\partial x_i}L_{total} &= \frac{\partial}{\partial x_i}L_{ce} + \frac{\partial}{\partial x_i}z\_loss\ &= softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon)\delta_{k,y} + 2\cdot lse\_square\_scale \cdot lse \cdot softmax(x_i)\ &=\begin{cases} (1 + 2\cdot lse\_square\_scale \cdot lse)\ softmax(x_i) - \frac{\epsilon}{K}, & i \neq y\\ (1 + 2\cdot lse\_square\_scale \cdot lse)\ softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon), & i = y \end{cases} \end{align} ``` ### Reference [PaLM: Scaling Language Modeling with Pathways](https://www.jmlr.org/papers/v24/22-1144.html) [Chameleon: Mixed-Modal Early-Fusion Foundation Models](https://arxiv.org/abs/2405.09818) ## Testing Done [benchmark gist](https://gist.github.com/Tcc0403/b9120282334196f66b5169d9f52bccaa) neglectable error in speed benchmark. This benchmark was done on my machine, which is probably not accurate. ``` liger ce: 66.123ms Peak mem: 8.66200832 liger ce with zloss: 65.991ms Peak mem: 8.66200832 liger ce with zloss with return zloss: 65.951ms Peak mem: 8.662073856 ``` - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang Co-authored-by: Byron Hsu commit 85d34efbd423cd97d3e97525af419193fbb07354 Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Wed Nov 6 17:44:54 2024 +0000 BUG: Fix bug in layer norm tests. (#359) ## Summary This PR fixes a bug in a test case for layer norm, where the assert on the gradient of x was incorrectly compared against itself meaning that the assertion would always succeed. ## Testing Done Tested on, A100-80G-SXM4 - Hardware Type: - [X] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [X] run `make test-convergence` to ensure convergence commit c131f0423ccef96e71a13d58bda168f5904bfa89 Author: Byron Hsu Date: Tue Nov 5 16:50:38 2024 -0800 Update ci.yml commit 985e6c74b61656061f28be74434a6de2de3aabfd Author: Byron Hsu Date: Tue Nov 5 16:13:49 2024 -0800 Update ci.yml commit a8c085488f3c47b86b2d560a1225bc27ec59c68d Author: Byron Hsu Date: Tue Nov 5 15:58:11 2024 -0800 fixing ci commit e985195bec82ea9d89b9d20a758356eee1650dc1 Author: Byron Hsu Date: Tue Nov 5 14:10:52 2024 -0800 Update pyproject.toml commit 98d77e077d7bf8335a4a7748067ea8fc3633e3ef Author: Byron Hsu Date: Tue Nov 5 14:05:27 2024 -0800 broadcast grad acc fix to all models (#354) ## Summary follow up for https://github.com/linkedin/Liger-Kernel/pull/339 However, identify few issues 1. revert patching causes flce not taking effect (comment out revert patching for now, and only test float32) 2. qwen2 vl flce is broken. we should fix later 3. we should provide a real "on-instance" patch that does not use any monkey patch. now the on-instance patch still relies on monkey patch ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence commit ef3f55dcd06b4fca95a5b75c9fe51ef1b7b7bfef Author: Byron Hsu Date: Mon Nov 4 17:04:47 2024 -0800 merge two tests into one (#349) ## Summary remove the launching overhead of the 2nd container ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence commit b09fb65a37a045aa64e92b4d493897ba1c462ce8 Author: Byron Hsu Date: Mon Nov 4 16:40:52 2024 -0800 Trim conv test (#348) ## Summary Remove non flce convergence test since most users are using flce ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence commit fbcb52d615f46f54ce865cec028ce5c64a205a2a Author: ByronHsu Date: Mon Nov 4 22:54:09 2024 +0000 Move dependent license to a folder commit a2dfa3cb2f7b6f0e23a65ad76b38a6b567404a2c Author: Byron Hsu Date: Mon Nov 4 14:04:40 2024 -0800 Aggressively trim test bloat (#346) ## Summary 1. Disable the test for experimental kernels 2. Reduce the size of tensor if the tests takes too long 3. Remove redundant tests that are testing the same thing Make sure unit test time < 5 mins ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence commit e68b291f11d2f1ab22c5db9b1038021ee1821a0e Author: Byron Hsu Date: Mon Nov 4 13:14:38 2024 -0800 avoid duplicate ci (#345) ## Summary ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence commit c34843c45eb8c3501d54f506fa359401e06d0166 Author: Byron Hsu Date: Mon Nov 4 13:08:19 2024 -0800 set up modal ci (#344) ## Summary follow https://github.com/modal-labs/ci-on-modal ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence commit ac7b38a2fdd3368b648d5ee02f6c0fb8661d8005 Author: TJian Date: Sun Nov 3 01:07:39 2024 +0800 [AMD] [ROCm] Pick `num_warps` based on platform (#326) ## Summary This is a PR to enable the kernel to run on AMD GPUs through the initial changes to the `num_warps`. This change is proposed by @Edenzzzz and @DocShotgun in this issue https://github.com/linkedin/Liger-Kernel/issues/266 ## Details I have updated the `transformers` version from `4.44.0` to `4.46.0` requirement and all unit tests passed on A100 and MI300X. ## Testing Done - Hardware Type: AMD Instinct MI300X - [x] run `make test` to ensure correctness - There are some test failed due to numerical precision issue. Passed by relaxing the condition by 1 order of magnitude (following the advice in the Liger-Kernel technical report https://arxiv.org/pdf/[2410.10989](https://arxiv.org/pdf/2410.10989) **Footnote 12:** _Note that in practice, the tolerance may need further relaxation in some cases by one or two orders of magnitude, even for exact kernels. We use convergence tests to ensure exactness in cases where the tolerance for correctness needs to be loose._ ) - The test that the tolerance are relaxed involves `kl_div` and `jsd` in `float32` tests - The relax conditions are described by the following code snippet ``` _DTYPE_PARAMS = ( "dtype, atol, rtol", [ pytest.param( torch.bfloat16, 1e-8, 5e-2, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), (torch.float32, 1e-8 if not is_hip() else 1e-7, 1e-6), (torch.float16, 1e-3, 1e-3), ], ) ``` - To pass the test, the triton must not be installed from source, it must be installed through pypi `pip install triton==3.0.0`. This issue will be tracked with an issue at triton https://github.com/triton-lang/triton/issues/5013 . - ~~Something is weird as well, if I just run the failed test `test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[10.0-dtype5-1e-08-1e-06-sum-2-4096-32000--100]`, the test passed. By running `pytest test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[10.0-dtype5-1e-08-1e-06-sum-2-4096-32000--100]`. However it will failed if there are other tests running before this test.~~ - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
Failure Test Logs (Click to expand/collapse) ```bash ============================================================= FAILURES ============================================================= ________________________ test_correctness_with_ignore_index[10.0-dtype5-1e-08-1e-06-sum-2-4096-32000--100] _________________________ B = 2, T = 4096, V = 32000, ignore_index = -100, reduction = 'sum', scalar = 10.0, dtype = torch.float32, atol = 1e-08, rtol = 1e-06 @pytest.mark.parametrize( "B, T, V, ignore_index", [ (2, 4096, 32000, -100), # llama2, mistral (2, 4096, 32000, 2), # llama2, mistral (1, 4096, 128256, -300), # llama3 # weird shapes (3, 423, 32000, -123), ], ) @pytest.mark.parametrize("reduction", ["sum", "mean"]) @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [ pytest.param( 0.1, torch.bfloat16, 1e-8, 5e-2, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), pytest.param( 1.0, torch.bfloat16, 1e-8, 5e-2, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), pytest.param( 10.0, torch.bfloat16, 1e-8, 5e-2, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), (0.1, torch.float32, 1e-8, 1e-6), (1.0, torch.float32, 1e-8, 1e-6), (10.0, torch.float32, 1e-8, 1e-6), ], ) @pytest.mark.skipif( torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000, reason="Needs 16GB+ GPU memory.", ) def test_correctness_with_ignore_index( B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol ): liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction) > _test_correctness_with_ignore_index_once( liger_ce, B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol ) test/transformers/test_cross_entropy.py:302: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ target_ce = LigerCrossEntropyLoss(), B = 2, T = 4096, V = 32000, ignore_index = -100, reduction = 'sum', scalar = 10.0 dtype = torch.float32, atol = 1e-08, rtol = 1e-06 def _test_correctness_with_ignore_index_once( target_ce, B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol ): torch_ce = CrossEntropyLoss(ignore_index=ignore_index, reduction=reduction) _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar _input = _tensor.detach().clone().requires_grad_(True) _input2 = _tensor.detach().clone().requires_grad_(True) target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) # Assign some random number of elements as ignore_index num_elements_to_assign = torch.randint( 1, B * T // 2, (1,) ).item() # Random number of elements to set to ignore_index indices_to_assign = torch.randperm(B * T)[ :num_elements_to_assign ] # Randomly select indices target[indices_to_assign] = ignore_index output = torch_ce(_input, target) output2 = target_ce(_input2, target) assert torch.allclose(output, output2, atol=atol, rtol=rtol) output.backward() output2.backward() > assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) E AssertionError: assert False E + where False = (tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3721e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0'), tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3722e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0'), atol=1e-08, rtol=1e-06) E + where = torch.allclose E + and tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3721e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0') = tensor([[ 6.0503, 3.7258, -0.3530, ..., 11.8853, 20.5071, -9.9739],\n [ 15.2597, -0.5924, 6.6471, ..., -9.3584, 3.0466, -2.5966],\n [-17.9122, 31.2363, -1.4114, ..., -5.5268, 17.4033, -3.3372],\n ...,\n [ 4.3242, -7.8904, 10.2973, ..., -17.3829, -1.2789, 6.6447],\n [-10.9055, 10.4553, -5.2270, ..., -12.5100, 5.0782, 11.1050],\n [ -5.8922, 15.0620, 5.5783, ..., -5.3107, 6.2329, -13.0452]],\n device='cuda:0', requires_grad=True).grad E + and tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3722e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0') = tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3722e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0', requires_grad=True).grad test/transformers/test_cross_entropy.py:61: AssertionError _________________________________ test_correctness_with_beta[0.1-dtype1-1e-08-1e-06-1-4096-128256] _________________________________ B = 1, T = 4096, V = 128256, beta = 0.1, dtype = torch.float32, atol = 1e-08, rtol = 1e-06 @pytest.mark.parametrize(*_SHAPE_PARAMS) @pytest.mark.parametrize(*_DTYPE_PARAMS) @pytest.mark.parametrize("beta", [0.1, 0.5, 0.9]) def test_correctness_with_beta(B, T, V, beta, dtype, atol, rtol): liger_jsd = LigerJSD(beta=beta) > _test_correctness_with_beta_once(liger_jsd, beta, B, T, V, dtype, atol, rtol) test/transformers/test_jsd.py:269: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ test/transformers/test_jsd.py:157: in _test_correctness_with_beta_once assert_verbose_allclose(output, output2, atol=atol, rtol=rtol) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ tensor1 = tensor(0.0805, device='cuda:0', grad_fn=) tensor2 = tensor(0.0805, device='cuda:0', grad_fn=), rtol = 1e-06, atol = 1e-08, max_print = 5 def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5): """ Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. Parameters: tensor1 (torch.Tensor): First tensor to compare. tensor2 (torch.Tensor): Second tensor to compare. rtol (float): Relative tolerance. atol (float): Absolute tolerance. max_print (int): Maximum number of mismatched elements to print. Raises: AssertionError: If the tensors are not all close within the given tolerance. """ # Check if the shapes of the tensors match if tensor1.shape != tensor2.shape: raise AssertionError("Input tensors must have the same shape.") # Calculate the difference between the tensors diff = torch.abs(tensor1 - tensor2) # Determine the tolerance tolerance = atol + rtol * torch.abs(tensor2) # Find tolerance mismatched elements tol_mismatched = diff > tolerance # Find nan mismatched elements nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2)) # Find +inf mismatched elements posinf_mismatched = torch.logical_xor( torch.isposinf(tensor1), torch.isposinf(tensor2) ) # Find -inf mismatched elements neginf_mismatched = torch.logical_xor( torch.isneginf(tensor1), torch.isneginf(tensor2) ) # Find all mismatched elements mismatched = torch.logical_or( torch.logical_or(tol_mismatched, nan_mismatched), torch.logical_or(posinf_mismatched, neginf_mismatched), ) mismatched_indices = torch.nonzero(mismatched) # Count the number of mismatched elements num_mismatched = mismatched.sum().item() # Check if all elements are close all_close = num_mismatched == 0 # Raise AssertionError with detailed information if there are mismatches if not all_close and num_mismatched >= 1: mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] print_count = min(max_print, num_mismatched) for index in mismatched_indices[:print_count]: i = tuple(index.tolist()) mismatch_details.append( f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}" ) if num_mismatched > max_print: mismatch_details.append( f"... and {num_mismatched - max_print} more mismatched elements." ) > raise AssertionError("\n".join(mismatch_details)) E AssertionError: Number of mismatched elements: 1 E Mismatch at index (): tensor1[()] = 0.08054989576339722, tensor2[()] = 0.08054977655410767 test/utils.py:106: AssertionError _________________________________ test_correctness_with_beta[0.9-dtype1-1e-08-1e-06-1-4096-128256] _________________________________ B = 1, T = 4096, V = 128256, beta = 0.9, dtype = torch.float32, atol = 1e-08, rtol = 1e-06 @pytest.mark.parametrize(*_SHAPE_PARAMS) @pytest.mark.parametrize(*_DTYPE_PARAMS) @pytest.mark.parametrize("beta", [0.1, 0.5, 0.9]) def test_correctness_with_beta(B, T, V, beta, dtype, atol, rtol): liger_jsd = LigerJSD(beta=beta) > _test_correctness_with_beta_once(liger_jsd, beta, B, T, V, dtype, atol, rtol) test/transformers/test_jsd.py:269: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ test/transformers/test_jsd.py:157: in _test_correctness_with_beta_once assert_verbose_allclose(output, output2, atol=atol, rtol=rtol) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ tensor1 = tensor(0.0805, device='cuda:0', grad_fn=) tensor2 = tensor(0.0805, device='cuda:0', grad_fn=), rtol = 1e-06, atol = 1e-08, max_print = 5 def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5): """ Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. Parameters: tensor1 (torch.Tensor): First tensor to compare. tensor2 (torch.Tensor): Second tensor to compare. rtol (float): Relative tolerance. atol (float): Absolute tolerance. max_print (int): Maximum number of mismatched elements to print. Raises: AssertionError: If the tensors are not all close within the given tolerance. """ # Check if the shapes of the tensors match if tensor1.shape != tensor2.shape: raise AssertionError("Input tensors must have the same shape.") # Calculate the difference between the tensors diff = torch.abs(tensor1 - tensor2) # Determine the tolerance tolerance = atol + rtol * torch.abs(tensor2) # Find tolerance mismatched elements tol_mismatched = diff > tolerance # Find nan mismatched elements nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2)) # Find +inf mismatched elements posinf_mismatched = torch.logical_xor( torch.isposinf(tensor1), torch.isposinf(tensor2) ) # Find -inf mismatched elements neginf_mismatched = torch.logical_xor( torch.isneginf(tensor1), torch.isneginf(tensor2) ) # Find all mismatched elements mismatched = torch.logical_or( torch.logical_or(tol_mismatched, nan_mismatched), torch.logical_or(posinf_mismatched, neginf_mismatched), ) mismatched_indices = torch.nonzero(mismatched) # Count the number of mismatched elements num_mismatched = mismatched.sum().item() # Check if all elements are close all_close = num_mismatched == 0 # Raise AssertionError with detailed information if there are mismatches if not all_close and num_mismatched >= 1: mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] print_count = min(max_print, num_mismatched) for index in mismatched_indices[:print_count]: i = tuple(index.tolist()) mismatch_details.append( f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}" ) if num_mismatched > max_print: mismatch_details.append( f"... and {num_mismatched - max_print} more mismatched elements." ) > raise AssertionError("\n".join(mismatch_details)) E AssertionError: Number of mismatched elements: 1 E Mismatch at index (): tensor1[()] = 0.08054172992706299, tensor2[()] = 0.08054161071777344 test/utils.py:106: AssertionError ___________________________________ test_correctness[dtype1-1e-08-1e-06-none-False-32-4096-1024] ___________________________________ B = 32, T = 4096, V = 1024, log_target = False, reduction = 'none', dtype = torch.float32, atol = 1e-08, rtol = 1e-06 @pytest.mark.parametrize(*_SHAPE_PARAMS) @pytest.mark.parametrize("log_target", [True, False]) @pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"]) @pytest.mark.parametrize(*_DTYPE_PARAMS) def test_correctness(B, T, V, log_target, reduction, dtype, atol, rtol): liger_kldiv = LigerKLDIVLoss(reduction=reduction, log_target=log_target) > _test_correctness_once( liger_kldiv, B, T, V, dtype, atol, rtol, reduction, log_target ) test/transformers/test_kl_div.py:97: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ target_kldiv = LigerKLDIVLoss(), B = 32, T = 4096, V = 1024, dtype = torch.float32, atol = 1e-08, rtol = 1e-06, reduction = 'none' log_target = False, is_last_layer = True, device = 'cuda' def _test_correctness_once( target_kldiv, B, T, V, dtype, atol, rtol, reduction, log_target, is_last_layer=True, device="cuda", ): torch.manual_seed(0) torch_kldiv = KLDivLoss(reduction=reduction, log_target=log_target) input = torch.randn( B * T, V, device=device, dtype=dtype, requires_grad=True ).log_softmax(dim=-1) x1 = input.detach().clone().requires_grad_(True) x2 = input.detach().clone().requires_grad_(True) with torch.no_grad(): target = torch.randn(B * T, V, device=device).softmax(dim=-1) output = torch_kldiv(x1, target) output2 = target_kldiv(x2, target) > assert torch.allclose(output, output2, atol=atol, rtol=rtol) E AssertionError: assert False E + where False = (tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0', grad_fn=), tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0',\n grad_fn=), atol=1e-08, rtol=1e-06) E + where = torch.allclose test/transformers/test_kl_div.py:75: AssertionError ______________________________ test_correctness_not_last[dtype1-1e-08-1e-06-none-False-32-4096-1024] _______________________________ B = 32, T = 4096, V = 1024, log_target = False, reduction = 'none', dtype = torch.float32, atol = 1e-08, rtol = 1e-06 @pytest.mark.parametrize(*_SHAPE_PARAMS) @pytest.mark.parametrize("log_target", [True, False]) @pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"]) @pytest.mark.parametrize(*_DTYPE_PARAMS) def test_correctness_not_last(B, T, V, log_target, reduction, dtype, atol, rtol): liger_kldiv = LigerKLDIVLoss(reduction=reduction, log_target=log_target) > _test_correctness_once( liger_kldiv, B, T, V, dtype, atol, rtol, reduction, log_target, is_last_layer=False, ) test/transformers/test_kl_div.py:108: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ target_kldiv = LigerKLDIVLoss(), B = 32, T = 4096, V = 1024, dtype = torch.float32, atol = 1e-08, rtol = 1e-06, reduction = 'none' log_target = False, is_last_layer = False, device = 'cuda' def _test_correctness_once( target_kldiv, B, T, V, dtype, atol, rtol, reduction, log_target, is_last_layer=True, device="cuda", ): torch.manual_seed(0) torch_kldiv = KLDivLoss(reduction=reduction, log_target=log_target) input = torch.randn( B * T, V, device=device, dtype=dtype, requires_grad=True ).log_softmax(dim=-1) x1 = input.detach().clone().requires_grad_(True) x2 = input.detach().clone().requires_grad_(True) with torch.no_grad(): target = torch.randn(B * T, V, device=device).softmax(dim=-1) output = torch_kldiv(x1, target) output2 = target_kldiv(x2, target) > assert torch.allclose(output, output2, atol=atol, rtol=rtol) E AssertionError: assert False E + where False = (tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0', grad_fn=), tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0',\n grad_fn=), atol=1e-08, rtol=1e-06) E + where = torch.allclose test/transformers/test_kl_div.py:75: AssertionError _________________________________________________ test_import_custom_cache_manager _________________________________________________ def test_import_custom_cache_manager(): from triton.runtime.cache import get_cache_manager from liger_kernel.triton import apply_liger_triton_cache_manager apply_liger_triton_cache_manager() > cache_manager = get_cache_manager(key="test_hash") test/triton/test_triton_monkey_patch.py:17: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ /opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/cache.py:277: in get_cache_manager return __cache_cls(_base64(key)) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ key = 'test_hash' def _base64(key): # Assume key is a hex string. > return base64.urlsafe_b64encode(bytes.fromhex(key)).decode("utf-8").rstrip("=") E ValueError: non-hexadecimal number found in fromhex() arg at position 0 /opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/cache.py:261: ValueError ===================================================== short test summary info ====================================================== FAILED test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[10.0-dtype5-1e-08-1e-06-sum-2-4096-32000--100] - AssertionError: assert False + where False = (tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3721e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0'), tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3722e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0'), atol=1e-08, rtol=1e-06) + where = torch.allclose + and tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3721e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0') = tensor([[ 6.0503, 3.7258, -0.3530, ..., 11.8853, 20.5071, -9.9739],\n [ 15.2597, -0.5924, 6.6471, ..., -9.3584, 3.0466, -2.5966],\n [-17.9122, 31.2363, -1.4114, ..., -5.5268, 17.4033, -3.3372],\n ...,\n [ 4.3242, -7.8904, 10.2973, ..., -17.3829, -1.2789, 6.6447],\n [-10.9055, 10.4553, -5.2270, ..., -12.5100, 5.0782, 11.1050],\n [ -5.8922, 15.0620, 5.5783, ..., -5.3107, 6.2329, -13.0452]],\n device='cuda:0', requires_grad=True).grad + and tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3722e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0') = tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19, ..., 1.3759e-13, 7.6381e-10,\n 4.4185e-23],\n [2.9569e-12, 3.8580e-19, 5.3756e-16, ..., 6.0166e-23, 1.4681e-17,\n 5.1994e-20],\n [4.7900e-26, 1.0599e-04, 7.0237e-19, ..., 1.1461e-20, 1.0415e-10,\n 1.0237e-19],\n ...,\n [6.9540e-17, 3.4471e-22, 2.7309e-14, ..., 2.5999e-26, 2.5635e-19,\n 7.0793e-16],\n [6.3722e-23, 1.2054e-13, 1.8638e-20, ..., 1.2807e-23, 5.5705e-16,\n 2.3085e-13],\n [1.9623e-20, 2.4720e-11, 1.8808e-15, ..., 3.5100e-20, 3.6195e-15,\n 1.5356e-23]], device='cuda:0', requires_grad=True).grad FAILED test/transformers/test_jsd.py::test_correctness_with_beta[0.1-dtype1-1e-08-1e-06-1-4096-128256] - AssertionError: Number of mismatched elements: 1 Mismatch at index (): tensor1[()] = 0.08054989576339722, tensor2[()] = 0.08054977655410767 FAILED test/transformers/test_jsd.py::test_correctness_with_beta[0.9-dtype1-1e-08-1e-06-1-4096-128256] - AssertionError: Number of mismatched elements: 1 Mismatch at index (): tensor1[()] = 0.08054172992706299, tensor2[()] = 0.08054161071777344 FAILED test/transformers/test_kl_div.py::test_correctness[dtype1-1e-08-1e-06-none-False-32-4096-1024] - AssertionError: assert False + where False = (tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0', grad_fn=), tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0',\n grad_fn=), atol=1e-08, rtol=1e-06) + where = torch.allclose FAILED test/transformers/test_kl_div.py::test_correctness_not_last[dtype1-1e-08-1e-06-none-False-32-4096-1024] - AssertionError: assert False + where False = (tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0', grad_fn=), tensor([[ 3.8871e-04, 1.5342e-03, 9.7731e-04, ..., 1.5857e-04,\n 2.0651e-05, -2.0225e-04],\n [ 3.0436e-04, 1.4040e-03, -1.4338e-04, ..., -9.6487e-04,\n 3.6957e-04, -1.7970e-04],\n [ 1.3870e-02, 1.8989e-03, -2.3409e-04, ..., -9.2741e-05,\n -2.1325e-03, -3.6861e-04],\n ...,\n [ 1.6965e-04, 7.5081e-04, 1.7243e-03, ..., -3.3345e-04,\n 2.9291e-04, 4.6570e-03],\n [-8.5313e-04, 5.1247e-04, 2.9434e-03, ..., -1.6669e-04,\n 6.3304e-04, 8.2082e-04],\n [-1.0297e-03, -5.9040e-05, -4.5201e-04, ..., 1.1601e-03,\n 1.0437e-03, 2.4179e-04]], device='cuda:0',\n grad_fn=), atol=1e-08, rtol=1e-06) + where = torch.allclose FAILED test/triton/test_triton_monkey_patch.py::test_import_custom_cache_manager - ValueError: non-hexadecimal number found in fromhex() arg at position 0 ================================ 6 failed, 1012 passed, 8 skipped, 72 warnings in 630.02s (0:10:30) ================================ make: *** [Makefile:8: test] Error 1 ```
--------- Co-authored-by: tjtanaa Co-authored-by: root commit a2f301759e051278c1491a1acd2e8ae9d09d21c5 Author: hoshi-hiyouga Date: Sat Nov 2 14:51:31 2024 +0800 Fix llama forward patch (#339) ## Summary The present version of liger kernel use `kwargs` in model forward function, while in transformers 4.46.0-4.46.1, they pass the `num_items_in_batch` parameter when `loss_kwargs` was in the model's forward function [1][2], thus, we change the `kwargs` to `loss_kwargs` to align with the transformers' implementation [3]. [1] https://github.com/huggingface/transformers/blob/v4.46.1/src/transformers/trainer.py#L593 [2] https://github.com/huggingface/transformers/blob/v4.46.1/src/transformers/trainer.py#L3620-L3625 [3] https://github.com/huggingface/transformers/blob/v4.46.1/src/transformers/models/llama/modeling_llama.py#L1137-L1151 ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence commit 1b04de6b47845f47473500ea18ed55b87e68a68e Author: Byron Hsu Date: Fri Nov 1 13:18:31 2024 -0700 Update pyproject.toml After https://github.com/linkedin/Liger-Kernel/pull/274, triton needs to be >=2.3.1 commit ac2e8f4563289f7bee0ad9652926afec5c46747b Author: Yun Dai Date: Thu Oct 31 21:46:53 2024 -0700 Fix FusedLinearJSD precision issue when using AMP (#336) ## Summary 1. make sure all the computation between logit to final JSD loss happen on FP32 2. make sure FLJSD works properly under mixed precision scenario, also add a test to guard 3. the Torch CE loss impl we use in testing FLCE misses out the fp32 cast for logits, add it back. **NOTE: we should definitely jus switch directly to [HF impl](https://github.com/huggingface/transformers/blob/main/src/transformers/loss/loss_utils.py#L32) for testing to ensure always doing apple-to-apple comparison. See the added TODO item.** ## Testing Done - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit 659d7d7856bf755c1cf26f2df6173da68841ba17 Author: Chiwan Park Date: Fri Nov 1 08:24:06 2024 +0900 Fix incorrect training of first and last Medusa heads (#325) ## Summary Currently, there are two errors on Medusa training examples: 1. When we use Liger Kernel, the first head (`model.medusa_head[0]`) is not trained. 2. When we don't use Liger Kernel, the logits of the last head (`medusa_logits[-1]`) is ignored. This PR fixes these errors. ## Testing Done - Hardware Type: A100 80GB 8 GPUs - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit 827b51c45762d6fc0ffaa7655126467c16f06d44 Author: Byron Hsu Date: Thu Oct 31 15:33:05 2024 -0700 Update llama.py commit e28521bed9f13daacdc363b6975158a2e67ec3a4 Author: Byron Hsu Date: Thu Oct 31 14:40:41 2024 -0700 Fix huggingface GA issue for llama (#333) ## Summary To fix https://github.com/linkedin/Liger-Kernel/pull/322 This PR introduces a new `lce_forward` compatible with `transformers>=4.46.0` (after grad acc fix) while ensuring backward compatibilty. To be specific, i keep the original flce untouched and write a new one for `4.46.0`. If HF version is `<4.46.0`, it will show a warning for deprecation, and fallback to the old flce. ```python if transformer_version >= version.parse("4.46.0"): modeling_llama.LlamaForCausalLM.forward = llama_lce_forward else: # if version < 4.46.0 logger.warning( "Support for transformers versions < 4.46.0 will soon be discontinued due to issues with incorrect gradient accumulation. " "Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/34191" ) modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated ``` For more context of grad acc fix, please see https://github.com/huggingface/transformers/pull/34191 ## TODO - [ ] broadcast the changes to all models once the effect is verified. ## Testing Done - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit 337bf9a8361740c1caf38ba28b9dc9f7303c9aca Author: Anish <98446102+novanish@users.noreply.github.com> Date: Thu Oct 31 06:04:25 2024 +0545 docs(CONTRIBUTING): fix typo (#331) ## Fix typo in CONTRIBUTING.md This PR corrects a typo in the CONTRIBUTING.md file, changing "functionaility" to "functionality" in the semantic versioning section. Co-authored-by: Yun Dai commit 48aa62d3ecb0a46009d2b92510a63e39e860fe82 Author: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Thu Oct 31 01:15:12 2024 +0800 Add missing ignore_index tests (#310) ## Summary `ignore_index` in fused_linear_cross_entropy was not tested ## Testing Done - Hardware Type: gpu-ci - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Byron Hsu Co-authored-by: Yun Dai commit 1c0c75c3455e788d575966bfc5edec3ef166835e Author: Yun Dai Date: Tue Oct 29 21:59:37 2024 -0700 fix fused JSD with ignore index (#330) ## Summary 1. There's currently a bug in fused linear JSD where we don't extract the correct subset of label corresponding to the currently processed chunk 2. add some tests to make sure results are correct when all tokens are ignored ## Testing Done - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit 6cdc93deee15ab6c843149d6ed660c297c5c2d4a Author: Yun Dai Date: Fri Oct 25 17:23:23 2024 -0700 fix FLCE AMP issue (#318) ## Summary fixes #305 : just rely on torch AMP to determine the input dtype when AMP context is enabled ## Testing Done - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit 9ad8f89373b2206e86e9bb1cdc6e63c37275bd81 Author: Byron Hsu Date: Fri Oct 25 09:53:42 2024 -0700 Update README.md commit 4e2f7c6b9185560294c24ee48c32c07cefc7e828 Author: Byron Hsu Date: Fri Oct 25 09:53:08 2024 -0700 remove torch compile section until the issue is fixed commit 99599091373f178e8ad6a69ecb1b32351d1d5c1f Author: Byron Hsu Date: Mon Oct 21 14:41:32 2024 -0700 Update README.md commit e49b83a4af985ef1f75c994bbdb7ed103b22ae11 Author: Byron Hsu Date: Mon Oct 21 14:40:01 2024 -0700 Update citation and add tech report (#317) ## Summary ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence commit 7da01b7188266342b94858fd2e01bf037099441c Author: Kürşat Aktaş Date: Tue Oct 22 00:22:41 2024 +0300 Introducing Liger Kernel Guru on Gurubase.io (#316) I created the [Liger Kernel Guru](https://gurubase.io/g/liger-kernel) badge on Gurubase.io upon request from @ByronHsu. Adding a new badge next to the Discord badge made all the badge text smaller, as the current style presents all badges in a table row. To address this, I added the Liger Kernel Guru badge to the index section. Please let me know if you'd like me to move it to a different section. commit 6ab3b9febc29f5045e6d2e27ba6bacaa4f041d91 Author: Shivam Sahni Date: Thu Oct 17 15:54:14 2024 -0700 Monkey patch layer norm in mllama (#302) ## Summary Monkey patches layer norm in mllama for conditional generation ## Testing Done Tested monkey patching works as intended - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shivam Sahni commit 24a7efca81c7e4cd7558c539fdfd5e380e9f2f58 Author: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Wed Oct 16 09:17:08 2024 +0800 Add ignore_index and label to jsd and fl-jsd (#306) commit 31469169eb286b792d96f2e92d4fcff47538d01d Author: barbarian360 <94866865+barbarian360@users.noreply.github.com> Date: Tue Oct 15 02:38:57 2024 +0530 Added contributors and back to top (#304) ## Summary Added the contributors section in the readme and also added the back to top button. commit 04d5a0e1d442439c65170cc67b112eba42dc37ee Author: Matthew Hoffman Date: Mon Oct 14 14:08:03 2024 -0700 Move `logits.float()` call (#308) ## Summary The analogous `logits.float()` calls were moved in the Hugging Face modeling source code to be inside the `if labels is not None` block to avoid upcasting logits unless they are being used in a loss calculation; this avoids a memory spike during inference if the model is in lower precision. * https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/llama/modeling_llama.py#L1211-L1212 * https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/mixtral/modeling_mixtral.py#L1329-L1330 * https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/phi3/modeling_phi3.py#L1303-L1304 * https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/qwen2/modeling_qwen2.py#L1206-L1207 Some of your models already have this change: https://github.com/linkedin/Liger-Kernel/blob/ff6650bbcef5d31b7522694cbeb73a21169460e9/src/liger_kernel/transformers/model/mistral.py#L114-L116 https://github.com/linkedin/Liger-Kernel/blob/ff6650bbcef5d31b7522694cbeb73a21169460e9/src/liger_kernel/transformers/model/gemma.py#L114-L116 See also: * https://github.com/huggingface/transformers/issues/30860 ## Testing Done - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit ff6650bbcef5d31b7522694cbeb73a21169460e9 Author: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sat Oct 12 02:22:25 2024 +0800 Add FusedLinearJSD (#300) ## Summary similar to the fuse linear CE. It handles the forward and backward pass of the final linear layer via JSD by avoiding the materialization of the large logits tensor. Since JSD is the last layer, we can compute the gradient at the forward pass. ## Testing Done Hidden size: 4096, Vocab size: 128256 ![fused_linear_jsd_memory](https://github.com/user-attachments/assets/231303d1-4734-49fb-8c69-8e60730563c2) ![fused_linear_jsd_speed](https://github.com/user-attachments/assets/d83c85ec-ab29-44e0-a3d9-ad85acf4577d) - Hardware Type: NVIDIA H100 80GB HBM3 (SXM5) - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Qingquan Song commit 9b10f48f17af28d7319869f2f8a0c18d93f60e21 Author: Tyler Romero Date: Thu Oct 10 15:16:00 2024 -0700 Monkeypatch for Llama 3.2-Vision (#282) ## Summary Add monkeypatch to support [Llama 3.2-Vision](https://github.com/huggingface/transformers/pull/33703/files#diff-d804e851851cdebeb8048938f1f8beec1cfa78bf7b1f06af86faa450f9d18def) models. ## Details Llama 3.2-Vision is a multimodal model. It is also only available in transformers>=4.45.0. Torchvision is required to run the multimodal tests for Llama 3.2-Vision (the image processor requires it). ## Testing Done - Hardware Type: RTX 4090 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang Co-authored-by: Steven Shimizu Co-authored-by: shivam15s commit de12602d858a6e83aaacc56e5cb64ab218c75a0a Author: Yanning Chen Date: Fri Oct 4 17:53:08 2024 -0700 Apache and MIT license reference (#294) ## Summary Add dependent licenses ## Testing Done - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit b540f0a52e35fd1bbdec47cf2ac9474cce6d5948 Author: Ikko Eltociear Ashimine Date: Fri Oct 4 17:26:12 2024 +0900 chore: update cross_entropy.py (#293) ## Summary Orginal -> Original - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence commit 6817c2d34d66fca9b469056b088b0a5702bd87ce Author: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Thu Oct 3 12:14:41 2024 +0800 Add beta support for jsd (#290) ## Summary Resolve #278 . ## Details ### Forward: ```math \begin{align} JSD(X, Y, \beta) &= JSD_{\beta}(P \Vert Q)\\ &= \beta\ KL(P \Vert \beta P + (1-\beta)Q) + (1-\beta)\ KL(Q \Vert \beta P + (1-\beta)Q)\\ &= \sum \beta\ PY + (1-\beta)QX - M\ logM \end{align} ``` where $X=logQ$, $Y=logP$ and $M=\beta P + (1-\beta)Q$. ### Gradients: ```math \frac{\partial}{\partial X_i} JSD(X, Y, \beta) = (1-\beta)Q_i(X_i - logM_i) ``` ## Testing Done ![jsd_memory](https://github.com/user-attachments/assets/a26e1a64-df4b-49fe-8564-01a6757cb76a) ![jsd_speed](https://github.com/user-attachments/assets/6f631bdb-5abf-44ed-875b-2596f3a30b8b) - Hardware Type: H100 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang commit 60640e11d9cbe37088093f0c8fabdb9673c6372d Author: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Date: Thu Oct 3 00:06:55 2024 +0200 FEAT Adding experimental feature : Triton mm int8xint2 (#195) ### Summary Introducing matrix multiplication int8xint2 in Triton as an experimental feature. This approach involves performing matmul with on-the-fly unpacking, utilizing cached tiling techniques. Currently, it leverages tl.dot with int8 values, which is the most optimized method available at this time. However, with future hardware advancements, this could become significantly more efficient, particularly when using ternary weights, potentially eliminating the need for multiplication altogether. --------- Co-authored-by: Shao Tang Co-authored-by: Byron Hsu commit e1e9d2e31a6543e902c0986a4e49cad39ef8887a Author: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Thu Oct 3 05:38:31 2024 +0800 RMSNorm aggregation (#255) ## Summary Resolve #179 ## Testing Done - Hardware Type: RTX-3080 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang commit 665751e2de7475e61b28542aeabc08c785445d6f Author: Wizyoung Date: Wed Oct 2 07:13:36 2024 +0800 FIX: tl.program_id() does indeed not have a cast method in triton2.3.1 (#274) ## Summary https://github.com/linkedin/Liger-Kernel/pull/251 casted program_id to int64 in SwiGLU/GeGLU, but it's not compatible for triton2.3.1. ## Testing Done - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence Co-authored-by: Shao Tang commit ede50df2d49a69fee5d56016f45006799e7fff44 Merge: dc07725 92ae3fa Author: Shao Tang Date: Tue Oct 1 15:58:37 2024 -0700 Merge branch 'main' into main commit 92ae3fa12c7266c32e1362787cb05d1397c089cb Author: Tyler Romero Date: Tue Oct 1 15:57:42 2024 -0700 Add missing Qwen2-VL monkey patch test (#283) ## Summary Add a monkeypatch test for Qwen2-VL. Also fixes and re-enables Qwen2-VL multimodal tests ## Testing Done - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang commit bc5d20f77197853dd61c43f6b69964714c97f978 Author: Qing Date: Wed Oct 2 06:40:36 2024 +0800 fix qwen2-vl: create correct rope position_ids when position_ids is None (#276) ## Summary When `position_ids` is None, we should call `get_rope_index` to create 3D rope index The code was copied from here: https://github.com/huggingface/transformers/pull/33487. ## Testing Done I am using qwen2-vl to train the grounding task. The red box shows the results before fixing, and the green box shows the results after fixing (correct results). WechatIMG4500_副本 - Hardware Type: 3090 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang commit 16bde6c4bfc9993987c7a33372ea77532f7838db Author: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Wed Oct 2 06:10:10 2024 +0800 Fix assert_verbose_allclose bugs (#261) ## Summary Fix #259 Adding more masks to cover all edge cases, including: + nan + inf + -inf We should merge #262 before this PR to pass all tests. ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence Co-authored-by: Shao Tang commit 744188e853784842c0b43cb3410c1aa51a69d4f1 Author: Tyler Romero Date: Tue Oct 1 14:47:04 2024 -0700 Cancel in-progress but out-of-date GPU actions (#289) ## Summary When updating a PR branch, a lot of concurrent CICD actions can be triggered. This creates a backlog and it takes time for the tests to actually get around to running on the most recent commit. Github makes it easy to cancel in-progress actions when new actions are trigged within a group. This will help avoid wasting GPU time running CICD on out of date commits. Avoids cancelations on main, so that a complete CICD history is available there. - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Steven Shimizu commit dc077250682d4a36769c6d525355d5648b8823ec Merge: 00f47dd 8e2f3a4 Author: Shao Tang Date: Tue Oct 1 14:40:00 2024 -0700 Merge branch 'main' into main commit 8e2f3a477871b3c72f3f054673edbd656497752d Author: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Wed Oct 2 05:26:42 2024 +0800 Add JSD kernel (#264) ## Summary Resolve #252 ## Details ### JSD We expect input $X$ and target $Y$ are distributions in log-space, i.e., $X = log Q$ and $Y = log P$. Jenson-Shannon Divergence between two distributions $P$ and $Q$ is defined as: ```math JSD(X, Y) = JSD(P\ \Vert \ Q) = \frac{1}{2} (KL(P\ \Vert\ M) + KL(Q\ \Vert\ M)) ``` where $M = \frac{1}{2}(P + Q)$ is the average distribution and $KL$ is the Kullback-Leibler divergence. Given that $X = log Q$ and $Y = log P$, we can simplify JSD expression to: ```math \begin{align} JSD(X, Y) &= \frac{1}{2}(\sum_i P_i\ log\frac{P_i}{M_i} + \sum_i Q_i\ log\frac{Q_i}{M_i})\ &= \frac{1}{2} \sum_i (P_i\ log\ P_i - P_i\ log\ M_i + Q_i\ log\ Q_i - Q_i\ log\ M_i)\ &= \frac{1}{2} \sum_i (P_i\ log\ P_i + Q_i\ log\ Q_i - 2M_i\ log\ M_i)\ &= \frac{1}{2} \sum_i (P_i \cdot X_i + Q_i\cdot Y_i - 2M_i\ log\ M_i) \end{align} ``` We define the point-wise JSD as: ```math JSD(X_i, Y_i)= \frac{1}{2} (P_i \cdot X_i + Q_i\cdot Y_i - 2M_i\ log\ M_i) ``` With point-wise JSD, it's easier to implement JSDs with respect to different reduction methods in future. The only downside is that it creates a torch.float32 tensor with the same shape as input's. Current implementation is hardcoded to batchmean which is the original JSD definition. ### Gradients Given: ```math JSD(X, Y) = JSD(P\ \Vert \ Q) = \frac{1}{2} (KL(P\ \Vert\ M) + KL(Q\ \Vert\ M)) ``` where $Q = e^X$, $P = e^Y$, and $M = \frac{1}{2}(e^X + e^Y)$. #### Gradients of $KL(P\ \Vert\ M)$ with respect to $X_i$: ```math \begin{align} \frac{\partial}{\partial X_i} \sum_j P_j\ log\frac{P_j}{M_j} &= \frac{\partial}{\partial X_i} \sum_j P_j (Y_j - log\ M_j)\ &= \frac{\partial}{\partial X_i} \sum_j - P_j log\ M_j\ &= - P_i \cdot \frac{1}{M_i}\cdot \frac{e^X}{2} = -P_i\cdot \frac{Q_i}{2M_i} \end{align} ``` #### Gradients of $KL(Q\ \Vert\ M)$ with respect to $X_i$: ```math \begin{align} \frac{\partial}{\partial X_i} \sum_j Q_j\ log\frac{Q_j}{M_j} &= \frac{\partial}{\partial X_i} \sum_j Q_j (X_j - log\ M_j)\ &= \sum_j\left( \frac{\partial Q_j}{\partial X_i}(X_j - log\ M_j) + Q_j \frac{\partial (X_j - log\ M_j)}{\partial X_i}\right)\ &= Q_i(X_i - log\ M_i) + Q_i (1 - \frac{Q_i}{2M_i}) \end{align} ``` #### Final gradients of JSD: Combine the results from two KL divergence terms: ```math \begin{align} \frac{\partial}{\partial X_i} JSD(X, Y) &= \frac{1}{2}\left(-P_i\cdot \frac{Q_i}{2M_i} + Q_i(X_i - log\ M_i) + Q_i (1 - \frac{Q_i}{2M_i})\right) \end{align} ``` Simplify this to: ```math \begin{align} \frac{\partial}{\partial X_i} JSD(X, Y) &= \frac{1}{2}\left(-P_i\cdot \frac{Q_i}{2M_i} + Q_i(X_i - log\ M_i) + Q_i (1 - \frac{Q_i}{2M_i})\right)\ &= \frac{1}{2}\left(Q_i (X_i - log\ M_i + 1 - \frac{P_i + Q_i}{2M_i})\right),\ where\ 2M_i = P_i + Q_i \ &= \frac{1}{2}\cdot Q_i \cdot (X_i - log\ M_i) \end{align} ``` We store gradients at X_ptr in forward pass to save memory, then retrieve it through ctx in backward function as cross_entropy does. (inplace) note: inplace operations on inputs might cause an issue with gradient computation. ## Testing Done ### With inplace (Storing gradients to inputs) reduce memory usage by 61.54% ![jsd_memory](https://github.com/user-attachments/assets/fc71a3b3-73aa-433f-ac62-0a0924d5c2de) increase speed by 53.64% ![jsd_speed](https://github.com/user-attachments/assets/98eac714-aee3-4326-8554-26efc473baac) ### Without inplace reduce memory usage by 53% ![jsd_memory](https://github.com/user-attachments/assets/4f435fdb-872e-4465-a086-9ebcaa41ba3f) increase speed by 61% ![jsd_speed](https://github.com/user-attachments/assets/4e88bccd-982e-4fca-9148-004e4ef97570) - Hardware Type: H100 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang Co-authored-by: Qingquan Song commit d4933b5c520833f47d4e9fe4f1866444c9b70cb6 Author: Yanning Chen Date: Tue Oct 1 14:10:42 2024 -0700 Acknowledgement in NOTICE file (#287) ## Summary Acknowledgement in NOTICE file ## Testing Done - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang commit 1520999e60e34a9e034026d05917082de098be1e Author: Steven Shimizu Date: Tue Oct 1 13:26:17 2024 -0700 Release version 0.3.1 (#286) ## Summary - Release version 0.3.1 ## Testing Done none commit e62fc98bd5ed79348f8651f8fadd02b82e325914 Author: Steven Shimizu Date: Tue Oct 1 11:41:59 2024 -0700 Disable gemma2 and qwen2_vl tests (#288) ## Summary - Gemma2 convergence tests were erroneously passing before due to all tensors having NaN values. Using `attn_implementation="eager"` fixes the NaNs, but results don't pass convergence criteria. Will need to investigate further, but skipping these for now. - The discrepancy was revealed after transformers 4.44.2 -> 4.45.1 update which seems to have fixed to fall back on eager attn implementation - Qwen2_VL convergence tests are failing and also require access to internet (HF Hub), so having a hard time debugging. Skipping this for now. ## Testing Done - Hardware Type: A100 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit a5035d1bee274e90ff241f5dd5fa8acd2710e490 Author: Steven Shimizu Date: Mon Sep 30 15:21:44 2024 -0700 Relaxed transformers dependency (#270) ## Summary - Make the `transformers` dependency optional so we only have `torch` and `triton` as required deps, which is helpful if you're not using `transformers` for modeling code. This was also causing installation issues for people using slightly older transformers versions. - If transformers is needed, make it compatible with any 4.x version. The specific model being used should dictate the transformers version compatibility. ## Testing Done `pip install -e .[transformers]` `pip install -e .[dev]` A100-80G-PCIe - Hardware Type: A100-80G-PCIe - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit f2b288c15435329ab4cd582249ce8473d2ea997a Author: Steven Shimizu Date: Mon Sep 30 14:32:55 2024 -0700 Post-init model patching fix (#280) ## Summary - Previously, the pre-trained weights were not being loaded if patching model post-initialization - Instead of loading weights, just patch the model instance module's forward method (see https://github.com/linkedin/Liger-Kernel/issues/279) ## Testing Done - In convergence tests, check that pre-init patching and post-init patching match results from original model - Hardware Type: A100 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --> most tests working, waiting for other fixes for all tests to pass commit 00f47dd0bba2ec9e18efa0ee9c5d6f3b9e871538 Merge: b1d3781 1dc6555 Author: Shao Tang Date: Mon Sep 30 09:30:32 2024 -0700 Merge branch 'main' into main commit 1dc655574490719a8cfdd4675d37920dff31c9ff Author: S1ro <54212263+S1ro1@users.noreply.github.com> Date: Mon Sep 30 18:29:27 2024 +0200 Fix/kldiv (#262) ## Summary Fixes KLDiv being calculated wrong due to invalid tests as described in #259 . Now made KLDiv work for even bigger vocab size etc. The benchmarks remain +- the same. Therefore not adding those as I don't have access to the original GPU they were done on before. ## Testing Done Tests stay the same, except replacing `assert_verbose_all_close` with `assert torch.allclose` - Hardware Type: Nvidia RTX A5500 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang commit b1d378183b0987340e22287aee0dbfd7049f7af0 Author: wizyoung Date: Thu Sep 26 21:17:42 2024 +0800 FIX: tl.program_id() does indeed not have a cast method in triton2.3.1 commit 8e2bd26f67123d37333488206fb8f36614c9567c Author: Chiwan Park Date: Thu Sep 26 00:34:57 2024 +0900 Fix sharing a ResBlock layer for each head in Medusa example (#269) ## Summary There is a bug of incorrect weight sharing between layers for each Medusa head. Since `nn.Module` is Python reference, the original source code creates a list containing references to the same weights. This PR fixes the bug. ## Testing Done - Hardware Type: A100-80G-PCIe - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit dcc7c9e4137390f4fc1a522b0cb89849d9b947ce Author: Mark Saroufim Date: Tue Sep 24 10:10:41 2024 -0700 rename cuda mode to gpu mode (#267) ## Summary This is a documentation only change since the server has been recently renamed. See this tweet for context https://x.com/jeremyphoward/status/1838341110344880637 Hopefully this is OK to merge :) commit ed4e60cd1d101410f4cffb06f18d62347b5bffc5 Author: Tyler Romero Date: Sun Sep 22 10:01:51 2024 -0700 chore: Add Qwen2.5 and Phi3.5 to Readme (#265) ## Summary Qwen2.5 was released recently - it uses the same model architecture as Qwen2 (see: https://huggingface.co/Qwen/Qwen2.5-7B/blob/main/config.json#L3). Likewise for Phi3.5 (see: https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/config.json#L4). Also, the `model_type`s in the above configs will allow AutoLigerKernelForCausalLM to work correctly for these models out of the box. Adding them to the readme for clarity / marketing reasons :) - Hardware Type: - [ x ] run `make test` to ensure correctness - [ x ] run `make checkstyle` to ensure code style - [ x ] run `make test-convergence` to ensure convergence commit dd86cbd2092177681acf75643ded1b23a785a816 Author: Shivam Sahni Date: Fri Sep 20 16:30:59 2024 -0700 Update contributing guide for adding a new model (#260) commit 1289cc41c2591df6a2c1e7d902f8733239991100 Author: Steven Shimizu Date: Fri Sep 20 16:17:12 2024 -0700 Fix AutoLigerKernelForCausalLM to pass through original kwargs (#263) ## Summary - Fixes https://github.com/linkedin/Liger-Kernel/issues/250 to correctly pass all original kwargs to .from_pretrained(). Previously we were only passing args that were part of the model config, but there are additional valid kwargs beyond that. - We still need to filter out the kwargs passed into the apply_liger_* functions, or else will result in model init errors ## Testing Done Tested on huggingface example with some of the args in https://github.com/linkedin/Liger-Kernel/issues/250 - Hardware Type: A100 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit ce71d59b0b0894f9f3e7512f5a3bf3780c5a1499 Author: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Thu Sep 19 23:03:41 2024 +0800 Fix a comment typo in flce (#256) ## Summary os huge -> is huge ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence commit 58fd2bc85073fdb010164426c9b159cd8a0e9542 Author: Hanson Wang Date: Tue Sep 17 08:48:20 2024 -0700 [Easy] Cast program_id to int64 in SwiGLU/GeGLU kernels (#251) ## Summary I hit some memory corruption errors testing large batches of tokens with larger models - e.g. with Gemma2-27B and a batch size of 80K tokens you will hit 80K * 36864 = 2.949e9 elements in the intermediate dimension, greater than (signed) int32! `tl.program_id` needs to be casted to int64 like in the fused cross-entropy kernel. ## Testing Done I didn't add a unit test for this because it would require a fair bit of VRAM, but can do so if desired. Was able to verify that forward+backward works without corruption on a Gemma2-27B model. - Hardware Type: A100 80GB - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence Co-authored-by: Shao Tang commit d1343adcdd9314efe004dbfb431cd7ad91105c12 Author: Edoardo Luciani Date: Sat Sep 14 23:45:29 2024 +0100 Remove debug print statement (#247) ## Summary Just a removal of a debug print statement left by accident (I suppose). - Hardware Type: NVIDIA RTX 2070 - [x] run `make test` to ensure correctness (to the extent my GPU can) - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit 793785f2dc999a2aef78fac58616a6ea93034542 Author: Qingquan Song Date: Fri Sep 13 14:18:14 2024 -0700 Release Liger-Kernel version 0.3.0 (#246) Release Liger-Kernel version 0.3.0 ## Summary ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence commit 53d5934c796d7ecdcbdf9790dd9fec89a2205149 Author: Shivam Sahni Date: Fri Sep 13 11:53:07 2024 -0700 Reduction support for CrossEntropy and Division by 0 Fix (#153) commit 7a5d48425d816fac15195f56861787d80659e16b Author: Steven Shimizu Date: Fri Sep 13 10:39:46 2024 -0700 Support for patching post-model initialization (#199) ## Summary - Currently, calling the patching APIs after the model has been initialized will only partially patch with Liger kernels. For example, the following will still be patched: - Model `forward()` method (e.g. `modeling_llama.LlamaForCausalLM.forward = lce_forward`) - module functions (e.g. `modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb`) but not any modules that were already instantiated and set as instance variables on the model: - For example: `modeling_llama.LlamaRMSNorm = LigerRMSNorm` will not affect existing LlamaRMSNorm instances - This means that integrations with HF Trainer and SFTTrainer only partially work. In the case of SFTTrainer, the current integration only works fully if the user passes a path to the model (https://github.com/huggingface/trl/pull/1992), and SFTTrainer handles calling `AutoLigerKernelForCausalLM.from_pretrained`. However, both HF Trainer and SFTTrainer allow passing the model instance directly, in which case we would need a way of patching the model post-init. ### API ```python from liger_kernel.transformers import _apply_liger_kernel_to_instance llama_model = AutoModelForCausalLM.from_pretrained("/path/to/llama", ...) _apply_liger_kernel_to_instance(model=llama_model) # can also pass in model-specific args that will get passed into the correct apply_liger_kernel_to_{model_type} _apply_liger_kernel_to_instance(model=llama_model, rope=False) ``` Required changes post-PR: - Update HF Trainer and SFTTrainer to pass in model instead of model_type to `_apply_liger_kernel(model=model)` ## Testing Done - Tested HF example with no patching, pre-init patching, post-init patching using existing method, post-init patching of model instance variables showing that post-init instance patching results in same performance as pre-init patching: - Added unit tests that model instances are actually patched correctly and that each patching API supports model instance patching. **Llama** | Scenario | Memory Usage (GB) | Throughput (tokens/sec) | |--------------|-------------------|-------------------------| | No patching | 68.3 | 9161 | | Patch post-init (existing method) | 39.9 | 10939 | | Patch pre-init | 38.4 | 12313 | | Patch post-init (instance patching) | 38.3 | 12409 | **Mistral** | Scenario | Memory Usage (GB) | Throughput (tokens/sec) | |--------------|-------------------|-------------------------| | No patching | 34.5 | 10976 | | Patch post-init (existing method) | 34.5 | 10812 | | Patch pre-init | 34.5 | 12286 | | Patch post-init (instance patching) | 34.5 | 12086 | **Mixtral** OOM on testing setup, but confirmed patched post-model init loaded correctly and started training **Gemma** | Scenario | Memory Usage (GB) | Throughput (tokens/sec) | |--------------|-------------------|-------------------------| | No patching | 52.8 | 7199 | | Patch post-init (existing method) | 40.7 | 9479 | | Patch pre-init | 40.7 | 9209 | | Patch post-init (instance patching) | 40.7 | 9981 | **Gemma2** | Scenario | Memory Usage (GB) | Throughput (tokens/sec) | |--------------|-------------------|-------------------------| | No patching | 66.9 | 12256 | | Patch post-init (existing method) | 51.0 | 14288 | | Patch pre-init | 50.8 | 17084 | | Patch post-init (instance patching) | 46.8 | 15893 | * Note: Gemma2 pre-init and post-init (instance patching) are not converging, seeing nan gradnorm. Need to investigate separately. **Qwen2** | Scenario | Memory Usage (GB) | Throughput (tokens/sec) | |--------------|-------------------|-------------------------| | No patching | 41.2 | 10785 | | Patch post-init (existing method) | 36.8 | 10491 | | Patch pre-init | 36.8 | 11908 | | Patch post-init (instance patching) | 36.8 | 12068 | **Qwen2 VL** - Patched, but couldn't test since not yet released in latest transformers **Phi3** | Scenario | Memory Usage (GB) | Throughput (tokens/sec) | |--------------|-------------------|-------------------------| | No patching | 49.4 | 15601 | | Patch post-init (existing method) | 49.4 | 16259 | | Patch pre-init | 42.8 | 20383 | | Patch post-init (instance patching) | 48.6 | 18922 | * Note: For phi3, post-init patching seems to consistently not match pre-init patching for some reason - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit d4879dfe2974d2a4d85019f143c331709f6920f5 Author: Austin Liu Date: Fri Sep 13 13:09:00 2024 +0800 Restore monkey patched modules (#232) ## Summary Fixes https://github.com/linkedin/Liger-Kernel/issues/176 There are several ways to restore a monkey-patched library in Python, including using context managers, decorators, pytest fixtures, or reloading the entire module. This PR focuses on reverting monkey-patched modules when `with_liger` is disabled in convergence tests. ```python import target.module importlib.reload(target.module) ``` These changes simplify the process of resetting the affected patched library and help prevent unintended side effects. And it's easier than manually reassigning functions anyway. ## Follow-up If this PR resolves the https://github.com/linkedin/Liger-Kernel/issues/176, it might introduce other value mismatch problems. We may need to adjust the convergence tolerance accordingly. For instance, ``` ______________________ test_mini_model[mini_mixtral-32-0.0001-dtype11-1e-08-1e-05-0.1-1e-05-0.01-1e-05] _______________________ model_name = 'mini_mixtral', num_steps = 32, lr = 0.0001, dtype = torch.bfloat16, loss_atol = 1e-08, loss_rtol = 1e-05 logits_atol = 0.1, logits_rtol = 1e-05, param_atol = 0.01, param_rtol = 1e-05 @pytest.mark.parametrize( "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol", [ ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5), pytest.param( "mini_llama3", 32, 1e-4, torch.bfloat16, 5e-3, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_qwen2", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), pytest.param( "mini_qwen2_vl", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5, marks=pytest.mark.skipif( not QWEN2_VL_AVAILABLE, reason="Qwen2-VL not available in this version of transformers", ), ), pytest.param( "mini_qwen2_vl", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5, marks=[ pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), pytest.mark.skipif( not QWEN2_VL_AVAILABLE, reason="Qwen2-VL not available in this version of transformers", ), ], ), ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_phi3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_mistral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5), pytest.param( "mini_mixtral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way) ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_gemma1", 32, 1e-4, torch.bfloat16, 1e-2, 1e-4, 2e-1, 1e-5, 1e-2, 1e-5, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_gemma1.1", 32, 1e-4, torch.bfloat16, 1e-2, 1e-4, 2e-1, 1e-5, 1e-2, 1e-5, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_gemma2", 32, 1e-4, torch.bfloat16, 1e-2, 1e-4, 2e-1, 1e-5, 1e-2, 1e-5, marks=pytest.mark.skipif( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), ], ) def test_mini_model( model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol, ): # Non-liger models should be initialized and tested first to avoid the module being overridden expected_output = run_mini_model( model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr ) actual_output = run_mini_model( model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True ) # Compare every step of the loss > assert_verbose_allclose( torch.tensor([expected_output["loss"]]), torch.tensor([actual_output["loss"]]), atol=loss_atol, rtol=loss_rtol, ) test/convergence/test_mini_models_no_logits.py:594: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ tensor1 = tensor([[10.9374, 7.0162, 4.8162, 3.2886, 2.4254, 1.9993, 1.6753, 1.7743, 1.4267, 1.4742, 1.4458, ... 1.0867, 0.8353, 0.9219, 0.8796, 0.8610, 0.8183, 0.7559, 0.8734, 0.9647, 0.7261, 1.0963, 0.8136]]) tensor2 = tensor([[10.9383, 7.0052, 4.8145, 3.3515, 2.3853, 2.0174, 1.6758, 1.7778, 1.4256, 1.4737, 1.4442, ... 1.0870, 0.8346, 0.9222, 0.8817, 0.8610, 0.8181, 0.7554, 0.8736, 0.9671, 0.7263, 1.0966, 0.8171]]) rtol = 1e-05, atol = 1e-08, max_print = 5 def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5): """ Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. Parameters: tensor1 (torch.Tensor): First tensor to compare. tensor2 (torch.Tensor): Second tensor to compare. rtol (float): Relative tolerance. atol (float): Absolute tolerance. max_print (int): Maximum number of mismatched elements to print. Raises: AssertionError: If the tensors are not all close within the given tolerance. """ # Check if the shapes of the tensors match if tensor1.shape != tensor2.shape: raise AssertionError("Input tensors must have the same shape.") # Calculate the difference between the tensors diff = torch.abs(tensor1 - tensor2) # Determine the tolerance tolerance = atol + rtol * torch.abs(tensor2) # Find mismatched elements mismatched = diff > tolerance # Get the indices of mismatched elements mismatched_indices = torch.nonzero(mismatched) # Count the number of mismatched elements num_mismatched = mismatched.sum().item() # Check if all elements are close all_close = num_mismatched == 0 # Raise AssertionError with detailed information if there are mismatches if not all_close and num_mismatched > 1: mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] print_count = min(max_print, num_mismatched) for index in mismatched_indices[:print_count]: i = tuple(index.tolist()) mismatch_details.append( f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}" ) if num_mismatched > max_print: mismatch_details.append( f"... and {num_mismatched - max_print} more mismatched elements." ) > raise AssertionError("\n".join(mismatch_details)) E AssertionError: Number of mismatched elements: 32 E Mismatch at index (0, 0): tensor1[(0, 0)] = 10.937411308288574, tensor2[(0, 0)] = 10.938319206237793 E Mismatch at index (0, 1): tensor1[(0, 1)] = 7.016175270080566, tensor2[(0, 1)] = 7.0052409172058105 E Mismatch at index (0, 2): tensor1[(0, 2)] = 4.8161821365356445, tensor2[(0, 2)] = 4.814478397369385 E Mismatch at index (0, 3): tensor1[(0, 3)] = 3.288573980331421, tensor2[(0, 3)] = 3.3514533042907715 E Mismatch at index (0, 4): tensor1[(0, 4)] = 2.425377368927002, tensor2[(0, 4)] = 2.3853368759155273 E ... and 27 more mismatched elements. test/utils.py:83: AssertionError ---------------------------------------------------- Captured stdout call ----------------------------------------------------- Liger kernel patches have been reverted. Step 0, Loss: 10.937411308288574 Step 1, Loss: 7.016175270080566 Step 2, Loss: 4.8161821365356445 Step 3, Loss: 3.288573980331421 Step 4, Loss: 2.425377368927002 Step 5, Loss: 1.999261736869812 Step 6, Loss: 1.675323486328125 Step 7, Loss: 1.7742501497268677 Step 8, Loss: 1.4266773462295532 Step 9, Loss: 1.474155068397522 Step 10, Loss: 1.4458246231079102 Step 11, Loss: 1.1540931463241577 Step 12, Loss: 1.3520232439041138 Step 13, Loss: 1.311019778251648 Step 14, Loss: 1.219789981842041 Step 15, Loss: 1.3071205615997314 Step 16, Loss: 1.2621395587921143 Step 17, Loss: 1.3119654655456543 Step 18, Loss: 1.1880946159362793 Step 19, Loss: 1.2357648611068726 Step 20, Loss: 1.0867037773132324 Step 21, Loss: 0.8352738618850708 Step 22, Loss: 0.9218576550483704 Step 23, Loss: 0.879619836807251 Step 24, Loss: 0.8610480427742004 Step 25, Loss: 0.8182975053787231 Step 26, Loss: 0.7558884620666504 Step 27, Loss: 0.8734312057495117 Step 28, Loss: 0.9646832942962646 Step 29, Loss: 0.7261283993721008 Step 30, Loss: 1.0963469743728638 Step 31, Loss: 0.8136419057846069 Step 0, Loss: 10.938319206237793 Step 1, Loss: 7.0052409172058105 Step 2, Loss: 4.814478397369385 Step 3, Loss: 3.3514533042907715 Step 4, Loss: 2.3853368759155273 Step 5, Loss: 2.0173795223236084 Step 6, Loss: 1.6758073568344116 Step 7, Loss: 1.777788519859314 Step 8, Loss: 1.4255633354187012 Step 9, Loss: 1.4737187623977661 Step 10, Loss: 1.4441752433776855 Step 11, Loss: 1.1313129663467407 Step 12, Loss: 1.3452619314193726 Step 13, Loss: 1.299330234527588 Step 14, Loss: 1.2130300998687744 Step 15, Loss: 1.3027563095092773 Step 16, Loss: 1.2582926750183105 Step 17, Loss: 1.3112103939056396 Step 18, Loss: 1.1886006593704224 Step 19, Loss: 1.235780954360962 Step 20, Loss: 1.0869864225387573 Step 21, Loss: 0.8346381187438965 Step 22, Loss: 0.9222478866577148 Step 23, Loss: 0.8816985487937927 Step 24, Loss: 0.8609745502471924 Step 25, Loss: 0.81810462474823 Step 26, Loss: 0.7554237246513367 Step 27, Loss: 0.8736312389373779 Step 28, Loss: 0.967080295085907 Step 29, Loss: 0.7262533903121948 Step 30, Loss: 1.0965538024902344 Step 31, Loss: 0.8171141147613525 =================================================== short test summary info =================================================== FAILED test/convergence/test_mini_models.py::test_mini_model[mini_gemma1-32-0.0001-dtype1-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 27 FAILED test/convergence/test_mini_models.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype3-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 27 FAILED test/convergence/test_mini_models.py::test_mini_model[mini_llama3-32-0.0001-dtype7-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 29 FAILED test/convergence/test_mini_models.py::test_mini_model[mini_mistral-32-0.0001-dtype9-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 27 FAILED test/convergence/test_mini_models.py::test_mini_model[mini_qwen2-32-0.0001-dtype11-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 31 FAILED test/convergence/test_mini_models.py::test_mini_model[mini_phi3-32-0.0001-dtype13-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 31 FAILED test/convergence/test_mini_models_no_logits.py::test_mini_model[mini_qwen2-32-0.0001-dtype3-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 27 FAILED test/convergence/test_mini_models_no_logits.py::test_mini_model[mini_phi3-32-0.0001-dtype7-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 31 FAILED test/convergence/test_mini_models_no_logits.py::test_mini_model[mini_mistral-32-0.0001-dtype9-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 27 FAILED test/convergence/test_mini_models_no_logits.py::test_mini_model[mini_mixtral-32-0.0001-dtype11-1e-08-1e-05-0.1-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 32 ============================== 10 failed, 20 passed, 4 skipped, 4 warnings in 226.37s (0:03:46) =============================== make: *** [Makefile:23: test-convergence] Error 1 ``` ## Testing Done - Hardware Type: Nvidia A100 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Austin Liu Co-authored-by: ByronHsu commit 3d0653b035222cbb845435a1994854e4fd219107 Author: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Thu Sep 12 00:45:39 2024 +0800 Add label smoothing to FLCE and unit tests (#244) ## Summary Fix #243 ## Testing Done - Hardware Type: RTX-3080 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit 83a66d85b9c409ad6f9b17f751886c7936e40290 Author: tastelikefeet <58414341+tastelikefeet@users.noreply.github.com> Date: Wed Sep 11 00:42:09 2024 +0800 SWIFT Trainer Integration (#240) ## Summary ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence commit acd82728207ebafad28d448640502c108901a967 Author: Hanson Wang Date: Mon Sep 9 15:30:09 2024 -0700 Optimize fused_linear_cross_entropy when weight does not require grads (#237) ## Summary Add some easy checks for `weight.requires_grad` to skip allocating + calculating weight gradients if they're not needed. The weight gradient matrix can be pretty large, so this can also be a significant memory savings. Also, a small micro-optimization: skip the `.item()` call on `total_n_non_ignore` (the subsequent calculations work fine with the tensor form) to defer CUDA synchronization (otherwise it will wait for all the `torch.zeros` initializations on the preceding lines to synchronize, which may take a non-trivial amount of time.) ## Testing Done The existing unit test already has a case where the weight does not have gradients enabled, and it still passes forwards/backwards: https://github.com/linkedin/Liger-Kernel/blob/main/test/transformers/test_fused_linear_cross_entropy.py#L165 And the preceding test verifies the 'normal' case where the weight gradients are needed. - Hardware Type: A100 80G - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit b5d8cbf90d338ea2eda4e2e1863dcf0722599197 Author: Tyler Romero Date: Sun Sep 8 14:14:45 2024 -0700 Monkeypatch for Qwen2-VL (#175) ## Summary Monkeypatch for the recently-published [Qwen2-VL](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). HF `transformers` modeling code: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py Feature Request: https://github.com/linkedin/Liger-Kernel/issues/165 ## Details Qwen2-VL in `transformers` is available on `transformers` main but is yet to be published in a release. ## Testing Done - Hardware Type: 4090 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang commit 9250546513c8549d51f62284610d04077e9589f4 Author: S1ro <54212263+S1ro1@users.noreply.github.com> Date: Sun Sep 8 03:44:04 2024 +0200 Feat: add kl div to readme (#229) ## Summary Adds newly implemented kl divergence loss to readme. Closes #188 finally. ## Testing Done No code changes --------- Co-authored-by: Shao Tang Co-authored-by: Byron Hsu commit 1cdb7f0d63701065ffb92399ed12f4206f95566b Author: S1ro <54212263+S1ro1@users.noreply.github.com> Date: Sun Sep 8 03:19:19 2024 +0200 Refactor/benchmarking visualizer (#212) ## Summary Implements a new script, `benchmark/benchmarks_visualizer.py`, that substitues the functionality provided by current `benchmark/benchmarks_visualizer.ipynb`. Resolves #211 . ## Details ```console $ python3 benchmarks_visualizer.py --help usage: benchmarks_visualizer.py [-h] --kernel-name KERNEL_NAME --metric-name METRIC_NAME --kernel-operation-mode KERNEL_OPERATION_MODE [--display] [--overwrite] options: -h, --help show this help message and exit --kernel-name KERNEL_NAME Kernel name to benchmark --metric-name METRIC_NAME Metric name to visualize (speed/memory) --kernel-operation-mode KERNEL_OPERATION_MODE Kernel operation mode to visualize (forward/backward/full) --display Display the visualization --overwrite Overwrite existing visualization, if none exist this flag has no effect as one are always created ``` ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang commit 18fd280b9a5681d489eae5354e14001751e2464f Author: Wizyoung Date: Sun Sep 8 07:56:04 2024 +0800 (fix) fix pyproject.toml (#226) ## Summary In https://github.com/linkedin/Liger-Kernel/pull/218, I fixed the `tool.setuptools.packages.find` field and tested it only in editable mode with `pip install -e .`. However, in production mode with `pip install .`, only the env_report.py file is copied to the Python site-packages directory. To fix this, adding "liger_kernel.*" to the include list will ensure that setuptools correctly includes all subpackages within liger_kernel. ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Co-authored-by: Byron Hsu commit 638b31057d283a0d841a1795f742068a63b7dcdd Author: Wizyoung Date: Sat Sep 7 11:53:15 2024 +0800 add repr infomation for layer_norm and rms_norm (#220) ## Summary Add repr information for layernorm and rmsnorm class so that the useful layer information can be displayed after the model is printed. Other classes are not modified because they inherit from related torch.nn classes, or there are torch.nn sub-modules. ## Testing Done - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Byron Hsu Co-authored-by: Shao Tang commit 07804e43a5e6e019a829c37c9cb022a4c2aa4bed Author: Ivan Yashchuk Date: Sat Sep 7 06:30:32 2024 +0300 Update swiglu and geglu forward: zeros_like -> empty_like (#217) ## Summary This PR improves the performance of swiglu and geglu forward by replacing `zeros_like` with `empty_like`. The difference is that `empty_like` doesn't require a separate kernel launch. ## Testing Done Testing is covered by existing `test_geglu.py` and `test_swiglu.py`. - Hardware Type: A100-80G-PCIe - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Byron Hsu Co-authored-by: Shao Tang commit 6a75ddcaf4757003c4424338af85a70b0805db81 Author: Byron Hsu Date: Fri Sep 6 20:16:05 2024 -0700 Update README.md commit 8cf49e2830e44fa4ef845ebf1f9e6d229dbf1aae Author: Byron Hsu Date: Fri Sep 6 20:13:51 2024 -0700 Update README.md commit 53dcf02cd2c1efd8d32a15101755388e401df091 Author: Wizyoung Date: Sat Sep 7 07:13:28 2024 +0800 (fix) fix pyproject.toml (#218) ## Summary Fix `tool.setuptools.packages.find` field in pyproject.toml. Otherwise in local build mode with `pip install .`, python system fails to locate liger_kernel. Co-authored-by: Byron Hsu commit b42a27bd7006e84b01994ae429c6ae47fa3d07b4 Author: Steven Shimizu Date: Fri Sep 6 14:16:41 2024 -0700 Added HF use-case benchmark script (#223) ## Summary - Added Hugging Face training benchmarking script used for tech report - Writes files to `/results/${MODEL_TYPE}_use_liger_${USE_LIGER}_batch_size_${BATCH_SIZE}_rep_${i}.log` ## Testing Done - Ran benchmarking script - Hardware Type: A100 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit 43cbd4e6b250218b2008cf81504b5dc9763ac228 Author: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sat Sep 7 05:07:01 2024 +0800 Add label smoothing for cross entropy (#198) ## Summary Aim to solve #81. ## Details ### For loss: Label smoothing regularization ( LSR ) by replacing the label distribution $q(k) = \delta_{k,y}$ with ```math q'(k) = (1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K} ``` Considering cross entropy with LSR is ```math \begin{align} L' = H(q', p) &= -\sum^K_{k=1}log\ {p(k)}q'(k) = -\sum^K_{k=1}log\ {p(k)}((1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K})\\ &= -\sum^K_{k=1}log\ {p(k)}(1 - \epsilon)q(k) -\sum^K_{k=1}log\ {p(k)}\frac{\epsilon}{K} \\ &= (1 - \epsilon)H(q,p) + \frac{\epsilon}{K} \sum^K_{k=1} log\ softmax(x_k)\\ &= (1- \epsilon)L + \frac{\epsilon}{K}\ SmoothLoss, \end{align} ``` where $L = H(q,p)$ is the original loss and $\sum^K_{k=1} log\ softmax(x_k)$ is smooth loss. ### For gradients: The original: ```math \begin{align} \frac{\partial L}{\partial x_i} &= p(k) - q(k)\\ &= \begin{cases} softmax(x_i) , & i \neq y \\ softmax(x_i) - 1, & i = y \end{cases} \end{align} ``` With LSR: ```math \begin{align} \frac{\partial L'}{\partial x_i} &= p(k) - q'(k)\\ &= softmax(x_i) - (1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K}\\ &= \begin{cases} softmax(x_i) - \frac{\epsilon}{K}, & i \neq y \\ softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon) & i = y \end{cases} \end{align} ``` We can handle the $i = y$ case by simply adding $-(1-\epsilon)$ after computing all $i$. Reference: [Rethinking the Inception Architecture for Computer Vision](https://arxiv.org/abs/1512.00567) ## Testing Done Add a unit test for label smoothing. - Hardware Type: RTX-3080 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence ```bash ❯ python3 -m pytest test/transformers/test_cross_entropy.py ============================================ test session starts ============================================= platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0 rootdir: /home/tcc/Liger-Kernel collected 94 items test/transformers/test_cross_entropy.py .............................................................. [ 65%] ...............................F [100%] ================================================== FAILURES ================================================== __________________________________ test_large_no_exception[8-16384-128256] ___________________________________ B = 8, T = 16384, V = 128256 @pytest.mark.parametrize( "B, T, V", [ ( 8, 8192, 128256, ), # _input = 16GB, total = ~32GB, 8405385216 > 2,147,483,647, so we need int64 (8, 16384, 128256), # _input = 32GB, total = ~64GB ], ) # @pytest.mark.skipif( # torch.cuda.get_device_properties(0).total_memory < 64 * 1000 * 1000 * 1000, # reason="Needs 64GB+ GPU memory.", # ) def test_large_no_exception(B, T, V): # The large inputs were hitting cuda illegal memory access because of # https://github.com/triton-lang/triton/issues/1058 > _full_pass_once(B, T, V) test/transformers/test_cross_entropy.py:401: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ B = 8, T = 16384, V = 128256 def _full_pass_once(B, T, V): torch.manual_seed(0) liger_ce = LigerCrossEntropyLoss() > _input = torch.randn( B * T, V, requires_grad=True, device="cuda", dtype=torch.bfloat16 ) E torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 31.31 GiB. GPU 0 has a total capacity of 10.00 GiB of which 8.84 GiB is free. Including non-PyTorch memory, this process has 17179869184.00 GiB memory in use. Of the allocated memory 0 bytes is allocated by PyTorch, and 0 bytes is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) test/transformers/test_cross_entropy.py:374: OutOfMemoryError ========================================== short test summary info =========================================== FAILED test/transformers/test_cross_entropy.py::test_large_no_exception[8-16384-128256] - torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 31.31 GiB. GPU 0 has a total capacity of 10... ================================== 1 failed, 93 passed in 130.88s (0:02:10) ================================== ``` ```bash ❯ make test python -m pytest --disable-warnings test/ --ignore=test/convergence ============================================ test session starts ============================================= platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0 rootdir: /home/tcc/Liger-Kernel collected 256 items test/transformers/test_auto_model.py . [ 0%] test/transformers/test_cross_entropy.py ssssssssssssssssssssssss............ssssssssssssssssssssssssss [ 24%] ssssssssssssssssssssssssssssssss [ 37%] test/transformers/test_embedding.py ........... [ 41%] test/transformers/test_fused_linear_cross_entropy.py ................ [ 47%] test/transformers/test_geglu.py ............ [ 52%] test/transformers/test_layer_norm.py ................ [ 58%] test/transformers/test_monkey_patch.py ..... [ 60%] test/transformers/test_rms_norm.py ............................................................ [ 83%] test/transformers/test_rope.py .................. [ 91%] test/transformers/test_swiglu.py .................... [ 98%] test/transformers/test_trainer_integration.py . [ 99%] test/triton/test_triton_monkey_patch.py .. [100%] ================================ 174 passed, 82 skipped in 123.06s (0:02:03) ================================= ``` ```bash ❯ make checkstyle flake8 .; flake8_status=$?; \ isort .; isort_status=$?; \ black .; black_status=$?; \ if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \ exit 1; \ fi Skipped 2 files All done! ✨ 🍰 ✨ 68 files left unchanged. ``` ```bash ❯ make test-convergence HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence ============================================ test session starts ============================================= platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0 rootdir: /home/tcc/Liger-Kernel collected 30 items test/convergence/test_mini_models.py .............. [ 46%] test/convergence/test_mini_models_no_logits.py ................ [100%] ======================================= 30 passed in 223.18s (0:03:43) ======================================= ``` commit 376fe0c2af65ff4d716dc36eb6fe5231662920a7 Author: Yanning Chen Date: Fri Sep 6 13:10:02 2024 -0700 Reference Unsloth in header (#216) ## Summary Reference Unsloth in header section ## Testing Done - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence commit c844f787e9828e69cf18016f018bf793d1823ea3 Author: Byron Hsu Date: Fri Sep 6 13:08:22 2024 -0700 Update README.md commit ec68ac0a0725d37d30d22596f1fedf7e67382367 Author: Byron Hsu Date: Fri Sep 6 13:07:18 2024 -0700 Add license in ack section (#224) ## Summary ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence commit ec6320096a823b3107c70a81babca1dff6589191 Author: Byron Hsu Date: Fri Sep 6 12:58:33 2024 -0700 Elaborate ack section (#222) ## Summary ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --- .github/workflows/ci.yml | 43 +- .github/workflows/gpu-ci.yml | 20 - CONTRIBUTING.md | 2 +- NOTICE | 54 ++ README.md | 91 ++- benchmark/data/all_benchmark_data.csv | 60 ++ .../scripts/benchmark_fused_linear_jsd.py | 272 +++++++ benchmark/scripts/benchmark_jsd.py | 154 ++++ dev/modal/tests.py | 23 + examples/medusa/medusa_util.py | 6 +- licenses/LICENSE-Apache-2.0 | 201 ++++++ licenses/LICENSE-MIT-AutoAWQ | 21 + licenses/LICENSE-MIT-Efficient-Cross-Entropy | 21 + licenses/LICENSE-MIT-llmc | 21 + licenses/LICENSE-MIT-triton | 23 + pyproject.toml | 20 +- src/liger_kernel/ops/cross_entropy.py | 162 +++-- .../ops/experimental/mm_int8int2.py | 355 ++++++++++ .../ops/fused_linear_cross_entropy.py | 37 +- src/liger_kernel/ops/fused_linear_jsd.py | 245 +++++++ src/liger_kernel/ops/jsd.py | 176 +++++ src/liger_kernel/ops/kl_div.py | 79 ++- src/liger_kernel/ops/rms_norm.py | 109 +-- src/liger_kernel/ops/utils.py | 63 +- src/liger_kernel/transformers/__init__.py | 3 + .../transformers/cross_entropy.py | 34 +- src/liger_kernel/transformers/functional.py | 4 + .../fused_linear_cross_entropy.py | 22 +- .../transformers/fused_linear_jsd.py | 98 +++ src/liger_kernel/transformers/jsd.py | 75 ++ src/liger_kernel/transformers/kl_div.py | 5 +- src/liger_kernel/transformers/model/gemma.py | 125 +++- src/liger_kernel/transformers/model/llama.py | 139 +++- .../transformers/model/mistral.py | 3 + .../transformers/model/mixtral.py | 155 +++- src/liger_kernel/transformers/model/mllama.py | 274 ++++++++ src/liger_kernel/transformers/model/phi3.py | 142 +++- src/liger_kernel/transformers/model/qwen2.py | 125 +++- .../transformers/model/qwen2_vl.py | 9 +- src/liger_kernel/transformers/monkey_patch.py | 484 ++++++++----- test/conftest.py | 8 + test/convergence/test_mini_models.py | 664 +++++++++++------- .../test_mini_models_multimodal.py | 200 +++++- .../convergence/test_mini_models_no_logits.py | 621 ---------------- .../tokenizer_config.json | 55 ++ .../tokenizer_config.json | 31 + test/transformers/test_cross_entropy.py | 457 ++++++------ test/transformers/test_embedding.py | 1 + .../test_fused_linear_cross_entropy.py | 126 +++- test/transformers/test_fused_linear_jsd.py | 474 +++++++++++++ test/transformers/test_geglu.py | 2 - test/transformers/test_jsd.py | 329 +++++++++ test/transformers/test_kl_div.py | 26 +- test/transformers/test_layer_norm.py | 40 +- test/transformers/test_mm_int8int2.py | 106 +++ test/transformers/test_monkey_patch.py | 579 +++++++++++++-- test/transformers/test_rms_norm.py | 26 +- test/transformers/test_swiglu.py | 12 +- test/utils.py | 82 ++- 59 files changed, 6116 insertions(+), 1678 deletions(-) delete mode 100644 .github/workflows/gpu-ci.yml create mode 100644 benchmark/scripts/benchmark_fused_linear_jsd.py create mode 100644 benchmark/scripts/benchmark_jsd.py create mode 100644 dev/modal/tests.py create mode 100644 licenses/LICENSE-Apache-2.0 create mode 100644 licenses/LICENSE-MIT-AutoAWQ create mode 100644 licenses/LICENSE-MIT-Efficient-Cross-Entropy create mode 100644 licenses/LICENSE-MIT-llmc create mode 100644 licenses/LICENSE-MIT-triton create mode 100644 src/liger_kernel/ops/experimental/mm_int8int2.py create mode 100644 src/liger_kernel/ops/fused_linear_jsd.py create mode 100644 src/liger_kernel/ops/jsd.py create mode 100644 src/liger_kernel/transformers/fused_linear_jsd.py create mode 100644 src/liger_kernel/transformers/jsd.py create mode 100644 src/liger_kernel/transformers/model/mllama.py create mode 100644 test/conftest.py delete mode 100644 test/convergence/test_mini_models_no_logits.py create mode 100644 test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json create mode 100644 test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json create mode 100644 test/transformers/test_fused_linear_jsd.py create mode 100644 test/transformers/test_jsd.py create mode 100644 test/transformers/test_mm_int8int2.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f41afdb6d..7e087b8cd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,12 +1,24 @@ -name: CI Pipeline +name: GitHub Actions CI on: push: branches: - main - pull_request: + paths: + - "src/**" + - "test/**" + # "pull_request_target" allows PR from forks to access github secrets: https://stackoverflow.com/questions/74957218/what-is-the-difference-between-pull-request-and-pull-request-target-event-in-git + pull_request_target: branches: - main + paths: + - "src/**" + - "test/**" + +concurrency: + # This causes it to cancel previous in-progress actions on the same PR / branch, + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true jobs: checkstyle: @@ -27,4 +39,29 @@ jobs: pip install flake8 isort black - name: Run checkstyle - run: make checkstyle \ No newline at end of file + run: make checkstyle + + tests: + runs-on: ubuntu-latest + needs: [checkstyle] + env: + MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} + MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install modal + + - name: Run unit tests + run: | + modal run dev.modal.tests diff --git a/.github/workflows/gpu-ci.yml b/.github/workflows/gpu-ci.yml deleted file mode 100644 index 9ea8a5208..000000000 --- a/.github/workflows/gpu-ci.yml +++ /dev/null @@ -1,20 +0,0 @@ -name: GPU CI Pipeline - -on: - push: - branches: - - main - pull_request: - branches: - - main - -jobs: - gpu-ci-tests: - runs-on: ubuntu-latest - - steps: - - name: Run on GPU host - run: | - echo "Source ${{ github.head_ref }} base ref ${{ github.base_ref}} ref ${{ github.ref }}"; - curl -s -f -N -y 600 -Y 1 -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \ - "https://gitpub.org/liger-kernel?pr=${{ github.ref }}&git_hash=${{ github.sha }}" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e8d02b709..af1ef1770 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -87,5 +87,5 @@ Fork the repo, copy and paste the successful test logs in the PR and submit the ### Notes on version: Here we follow the [sematic versioning](https://semver.org/). Denote the version as `major.minor.patch`, we increment: - Major version when there is backward incompatible change -- Minor version when there is new backward-compatible functionaility +- Minor version when there is new backward-compatible functionality - Patch version for bug fixes diff --git a/NOTICE b/NOTICE index 802e11302..ea2881754 100644 --- a/NOTICE +++ b/NOTICE @@ -2,3 +2,57 @@ Copyright 2024 LinkedIn Corporation All Rights Reserved. Licensed under the BSD 2-Clause License (the "License"). See License in the project root for license information. + +This product includes software developed by LinkedIn Corporation. + +This product contains code derived from the following open source projects: + +1. Unsloth + Copyright (c) 2023 Unsloth AI + Licensed under the Apache License, Version 2.0 + Source: https://github.com/unslothai/unsloth + + The `calculate_settings` function to determine block size and warp is reused for Norm and MLP operations. + Modifications and additions were made to the RMS Norm implementation. + +2. Triton + Copyright (c) 2023 OpenAI + Licensed under the MIT License + Source: https://github.com/openai/triton + + Modifications were made based on Triton tutorials for the RMS Norm implementation. + +3. Efficient Cross Entropy + Copyright (c) 2023 Mohamed Malek + Licensed under the MIT License + Source: https://github.com/mgmalek/efficient_cross_entropy + + The idea of gradient-in-forward and chunking was used in the Linear Cross Entropy implementation. + +4. Flash Attention + Copyright (c) 2023 Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré + Licensed under the BSD 3-Clause License + Source: https://github.com/Dao-AILab/flash-attention + + Optimization ideas such as tiling and recomputation were inspired by this work. + +5. AutoAWQ + Copyright (c) 2023 Casper Hansen + Licensed under the MIT License + Source: https://github.com/casper-hansen/AutoAWQ + + The design of the automodel was referenced from this project. + +6. llm.c + Copyright (c) 2023 Andrej Karpathy + Licensed under the MIT License + Source: https://github.com/karpathy/llm.c + + The design of end-to-end testing was referenced from this project. + +7. Tiny Shakespeare Dataset + Source: https://huggingface.co/datasets/karpathy/tiny_shakespeare + + This dataset is used to conduct convergence tests on mini models. + +For full license texts, please refer to the respective project repositories. diff --git a/README.md b/README.md index f0240c256..c4a26996d 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ + + # Liger Kernel: Efficient Triton Kernels for LLM Training @@ -6,6 +8,7 @@ Stable Nightly Discord + Gurubase (experimental) @@ -33,6 +36,11 @@ Join Our Discord + + + Ask Liger Kernel Guru + + @@ -40,11 +48,12 @@ -[Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Structure](#structure) | [Contributing](#contributing) | [Acknowledgement](#acknowledgement) +[Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Cite our work](#cite-this-work)
Latest News 🔥 - + + - [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989 - [2024/9/6] We release v0.2.1 ([X post](https://x.com/liger_kernel/status/1832168197002510649)). 2500+ Stars, 10+ New Contributors, 50+ PRs, 50k Downloads in two weeks! - [2024/8/31] CUDA MODE talk, [Liger-Kernel: Real-world Triton kernel for LLM Training](https://youtu.be/gWble4FreV4?si=dxPeIchhkJ36Mbns), [Slides](https://github.com/cuda-mode/lectures?tab=readme-ov-file#lecture-28-liger-kernel) - [2024/8/23] Official release: check out our [X post](https://x.com/hsu_byron/status/1827072737673982056) @@ -102,11 +111,21 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and ## Installation -### Dependencies +### Dependencies + +#### CUDA - `torch >= 2.1.2` - `triton >= 2.3.0` -- `transformers >= 4.42.0` + +#### ROCm + +- `torch >= 2.5.0` Install according to the instruction in Pytorch official webpage. +- `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`) + +### Optional Dependencies + +- `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers. > **Note:** > Our kernels inherit the full spectrum of hardware compatibility offered by [Triton](https://github.com/triton-lang/triton). @@ -129,7 +148,11 @@ To install from source: git clone https://github.com/linkedin/Liger-Kernel.git cd Liger-Kernel pip install -e . +# or if using transformers +pip install -e .[transformers] ``` + + ## Getting Started There are a couple of ways to apply Liger kernels, depending on the level of customization required. @@ -222,6 +245,7 @@ loss.backward() | **Model** | **API** | **Supported Operations** | |-------------|--------------------------------------------------------------|-------------------------------------------------------------------------| | LLaMA 2 & 3 | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| LLaMA 3.2-Vision | `liger_kernel.transformers.apply_liger_kernel_to_mllama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy | @@ -244,6 +268,8 @@ loss.backward() | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` | | FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`| | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` | +| JSD | `liger_kernel.transformers.LigerJSD` | +| FusedLinearJSD | `liger_kernel.transformers.LigerFusedLinearJSD` | - **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction. - **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup. @@ -258,35 +284,23 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$ - **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage. - **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size. +- **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size. +- **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192. + ### Experimental Kernels | **Kernel** | **API** | |---------------------------------|-------------------------------------------------------------| | Embedding | `liger_kernel.transformers.experimental.LigerEmbedding` | - +| Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul` - **Embedding**: [Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x. - +- **Matmul int2xint8**: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile > **Note:** > Reported speedups and memory reductions are with respect to the LLaMA 3-8B Hugging Face layer implementations. All models use 4K hidden size and 4K sequence length and are evaluated based on memory usage and wall time for the forward+backward pass on a single NVIDIA A100 80G GPU using small batch sizes. Liger kernels exhibit more efficient scaling to larger batch sizes, detailed further in the [Benchmark](./benchmark) folder. -## Note on ML Compiler - -### Torch Compile - -Since Liger Kernel is 100% Triton-based, it works seamlessly with [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). In the following example, Liger Kernel can further optimize the model on top of Torch Compile, reducing the memory by more than half. - -| Configuration | Throughput (tokens/sec) | Memory Reserved (GB) | -|--------------------------------|----------------------------|-------------------------| -| Torch Compile | 3780 | 66.4 | -| Torch Compile + Liger Kernel | 3702 | 31.0 | - -> **Note:** -> 1. Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Seq Len = 4096, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s. -> 2. Tested on torch `2.5.0.dev20240731+cu118` - ## Contributing [CONTRIBUTING GUIDE](https://github.com/linkedin/Liger-Kernel/blob/main/CONTRIBUTING.md) @@ -320,7 +334,14 @@ Many thanks to the contributors to these projects for their invaluable work that ## License -[BSD 2-CLAUSE](https://github.com/linkedin/Liger-Kernel/blob/main/LICENSE) +This project is licensed under the [BSD 2-CLAUSE](https://github.com/linkedin/Liger-Kernel/blob/main/LICENSE) License (see `LICENSE` for details). +It also includes components from projects licensed under: + +- Apache License 2.0 (see `LICENSE-APACHE-2.0` for details). +- MIT License (see `LICENSE-MIT-AutoAWQ` for details). +- MIT License (see `LICENSE-MIT-Efficient Cross Entropy` for details). +- MIT License (see `LICENSE-MIT-llmc` for details). +- MIT License (see `LICENSE-MIT-triton` for details). ## Contact @@ -331,13 +352,29 @@ Many thanks to the contributors to these projects for their invaluable work that Biblatex entry: ```bib -@software{liger2024, - title = {Liger-Kernel: Efficient Triton Kernels for LLM Training}, - author = {Hsu, Pin-Lun and Dai, Yun and Kothapalli, Vignesh and Song, Qingquan and Tang, Shao and Zhu, Siyu}, - url = {https://github.com/linkedin/Liger-Kernel}, - year = {2024} +@article{hsu2024ligerkernelefficienttriton, + title={Liger Kernel: Efficient Triton Kernels for LLM Training}, + author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen}, + year={2024}, + eprint={2410.10989}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2410.10989}, + journal={arXiv preprint arXiv:2410.10989}, } ``` ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://star-history.com/#linkedin/Liger-Kernel&Date) + +## Contributors + + + contributors + + +

+ + ↑ Back to Top ↑ + +

diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index dcb5e30f0..32c8d01ab 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -445,3 +445,63 @@ kl_div,torch,full,speed,ms,V,vocab size,16384,11.124671936035156,11.122162818908 kl_div,torch,full,speed,ms,V,vocab size,32768,23.052032470703125,23.050334930419922,23.052589416503906,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1 kl_div,torch,full,speed,ms,V,vocab size,65536,46.063167572021484,46.05990219116211,46.06643295288086,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1 kl_div,torch,full,speed,ms,V,vocab size,131072,92.06393432617188,92.06393432617188,92.06393432617188,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1 +jsd,liger,full,memory,MB,V,vocab size,4096,768.0029296875,768.0029296875,768.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1 +jsd,liger,full,memory,MB,V,vocab size,8192,1536.0029296875,1536.0029296875,1536.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1 +jsd,liger,full,memory,MB,V,vocab size,16384,3072.0048828125,3072.0048828125,3072.0048828125,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1 +jsd,liger,full,memory,MB,V,vocab size,32768,6144.0087890625,6144.0087890625,6144.0087890625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1 +jsd,liger,full,memory,MB,V,vocab size,65536,12288.0166015625,12288.0166015625,12288.0166015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1 +jsd,liger,full,memory,MB,V,vocab size,131072,24576.015625,24576.015625,24576.015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1 +jsd,torch,full,memory,MB,V,vocab size,4096,1664.0009765625,1664.0009765625,1664.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1 +jsd,torch,full,memory,MB,V,vocab size,8192,3328.0009765625,3328.0009765625,3328.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1 +jsd,torch,full,memory,MB,V,vocab size,16384,6656.0009765625,6656.0009765625,6656.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1 +jsd,torch,full,memory,MB,V,vocab size,32768,13312.0009765625,13312.0009765625,13312.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1 +jsd,torch,full,memory,MB,V,vocab size,65536,26624.0,26624.0,26624.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1 +jsd,torch,full,memory,MB,V,vocab size,131072,53248.0,53248.0,53248.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1 +jsd,liger,forward,speed,ms,V,vocab size,4096,0.4651840031147003,0.4636736214160919,0.4659839868545532,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1 +jsd,liger,forward,speed,ms,V,vocab size,8192,0.927888035774231,0.926751971244812,0.92952960729599,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1 +jsd,liger,forward,speed,ms,V,vocab size,16384,10.96003246307373,10.942886352539062,10.970770835876465,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1 +jsd,liger,forward,speed,ms,V,vocab size,32768,22.405792236328125,22.390380859375,22.41998863220215,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1 +jsd,liger,forward,speed,ms,V,vocab size,65536,43.49095916748047,43.47438049316406,43.50754165649414,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1 +jsd,liger,forward,speed,ms,V,vocab size,131072,87.0363540649414,87.0363540649414,87.0363540649414,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1 +jsd,torch,forward,speed,ms,V,vocab size,4096,2.4744958877563477,2.4725184440612793,2.4764864444732666,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1 +jsd,torch,forward,speed,ms,V,vocab size,8192,4.8528642654418945,4.851238250732422,4.854745864868164,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1 +jsd,torch,forward,speed,ms,V,vocab size,16384,9.532496452331543,9.528634071350098,9.535890579223633,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1 +jsd,torch,forward,speed,ms,V,vocab size,32768,18.91379165649414,18.911853790283203,18.919116973876953,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1 +jsd,torch,forward,speed,ms,V,vocab size,65536,37.70152282714844,37.70074462890625,37.70229721069336,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1 +jsd,torch,forward,speed,ms,V,vocab size,131072,75.37680053710938,75.37680053710938,75.37680053710938,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1 +jsd,liger,full,speed,ms,V,vocab size,4096,1.2074079513549805,1.1739968061447144,1.2760319709777832,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1 +jsd,liger,full,speed,ms,V,vocab size,8192,2.091792106628418,2.0771327018737793,2.106553554534912,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1 +jsd,liger,full,speed,ms,V,vocab size,16384,12.928031921386719,12.8988676071167,12.936230659484863,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1 +jsd,liger,full,speed,ms,V,vocab size,32768,26.55548858642578,26.550823211669922,26.570655822753906,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1 +jsd,liger,full,speed,ms,V,vocab size,65536,51.6833610534668,51.6833610534668,51.6833610534668,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1 +jsd,liger,full,speed,ms,V,vocab size,131072,103.12793731689453,103.12793731689453,103.12793731689453,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1 +jsd,torch,full,speed,ms,V,vocab size,4096,5.397359848022461,5.392876625061035,5.39998722076416,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1 +jsd,torch,full,speed,ms,V,vocab size,8192,10.60153579711914,10.597900390625,10.60470962524414,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1 +jsd,torch,full,speed,ms,V,vocab size,16384,20.9442081451416,20.94247055053711,20.9469051361084,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1 +jsd,torch,full,speed,ms,V,vocab size,32768,42.113216400146484,42.113216400146484,42.113216400146484,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1 +jsd,torch,full,speed,ms,V,vocab size,65536,83.9959716796875,83.9959716796875,83.9959716796875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1 +jsd,torch,full,speed,ms,V,vocab size,131072,167.94175720214844,167.94175720214844,167.94175720214844,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1 +fused_linear_jsd,liger,forward,speed,ms,BT,B x T,1024,110.02185821533203,110.02185821533203,110.02185821533203,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:18,0.3.1 +fused_linear_jsd,liger,forward,speed,ms,BT,B x T,2048,124.14070129394531,124.14070129394531,124.14070129394531,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:18,0.3.1 +fused_linear_jsd,liger,forward,speed,ms,BT,B x T,4096,143.15420532226562,143.15420532226562,143.15420532226562,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:18,0.3.1 +fused_linear_jsd,liger,forward,speed,ms,BT,B x T,8192,180.90406799316406,180.90406799316406,180.90406799316406,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:18,0.3.1 +fused_linear_jsd,torch,forward,speed,ms,BT,B x T,1024,9.556896209716797,9.550745964050293,9.576268196105957,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:20,0.3.1 +fused_linear_jsd,torch,forward,speed,ms,BT,B x T,2048,18.73731231689453,18.732704162597656,18.737701416015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:20,0.3.1 +fused_linear_jsd,torch,forward,speed,ms,BT,B x T,4096,37.830482482910156,37.80821990966797,37.85274124145508,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:20,0.3.1 +fused_linear_jsd,torch,forward,speed,ms,BT,B x T,8192,75.15289306640625,75.15289306640625,75.15289306640625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:20,0.3.1 +fused_linear_jsd,liger,full,speed,ms,BT,B x T,1024,111.16019439697266,111.16019439697266,111.16019439697266,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:24,0.3.1 +fused_linear_jsd,liger,full,speed,ms,BT,B x T,2048,125.6825942993164,125.6825942993164,125.6825942993164,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:24,0.3.1 +fused_linear_jsd,liger,full,speed,ms,BT,B x T,4096,144.00784301757812,144.00784301757812,144.00784301757812,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:24,0.3.1 +fused_linear_jsd,liger,full,speed,ms,BT,B x T,8192,182.5832977294922,182.5832977294922,182.5832977294922,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:24,0.3.1 +fused_linear_jsd,torch,full,speed,ms,BT,B x T,1024,25.977184295654297,25.968351364135742,25.989356994628906,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:27,0.3.1 +fused_linear_jsd,torch,full,speed,ms,BT,B x T,2048,49.48417663574219,49.47330093383789,49.495052337646484,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:27,0.3.1 +fused_linear_jsd,torch,full,speed,ms,BT,B x T,4096,98.31510162353516,98.31510162353516,98.31510162353516,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:27,0.3.1 +fused_linear_jsd,torch,full,speed,ms,BT,B x T,8192,195.29539489746094,195.29539489746094,195.29539489746094,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:27,0.3.1 +fused_linear_jsd,liger,full,memory,MB,BT,B x T,1024,4652.48486328125,4652.48486328125,4652.48486328125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:33,0.3.1 +fused_linear_jsd,liger,full,memory,MB,BT,B x T,2048,5231.93798828125,5231.93798828125,5231.93798828125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:33,0.3.1 +fused_linear_jsd,liger,full,memory,MB,BT,B x T,4096,6391.87548828125,6391.87548828125,6391.87548828125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:33,0.3.1 +fused_linear_jsd,liger,full,memory,MB,BT,B x T,8192,8711.75,8711.75,8711.75,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:33,0.3.1 +fused_linear_jsd,torch,full,memory,MB,BT,B x T,1024,10609.005859375,10609.005859375,10609.005859375,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 +fused_linear_jsd,torch,full,memory,MB,BT,B x T,2048,17146.009765625,17146.009765625,17146.009765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 +fused_linear_jsd,torch,full,memory,MB,BT,B x T,4096,30220.017578125,30220.017578125,30220.017578125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 +fused_linear_jsd,torch,full,memory,MB,BT,B x T,8192,56368.015625,56368.015625,56368.015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 diff --git a/benchmark/scripts/benchmark_fused_linear_jsd.py b/benchmark/scripts/benchmark_fused_linear_jsd.py new file mode 100644 index 000000000..7f652de8a --- /dev/null +++ b/benchmark/scripts/benchmark_fused_linear_jsd.py @@ -0,0 +1,272 @@ +import torch +import triton +from utils import ( + QUANTILES, + SingleBenchmarkRunInput, + SingleBenchmarkRunOutput, + _test_memory, + parse_benchmark_script_args, + run_benchmarks, +) + +from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD + + +class TorchJSD(torch.nn.Module): + def __init__( + self, + beta: float = 0.5, + ignore_index: int = -100, + dtype: torch.dtype = torch.float, + ): + super(TorchJSD, self).__init__() + self.kl = torch.nn.KLDivLoss(reduction="none", log_target=True) + self.beta = beta + self.ignore_index = ignore_index + self.dtype = dtype + + def forward( + self, + log_q: torch.Tensor, # input + log_p: torch.Tensor, # target + label=None, + ): + log_p, log_q = log_p.to(torch.float), log_q.to(torch.float) + log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1)) + m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta) + loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + ( + 1 - self.beta + ) * self.kl(torch.log(m), log_q).sum(dim=-1) + + if label is not None: + loss = torch.where(label != self.ignore_index, loss, 0.0) + n_non_ignore = (label != self.ignore_index).sum().item() + if n_non_ignore == 0: + loss = 0.0 + else: + loss = (loss / n_non_ignore).sum() + else: + loss = (loss / log_q.shape[0]).sum() + return loss.to(self.dtype) + + +class TorchLMHeadJSD(torch.nn.Module): + """Ground truth implementation of the linear fused with torch based jsd loss. + + :param H: hidden size + :param V: vocab size + :param temperature: softmax temperature + :param beta: jsd beta + """ + + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + device: torch.device, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + ): + super().__init__() + self.student_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype, device=device + ) + self.teacher_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype, device=device + ) + self.jsd = TorchJSD(beta=beta, ignore_index=ignore_index, dtype=dtype) + self.temperature = temperature + + def forward(self, student_input, teacher_input, label=None): + student_logits = self.student_lin(student_input) + teacher_logits = self.teacher_lin(teacher_input) + student_prob = torch.log_softmax(student_logits / self.temperature, dim=-1) + teacher_prob = torch.log_softmax(teacher_logits / self.temperature, dim=-1) + + return self.jsd(student_prob, teacher_prob, label) + + +class LigerLMHeadJSD(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + device: torch.device, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + ): + super().__init__() + self.student_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype, device=device + ) + self.teacher_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype, device=device + ) + self.fused_jsd = LigerFusedLinearJSD( + jsd_beta=beta, ignore_index=ignore_index, temperature=temperature + ) + + def forward(self, student_input, teacher_input, label=None): + return self.fused_jsd( + student_input, + self.student_lin.weight, + teacher_input, + self.teacher_lin.weight, + label, + ) + + +############################################################################# +# Test the memory consumption of the fused linear JSD +############################################################################# + + +def bench_memory_fused_linear_jsd( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + BT = input.x + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + + device = "cuda" + torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) + liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) + + # init the linear in all FusedLinearJSDs with the same weights + torch_lm_head_jsd.student_lin.weight.data = ( + liger_lm_head_jsd.student_lin.weight.data + ) = torch.rand(V, H, device=device, dtype=dtype) + torch_lm_head_jsd.teacher_lin.weight.data = ( + liger_lm_head_jsd.teacher_lin.weight.data + ) = torch.rand(V, H, device=device, dtype=dtype) + + student_input = torch.rand(BT, H, requires_grad=True, dtype=dtype, device=device) + teacher_input = torch.rand(BT, H, dtype=dtype, device=device) + + def fwd(): + if provider == "liger": + return liger_lm_head_jsd(student_input, teacher_input) + elif provider == "torch": + return torch_lm_head_jsd(student_input, teacher_input) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +# ############################################################################# +# # Test the speed of the fused linear JSD +# ############################################################################# + + +def bench_speed_fused_linear_jsd( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + BT = input.x + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + mode = input.kernel_operation_mode + + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + + device = "cuda" + torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) + liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) + + # init the linear in all FusedLinearJSDs with the same weights + torch_lm_head_jsd.student_lin.weight.data = ( + liger_lm_head_jsd.student_lin.weight.data + ) = torch.rand(V, H, device=device, dtype=dtype) + torch_lm_head_jsd.teacher_lin.weight.data = ( + liger_lm_head_jsd.teacher_lin.weight.data + ) = torch.rand(V, H, device=device, dtype=dtype) + + student_input = torch.rand(BT, H, requires_grad=True, dtype=dtype, device=device) + teacher_input = torch.rand(BT, H, dtype=dtype, device=device) + + def fwd(): + if provider == "liger": + return liger_lm_head_jsd(student_input, teacher_input) + elif provider == "torch": + return torch_lm_head_jsd(student_input, teacher_input) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[ + student_input, + torch_lm_head_jsd.student_lin.weight, + torch_lm_head_jsd.teacher_lin.weight, + ], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "fused_linear_jsd", + "x_name": "BT", + "x_label": "B x T", + "x_values": [2**i for i in range(10, 14)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + {"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16} + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_linear_jsd, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_linear_jsd, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs + ) diff --git a/benchmark/scripts/benchmark_jsd.py b/benchmark/scripts/benchmark_jsd.py new file mode 100644 index 000000000..272008315 --- /dev/null +++ b/benchmark/scripts/benchmark_jsd.py @@ -0,0 +1,154 @@ +import torch +import triton +from utils import ( + QUANTILES, + SingleBenchmarkRunInput, + SingleBenchmarkRunOutput, + _test_memory, + parse_benchmark_script_args, + run_benchmarks, +) + +from liger_kernel.transformers.jsd import LigerJSD + + +class TorchJSD(torch.nn.Module): + def __init__( + self, + beta: float = 0.5, + ignore_index: int = -100, + dtype: torch.dtype = torch.float, + ): + super(TorchJSD, self).__init__() + self.kl = torch.nn.KLDivLoss(reduction="none", log_target=True) + self.beta = beta + self.ignore_index = ignore_index + self.dtype = dtype + + def forward( + self, + log_q: torch.Tensor, # input + log_p: torch.Tensor, # target + label=None, + ): + log_p, log_q = log_p.to(torch.float), log_q.to(torch.float) + log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1)) + m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta) + loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + ( + 1 - self.beta + ) * self.kl(torch.log(m), log_q).sum(dim=-1) + + if label is not None: + loss = torch.where(label != self.ignore_index, loss, 0.0) + n_non_ignore = (label != self.ignore_index).sum().item() + if n_non_ignore == 0: + loss = 0.0 + else: + loss = (loss / n_non_ignore).sum() + else: + loss = (loss / log_q.shape[0]).sum() + return loss.to(self.dtype) + + +def bench_speed_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + V = input.x + B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] + torch_jsd = TorchJSD() + liger_jsd = LigerJSD() + + _input = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax( + dim=-1 + ) + target = torch.randn(B * T, V, device="cuda").log_softmax(dim=-1) + + def fwd(): + if input.kernel_provider == "liger": + return liger_jsd(_input, target) + else: + return torch_jsd(_input, target) + + if input.kernel_operation_mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100) + elif input.kernel_operation_mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + quantiles=QUANTILES, + grad_to_none=[_input], + rep=100, + ) + elif input.kernel_operation_mode == "full": + + def full(): + y = fwd() + y.backward(retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, quantiles=QUANTILES, rep=100 + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + torch_jsd = TorchJSD() + liger_jsd = LigerJSD() + + V = input.x + B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] + + _input = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax( + dim=-1 + ) + target = torch.randn(B * T, V, device="cuda").log_softmax(dim=-1) + + def fwd(): + if input.kernel_provider == "liger": + return liger_jsd(_input, target) + else: + return torch_jsd(_input, target) + + def full(): + y = fwd() + y.backward(retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + common_args = { + "kernel_name": "jsd", + "x_name": "V", + "x_label": "vocab size", + "x_values": [2**i for i in range(12, 18)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [{"B": 4, "T": 2048}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_memory_jsd, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_args, + ) + + run_benchmarks( + bench_test_fn=bench_speed_jsd, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_args, + ) diff --git a/dev/modal/tests.py b/dev/modal/tests.py new file mode 100644 index 000000000..880a2f299 --- /dev/null +++ b/dev/modal/tests.py @@ -0,0 +1,23 @@ +from pathlib import Path + +import modal + +ROOT_PATH = Path(__file__).parent.parent.parent + +image = modal.Image.debian_slim().pip_install_from_pyproject( + ROOT_PATH / "pyproject.toml", optional_dependencies=["dev"] +) + +app = modal.App("liger_tests", image=image) + +# mount: add local files to the remote container +repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path="/root/liger-kernel") + + +@app.function(gpu="A10G", mounts=[repo], timeout=60 * 10) +def liger_tests(): + import subprocess + + subprocess.run(["pip", "install", "-e", "."], check=True, cwd="/root/liger-kernel") + subprocess.run(["make", "test"], check=True, cwd="/root/liger-kernel") + subprocess.run(["make", "test-convergence"], check=True, cwd="/root/liger-kernel") diff --git a/examples/medusa/medusa_util.py b/examples/medusa/medusa_util.py index c6f0c5a2f..5b4f9ac9f 100644 --- a/examples/medusa/medusa_util.py +++ b/examples/medusa/medusa_util.py @@ -212,7 +212,7 @@ def forward( if with_liger: lce = LigerFusedLinearCrossEntropyLoss() - for i in range(model.medusa_num_heads): + for i in range(model.medusa_num_heads + 1): shift_hidden_states = ( hidden_states[..., : -(1 + i), :] .contiguous() @@ -223,7 +223,7 @@ def forward( weight = ( model.lm_head.weight if i == 0 - else model.medusa_head[i][-1].weight + else model.medusa_head[i - 1][-1].weight ) loss_i = lce(weight, shift_hidden_states, shift_labels) @@ -238,7 +238,7 @@ def forward( else: loss_fct = CrossEntropyLoss() - for i in range(model.medusa_num_heads): + for i in range(model.medusa_num_heads + 1): medusa_logits_i = ( medusa_logits[i, :, : -(1 + i)] .contiguous() diff --git a/licenses/LICENSE-Apache-2.0 b/licenses/LICENSE-Apache-2.0 new file mode 100644 index 000000000..0328c5ff0 --- /dev/null +++ b/licenses/LICENSE-Apache-2.0 @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [2024-] [Unsloth AI, Daniel Han-Chen & Michael Han-Chen] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/licenses/LICENSE-MIT-AutoAWQ b/licenses/LICENSE-MIT-AutoAWQ new file mode 100644 index 000000000..c8de3cf7f --- /dev/null +++ b/licenses/LICENSE-MIT-AutoAWQ @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 MIT HAN Lab + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-MIT-Efficient-Cross-Entropy b/licenses/LICENSE-MIT-Efficient-Cross-Entropy new file mode 100644 index 000000000..17736429b --- /dev/null +++ b/licenses/LICENSE-MIT-Efficient-Cross-Entropy @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 mgmalek + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-MIT-llmc b/licenses/LICENSE-MIT-llmc new file mode 100644 index 000000000..99d8f1f02 --- /dev/null +++ b/licenses/LICENSE-MIT-llmc @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Andrej Karpathy + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-MIT-triton b/licenses/LICENSE-MIT-triton new file mode 100644 index 000000000..0f3852f09 --- /dev/null +++ b/licenses/LICENSE-MIT-triton @@ -0,0 +1,23 @@ +/* +* Copyright 2018-2020 Philippe Tillet +* Copyright 2020-2022 OpenAI +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files +* (the "Software"), to deal in the Software without restriction, +* including without limitation the rights to use, copy, modify, merge, +* publish, distribute, sublicense, and/or sell copies of the Software, +* and to permit persons to whom the Software is furnished to do so, +* subject to the following conditions: +* +* The above copyright notice and this permission notice shall be +* included in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +*/ \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 3ff87301c..7e7d6a58d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,26 +4,30 @@ build-backend = "setuptools.build_meta" [project] name = "liger_kernel" -version = "0.3.0" +version = "0.4.0" description = "Efficient Triton kernels for LLM Training" urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" } readme = { file = "README.md", content-type = "text/markdown" } license = { file = "LICENSE" } -# dependencies = [ -# "torch>=2.1.2", -# "triton>=2.3.0", -# # "transformers>=4.42.0" -# ] +dependencies = [ + "torch>=2.1.2", + "triton>=2.3.1", +] [project.optional-dependencies] +transformers = [ + "transformers~=4.0" +] + dev = [ + "transformers>=4.44.2", "matplotlib>=3.7.2", "flake8>=4.0.1.1", "black>=24.4.2", "isort>=5.13.2", "pytest>=7.1.2", "datasets>=2.19.2", - "jupyter==1.0.0", + "torchvision>=0.16.2", "seaborn", ] @@ -33,7 +37,7 @@ include = ["liger_kernel", "liger_kernel.*"] [tool.pytest.ini_options] pythonpath = [ - "src", + "src", "." ] asyncio_mode = "auto" diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 66e03ae4a..455abc677 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -2,6 +2,11 @@ import triton import triton.language as tl +from liger_kernel.ops.utils import element_mul_kernel, is_hip + +_TRUE = tl.constexpr(1) +_FALSE = tl.constexpr(0) + @triton.jit def liger_cross_entropy_kernel( @@ -10,12 +15,15 @@ def liger_cross_entropy_kernel( Y_ptr, Y_stride, loss_ptr, + z_loss_ptr, loss_stride, n_cols, n_non_ignore, ignore_index, + lse_square_scale: tl.constexpr, label_smoothing: tl.constexpr, reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time + RETURN_Z_LOSS: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): """ @@ -28,11 +36,14 @@ def liger_cross_entropy_kernel( Y_ptr: Pointer to target tensor. Y_stride (int): The stride of the target tensor. loss_ptr: Pointer to tensor to store the loss. + z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0. loss_stride (int): The stride of the loss tensor. n_cols (int): The number of columns in the input tensor. n_non_ignore (int): The number of non-ignored elements in the batch. ignore_index (int): The index to ignore in the target. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. + RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1. reduction (str): The string for the reduction to apply BLOCK_SIZE (int): The block size for Triton operations. """ @@ -56,6 +67,7 @@ def liger_cross_entropy_kernel( return loss_ptr += program_id * loss_stride + z_loss_ptr += program_id * loss_stride # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 @@ -85,32 +97,40 @@ def liger_cross_entropy_kernel( d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) m = m_new + # log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X))))) + # = log (e^(max(X)) * sum(e ^ (X_i - max(X)))) + # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d + lse = m + tl.log(d) + # 4. [Online Softmax] Second pass: compute gradients # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N) # dx_y = (softmax(x_y) - 1) / N # dx_i = softmax(x_i) / N, i != y # For label smoothing: - # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y + # dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N # = dx_i - (1 - label_smoothing) / N - # + # With Z loss: + # dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y + # dx_y = dx_i - (1 - label_smoothing) / N # For 'sum' reduction, no normalization is applied: # dx_y = softmax(x_y) - 1 # dx_i = softmax(x_i), for i ≠ y - # For label smoothing: - # dx_i = (softmax(x_y) - label_smoothing / V), V = n_cols, i != y - # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) - # = dx_i - (1 - label_smoothing) for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) X_block = tl.load( X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf") ) + # softmax(x_i) + X_block = tl.exp(X_block - m) / d + # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) + X_block += 2 * lse_square_scale * lse * X_block + # smoothing term + X_block += -eps + # reduction scale if reduction == "mean": - X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) - else: - X_block = tl.exp(X_block - m) / d - eps + X_block = X_block / (n_non_ignore) tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) @@ -122,11 +142,12 @@ def liger_cross_entropy_kernel( # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) + # = X_y - m - log d = X_y - lse # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 # So we can safely calculate log (softmax(X_y)) without overflow - loss = -(ori_X_y - m - tl.log(d)) + loss = lse - ori_X_y - # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: @@ -135,11 +156,16 @@ def liger_cross_entropy_kernel( # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 if label_smoothing > 0: - smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d)) + smooth_loss = scaled_x_sum + label_smoothing * lse loss = loss * (1 - label_smoothing) + smooth_loss + # An auxiliary loss, z_loss + # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html + z_loss = lse_square_scale * lse * lse + loss += z_loss # Normalize the loss by the number of non-ignored elements if reduction is "mean" if reduction == "mean": + z_loss = z_loss / n_non_ignore loss = loss / n_non_ignore # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` @@ -150,6 +176,8 @@ def liger_cross_entropy_kernel( X_y += -(1 - label_smoothing) tl.store(loss_ptr, loss) + if RETURN_Z_LOSS == _TRUE: + tl.store(z_loss_ptr, z_loss) tl.store(X_ptr + y, X_y) @@ -159,43 +187,31 @@ def liger_cross_entropy_kernel( MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning -@triton.jit -def element_mul_kernel( - X_ptr, - X_stride, - grad_output_ptr, - n_cols, - BLOCK_SIZE: tl.constexpr, -): - """ - This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. - The multiplication is performed in-place on the tensor pointed by X_ptr. - - Parameters: - X_ptr: Pointer to the input tensor. - X_stride (int): The stride of the input tensor. - grad_output_ptr: Pointer to the gradient output value. - n_cols (int): The number of columns in the input tensor. - BLOCK_SIZE (int): The block size for Triton operations. - """ - - # Get the program ID and convert it to int64 to avoid overflow - program_id = tl.program_id(0).to(tl.int64) - - # Locate the start index - X_ptr += program_id * X_stride +_bool_to_return_z_loss = { + True: _TRUE.value, + False: _FALSE.value, +} - # Load the gradient output value - grad_output = tl.load(grad_output_ptr) - - # Perform the element-wise multiplication - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) - tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) +def cross_entropy_forward( + _input, + target, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + return_z_loss, +): + if not isinstance(return_z_loss, int): + assert ( + return_z_loss in _bool_to_return_z_loss + ), f"return_z_loss must be True or False. Got: {return_z_loss}" + return_z_loss = _bool_to_return_z_loss[return_z_loss] + else: + assert ( + return_z_loss in _bool_to_return_z_loss + ), f"return_z_loss must be True or False. Got: {return_z_loss}" -def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction): BT, V = _input.shape n_rows = BT @@ -203,6 +219,10 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti # unreduced loss loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) + if return_z_loss == _TRUE.value: + z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) + else: + z_loss_1d = loss_1d # dummy ptr when return_z_loss == False n_non_ignore = (target != ignore_index).sum().item() @@ -219,20 +239,28 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti Y_ptr=target, Y_stride=target.stride(-1), # always 1 loss_ptr=loss_1d, + z_loss_ptr=z_loss_1d, loss_stride=loss_1d.stride(-1), # always 1 n_cols=V, n_non_ignore=n_non_ignore, ignore_index=ignore_index, + lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, + RETURN_Z_LOSS=return_z_loss, BLOCK_SIZE=BLOCK_SIZE, # TODO: 32 seems to give the best performance # Performance is quite sensitive to num_warps - num_warps=32, + num_warps=32 if not is_hip() else 16, ) loss = torch.sum(loss_1d) - return loss, _input + if return_z_loss == _TRUE.value: + z_loss = torch.sum(z_loss_1d) + else: + z_loss = None + + return loss, z_loss, _input def cross_entropy_backward(_input, grad_output): @@ -253,7 +281,7 @@ def cross_entropy_backward(_input, grad_output): grad_output, V, BLOCK_SIZE=BLOCK_SIZE, - num_warps=32, + num_warps=32 if not is_hip() else 16, ) return _input @@ -267,7 +295,14 @@ class LigerCrossEntropyFunction(torch.autograd.Function): @staticmethod def forward( - ctx, _input, target, ignore_index=-100, label_smoothing=0.0, reduction="mean" + ctx, + _input, + target, + ignore_index=-100, + lse_square_scale=0.0, + label_smoothing=0.0, + reduction="mean", + return_z_loss=False, ): """ The forward pass of the Liger Cross Entropy loss. @@ -277,33 +312,46 @@ def forward( _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. ignore_index (int): The index to ignore in the target. + lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. reduction (str): The reduction to apply to the output: "none" | "mean | "sum". + return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False` Returns: - tensor: The computed loss. + tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None. """ - loss, _input = cross_entropy_forward( - _input, target, ignore_index, label_smoothing, reduction + loss, z_loss, _input = cross_entropy_forward( + _input, + target, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + return_z_loss, ) # TODO: investigation # If we don't detach the _input tensor, the memory will double # Not sure why but seems that there will be a time both grad and value exist but in different location ctx.save_for_backward(_input.detach()) - return loss + ctx.return_z_loss = return_z_loss + + return loss, z_loss @staticmethod - def backward(ctx, grad_output): + def backward(ctx, grad_output, grad_ouput2): """ The backward pass of the Liger Cross Entropy loss. Parameters: ctx : The context object with saved tensors. grad_output (tensor): The tensor containing the gradient of the loss with respect to the output. - + grad_output2 (tenosr): No use. Returns: tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. """ + if ctx.return_z_loss: + del grad_ouput2 # z_loss is only for logging + (_input,) = ctx.saved_tensors _input = cross_entropy_backward(_input, grad_output) return ( @@ -312,4 +360,6 @@ def backward(ctx, grad_output): None, None, None, + None, + None, ) diff --git a/src/liger_kernel/ops/experimental/mm_int8int2.py b/src/liger_kernel/ops/experimental/mm_int8int2.py new file mode 100644 index 000000000..4de17124b --- /dev/null +++ b/src/liger_kernel/ops/experimental/mm_int8int2.py @@ -0,0 +1,355 @@ +import torch +import triton +import triton.language as tl + + +def unpack_weights(packed: torch.Tensor, bits: int = 2) -> torch.Tensor: + values_per_item = 8 // bits + packed_shape = packed.shape + + if len(packed_shape) == 1: + original_row_dim = packed_shape[0] * values_per_item + unpacked_shape = (original_row_dim,) + else: + original_row_dim = packed_shape[0] * values_per_item + unpacked_shape = (original_row_dim, *packed_shape[1:]) + + unpacked = torch.zeros(unpacked_shape, device=packed.device, dtype=torch.uint8) + + for i in range(values_per_item): + start = i * packed_shape[0] + end = start + packed_shape[0] + mask = 3 << (2 * i) + unpacked[start:end] = (packed & mask) >> (2 * i) + + unpacked = unpacked.to(torch.int32) - 1 + return unpacked + + +def pack_weights(intweights: torch.Tensor, bits: int = 2) -> torch.Tensor: + intweights += 1 + original_shape = intweights.shape + values_per_item = 8 // bits + row_dim = (original_shape[0] + values_per_item - 1) // values_per_item + + if len(original_shape) == 1: + packed_tensor_shape = (row_dim,) + else: + packed_tensor_shape = (row_dim, *original_shape[1:]) + + packed = torch.zeros( + packed_tensor_shape, device=intweights.device, dtype=torch.uint8 + ) + unpacked = intweights.to(torch.uint8) + + def lshift(t: torch.Tensor, bits: int): + return t << bits + + it = min(values_per_item, (original_shape[0] // row_dim) + 1) + for i in range(it): + start = i * row_dim + end = min(start + row_dim, original_shape[0]) + packed[: (end - start)] |= lshift(unpacked[start:end], bits * i) + + return packed + + +def get_autotune_config(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + }, + num_stages=4, + num_warps=4, + ), + ] + + +@triton.autotune( + configs=get_autotune_config(), + key=["M", "N", "K"], +) +@triton.jit +def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K: tl.constexpr, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # We want K / 4 to be divisible by BLOCK_SIZE_K so that the multiplication can be aligned + tl.static_assert( + K % (4 * BLOCK_SIZE_K) == 0, + "K / 4 must be divisible by BLOCK_SIZE_K => K divisible by 4*BLOCK_SIZE_K", + ) + # determine the block id in the 1D grid, pid <=> blockId in cuda + pid = tl.program_id(axis=0) + # number of blocks we would need in the M dimension + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + # number of blocks we would need in the N dimension + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + # blocks are grouped along the M dimension. num_pid_in_group computes how many blocks are grouped together, + # and group_id calculates the group to which the current block (pid) belongs. + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + + # pid of the first block in the group that the current block belongs too + first_pid_m = group_id * GROUP_SIZE_M + + # pid_m : pid of the block along the M dimension of the output matrix, and pid_n : pid of the block along the N dimension of the output matrix + # remember that the grid of blocks is 1D, but we calculate pid_m and pid_n to locate the block pid place in the output matrix + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # offs_am represent the indices of elements within the block for matrices A with respect to the M dimension + # offs_bn represent the indices of elements within the block for matrices B with respect to the N dimension + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + """ + This part of the code generates pointers to the specific blocks of matrices A and B that the current thread block will process. + + As described in the PyTorch documentation, a stride refers to the step size needed to move from one element to the next along a given dimension: + + For matrix A: stride_am = A.stride(0) = K (stride along the rows), and stride_ak = A.stride(1) = 1 (stride along the columns). + For matrix B: stride_bk = B.stride(0) = N (stride along the rows), and stride_bn = B.stride(1) = 1 (stride along the columns). + Now, let's break down the pointer generation: + + offs_am[:, None] creates a column of shape [BLOCK_SIZE_M, 1], which represents the row indices of matrix A that this block is processing. It is multiplied by K (the number of columns in matrix A) since A is stored in row-major order. So, the element at position (i, j) in A is located at index i*K + j in memory. + offs_k[None, BLOCK_SIZE_K] creates a row vector representing the column indices of the block, i.e., a range from 0 to BLOCK_SIZE_K. This is used to compute the positions of the columns within the block. + When combined, the result has the shape [BLOCK_SIZE_M, BLOCK_SIZE_K], where each entry (i, j) points to the element in matrix A at position (i, j) for the current block. + + The same logic is applied to matrix B, but the resulting shape is [BLOCK_SIZE_K, BLOCK_SIZE_N], representing the block of matrix B that the thread block will work on. + """ + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # An accumulator matrix is initialized with zeros. It stores the intermediate results of the block matrix multiplication. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32) + """ + We split the loop into two layers. The outer loop runs 4 times, and each iteration focuses on a specific portion of matrix A. + + For example, when i = 0, we’re only concerned with the blocks of matrix A that cover the range from 0 to K // (4 * BLOCK_SIZE_K). + Since matrix B is packed, its first dimension is effectively divided by 4. So, while we process the first segment of matrix A, + we still iterate over the entire first dimension of matrix B. + + In each of the 4 iterations of the outer loop, we go through the full blocks of matrix B, but what changes is the data we extract. + Matrix B elements contain 4 weights, all packed into an int8 format, and during each iteration of the outer loop, + we extract a different weight by using bitwise shifting operations. This way, we access a unique weight on each pass. + """ + for i in range(4): + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + for j in range(0, tl.cdiv(K // 4, BLOCK_SIZE_K)): + k = i * tl.cdiv(K // 4, BLOCK_SIZE_K) + j + # load the block of matrix A + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0) + # load the block of matrix B + b_uint8 = tl.load(b_ptrs, mask=offs_k[:, None] < K, other=0) + # when i = 0 for example, we only care about the first 2 bits of the elements of the matrix B, so we use the mask 00000011 to mask the other bits + mask = 3 << (2 * i) + # we shift the results after the mask + b = (b_uint8 & mask) >> (2 * i) + # During the packing of the weights, it's easier to pack 0, 1, 2, then -1, 0, 1, so we add 1 to the weight tensor, and we substract it here + tensor_full = tl.full((1,), 1, dtype=tl.int8) + # We accumulate the result of multiplication of the blocks along the K dimension on int32 to avoid any overflows or underflows. + accumulator += tl.dot(a, (b.to(tl.int8) - tensor_full), out_dtype=tl.int32) + # we move the pointers, for the a_ptrs we more in a horizontal way along the second dimension -> we use strid_ak=1 + # for b_ptrs we move in a vertical way, along the rows -> we use stride_bk=N + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator + # These lines compute the offsets into matrix C where the result of this block’s computation should be stored. + # stride_cm = N & stride_cn = 1 + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + # we do a boundary check to ensure only elements within matrix bounds are stored + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a, b): + assert ( + a.shape[1] == b.shape[0] * 4 + ), "Incompatible dimensions, the weight matrix need to be packed" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + _, N = b.shape + # c is in int32 to avoid any overflows or underflows + c = torch.empty((M, N), device=a.device, dtype=torch.int32) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + matmul_kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + ) + return c diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 73da9cd46..34016ee4c 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -1,9 +1,12 @@ import torch import triton -from liger_kernel.ops.cross_entropy import ( +from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel +from liger_kernel.ops.utils import ( + amp_custom_bwd, + amp_custom_fwd, element_mul_kernel, - liger_cross_entropy_kernel, + is_hip, ) # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 @@ -18,12 +21,11 @@ def fused_linear_cross_entropy_forward( target, bias=None, ignore_index=-100, + lse_square_scale=0.0, label_smoothing=0.0, reduction="mean", ): - dtype = ( - torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype - ) + dtype = _input.dtype device = _input.device # inputs have shape: BT x H @@ -85,14 +87,17 @@ def fused_linear_cross_entropy_forward( Y_ptr=target_chunk, Y_stride=target_chunk.stride(-1), # always 1 loss_ptr=loss_1d_slice, + z_loss_ptr=loss_1d_slice, # dummy ptr, not used loss_stride=loss_1d_slice.stride(-1), # always 1 n_cols=V, n_non_ignore=n_non_ignore, ignore_index=ignore_index, + lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, + RETURN_Z_LOSS=0, # False BLOCK_SIZE=BLOCK_SIZE, - num_warps=32, + num_warps=32 if not is_hip() else 16, ) # gradient of logits_chunk is computed in-place by the above triton kernel. @@ -157,7 +162,7 @@ def fused_linear_cross_entropy_backward( grad_output, H, BLOCK_SIZE=BLOCK_SIZE, - num_warps=32, + num_warps=32 if not is_hip() else 16, ) # handle grad_weight @@ -171,7 +176,7 @@ def fused_linear_cross_entropy_backward( grad_output, H, BLOCK_SIZE=BLOCK_SIZE, - num_warps=32, + num_warps=32 if not is_hip() else 16, ) if grad_bias is not None: @@ -184,13 +189,14 @@ def fused_linear_cross_entropy_backward( grad_output, 1, BLOCK_SIZE=BLOCK_SIZE, - num_warps=32, + num_warps=32 if not is_hip() else 16, ) return grad_input, grad_weight, grad_bias class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function): @staticmethod + @amp_custom_fwd def forward( ctx, _input, @@ -198,6 +204,7 @@ def forward( target, bias=None, ignore_index=-100, + lse_square_scale=0.0, label_smoothing=0.0, reduction="mean", ): @@ -219,7 +226,14 @@ def forward( reduction: reduction to apply """ loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( - _input, weight, target, bias, ignore_index, label_smoothing, reduction + _input, + weight, + target, + bias, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, ) # downcast to dtype and store for backward ctx.save_for_backward( @@ -230,9 +244,10 @@ def forward( return loss @staticmethod + @amp_custom_bwd def backward(ctx, grad_output): (grad_input, grad_weight, grad_bias) = ctx.saved_tensors grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( grad_output, grad_input, grad_weight, grad_bias ) - return (grad_input, grad_weight, None, grad_bias, None, None, None) + return (grad_input, grad_weight, None, grad_bias, None, None, None, None) diff --git a/src/liger_kernel/ops/fused_linear_jsd.py b/src/liger_kernel/ops/fused_linear_jsd.py new file mode 100644 index 000000000..27ef3aa2f --- /dev/null +++ b/src/liger_kernel/ops/fused_linear_jsd.py @@ -0,0 +1,245 @@ +from typing import Optional + +import torch +import triton + +from liger_kernel.ops.jsd import _jsd_kernel +from liger_kernel.ops.utils import ( + amp_custom_bwd, + amp_custom_fwd, + element_mul_kernel, + is_hip, +) + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 65536 // 2 + + +def fused_linear_jsd_forward( + student_input, + student_weight, + teacher_input, + teacher_weight, + shift_labels, + jsd_beta, + ignore_index, + has_label, + temperature, +): + device = student_input.device + dtype = student_input.dtype + + # inputs have shape: BT x H + # materialized activations will have shape: BT x V + # the increase in memory = BT x V + # reduction can be achieved by partitioning the number of tokens BT into smaller chunks. + # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be: + # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor + # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048 + BT, H = student_input.shape + V = student_weight.shape[0] + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + inc_factor = triton.cdiv(V, H) # (V + H - 1) // H + chunk_size = triton.next_power_of_2( + triton.cdiv(BT, inc_factor) + ) # (BT + inc_factor - 1) // inc_factor + num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size + + grad_weight = ( + torch.zeros_like(student_weight, device=device) + if student_weight.requires_grad + else None + ) + grad_input = torch.zeros_like(student_input) + # we use fp32 for loss accumulator + loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device) + + if has_label: + n_non_ignore = (shift_labels != ignore_index).sum().item() + else: + n_non_ignore = BT + + for chunk_id in range(num_chunks): + start_idx = chunk_id * chunk_size + end_idx = min((chunk_id + 1) * chunk_size, BT) + + # chunk both inputs, shape: chunk_size x H + student_input_chunk = student_input[start_idx:end_idx] + teacher_input_chunk = teacher_input[start_idx:end_idx] + + # shape: chunk_size x V + # For anything starting from logits to the final JSD loss, we do computation + # in FP32 to avoid losing numerical stability. + student_logits_chunk = (student_input_chunk @ student_weight.t()).to( + torch.float32 + ) + teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to( + torch.float32 + ) + chunk_n_rows = student_logits_chunk.shape[0] + + # unreduced loss + loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size + # log-softmax with temperature + student_logits_chunk = student_logits_chunk / temperature + teacher_logits_chunk = teacher_logits_chunk / temperature + student_prob_chunk = torch.log_softmax(student_logits_chunk, dim=-1) + teacher_prob_chunk = torch.log_softmax(teacher_logits_chunk, dim=-1) + + # ensure _input and target are contiguous + student_prob_chunk = student_prob_chunk.contiguous() + teacher_prob_chunk = teacher_prob_chunk.contiguous() + + # Here we calculate the gradient of prob_chunk in place so we can save memory. + _jsd_kernel[(chunk_n_rows,)]( + X_ptr=student_prob_chunk, + X_stride=student_prob_chunk.stride(-2), + Y_ptr=teacher_prob_chunk, + Y_stride=teacher_prob_chunk.stride(-2), + loss_ptr=loss_1d_slice, + loss_stride=loss_1d_slice.stride(-2), + dX_ptr=student_prob_chunk, + dX_stride=student_prob_chunk.stride(-2), + label_ptr=( + shift_labels[start_idx:end_idx] + if has_label + else torch.empty(1, device=device) + ), # dummy ptr if no label + beta=jsd_beta, + n_non_ignore=n_non_ignore, + ignore_index=ignore_index, + n_cols=V, + BLOCK_SIZE=BLOCK_SIZE, + HAS_LABEL=has_label, + ) + loss_1d[start_idx:end_idx] = loss_1d_slice + # gradients of prob_chunk in place, shape: chunk_size x V + # gradients of logits_chunk in place, shape: chunk_size x V + student_logits_chunk = ( + student_prob_chunk + - torch.softmax(student_logits_chunk, dim=-1) + * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to( + student_prob_chunk.shape + ) + ) / temperature + # now we traverse back to grad w.r.t. input to `lm_head` and grad + # w.r.t. `lm_head` which should be computed in original dtype + student_logits_chunk = student_logits_chunk.to(dtype) + grad_input[start_idx:end_idx] = student_logits_chunk @ student_weight + + if grad_weight is not None: + grad_weight.add_(student_logits_chunk.t() @ student_input_chunk) + + loss = torch.sum(loss_1d) + return loss, grad_input, grad_weight + + +def fused_linear_jsd_backward(grad_output, grad_input, grad_weight): + # If JSD is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + BT, H = grad_input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) + + element_mul_kernel[(n_rows,)]( + grad_input, + grad_input.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + # handle grad_weight + if grad_weight is not None: + V, H = grad_weight.shape + n_rows = V + + element_mul_kernel[(n_rows,)]( + grad_weight, + grad_weight.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + return grad_input, grad_weight + + +class LigerFusedLinearJSDFunction(torch.autograd.Function): + """ + Fusing the last linear layer with generalized JSD + + Handle the forward and backward pass of the final linear layer via JSD by avoiding + the materialization of the large logits tensor. Since JSD is the last layer, we can + compute the gradient at the forward pass. + """ + + @staticmethod + @amp_custom_fwd + def forward( + ctx, + student_input: torch.Tensor, + student_weight: torch.Tensor, + teacher_input: torch.Tensor, + teacher_weight: torch.Tensor, + shift_labels: Optional[torch.Tensor] = None, + jsd_beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + ): + """ + Args: + + student_input (torch.tensor): input of the last projection layer in student model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension. + student_weight (torch.tensor): the last projection layer in student model, with shape (V, H), where V is vocab size + teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension. + teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size + shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1]. + jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5` + ignore_index (int): the index to ignore. Default: -100 + temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0` + + Returns: + loss (torch.Tensor): generalized JSD + """ + has_label = False + if shift_labels is not None: + assert shift_labels.shape == ( + teacher_input.shape[0], + ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" + shift_labels = shift_labels.contiguous() + has_label = True + + loss, grad_input, grad_weight = fused_linear_jsd_forward( + student_input, + student_weight, + teacher_input, + teacher_weight, + shift_labels, + jsd_beta, + ignore_index, + has_label, + temperature, + ) + # downcast to dtype and store for backward + ctx.save_for_backward( + grad_input.detach(), + grad_weight.detach() if grad_weight is not None else None, + ) + return loss + + @staticmethod + @amp_custom_bwd + def backward(ctx, grad_output): + (grad_input, grad_weight) = ctx.saved_tensors + grad_input, grad_weight = fused_linear_jsd_backward( + grad_output, grad_input, grad_weight + ) + return (grad_input, grad_weight, None, None, None, None, None, None) diff --git a/src/liger_kernel/ops/jsd.py b/src/liger_kernel/ops/jsd.py new file mode 100644 index 000000000..6ecf8dbe9 --- /dev/null +++ b/src/liger_kernel/ops/jsd.py @@ -0,0 +1,176 @@ +from typing import Optional + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import ensure_contiguous + + +@triton.jit +def _jsd_kernel( + X_ptr, # input in logspace, X = log Q + X_stride, + Y_ptr, # ground truth in logspace, Y = log P + Y_stride, + loss_ptr, + loss_stride, + dX_ptr, + dX_stride, + label_ptr, + beta, + n_non_ignore: int, + ignore_index: tl.constexpr, + n_cols, + BLOCK_SIZE: tl.constexpr, + HAS_LABEL: tl.constexpr, +): + # JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X) + # = sum(P * log P + Q * log Q - 2 * M * log M) / 2 + # = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2 + # grad_x_i = 0.5 * Q * (X - log_M) + pid = tl.program_id(0).to(tl.int64) + X_ptr += pid * X_stride + dX_ptr += pid * dX_stride + Y_ptr += pid * Y_stride + loss_ptr += pid * loss_stride + label_ptr += pid + + if HAS_LABEL: + label = tl.load(label_ptr) + if label == ignore_index: + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols) + return + + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_cols + X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32) + Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32) + + Q = tl.exp(X) + P = tl.exp(Y) + M = beta * P + (1 - beta) * Q + log_M = tl.log(M) + + loss = beta * P * Y + (1 - beta) * Q * X - M * log_M + # reduction == "batchmean" + loss = loss / n_non_ignore + tl.store(loss_ptr + offsets, loss, mask=mask) + + dX = (1 - beta) * Q * (X - log_M) / n_non_ignore + tl.store(dX_ptr + offsets, dX, mask=mask) + + +MAX_FUSED_SIZE = 65536 + + +def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label): + BT, V = _input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + # non reduction loss + loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device) + dX = torch.empty_like(_input) + + if has_label: + n_non_ignore = (shift_labels != ignore_index).sum().item() + else: + n_non_ignore = BT + + _jsd_kernel[(n_rows,)]( + X_ptr=_input, # input in logspace, X = log Q + X_stride=_input.stride(-2), + Y_ptr=target, # ground truth in logspace, Y = log P + Y_stride=target.stride(-2), + loss_ptr=loss, + loss_stride=loss.stride(-2), + dX_ptr=dX, + dX_stride=dX.stride(-2), + label_ptr=( + shift_labels if has_label else torch.empty(1, device=_input.device) + ), # dummy ptr if no label + beta=beta, + n_non_ignore=n_non_ignore, + ignore_index=ignore_index, + n_cols=V, + BLOCK_SIZE=BLOCK_SIZE, + HAS_LABEL=has_label, + ) + + loss = torch.sum(loss) + return loss.to(_input.dtype), dX + + +def jsd_backward(dX, grad_output): + # If jsd is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + return dX + else: + return grad_output * dX + + +class LigerJSDFunction(torch.autograd.Function): + r""" + This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence. + .. math:: + JSD(\beta)(P || Q) + = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q)) + + .. note:: + As all the other losses in PyTorch, this function expects the first argument, + :attr:`_input`, to be the predictions, the output of the student model, in log-space + and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space. + This differs from the standard mathematical notation :math:`JSD(P || Q)` where + :math:`P` denotes the teacher model and :math:`Q` denotes the student model. + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, + _input: torch.Tensor, + target: torch.Tensor, + shift_labels: Optional[torch.Tensor] = None, + beta: float = 0.5, + ignore_index: int = -100, + ) -> torch.Tensor: + """ + Args: + _input (torch.Tensor): predict values with shape (BT, V) in logspace + target (torch.Tensor): ground truth values with shape (BT, V) in logspace + shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1]. + beta (float): coefficient beta of generalized JSD in the open interval (0, 1) + ignore_index (int): the index to ignore. Default: -100 + + Returns: + loss (torch.Tensor): generalized JSD + """ + has_label = False + if shift_labels is not None: + assert shift_labels.shape == ( + _input.shape[0], + ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" + shift_labels = shift_labels.contiguous() + has_label = True + + loss, dX = jsd_forward( + _input, target, shift_labels, beta, ignore_index, has_label + ) + ctx.save_for_backward(dX) + return loss + + @staticmethod + @ensure_contiguous + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: + (dX,) = ctx.saved_tensors + dX = jsd_backward(dX, grad_output) + return ( + dX, + None, + None, + None, + None, + ) diff --git a/src/liger_kernel/ops/kl_div.py b/src/liger_kernel/ops/kl_div.py index 215810f38..2e3c6e933 100644 --- a/src/liger_kernel/ops/kl_div.py +++ b/src/liger_kernel/ops/kl_div.py @@ -4,13 +4,13 @@ import triton import triton.language as tl -from liger_kernel.ops.utils import ensure_contiguous +from liger_kernel.ops.utils import ensure_contiguous, is_hip def get_num_warps(BLOCK_SIZE): num_warps = 4 if BLOCK_SIZE >= 32768: - num_warps = 32 + num_warps = 32 if not is_hip() else 16 elif BLOCK_SIZE >= 8192: num_warps = 16 elif BLOCK_SIZE >= 2048: @@ -45,6 +45,7 @@ def _kldiv_kernel_forward( loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr loss_stride, # int, output stride n_cols, # int, number of columns in the input tensor + eps, BLOCK_SIZE: tl.constexpr, log_target: tl.constexpr = False, reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN, @@ -56,6 +57,7 @@ def _kldiv_kernel_forward( base_offsets = tl.arange(0, BLOCK_SIZE) + loss_sum = 0.0 for i in range(0, n_cols, BLOCK_SIZE): offsets = i + base_offsets mask = offsets < n_cols @@ -65,32 +67,33 @@ def _kldiv_kernel_forward( # KL(y_true || y) = y_true * (log(y_true) - log(y)) # We compute KL(y_true || y) with y in the log-space if not log_target: - loss = y_true * (tl.log(y_true) - y) + loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y) else: loss = tl.exp(y_true) * (y_true - y) if reduction == _REDUCTION_MODE_NONE: tl.store(loss_ptr + offsets, loss, mask=mask) else: - loss = tl.sum(loss, axis=0) - tl.store(loss_ptr, loss) - loss_ptr += 1 # in case of reduction, the output tensor has dimensions [B,], therefore stride is always 1 + loss_sum += tl.sum(loss, axis=0) + + if reduction != _REDUCTION_MODE_NONE: + tl.store(loss_ptr, loss_sum) @triton.jit def _kldiv_kernel_backward( - input_ptr, - input_stride, target_ptr, target_stride, + new_grads_ptr, + new_grads_stride, n_cols, BLOCK_SIZE: tl.constexpr, log_target: tl.constexpr = False, ): pid = tl.program_id(0).to(tl.int64) - input_ptr += pid * input_stride target_ptr += pid * target_stride + new_grads_ptr += pid * new_grads_stride offsets = tl.arange(0, BLOCK_SIZE) mask = offsets < n_cols @@ -106,19 +109,19 @@ def _kldiv_kernel_backward( else: res = -tl.exp(target) - tl.store(input_ptr + offsets, res, mask=mask) + tl.store(new_grads_ptr + offsets, res, mask=mask) -def kldiv_forward_triton(y_pred, y_true, log_target, reduction): # [B, S] # [B, S] - B, S = y_pred.shape +def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V] + BT, V = y_pred.shape - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(S)) + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) num_warps = get_num_warps(BLOCK_SIZE) - grid = (B,) + grid = (BT,) reduction = _str_to_reduction_mode[reduction] - out_size = (B, S) if reduction == _REDUCTION_MODE_NONE.value else (B,) + out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,) output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32) _kldiv_kernel_forward[grid]( @@ -128,7 +131,8 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction): # [B, S] # [B y_true.stride(0), output_tensor, output_tensor.stride(0), - S, + V, + eps=eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, log_target=log_target, @@ -139,30 +143,30 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction): # [B, S] # [B # https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html # https://github.com/pytorch/pytorch/blob/d7b57c4d63edb42e1deeeba9497fcb5f1f748ff2/torch/nn/functional.py#L3372 if reduction == _REDUCTION_MODE_BATCHMEAN.value: - return output_tensor.sum() / B + return output_tensor.sum() / BT elif reduction == _REDUCTION_MODE_SUM.value: return output_tensor.sum(dim=0) elif reduction == _REDUCTION_MODE_MEAN.value: - return output_tensor.mean(dim=0) + return output_tensor.sum() / (BT * V) else: return output_tensor -def kldiv_backward_triton(input, target, grad_output, log_target): - B, S = input.shape +def kldiv_backward_triton(target, grad_output, new_grads, log_target): + BT, V = target.shape - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(S)) + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) num_warps = get_num_warps(BLOCK_SIZE) - grid = (B,) + grid = (BT,) # We store the gradients in-place in the input tensor _kldiv_kernel_backward[grid]( - input, - input.stride(0), target, target.stride(0), - S, + new_grads, + new_grads.stride(0), + V, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, log_target=log_target, @@ -170,9 +174,9 @@ def kldiv_backward_triton(input, target, grad_output, log_target): # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then. if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): - return input + return new_grads - return input * grad_output + return new_grads * grad_output class LigerKLDivLossFunction(torch.autograd.Function): @@ -196,6 +200,7 @@ def forward( y_true: torch.Tensor, reduction: REDUCTION_LITERAL = "batchmean", log_target: bool = False, + eps: float = 1e-10, ) -> torch.Tensor: """A forward pass for the KL Divergence Loss. @@ -205,15 +210,16 @@ def forward( y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`. reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean". log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False. + eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10. Returns: torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar. """ - ctx.save_for_backward(y_pred, y_true) + ctx.save_for_backward(y_true) ctx.reduction = reduction ctx.log_target = log_target return kldiv_forward_triton( - y_pred, y_true, log_target=log_target, reduction=reduction + y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps ) @staticmethod @@ -226,22 +232,27 @@ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: grad_output (torch.Tensor): The gradient of the loss with respect to the output. Returns: - tuple[torch.Tensor, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method. + tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method. """ - y_pred, y_true = ctx.saved_tensors + (y_true,) = ctx.saved_tensors + + new_grads = torch.empty_like(y_true) - derivative = kldiv_backward_triton(y_pred, y_true, grad_output, ctx.log_target) + derivative = kldiv_backward_triton( + y_true, grad_output, new_grads, ctx.log_target + ) if ctx.reduction == "batchmean": - derivative = derivative / y_pred.shape[0] + derivative = derivative / y_true.shape[0] elif ctx.reduction == "sum" or ctx.reduction == "none": pass elif ctx.reduction == "mean": - derivative = derivative / (y_pred.shape[0] * y_pred.shape[1]) + derivative = derivative / (y_true.shape[0] * y_true.shape[1]) return ( derivative, None, None, None, + None, ) diff --git a/src/liger_kernel/ops/rms_norm.py b/src/liger_kernel/ops/rms_norm.py index 68fcf05d2..06819f124 100644 --- a/src/liger_kernel/ops/rms_norm.py +++ b/src/liger_kernel/ops/rms_norm.py @@ -10,6 +10,7 @@ Modifications made by Yanning Chen, 2024. """ +import math import operator import torch @@ -20,6 +21,7 @@ calculate_settings, compare_version, ensure_contiguous, + torch_to_triton_dtype, ) if compare_version("triton", operator.ge, "3.0.0"): @@ -84,6 +86,10 @@ def _rms_norm_forward_kernel( W_row = W_row.to(tl.float32) X_row = X_row.to(tl.float32) + if casting_mode == _CASTING_MODE_NONE: + eps = eps.to(X_row_dtype) + offset = offset.to(X_row_dtype) + mean_square = tl.sum(X_row * X_row, axis=0) / n_cols rstd = rsqrt(mean_square + eps) @@ -100,6 +106,9 @@ def _rms_norm_forward_kernel( Y_row = X_row * (offset + W_row) + if casting_mode == _CASTING_MODE_GEMMA: + Y_row = Y_row.to(X_row_dtype) + tl.store(Y_ptr + col_offsets, Y_row, mask=mask) @@ -109,14 +118,17 @@ def _rms_norm_backward_kernel( dY_row_stride, X_ptr, X_row_stride, + X_dtype: tl.constexpr, W_ptr, W_row_stride, RSTD_ptr, RSTD_row_stride, dW_ptr, dW_row_stride, + n_rows, n_cols, offset, + rows_per_program: tl.constexpr, casting_mode: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): @@ -125,54 +137,60 @@ def _rms_norm_backward_kernel( dw = sum(dy * (x / RMS)). summation over BxT dimension """ - row_idx = tl.program_id(0) + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + row_end = min((row_block_id + 1) * rows_per_program, n_rows) col_offsets = tl.arange(0, BLOCK_SIZE) mask = col_offsets < n_cols - dY_ptr += row_idx * dY_row_stride - X_ptr += row_idx * X_row_stride - RSTD_ptr += row_idx * RSTD_row_stride - dW_ptr += row_idx * dW_row_stride + dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) - dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0) - X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0) - W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0) - original_x_dtype = X_row.dtype - - # Get cached rms - rstd_row = tl.load(RSTD_ptr) + dY_ptr += row_start * dY_row_stride + X_ptr += row_start * X_row_stride + RSTD_ptr += row_start + W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0) W_row = W_row + offset - X_row = X_row.to(tl.float32) + for _ in range(row_start, row_end): + dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0) + X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0) - # Different bacward graphs for different casting modes - if casting_mode == _CASTING_MODE_LLAMA: - m = (dY_row * W_row).to(tl.float32) + # Get cached rms + rstd_row = tl.load(RSTD_ptr) - elif casting_mode == _CASTING_MODE_GEMMA: - dY_row, W_row = ( - dY_row.to(tl.float32), - W_row.to(tl.float32), - ) + X_row = X_row.to(tl.float32) - m = dY_row * W_row + # Different bacward graphs for different casting modes + if casting_mode == _CASTING_MODE_LLAMA: + m = (dY_row * W_row).to(tl.float32) - dX_row = rstd_row * m + elif casting_mode == _CASTING_MODE_GEMMA: + dY_row = dY_row.to(tl.float32) + m = dY_row * W_row + else: + m = dY_row * W_row - dX_row += (rstd_row) * ( - -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row - ) + dX_row = rstd_row * m - # calculate the gradient of W - if casting_mode == _CASTING_MODE_LLAMA: - dW_row = dY_row * (X_row * rstd_row).to(original_x_dtype) - else: - # here X_row is already in fp32 (see previous if block) - dW_row = dY_row * (X_row * rstd_row) + dX_row += (rstd_row) * ( + -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row + ) + + # calculate the gradient of W + if casting_mode == _CASTING_MODE_LLAMA: + dW_row += dY_row * (X_row * rstd_row).to(X_dtype) + else: + # here X_row is already in fp32 (see previous if block) + dW_row += dY_row * (X_row * rstd_row) - tl.store(dY_ptr + col_offsets, dX_row, mask=mask) - tl.store(dW_ptr + col_offsets, dW_row, mask=mask) + tl.store(dY_ptr + col_offsets, dX_row.to(X_dtype), mask=mask) + + dY_ptr += dY_row_stride + X_ptr += X_row_stride + RSTD_ptr += RSTD_row_stride + + tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask) _str_to_casting_mode = { @@ -238,31 +256,38 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp dim = shape[-1] dY = dY.view(-1, dim) n_rows, n_cols = dY.shape - dW = torch.empty_like( - X, - dtype=(torch.float32 if casting_mode == _CASTING_MODE_GEMMA.value else W.dtype), - ) + sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count + # fp32 for numerical stability especially. + _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device) + + if n_cols > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + rows_per_program = math.ceil(n_rows / sm_count) + grid = (sm_count,) # Here we use dY to store the value of dX to save memory - _rms_norm_backward_kernel[(n_rows,)]( + _rms_norm_backward_kernel[grid]( dY, dY.stride(0), X, X.stride(0), + torch_to_triton_dtype[X.dtype], W, W.stride(0), RSTD, RSTD.stride(0), - dW, - dW.stride(0), + _dW, + _dW.stride(0), + n_rows, n_cols, offset, + rows_per_program, casting_mode, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, ) dX = dY.view(*shape) - dW = torch.sum(dW, dim=0).to(W.dtype) + dW = _dW.sum(dim=0).to(W.dtype) return dX, dW diff --git a/src/liger_kernel/ops/utils.py b/src/liger_kernel/ops/utils.py index d89da288f..4a24223d0 100644 --- a/src/liger_kernel/ops/utils.py +++ b/src/liger_kernel/ops/utils.py @@ -12,13 +12,19 @@ import functools import importlib +import operator from typing import Callable import torch import triton +import triton.language as tl from packaging.version import Version +def is_hip() -> bool: + return torch.version.hip is not None + + def ensure_contiguous(fn): @functools.wraps(fn) def wrapper(ctx, *args, **kwargs): @@ -45,7 +51,7 @@ def calculate_settings(n): num_warps = 4 if BLOCK_SIZE >= 32768: - num_warps = 32 + num_warps = 32 if not is_hip() else 16 elif BLOCK_SIZE >= 8192: num_warps = 16 elif BLOCK_SIZE >= 2048: @@ -60,3 +66,58 @@ def compare_version(package: str, operator: Callable, target: str): return False pkg_version = Version(pkg.__version__) return operator(pkg_version, Version(target)) + + +def get_amp_custom_fwd_bwd() -> Callable: + if compare_version("torch", operator.ge, "2.4.0"): + return ( + functools.partial(torch.amp.custom_fwd, device_type="cuda"), + functools.partial(torch.amp.custom_bwd, device_type="cuda"), + ) + return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd + + +amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd() + + +torch_to_triton_dtype = { + torch.float32: tl.float32, + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, +} + + +@triton.jit +def element_mul_kernel( + X_ptr, + X_stride, + grad_output_ptr, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. + The multiplication is performed in-place on the tensor pointed by X_ptr. + + Parameters: + X_ptr: Pointer to the input tensor. + X_stride (int): The stride of the input tensor. + grad_output_ptr: Pointer to the gradient output value. + n_cols (int): The number of columns in the input tensor. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + # Get the program ID and convert it to int64 to avoid overflow + program_id = tl.program_id(0).to(tl.int64) + + # Locate the start index + X_ptr += program_id * X_stride + + # Load the gradient output value + grad_output = tl.load(grad_output_ptr) + + # Perform the element-wise multiplication + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) + tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index 5cf559011..ffb8235cc 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -5,7 +5,9 @@ from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: F401 LigerFusedLinearCrossEntropyLoss, ) +from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401 from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401 +from liger_kernel.transformers.jsd import LigerJSD # noqa: F401 from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401 from liger_kernel.transformers.monkey_patch import ( # noqa: F401 _apply_liger_kernel, @@ -15,6 +17,7 @@ apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, + apply_liger_kernel_to_mllama, apply_liger_kernel_to_phi3, apply_liger_kernel_to_qwen2, apply_liger_kernel_to_qwen2_vl, diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py index b2457481b..f612f6f4d 100644 --- a/src/liger_kernel/transformers/cross_entropy.py +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -1,11 +1,24 @@ -from torch.nn import CrossEntropyLoss +import torch.nn as nn from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction -class LigerCrossEntropyLoss(CrossEntropyLoss): - def __init__(self, *args, **kwargs): - super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs) +class LigerCrossEntropyLoss(nn.Module): + def __init__( + self, + ignore_index=-100, + lse_square_scale=0.0, + label_smoothing=0.0, + reduction="mean", + return_z_loss=False, + ): + super().__init__() + self.ignore_index = ignore_index + self.lse_square_scale = lse_square_scale + self.label_smoothing = label_smoothing + self.reduction = reduction + self.return_z_loss = return_z_loss + assert (self.label_smoothing >= 0) and ( self.label_smoothing <= 1 ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}" @@ -16,6 +29,15 @@ def __init__(self, *args, **kwargs): }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}" def forward(self, _input, target): - return LigerCrossEntropyFunction.apply( - _input, target, self.ignore_index, self.label_smoothing, self.reduction + loss, z_loss = LigerCrossEntropyFunction.apply( + _input, + target, + self.ignore_index, + self.lse_square_scale, + self.label_smoothing, + self.reduction, + self.return_z_loss, ) + if not self.return_z_loss: + return loss + return loss, z_loss diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index d63045efb..f160887b8 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -2,7 +2,9 @@ from liger_kernel.ops.fused_linear_cross_entropy import ( LigerFusedLinearCrossEntropyFunction, ) +from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction from liger_kernel.ops.geglu import LigerGELUMulFunction +from liger_kernel.ops.jsd import LigerJSDFunction from liger_kernel.ops.kl_div import LigerKLDivLossFunction from liger_kernel.ops.layer_norm import LigerLayerNormFunction from liger_kernel.ops.rms_norm import LigerRMSNormFunction @@ -17,3 +19,5 @@ liger_rope = LigerRopeFunction.apply liger_layer_norm = LigerLayerNormFunction.apply liger_kl_div = LigerKLDivLossFunction.apply +liger_jsd = LigerJSDFunction.apply +liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply diff --git a/src/liger_kernel/transformers/fused_linear_cross_entropy.py b/src/liger_kernel/transformers/fused_linear_cross_entropy.py index 2a3971f2c..74c4b778a 100644 --- a/src/liger_kernel/transformers/fused_linear_cross_entropy.py +++ b/src/liger_kernel/transformers/fused_linear_cross_entropy.py @@ -1,13 +1,26 @@ -from torch.nn import CrossEntropyLoss +import torch.nn as nn from liger_kernel.ops.fused_linear_cross_entropy import ( LigerFusedLinearCrossEntropyFunction, ) -class LigerFusedLinearCrossEntropyLoss(CrossEntropyLoss): - def __init__(self, *args, **kwargs): - super(LigerFusedLinearCrossEntropyLoss, self).__init__(*args, **kwargs) +class LigerFusedLinearCrossEntropyLoss(nn.Module): + def __init__( + self, + ignore_index=-100, + label_smoothing=0.0, + reduction="mean", + lse_square_scale=0.0, + ): + super().__init__() + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.reduction = reduction + self.lse_square_scale = lse_square_scale + assert (self.label_smoothing >= 0) and ( + self.label_smoothing <= 1 + ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}" def forward(self, lin_weight, _input, target, bias=None, label_smoothing=0.0): return LigerFusedLinearCrossEntropyFunction.apply( @@ -16,6 +29,7 @@ def forward(self, lin_weight, _input, target, bias=None, label_smoothing=0.0): target, bias, self.ignore_index, + self.lse_square_scale, self.label_smoothing, self.reduction, ) diff --git a/src/liger_kernel/transformers/fused_linear_jsd.py b/src/liger_kernel/transformers/fused_linear_jsd.py new file mode 100644 index 000000000..001174cc2 --- /dev/null +++ b/src/liger_kernel/transformers/fused_linear_jsd.py @@ -0,0 +1,98 @@ +from typing import Optional + +import torch + +from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction + + +class LigerFusedLinearJSD(torch.nn.Module): + r"""Fusing the last linear layer with generalized JSD + + Handle the forward and backward pass of the final linear layer via JSD by avoiding + the materialization of the large logits tensor. + + Args: + jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5` + ignore_index (int): The index to ignore in the target. Default: `-100` + temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0` + + Shape: + - student_input: :math:`(BT, H)`, where B is batch size, T is sequence length, H is hidden dimension. + - student_weight: :math:`(V, H)`, where V is vocab size. + - teacher_input: :math:`(BT, H')`, where H' is hidden dimension of the teacher model. + - teacher_weight: :math:`(V, H')`, where hidden size H and H' can be different. + - shift_labels: :math:`(BT,)` + - Output: a scalar. + + Examples: + ```python + >>> (B, T, H_s, H_t, V) = (2, 2, 3, 5, 10) + >>> fused_jsd = LigerFusedLinearJSD(jsd_beta=0.1, temperature=2.0) + >>> # generate inputs and weights + >>> student_input = torch.rand(B * T, H_s, device="cuda", requires_grad=True) + >>> student_lin = torch.nn.Linear(H_s, V, bias=False, device="cuda") + >>> # teacher input doesn't require grad, hidden_dim can be different from student's + >>> teacher_input = torch.rand(B * T, H_t, device="cuda") + >>> teacher_lin = torch.nn.Linear(H_t, V, bias=False, device="cuda") + >>> output = fused_jsd(student_input, student_lin.weight, teacher_input, teacher_lin.weight) + >>> output.backward() + >>> + >>> # Example with labels for supervised fine-tuning (SFT) context: + >>> + >>> # Assume hidden_states, lm_heads and corresponding labels are given + >>> student_lm_head = torch.nn.Linear(H_s, V, bias=False) + >>> student_hidden_states = torch.randn(B * T, H_s, requires_grad=True).log_softmax(dim=-1) + >>> teacher_lm_head = torch.nn.Linear(H_t, V, bias=False) + >>> teacher_hidden_states = torch.randn(B * T, H_t).log_softmax(dim=-1) + >>> labels = torch.randint(0, V, (B * T,), torch.long) + >>> + >>> # Shift so that tokens < n predict n + >>> shift_student_hidden_states = student_hidden_states[..., :-1, :].contiguous() + >>> shift_teacher_hidden_states = teacher_hidden_states[..., :-1, :].contiguous() + >>> shift_labels = labels[..., 1:].contiguous() + >>> + >>> # Flatten tokens + >>> shift_student_hidden_states = shift_student_hidden_states.view(-1, V) + >>> shift_teacher_hidden_states = shift_teacher_hidden_states.view(-1, V) + >>> shift_labels = shift_labels.view(-1) + >>> + >>> # Calculate loss + >>> loss_fct = LigerJSD(beta=0.1) + >>> loss = loss_fct( + >>> shift_studetn_hidden_states, + >>> student_lm_head.weight, + >>> shift_teacher_hidden_states, + >>> teacher_lm_head.weight, + >>> shift_labels + >>> ) + ``` + """ + + def __init__(self, jsd_beta=0.5, ignore_index=-100, temperature=1.0): + super().__init__() + assert ( + jsd_beta > 0 and jsd_beta < 1 + ), f"beta must be greater than 0 and less than 1. Got: {jsd_beta}" + assert temperature != 0, "temperature cannot be 0." + self.jsd_beta = jsd_beta + self.temperature = temperature + self.ignore_index = ignore_index + + def forward( + self, + student_input: torch.Tensor, + student_weight: torch.Tensor, + teacher_input: torch.Tensor, + teacher_weight: torch.Tensor, + shift_labels: Optional[torch.LongTensor], + ): + return LigerFusedLinearJSDFunction.apply( + student_input, + student_weight, + teacher_input, + teacher_weight, + shift_labels, + self.jsd_beta, + self.ignore_index, + self.temperature, + ) diff --git a/src/liger_kernel/transformers/jsd.py b/src/liger_kernel/transformers/jsd.py new file mode 100644 index 000000000..e218ca84b --- /dev/null +++ b/src/liger_kernel/transformers/jsd.py @@ -0,0 +1,75 @@ +from typing import Optional + +import torch + +from liger_kernel.ops.jsd import LigerJSDFunction + + +class LigerJSD(torch.nn.Module): + r"""The generalized Jensen-Shannon Divergence. + .. math:: + JSD(\beta)(P || Q) + = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q)) + .. note:: + As all the other losses in PyTorch, this function expects the first argument, + :attr:`log_q`, to be the predictions, the output of the student model in log-space, + and the second, :attr:`log_p`, to be the observations, the output of the teacher model in log-space. + This differs from the standard mathematical notation :math:`JSD(P || Q)` where + :math:`P` denotes the teacher model and :math:`Q` denotes the student model. + + Args: + beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5` + ignore_index (int): The index to ignore in the target. Default: `-100` + + Shape: + - Input: :math:`(BT, V)`, where B is batch size, T is sequence length, V is vocab size. + - Target: :math:`(BT, V)`, same shape as the input. + - shift_labels (Optional): :math:`(BT,)` + - Output: a scalar. + + Examples: + ```python + >>> (B, T, V) = (2, 2, 5) + >>> jsd = LigerJSD(beta=0.1) + >>> # input should be a distribution in the log space + >>> input = torch.randn(B * T, V, requires_grad=True).log_softmax(dim=-1) + >>> target = torch.randn(B * T, V).log_softmax(dim=-1) + >>> output = jsd(input, target) + >>> + >>> # Example with labels for supervised fine-tuning (SFT) context + >>> # Assume logits and corresponding labels are given + >>> student_logits = torch.randn(B * T, V, requires_grad=True).log_softmax(dim=-1) + >>> teacher_logits = torch.randn(B * T, V).log_softmax(dim=-1) + >>> labels = torch.randint(0, V, (B * T,), torch.long) + >>> # Shift so that tokens < n predict n + >>> shift_student_logits = student_logits[..., :-1, :].contiguous() + >>> shift_teacher_logits = teacher_logits[..., :-1, :].contiguous() + >>> shift_labels = labels[..., 1:].contiguous() + >>> # Flatten tokens + >>> shift_student_logits = shift_student_logits.view(-1, V) + >>> shift_teacher_logits = shift_teacher_logits.view(-1, V) + >>> shift_labels = shift_labels.view(-1) + >>> # Calculate loss + >>> loss_fct = LigerJSD(beta=0.1) + >>> loss = loss_fct(shift_studetn_logits, shift_teacher_logits, shift_labels) + + ``` + """ + + def __init__(self, beta: float = 0.5, ignore_index: int = -100): + super().__init__() + assert ( + beta > 0 and beta < 1 + ), f"beta must be greater than 0 and less than 1. Got: {beta}" + self.beta = beta + self.ignore_index = ignore_index + + def forward( + self, + log_q: torch.Tensor, + log_p: torch.Tensor, + shift_labels: Optional[torch.LongTensor] = None, + ): + return LigerJSDFunction.apply( + log_q, log_p, shift_labels, self.beta, self.ignore_index + ) diff --git a/src/liger_kernel/transformers/kl_div.py b/src/liger_kernel/transformers/kl_div.py index 3c8785a7e..8bd50dad0 100644 --- a/src/liger_kernel/transformers/kl_div.py +++ b/src/liger_kernel/transformers/kl_div.py @@ -4,10 +4,11 @@ class LigerKLDIVLoss(nn.KLDivLoss): - def __init__(self, *args, **kwargs): + def __init__(self, eps: float = 1e-10, *args, **kwargs): super(LigerKLDIVLoss, self).__init__(*args, **kwargs) + self.eps = eps def forward(self, y_pred, y_true): return LigerKLDivLossFunction.apply( - y_pred, y_true, self.reduction, self.log_target + y_pred, y_true, self.reduction, self.log_target, self.eps ) diff --git a/src/liger_kernel/transformers/model/gemma.py b/src/liger_kernel/transformers/model/gemma.py index b6cdf1238..f7b9814e9 100644 --- a/src/liger_kernel/transformers/model/gemma.py +++ b/src/liger_kernel/transformers/model/gemma.py @@ -22,7 +22,7 @@ @replace_return_docstrings( output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC ) -def lce_forward( +def lce_forward_deprecated( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, @@ -136,3 +136,126 @@ def lce_forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **loss_kwargs, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, GemmaForCausalLM + + >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + logits = None + loss = None + # if in training mode, don't materialize logits + if self.training and (labels is not None): + # We do the same thing as ForCausalLMLoss but using Liger FLCE + + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean" + lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction) + + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + if reduction == "sum": + loss /= loss_kwargs["num_items_in_batch"] + + else: # if in inference mode materialize logits + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **loss_kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/liger_kernel/transformers/model/llama.py b/src/liger_kernel/transformers/model/llama.py index 9cf6ed446..b8d12c76a 100644 --- a/src/liger_kernel/transformers/model/llama.py +++ b/src/liger_kernel/transformers/model/llama.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -17,17 +17,20 @@ LigerFusedLinearCrossEntropyLoss, ) +if TYPE_CHECKING: + from transformers.cache_utils import Cache + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @replace_return_docstrings( output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC ) -def lce_forward( +def lce_forward_deprecated( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -120,8 +123,9 @@ def lce_forward( logits = torch.cat(logits, dim=-1) else: logits = self.lm_head(hidden_states) - logits = logits.float() if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -144,3 +148,130 @@ def lce_forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **loss_kwargs, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + if self.config.pretraining_tp > 1: + raise Exception("Liger Kernel does not support pretraining_tp!!") + + logits = None + loss = None + # if in training mode, don't materialize logits + if self.training and (labels is not None): + # We do the same thing as ForCausalLMLoss but using Liger FLCE + + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean" + lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction) + + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + if reduction == "sum": + loss /= loss_kwargs["num_items_in_batch"] + + else: # if in inference mode materialize logits + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **loss_kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/liger_kernel/transformers/model/mistral.py b/src/liger_kernel/transformers/model/mistral.py index cd0f6f9d9..cc2ab9b76 100644 --- a/src/liger_kernel/transformers/model/mistral.py +++ b/src/liger_kernel/transformers/model/mistral.py @@ -136,3 +136,6 @@ def lce_forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +# Note: Grad Acc is not fixed in mistral at transformer 4.46.1 diff --git a/src/liger_kernel/transformers/model/mixtral.py b/src/liger_kernel/transformers/model/mixtral.py index f449284cf..22fea53da 100644 --- a/src/liger_kernel/transformers/model/mixtral.py +++ b/src/liger_kernel/transformers/model/mixtral.py @@ -22,7 +22,7 @@ @replace_return_docstrings( output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC ) -def lce_forward( +def lce_forward_deprecated( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, @@ -103,7 +103,6 @@ def lce_forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) - logits = logits.float() loss = None if self.training and (labels is not None): @@ -116,6 +115,8 @@ def lce_forward( lce = LigerFusedLinearCrossEntropyLoss() loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) elif labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -156,3 +157,153 @@ def lce_forward( attentions=outputs.attentions, router_logits=outputs.router_logits, ) + + +@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +# Ignore copy +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **loss_kwargs, +) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + logits = None + loss = None + # if in training mode, don't materialize logits + if self.training and (labels is not None): + # We do the same thing as ForCausalLMLoss but using Liger FLCE + + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean" + lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction) + + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + if reduction == "sum": + loss /= loss_kwargs["num_items_in_batch"] + + else: # if in inference mode materialize logits + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **loss_kwargs, + ) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to( + loss.device + ) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) diff --git a/src/liger_kernel/transformers/model/mllama.py b/src/liger_kernel/transformers/model/mllama.py new file mode 100644 index 000000000..fcf45293e --- /dev/null +++ b/src/liger_kernel/transformers/model/mllama.py @@ -0,0 +1,274 @@ +from typing import List, Optional, Tuple, Union + +import torch +from torch.nn import CrossEntropyLoss +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING +from transformers.utils import ( + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) + +from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, +) + + +@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig" +) +def lce_forward_deprecated( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.LongTensor] = None, + cross_attention_mask: Optional[torch.LongTensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Copy paste mllama forward but replace torch cross entropy with liger fused linear cross entropy + + + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: + Example: + ```python + >>> from transformers import AutoTokenizer, MllamaForCausalLM + >>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision") + >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision") + >>> prompt = "If I had to write a haiku, it would be:" + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6) + >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + >>> print(result) + If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful. + I love the idea of snowflakes gently falling, each one + ``` + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + cross_attention_states=cross_attention_states, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + loss = None + logits = None + + if self.training and (labels is not None): + kept_hidden_states = hidden_states[:, -num_logits_to_keep:, :] + + shift_hidden_states = kept_hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + lce = LigerFusedLinearCrossEntropyLoss() + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + + else: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig" +) +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.LongTensor] = None, + cross_attention_mask: Optional[torch.LongTensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **loss_kwargs, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MllamaForCausalLM + + >>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision") + >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision") + + >>> prompt = "If I had to write a haiku, it would be:" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6) + >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + >>> print(result) + If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful. + I love the idea of snowflakes gently falling, each one + ``` + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + cross_attention_states=cross_attention_states, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + logits = None + loss = None + # if in training mode, don't materialize logits + if self.training and (labels is not None): + # We do the same thing as ForCausalLMLoss but using Liger FLCE + + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean" + lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction) + + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + if reduction == "sum": + loss /= loss_kwargs["num_items_in_batch"] + + else: # if in inference mode materialize logits + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **loss_kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/liger_kernel/transformers/model/phi3.py b/src/liger_kernel/transformers/model/phi3.py index 4cb7ec0ea..e860582ce 100644 --- a/src/liger_kernel/transformers/model/phi3.py +++ b/src/liger_kernel/transformers/model/phi3.py @@ -21,7 +21,7 @@ @replace_return_docstrings( output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC ) -def lce_forward( +def lce_forward_deprecated( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, @@ -108,10 +108,11 @@ def lce_forward( loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) else: logits = self.lm_head(hidden_states) - logits = logits.float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -134,3 +135,140 @@ def lce_forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **loss_kwargs, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Phi3ForCausalLM + + >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct") + + >>> prompt = "This is an example script ." + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' + ```""" + + from transformers.models.phi3.modeling_phi3 import logging + + logger = logging.get_logger(__name__) + + if ( + use_cache + and self.config.rope_scaling + and cache_position is not None + and cache_position[0] == self.config.original_max_position_embeddings + ): + logger.warning( + f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed." + ) + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + logits = None + loss = None + # if in training mode, don't materialize logits + if self.training and (labels is not None): + # We do the same thing as ForCausalLMLoss but using Liger FLCE + + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean" + lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction) + + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + if reduction == "sum": + loss /= loss_kwargs["num_items_in_batch"] + + else: # if in inference mode materialize logits + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **loss_kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/liger_kernel/transformers/model/qwen2.py b/src/liger_kernel/transformers/model/qwen2.py index b8e9957e9..b019e4c88 100644 --- a/src/liger_kernel/transformers/model/qwen2.py +++ b/src/liger_kernel/transformers/model/qwen2.py @@ -21,7 +21,7 @@ @replace_return_docstrings( output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC ) -def lce_forward( +def lce_forward_deprecated( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, @@ -109,8 +109,9 @@ def lce_forward( else: logits = self.lm_head(hidden_states) - logits = logits.float() if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -133,3 +134,123 @@ def lce_forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **loss_kwargs, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2ForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + logits = None + loss = None + # if in training mode, don't materialize logits + if self.training and (labels is not None): + # We do the same thing as ForCausalLMLoss but using Liger FLCE + + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean" + lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction) + + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + if reduction == "sum": + loss /= loss_kwargs["num_items_in_batch"] + + else: # if in inference mode materialize logits + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **loss_kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/liger_kernel/transformers/model/qwen2_vl.py b/src/liger_kernel/transformers/model/qwen2_vl.py index cfb7a905b..68087c3e5 100644 --- a/src/liger_kernel/transformers/model/qwen2_vl.py +++ b/src/liger_kernel/transformers/model/qwen2_vl.py @@ -80,6 +80,7 @@ def lce_forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." ```""" + # FIXME: The code is outdated and not compatible with transformer >= 4.46.1 output_attentions = ( output_attentions @@ -115,6 +116,11 @@ def lce_forward( inputs_embeds[video_mask] = video_embeds if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device) + # The code is copied from https://github.com/huggingface/transformers/pull/33487 + if position_ids is None and input_ids is not None: + position_ids, _ = self.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, attention_mask + ) outputs = self.model( input_ids=None, @@ -145,8 +151,9 @@ def lce_forward( loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) else: logits = self.lm_head(hidden_states) - logits = logits.float() if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 52abc1170..bb489be19 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -1,6 +1,11 @@ import inspect import logging from functools import partial +from typing import Callable + +import transformers +from packaging import version +from transformers import PreTrainedModel from torch import nn from transformers import PretrainedConfig, PreTrainedModel @@ -9,11 +14,26 @@ from liger_kernel.transformers.geglu import LigerGEGLUMLP from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward +from liger_kernel.transformers.model.gemma import ( + lce_forward_deprecated as gemma_lce_forward_deprecated, +) from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward +from liger_kernel.transformers.model.llama import ( + lce_forward_deprecated as llama_lce_forward_deprecated, +) from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward +from liger_kernel.transformers.model.mixtral import ( + lce_forward_deprecated as mixtral_lce_forward_deprecated, +) from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward +from liger_kernel.transformers.model.phi3 import ( + lce_forward_deprecated as phi3_lce_forward_deprecated, +) from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward +from liger_kernel.transformers.model.qwen2 import ( + lce_forward_deprecated as qwen2_lce_forward_deprecated, +) from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rope import liger_rotary_pos_emb from liger_kernel.transformers.swiglu import ( @@ -22,7 +42,35 @@ LigerSwiGLUMLP, ) +transformer_version = version.parse(transformers.__version__) + logger = logging.getLogger(__name__) +SUPPORTED_TRANSFORMER_VERSION = "4.46.1" +TRANSFORMER_DEPRECATION_WARNING = "Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/34191" + + +def _bind_method_to_module(module, method_name: str, new_method: Callable): + # Binds a new method to a module instance so that self is passed as the first argument + module.__dict__[method_name] = new_method.__get__(module, module.__class__) + + +def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama"): + module.offset = offset + module.casting_mode = casting_mode + module.variance_epsilon = ( + getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps + ) + _bind_method_to_module(module, "forward", LigerRMSNorm.forward) + _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr) + + +def _patch_layer_norm_module(module, eps=1e-6): + module.variance_epsilon = ( + getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps + ) + module.hidden_size = module.normalized_shape + _bind_method_to_module(module, "forward", LigerLayerNorm.forward) + _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr) def apply_liger_kernel_to_llama( @@ -54,6 +102,7 @@ def apply_liger_kernel_to_llama( ), "cross_entropy and fused_linear_cross_entropy cannot both be True." from transformers.models.llama import modeling_llama + from transformers.models.llama.modeling_llama import LlamaModel if rope: modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb @@ -64,7 +113,134 @@ def apply_liger_kernel_to_llama( if cross_entropy: modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss if fused_linear_cross_entropy: - modeling_llama.LlamaForCausalLM.forward = llama_lce_forward + if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + modeling_llama.LlamaForCausalLM.forward = llama_lce_forward + else: # if version < 4.46.1 + logger.warning(TRANSFORMER_DEPRECATION_WARNING) + modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP) + + # get the base model from the model instance + base_model: LlamaModel = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm) + + for decoder_layer in base_model.layers: + if swiglu: + _bind_method_to_module( + decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward + ) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + +def apply_liger_kernel_to_mllama( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + layer_norm: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace MLlama models. + NOTE: MLlama is not available in transformers<4.45.0 + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + + assert not ( + cross_entropy and fused_linear_cross_entropy + ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + + from transformers.models.mllama import modeling_mllama + from transformers.models.mllama.modeling_mllama import ( + MllamaForCausalLM, + MllamaForConditionalGeneration, + MllamaTextModel, + MllamaVisionModel, + ) + + from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward + from liger_kernel.transformers.model.mllama import ( + lce_forward_deprecated as mllama_lce_forward_deprecated, + ) + + if rope: + modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb + if layer_norm: + modeling_mllama.nn.LayerNorm = LigerLayerNorm + if rms_norm: + modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm + if swiglu: + modeling_mllama.MllamaTextMLP = LigerSwiGLUMLP + if cross_entropy: + modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss + if fused_linear_cross_entropy: + if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward + else: # if version < 4.46.1 + logger.warning(TRANSFORMER_DEPRECATION_WARNING) + modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + if isinstance(model, MllamaForConditionalGeneration): + language_model: MllamaForCausalLM = model.language_model + vision_model: MllamaVisionModel = model.vision_model + text_model: MllamaTextModel = language_model.model + elif isinstance(model, MllamaForCausalLM): + text_model = model.model + vision_model = None + elif isinstance(model, MllamaTextModel): + text_model = model + vision_model = None + else: + raise ValueError(f"Unsupported Mllama model type: {type(model)}") + + if text_model: + if rms_norm: + _patch_rms_norm_module(text_model.norm) + for decoder_layer in text_model.layers: + if swiglu: + _bind_method_to_module( + decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward + ) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + if vision_model: + _patch_layer_norm_module(vision_model.layernorm_pre) + _patch_layer_norm_module(vision_model.layernorm_post) + + for layer in vision_model.transformer.layers: + if layer_norm: + _patch_layer_norm_module(layer.input_layernorm) + _patch_layer_norm_module(layer.post_attention_layernorm) + + for layer in vision_model.global_transformer.layers: + if layer_norm: + _patch_layer_norm_module(layer.input_layernorm) + _patch_layer_norm_module(layer.post_attention_layernorm) if model is not None: # The model instance already exists, so we need to additionally patch the @@ -128,6 +304,7 @@ def apply_liger_kernel_to_mistral( ), "cross_entropy and fused_linear_cross_entropy cannot both be True." from transformers.models.mistral import modeling_mistral + from transformers.models.mistral.modeling_mistral import MistralModel if rope: modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb @@ -143,31 +320,21 @@ def apply_liger_kernel_to_mistral( if model is not None: # The model instance already exists, so we need to additionally patch the # instance variables that reference already-instantiated modules - config: PretrainedConfig = model.config - if hasattr(model, "model"): - # The case for MistralForCausalLM, MistralForTokenClassification for example - base_model = model.model - else: - # Direct MistralModel - base_model = model + # get the base model from the model instance + base_model: MistralModel = getattr(model, model.base_model_prefix, model) - torch_dtype = config.torch_dtype if rms_norm: - base_model.norm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_module(base_model.norm) for decoder_layer in base_model.layers: if swiglu: - decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype) + _bind_method_to_module( + decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward + ) if rms_norm: - decoder_layer.input_layernorm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) - decoder_layer.post_attention_layernorm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) def apply_liger_kernel_to_mixtral( @@ -199,6 +366,7 @@ def apply_liger_kernel_to_mixtral( ), "cross_entropy and fused_linear_cross_entropy cannot both be True." from transformers.models.mixtral import modeling_mixtral + from transformers.models.mixtral.modeling_mixtral import MixtralModel if rope: modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb @@ -207,45 +375,33 @@ def apply_liger_kernel_to_mixtral( if cross_entropy: modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss if fused_linear_cross_entropy: - modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward + if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward + else: # if version < 4.46.1 + logger.warning(TRANSFORMER_DEPRECATION_WARNING) + modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated if swiglu: modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP if model is not None: # The model instance already exists, so we need to additionally patch the # instance variables that reference already-instantiated modules - config: PretrainedConfig = model.config - if hasattr(model, "model"): - # The case for MixtralForCausalLM, MixtralForTokenClassification for example - base_model = model.model - else: - # Direct MixtralModel - base_model = model + # get the base model from the model instance + base_model: MixtralModel = getattr(model, model.base_model_prefix, model) - torch_dtype = config.torch_dtype if rms_norm: - base_model.norm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_module(base_model.norm) for decoder_layer in base_model.layers: if swiglu: - block_sparse_moe = decoder_layer.block_sparse_moe - patched_experts = nn.ModuleList( - [ - LigerBlockSparseTop2MLP(config) - for _ in range(block_sparse_moe.num_experts) - ] - ) - decoder_layer.block_sparse_moe.experts = patched_experts.to(torch_dtype) + for expert in decoder_layer.block_sparse_moe.experts: + _bind_method_to_module( + expert, "forward", LigerBlockSparseTop2MLP.forward + ) if rms_norm: - decoder_layer.input_layernorm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) - decoder_layer.post_attention_layernorm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) def apply_liger_kernel_to_gemma( @@ -277,6 +433,15 @@ def apply_liger_kernel_to_gemma( ), "cross_entropy and fused_linear_cross_entropy cannot both be True." from transformers.models.gemma import modeling_gemma + from transformers.models.gemma.modeling_gemma import GemmaModel + + # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109 + LigerRMSNormForGemma = partial( + LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma" + ) + _patch_rms_norm_module_for_gemma = partial( + _patch_rms_norm_module, casting_mode="gemma", offset=1.0 + ) # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109 LigerRMSNormForGemma = partial( @@ -292,7 +457,30 @@ def apply_liger_kernel_to_gemma( if geglu: modeling_gemma.GemmaMLP = LigerGEGLUMLP if fused_linear_cross_entropy: - modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward + if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward + else: # if version < 4.46.1 + logger.warning(TRANSFORMER_DEPRECATION_WARNING) + modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + # get the base model from the model instance + base_model: GemmaModel = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module_for_gemma(base_model.norm) + + for decoder_layer in base_model.layers: + if geglu: + _bind_method_to_module( + decoder_layer.mlp, "forward", LigerGEGLUMLP.forward + ) + if rms_norm: + _patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm) + _patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm) if model is not None: # The model instance already exists, so we need to additionally patch the @@ -344,6 +532,14 @@ def apply_liger_kernel_to_gemma2( loaded. Default is None. """ from transformers.models.gemma2 import modeling_gemma2 + from transformers.models.gemma2.modeling_gemma2 import Gemma2Model + + LigerRMSNormForGemma2 = partial( + LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros" + ) + _patch_rms_norm_module_for_gemma2 = partial( + _patch_rms_norm_module, offset=1.0, casting_mode="gemma" + ) LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, init_fn="zeros") if rope: @@ -359,37 +555,29 @@ def apply_liger_kernel_to_gemma2( if model is not None: # The model instance already exists, so we need to additionally patch the # instance variables that reference already-instantiated modules - config: PretrainedConfig = model.config - if hasattr(model, "model"): - # The case for Gemma2ForCausalLM, Gemma2ForTokenClassification for example - base_model = model.model - else: - # Direct Gemma2Model - base_model = model + # get the base model from the model instance + base_model: Gemma2Model = getattr(model, model.base_model_prefix, model) - torch_dtype = config.torch_dtype if rms_norm: - base_model.norm = LigerRMSNormForGemma2( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_module_for_gemma2(base_model.norm) for decoder_layer in base_model.layers: if geglu: - decoder_layer.mlp = LigerGEGLUMLP(config).to(torch_dtype) + _bind_method_to_module( + decoder_layer.mlp, "forward", LigerGEGLUMLP.forward + ) if rms_norm: - decoder_layer.input_layernorm = LigerRMSNormForGemma2( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) - decoder_layer.post_attention_layernorm = LigerRMSNormForGemma2( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) - decoder_layer.pre_feedforward_layernorm = LigerRMSNormForGemma2( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) - decoder_layer.post_feedforward_layernorm = LigerRMSNormForGemma2( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm) + _patch_rms_norm_module_for_gemma2( + decoder_layer.post_attention_layernorm + ) + _patch_rms_norm_module_for_gemma2( + decoder_layer.pre_feedforward_layernorm + ) + _patch_rms_norm_module_for_gemma2( + decoder_layer.post_feedforward_layernorm + ) def apply_liger_kernel_to_qwen2( @@ -420,6 +608,7 @@ def apply_liger_kernel_to_qwen2( ), "cross_entropy and fused_linear_cross_entropy cannot both be True." from transformers.models.qwen2 import modeling_qwen2 + from transformers.models.qwen2.modeling_qwen2 import Qwen2Model if rope: modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb @@ -427,39 +616,38 @@ def apply_liger_kernel_to_qwen2( modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm if cross_entropy: modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss + + # import pdb; pdb.set_trace() if fused_linear_cross_entropy: - modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward + + if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward + else: # if version < 4.46.1 + logger.warning(TRANSFORMER_DEPRECATION_WARNING) + modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated + if swiglu: modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP if model is not None: # The model instance already exists, so we need to additionally patch the # instance variables that reference already-instantiated modules - config: PretrainedConfig = model.config - if hasattr(model, "model"): - # The case for Qwen2ForCausalLM, Qwen2ForTokenClassification for example - base_model = model.model - else: - # Direct Qwen2Model - base_model = model + # get the base model from the model instance + base_model: Qwen2Model = getattr(model, model.base_model_prefix, model) - torch_dtype = config.torch_dtype if rms_norm: - base_model.norm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_module(base_model.norm) for decoder_layer in base_model.layers: if swiglu: - decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype) + _bind_method_to_module( + decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward + ) if rms_norm: - decoder_layer.input_layernorm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) - decoder_layer.post_attention_layernorm = LigerRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + print("Applied Liger kernels to Qwen2") def apply_liger_kernel_to_qwen2_vl( @@ -472,7 +660,7 @@ def apply_liger_kernel_to_qwen2_vl( ) -> None: """ Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models. - NOTE: Qwen2-VL is not available in transformers<=4.44.2 + NOTE: Qwen2-VL is not available in transformers<4.45.0 Args: cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. @@ -491,6 +679,7 @@ def apply_liger_kernel_to_qwen2_vl( ), "cross_entropy and fused_linear_cross_entropy cannot both be True." from transformers.models.qwen2_vl import modeling_qwen2_vl + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel from liger_kernel.transformers.model.qwen2_vl import ( lce_forward as qwen2_vl_lce_forward, @@ -498,10 +687,9 @@ def apply_liger_kernel_to_qwen2_vl( # TODO: Support Qwen2-VL's multimodal RoPE implementation - LigerRMSNormForQwen2VL = partial(LigerRMSNorm, init_fn="ones", casting_mode="gemma") if rms_norm: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439 - modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNormForQwen2VL + modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm if layer_norm: modeling_qwen2_vl.LayerNorm = LigerLayerNorm if cross_entropy: @@ -514,90 +702,27 @@ def apply_liger_kernel_to_qwen2_vl( if model is not None: # The model instance already exists, so we need to additionally patch the # instance variables that reference already-instantiated modules - config: PretrainedConfig = model.config - torch_dtype = config.torch_dtype - - if hasattr(model, "model"): - # The case for Qwen2VLForConditionalGeneration. - base_model = model.model - else: - # Direct Qwen2VLModel - base_model = model + # get the base model from the model instance + base_model: Qwen2VLModel = getattr(model, model.base_model_prefix, model) if hasattr(model, "visual"): # Patch Qwen2VisionTransformerPretrainedModel for vision_block in model.visual.blocks: if layer_norm: - vision_block.norm1 = LigerLayerNorm(config.embed_dim, eps=1e-6).to( - torch_dtype - ) - vision_block.norm2 = LigerLayerNorm(config.embed_dim, eps=1e-6).to( - torch_dtype - ) + _patch_layer_norm_module(vision_block.norm1) + _patch_layer_norm_module(vision_block.norm2) if rms_norm: - base_model.norm = LigerRMSNormForQwen2VL( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) + _patch_rms_norm_module(base_model.norm) for decoder_layer in base_model.layers: if swiglu: - decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype) + _bind_method_to_module( + decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward + ) if rms_norm: - decoder_layer.input_layernorm = LigerRMSNormForQwen2VL( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) - decoder_layer.post_attention_layernorm = LigerRMSNormForQwen2VL( - config.hidden_size, eps=config.rms_norm_eps - ).to(torch_dtype) - - -def apply_liger_kernel_to_qwen2_vl( - cross_entropy: bool = False, - fused_linear_cross_entropy: bool = True, - rms_norm: bool = True, - layer_norm: bool = True, - swiglu: bool = True, -) -> None: - """ - Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models. - NOTE: Qwen2-VL is not available in transformers<=4.44.2 - - Args: - cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. - fused_linear_cross_entropy (bool): - Whether to apply Liger's fused linear cross entropy loss. Default is True. - `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. - If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. - rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. - layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True. - swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. - """ - assert not ( - cross_entropy and fused_linear_cross_entropy - ), "cross_entropy and fused_linear_cross_entropy cannot both be True." - - from transformers.models.qwen2_vl import modeling_qwen2_vl - - from liger_kernel.transformers.model.qwen2_vl import ( - lce_forward as qwen2_vl_lce_forward, - ) - - # TODO: Support Qwen2-VL's multimodal RoPE implementation - - if rms_norm: - # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439 - modeling_qwen2_vl.Qwen2RMSNorm = partial( - LigerRMSNorm, init_fn="ones", casting_mode="gemma" - ) - if layer_norm: - modeling_qwen2_vl.LayerNorm = LigerLayerNorm - if cross_entropy: - modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss - if fused_linear_cross_entropy: - modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward - if swiglu: - modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) def apply_liger_kernel_to_phi3( @@ -628,6 +753,7 @@ def apply_liger_kernel_to_phi3( ), "cross_entropy and fused_linear_cross_entropy cannot both be True." from transformers.models.phi3 import modeling_phi3 + from transformers.models.phi3.modeling_phi3 import Phi3Model if rope: modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma @@ -638,7 +764,30 @@ def apply_liger_kernel_to_phi3( if cross_entropy: modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss if fused_linear_cross_entropy: - modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward + if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward + else: # if version < 4.46.1 + logger.warning(TRANSFORMER_DEPRECATION_WARNING) + modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + # get the base model from the model instance + base_model: Phi3Model = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm) + + for decoder_layer in base_model.layers: + if swiglu: + _bind_method_to_module( + decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward + ) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) if model is not None: # The model instance already exists, so we need to additionally patch the @@ -675,6 +824,8 @@ def apply_liger_kernel_to_phi3( "gemma": apply_liger_kernel_to_gemma, "gemma2": apply_liger_kernel_to_gemma2, "llama": apply_liger_kernel_to_llama, + "mllama": apply_liger_kernel_to_mllama, + "mllama_text_model": apply_liger_kernel_to_mllama, "mistral": apply_liger_kernel_to_mistral, "mixtral": apply_liger_kernel_to_mixtral, "qwen2": apply_liger_kernel_to_qwen2, @@ -760,7 +911,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None: for key, value in kwargs.items() if key in apply_fn_signature.parameters } - logger.info( f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}" ) diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 000000000..806fa8664 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,8 @@ +import pytest +import torch + + +@pytest.fixture(autouse=True) +def clear_cuda_cache(): + yield + torch.cuda.empty_cache() diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index f648a88c2..72be62c0c 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -1,5 +1,3 @@ -import functools -import os from test.utils import ( DEFAULT_DATASET_PATH, MiniModelConfig, @@ -9,11 +7,12 @@ revert_liger_kernel_to_llama, revert_liger_kernel_to_mistral, revert_liger_kernel_to_mixtral, + revert_liger_kernel_to_mllama, revert_liger_kernel_to_phi3, revert_liger_kernel_to_qwen2, + revert_liger_kernel_to_qwen2_vl, set_seed, simple_collate_fn, - supports_bfloat16, ) import pytest @@ -34,25 +33,35 @@ apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, + apply_liger_kernel_to_mllama, apply_liger_kernel_to_phi3, apply_liger_kernel_to_qwen2, + apply_liger_kernel_to_qwen2_vl, ) -torch.use_deterministic_algorithms(True) +try: + # Mllama is only available in transformers>=4.45.0 + from transformers.models.mllama.configuration_mllama import MllamaTextConfig + from transformers.models.mllama.modeling_mllama import MllamaForCausalLM -# Only setting torch.use_deterministic_algorithms(True) throws the following error: -# RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`, -# but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an -# environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information, -# go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility + MLLAMA_AVAILABLE = True +except ImportError: + MLLAMA_AVAILABLE = False -os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +try: + # Qwen2-VL is only available in transformers>4.44.2 + from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig + from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLForConditionalGeneration, + ) + + QWEN2_VL_AVAILABLE = True +except ImportError: + QWEN2_VL_AVAILABLE = False MINI_MODEL_SETUPS = { "mini_llama3": MiniModelConfig( - liger_kernel_patch_func=functools.partial( - apply_liger_kernel_to_llama, fused_linear_cross_entropy=False - ), + liger_kernel_patch_func=apply_liger_kernel_to_llama, liger_kernel_patch_revert_func=revert_liger_kernel_to_llama, model_class=LlamaForCausalLM, mini_model_config=LlamaConfig( @@ -76,7 +85,7 @@ rope_theta=500000.0, tie_word_embeddings=False, use_cache=True, - vocab_size=32000, # 128256 + vocab_size=32000, # 128256, # At rope backward # Eager produces incontiguous dq and dk # SDPA produces contiguous dq and incontiguous dk @@ -84,10 +93,112 @@ attn_implementation="sdpa", # default value, pytorch native attention ), ), - "mini_gemma1": MiniModelConfig( - liger_kernel_patch_func=functools.partial( - apply_liger_kernel_to_gemma, fused_linear_cross_entropy=False + "mini_qwen2": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen2, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2, + model_class=Qwen2ForCausalLM, + mini_model_config=Qwen2Config( + attention_dropout=0.0, + bos_token_id=1, # 151643 + eos_token_id=2, # 151643 + hidden_act="silu", + hidden_size=896, + initializer_range=0.02, + intermediate_size=4864, + max_position_embeddings=32768, # 131072 + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-6, + rope_theta=1000000.0, + sliding_window=131072, + tie_word_embeddings=True, + use_cache=True, + vocab_size=32000, # 151936 + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ), + "mini_phi3": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_phi3, + liger_kernel_patch_revert_func=revert_liger_kernel_to_phi3, + model_class=Phi3ForCausalLM, + mini_model_config=Phi3Config( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, # 32000 + hidden_act="silu", + hidden_size=896, # 3072 + initializer_range=0.02, + intermediate_size=4864, # 8192 + max_position_embeddings=4096, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=None, # defaults to num_attention_heads + rms_norm_eps=1e-5, + rope_theta=10000.0, + sliding_window=None, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32064, + attn_implementation="eager", ), + ), + "mini_mistral": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_mistral, + liger_kernel_patch_revert_func=revert_liger_kernel_to_mistral, + model_class=MistralForCausalLM, + mini_model_config=MistralConfig( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=2048, + max_position_embeddings=32768, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-5, + rope_theta=10000.0, + sliding_window=4096, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, + attn_implementation="sdpa", + ), + ), + "mini_mixtral": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_mixtral, + liger_kernel_patch_revert_func=revert_liger_kernel_to_mixtral, + model_class=MixtralForCausalLM, + mini_model_config=MixtralConfig( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=512, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=32768, # 32768 + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + rope_theta=10000.0, + sliding_window=4096, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, + attn_implementation="sdpa", + ), + ), + "mini_gemma1": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gemma, liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma, model_class=GemmaForCausalLM, mini_model_config=GemmaConfig( @@ -119,9 +230,7 @@ ), ), "mini_gemma1.1": MiniModelConfig( - liger_kernel_patch_func=functools.partial( - apply_liger_kernel_to_gemma, fused_linear_cross_entropy=False - ), + liger_kernel_patch_func=apply_liger_kernel_to_gemma, liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma, model_class=GemmaForCausalLM, mini_model_config=GemmaConfig( @@ -174,127 +283,90 @@ rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, + attn_implementation="eager", ), ), - "mini_mistral": MiniModelConfig( - liger_kernel_patch_func=functools.partial( - apply_liger_kernel_to_mistral, fused_linear_cross_entropy=False - ), - liger_kernel_patch_revert_func=revert_liger_kernel_to_mistral, - model_class=MistralForCausalLM, - mini_model_config=MistralConfig( - attention_dropout=0.0, - bos_token_id=1, - eos_token_id=2, - hidden_act="silu", - hidden_size=1024, # 4096 - initializer_range=0.02, - intermediate_size=2048, # 14336 - max_position_embeddings=32768, # 32768 - num_attention_heads=8, # 32 - num_hidden_layers=4, # 32 - num_key_value_heads=2, # 8 - rms_norm_eps=1e-5, - rope_theta=10000.0, - sliding_window=4096, - tie_word_embeddings=False, - use_cache=True, - vocab_size=32000, - attn_implementation="sdpa", - ), - ), - "mini_mixtral": MiniModelConfig( - liger_kernel_patch_func=apply_liger_kernel_to_mixtral, - liger_kernel_patch_revert_func=revert_liger_kernel_to_mixtral, - model_class=MixtralForCausalLM, - mini_model_config=MixtralConfig( - attention_dropout=0.0, - bos_token_id=1, - eos_token_id=2, +} + +if MLLAMA_AVAILABLE: + MINI_MODEL_SETUPS["mini_mllama"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_mllama, + liger_kernel_patch_revert_func=revert_liger_kernel_to_mllama, + model_class=MllamaForCausalLM, + mini_model_config=MllamaTextConfig( + bos_token_id=1, # 128000 + eos_token_id=2, # 128001 + pad_token_id=2, + cross_attention_layers=None, + dropout=0, hidden_act="silu", hidden_size=1024, # 4096 initializer_range=0.02, intermediate_size=2048, # 14336 - max_position_embeddings=32768, # 32768 + max_position_embeddings=131_072, num_attention_heads=8, # 32 - num_experts_per_tok=2, - num_hidden_layers=4, # 32 + num_hidden_layers=4, # 40 num_key_value_heads=2, # 8 - num_local_experts=8, - output_router_logits=False, rms_norm_eps=1e-5, - rope_theta=1000000.0, - router_aux_loss_coef=0.02, - sliding_window=None, + rope_scaling=dict( + factor=8.0, + high_freq_factor=4.0, + low_freq_factor=1.0, + original_max_position_embeddings=8192, + rope_type="llama3", + ), + rope_theta=500_000, tie_word_embeddings=False, use_cache=True, - vocab_size=32000, - # At rope backward - # Eager produces incontiguous dq and dk - # SDPA produces contiguous dq and incontiguous dk - # Flash_attn produces contiguous dq and dk + vocab_size=32000, # 128256, attn_implementation="sdpa", # default value, pytorch native attention ), - ), - "mini_qwen2": MiniModelConfig( - liger_kernel_patch_func=functools.partial( - apply_liger_kernel_to_qwen2, fused_linear_cross_entropy=False - ), - liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2, - model_class=Qwen2ForCausalLM, - mini_model_config=Qwen2Config( + ) + +if QWEN2_VL_AVAILABLE: + MINI_MODEL_SETUPS["mini_qwen2_vl"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_qwen2_vl, + liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2_vl, + model_class=Qwen2VLForConditionalGeneration, + mini_model_config=Qwen2VLConfig( attention_dropout=0.0, bos_token_id=1, # 151643 - eos_token_id=2, # 151643 + eos_token_id=2, # 151645 hidden_act="silu", - hidden_size=896, + hidden_size=1536, # 8192 initializer_range=0.02, - intermediate_size=4864, - max_position_embeddings=32768, # 131072 - num_attention_heads=8, - num_hidden_layers=4, - num_key_value_heads=2, - rms_norm_eps=1e-6, + intermediate_size=4864, # 29568 + max_position_embeddings=32768, + max_window_layers=4, # 80 + num_attention_heads=12, # 64 + num_hidden_layers=4, # 80 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-6, # 1e-5 rope_theta=1000000.0, - sliding_window=131072, - tie_word_embeddings=True, - use_cache=True, - vocab_size=32000, # 151936 - # At rope backward - # Eager produces incontiguous dq and dk - # SDPA produces contiguous dq and incontiguous dk - # Flash_attn produces contiguous dq and dk - attn_implementation="sdpa", # default value, pytorch native attention - ), - ), - "mini_phi3": MiniModelConfig( - liger_kernel_patch_func=functools.partial( - apply_liger_kernel_to_phi3, fused_linear_cross_entropy=False - ), - liger_kernel_patch_revert_func=revert_liger_kernel_to_phi3, - model_class=Phi3ForCausalLM, - mini_model_config=Phi3Config( - attention_dropout=0.0, - bos_token_id=1, - eos_token_id=2, # 32000 - hidden_act="silu", - hidden_size=896, # 3072 - initializer_range=0.02, - intermediate_size=4864, # 8192 - max_position_embeddings=4096, - num_attention_heads=8, # 32 - num_hidden_layers=4, # 32 - num_key_value_heads=None, # defaults to num_attention_heads - rms_norm_eps=1e-5, - rope_theta=10000.0, - sliding_window=None, + rope_scaling=dict( + type="mrope", + mrope_section=[16, 24, 24], # (temporal, height, width) + ), + sliding_window=4096, tie_word_embeddings=False, use_cache=True, - vocab_size=32064, - attn_implementation="eager", + vocab_size=32000, # 152064 + use_sliding_window=False, + vision_config={ + "depth": 4, # 32 + "embed_dim": 1280, + "mlp_ratio": 4, + "num_heads": 16, + "in_chans": 3, + "hidden_size": 128, # 1536 + "patch_size": 14, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "temporal_patch_size": 2, + }, + attn_implementation="sdpa", ), - ), -} + ) def create_model(model_name="mini_llama3"): @@ -323,21 +395,37 @@ def run_mini_model( if with_liger is True: kwargs = { - "rope": True, "rms_norm": True, - "cross_entropy": True, } + model_supports_rope = "qwen2_vl" not in model_name + if model_supports_rope: + kwargs["rope"] = True + + model_supports_layer_norm = "qwen2_vl" in model_name + if model_supports_layer_norm: + kwargs["layer_norm"] = True + if "gemma" in model_name: kwargs["geglu"] = True else: kwargs["swiglu"] = True + + model_support_flce = "gemma2" not in model_name + + if model_support_flce: + kwargs["fused_linear_cross_entropy"] = True + kwargs["cross_entropy"] = False + else: + kwargs["cross_entropy"] = True + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) else: - MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func() + ... + # FIXME: disable revert because it will cause flce to not be patched + # MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func() model = create_model(model_name).to(dtype).to("cuda") train_dataset = load_from_disk(DEFAULT_DATASET_PATH) - loader = DataLoader( train_dataset, batch_size=16, shuffle=False, collate_fn=simple_collate_fn ) @@ -355,130 +443,220 @@ def run_mini_model( print(f"Step {i}, Loss: {output.loss.item()}") loss_list.append(output.loss.item()) - MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func() + # MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func() return {"loss": loss_list, "logits": output.logits, "model": model} @pytest.mark.parametrize( + # FIXME enable bf16 tests after revert is fixed "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol", [ - # Gemma 1 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way) - ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 6e-4, 5e-3, 1e-5, 5e-3, 1e-5), - pytest.param( - "mini_gemma1", - 32, - 1e-4, - torch.bfloat16, - 1e-3, - 1e-2, - 1e-1, - 1e-2, - 1e-2, - 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), - ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5), + # pytest.param( + # "mini_llama3", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-2, + # 1e-2, + # marks=pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # ), pytest.param( - "mini_gemma1.1", + "mini_mllama", 32, 1e-4, - torch.bfloat16, - 1e-3, - 1e-2, - 1e-1, - 1e-2, - 1e-2, - 1e-2, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), - ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), - pytest.param( - "mini_gemma2", - 32, - 1e-4, - torch.bfloat16, - 1e-3, - 1e-2, - 1e-1, - 1e-2, - 1e-2, - 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), - ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), - pytest.param( - "mini_llama3", - 32, - 1e-4, - torch.bfloat16, - 1e-3, - 1e-2, - 1e-1, - 1e-2, - 1e-2, - 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), - # TODO: torch 2.5.0 nightly breaks mixtral test, but torch 2.3.0 works fine - # TODO: mixtral MoE structure makes the convergence flaky so disable the test for now. It needs high tol to pass. - # ("mini_mixtral", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 8e-3, 1e-5), - # ("mini_mixtral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 2.0, 1e-5, 1e-2, 1e-5), - ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), - pytest.param( - "mini_mistral", - 32, - 1e-4, - torch.bfloat16, - 1e-3, - 1e-2, - 1e-1, - 1e-2, - 1e-2, - 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + not MLLAMA_AVAILABLE, + reason="Mllama not available in this version of transformers", ), ), + # pytest.param( + # "mini_mllama", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-2, + # 1e-2, + # marks=[ + # pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # pytest.mark.skipif( + # not MLLAMA_AVAILABLE, + # reason="Mllama not available in this version of transformers", + # ), + # ], + # ), ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), - pytest.param( - "mini_qwen2", - 32, - 1e-4, - torch.bfloat16, - 1e-3, - 1e-2, - 1e-1, - 1e-2, - 1e-2, - 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), + # pytest.param( + # "mini_qwen2", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-2, + # 1e-2, + # marks=pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # ), + # FIXME qwen2 is broken and needs fix + # pytest.param( + # "mini_qwen2_vl", + # 32, + # 1e-4, + # torch.float32, + # 1e-8, + # 1e-5, + # 5e-3, + # 1e-5, + # 5e-3, + # 1e-5, + # marks=pytest.mark.skipif( + # not QWEN2_VL_AVAILABLE, + # reason="Qwen2-VL not available in this version of transformers", + # ), + # ), + # pytest.param( + # "mini_qwen2_vl", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-2, + # 1e-2, + # marks=[ + # pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # pytest.mark.skipif( + # not QWEN2_VL_AVAILABLE, + # reason="Qwen2-VL not available in this version of transformers", + # ), + # ], + # ), ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), - pytest.param( - "mini_phi3", - 32, - 1e-4, - torch.bfloat16, - 1e-3, - 1e-2, - 1e-1, - 1e-2, - 1e-2, - 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), + # pytest.param( + # "mini_phi3", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-2, + # 1e-2, + # marks=pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # ), + ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + # pytest.param( + # "mini_mistral", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-2, + # 1e-2, + # marks=pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # ), + # TODO: mixtral is flaky so disable the test for now + # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5), + # pytest.param( + # "mini_mixtral", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-1, + # 1e-2, + # marks=pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # ), + # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way) + ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), + # pytest.param( + # "mini_gemma1", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-2, + # 1e-2, + # marks=pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # ), + ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), + # pytest.param( + # "mini_gemma1.1", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-2, + # 1e-2, + # marks=pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # ), + # TODO: Gemma2 tests are not passing within the tolerance range, need to investigate + # ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), + # pytest.param( + # "mini_gemma2", + # 32, + # 1e-4, + # torch.bfloat16, + # 1e-3, + # 1e-2, + # 1e-1, + # 1e-2, + # 1e-2, + # 1e-2, + # marks=pytest.mark.skipif( + # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + # ), + # ), ], ) def test_mini_model( @@ -503,7 +681,7 @@ def test_mini_model( model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True ) - # Compare the loss of every step + # Compare every step of the loss assert_verbose_allclose( torch.tensor([expected_output["loss"]]), torch.tensor([actual_output["loss"]]), @@ -511,13 +689,15 @@ def test_mini_model( rtol=loss_rtol, ) - # Compare the logits from the last step - assert_verbose_allclose( - expected_output["logits"], - actual_output["logits"], - atol=logits_atol, - rtol=logits_rtol, - ) + # No logits are materialized + + # # Compare the logits from the last step + # assert_verbose_allclose( + # expected_output["logits"], + # actual_output["logits"], + # atol=logits_atol, + # rtol=logits_rtol, + # ) # Compare the params from the last step # Iterate over the model's parameters and compare them diff --git a/test/convergence/test_mini_models_multimodal.py b/test/convergence/test_mini_models_multimodal.py index 4c164ba58..c835df05d 100644 --- a/test/convergence/test_mini_models_multimodal.py +++ b/test/convergence/test_mini_models_multimodal.py @@ -1,34 +1,63 @@ import functools import os from test.utils import ( + FAKE_CONFIGS_PATH, UNTOKENIZED_DATASET_PATH, MiniModelConfig, assert_verbose_allclose, + load_tokenizer_config, multimodal_collate_fn, + revert_liger_kernel_to_mllama, revert_liger_kernel_to_qwen2_vl, set_seed, supports_bfloat16, + train_bpe_tokenizer, ) import pytest import torch from datasets import load_dataset from torch.utils.data import DataLoader -from transformers.models.auto.processing_auto import AutoProcessor +from transformers import PreTrainedTokenizerFast -from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl +from liger_kernel.transformers import ( + apply_liger_kernel_to_mllama, + apply_liger_kernel_to_qwen2_vl, +) try: - # Qwen2-VL is only available in transformers>4.44.2 + # Qwen2-VL is only available in transformers>=4.45.0 + from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig + from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( + Qwen2VLImageProcessor, + ) from transformers.models.qwen2_vl.modeling_qwen2_vl import ( Qwen2VLForConditionalGeneration, ) + from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor QWEN2_VL_AVAILABLE = True except ImportError: QWEN2_VL_AVAILABLE = False +try: + # Mllama is only available in transformers>=4.45.0 + from transformers.models.mllama.configuration_mllama import ( + MllamaConfig, + MllamaTextConfig, + MllamaVisionConfig, + ) + from transformers.models.mllama.image_processing_mllama import MllamaImageProcessor + from transformers.models.mllama.modeling_mllama import ( + MllamaForConditionalGeneration, + ) + from transformers.models.mllama.processing_mllama import MllamaProcessor + + MLLAMA_AVAILABLE = True +except ImportError: + MLLAMA_AVAILABLE = False + torch.use_deterministic_algorithms(True) # Only setting torch.use_deterministic_algorithms(True) throws the following error: @@ -43,6 +72,64 @@ MINI_MODEL_SETUPS = {} + +if MLLAMA_AVAILABLE: + MINI_MODEL_SETUPS["mini_mllama"] = MiniModelConfig( + liger_kernel_patch_func=functools.partial( + apply_liger_kernel_to_mllama, fused_linear_cross_entropy=False + ), + liger_kernel_patch_revert_func=revert_liger_kernel_to_mllama, + model_class=MllamaForConditionalGeneration, + mini_model_config=MllamaConfig( + vision_config=MllamaVisionConfig( + hidden_act="gelu", + hidden_size=512, # 1280 + image_size=560, # 560 + initializer_range=0.02, + intermediate_layers_indices=[2], # [3, 7, 15, etc...] + intermediate_size=2048, # 5120 + max_num_tiles=1, # 4 + norm_eps=1e-5, + num_attention_heads=4, # 16 + num_channels=3, + num_global_layers=2, # 8 + num_hidden_layers=8, # 32 + patch_size=140, # 14 + supported_aspect_ratios=[[1, 1]], # [[1, 1], [1, 2], etc... ] + vision_output_dim=1024, # 7680 + ), + text_config=MllamaTextConfig( + bos_token_id=0, + eos_token_id=0, + pad_token_id=0, + cross_attention_layers=[2], # [3, 8, 13, 18, etc...] + dropout=0, + hidden_act="silu", + hidden_size=1024, # 4096 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=131_072, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 40 + num_key_value_heads=2, # 8 + rms_norm_eps=1e-5, + rope_scaling=dict( + factor=8.0, + high_freq_factor=4.0, + low_freq_factor=1.0, + original_max_position_embeddings=8192, + rope_type="llama3", + ), + rope_theta=500_000, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 128256, + ), + image_token_index=1, # NOTE: outside the vocab size + attn_implementation="sdpa", + ), + ) + if QWEN2_VL_AVAILABLE: MINI_MODEL_SETUPS["mini_qwen2_vl"] = MiniModelConfig( liger_kernel_patch_func=functools.partial( @@ -54,12 +141,12 @@ attention_dropout=0.0, # Token Ids and vocab size must match those in the tokenizer/processor # https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/blob/main/config.json - bos_token_id=151643, - eos_token_id=151645, - vision_start_token_id=151652, - vision_end_token_id=151653, - vision_token_id=151654, - image_token_id=151655, + bos_token_id=0, + eos_token_id=0, + vision_start_token_id=1, + vision_end_token_id=2, + vision_token_id=3, + image_token_id=4, hidden_act="silu", hidden_size=1024, # 8192 initializer_range=0.02, @@ -78,7 +165,7 @@ sliding_window=4096, tie_word_embeddings=True, use_cache=False, # True - vocab_size=152064, + vocab_size=32000, # 152064, use_sliding_window=False, vision_config={ "depth": 4, # 32 @@ -95,7 +182,51 @@ def create_processor(model_name): if model_name == "mini_qwen2_vl": - return AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + tokenizer_config = load_tokenizer_config( + os.path.join( + FAKE_CONFIGS_PATH, "Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json" + ) + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + qwen_tokenizer = Qwen2TokenizerFast( + tokenizer_object=tokenizer_base, **tokenizer_config + ) + image_processor = Qwen2VLImageProcessor() + return Qwen2VLProcessor( + image_processor=image_processor, tokenizer=qwen_tokenizer + ) + + elif model_name == "mini_mllama": + tokenizer_config = load_tokenizer_config( + os.path.join( + FAKE_CONFIGS_PATH, + "meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json", + ) + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + fast_tokenizer = PreTrainedTokenizerFast( + tokenizer_object=tokenizer_base, **tokenizer_config + ) + image_processor = MllamaImageProcessor(size={"height": 560, "width": 560}) + return MllamaProcessor( + image_processor=image_processor, tokenizer=fast_tokenizer + ) else: raise ValueError(f"Processor not available for model {model_name}") @@ -129,7 +260,9 @@ def apply_chat_template(example): "content": [{"type": "text", "text": example["text"]}], }, ] - example["text"] = processor.apply_chat_template(conversation, tokenize=False) + example["text"] = processor.tokenizer.apply_chat_template( + conversation, tokenize=False + ) return example def preprocess_function(examples): @@ -140,6 +273,7 @@ def preprocess_function(examples): padding="max_length", truncation=True, max_length=1024, # longer than for text-only b/c images require quite a few tokens + return_tensors="pt", ) train_dataset = ( @@ -182,15 +316,12 @@ def run_mini_model_multimodal( kwargs = { "rms_norm": True, "cross_entropy": True, + "layer_norm": True, } model_supports_rope = "qwen2_vl" not in model_name if model_supports_rope: kwargs["rope"] = True - model_supports_layer_norm = "qwen2_vl" in model_name - if model_supports_layer_norm: - kwargs["layer_norm"] = True - if "gemma" in model_name: kwargs["geglu"] = True else: @@ -265,6 +396,43 @@ def run_mini_model_multimodal( ), ], ), + pytest.param( + "mini_mllama", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not MLLAMA_AVAILABLE, + reason="Mllama not available in this version of transformers", + ), + ), + pytest.param( + "mini_mllama", + 32, + 1e-4, + torch.bfloat16, + 1e-3, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + pytest.mark.skipif( + not MLLAMA_AVAILABLE, + reason="Mllama not available in this version of transformers", + ), + ], + ), ], ) def test_mini_model_multimodal( diff --git a/test/convergence/test_mini_models_no_logits.py b/test/convergence/test_mini_models_no_logits.py deleted file mode 100644 index 7dfaa00f1..000000000 --- a/test/convergence/test_mini_models_no_logits.py +++ /dev/null @@ -1,621 +0,0 @@ -from test.utils import ( - DEFAULT_DATASET_PATH, - MiniModelConfig, - assert_verbose_allclose, - revert_liger_kernel_to_gemma, - revert_liger_kernel_to_gemma2, - revert_liger_kernel_to_llama, - revert_liger_kernel_to_mistral, - revert_liger_kernel_to_mixtral, - revert_liger_kernel_to_phi3, - revert_liger_kernel_to_qwen2, - revert_liger_kernel_to_qwen2_vl, - set_seed, - simple_collate_fn, - supports_bfloat16, -) - -import pytest -import torch -from datasets import load_from_disk -from torch.utils.data import DataLoader -from transformers.models.gemma import GemmaConfig, GemmaForCausalLM -from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM -from transformers.models.llama import LlamaConfig, LlamaForCausalLM -from transformers.models.mistral import MistralConfig, MistralForCausalLM -from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM -from transformers.models.phi3 import Phi3Config, Phi3ForCausalLM -from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM - -from liger_kernel.transformers import ( - apply_liger_kernel_to_gemma, - apply_liger_kernel_to_gemma2, - apply_liger_kernel_to_llama, - apply_liger_kernel_to_mistral, - apply_liger_kernel_to_mixtral, - apply_liger_kernel_to_phi3, - apply_liger_kernel_to_qwen2, - apply_liger_kernel_to_qwen2_vl, -) - -try: - # Qwen2-VL is only available in transformers>4.44.2 - from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig - from transformers.models.qwen2_vl.modeling_qwen2_vl import ( - Qwen2VLForConditionalGeneration, - ) - - QWEN2_VL_AVAILABLE = True -except ImportError: - QWEN2_VL_AVAILABLE = False - -MINI_MODEL_SETUPS = { - "mini_llama3": MiniModelConfig( - liger_kernel_patch_func=apply_liger_kernel_to_llama, - liger_kernel_patch_revert_func=revert_liger_kernel_to_llama, - model_class=LlamaForCausalLM, - mini_model_config=LlamaConfig( - attention_bias=False, - attention_dropout=0.0, - # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset - # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json - bos_token_id=1, # 128000 - eos_token_id=2, # 128001 - hidden_act="silu", - hidden_size=1024, # 4096 - initializer_range=0.02, - intermediate_size=2048, # 14336 - max_position_embeddings=8192, - num_attention_heads=8, # 32 - num_hidden_layers=4, # 32 - num_key_value_heads=2, # 8 - pretraining_tp=1, - rms_norm_eps=1e-5, - rope_scaling=None, - rope_theta=500000.0, - tie_word_embeddings=False, - use_cache=True, - vocab_size=32000, # 128256, - # At rope backward - # Eager produces incontiguous dq and dk - # SDPA produces contiguous dq and incontiguous dk - # Flash_attn produces contiguous dq and dk - attn_implementation="sdpa", # default value, pytorch native attention - ), - ), - "mini_qwen2": MiniModelConfig( - liger_kernel_patch_func=apply_liger_kernel_to_qwen2, - liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2, - model_class=Qwen2ForCausalLM, - mini_model_config=Qwen2Config( - attention_dropout=0.0, - bos_token_id=1, # 151643 - eos_token_id=2, # 151643 - hidden_act="silu", - hidden_size=896, - initializer_range=0.02, - intermediate_size=4864, - max_position_embeddings=32768, # 131072 - num_attention_heads=8, - num_hidden_layers=4, - num_key_value_heads=2, - rms_norm_eps=1e-6, - rope_theta=1000000.0, - sliding_window=131072, - tie_word_embeddings=True, - use_cache=True, - vocab_size=32000, # 151936 - # At rope backward - # Eager produces incontiguous dq and dk - # SDPA produces contiguous dq and incontiguous dk - # Flash_attn produces contiguous dq and dk - attn_implementation="sdpa", # default value, pytorch native attention - ), - ), - "mini_phi3": MiniModelConfig( - liger_kernel_patch_func=apply_liger_kernel_to_phi3, - liger_kernel_patch_revert_func=revert_liger_kernel_to_phi3, - model_class=Phi3ForCausalLM, - mini_model_config=Phi3Config( - attention_dropout=0.0, - bos_token_id=1, - eos_token_id=2, # 32000 - hidden_act="silu", - hidden_size=896, # 3072 - initializer_range=0.02, - intermediate_size=4864, # 8192 - max_position_embeddings=4096, - num_attention_heads=8, # 32 - num_hidden_layers=4, # 32 - num_key_value_heads=None, # defaults to num_attention_heads - rms_norm_eps=1e-5, - rope_theta=10000.0, - sliding_window=None, - tie_word_embeddings=False, - use_cache=True, - vocab_size=32064, - attn_implementation="eager", - ), - ), - "mini_mistral": MiniModelConfig( - liger_kernel_patch_func=apply_liger_kernel_to_mistral, - liger_kernel_patch_revert_func=revert_liger_kernel_to_mistral, - model_class=MistralForCausalLM, - mini_model_config=MistralConfig( - attention_dropout=0.0, - bos_token_id=1, - eos_token_id=2, - hidden_act="silu", - hidden_size=1024, - initializer_range=0.02, - intermediate_size=2048, - max_position_embeddings=32768, - num_attention_heads=8, - num_hidden_layers=4, - num_key_value_heads=2, - rms_norm_eps=1e-5, - rope_theta=10000.0, - sliding_window=4096, - tie_word_embeddings=False, - use_cache=True, - vocab_size=32000, - attn_implementation="sdpa", - ), - ), - "mini_mixtral": MiniModelConfig( - liger_kernel_patch_func=apply_liger_kernel_to_mixtral, - liger_kernel_patch_revert_func=revert_liger_kernel_to_mixtral, - model_class=MixtralForCausalLM, - mini_model_config=MixtralConfig( - attention_dropout=0.0, - bos_token_id=1, - eos_token_id=2, - hidden_act="silu", - hidden_size=512, # 4096 - initializer_range=0.02, - intermediate_size=2048, # 14336 - max_position_embeddings=32768, # 32768 - num_attention_heads=8, # 32 - num_hidden_layers=4, # 32 - num_key_value_heads=2, # 8 - rms_norm_eps=1e-5, - rope_theta=10000.0, - sliding_window=4096, - tie_word_embeddings=False, - use_cache=True, - vocab_size=32000, - attn_implementation="sdpa", - ), - ), - "mini_gemma1": MiniModelConfig( - liger_kernel_patch_func=apply_liger_kernel_to_gemma, - liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma, - model_class=GemmaForCausalLM, - mini_model_config=GemmaConfig( - vocab_size=32000, # 256000 - hidden_size=1024, # 3072 - intermediate_size=2048, # 24576 - num_hidden_layers=4, # 28 - num_attention_heads=4, # 16 - num_key_value_heads=4, # 16 - head_dim=256, - # gemma1 model config uses `hidden_act` and point it to gelu, - # https://huggingface.co/google/gemma-7b/blob/main/config.json#L10 - # but in reality it's ignored and HuggingFace will use tanh approximation: - # https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/gemma/modeling_gemma.py#L175 - hidden_act="gelu", - max_position_embeddings=8192, - initializer_range=0.02, - rms_norm_eps=1e-06, - use_cache=True, - pad_token_id=0, - # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset - # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json - bos_token_id=1, # 128000 - eos_token_id=2, # 128001 - tie_word_embeddings=True, - rope_theta=10000.0, - attention_bias=False, - attention_dropout=0.0, - ), - ), - "mini_gemma1.1": MiniModelConfig( - liger_kernel_patch_func=apply_liger_kernel_to_gemma, - liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma, - model_class=GemmaForCausalLM, - mini_model_config=GemmaConfig( - vocab_size=32000, # 256000 - hidden_size=1024, # 3072 - intermediate_size=2048, # 24576 - num_hidden_layers=4, # 28 - num_attention_heads=4, # 16 - num_key_value_heads=4, # 16 - head_dim=256, - hidden_activation="gelu_pytorch_tanh", - max_position_embeddings=8192, - initializer_range=0.02, - rms_norm_eps=1e-06, - use_cache=True, - pad_token_id=0, - # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset - # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json - bos_token_id=1, # 128000 - eos_token_id=2, # 128001 - tie_word_embeddings=True, - rope_theta=10000.0, - attention_bias=False, - attention_dropout=0.0, - ), - ), - "mini_gemma2": MiniModelConfig( - liger_kernel_patch_func=apply_liger_kernel_to_gemma2, - liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma2, - model_class=Gemma2ForCausalLM, - mini_model_config=Gemma2Config( - vocab_size=32000, # 256000 - hidden_size=1024, # 3072 - intermediate_size=2048, # 24576 - num_hidden_layers=4, # 28 - num_attention_heads=4, # 16 - num_key_value_heads=4, # 16 - head_dim=256, - hidden_activation="gelu_pytorch_tanh", - max_position_embeddings=8192, - initializer_range=0.02, - rms_norm_eps=1e-06, - use_cache=True, - pad_token_id=0, - # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset - # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json - bos_token_id=1, # 128000 - eos_token_id=2, # 128001 - tie_word_embeddings=True, - rope_theta=10000.0, - attention_bias=False, - attention_dropout=0.0, - ), - ), -} - -if QWEN2_VL_AVAILABLE: - MINI_MODEL_SETUPS["mini_qwen2_vl"] = MiniModelConfig( - liger_kernel_patch_func=apply_liger_kernel_to_qwen2_vl, - liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2_vl, - model_class=Qwen2VLForConditionalGeneration, - mini_model_config=Qwen2VLConfig( - attention_dropout=0.0, - bos_token_id=1, # 151643 - eos_token_id=2, # 151645 - hidden_act="silu", - hidden_size=1536, # 8192 - initializer_range=0.02, - intermediate_size=4864, # 29568 - max_position_embeddings=32768, - max_window_layers=4, # 80 - num_attention_heads=12, # 64 - num_hidden_layers=4, # 80 - num_key_value_heads=2, # 8 - rms_norm_eps=1e-6, # 1e-5 - rope_theta=1000000.0, - rope_scaling=dict( - type="mrope", - mrope_section=[16, 24, 24], # (temporal, height, width) - ), - sliding_window=4096, - tie_word_embeddings=False, - use_cache=True, - vocab_size=32000, # 152064 - use_sliding_window=False, - vision_config={ - "depth": 4, # 32 - "embed_dim": 1280, - "mlp_ratio": 4, - "num_heads": 16, - "in_chans": 3, - "hidden_size": 128, # 1536 - "patch_size": 14, - "spatial_merge_size": 2, - "spatial_patch_size": 14, - "temporal_patch_size": 2, - }, - attn_implementation="sdpa", - ), - ) - - -def create_model(model_name="mini_llama3"): - """ - Create a mini version model - The commented values are the original values - """ - model_config = MINI_MODEL_SETUPS[model_name].mini_model_config - model_class = MINI_MODEL_SETUPS[model_name].model_class - return model_class(model_config) - - -def run_mini_model( - model_name="mini_llama3", - num_steps=100, - dtype=torch.bfloat16, - lr=1e-5, - with_liger=False, -): - # If we move it to the beginning of test_mini_model, the two runs are initialized with different weights. - # This is due to RNG (Random Number Generator). The formula of RNG progression is x_(n+1) = (a * x_n + c) % m - # Everytime RNG is used, like randomly initialzing weight, the RNG progresses to the next state. - # Therefore, we have to reset RNG before we create the model to ensure the weight initialization started from the same RNG state. - - set_seed(42) - - if with_liger is True: - kwargs = { - "rms_norm": True, - } - model_supports_rope = "qwen2_vl" not in model_name - if model_supports_rope: - kwargs["rope"] = True - - model_supports_layer_norm = "qwen2_vl" in model_name - if model_supports_layer_norm: - kwargs["layer_norm"] = True - - if "gemma" in model_name: - kwargs["geglu"] = True - else: - kwargs["swiglu"] = True - - model_support_flce = "gemma2" not in model_name - if model_support_flce: - kwargs["fused_linear_cross_entropy"] = True - kwargs["cross_entropy"] = False - else: - kwargs["cross_entropy"] = True - - MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) - else: - MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func() - - model = create_model(model_name).to(dtype).to("cuda") - train_dataset = load_from_disk(DEFAULT_DATASET_PATH) - loader = DataLoader( - train_dataset, batch_size=16, shuffle=False, collate_fn=simple_collate_fn - ) - loader_iter = iter(loader) - optimizer = torch.optim.AdamW(model.parameters(), lr=lr) - - loss_list = [] - - for i in range(num_steps): - batch = next(loader_iter).to(model.device) - optimizer.zero_grad() - output = model(**batch) - output.loss.backward() - optimizer.step() - print(f"Step {i}, Loss: {output.loss.item()}") - loss_list.append(output.loss.item()) - - MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func() - return {"loss": loss_list, "logits": output.logits, "model": model} - - -@pytest.mark.parametrize( - "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol", - [ - ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5), - pytest.param( - "mini_llama3", - 32, - 1e-4, - torch.bfloat16, - 1e-3, - 1e-2, - 1e-1, - 1e-2, - 1e-2, - 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), - ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), - pytest.param( - "mini_qwen2", - 32, - 1e-4, - torch.bfloat16, - 1e-3, - 1e-2, - 1e-1, - 1e-2, - 1e-2, - 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), - pytest.param( - "mini_qwen2_vl", - 32, - 1e-4, - torch.float32, - 1e-8, - 1e-5, - 5e-3, - 1e-5, - 5e-3, - 1e-5, - marks=pytest.mark.skipif( - not QWEN2_VL_AVAILABLE, - reason="Qwen2-VL not available in this version of transformers", - ), - ), - pytest.param( - "mini_qwen2_vl", - 32, - 1e-4, - torch.bfloat16, - 1e-3, - 1e-2, - 1e-1, - 1e-2, - 1e-2, - 1e-2, - marks=[ - pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - pytest.mark.skipif( - not QWEN2_VL_AVAILABLE, - reason="Qwen2-VL not available in this version of transformers", - ), - ], - ), - ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), - pytest.param( - "mini_phi3", - 32, - 1e-4, - torch.bfloat16, - 1e-3, - 1e-2, - 1e-1, - 1e-2, - 1e-2, - 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), - ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), - pytest.param( - "mini_mistral", - 32, - 1e-4, - torch.bfloat16, - 1e-3, - 1e-2, - 1e-1, - 1e-2, - 1e-2, - 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), - # TODO: mixtral is flaky so disable the test for now - # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5), - # pytest.param( - # "mini_mixtral", - # 32, - # 1e-4, - # torch.bfloat16, - # 1e-3, - # 1e-2, - # 1e-1, - # 1e-2, - # 1e-1, - # 1e-2, - # marks=pytest.mark.skipif( - # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - # ), - # ), - # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way) - ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), - pytest.param( - "mini_gemma1", - 32, - 1e-4, - torch.bfloat16, - 1e-3, - 1e-2, - 1e-1, - 1e-2, - 1e-2, - 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), - ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), - pytest.param( - "mini_gemma1.1", - 32, - 1e-4, - torch.bfloat16, - 1e-3, - 1e-2, - 1e-1, - 1e-2, - 1e-2, - 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), - ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), - pytest.param( - "mini_gemma2", - 32, - 1e-4, - torch.bfloat16, - 1e-3, - 1e-2, - 1e-1, - 1e-2, - 1e-2, - 1e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), - ], -) -def test_mini_model( - model_name, - num_steps, - lr, - dtype, - loss_atol, - loss_rtol, - logits_atol, - logits_rtol, - param_atol, - param_rtol, -): - # Non-liger models should be initialized and tested first to avoid the module being overridden - - expected_output = run_mini_model( - model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr - ) - - actual_output = run_mini_model( - model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True - ) - - # Compare every step of the loss - assert_verbose_allclose( - torch.tensor([expected_output["loss"]]), - torch.tensor([actual_output["loss"]]), - atol=loss_atol, - rtol=loss_rtol, - ) - - # No logits are materialized - - # # Compare the logits from the last step - # assert_verbose_allclose( - # expected_output["logits"], - # actual_output["logits"], - # atol=logits_atol, - # rtol=logits_rtol, - # ) - - # Compare the params from the last step - # Iterate over the model's parameters and compare them - for expected_param, actual_param in zip( - expected_output["model"].named_parameters(), - actual_output["model"].named_parameters(), - ): - assert_verbose_allclose( - expected_param[1], actual_param[1], atol=param_atol, rtol=param_rtol - ) diff --git a/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json b/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json new file mode 100644 index 000000000..e784b6882 --- /dev/null +++ b/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json @@ -0,0 +1,55 @@ +{ + "added_tokens_decoder": { + "0": { + "content": "<|unk|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "<|vision_start|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "<|vision_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "3": { + "content": "<|vision_pad|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "4": { + "content": "<|image_pad|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "additional_special_tokens": ["<|im_start|>", "<|im_end|>", "<|object_ref_start|>","<|object_ref_end|>","<|box_start|>","<|box_end|>","<|quad_start|>","<|quad_end|>","<|vision_start|>","<|vision_end|>","<|vision_pad|>","<|image_pad|>","<|video_pad|>"], + "bos_token": null, + "chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}", + "clean_up_tokenization_spaces": false, + "eos_token": "<|im_end|>", + "padding_side": "left", + "errors": "replace", + "model_max_length": 32768, + "pad_token": "<|endoftext|>", + "split_special_tokens": false, + "unk_token": null + } \ No newline at end of file diff --git a/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json b/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json new file mode 100644 index 000000000..f760c041e --- /dev/null +++ b/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json @@ -0,0 +1,31 @@ +{ + "added_tokens_decoder": { + "0": { + "content": "<|unk|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "<|image|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "bos_token": "<|begin_of_text|>", + "chat_template": "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %} \n{%- for message in messages %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {%- set image_ns.has_images = true %}\n {%- endif %}\n {%- endfor %}\n{%- endfor %}\n\n{#- Error out if there are images and system message #}\n{%- if image_ns.has_images and not system_message == \"\" %}\n {{- raise_exception(\"Prompting with images is incompatible with system messages.\") }}\n{%- endif %}\n\n{#- System message if there are no images #}\n{%- if not image_ns.has_images %}\n {{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n {%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n {%- endif %}\n {{- \"Cutting Knowledge Date: December 2023\\n\" }}\n {{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n {%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- \"<|eot_id|>\" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n", + "clean_up_tokenization_spaces": true, + "eos_token": "<|eot_id|>", + "model_input_names": [ + "input_ids", + "attention_mask" + ], + "model_max_length": 131072, + "pad_token": "<|finetune_right_pad_id|>", + "tokenizer_class": "PreTrainedTokenizerFast" + } \ No newline at end of file diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 023736596..66bec37ee 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -1,7 +1,8 @@ -from test.utils import set_seed, supports_bfloat16 +from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16 import pytest import torch +import torch.nn.functional as F from torch.nn import CrossEntropyLoss from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction @@ -11,8 +12,63 @@ set_seed(42) -def _test_correctness_once(target_ce, B, T, V, reduction, scalar, dtype, atol, rtol): +class CrossEntropyWithZLoss(torch.nn.Module): + def __init__( + self, + lse_square_scale=0.0, + reduction="mean", + ignore_index=-100, + label_smoothing=0.0, + return_z_loss=False, + dtype=torch.float32, + ): + super().__init__() + self.lse_square_scale = lse_square_scale + self.reduction = reduction + self.ignore_index = ignore_index + self.return_z_loss = return_z_loss + self.label_smoothing = label_smoothing + self.dtype = dtype + + def forward(self, logits, targets): + # Loss calculations are all in float32 + logits = logits.to(torch.float32) + # Standard cross entropy loss + ce_loss = F.cross_entropy( + logits, + targets, + reduction=self.reduction, + label_smoothing=self.label_smoothing, + ignore_index=self.ignore_index, + ) + + # Compute log-sum-exp term + lse = torch.logsumexp(logits, dim=-1) + + # Z-loss term + z_loss = torch.where( + targets != self.ignore_index, self.lse_square_scale * (lse**2), 0.0 + ) + z_loss = z_loss.to(logits.dtype) + if self.reduction == "mean": + z_loss = z_loss.sum() / (targets != self.ignore_index).sum() + elif self.reduction == "sum": + z_loss = z_loss.sum() + else: + z_loss = z_loss + ce_loss = ce_loss.to(self.dtype) + z_loss = z_loss.to(self.dtype) + + # Final loss: cross-entropy loss + Z-loss + total_loss = ce_loss + z_loss + if self.return_z_loss: + return total_loss, z_loss + else: + return total_loss + +def _test_correctness_once(target_ce, B, T, V, reduction, scalar, dtype, atol, rtol): + torch.manual_seed(0) torch_ce = CrossEntropyLoss(reduction=reduction) _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar @@ -116,11 +172,24 @@ def _test_correctness_with_label_smoothing_with_ignore_index_once( assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) -def _test_correctness_with_label_smoothing_once( - target_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol +def _test_correctness_with_z_loss_once( + target_ce, + B, + T, + V, + scalar, + dtype, + atol, + rtol, + lse_square_scale, + return_z_loss, ): torch.manual_seed(0) - torch_ce = CrossEntropyLoss(label_smoothing=label_smoothing) + torch_ce = CrossEntropyWithZLoss( + lse_square_scale=lse_square_scale, + return_z_loss=return_z_loss, + dtype=dtype, + ) _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar _input = _tensor.detach().clone().requires_grad_(True) @@ -128,22 +197,48 @@ def _test_correctness_with_label_smoothing_once( target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) - output = torch_ce(_input, target) - output2 = target_ce(_input2, target) + if return_z_loss: + output, z_output = torch_ce(_input, target) + output2, z_output2 = target_ce(_input2, target) + + else: + output = torch_ce(_input, target) + output2 = target_ce(_input2, target) assert torch.allclose(output, output2, atol=atol, rtol=rtol) + if return_z_loss: + assert torch.allclose(z_output, z_output2, atol=atol, rtol=rtol) + output.backward() output2.backward() + assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) -def _test_correctness_with_label_smoothing_with_ignore_index_once( - target_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol +def _test_correctness_with_z_loss_with_other_params_once( + target_ce, + B, + T, + V, + scalar, + dtype, + atol, + rtol, + lse_square_scale, + return_z_loss, + label_smoothing, + ignore_index, + reduction, ): torch.manual_seed(0) - torch_ce = CrossEntropyLoss( - ignore_index=ignore_index, label_smoothing=label_smoothing + torch_ce = CrossEntropyWithZLoss( + lse_square_scale=lse_square_scale, + return_z_loss=return_z_loss, + label_smoothing=label_smoothing, + ignore_index=ignore_index, + reduction=reduction, + dtype=dtype, ) _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar @@ -161,14 +256,27 @@ def _test_correctness_with_label_smoothing_with_ignore_index_once( ] # Randomly select indices target[indices_to_assign] = ignore_index - output = torch_ce(_input, target) - output2 = target_ce(_input2, target) + if return_z_loss: + output, z_output = torch_ce(_input, target) + output2, z_output2 = target_ce(_input2, target) + + else: + output = torch_ce(_input, target) + output2 = target_ce(_input2, target) assert torch.allclose(output, output2, atol=atol, rtol=rtol) + if return_z_loss: + assert torch.allclose(z_output, z_output2, atol=atol, rtol=rtol) + output.backward() output2.backward() - assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) + print(_input.grad) + print(_input2.grad) + + print(f"{(_input.grad - _input2.grad).sum()=}") + + assert_verbose_allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) def _test_correctness_not_last_layer_once( @@ -204,10 +312,11 @@ def _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) - y1 = liger_cross_entropy(x1, target, 0) - y2 = LigerCrossEntropyFunction.apply(x2, target, 0) + y1, y1_z = liger_cross_entropy(x1, target, 0, 1e-4, 0.1, "mean", True) + y2, y2_z = LigerCrossEntropyFunction.apply(x2, target, 0, 1e-4, 0.1, "mean", True) assert torch.allclose(y1, y2, atol=atol, rtol=rtol) + assert torch.allclose(y1_z, y2_z, atol=atol, rtol=rtol) grad = torch.randn_like(y2) @@ -225,26 +334,14 @@ def _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): @pytest.mark.parametrize( "B, T, V", [ - (2, 4096, 32000), # llama2, mistral - (2, 4096, 32000), # llama2, mistral - (1, 4096, 128256), # llama3 - # # weird shapes - (3, 423, 32000), + (2, 4096, 32000), # llama + (3, 423, 32000), # weird shapes ], ) @pytest.mark.parametrize("reduction", ["sum", "mean"]) @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [ - pytest.param( - 0.1, - torch.bfloat16, - 1e-8, - 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), pytest.param( 1.0, torch.bfloat16, @@ -254,24 +351,9 @@ def _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), - pytest.param( - 10.0, - torch.bfloat16, - 1e-7, - 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), - (0.1, torch.float32, 1e-8, 1e-6), (1.0, torch.float32, 1e-8, 1e-6), - (10.0, torch.float32, 1e-8, 1e-6), ], ) -@pytest.mark.skipif( - torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000, - reason="Needs 16GB+ GPU memory.", -) def test_correctness(B, T, V, scalar, dtype, reduction, atol, rtol): liger_ce = LigerCrossEntropyLoss(reduction=reduction) _test_correctness_once(liger_ce, B, T, V, reduction, scalar, dtype, atol, rtol) @@ -288,12 +370,8 @@ def test_correctness(B, T, V, scalar, dtype, reduction, atol, rtol): @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [ - (0.1, torch.bfloat16, 1e-8, 5e-2), (1.0, torch.bfloat16, 1e-8, 5e-2), - (10.0, torch.bfloat16, 1e-7, 5e-2), - (0.1, torch.float32, 1e-8, 1e-6), (1.0, torch.float32, 1e-8, 1e-6), - (10.0, torch.float32, 1e-8, 1e-6), ], ) def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): @@ -303,9 +381,7 @@ def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): @pytest.mark.parametrize( "B, T, V, ignore_index", [ - (2, 4096, 32000, -100), # llama2, mistral - (2, 4096, 32000, 2), # llama2, mistral - (1, 4096, 128256, -300), # llama3 + (2, 4096, 32000, 2), # weird shapes (3, 423, 32000, -123), ], @@ -314,15 +390,6 @@ def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [ - pytest.param( - 0.1, - torch.bfloat16, - 1e-8, - 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), pytest.param( 1.0, torch.bfloat16, @@ -332,24 +399,9 @@ def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), - pytest.param( - 10.0, - torch.bfloat16, - 1e-8, - 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), - (0.1, torch.float32, 1e-8, 1e-6), (1.0, torch.float32, 1e-8, 1e-6), - (10.0, torch.float32, 1e-8, 1e-6), ], ) -@pytest.mark.skipif( - torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000, - reason="Needs 16GB+ GPU memory.", -) def test_correctness_with_ignore_index( B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol ): @@ -362,9 +414,7 @@ def test_correctness_with_ignore_index( @pytest.mark.parametrize( "B, T, V, label_smoothing", [ - (2, 4096, 32000, 0.1), # llama2, mistral - (2, 4096, 32000, 0.1), # llama2, mistral - (1, 4096, 128256, 0.1), # llama3 + (2, 4096, 32000, 0.1), # weird shapes (3, 423, 32000, 0.1), ], @@ -372,15 +422,6 @@ def test_correctness_with_ignore_index( @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [ - pytest.param( - 0.1, - torch.bfloat16, - 1e-8, - 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), pytest.param( 1.0, torch.bfloat16, @@ -390,24 +431,9 @@ def test_correctness_with_ignore_index( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), - pytest.param( - 10.0, - torch.bfloat16, - 1e-8, - 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), - (0.1, torch.float32, 1e-8, 1e-6), (1.0, torch.float32, 1e-8, 1e-6), - (10.0, torch.float32, 1e-8, 1e-6), ], ) -@pytest.mark.skipif( - torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000, - reason="Needs 16GB+ GPU memory.", -) def test_correctness_with_label_smoothing_once( B, T, V, label_smoothing, scalar, dtype, atol, rtol ): @@ -420,9 +446,7 @@ def test_correctness_with_label_smoothing_once( @pytest.mark.parametrize( "B, T, V, ignore_index, label_smoothing", [ - (2, 4096, 32000, 1, 0.1), # llama2, mistral - (2, 4096, 32000, -100, 0.2), # llama2, mistral - (1, 4096, 128256, 2, 0.1), # llama3 + (2, 4096, 32000, 1, 0.1), # weird shapes (3, 423, 32000, -300, 0.2), ], @@ -430,15 +454,6 @@ def test_correctness_with_label_smoothing_once( @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [ - pytest.param( - 0.1, - torch.bfloat16, - 1e-8, - 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), pytest.param( 1.0, torch.bfloat16, @@ -448,24 +463,9 @@ def test_correctness_with_label_smoothing_once( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), - pytest.param( - 10.0, - torch.bfloat16, - 1e-6, - 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), - (0.1, torch.float32, 1e-8, 1e-6), (1.0, torch.float32, 1e-8, 1e-6), - (10.0, torch.float32, 1e-8, 1e-6), ], ) -@pytest.mark.skipif( - torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000, - reason="Needs 16GB+ GPU memory.", -) def test_correctness_with_label_smoothing_with_ignore_index_once( B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol ): @@ -479,27 +479,17 @@ def test_correctness_with_label_smoothing_with_ignore_index_once( @pytest.mark.parametrize( - "B, T, V, label_smoothing", + "B, T, V", [ - (2, 4096, 32000, 0.1), # llama2, mistral - (2, 4096, 32000, 0.1), # llama2, mistral - (1, 4096, 128256, 0.1), # llama3 + (2, 4096, 32000), # llama2 # weird shapes - (3, 423, 32000, 0.1), + (3, 423, 32000), ], ) +@pytest.mark.parametrize("reduction", ["sum", "mean"]) @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [ - pytest.param( - 0.1, - torch.bfloat16, - 1e-8, - 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), pytest.param( 1.0, torch.bfloat16, @@ -509,55 +499,57 @@ def test_correctness_with_label_smoothing_with_ignore_index_once( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), - pytest.param( - 10.0, - torch.bfloat16, - 1e-8, - 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), - (0.1, torch.float32, 1e-8, 1e-6), (1.0, torch.float32, 1e-8, 1e-6), - (10.0, torch.float32, 1e-8, 1e-6), ], ) -@pytest.mark.skipif( - torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000, - reason="Needs 16GB+ GPU memory.", +@pytest.mark.parametrize("return_z_loss", [True, False]) +@pytest.mark.parametrize( + "lse_square_scale", + [ + 1e-4, # PaLM + 1e-5, # Chameleon + ], ) -def test_correctness_with_label_smoothing_once( - B, T, V, label_smoothing, scalar, dtype, atol, rtol +def test_correctness_with_z_loss_once( + B, + T, + V, + scalar, + dtype, + atol, + rtol, + lse_square_scale, + return_z_loss, ): - liger_ce = LigerCrossEntropyLoss(label_smoothing=label_smoothing) - _test_correctness_with_label_smoothing_once( - liger_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol + test_ce = LigerCrossEntropyLoss( + lse_square_scale=lse_square_scale, + return_z_loss=return_z_loss, + ) + _test_correctness_with_z_loss_once( + test_ce, + B, + T, + V, + scalar, + dtype, + atol, + rtol, + lse_square_scale, + return_z_loss, ) @pytest.mark.parametrize( - "B, T, V, ignore_index, label_smoothing", + "B, T, V", [ - (2, 4096, 32000, 1, 0.1), # llama2, mistral - (2, 4096, 32000, -100, 0.2), # llama2, mistral - (1, 4096, 128256, 2, 0.1), # llama3 + (2, 4096, 32000), # llama2, mistral # weird shapes - (3, 423, 32000, -300, 0.2), + (3, 423, 32000), ], ) @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [ - pytest.param( - 0.1, - torch.bfloat16, - 1e-8, - 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), pytest.param( 1.0, torch.bfloat16, @@ -567,33 +559,58 @@ def test_correctness_with_label_smoothing_once( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), - pytest.param( - 10.0, - torch.bfloat16, - 1e-8, - 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), - (0.1, torch.float32, 1e-8, 1e-6), (1.0, torch.float32, 1e-8, 1e-6), - (10.0, torch.float32, 1e-8, 1e-6), ], ) -@pytest.mark.skipif( - torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000, - reason="Needs 16GB+ GPU memory.", +@pytest.mark.parametrize( + "return_z_loss, lse_square_scale", + [ + (True, 1e-4), + (False, 1e-5), + ], ) -def test_correctness_with_label_smoothing_with_ignore_index_once( - B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol +@pytest.mark.parametrize( + "label_smoothing, ignore_index, reduction", + [ + (0.1, 42, "mean"), + (0.2, -42, "sum"), + ], +) +def test_correctness_with_z_loss_with_other_params_once( + B, + T, + V, + scalar, + dtype, + atol, + rtol, + lse_square_scale, + return_z_loss, + label_smoothing, + ignore_index, + reduction, ): - liger_ce = LigerCrossEntropyLoss( - ignore_index=ignore_index, + test_ce = LigerCrossEntropyLoss( + lse_square_scale=lse_square_scale, + return_z_loss=return_z_loss, label_smoothing=label_smoothing, + ignore_index=ignore_index, + reduction=reduction, ) - _test_correctness_with_label_smoothing_with_ignore_index_once( - liger_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol + _test_correctness_with_z_loss_with_other_params_once( + test_ce, + B, + T, + V, + scalar, + dtype, + atol, + rtol, + lse_square_scale, + return_z_loss, + label_smoothing, + ignore_index, + reduction, ) @@ -601,8 +618,6 @@ def test_correctness_with_label_smoothing_with_ignore_index_once( "B, T, V", [ (2, 4096, 32000), # llama2, mistral - (2, 4096, 32000), # llama2, mistral - (1, 4096, 128256), # llama3 # # weird shapes (3, 423, 32000), ], @@ -623,52 +638,8 @@ def test_correctness_with_label_smoothing_with_ignore_index_once( (1.0, torch.float32, 1e-8, 1e-6), ], ) -@pytest.mark.skipif( - torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000, - reason="Needs 16GB+ GPU memory.", -) def test_correctness_not_last_layer(B, T, V, reduction, scalar, dtype, atol, rtol): liger_ce = LigerCrossEntropyLoss(reduction=reduction) _test_correctness_not_last_layer_once( liger_ce, B, T, V, reduction, scalar, dtype, atol, rtol ) - - -############################################################################# -# Test full pass of the liger cross entropy loss to ensure it doesn't crash -############################################################################# - - -def _full_pass_once(B, T, V, reduction): - - liger_ce = LigerCrossEntropyLoss(reduction=reduction) - - _input = torch.randn( - B * T, V, requires_grad=True, device="cuda", dtype=torch.bfloat16 - ) - target = torch.randint(V, (B * T, 1), device="cuda").squeeze(1) - - output = liger_ce(_input, target) - output.backward() - - -@pytest.mark.parametrize( - "B, T, V", - [ - ( - 8, - 8192, - 128256, - ), # _input = 16GB, total = ~32GB, 8405385216 > 2,147,483,647, so we need int64 - (8, 16384, 128256), # _input = 32GB, total = ~64GB - ], -) -@pytest.mark.parametrize("reduction", ["sum", "mean"]) -@pytest.mark.skipif( - torch.cuda.get_device_properties(0).total_memory < 64 * 1000 * 1000 * 1000, - reason="Needs 64GB+ GPU memory.", -) -def test_large_no_exception(B, T, V, reduction): - # The large inputs were hitting cuda illegal memory access because of - # https://github.com/triton-lang/triton/issues/1058 - _full_pass_once(B, T, V, reduction) diff --git a/test/transformers/test_embedding.py b/test/transformers/test_embedding.py index b192835e3..998a544c5 100644 --- a/test/transformers/test_embedding.py +++ b/test/transformers/test_embedding.py @@ -7,6 +7,7 @@ SLEEP_SECONDS = 0.1 +@pytest.mark.skip(reason="LigerEmbedding is under experimentation") @pytest.mark.parametrize( "num_embeddings, embedding_dim, padding_idx", [ diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index 57e2cf534..2be9c9d10 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -1,3 +1,4 @@ +from test.transformers.test_cross_entropy import CrossEntropyWithZLoss from test.utils import assert_verbose_allclose, set_seed import pytest @@ -22,6 +23,12 @@ class TorchLMHeadCE(torch.nn.Module): :param V: vocab size :param ignore_index: index to ignore :param reduction: reduction method + :param label_smoothing: label_smoothing to apply on target + :param lse_square_scale: scaler of lse ^ 2 to compute z loss + + # TODO: if we bump CI env's `transformers` version to >= 4.46, we should just directly + # call https://github.com/huggingface/transformers/blob/main/src/transformers/loss/loss_utils.py#L32 + # to be consistent with Hugging Face model implementation. """ def __init__( @@ -31,6 +38,7 @@ def __init__( dtype: torch.dtype, bias: bool = False, ignore_index: int = -100, + lse_square_scale: float = 0.0, label_smoothing: float = 0.0, reduction: str = "mean", ): @@ -38,14 +46,15 @@ def __init__( self.lin = torch.nn.Linear( in_features=H, out_features=V, bias=bias, dtype=dtype ) - self.ce_loss = torch.nn.CrossEntropyLoss( + self.ce_loss = CrossEntropyWithZLoss( ignore_index=ignore_index, - reduction=reduction, + lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, + reduction=reduction, ) def forward(self, x, y): - logits = self.lin(x) + logits = self.lin(x).to(torch.float32) return self.ce_loss(logits, y) @@ -57,6 +66,7 @@ def __init__( dtype: torch.dtype, bias: bool = False, ignore_index: int = -100, + lse_square_scale: float = 0.0, label_smoothing: float = 0.0, reduction: str = "mean", ): @@ -66,8 +76,9 @@ def __init__( ) self.ce_loss = LigerFusedLinearCrossEntropyLoss( ignore_index=ignore_index, - reduction=reduction, + lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, + reduction=reduction, ) def forward(self, x, y): @@ -82,12 +93,8 @@ def forward(self, x, y): @pytest.mark.parametrize( "B, T, H, V", [ - # (2, 4, 512, 512), # The test does not work on some CI GPUs. Issue #160 - (8, 2048, 4096, 32000), # llama2, mistral - # Comment out to speed up testing - # (4, 2048, 4096, 128256), # llama3 8B - # (4, 1024, 8192, 128256), # llama3 70B - (4, 423, 8192, 32000), # random shape + (8, 128, 1024, 4096), + (4, 47, 31, 123), # random shape ], ) @pytest.mark.parametrize( @@ -100,16 +107,36 @@ def forward(self, x, y): ], ) @pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("label_smoothing", [0, 0.1]) +@pytest.mark.parametrize( + "label_smoothing, ignore_index, lse_square_scale", + [ + (0, -100, 0), + (0.1, 42, 1e-4), # Pass non-default values once to ensure all params work along + ], +) def test_correctness( - B, T, H, V, scalar, dtype, bias, label_smoothing, reduction, atol, rtol + B, + T, + H, + V, + scalar, + dtype, + bias, + lse_square_scale, + label_smoothing, + ignore_index, + reduction, + atol, + rtol, ): device = "cuda" torch_lm_head_ce = TorchLMHeadCE( H=H, V=V, bias=bias, + lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, + ignore_index=ignore_index, reduction=reduction, dtype=dtype, ).to(device) @@ -117,7 +144,9 @@ def test_correctness( H=H, V=V, bias=bias, + lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, + ignore_index=ignore_index, reduction=reduction, dtype=dtype, ).to(device) @@ -137,6 +166,14 @@ def test_correctness( _input2 = _tensor.detach().clone().requires_grad_(True) target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[ + :num_elements_to_assign + ] # Randomly select indices + target[indices_to_assign] = ignore_index output1 = torch_lm_head_ce(_input1, target) output2 = liger_lm_head_ce(_input2, target) @@ -203,3 +240,68 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, bias, atol, rtol): y2.backward(grad_output) assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (4, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "cast_dtype, atol, rtol", + [ + (torch.bfloat16, 5e-3, 5e-2), + (torch.float16, 5e-3, 5e-2), + ], +) +def test_amp(B, T, H, V, cast_dtype, atol, rtol): + device = "cuda" + dtype = torch.float32 + torch_lm_head_ce = TorchLMHeadCE( + H=H, + V=V, + bias=True, + label_smoothing=0.0, + reduction="mean", + dtype=dtype, + ).to(device) + liger_lm_head_ce = LigerLMHeadCE( + H=H, + V=V, + bias=True, + label_smoothing=0.0, + reduction="mean", + dtype=dtype, + ).to(device) + + # init the linear in all CEs with the same weights + torch_lm_head_ce.lin.weight.data = liger_lm_head_ce.lin.weight.data = torch.rand( + V, H, device=device, dtype=dtype + ) + + _tensor = torch.randn(B * T, H, device=device, dtype=dtype) + _input1 = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + with torch.autocast(device_type="cuda", dtype=cast_dtype): + output1 = torch_lm_head_ce(_input1, target) + output2 = liger_lm_head_ce(_input2, target) + + assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) + + with torch.autocast(device_type="cuda", dtype=cast_dtype): + output1.backward() + output2.backward() + + assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) + + assert_verbose_allclose( + torch_lm_head_ce.lin.weight.grad, + liger_lm_head_ce.lin.weight.grad, + atol=atol, + rtol=rtol, + ) diff --git a/test/transformers/test_fused_linear_jsd.py b/test/transformers/test_fused_linear_jsd.py new file mode 100644 index 000000000..31a3ea103 --- /dev/null +++ b/test/transformers/test_fused_linear_jsd.py @@ -0,0 +1,474 @@ +from test.transformers.test_jsd import JSD as TorchJSD +from test.utils import assert_verbose_allclose, set_seed + +import pytest +import torch + +from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction +from liger_kernel.transformers.functional import liger_fused_linear_jsd +from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD + +set_seed(42) + + +class TorchLMHeadJSD(torch.nn.Module): + """Ground truth implementation of the linear fused with torch based jsd loss. + + :param H: hidden size + :param V: vocab size + :param temperature: softmax temperature + :param beta: jsd beta + """ + + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + device: torch.device, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + ): + super().__init__() + self.student_lin = torch.nn.Linear( + in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device + ) + self.teacher_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype, device=device + ) + self.jsd = TorchJSD(beta=beta, ignore_index=ignore_index, dtype=dtype) + self.temperature = temperature + + def forward(self, student_input, teacher_input, label=None): + student_logits = self.student_lin(student_input).to(torch.float32) + teacher_logits = self.teacher_lin(teacher_input).to(torch.float32) + student_prob = torch.log_softmax(student_logits / self.temperature, dim=-1) + teacher_prob = torch.log_softmax(teacher_logits / self.temperature, dim=-1) + + return self.jsd(student_prob, teacher_prob, label) + + +class LigerLMHeadJSD(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + device: torch.device, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + ): + super().__init__() + self.student_lin = torch.nn.Linear( + in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device + ) + self.teacher_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype, device=device + ) + self.fused_jsd = LigerFusedLinearJSD( + jsd_beta=beta, ignore_index=ignore_index, temperature=temperature + ) + + def forward(self, student_input, teacher_input, label=None): + return self.fused_jsd( + student_input, + self.student_lin.weight, + teacher_input, + self.teacher_lin.weight, + label, + ) + + +############################################################################# +# Test the correctness of the fused linear JSD +############################################################################# + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (4, 423, 167, 1423), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-2), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize( + "temperature, beta", + [ + (1.0, 0.5), + (2.0, 0.1), + ], +) +def test_correctness(B, T, H, V, scalar, dtype, beta, temperature, atol, rtol): + device = "cuda" + torch_lm_head_jsd = TorchLMHeadJSD( + H=H, + V=V, + dtype=dtype, + device=device, + temperature=temperature, + beta=beta, + ).to(device) + liger_lm_head_jsd = LigerLMHeadJSD( + H=H, + V=V, + dtype=dtype, + device=device, + temperature=temperature, + beta=beta, + ).to(device) + + # init the linear in all FusedLinearJSDs with the same weights + torch_lm_head_jsd.student_lin.weight.data = ( + liger_lm_head_jsd.student_lin.weight.data + ) = torch.rand(V, H // 2, device=device, dtype=dtype) + torch_lm_head_jsd.teacher_lin.weight.data = ( + liger_lm_head_jsd.teacher_lin.weight.data + ) = torch.rand(V, H, device=device, dtype=dtype) + + _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar + _input1 = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar + + with torch.autograd.detect_anomaly(): + output1 = torch_lm_head_jsd(_input1, teacher_input) + output2 = liger_lm_head_jsd(_input2, teacher_input) + + assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) + + output1.backward() + output2.backward() + + assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) + + assert_verbose_allclose( + torch_lm_head_jsd.student_lin.weight.grad, + liger_lm_head_jsd.student_lin.weight.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (4, 423, 167, 1423), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-2), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize( + "temperature, beta, ignore_index", + [ + (1.0, 0.5, 2), + (2.0, 0.1, 42), + ], +) +def test_correctness_with_ignore_index( + B, T, H, V, scalar, dtype, beta, ignore_index, temperature, atol, rtol +): + device = "cuda" + torch_lm_head_jsd = TorchLMHeadJSD( + H=H, + V=V, + dtype=dtype, + device=device, + temperature=temperature, + ignore_index=ignore_index, + beta=beta, + ).to(device) + liger_lm_head_jsd = LigerLMHeadJSD( + H=H, + V=V, + dtype=dtype, + device=device, + temperature=temperature, + ignore_index=ignore_index, + beta=beta, + ).to(device) + + # init the linear in all FusedLinearJSDs with the same weights + torch_lm_head_jsd.student_lin.weight.data = ( + liger_lm_head_jsd.student_lin.weight.data + ) = torch.rand(V, H // 2, device=device, dtype=dtype) + torch_lm_head_jsd.teacher_lin.weight.data = ( + liger_lm_head_jsd.teacher_lin.weight.data + ) = torch.rand(V, H, device=device, dtype=dtype) + + _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar + _input1 = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar + + label = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[ + :num_elements_to_assign + ] # Randomly select indices + label[indices_to_assign] = ignore_index + + output1 = torch_lm_head_jsd(_input1, teacher_input, label) + output2 = liger_lm_head_jsd(_input2, teacher_input, label) + + assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) + + output1.backward() + output2.backward() + + assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) + + assert_verbose_allclose( + torch_lm_head_jsd.student_lin.weight.grad, + liger_lm_head_jsd.student_lin.weight.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 2, 8, 8), + # weird shapes + (9, 7, 41, 41), + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (0.5, torch.bfloat16, 5e-3, 5e-2), + (0.5, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize( + "temperature, beta, ignore_index", [(1.0, 0.5, -100), (2.0, 0.1, 42)] +) +def test_correctness_functional( + B, T, H, V, scalar, dtype, beta, ignore_index, temperature, atol, rtol +): + device = "cuda" + + # init the linear in all FusedLinearJSDs with the same weights + _weight = torch.rand(V, H // 2, device=device, dtype=dtype) + _weight1 = _weight.detach().clone().requires_grad_(True) + _weight2 = _weight.detach().clone().requires_grad_(True) + teacher_weight = torch.rand(V, H, device=device, dtype=dtype) + + _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar + _input1 = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar + + label = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[ + :num_elements_to_assign + ] # Randomly select indices + label[indices_to_assign] = ignore_index + + output1 = liger_fused_linear_jsd( + _input1, + _weight1, + teacher_input, + teacher_weight, + label, + beta, + ignore_index, + temperature, + ) + output2 = LigerFusedLinearJSDFunction.apply( + _input2, + _weight2, + teacher_input, + teacher_weight, + label, + beta, + ignore_index, + temperature, + ) + + assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) + + output1.backward() + output2.backward() + + assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) + + assert_verbose_allclose(_weight1.grad, _weight2.grad, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (4, 423, 167, 1423), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-2), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize( + "temperature, beta, ignore_index", + [ + (1.0, 0.5, 2), + (2.0, 0.1, 42), + ], +) +def test_correctness_all_ignored( + B, T, H, V, scalar, dtype, beta, ignore_index, temperature, atol, rtol +): + device = "cuda" + torch_lm_head_jsd = TorchLMHeadJSD( + H=H, + V=V, + dtype=dtype, + device=device, + temperature=temperature, + ignore_index=ignore_index, + beta=beta, + ).to(device) + liger_lm_head_jsd = LigerLMHeadJSD( + H=H, + V=V, + dtype=dtype, + device=device, + temperature=temperature, + ignore_index=ignore_index, + beta=beta, + ).to(device) + + # init the linear in all FusedLinearJSDs with the same weights + torch_lm_head_jsd.student_lin.weight.data = ( + liger_lm_head_jsd.student_lin.weight.data + ) = torch.rand(V, H // 2, device=device, dtype=dtype) + torch_lm_head_jsd.teacher_lin.weight.data = ( + liger_lm_head_jsd.teacher_lin.weight.data + ) = torch.rand(V, H, device=device, dtype=dtype) + + _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar + _input1 = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar + + label = torch.full((B * T,), ignore_index, device=device, dtype=torch.long) + + output1 = torch_lm_head_jsd(_input1, teacher_input, label) + output2 = liger_lm_head_jsd(_input2, teacher_input, label) + + assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) + assert_verbose_allclose(output2, torch.zeros_like(output2), atol=atol, rtol=rtol) + + output2.backward() + + assert_verbose_allclose( + torch.zeros_like(_input2.grad), _input2.grad, atol=atol, rtol=rtol + ) + + +@pytest.mark.parametrize( + "autocast_dtype, atol, rtol", + [ + (torch.bfloat16, 5e-3, 5e-2), + (torch.float16, 5e-3, 5e-2), + ], +) +def test_amp(autocast_dtype, atol, rtol): + B = 2 + T = 4 + H = 2048 + V = 3200 + scalar = 1.0 + ignore_index = -100 + temperature = 1.0 + beta = 0.5 + device = "cuda" + dtype = torch.float32 + torch_lm_head_jsd = TorchLMHeadJSD( + H=H, + V=V, + dtype=dtype, + device=device, + temperature=temperature, + ignore_index=ignore_index, + beta=beta, + ).to(device) + liger_lm_head_jsd = LigerLMHeadJSD( + H=H, + V=V, + dtype=dtype, + device=device, + temperature=temperature, + ignore_index=ignore_index, + beta=beta, + ).to(device) + # init the linear in all FusedLinearJSDs with the same weights + torch_lm_head_jsd.student_lin.weight.data = ( + liger_lm_head_jsd.student_lin.weight.data + ) = torch.rand(V, H // 2, device=device, dtype=dtype) + torch_lm_head_jsd.teacher_lin.weight.data = ( + liger_lm_head_jsd.teacher_lin.weight.data + ) = torch.rand(V, H, device=device, dtype=dtype) + + _tensor = torch.rand(B * T, H // 2, device=device, dtype=autocast_dtype) * scalar + _input1 = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + teacher_input = torch.rand(B * T, H, device=device, dtype=autocast_dtype) * scalar + + label = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[ + :num_elements_to_assign + ] # Randomly select indices + label[indices_to_assign] = ignore_index + + with torch.autocast(device_type="cuda", dtype=autocast_dtype): + output1 = torch_lm_head_jsd(_input1, teacher_input, label) + output2 = liger_lm_head_jsd(_input2, teacher_input, label) + + assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) + + output1.backward() + output2.backward() + + assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose( + torch_lm_head_jsd.student_lin.weight.grad, + liger_lm_head_jsd.student_lin.weight.grad, + atol=atol, + rtol=rtol, + ) diff --git a/test/transformers/test_geglu.py b/test/transformers/test_geglu.py index 4fa744656..cf7c5a3c5 100644 --- a/test/transformers/test_geglu.py +++ b/test/transformers/test_geglu.py @@ -20,11 +20,9 @@ @pytest.mark.parametrize( "bsz, seq_len, hidden_size, intermediate_size", [ - (2, 2048, 4096, 11008), (2, 2048, 2048, 4096), # weird shapes (9, 41, 341, 4231), - (6, 42, 256, 2048), ], ) @pytest.mark.parametrize( diff --git a/test/transformers/test_jsd.py b/test/transformers/test_jsd.py new file mode 100644 index 000000000..388b3a5c3 --- /dev/null +++ b/test/transformers/test_jsd.py @@ -0,0 +1,329 @@ +from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16 +from typing import Optional + +import pytest +import torch +from torch.nn import KLDivLoss + +from liger_kernel.transformers.functional import liger_jsd +from liger_kernel.transformers.jsd import LigerJSD, LigerJSDFunction + +set_seed(42) + + +class JSD(torch.nn.Module): + def __init__( + self, + beta: float = 0.5, + ignore_index: int = -100, + dtype: torch.dtype = torch.float, + ): + super(JSD, self).__init__() + self.kl = KLDivLoss(reduction="none", log_target=True) + self.beta = beta + self.ignore_index = ignore_index + self.dtype = dtype + + def forward( + self, + log_q: torch.Tensor, # input + log_p: torch.Tensor, # target + label: Optional[torch.Tensor] = None, + ): + log_p, log_q = log_p.to(torch.float), log_q.to(torch.float) + log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1)) + m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta) + loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + ( + 1 - self.beta + ) * self.kl(torch.log(m), log_q).sum(dim=-1) + + if label is not None: + loss = torch.where(label != self.ignore_index, loss, 0.0) + n_non_ignore = (label != self.ignore_index).sum().item() + if n_non_ignore == 0: + loss = torch.tensor(0.0).to(loss.device) + else: + loss = (loss / n_non_ignore).sum() + else: + loss = (loss / log_q.shape[0]).sum() + return loss.to(self.dtype) + + +_SHAPE_PARAMS = ( + "B, T, V", + [ + (2, 1024, 3200), + # weird shape + (41, 401, 1271), + ], +) + +_DTYPE_PARAMS = ( + "dtype, atol, rtol", + [ + pytest.param( + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + (torch.float32, 1e-8, 1e-6), + (torch.float16, 1e-3, 1e-3), + ], +) + + +def _test_correctness_once( + target_jsd, + B, + T, + V, + dtype, + atol, + rtol, + is_last_layer=True, + device="cuda", +): + torch_jsd = JSD(dtype=dtype) + + input = torch.randn( + B * T, V, device=device, dtype=dtype, requires_grad=True + ).log_softmax(dim=-1) + + x1 = input.detach().clone().requires_grad_(True) + x2 = input.detach().clone().requires_grad_(True) + x3 = input.detach().clone().requires_grad_(True) + + with torch.no_grad(): + target = torch.randn(B * T, V, dtype=dtype, device=device).log_softmax(dim=-1) + + output = torch_jsd(x1, target) + output2 = target_jsd(x2, target) + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + # symmetry + output3 = target_jsd(target, x3) + assert torch.allclose(output3, output2, atol=atol, rtol=rtol) + if ( + not is_last_layer + ): # if the loss is the last layer, grad_output is 1.0 and mul op is skipped, testing for that reason + output = output * 2.0 + output2 = output2 * 2.0 + + output.backward() + output2.backward() + assert_verbose_allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) + + +def _test_correctness_with_beta_once( + target_jsd, + beta, + B, + T, + V, + dtype, + atol, + rtol, + is_last_layer=True, + device="cuda", +): + torch_jsd = JSD(beta=beta, dtype=dtype) + + input = torch.randn( + B * T, V, device=device, dtype=dtype, requires_grad=True + ).log_softmax(dim=-1) + + x1 = input.detach().clone().requires_grad_(True) + x2 = input.detach().clone().requires_grad_(True) + + with torch.no_grad(): + target = torch.randn(B * T, V, dtype=dtype, device=device).log_softmax(dim=-1) + + output = torch_jsd(x1, target) + output2 = target_jsd(x2, target) + assert_verbose_allclose(output, output2, atol=atol, rtol=rtol) + if ( + not is_last_layer + ): # if the loss is the last layer, grad_output is 1.0 and mul op is skipped, testing for that reason + output = output * 2.0 + output2 = output2 * 2.0 + + output.backward() + output2.backward() + assert_verbose_allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) + + +def _test_correctness_with_ignore_index_once( + target_jsd, + ignore_index, + B, + T, + V, + dtype, + atol, + rtol, + device="cuda", +): + torch_jsd = JSD(ignore_index=ignore_index, dtype=dtype) + + input = torch.randn( + B * T, V, device=device, dtype=dtype, requires_grad=True + ).log_softmax(dim=-1) + + x1 = input.detach().clone().requires_grad_(True) + x2 = input.detach().clone().requires_grad_(True) + + with torch.no_grad(): + target = torch.randn(B * T, V, dtype=dtype, device=device).log_softmax(dim=-1) + + label = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[ + :num_elements_to_assign + ] # Randomly select indices + label[indices_to_assign] = ignore_index + + output = torch_jsd(x1, target, label) + output2 = target_jsd(x2, target, label) + assert_verbose_allclose(output, output2, atol=atol, rtol=rtol) + + output.backward() + output2.backward() + assert_verbose_allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) + + +def _test_correctness_functional( + B, T, V, beta, ignore_index, is_last_layer, dtype, atol, rtol, device="cuda" +): + input = torch.randn( + B * T, V, device=device, dtype=dtype, requires_grad=True + ).log_softmax(dim=-1) + + x1 = input.detach().clone().requires_grad_(True) + x2 = input.detach().clone().requires_grad_(True) + + with torch.no_grad(): + target = torch.randn(B * T, V, dtype=dtype, device=device).log_softmax(dim=-1) + + label = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[ + :num_elements_to_assign + ] # Randomly select indices + label[indices_to_assign] = ignore_index + + output = LigerJSDFunction.apply(x1, target, label, beta, ignore_index) + output2 = liger_jsd(x2, target, label, beta, ignore_index) + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + if ( + not is_last_layer + ): # if the loss is the last layer, grad_output is 1.0 and mul op is skipped, testing for that reason + output = output * 2.0 + output2 = output2 * 2.0 + output.backward() + output2.backward() + assert_verbose_allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize(*_SHAPE_PARAMS) +@pytest.mark.parametrize(*_DTYPE_PARAMS) +def test_correctness(B, T, V, dtype, atol, rtol): + liger_jsd = LigerJSD() + _test_correctness_once(liger_jsd, B, T, V, dtype, atol, rtol) + + +@pytest.mark.parametrize(*_SHAPE_PARAMS) +@pytest.mark.parametrize(*_DTYPE_PARAMS) +def test_correctness_not_last(B, T, V, dtype, atol, rtol): + liger_jsd = LigerJSD() + + _test_correctness_once(liger_jsd, B, T, V, dtype, atol, rtol, is_last_layer=False) + + +@pytest.mark.parametrize(*_SHAPE_PARAMS) +@pytest.mark.parametrize(*_DTYPE_PARAMS) +@pytest.mark.parametrize("beta", [0.1, 0.5, 0.9]) +def test_correctness_with_beta(B, T, V, beta, dtype, atol, rtol): + liger_jsd = LigerJSD(beta=beta) + _test_correctness_with_beta_once(liger_jsd, beta, B, T, V, dtype, atol, rtol) + + +@pytest.mark.parametrize(*_SHAPE_PARAMS) +@pytest.mark.parametrize(*_DTYPE_PARAMS) +@pytest.mark.parametrize("ignore_index", [2, 42]) +def test_correctness_with_ignore_index(B, T, V, ignore_index, dtype, atol, rtol): + liger_jsd = LigerJSD(ignore_index=ignore_index) + _test_correctness_with_ignore_index_once( + liger_jsd, ignore_index, B, T, V, dtype, atol, rtol + ) + + +@pytest.mark.parametrize(*_SHAPE_PARAMS) +@pytest.mark.parametrize(*_DTYPE_PARAMS) +@pytest.mark.parametrize( + "beta, ignore_index, is_last_layer", + [ + (0.5, 2, False), + (0.1, 42, True), + ], +) +def test_correctness_functional( + B, T, V, beta, ignore_index, is_last_layer, dtype, atol, rtol +): + _test_correctness_functional( + B, T, V, beta, ignore_index, is_last_layer, dtype, atol, rtol + ) + + +# @pytest.mark.parametrize(*_SHAPE_PARAMS) +def test_correctness_with_all_indices_ignored( + B=2, + T=10, + V=32, + dtype=torch.bfloat16, + atol=1e-3, + rtol=1e-3, + device="cuda", +): + ignore_index = -100 + torch_jsd = JSD(ignore_index=ignore_index, dtype=dtype) + liger_jsd = LigerJSD(ignore_index=ignore_index) + + inp = torch.randn( + B * T, V, device=device, dtype=dtype, requires_grad=True + ).log_softmax(dim=-1) + + x1 = inp.detach().clone().requires_grad_(True) + x2 = inp.detach().clone().requires_grad_(True) + + with torch.no_grad(): + target = torch.randn(B * T, V, dtype=dtype, device=device).log_softmax(dim=-1) + + # label = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + label = torch.full((B * T,), ignore_index, device=device, dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[ + :num_elements_to_assign + ] # Randomly select indices + label[indices_to_assign] = ignore_index + + output = torch_jsd(x1, target, label) + output2 = liger_jsd(x2, target, label) + assert_verbose_allclose(output, output2, atol=atol, rtol=rtol) + assert_verbose_allclose(torch.zeros_like(output2), output2, atol=atol, rtol=rtol) + + output2.backward() + assert_verbose_allclose(torch.zeros_like(x2.grad), x2.grad, atol=atol, rtol=rtol) diff --git a/test/transformers/test_kl_div.py b/test/transformers/test_kl_div.py index db29047c7..5cc3eba6a 100644 --- a/test/transformers/test_kl_div.py +++ b/test/transformers/test_kl_div.py @@ -1,4 +1,4 @@ -from test.utils import assert_verbose_allclose, supports_bfloat16 +from test.utils import supports_bfloat16 import pytest import torch @@ -10,20 +10,8 @@ "B, T, V", [ (1, 4096, 32000), - (32, 4096, 1024), # weird shape (41, 401, 1271), - pytest.param( - 1, - 4096, - 128256, - marks=pytest.mark.skipif( - torch.cuda.get_device_properties(0).total_memory - < 36 * 1000 * 1000 * 1000, - reason="This test requires a GPU with at least 36GB of memory", - ), - ), - (3, 423, 32000), ], ) @@ -72,7 +60,7 @@ def _test_correctness_once( output = torch_kldiv(x1, target) output2 = target_kldiv(x2, target) - assert_verbose_allclose(output, output2, atol=atol, rtol=rtol) + assert torch.allclose(output, output2, atol=atol, rtol=rtol) if ( not is_last_layer @@ -85,12 +73,12 @@ def _test_correctness_once( output.backward() output2.backward() - assert_verbose_allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) + assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) @pytest.mark.parametrize(*_SHAPE_PARAMS) -@pytest.mark.parametrize("log_target", [False, True]) -@pytest.mark.parametrize("reduction", ["none", "batchmean", "mean", "sum"]) +@pytest.mark.parametrize("log_target", [True, False]) +@pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"]) @pytest.mark.parametrize(*_DTYPE_PARAMS) def test_correctness(B, T, V, log_target, reduction, dtype, atol, rtol): liger_kldiv = LigerKLDIVLoss(reduction=reduction, log_target=log_target) @@ -100,8 +88,8 @@ def test_correctness(B, T, V, log_target, reduction, dtype, atol, rtol): @pytest.mark.parametrize(*_SHAPE_PARAMS) -@pytest.mark.parametrize("log_target", [False, True]) -@pytest.mark.parametrize("reduction", ["none", "batchmean", "mean", "sum"]) +@pytest.mark.parametrize("log_target", [True, False]) +@pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"]) @pytest.mark.parametrize(*_DTYPE_PARAMS) def test_correctness_not_last(B, T, V, log_target, reduction, dtype, atol, rtol): liger_kldiv = LigerKLDIVLoss(reduction=reduction, log_target=log_target) diff --git a/test/transformers/test_layer_norm.py b/test/transformers/test_layer_norm.py index 840fd1155..3132c0d50 100644 --- a/test/transformers/test_layer_norm.py +++ b/test/transformers/test_layer_norm.py @@ -7,20 +7,10 @@ @pytest.mark.parametrize( - "hidden_size", + "batch_size, seq_len, hidden_size", [ - 64, - 128, - 256, - 512, - ], -) -@pytest.mark.parametrize( - "batch_size, seq_len", - [ - (2, 8), - (4, 16), - (8, 32), + (2, 8, 64), + (4, 16, 128), ], ) @pytest.mark.parametrize( @@ -33,9 +23,11 @@ def test_liger_layer_norm(batch_size, seq_len, hidden_size, dtype, atol, rtol): torch.manual_seed(0) - x = torch.randn( - batch_size, seq_len, hidden_size, dtype=dtype, device="cuda", requires_grad=True - ) + x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") + + liger_x = x.clone().requires_grad_(True) + torch_x = x.clone().requires_grad_(True) + liger_ln = LigerLayerNorm(hidden_size, eps=1e-6).to(dtype).cuda() torch_ln = torch.nn.LayerNorm(hidden_size, eps=1e-6).to(dtype).cuda() @@ -43,8 +35,8 @@ def test_liger_layer_norm(batch_size, seq_len, hidden_size, dtype, atol, rtol): torch_ln.weight.copy_(liger_ln.weight) torch_ln.bias.copy_(liger_ln.bias) - liger_output = liger_ln(x) - torch_output = torch_ln(x) + liger_output = liger_ln(liger_x) + torch_output = torch_ln(torch_x) assert torch.allclose(liger_output, torch_output, atol=atol, rtol=rtol) @@ -52,7 +44,7 @@ def test_liger_layer_norm(batch_size, seq_len, hidden_size, dtype, atol, rtol): liger_output.backward(grad_output, retain_graph=True) torch_output.backward(grad_output, retain_graph=True) - assert torch.allclose(x.grad, x.grad, atol=atol, rtol=rtol) + assert torch.allclose(liger_x.grad, torch_x.grad, atol=atol, rtol=rtol) assert torch.allclose( liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol ) @@ -60,14 +52,10 @@ def test_liger_layer_norm(batch_size, seq_len, hidden_size, dtype, atol, rtol): @pytest.mark.parametrize( - "hidden_size", - [8, 41], -) -@pytest.mark.parametrize( - "batch_size, seq_len", + "batch_size, seq_len, hidden_size", [ - (2, 2), - (9, 7), + (2, 8, 64), + (4, 16, 128), ], ) @pytest.mark.parametrize( diff --git a/test/transformers/test_mm_int8int2.py b/test/transformers/test_mm_int8int2.py new file mode 100644 index 000000000..d7d13a958 --- /dev/null +++ b/test/transformers/test_mm_int8int2.py @@ -0,0 +1,106 @@ +import pytest +import torch + +from liger_kernel.ops.experimental.mm_int8int2 import ( + matmul, + pack_weights, + unpack_weights, +) + + +# input_features = size*4 when the weight matrix is unpacked +@pytest.mark.skip(reason="mm_int8int2 is under experimentation") +@pytest.mark.parametrize( + "size", + [ + 2048, + 1024, + 512, + ], +) +@pytest.mark.parametrize( + "batch_size", + [1, 2, 3, 8], +) +@pytest.mark.parametrize( + "seq_len", + [1, 7, 16, 2048], +) +@pytest.mark.parametrize( + "out_features", + [ + 1024, + 2048, + 4096, + 10000, + ], +) +@pytest.mark.parametrize( + "atol, rtol, device", + [ + (1e-2, 1e-2, "cuda"), + ], +) +def test_kernel_correctness( + batch_size, seq_len, out_features, size, atol, rtol, device +): + print(f"\nTesting kernel with size: {size}, atol: {atol}, rtol: {rtol}") + + # Generate the random tensors + ht = torch.randint( + -127, 127, (batch_size, seq_len, size * 4), device=device, dtype=torch.int8 + ) + u = torch.randint(0, 255, (out_features, size), device=device, dtype=torch.uint8) + + # Calculate dimensions + B, M, N = ht.size() + + # Compute triton output + triton_output = matmul(ht.view(B * M, N), u.T.contiguous()).view(B, M, -1) + + # Unpack weights and compute torch output + unpacked = unpack_weights(u.T, bits=2).T + torch_output = torch.matmul( + ht.to(torch.float32), unpacked.T.contiguous().to(torch.float32) + ) + + # Print the results (optional, can be commented out) + print("triton_output =", triton_output) + print("torch_output =", torch_output) + + # Check if outputs are close within the given tolerances + assert torch.allclose( + triton_output, torch_output.to(torch.int32), atol=atol, rtol=rtol + ), "Results differ" + + +@pytest.mark.skip(reason="mm_int8int2 is under experimentation") +@pytest.mark.parametrize( + "size", + [ + 2048, + 1024, + 512, + ], +) +@pytest.mark.parametrize( + "out_features", + [ + 1024, + 2048, + 4096, + 10000, + ], +) +@pytest.mark.parametrize( + "device", + [ + "cuda", + ], +) +def test_unpack_pack_correctness(out_features, size, device): + u = torch.randint(0, 255, (out_features, size), device=device, dtype=torch.uint8) + + assert ( + pack_weights(unpack_weights(u.T), 2) == u.T + ).all(), "Packed weights do not match original weights." diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index bdb6ee11e..c62ea3575 100644 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -15,6 +15,7 @@ LigerSwiGLUMLP, monkey_patch, ) +from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.transformers.monkey_patch import ( MODEL_TYPE_TO_APPLY_LIGER_FN, _apply_liger_kernel, @@ -31,6 +32,7 @@ def test_import_from_root(): apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, + apply_liger_kernel_to_mllama, apply_liger_kernel_to_phi3, apply_liger_kernel_to_qwen2, apply_liger_kernel_to_qwen2_vl, @@ -81,16 +83,90 @@ def dummy_apply_liger_kernal_to_llama( with patch.dict(MODEL_TYPE_TO_APPLY_LIGER_FN, {"llama": mock_llama}): mock_llama.__signature__ = apply_liger_kernal_to_llama_sig - _apply_liger_kernel( - "llama", + ( + _apply_liger_kernel( + "llama", + rope=False, + fused_linear_cross_entropy=False, + cross_entropy=True, + foobar=True, + barbaz=False, + ), + ) + mock_llama.assert_called_once() + mock_llama.assert_called_once_with( rope=False, fused_linear_cross_entropy=False, cross_entropy=True, - foobar=True, - barbaz=False, - ), + ) + + +def test_apply_liger_kernel_to_instance_no_supported_model_type(): + # Test that calling _apply_liger_kernel_to_instance with an unsupported model type is a no-op + mock_mistral = Mock() + mock_unknown_model = MagicMock(spec=PreTrainedModel) + mock_unknown_model.config = {"model_type": "foobar"} + + with patch.dict(MODEL_TYPE_TO_APPLY_LIGER_FN, {"mistral": mock_mistral}): + _apply_liger_kernel_to_instance(model=mock_unknown_model) + MODEL_TYPE_TO_APPLY_LIGER_FN["mistral"].assert_not_called() + + +def test_apply_liger_kernel_to_instance_only_supported_model_type_called(): + # Test that liger kernel is applied only to the specified model + mock_gemma = Mock() + mock_llama = Mock() + mock_mistral = Mock() + + mock_llama_model_instance = MagicMock(spec=PreTrainedModel) + mock_llama_model_instance.config = MagicMock(spec=PretrainedConfig) + mock_llama_model_instance.config.model_type = "llama" + + with patch.dict( + MODEL_TYPE_TO_APPLY_LIGER_FN, + {"gemma": mock_gemma, "llama": mock_llama, "mistral": mock_mistral}, + ): + _apply_liger_kernel_to_instance(model=mock_llama_model_instance) + mock_llama.assert_called_once() + mock_gemma.assert_not_called() + mock_mistral.assert_not_called() + + +def test_apply_liger_kernel_to_instance_only_passes_valid_kwargs(): + # Test that keyword args that are not valid for the apply_liger_* function are not passed + mock_llama = Mock() + + mock_llama_model_instance = MagicMock(spec=PreTrainedModel) + mock_llama_model_instance.config = MagicMock(spec=PretrainedConfig) + mock_llama_model_instance.config.model_type = "llama" + + def dummy_apply_liger_kernel_to_llama( + rope=False, + cross_entropy=False, + fused_linear_cross_entropy=True, + rms_norm=True, + swiglu=True, + model=None, + ): + pass + + apply_liger_kernel_to_llama_sig = signature(dummy_apply_liger_kernel_to_llama) + + with patch.dict(MODEL_TYPE_TO_APPLY_LIGER_FN, {"llama": mock_llama}): + mock_llama.__signature__ = apply_liger_kernel_to_llama_sig + ( + _apply_liger_kernel_to_instance( + model=mock_llama_model_instance, + rope=False, + fused_linear_cross_entropy=False, + cross_entropy=True, + foobar=True, + barbaz=False, + ), + ) mock_llama.assert_called_once() mock_llama.assert_called_once_with( + model=mock_llama_model_instance, rope=False, fused_linear_cross_entropy=False, cross_entropy=True, @@ -199,7 +275,6 @@ def test_patching_apis_support_patching_model_instance(): def test_apply_liger_kernel_to_instance_for_llama(): # Ensure any monkey patching is cleaned up for subsequent tests with patch("transformers.models.llama.modeling_llama"): - # Instantiate a dummy model config = transformers.models.llama.configuration_llama.LlamaConfig( torch_dtype=torch.bfloat16, @@ -211,28 +286,213 @@ def test_apply_liger_kernel_to_instance_for_llama(): ) dummy_model_instance = AutoModelForCausalLM.from_config(config) + # Check that model instance variables are not yet patched with Liger modules + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + for layer in dummy_model_instance.model.layers: + assert inspect.getsource(layer.mlp.forward) != inspect.getsource( + LigerSwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + + # Test applying kernels to the model instance + _apply_liger_kernel_to_instance(model=dummy_model_instance) + + # Check that the model's instance variables were correctly patched with Liger modules + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + for layer in dummy_model_instance.model.layers: + assert inspect.getsource(layer.mlp.forward) == inspect.getsource( + LigerSwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + + +def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation(): + # Ensure any monkey patching is cleaned up for subsequent tests + with patch("transformers.models.mllama.modeling_mllama"): + from transformers.models.mllama.modeling_mllama import ( + MllamaForConditionalGeneration, + ) + + # Instantiate a dummy model + config = transformers.models.mllama.configuration_mllama.MllamaConfig( + torch_dtype=torch.bfloat16, + text_config=transformers.models.mllama.configuration_mllama.MllamaTextConfig( + rms_norm_eps=1e-5, + hidden_size=32, + intermediate_size=64, + hidden_act="silu", + num_hidden_layers=2, + rope_scaling=dict( + factor=8.0, + high_freq_factor=4.0, + low_freq_factor=1.0, + original_max_position_embeddings=8192, + rope_type="llama3", + ), + ), + vision_config=transformers.models.mllama.configuration_mllama.MllamaVisionConfig( + rms_norm_eps=1e-5, + hidden_size=32, + intermediate_size=64, + hidden_act="gelu", + num_hidden_layers=2, + vision_output_dim=64, + ), + ) + dummy_model_instance = MllamaForConditionalGeneration._from_config(config) + + assert isinstance(dummy_model_instance, MllamaForConditionalGeneration) + + # Check that model instance variables are not yet patched with Liger modules + assert inspect.getsource( + dummy_model_instance.language_model.model.norm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + for layer in dummy_model_instance.language_model.model.layers: + assert inspect.getsource(layer.mlp.forward) != inspect.getsource( + LigerSwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + + assert inspect.getsource( + dummy_model_instance.vision_model.layernorm_pre.forward + ) != inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource( + dummy_model_instance.vision_model.layernorm_post.forward + ) != inspect.getsource(LigerLayerNorm.forward) + for layer in dummy_model_instance.vision_model.transformer.layers: + assert inspect.getsource( + layer.input_layernorm.forward + ) != inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) != inspect.getsource(LigerLayerNorm.forward) + for layer in dummy_model_instance.vision_model.global_transformer.layers: + assert inspect.getsource( + layer.input_layernorm.forward + ) != inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) != inspect.getsource(LigerLayerNorm.forward) + + # Test applying kernels to the model instance + _apply_liger_kernel_to_instance(model=dummy_model_instance) + + # Check that the model's instance variables were correctly patched with Liger modules + assert inspect.getsource( + dummy_model_instance.language_model.model.norm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + for layer in dummy_model_instance.language_model.model.layers: + assert inspect.getsource(layer.mlp.forward) == inspect.getsource( + LigerSwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + + assert inspect.getsource( + dummy_model_instance.vision_model.layernorm_pre.forward + ) == inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource( + dummy_model_instance.vision_model.layernorm_post.forward + ) == inspect.getsource(LigerLayerNorm.forward) + for layer in dummy_model_instance.vision_model.transformer.layers: + assert inspect.getsource( + layer.input_layernorm.forward + ) == inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) == inspect.getsource(LigerLayerNorm.forward) + for layer in dummy_model_instance.vision_model.global_transformer.layers: + assert inspect.getsource( + layer.input_layernorm.forward + ) == inspect.getsource(LigerLayerNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) == inspect.getsource(LigerLayerNorm.forward) + + +def test_apply_liger_kernel_to_instance_for_mllama_for_causal_lm(): + # Ensure any monkey patching is cleaned up for subsequent tests + with patch("transformers.models.mllama.modeling_mllama"): + from transformers.models.mllama.modeling_mllama import MllamaForCausalLM + + # Instantiate a dummy model + config = transformers.models.mllama.configuration_mllama.MllamaTextConfig( + rms_norm_eps=1e-5, + hidden_size=32, + intermediate_size=64, + hidden_act="silu", + num_hidden_layers=2, + rope_scaling=dict( + factor=8.0, + high_freq_factor=4.0, + low_freq_factor=1.0, + original_max_position_embeddings=8192, + rope_type="llama3", + ), + ) + + dummy_model_instance = MllamaForCausalLM._from_config(config) + + assert isinstance(dummy_model_instance, MllamaForCausalLM) + # Check that model instance variables are not yet patched with Liger modules assert not isinstance(dummy_model_instance.model.norm, LigerRMSNorm) for layer in dummy_model_instance.model.layers: - assert not isinstance(layer.mlp, LigerSwiGLUMLP) - assert not isinstance(layer.input_layernorm, LigerRMSNorm) - assert not isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert inspect.getsource(layer.mlp.forward) != inspect.getsource( + LigerSwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert isinstance(layer.mlp, LigerSwiGLUMLP) - assert isinstance(layer.input_layernorm, LigerRMSNorm) - assert isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert inspect.getsource(layer.mlp.forward) == inspect.getsource( + LigerSwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) def test_apply_liger_kernel_to_instance_for_mistral(): # Ensure any monkey patching is cleaned up for subsequent tests with patch("transformers.models.mistral.modeling_mistral"): - # Instantiate a dummy model config = transformers.models.mistral.configuration_mistral.MistralConfig( torch_dtype=torch.bfloat16, @@ -245,27 +505,42 @@ def test_apply_liger_kernel_to_instance_for_mistral(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert not isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert not isinstance(layer.mlp, LigerSwiGLUMLP) - assert not isinstance(layer.input_layernorm, LigerRMSNorm) - assert not isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert inspect.getsource(layer.mlp.forward) != inspect.getsource( + LigerSwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert isinstance(layer.mlp, LigerSwiGLUMLP) - assert isinstance(layer.input_layernorm, LigerRMSNorm) - assert isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert inspect.getsource(layer.mlp.forward) == inspect.getsource( + LigerSwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) def test_apply_liger_kernel_to_instance_for_mixtral(): # Ensure any monkey patching is cleaned up for subsequent tests with patch("transformers.models.mixtral.modeling_mixtral"): - # Instantiate a dummy model config = transformers.models.mixtral.configuration_mixtral.MixtralConfig( torch_dtype=torch.bfloat16, @@ -280,29 +555,44 @@ def test_apply_liger_kernel_to_instance_for_mixtral(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert not isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: for expert in layer.block_sparse_moe.experts: - assert not isinstance(expert, LigerBlockSparseTop2MLP) - assert not isinstance(layer.input_layernorm, LigerRMSNorm) - assert not isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert inspect.getsource(expert.forward) != inspect.getsource( + LigerBlockSparseTop2MLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: for expert in layer.block_sparse_moe.experts: - assert isinstance(expert, LigerBlockSparseTop2MLP) - assert isinstance(layer.input_layernorm, LigerRMSNorm) - assert isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert inspect.getsource(expert.forward) == inspect.getsource( + LigerBlockSparseTop2MLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) def test_apply_liger_kernel_to_instance_for_gemma(): # Ensure any monkey patching is cleaned up for subsequent tests with patch("transformers.models.gemma.modeling_gemma"): - # Instantiate a dummy model config = transformers.models.gemma.configuration_gemma.GemmaConfig( torch_dtype=torch.bfloat16, @@ -315,27 +605,42 @@ def test_apply_liger_kernel_to_instance_for_gemma(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert not isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert not isinstance(layer.mlp, LigerGEGLUMLP) - assert not isinstance(layer.input_layernorm, LigerRMSNorm) - assert not isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert inspect.getsource(layer.mlp.forward) != inspect.getsource( + LigerGEGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert isinstance(layer.mlp, LigerGEGLUMLP) - assert isinstance(layer.input_layernorm, LigerRMSNorm) - assert isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert inspect.getsource(layer.mlp.forward) == inspect.getsource( + LigerGEGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) def test_apply_liger_kernel_to_instance_for_gemma2(): # Ensure any monkey patching is cleaned up for subsequent tests with patch("transformers.models.gemma2.modeling_gemma2"): - # Instantiate a dummy model config = transformers.models.gemma2.configuration_gemma2.Gemma2Config( torch_dtype=torch.bfloat16, @@ -348,31 +653,54 @@ def test_apply_liger_kernel_to_instance_for_gemma2(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert not isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert not isinstance(layer.mlp, LigerGEGLUMLP) - assert not isinstance(layer.input_layernorm, LigerRMSNorm) - assert not isinstance(layer.post_attention_layernorm, LigerRMSNorm) - assert not isinstance(layer.pre_feedforward_layernorm, LigerRMSNorm) - assert not isinstance(layer.post_feedforward_layernorm, LigerRMSNorm) + assert inspect.getsource(layer.mlp.forward) != inspect.getsource( + LigerGEGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.pre_feedforward_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_feedforward_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert isinstance(layer.mlp, LigerGEGLUMLP) - assert isinstance(layer.input_layernorm, LigerRMSNorm) - assert isinstance(layer.post_attention_layernorm, LigerRMSNorm) - assert isinstance(layer.pre_feedforward_layernorm, LigerRMSNorm) - assert isinstance(layer.post_feedforward_layernorm, LigerRMSNorm) + assert inspect.getsource(layer.mlp.forward) == inspect.getsource( + LigerGEGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.pre_feedforward_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_feedforward_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) def test_apply_liger_kernel_to_instance_for_qwen2(): # Ensure any monkey patching is cleaned up for subsequent tests with patch("transformers.models.qwen2.modeling_qwen2"): - # Instantiate a dummy model config = transformers.models.qwen2.configuration_qwen2.Qwen2Config( torch_dtype=torch.bfloat16, @@ -385,27 +713,120 @@ def test_apply_liger_kernel_to_instance_for_qwen2(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert not isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + for layer in dummy_model_instance.model.layers: + assert inspect.getsource(layer.mlp.forward) != inspect.getsource( + LigerSwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + + # Test applying kernels to the model instance + _apply_liger_kernel_to_instance(model=dummy_model_instance) + + # Check that the model's instance variables were correctly patched with Liger modules + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert not isinstance(layer.mlp, LigerSwiGLUMLP) - assert not isinstance(layer.input_layernorm, LigerRMSNorm) - assert not isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert inspect.getsource(layer.mlp.forward) == inspect.getsource( + LigerSwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + + +def test_apply_liger_kernel_to_instance_for_qwen2_vl(): + # Ensure any monkey patching is cleaned up for subsequent tests + with patch("transformers.models.qwen2_vl.modeling_qwen2_vl"): + from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLForConditionalGeneration, + ) + + # Instantiate a dummy model + config = transformers.models.qwen2_vl.configuration_qwen2_vl.Qwen2VLConfig( + torch_dtype=torch.bfloat16, + rms_norm_eps=1e-5, + hidden_size=32, + intermediate_size=48, + embed_dim=16, + hidden_act="silu", + num_hidden_layers=2, + num_attention_heads=2, + max_position_embeddings=128, + vocab_size=1000, + vision_config={ + "depth": 4, + "embed_dim": 128, + "num_heads": 8, + "hidden_size": 1024, + }, + ) + dummy_model_instance = Qwen2VLForConditionalGeneration._from_config(config) + + assert isinstance(dummy_model_instance, Qwen2VLForConditionalGeneration) + + # Check that model instance variables are not yet patched with Liger modules + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + for layer in dummy_model_instance.model.layers: + assert inspect.getsource(layer.mlp.forward) != inspect.getsource( + LigerSwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + for vision_block in dummy_model_instance.visual.blocks: + assert inspect.getsource(vision_block.norm1.forward) != inspect.getsource( + LigerLayerNorm.forward + ) + assert inspect.getsource(vision_block.norm2.forward) != inspect.getsource( + LigerLayerNorm.forward + ) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert isinstance(layer.mlp, LigerSwiGLUMLP) - assert isinstance(layer.input_layernorm, LigerRMSNorm) - assert isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert inspect.getsource(layer.mlp.forward) == inspect.getsource( + LigerSwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + for vision_block in dummy_model_instance.visual.blocks: + assert inspect.getsource(vision_block.norm1.forward) == inspect.getsource( + LigerLayerNorm.forward + ) + assert inspect.getsource(vision_block.norm2.forward) == inspect.getsource( + LigerLayerNorm.forward + ) def test_apply_liger_kernel_to_instance_for_phi3(): # Ensure any monkey patching is cleaned up for subsequent tests with patch("transformers.models.phi3.modeling_phi3"): - # Instantiate a dummy model config = transformers.models.phi3.configuration_phi3.Phi3Config( torch_dtype=torch.bfloat16, @@ -418,18 +839,34 @@ def test_apply_liger_kernel_to_instance_for_phi3(): dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules - assert not isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) != inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert not isinstance(layer.mlp, LigerPhi3SwiGLUMLP) - assert not isinstance(layer.input_layernorm, LigerRMSNorm) - assert not isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert inspect.getsource(layer.mlp.forward) != inspect.getsource( + LigerPhi3SwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) != inspect.getsource(LigerRMSNorm.forward) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) # Check that the model's instance variables were correctly patched with Liger modules - assert isinstance(dummy_model_instance.model.norm, LigerRMSNorm) + assert inspect.getsource( + dummy_model_instance.model.norm.forward + ) == inspect.getsource(LigerRMSNorm.forward) for layer in dummy_model_instance.model.layers: - assert isinstance(layer.mlp, LigerPhi3SwiGLUMLP) - assert isinstance(layer.input_layernorm, LigerRMSNorm) - assert isinstance(layer.post_attention_layernorm, LigerRMSNorm) + assert inspect.getsource(layer.mlp.forward) == inspect.getsource( + LigerPhi3SwiGLUMLP.forward + ) + assert inspect.getsource( + layer.input_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource( + layer.post_attention_layernorm.forward + ) == inspect.getsource(LigerRMSNorm.forward) diff --git a/test/transformers/test_rms_norm.py b/test/transformers/test_rms_norm.py index d9e823e6d..1dd2299b8 100644 --- a/test/transformers/test_rms_norm.py +++ b/test/transformers/test_rms_norm.py @@ -1,5 +1,5 @@ import os -from test.utils import assert_verbose_allclose, supports_bfloat16 +from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16 import pytest import torch @@ -9,6 +9,7 @@ from liger_kernel.transformers.functional import liger_rms_norm from liger_kernel.transformers.rms_norm import LigerRMSNorm +set_seed(42) torch.use_deterministic_algorithms(True) # Only setting torch.use_deterministic_algorithms(True) might throw the following error: @@ -73,14 +74,8 @@ def forward(self, x): "bs, sl, hd", [ (2, 128, 512), - (4, 256, 1024), - (8, 512, 2048), - (16, 1024, 4096), - # # weird shapes - (3, 423, 213), + # weird shapes (5, 123, 123), - (7, 341, 234), - (9, 236, 345), ], ) @pytest.mark.parametrize( @@ -95,7 +90,6 @@ def forward(self, x): not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), - (torch.float16, 2e-1, 2e-2), ], ) @pytest.mark.parametrize( @@ -107,9 +101,6 @@ def forward(self, x): ], ) def test_correctness(bs, sl, hd, dtype, atol, rtol, reference, offset, casting_mode): - if reference == BaseRMSNorm and dtype == torch.bfloat16: - pytest.skip("bfloat16 has larger errors for BaseRMSNorm") - _tensor = torch.randn(bs, sl, hd, device="cuda", dtype=dtype) h1 = _tensor.clone().requires_grad_(True) @@ -121,7 +112,7 @@ def test_correctness(bs, sl, hd, dtype, atol, rtol, reference, offset, casting_m # reference (llama or gemma) ref_rms = reference(hidden_size=hd).to("cuda").to(dtype) ref_o = ref_rms(h1) - ref_o.backward(do.clone(), retain_graph=True) + ref_o.backward(do, retain_graph=True) # triton triton_rms = ( @@ -130,20 +121,22 @@ def test_correctness(bs, sl, hd, dtype, atol, rtol, reference, offset, casting_m .to(dtype) ) triton_o = triton_rms(h2) - triton_o.backward(do.clone(), retain_graph=True) + triton_o.backward(do, retain_graph=True) assert_verbose_allclose(ref_o, triton_o, atol=atol, rtol=rtol) assert_verbose_allclose( ref_rms.weight.grad, triton_rms.weight.grad, atol=atol, rtol=rtol ) - assert_verbose_allclose(h1.grad, h2.grad, atol=atol, rtol=rtol) + print(f"{h1.grad=}") + print(f"{h2.grad=}") + assert_verbose_allclose(h1.grad, h2.grad, atol=atol, rtol=rtol, max_print=20) @pytest.mark.parametrize( "bs, sl, hd", [ (2, 2, 8), - # # weird shapes + # weird shapes (9, 7, 41), ], ) @@ -152,7 +145,6 @@ def test_correctness(bs, sl, hd, dtype, atol, rtol, reference, offset, casting_m [ (torch.float32, 1e-4, 1e-6), (torch.bfloat16, 2e-1, 2e-2), - (torch.float16, 2e-1, 2e-2), ], ) @pytest.mark.parametrize( diff --git a/test/transformers/test_swiglu.py b/test/transformers/test_swiglu.py index ccb395c98..be7aaef42 100644 --- a/test/transformers/test_swiglu.py +++ b/test/transformers/test_swiglu.py @@ -27,11 +27,9 @@ @pytest.mark.parametrize( "bsz, seq_len, hidden_size, intermediate_size", [ - (2, 2048, 4096, 11008), - (2, 2048, 2048, 4096), + (2, 256, 256, 512), # weird shapes - (9, 41, 341, 4231), - (6, 42, 256, 2048), + (6, 42, 123, 431), ], ) @pytest.mark.parametrize( @@ -109,11 +107,9 @@ def test_correctness_llamamlp( @pytest.mark.parametrize( "bsz, seq_len, hidden_size, intermediate_size", [ - (2, 2048, 4096, 11008), - (2, 2048, 2048, 4096), + (2, 256, 256, 512), # weird shapes - (9, 41, 341, 4231), - (6, 42, 256, 2048), + (6, 42, 123, 431), ], ) @pytest.mark.parametrize( diff --git a/test/utils.py b/test/utils.py index 748d84e64..ac9a13190 100644 --- a/test/utils.py +++ b/test/utils.py @@ -1,10 +1,15 @@ import importlib +import json import os import random from dataclasses import dataclass from typing import Any, Dict, List import torch +from tokenizers import AddedToken, Tokenizer +from tokenizers.models import BPE +from tokenizers.pre_tokenizers import Whitespace +from tokenizers.trainers import BpeTrainer from transformers import PretrainedConfig, PreTrainedModel from transformers.tokenization_utils_base import BatchEncoding @@ -55,10 +60,27 @@ def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print= # Determine the tolerance tolerance = atol + rtol * torch.abs(tensor2) - # Find mismatched elements - mismatched = diff > tolerance + # Find tolerance mismatched elements + tol_mismatched = diff > tolerance + + # Find nan mismatched elements + nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2)) + + # Find +inf mismatched elements + posinf_mismatched = torch.logical_xor( + torch.isposinf(tensor1), torch.isposinf(tensor2) + ) + # Find -inf mismatched elements + neginf_mismatched = torch.logical_xor( + torch.isneginf(tensor1), torch.isneginf(tensor2) + ) + + # Find all mismatched elements + mismatched = torch.logical_or( + torch.logical_or(tol_mismatched, nan_mismatched), + torch.logical_or(posinf_mismatched, neginf_mismatched), + ) - # Get the indices of mismatched elements mismatched_indices = torch.nonzero(mismatched) # Count the number of mismatched elements @@ -68,7 +90,7 @@ def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print= all_close = num_mismatched == 0 # Raise AssertionError with detailed information if there are mismatches - if not all_close and num_mismatched > 1: + if not all_close and num_mismatched >= 1: mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] print_count = min(max_print, num_mismatched) for index in mismatched_indices[:print_count]: @@ -93,6 +115,10 @@ def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print= os.path.dirname(os.path.abspath(__file__)), "resources/tiny_shakespeare.txt" ) +FAKE_CONFIGS_PATH = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "resources/fake_configs" +) + @dataclass class MiniModelConfig: @@ -139,6 +165,41 @@ def multimodal_collate_fn(data: List[Dict[str, Any]]): return BatchEncoding(batch) +def load_tokenizer_config(config_path: str) -> dict: + """Load and process tokenizer configuration from a JSON file.""" + with open(config_path) as reader: + tokenizer_config = json.load(reader) + tokenizer_config["added_tokens_decoder"] = { + k: AddedToken(**v) for k, v in tokenizer_config["added_tokens_decoder"].items() + } + return tokenizer_config + + +def train_bpe_tokenizer(special_tokens: List[str], unk_token: str = "<|unk|>"): + """ + Train a tokenizer using the BPE algorithm. + + Parameters: + unk_token (str): The token to use for unknown tokens. + special_tokens (List[str]): A list of special tokens to use. + + Returns: + Tokenizer: The trained tokenizer. + """ + # Add unk_token to special_tokens if not already present + if unk_token not in special_tokens: + special_tokens.append(unk_token) + + tokenizer = Tokenizer(BPE(unk_token=unk_token)) + trainer = BpeTrainer(special_tokens=special_tokens) + + tokenizer.pre_tokenizer = Whitespace() + file = [UNTOKENIZED_DATASET_PATH] + tokenizer.train(file, trainer) + + return tokenizer + + def supports_bfloat16(): if not torch.cuda.is_available(): return False @@ -156,6 +217,19 @@ def revert_liger_kernel_to_llama(): print("Liger kernel patches have been reverted.") +def revert_liger_kernel_to_mllama(): + """ + Revert all Liger kernel patches applied to MLlama. + """ + + import torch.nn as nn + from transformers.models.mllama import modeling_mllama + + importlib.reload(nn) + importlib.reload(modeling_mllama) + print("Liger kernel patches have been reverted.") + + def revert_liger_kernel_to_mistral(): """ Revert all Liger kernel patches applied to Mistral.