Skip to content

Commit 733992a

Browse files
anhuongachew010fabianlim
authored
feat: add liger kernel with fused cross entropy loss (#93)
* initial implementation of fused-linear-loss on llama Signed-off-by: 1000850000 user <aaron.chew1@ibm.com> Signed-off-by: Anh Uong <anh.uong@ibm.com> * syntax fixes and remove unused code Signed-off-by: Anh Uong <anh.uong@ibm.com> * add new num_logits_to_keep arg for llama.forward() Signed-off-by: Anh Uong <anh.uong@ibm.com> * add mixtral model patch Signed-off-by: Anh Uong <anh.uong@ibm.com> * add mistral and granite model patch Signed-off-by: Anh Uong <anh.uong@ibm.com> * add benchmark Signed-off-by: Anh Uong <anh.uong@ibm.com> * add new liger benchmarks Signed-off-by: Anh Uong <anh.uong@ibm.com> * some fixes Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * revise benches Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * refactor to fused_ops Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * fix fmt + lint Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * update full benches and readme Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * fix fast foak configs Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> * docs: update foak readme benchmarks Signed-off-by: Anh Uong <anh.uong@ibm.com> --------- Signed-off-by: 1000850000 user <aaron.chew1@ibm.com> Signed-off-by: Anh Uong <anh.uong@ibm.com> Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> Co-authored-by: 1000850000 user <aaron.chew1@ibm.com> Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
1 parent c70ffe0 commit 733992a

25 files changed

+1326
-18
lines changed

plugins/framework/src/fms_acceleration/framework_plugin.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def _check_config_and_maybe_check_values(
206206
t = list(t.keys())[0] # otherwise take the first value
207207

208208
if t not in values:
209-
if default is None:
209+
if t is not None or default is None:
210210
raise AccelerationPluginConfigError(
211211
f"{self.__class__.__name__}: Value at '{key}' was '{t}'. "
212212
f"Not found in expected set '{values}'."

plugins/fused-ops-and-kernels/.isort.cfg

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ known_firstparty=
1010
known_localfolder=tuning
1111

1212
# skip code imported from unsloth
13-
skip_glob=**/unsloth*/**
13+
skip_glob=**/unsloth*/**,
14+
**/liger*/**

plugins/fused-ops-and-kernels/.pylintrc

+3-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ ignore=CVS,protobufs
5353
# format. Because '\\' represents the directory delimiter on Windows systems,
5454
# it can't be used as an escape character.
5555
# NOTE: do not lint code imported from unsloth
56-
ignore-paths=.*fused_ops/unsloth_lora.*,.*kernels/unsloth*
56+
ignore-paths=.*fused_ops/unsloth_lora.*,
57+
.*fused_ops/liger_ce.*,
58+
.*kernels/unsloth*,
5759

5860
# Files or directories matching the regular expression patterns are skipped.
5961
# The regex matches against base names, not paths. The default value ignores

plugins/fused-ops-and-kernels/README.md

+14-1
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,23 @@ It is realtively easy by following an existing template, in what follows we use
7979
)
8080
```
8181
82+
### Running Liger Kernel Benchmarks
83+
84+
Using the [scenarios-liger.yaml](../../scripts/benchmarks/scenarios-liger.yaml), this will run full fine tuning, lora peft, autoGPTQ lora peft, and bits-and-bytes lora peft with the triton kernels (Fast RMS, RoPE, CrossEnt) as a base and then run with the liger kernel for LigerFusedLinearCrossEntropy as well as Fast RMS, RoPE to compare results. It only runs against mistral and llama models.
85+
86+
The benchmarks were ran separately for each `num_gpu` entry; they can be run together in a single command, but this is more efficient.
87+
88+
```sh
89+
tox -e run-benches -- 1 "4 8 16 32" benchmark_outputs_1 scenarios-liger.yaml none
90+
tox -e run-benches -- 2 "8 16 32 64" benchmark_outputs_2 scenarios-liger.yaml none
91+
tox -e run-benches -- 4 "16 32 64 128" benchmark_outputs_3 scenarios-liger.yaml none
92+
```
93+
94+
8295
## Known Issues
8396

8497
- MixedPrecision `--fp16` or `--bf16` should be used with `fast_lora`.
8598
- `fast_lora` has issues with FSDP V1 with the `peft` style of FSDP wrapping.
8699
* This is because the adapter's forward functions are bypassed in the fused ops.
87100
* For AutoGPTQ/QLoRA this is addressed by distributing the adapters using DDP so they will be unsharded in time for the fused ops.
88-
- `fast_rope_embeddings` does not work with position_ids. Currently `position_ids` are ignored and could give wrong results.
101+
- `fast_rope_embeddings` does not work with `postion_ids`, it seems like HF has depracated passing these ids into the rope embedding methods.

plugins/fused-ops-and-kernels/configs/fast_kernels.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@ training:
2222
fast_rms_layernorm: True
2323

2424
# fast RoPE embedding triton kernels
25-
fast_rope_embeddings: True
25+
fast_rope_embeddings: True
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
training:
2+
3+
fused_ops_and_kernels:
4+
5+
# if under training stanza, then putting
6+
# base_layer and fused_lora will be a misnomer
7+
# - this should be in peft.quantized
8+
# However, if it is specified, it will still
9+
# be read. This is useful in use cases where
10+
# the yaml is system generated and not shown
11+
# to a user.
12+
13+
# activate various unsloth optimizations
14+
# there are two versions of the plugin
15+
# - the FastKernel version supports individual kernels
16+
# - the FastQuantized version is all-or-nothing
17+
18+
# fast loss triton kernels
19+
fast_loss: fused_ce_liger
20+
21+
# fast rms norm triton kernels
22+
fast_rms_layernorm: True
23+
24+
# fast RoPE embedding triton kernels
25+
fast_rope_embeddings: True
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# PEFT-related acceleration
2+
peft:
3+
4+
# quantization-releated acceleration
5+
# e.g., kernels for quantized base weights
6+
quantization:
7+
8+
fused_ops_and_kernels:
9+
10+
# load unsloth optimizations for these 4bit base layer weights.
11+
# currently only support "auto_gptq" and "bitsandbytes"
12+
base_layer: auto_gptq
13+
14+
# activate various unsloth optimizations
15+
# there are two versions of the plugin
16+
# - the FastKernel version supports individual kernels
17+
# - the FastQuantized version is all-or-nothing
18+
19+
20+
# fused kernels for lora linear layers
21+
fused_lora: True
22+
23+
# fast loss triton kernels
24+
fast_loss: fused_ce_liger
25+
26+
# fast rms norm triton kernels
27+
fast_rsm_layernorm: True
28+
29+
# fast RoPE embedding triton kernels
30+
fast_rope_embeddings: True

plugins/fused-ops-and-kernels/pyproject.toml

+8
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,11 @@ only-include = ["src/fms_acceleration_foak"]
2929

3030
[tool.hatch.build.targets.wheel.sources]
3131
"src" = ""
32+
33+
[tool.black]
34+
force-exclude = '''
35+
/(
36+
.*unsloth.*
37+
| .*liger.*
38+
)/
39+
'''

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ def register_foak_model_patch_rules(
7373
# maybe this we should define envvars
7474
FILTER_MAP = {
7575
"fused_lora": {"qkvo", "mlp"},
76-
"fast_loss": "cross-ent",
76+
"fast_loss": {
77+
True: "cross-ent",
78+
"fused_ce_liger": "fused-lce",
79+
},
7780
"fast_rms_layernorm": "rms",
7881
"fast_rope_embeddings": "rope",
7982
}
@@ -109,19 +112,19 @@ def __init__(self, configurations: Dict[str, Dict]):
109112
key="base_layer", values=["auto_gptq", "bitsandbytes"], default="auto_gptq"
110113
)
111114
self.configurations["fused_lora"] = self._check_config_and_maybe_check_values(
112-
key="fused_lora", values=[False, True], default=True
115+
key="fused_lora", values=[False, True], default=False
113116
)
114117
self.configurations["fast_loss"] = self._check_config_and_maybe_check_values(
115-
key="fast_loss", values=[False, True], default=True
118+
key="fast_loss", values=[False, True, "fused_ce_liger"], default=False
116119
)
117120
self.configurations["fast_rms_layernorm"] = (
118121
self._check_config_and_maybe_check_values(
119-
key="fast_rms_layernorm", values=[False, True], default=True
122+
key="fast_rms_layernorm", values=[False, True], default=False
120123
)
121124
)
122125
self.configurations["fast_rope_embeddings"] = (
123126
self._check_config_and_maybe_check_values(
124-
key="fast_rope_embeddings", values=[False, True], default=True
127+
key="fast_rope_embeddings", values=[False, True], default=False
125128
)
126129
)
127130

@@ -162,6 +165,8 @@ def augmentation(
162165

163166
if k in FILTER_MAP and k not in omitted:
164167
ts = FILTER_MAP[k]
168+
if isinstance(ts, dict) and v in ts:
169+
ts = ts[v]
165170
if isinstance(ts, str):
166171
ts = {ts}
167172

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2024 Byron Hsu & Linkedin team. All rights reserved.
2+
#
3+
# BSD 2-CLAUSE LICENSE
4+
# Copyright 2024 LinkedIn Corporation
5+
# All Rights Reserved.
6+
# Redistribution and use in source and binary forms, with or
7+
# without modification, are permitted provided that the following
8+
# conditions are met:
9+
# 1. Redistributions of source code must retain the above copyright
10+
# notice, this list of conditions and the following disclaimer.
11+
# 2. Redistributions in binary form must reproduce the above
12+
# copyright notice, this list of conditions and the following
13+
# disclaimer in the documentation and/or other materials provided
14+
# with the distribution.
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16+
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17+
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18+
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
19+
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
20+
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
21+
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
22+
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
23+
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
from .fused_linear_cross_entropy_loss import lce_forward

0 commit comments

Comments
 (0)