diff --git a/.github/workflows/build-and-publish.yml b/.github/workflows/build-and-publish.yml
index 8ebb698f..307ade0e 100644
--- a/.github/workflows/build-and-publish.yml
+++ b/.github/workflows/build-and-publish.yml
@@ -14,6 +14,7 @@ jobs:
- "framework"
- "accelerated-peft"
- "fused-ops-and-kernels"
+ - "attention-and-distributed-packing"
permissions:
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml
index d6d6b089..90f7210a 100644
--- a/.github/workflows/format.yml
+++ b/.github/workflows/format.yml
@@ -29,7 +29,7 @@ jobs:
- "framework"
- "accelerated-peft"
- "fused-ops-and-kernels"
- - "instruct-lab"
+ - "attention-and-distributed-packing"
steps:
- uses: actions/checkout@v4
diff --git a/README.md b/README.md
index f79026f4..1158550c 100644
--- a/README.md
+++ b/README.md
@@ -33,7 +33,7 @@ Plugin | Description | Depends | License | Status
[framework](./plugins/framework/README.md) | This acceleration framework for integration with huggingface trainers | | | Alpha
[accelerated-peft](./plugins/accelerated-peft/README.md) | For PEFT-training, e.g., 4bit QLoRA. | Huggingface
AutoGPTQ | Apache 2.0
MIT | Alpha
[fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 [(contains extracted code)](./plugins/fused-ops-and-kernels/README.md#code-extracted-from-unsloth)| Beta
-[instruct-lab](./plugins/instruct-lab/README.md) | Padding-Free Flash Attention Computation | flash-attn | Apache 2.0 | Beta
+[attention-and-distributed-packing](./plugins/attention-and-distributed-packing/README.md) | Padding-Free Flash Attention Computation | flash-attn | Apache 2.0 | Beta
MOE-training-acceleration | [MegaBlocks](https://github.com/databricks/megablocks) inspired triton Kernels and acclerations for Mixture-of-Expert models | | Apache 2.0 | Coming Soon
## Usage with FMS HF Tuning
diff --git a/plugins/instruct-lab/.isort.cfg b/plugins/attention-and-distributed-packing/.isort.cfg
similarity index 80%
rename from plugins/instruct-lab/.isort.cfg
rename to plugins/attention-and-distributed-packing/.isort.cfg
index 4aa62fac..98382601 100644
--- a/plugins/instruct-lab/.isort.cfg
+++ b/plugins/attention-and-distributed-packing/.isort.cfg
@@ -8,6 +8,3 @@ import_heading_firstparty=First Party
import_heading_localfolder=Local
known_firstparty=
known_localfolder=tuning
-
-# skip code imported from unsloth
-skip_glob=**/unsloth*/**
diff --git a/plugins/instruct-lab/.pylintrc b/plugins/attention-and-distributed-packing/.pylintrc
similarity index 100%
rename from plugins/instruct-lab/.pylintrc
rename to plugins/attention-and-distributed-packing/.pylintrc
diff --git a/plugins/instruct-lab/README.md b/plugins/attention-and-distributed-packing/README.md
similarity index 96%
rename from plugins/instruct-lab/README.md
rename to plugins/attention-and-distributed-packing/README.md
index d76f327e..cad6ec63 100644
--- a/plugins/instruct-lab/README.md
+++ b/plugins/attention-and-distributed-packing/README.md
@@ -1,4 +1,4 @@
-# FMS Acceleration for Instruct Lab
+# FMS Acceleration for Attention And Distributed Packing Plugin
This library contains plugins to accelerate finetuning with the following optimizations:
diff --git a/plugins/instruct-lab/configs/instruct_lab.yaml b/plugins/attention-and-distributed-packing/configs/aadp.yaml
similarity index 100%
rename from plugins/instruct-lab/configs/instruct_lab.yaml
rename to plugins/attention-and-distributed-packing/configs/aadp.yaml
diff --git a/plugins/instruct-lab/pyproject.toml b/plugins/attention-and-distributed-packing/pyproject.toml
similarity index 81%
rename from plugins/instruct-lab/pyproject.toml
rename to plugins/attention-and-distributed-packing/pyproject.toml
index e6e4adb1..00f1a155 100644
--- a/plugins/instruct-lab/pyproject.toml
+++ b/plugins/attention-and-distributed-packing/pyproject.toml
@@ -3,9 +3,9 @@ requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
-name = "fms-acceleration-ilab"
+name = "fms-acceleration-aadp"
version = '0.0.1'
-description = "FMS Acceleration Plugin for Functionalities Used in Instruct Lab Training"
+description = "FMS Acceleration Plugin for Attention and Distributed Packing Optimizations"
authors = [
{name = "Fabian Lim", email = "flim@sg.ibm.com"},
{name = "Aaron Chew", email = "aaron.chew1@ibm.com"},
@@ -24,7 +24,7 @@ classifiers=[
]
[tool.hatch.build.targets.wheel]
-only-include = ["src/fms_acceleration_ilab"]
+only-include = ["src/fms_acceleration_aadp"]
[tool.hatch.build.targets.wheel.sources]
"src" = ""
diff --git a/plugins/instruct-lab/src/fms_acceleration_ilab/__init__.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/__init__.py
similarity index 100%
rename from plugins/instruct-lab/src/fms_acceleration_ilab/__init__.py
rename to plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/__init__.py
diff --git a/plugins/instruct-lab/src/fms_acceleration_ilab/ilab_utils.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/aadp_utils.py
similarity index 75%
rename from plugins/instruct-lab/src/fms_acceleration_ilab/ilab_utils.py
rename to plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/aadp_utils.py
index 330bf5eb..08314e68 100644
--- a/plugins/instruct-lab/src/fms_acceleration_ilab/ilab_utils.py
+++ b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/aadp_utils.py
@@ -15,7 +15,7 @@
from dataclasses import dataclass
import warnings
from transformers import DefaultDataCollator, default_data_collator
-
+import torch
@dataclass
class DataCollatorWithFlattening(DefaultDataCollator):
@@ -51,4 +51,24 @@ def __call__(self, features, return_tensors=None):
ret["labels"] += [-100] + feature["labels"][1:]
else:
ret["labels"] += [-100] + feature["input_ids"][1:]
- return default_data_collator([ret], return_tensors)
+
+ position_ids = torch.tensor(ret["position_ids"]).flatten()
+ indices_q = torch.arange(
+ position_ids.size(0), device=position_ids.device, dtype=torch.int32
+ )
+ cu_seq_lens = torch.cat(
+ (
+ indices_q[position_ids == 0],
+ torch.tensor(
+ position_ids.size(), dtype=torch.int32
+ ),
+ )
+ )
+ max_length = position_ids.max() + 1
+
+ # return default_data_collator([ret], return_tensors)
+ return {
+ **default_data_collator([ret], return_tensors),
+ "cu_seq_lens": cu_seq_lens,
+ "max_length": max_length,
+ }
diff --git a/plugins/instruct-lab/src/fms_acceleration_ilab/flash_attn.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/flash_attn.py
similarity index 80%
rename from plugins/instruct-lab/src/fms_acceleration_ilab/flash_attn.py
rename to plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/flash_attn.py
index 26e26d01..1ecb978c 100644
--- a/plugins/instruct-lab/src/fms_acceleration_ilab/flash_attn.py
+++ b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/flash_attn.py
@@ -32,36 +32,27 @@
inspect.signature(flash_attn_func).parameters
)
+# model id -> position_ids
+POSITION_IDS_CACHE = {}
+CU_SEQ_LENS_CACHE = {}
+MAX_SEQ_LENS_CACHE = {}
-def prepare_fa2_from_position_ids(query, key, value, position_ids, query_length):
- query = query.view(-1, query.size(-2), query.size(-1))
- key = key.view(-1, key.size(-2), key.size(-1))
- value = value.view(-1, value.size(-2), value.size(-1))
- position_ids = position_ids.flatten()
- indices_q = torch.arange(
- position_ids.size(0), device=position_ids.device, dtype=torch.int32
- )
- cu_seq_lens = torch.cat(
- (
- indices_q[position_ids == 0],
- torch.tensor(
- position_ids.size(), device=position_ids.device, dtype=torch.int32
- ),
- )
- )
- max_length = position_ids.max() + 1
- return (
- query,
- key,
- value,
- indices_q,
- (cu_seq_lens, cu_seq_lens),
- (max_length, max_length),
- )
+# This is used to patch the top-level model to accept cuseqlen
+# and maxseqlen as additional args that are cached for attention
+# computation
+def build_toplevel_model_forward(
+ model: torch.nn.Module,
+ model_id: str,
+):
+ # forward
+ old_forward = model.forward
+ def forward(self, *args, cu_seq_lens, max_length, **kwargs):
+ CU_SEQ_LENS_CACHE[model_id] = (cu_seq_lens, cu_seq_lens)
+ MAX_SEQ_LENS_CACHE[model_id] = (max_length, max_length)
+ return old_forward(*args, **kwargs)
-# model id -> position_ids
-POSITION_IDS_CACHE = {}
+ return forward
# - needed to store position ids when first come into model
@@ -123,6 +114,8 @@ def _flash_attention_forward_with_posids(
):
# get the position ids out here
position_ids = POSITION_IDS_CACHE[model_id]
+ cu_seqlens_q, cu_seqlens_k = CU_SEQ_LENS_CACHE[model_id]
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = MAX_SEQ_LENS_CACHE[model_id]
if not use_top_left_mask:
causal = is_causal
@@ -161,19 +154,10 @@ def _flash_attention_forward_with_posids(
assert attention_mask is None, "should not be using attention mask"
assert position_ids is not None, "should be expecting position ids"
batch_size = query_states.size(0)
- (
- query_states,
- key_states,
- value_states,
- _,
- cu_seq_lens,
- max_seq_lens,
- ) = prepare_fa2_from_position_ids(
- query_states, key_states, value_states, position_ids, query_length
- )
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+ query_states = query_states.view(-1, query_states.size(-2), query_states.size(-1))
+ key_states = key_states.view(-1, key_states.size(-2), key_states.size(-1))
+ value_states = value_states.view(-1, value_states.size(-2), value_states.size(-1))
attn_output = flash_attn_varlen_func(
query_states,
diff --git a/plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py
similarity index 92%
rename from plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py
rename to plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py
index 33a592ee..70d6fcbf 100644
--- a/plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py
+++ b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py
@@ -93,7 +93,7 @@ def _is_backbone(module: torch.nn.Module):
# - patch backbone
model_type = model.config.model_type
# pylint: disable=import-outside-toplevel
- from .flash_attn import build_backbone_forward
+ from .flash_attn import build_backbone_forward, build_toplevel_model_forward
ModelPatcher.register(
ModelPatcherRule(
@@ -106,6 +106,19 @@ def _is_backbone(module: torch.nn.Module):
),
)
+ # Need to patch the top-level model to accept and cache additional
+ # kwargs, cu_seq_lens and max_len from data collator
+ ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id=f"{model_type}-cumseqlen-cache",
+ trigger=ModelPatcherTrigger(check=model.__class__),
+ forward_builder=partial(
+ build_toplevel_model_forward,
+ model_id=id(model),
+ ),
+ ),
+ )
+
# Next, the flash attention function needs to be patched
# how it is patched depends on the transformers version
try:
@@ -185,7 +198,7 @@ def _patch_dataloader(
except ImportError:
# Otherwise, use the locally implemented DataCollatorWithFlattening
# pylint: disable=import-outside-toplevel
- from .ilab_utils import (
+ from .aadp_utils import (
DataCollatorWithFlattening,
)
diff --git a/plugins/instruct-lab/tests/__init__.py b/plugins/attention-and-distributed-packing/tests/__init__.py
similarity index 100%
rename from plugins/instruct-lab/tests/__init__.py
rename to plugins/attention-and-distributed-packing/tests/__init__.py
diff --git a/plugins/instruct-lab/tests/test_ilab_plugin.py b/plugins/attention-and-distributed-packing/tests/test_aadp_plugin.py
similarity index 83%
rename from plugins/instruct-lab/tests/test_ilab_plugin.py
rename to plugins/attention-and-distributed-packing/tests/test_aadp_plugin.py
index c3185d83..ea38158b 100644
--- a/plugins/instruct-lab/tests/test_ilab_plugin.py
+++ b/plugins/attention-and-distributed-packing/tests/test_aadp_plugin.py
@@ -17,13 +17,13 @@
instantiate_framework,
read_configuration,
)
-from fms_acceleration_ilab import PaddingFreeAccelerationPlugin
+from fms_acceleration_aadp import PaddingFreeAccelerationPlugin
# configuration
DIRNAME = os.path.dirname(__file__)
-CONFIG_PATH_ILAB = os.path.join(DIRNAME, "../configs/instruct_lab.yaml")
+CONFIG_PATH_ILAB = os.path.join(DIRNAME, "../configs/aadp.yaml")
-def test_framework_installs_ilab_padding_free_plugin():
+def test_framework_installs_aadp_padding_free_plugin():
with instantiate_framework(
read_configuration(CONFIG_PATH_ILAB), require_packages_check=False
) as framework:
diff --git a/plugins/instruct-lab/tox.ini b/plugins/attention-and-distributed-packing/tox.ini
similarity index 100%
rename from plugins/instruct-lab/tox.ini
rename to plugins/attention-and-distributed-packing/tox.ini
diff --git a/plugins/framework/src/fms_acceleration/constants.py b/plugins/framework/src/fms_acceleration/constants.py
index 3cdef252..6a81d977 100644
--- a/plugins/framework/src/fms_acceleration/constants.py
+++ b/plugins/framework/src/fms_acceleration/constants.py
@@ -21,4 +21,4 @@
# and activated.
# - hence the plugins that have model loaders should be on top of this list
-PLUGINS = ["peft", "foak", "ilab"]
+PLUGINS = ["peft", "foak", "aadp"]
diff --git a/sample-configurations/CONTENTS.yaml b/sample-configurations/CONTENTS.yaml
index f5dc6819..e2eccbc1 100644
--- a/sample-configurations/CONTENTS.yaml
+++ b/sample-configurations/CONTENTS.yaml
@@ -33,7 +33,7 @@ framework_configs:
- fused-ops-and-kernels
filename: accelerated-peft-bnb-nf4-foak-sample-configuration.yaml
- - shortname: ilab-padding-free
+ - shortname: aadp-padding-free
plugins:
- - instruct-lab
- filename: ilab-padding-free-sample-configuration.yaml
\ No newline at end of file
+ - attention-and-distributed-packing
+ filename: aadp-padding-free-sample-configuration.yaml
\ No newline at end of file
diff --git a/sample-configurations/ilab-padding-free-sample-configuration.yaml b/sample-configurations/aadp-padding-free-sample-configuration.yaml
similarity index 100%
rename from sample-configurations/ilab-padding-free-sample-configuration.yaml
rename to sample-configurations/aadp-padding-free-sample-configuration.yaml
diff --git a/scripts/generate_sample_configurations.py b/scripts/generate_sample_configurations.py
index c147df6a..3dd80b92 100644
--- a/scripts/generate_sample_configurations.py
+++ b/scripts/generate_sample_configurations.py
@@ -144,7 +144,7 @@ def read_configuration(path: str) -> Dict:
KEY_BNB_NF4_BASELINE = "baseline-bnb-nf4"
KEY_AUTO_GPTQ_FOAK = "auto-gptq-foak"
KEY_BNB_NF4_FOAK = "bnb-nf4-foak"
-KEY_ILAB_PADDING_FREE = "ilab-padding-free"
+KEY_AADP_PADDING_FREE = "aadp-padding-free"
CONFIGURATIONS = {
KEY_AUTO_GPTQ: "plugins/accelerated-peft/configs/autogptq.yaml",
@@ -167,7 +167,7 @@ def read_configuration(path: str) -> Dict:
"plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml",
[("peft.quantization.fused_ops_and_kernels.base_layer", "bitsandbytes")],
),
- KEY_ILAB_PADDING_FREE: "plugins/instruct-lab/configs/instruct_lab.yaml",
+ KEY_AADP_PADDING_FREE: "plugins/attention-and-distributed-packing/configs/aadp.yaml",
}
# list of (tag, combi) tuples
@@ -181,7 +181,7 @@ def read_configuration(path: str) -> Dict:
("baseline-peft-bnb-nf4", (KEY_BNB_NF4_BASELINE,)),
("accelerated-peft-autogptq-foak", (KEY_AUTO_GPTQ, KEY_AUTO_GPTQ_FOAK)),
("accelerated-peft-bnb-nf4-foak", (KEY_BNB_NF4, KEY_BNB_NF4_FOAK)),
- ("ilab-padding-free", (KEY_ILAB_PADDING_FREE,)),
+ ("aadp-padding-free", (KEY_AADP_PADDING_FREE,)),
]