Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shift Computation of PaddingFree Variable CuSeqLen from Flash Attention Forward to DataCollatorWithFlattening #65

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build-and-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
- "framework"
- "accelerated-peft"
- "fused-ops-and-kernels"
- "attention-and-distributed-packing"

permissions:
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- "framework"
- "accelerated-peft"
- "fused-ops-and-kernels"
- "instruct-lab"
- "attention-and-distributed-packing"

steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Plugin | Description | Depends | License | Status
[framework](./plugins/framework/README.md) | This acceleration framework for integration with huggingface trainers | | | Alpha
[accelerated-peft](./plugins/accelerated-peft/README.md) | For PEFT-training, e.g., 4bit QLoRA. | Huggingface<br>AutoGPTQ | Apache 2.0<br>MIT | Alpha
[fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 [(contains extracted code)](./plugins/fused-ops-and-kernels/README.md#code-extracted-from-unsloth)| Beta
[instruct-lab](./plugins/instruct-lab/README.md) | Padding-Free Flash Attention Computation | flash-attn | Apache 2.0 | Beta
[attention-and-distributed-packing](./plugins/attention-and-distributed-packing/README.md) | Padding-Free Flash Attention Computation | flash-attn | Apache 2.0 | Beta
MOE-training-acceleration | [MegaBlocks](https://github.com/databricks/megablocks) inspired triton Kernels and acclerations for Mixture-of-Expert models | | Apache 2.0 | Coming Soon

## Usage with FMS HF Tuning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,3 @@ import_heading_firstparty=First Party
import_heading_localfolder=Local
known_firstparty=
known_localfolder=tuning

# skip code imported from unsloth
skip_glob=**/unsloth*/**
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# FMS Acceleration for Instruct Lab
# FMS Acceleration for Attention And Distributed Packing Plugin

This library contains plugins to accelerate finetuning with the following optimizations:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "fms-acceleration-ilab"
name = "fms-acceleration-aadp"
version = '0.0.1'
description = "FMS Acceleration Plugin for Functionalities Used in Instruct Lab Training"
description = "FMS Acceleration Plugin for Attention and Distributed Packing Optimizations"
authors = [
{name = "Fabian Lim", email = "flim@sg.ibm.com"},
{name = "Aaron Chew", email = "aaron.chew1@ibm.com"},
Expand All @@ -24,7 +24,7 @@ classifiers=[
]

[tool.hatch.build.targets.wheel]
only-include = ["src/fms_acceleration_ilab"]
only-include = ["src/fms_acceleration_aadp"]

[tool.hatch.build.targets.wheel.sources]
"src" = ""
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from dataclasses import dataclass
import warnings
from transformers import DefaultDataCollator, default_data_collator

import torch

@dataclass
class DataCollatorWithFlattening(DefaultDataCollator):
Expand Down Expand Up @@ -51,4 +51,24 @@ def __call__(self, features, return_tensors=None):
ret["labels"] += [-100] + feature["labels"][1:]
else:
ret["labels"] += [-100] + feature["input_ids"][1:]
return default_data_collator([ret], return_tensors)

position_ids = torch.tensor(ret["position_ids"]).flatten()
indices_q = torch.arange(
position_ids.size(0), device=position_ids.device, dtype=torch.int32
)
cu_seq_lens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(
position_ids.size(), dtype=torch.int32
),
)
)
max_length = position_ids.max() + 1

# return default_data_collator([ret], return_tensors)
return {
**default_data_collator([ret], return_tensors),
"cu_seq_lens": cu_seq_lens,
"max_length": max_length,
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,36 +32,27 @@
inspect.signature(flash_attn_func).parameters
)

# model id -> position_ids
POSITION_IDS_CACHE = {}
CU_SEQ_LENS_CACHE = {}
MAX_SEQ_LENS_CACHE = {}

def prepare_fa2_from_position_ids(query, key, value, position_ids, query_length):
query = query.view(-1, query.size(-2), query.size(-1))
key = key.view(-1, key.size(-2), key.size(-1))
value = value.view(-1, value.size(-2), value.size(-1))
position_ids = position_ids.flatten()
indices_q = torch.arange(
position_ids.size(0), device=position_ids.device, dtype=torch.int32
)
cu_seq_lens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(
position_ids.size(), device=position_ids.device, dtype=torch.int32
),
)
)
max_length = position_ids.max() + 1
return (
query,
key,
value,
indices_q,
(cu_seq_lens, cu_seq_lens),
(max_length, max_length),
)
# This is used to patch the top-level model to accept cuseqlen
# and maxseqlen as additional args that are cached for attention
# computation
def build_toplevel_model_forward(
model: torch.nn.Module,
model_id: str,
):
# forward
old_forward = model.forward

def forward(self, *args, cu_seq_lens, max_length, **kwargs):
CU_SEQ_LENS_CACHE[model_id] = (cu_seq_lens, cu_seq_lens)
MAX_SEQ_LENS_CACHE[model_id] = (max_length, max_length)
return old_forward(*args, **kwargs)

# model id -> position_ids
POSITION_IDS_CACHE = {}
return forward


# - needed to store position ids when first come into model
Expand Down Expand Up @@ -123,6 +114,8 @@ def _flash_attention_forward_with_posids(
):
# get the position ids out here
position_ids = POSITION_IDS_CACHE[model_id]
cu_seqlens_q, cu_seqlens_k = CU_SEQ_LENS_CACHE[model_id]
max_seqlen_in_batch_q, max_seqlen_in_batch_k = MAX_SEQ_LENS_CACHE[model_id]

if not use_top_left_mask:
causal = is_causal
Expand Down Expand Up @@ -161,19 +154,10 @@ def _flash_attention_forward_with_posids(
assert attention_mask is None, "should not be using attention mask"
assert position_ids is not None, "should be expecting position ids"
batch_size = query_states.size(0)
(
query_states,
key_states,
value_states,
_,
cu_seq_lens,
max_seq_lens,
) = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids, query_length
)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
query_states = query_states.view(-1, query_states.size(-2), query_states.size(-1))
key_states = key_states.view(-1, key_states.size(-2), key_states.size(-1))
value_states = value_states.view(-1, value_states.size(-2), value_states.size(-1))

attn_output = flash_attn_varlen_func(
query_states,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _is_backbone(module: torch.nn.Module):
# - patch backbone
model_type = model.config.model_type
# pylint: disable=import-outside-toplevel
from .flash_attn import build_backbone_forward
from .flash_attn import build_backbone_forward, build_toplevel_model_forward

ModelPatcher.register(
ModelPatcherRule(
Expand All @@ -106,6 +106,19 @@ def _is_backbone(module: torch.nn.Module):
),
)

# Need to patch the top-level model to accept and cache additional
# kwargs, cu_seq_lens and max_len from data collator
ModelPatcher.register(
ModelPatcherRule(
rule_id=f"{model_type}-cumseqlen-cache",
trigger=ModelPatcherTrigger(check=model.__class__),
forward_builder=partial(
build_toplevel_model_forward,
model_id=id(model),
),
),
)

# Next, the flash attention function needs to be patched
# how it is patched depends on the transformers version
try:
Expand Down Expand Up @@ -185,7 +198,7 @@ def _patch_dataloader(
except ImportError:
# Otherwise, use the locally implemented DataCollatorWithFlattening
# pylint: disable=import-outside-toplevel
from .ilab_utils import (
from .aadp_utils import (
DataCollatorWithFlattening,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
instantiate_framework,
read_configuration,
)
from fms_acceleration_ilab import PaddingFreeAccelerationPlugin
from fms_acceleration_aadp import PaddingFreeAccelerationPlugin

# configuration
DIRNAME = os.path.dirname(__file__)
CONFIG_PATH_ILAB = os.path.join(DIRNAME, "../configs/instruct_lab.yaml")
CONFIG_PATH_ILAB = os.path.join(DIRNAME, "../configs/aadp.yaml")

def test_framework_installs_ilab_padding_free_plugin():
def test_framework_installs_aadp_padding_free_plugin():
with instantiate_framework(
read_configuration(CONFIG_PATH_ILAB), require_packages_check=False
) as framework:
Expand Down
2 changes: 1 addition & 1 deletion plugins/framework/src/fms_acceleration/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
# and activated.
# - hence the plugins that have model loaders should be on top of this list

PLUGINS = ["peft", "foak", "ilab"]
PLUGINS = ["peft", "foak", "aadp"]
6 changes: 3 additions & 3 deletions sample-configurations/CONTENTS.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ framework_configs:
- fused-ops-and-kernels
filename: accelerated-peft-bnb-nf4-foak-sample-configuration.yaml

- shortname: ilab-padding-free
- shortname: aadp-padding-free
plugins:
- instruct-lab
filename: ilab-padding-free-sample-configuration.yaml
- attention-and-distributed-packing
filename: aadp-padding-free-sample-configuration.yaml
6 changes: 3 additions & 3 deletions scripts/generate_sample_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def read_configuration(path: str) -> Dict:
KEY_BNB_NF4_BASELINE = "baseline-bnb-nf4"
KEY_AUTO_GPTQ_FOAK = "auto-gptq-foak"
KEY_BNB_NF4_FOAK = "bnb-nf4-foak"
KEY_ILAB_PADDING_FREE = "ilab-padding-free"
KEY_AADP_PADDING_FREE = "aadp-padding-free"

CONFIGURATIONS = {
KEY_AUTO_GPTQ: "plugins/accelerated-peft/configs/autogptq.yaml",
Expand All @@ -167,7 +167,7 @@ def read_configuration(path: str) -> Dict:
"plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml",
[("peft.quantization.fused_ops_and_kernels.base_layer", "bitsandbytes")],
),
KEY_ILAB_PADDING_FREE: "plugins/instruct-lab/configs/instruct_lab.yaml",
KEY_AADP_PADDING_FREE: "plugins/attention-and-distributed-packing/configs/aadp.yaml",
}

# list of (tag, combi) tuples
Expand All @@ -181,7 +181,7 @@ def read_configuration(path: str) -> Dict:
("baseline-peft-bnb-nf4", (KEY_BNB_NF4_BASELINE,)),
("accelerated-peft-autogptq-foak", (KEY_AUTO_GPTQ, KEY_AUTO_GPTQ_FOAK)),
("accelerated-peft-bnb-nf4-foak", (KEY_BNB_NF4, KEY_BNB_NF4_FOAK)),
("ilab-padding-free", (KEY_ILAB_PADDING_FREE,)),
("aadp-padding-free", (KEY_AADP_PADDING_FREE,)),
]


Expand Down
Loading