Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kv cache e2e add #1000

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable, Dict, List, Optional, Tuple

import torch
from compressed_tensors.utils.offload import is_module_offloaded
from loguru import logger
from torch.nn import Module

Expand Down Expand Up @@ -282,6 +283,10 @@ def _apply_smoothing(self, model: Module):

@torch.no_grad()
def smooth(module):
offloaded = is_module_offloaded(module)
if offloaded:
module._hf_hook.pre_forward(module)

if module in balance_layers:
module.weight.mul_(scales.view(1, -1))
elif module == smooth_layer:
Expand All @@ -292,6 +297,9 @@ def smooth(module):
if hasattr(module, "bias") and module.bias is not None:
module.bias.div_(scales)

if offloaded:
module._hf_hook.post_forward(module, None)

parent = get_fsdp_parent(mapping.smooth_name, model)
if parent is not None:
parent.apply(smooth)
Expand All @@ -318,8 +326,16 @@ def _calculate_smoothing_scales(
# get the channel-wise dynamic range for each layer to be balanced
weight_scales = []
for layer in balance_layers:
offloaded = is_module_offloaded(layer)
if offloaded:
layer._hf_hook.pre_forward(layer)

scale = layer.weight.abs().max(dim=0, keepdim=True)[0]
weight_scales.append(scale)

if offloaded:
layer._hf_hook.post_forward(layer, None)

weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0]

# calculate the amount of smoothing to apply
Expand All @@ -329,4 +345,5 @@ def _calculate_smoothing_scales(
1 - self.smoothing_strength
)
scales = torch.where(weight_scales > 0.0, scales, activation_scales)

return scales
12 changes: 12 additions & 0 deletions src/llmcompressor/modifiers/smoothquant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@
),
]

GLM_SMOOTHQUANT_MAPPINGS: List[LayerMap] = [
LayerMap(
balance_layers=["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"],
smooth_layers="re:.*input_layernorm",
),
LayerMap(
balance_layers=["re:.*gate_up_proj"],
smooth_layers="re:.*post_attention_layernorm",
),
]


# Registry of layer mappings for different architectures
# Add more mappings here
Expand All @@ -53,6 +64,7 @@
"MistralForCausalLM": DEFAULT_SMOOTHQUANT_MAPPINGS,
"Qwen2ForCausalLM": DEFAULT_SMOOTHQUANT_MAPPINGS,
"BloomForCausalLM": BLOOM_SMOOTHQUANT_MAPPINGS,
"GlmForCausalLM": GLM_SMOOTHQUANT_MAPPINGS,
}


Expand Down
12 changes: 7 additions & 5 deletions src/llmcompressor/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from llmcompressor.core import active_session, create_session, pre_initialize_structure
from llmcompressor.pytorch.utils import ModuleSparsificationInfo
from llmcompressor.typing import Processor

COMPLETED_STAGES_FILENAME = "completed_stages.json"

Expand Down Expand Up @@ -92,15 +93,16 @@ def initialize_recipe(model: Module, recipe_path: str):
def save_model_and_recipe(
model: Module,
save_path: str,
tokenizer: Optional[Any] = None,
processor: Optional[Processor] = None,
save_safetensors: bool = False,
save_compressed: bool = False,
):
"""
Save a model, tokenizer and the currently loaded recipe to file
Save a model, processor and the currently loaded recipe to file

:param model: pytorch model to save
:param save_path: path to save output to
:param tokenizer: model tokenizer to save
:param processor: model processor or tokenizer to save
:param save_safetensors: whether to save as safetensors or pickle (bin)
:param save_compressed: whether to compress sparse weights on disk
"""
Expand All @@ -111,8 +113,8 @@ def save_model_and_recipe(
save_path, save_compressed=save_compressed, safe_serialization=save_safetensors
)

if tokenizer is not None:
tokenizer.save_pretrained(save_path)
if processor is not None:
processor.save_pretrained(save_path)

logger.info("Saving output to {}".format(os.path.abspath(save_path)))

Expand Down
44 changes: 31 additions & 13 deletions src/llmcompressor/transformers/finetune/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from compressed_tensors.registry import RegistryMixin
from datasets import Dataset, IterableDataset
from loguru import logger
from transformers import AutoTokenizer

from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments
from llmcompressor.transformers.finetune.data.data_helpers import (
LABELS_MASK_VALUE,
get_custom_datasets_from_path,
get_raw_dataset,
)
from llmcompressor.typing import Processor


class TextGenerationDataset(RegistryMixin):
Expand All @@ -30,10 +30,10 @@ def __init__(
text_column: str,
data_args: DataTrainingArguments,
split: str,
tokenizer: AutoTokenizer,
processor: Processor,
):
self.text_column = text_column
self.tokenizer = tokenizer
self.processor = processor
self.data_args = data_args
self.raw_kwargs = data_args.raw_kwargs or {}
self.split = split
Expand All @@ -50,20 +50,38 @@ def __init__(
else:
self.padding = False

if self.tokenizer:
# get tokenizer
self.tokenizer = getattr(self.processor, "tokenizer", self.processor)

if self.tokenizer is not None:
# fill in pad token
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token

# configure sequence length
max_seq_length = data_args.max_seq_length
model_max_length = tokenizer.model_max_length if tokenizer else max_seq_length
if self.tokenizer and max_seq_length > model_max_length:
logger.warning(
f"The max_seq_length passed ({max_seq_length}) is larger than "
f"the maximum length for the model ({tokenizer.model_max_length}). "
f"Using max_seq_length={tokenizer.model_max_length}."
# configure sequence length
max_seq_length = data_args.max_seq_length
if data_args.max_seq_length > self.tokenizer.model_max_length:
logger.warning(
f"The max_seq_length passed ({max_seq_length}) is larger than "
f"maximum length for model ({self.tokenizer.model_max_length}). "
f"Using max_seq_length={self.tokenizer.model_max_length}."
)
self.max_seq_length = min(
data_args.max_seq_length, self.tokenizer.model_max_length
)

# configure padding
self.padding = (
False
if self.data_args.concatenate_data
else "max_length"
if self.data_args.pad_to_max_length
else False
)
self.max_seq_length = min(data_args.max_seq_length, model_max_length)

else:
self.max_seq_length = None
self.padding = False

def get_raw_dataset(self, cache_dir: Optional[str] = None) -> Dataset:
"""
Expand Down
6 changes: 3 additions & 3 deletions src/llmcompressor/transformers/finetune/data/c4.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ class C4Dataset(TextGenerationDataset):

:param data_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param tokenizer: tokenizer to use on dataset
:param processor: processor or tokenizer to use on dataset
"""

def __init__(self, data_args, split, tokenizer):
def __init__(self, data_args, split, processor):
data_args = deepcopy(data_args)
data_args.dataset = "allenai/c4"
super().__init__(
text_column="text", data_args=data_args, split=split, tokenizer=tokenizer
text_column="text", data_args=data_args, split=split, processor=processor
)
6 changes: 3 additions & 3 deletions src/llmcompressor/transformers/finetune/data/cnn_dailymail.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,18 @@ class CNNDailyMailDataset(TextGenerationDataset):

:param data_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param tokenizer: tokenizer to use on dataset
:param processor: processor or tokenizer to use on dataset
"""

SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n"

def __init__(self, data_args, split, tokenizer):
def __init__(self, data_args, split, processor):
data_args = deepcopy(data_args)
data_args.dataset = "cnn_dailymail"
data_args.dataset_config_name = "3.0.0"

super().__init__(
text_column="text", data_args=data_args, split=split, tokenizer=tokenizer
text_column="text", data_args=data_args, split=split, processor=processor
)

def get_raw_dataset(self, cache_dir: Optional[str] = None):
Expand Down
6 changes: 3 additions & 3 deletions src/llmcompressor/transformers/finetune/data/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,17 @@ class CustomDataset(TextGenerationDataset):
:param data_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
Can also be set to None to load all the splits
:param tokenizer: tokenizer to use on dataset
:param processor: processor or tokenizer to use on dataset

"""

def __init__(self, data_args, split, tokenizer):
def __init__(self, data_args, split, processor):
data_args = deepcopy(data_args)
super().__init__(
text_column=data_args.text_column,
data_args=data_args,
split=split,
tokenizer=tokenizer,
processor=processor,
)
self.preprocessing_func = data_args.preprocessing_func
self.remove_columns = data_args.remove_columns
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class EvolCodeAlpacaDataset(TextGenerationDataset):

:param data_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param tokenizer: tokenizer to use on dataset
:param processor: processor or tokenizer to use on dataset
"""

EVOL_ALPACA_TEMPLATE = (
Expand All @@ -34,11 +34,11 @@ class EvolCodeAlpacaDataset(TextGenerationDataset):
"\n\n### Response:\n"
)

def __init__(self, data_args, split, tokenizer):
def __init__(self, data_args, split, processor):
data_args = deepcopy(data_args)
data_args.dataset = "theblackcat102/evol-codealpaca-v1"
super().__init__(
text_column="text", data_args=data_args, split=split, tokenizer=tokenizer
text_column="text", data_args=data_args, split=split, processor=processor
)

def get_raw_dataset(self, cache_dir: Optional[str] = None):
Expand Down
6 changes: 3 additions & 3 deletions src/llmcompressor/transformers/finetune/data/gsm8k.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ class GSM8KDataset(TextGenerationDataset):

:param data_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param tokenizer: tokenizer to use on dataset
:param processor: processor or tokenizer to use on dataset
"""

GSM_TEMPLATE = "Question: {question}\nAnswer:"

def __init__(self, data_args, split, tokenizer):
def __init__(self, data_args, split, processor):
data_args = deepcopy(data_args)
data_args.dataset = "gsm8k"
super().__init__(
text_column="text", data_args=data_args, split=split, tokenizer=tokenizer
text_column="text", data_args=data_args, split=split, processor=processor
)

def get_raw_dataset(self, cache_dir: Optional[str] = None):
Expand Down
6 changes: 3 additions & 3 deletions src/llmcompressor/transformers/finetune/data/open_platypus.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class OpenPlatypusDataset(TextGenerationDataset):

:param data_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param tokenizer: tokenizer to use on dataset
:param processor: processor or tokenizer to use on dataset
"""

ALPACA_TEMPLATE = {
Expand All @@ -37,11 +37,11 @@ class OpenPlatypusDataset(TextGenerationDataset):
"instruction}\n\n### Response:\n",
}

def __init__(self, data_args, split, tokenizer):
def __init__(self, data_args, split, processor):
data_args = deepcopy(data_args)
data_args.dataset = "garage-bAInd/Open-Platypus"
super().__init__(
text_column="text", data_args=data_args, split=split, tokenizer=tokenizer
text_column="text", data_args=data_args, split=split, processor=processor
)

def get_raw_dataset(self, cache_dir: Optional[str] = None):
Expand Down
6 changes: 3 additions & 3 deletions src/llmcompressor/transformers/finetune/data/ptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ class PtbDataset(TextGenerationDataset):

:param data_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param tokenizer: tokenizer to use on dataset
:param processor: processor or tokenizer to use on dataset
"""

def __init__(self, data_args, split, tokenizer):
def __init__(self, data_args, split, processor):
data_args = deepcopy(data_args)
data_args.dataset = "ptb_text_only"
super().__init__(
text_column="sentence",
data_args=data_args,
split=split,
tokenizer=tokenizer,
processor=processor,
)
10 changes: 6 additions & 4 deletions src/llmcompressor/transformers/finetune/data/ultrachat_200k.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class UltraChatDataset(TextGenerationDataset):

:param data_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param tokenizer: tokenizer to use on dataset
:param processor: processor or tokenizer to use on dataset
"""

DEFAULT_CHAT_TEMPLATE = (
Expand All @@ -40,7 +40,7 @@ class UltraChatDataset(TextGenerationDataset):
"{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
)

def __init__(self, data_args, split, tokenizer):
def __init__(self, data_args, split, processor):
data_args = deepcopy(data_args)
data_args.dataset = "HuggingFaceH4/ultrachat_200k"

Expand All @@ -51,13 +51,15 @@ def __init__(self, data_args, split, tokenizer):
text_column="messages",
data_args=data_args,
split=split,
tokenizer=tokenizer,
processor=processor,
)

if (
not hasattr(self.tokenizer, "chat_template")
or self.tokenizer.chat_template is None
):
# note that since tokenizer is a member of processor,
# this change affects processor.apply_chat_template
self.tokenizer.chat_template = self.DEFAULT_CHAT_TEMPLATE

def get_raw_dataset(self, cache_dir: Optional[str] = None):
Expand All @@ -75,7 +77,7 @@ def restructure_fn(sample):
if sample["messages"][0]["role"] != "system":
sample["messages"].insert(0, {"role": "system", "content": ""})

sample["messages"] = self.tokenizer.apply_chat_template(
sample["messages"] = self.processor.apply_chat_template(
sample["messages"], tokenize=False, add_generation_prompt=False
)
return sample
Expand Down
6 changes: 3 additions & 3 deletions src/llmcompressor/transformers/finetune/data/wikitext.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ class WikiTextDataset(TextGenerationDataset):

:param data_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param tokenizer: tokenizer to use on dataset
:param processor: processor or tokenizer to use on dataset
"""

def __init__(self, data_args, split, tokenizer):
def __init__(self, data_args, split, processor):
super().__init__(
text_column="text", data_args=data_args, split=split, tokenizer=tokenizer
text_column="text", data_args=data_args, split=split, processor=processor
)
Loading
Loading