Skip to content

Commit

Permalink
lint + fmt and improvements to readme
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
  • Loading branch information
fabianlim committed May 30, 2024
1 parent 97f013c commit a308a61
Show file tree
Hide file tree
Showing 14 changed files with 157 additions and 135 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Plugin | Description | Depends | License | Status
--|--|--|--|--
[framework](./plugins/framework/README.md) | This acceleration framework for integration with huggingface trainers | | | Beta
[accelerated-peft](./plugins/accelerated-peft/README.md) | For PEFT-training, e.g., 4bit QLoRA. | Huggingface<br>AutoGPTQ | Apache 2.0<br>MIT | Beta
[fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 with exclusions. | Coming Soon
[fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 [(contains extracted code)](./plugins/fused-ops-and-kernels/README.md#code-extracted-from-unsloth)| Beta
MOE-training-acceleration | [MegaBlocks](https://github.com/databricks/megablocks) inspired triton Kernels and acclerations for Mixture-of-Expert models | | Apache 2.0 | Coming Soon

## Usage with FMS HF Tuning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
# consider making a map if patching more kernels
PATCH_FOR_FSDP_TRITON_V2 = ["qweight", "qzeros"]


# This function may be moved after merging
# https://github.com/foundation-model-stack/fms-acceleration/pull/25
def _patch_target_module(
Expand Down Expand Up @@ -123,6 +124,7 @@ def create_new_module_peft(
# if module cannot be found, return None which results in a raise in the call-stack
return new_module


# consider to move this somewhere more general
def patch_forward_to_view_attributes_before_call(
old_forward: Callable,
Expand All @@ -133,9 +135,9 @@ def patch_forward_to_view_attributes_before_call(
):
# patch old_forward to view attribtues to torch_dype
# before call

if submodule_names is None:
submodule_names = ''
submodule_names = ""
if isinstance(submodule_names, str):
submodule_names = [submodule_names]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,18 @@ def __init__(self, configurations: Dict[str, Dict]):
def model_loader(self, model_name: str, **kwargs):
# guarded imports
# Third Party
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig #pylint: disable=import-outside-toplevel,import-error
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error
from auto_gptq import ( # pylint: disable=import-outside-toplevel,import-error
AutoGPTQForCausalLM,
BaseQuantizeConfig,
)
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
QuantLinear,
)

# Local
from .autogptq_utils import ( #pylint: disable=import-outside-toplevel
patch_forward_to_view_attributes_before_call,
PATCH_FOR_FSDP_TRITON_V2
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
PATCH_FOR_FSDP_TRITON_V2,
patch_forward_to_view_attributes_before_call,
)

# Currently we allow only a quantized checkpoint to be loaded, we do not
Expand Down Expand Up @@ -214,8 +219,14 @@ def augmentation(
):
# guarded imports
# Third Party
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error
from auto_gptq.utils.peft_utils import GPTQLoraModel, get_gptq_peft_model #pylint: disable=import-outside-toplevel,import-error
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
QuantLinear,
)
from auto_gptq.utils.peft_utils import ( # pylint: disable=import-outside-toplevel,import-error
GPTQLoraModel,
get_gptq_peft_model,
)

# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
create_new_module_peft,
Expand Down
32 changes: 9 additions & 23 deletions plugins/fused-ops-and-kernels/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This library contains fused operations and custom kernels, to be expanded over time. Currently it contains the following:


1. Fused operations and kernels are extracted from [unsloth](#extracted-code-from-unsloth).
1. Fused operations and kernels extracted from [unsloth](#extracted-code-from-unsloth).
- Low-Rank Adapter Fused Operations
- Fast RoPE Triton Kernels
- Fast RMS LayerNorm Triton Kernels
Expand All @@ -13,30 +13,21 @@ This library contains fused operations and custom kernels, to be expanded over t

Plugin | Description | Depends | Loading | Augmentation | Callbacks
--|--|--|--|--|--
[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | Loads fused lora, fast cross-entropy, fast rms, fast RoPE | | | ✅
[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | LoRA fused ops, fast cross-entropy, fast rms, fast RoPE | Contains extracted code | | ✅

### Code Extracted from Unsloth

<!--
NOTE: the
- fused_ops/unsloth_lora -> unsloth main
* utils (fast_dequant, fast_gemv, fast_linear_forward, matmul_lora)
* geglu, swiglu (this can be reused across other models, but currently used inside MLP fused ops only)
* bnb (fast_lora.py)
* gtqp (fast_lora, triton) -> jeromeku
- kernels
* cross_ent, rms, rope -> unsloth main
-->

Notes on the extraction of code from [unsloth](https://github.com/unslothai/unsloth):
- while unsloth is released under Apache 2.0, there are [exceptions to the permissive licenses scattered in the code base](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1140-L1143).
- While unsloth is [released under Apache 2.0](https://github.com/unslothai/unsloth/blob/main/LICENSE), there are comments indicating some exceptions strewn throughout the code base, see [an example here](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1140-L1143).
```
it would require a commercial license if used to run on more than 4 GPUs, see
https://github.com/unslothai/unsloth/blob/d215fd902cf28feb8abcfde2d25281d0fbf9d28c/unsloth/models/llama.py#L1140-L1143
it would require a commercial license if used to run on more than 4 GPUs ...
```
- these exceptions appear around [Feb 2024 Release](https://github.com/unslothai/unsloth/commit/3e4c5a323c16bbda2c92212b790073c4e99c2a55), around the model files (namely `llama.py`, `mistral.py`, etc).
* These model files are **not extracted**.
- All code extracted here before the Feb 2024 Release, see table below.
- These exceptions appear to be located around the trainer improvements, see [another example here](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1177-L1183).
- These exceptions appear around [Feb 2024 Release](https://github.com/unslothai/unsloth/commit/3e4c5a323c16bbda2c92212b790073c4e99c2a55); any code that appears in any file where such exceptions occur **is not extracted**.
- Instead in its place, we have adopted a different approach; we adopt the approach of model patching, as opposed unsloths' approach to rewrite the model. Our approach is novel and **completely rewritten from scratch**.
- All extracted code appears before the Feb 2024 Release.
- In the table below we record what was extracted, and the exact commit from which it was taken.

Path | Description | Extracted From | Modifications | Date
--|--|--|--|--
Expand All @@ -45,11 +36,6 @@ Path | Description | Extracted From | Modifications | Date
[fused_ops/unsloth_lora/gptq](./src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq) | GPTQ fast dequant (triton_v2) | `jeromeku/main` @ [2839d39](https://github.com/jeromeku/unsloth/commit/2839d390ef3bb318904289bfb9a7751a782c4e44) | `fast_lora.py`<br>`triton/layers.py` | 6 Feb 2024
[kernels/unsloth](./src/fms_acceleration_foak/kernels/unsloth) | Fast RMS, RoPE, CrossEnt kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `cross_entropy_loss.py`<br>`rms_layernorm.py` | 28 Jan 2024

<!--
[models/](./src/fms_accelerate_unsloth/models/) | Model Forwards | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc)<br><br>`tohrnii/mixtral` @ [a55b7400](https://github.com/tohrnii/unsloth/commit/a55b740062b4fc8ce8f5196bfabe3cf860020ca7) | `llama.py`<br>`mistral.py`<br>`mixtral.py`| `llama.py`<br>`mistral.py`<br>`mixtral.py` | 6 Feb 2024<br><br> 22 Feb 2024
-->


## Known Issues

- MixedPrecision `--fp16` should be used `fast_lora`. Also consider loading the model in `torch.float16`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ def __init__(self, configurations: Dict[str, Dict]):

self._base_layer = self._check_config_and_maybe_check_values(
key="peft.quantization.fused_ops_and_kernels.base_layer",
values=[
"auto_gptq", "bitsandbytes"
],
values=["auto_gptq", "bitsandbytes"],
)

# only support these at the moment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
# Local
from .model_patcher import ModelPatcher

PATCHES = [
".models.llama", ".models.mistral",
".models.mixtral"
]
PATCHES = [".models.llama", ".models.mistral", ".models.mixtral"]
PLUGIN_PREFIX = "fms_acceleration_foak"

# TODO: remove the need for the prefix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,24 @@
from functools import partial

# Third Party
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm
from transformers.models.llama.modeling_llama import LlamaMLP
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaMLP,
LlamaRMSNorm,
)

# Local
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from .model_patcher import ModelPatcher, ModelPatcherRule, ModelPatcherTrigger
from .model_patcher import combine_triggers, combine_functions
from .utils import build_lora_fused_ops, trigger_fused_ops
from .utils import KEY_QKV, KEY_O, KEY_MLP
from .model_patcher import (
ModelPatcher,
ModelPatcherRule,
ModelPatcherTrigger,
combine_functions,
combine_triggers,
)
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops

# TODO: have a generic version of this rule
# - do regex on RMSNorm class name
Expand All @@ -48,17 +55,19 @@
trigger=combine_triggers(
ModelPatcherTrigger(
check=partial(
trigger_fused_ops, attn_cls=LlamaAttention,
trigger_fused_ops,
attn_cls=LlamaAttention,
submodule_names=["q_proj", "k_proj", "v_proj"],
)
),
ModelPatcherTrigger(
check=partial(
trigger_fused_ops, attn_cls=LlamaAttention,
trigger_fused_ops,
attn_cls=LlamaAttention,
submodule_names=["o_proj"],
)
),
logic='OR',
logic="OR",
),
forward_builder=combine_functions(
partial(
Expand All @@ -71,7 +80,7 @@
submodule_names=["o_proj"],
fused_op=KEY_O,
),
logic='APPEND',
logic="APPEND",
),
forward_builder_args=["base_type"],
)
Expand All @@ -82,11 +91,12 @@
rule_id="llama-mlp",
trigger=ModelPatcherTrigger(
check=partial(
trigger_fused_ops, attn_cls=LlamaMLP,
trigger_fused_ops,
attn_cls=LlamaMLP,
submodule_names=["up_proj", "down_proj", "gate_proj"],
)
),
forward_builder= partial(
forward_builder=partial(
build_lora_fused_ops,
submodule_names=["up_proj", "down_proj", "gate_proj"],
fused_op=KEY_MLP,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,22 @@
# Third Party
from transformers.models.mistral.modeling_mistral import (
MistralAttention,
MistralMLP,
MistralRMSNorm,
MistralMLP
)

# Local
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from .model_patcher import ModelPatcher, ModelPatcherRule, ModelPatcherTrigger
from .model_patcher import combine_triggers, combine_functions
from .utils import build_lora_fused_ops, trigger_fused_ops
from .utils import KEY_QKV, KEY_O, KEY_MLP
from .model_patcher import (
ModelPatcher,
ModelPatcherRule,
ModelPatcherTrigger,
combine_functions,
combine_triggers,
)
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops

# - do regex on RMSNorm class name
# - check on the tensors required for fast_rms_layernorm
Expand All @@ -47,17 +51,19 @@
trigger=combine_triggers(
ModelPatcherTrigger(
check=partial(
trigger_fused_ops, attn_cls=MistralAttention,
trigger_fused_ops,
attn_cls=MistralAttention,
submodule_names=["q_proj", "k_proj", "v_proj"],
)
),
ModelPatcherTrigger(
check=partial(
trigger_fused_ops, attn_cls=MistralAttention,
trigger_fused_ops,
attn_cls=MistralAttention,
submodule_names=["o_proj"],
)
),
logic='OR',
logic="OR",
),
forward_builder=combine_functions(
partial(
Expand All @@ -70,7 +76,7 @@
submodule_names=["o_proj"],
fused_op=KEY_O,
),
logic='APPEND',
logic="APPEND",
),
forward_builder_args=["base_type"],
)
Expand All @@ -81,11 +87,12 @@
rule_id="mistral-mlp",
trigger=ModelPatcherTrigger(
check=partial(
trigger_fused_ops, attn_cls=MistralMLP,
trigger_fused_ops,
attn_cls=MistralMLP,
submodule_names=["up_proj", "down_proj", "gate_proj"],
)
),
forward_builder= partial(
forward_builder=partial(
build_lora_fused_ops,
submodule_names=["up_proj", "down_proj", "gate_proj"],
fused_op=KEY_MLP,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from .model_patcher import ModelPatcher, ModelPatcherRule, ModelPatcherTrigger
from .model_patcher import combine_triggers, combine_functions
from .utils import build_lora_fused_ops, trigger_fused_ops
from .utils import KEY_QKV, KEY_O
from .model_patcher import (
ModelPatcher,
ModelPatcherRule,
ModelPatcherTrigger,
combine_functions,
combine_triggers,
)
from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops

# - do regex on RMSNorm class name
# - check on the tensors required for fast_rms_layernorm
Expand All @@ -46,17 +50,19 @@
trigger=combine_triggers(
ModelPatcherTrigger(
check=partial(
trigger_fused_ops, attn_cls=MixtralAttention,
trigger_fused_ops,
attn_cls=MixtralAttention,
submodule_names=["q_proj", "k_proj", "v_proj"],
)
),
ModelPatcherTrigger(
check=partial(
trigger_fused_ops, attn_cls=MixtralAttention,
trigger_fused_ops,
attn_cls=MixtralAttention,
submodule_names=["o_proj"],
)
),
logic='OR',
logic="OR",
),
forward_builder=combine_functions(
partial(
Expand All @@ -69,7 +75,7 @@
submodule_names=["o_proj"],
fused_op=KEY_O,
),
logic='APPEND',
logic="APPEND",
),
forward_builder_args=["base_type"],
)
Expand Down
Loading

0 comments on commit a308a61

Please sign in to comment.