Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Offloading Support Pt 2 #34

Merged
merged 22 commits into from
Jul 30, 2024
59 changes: 59 additions & 0 deletions examples/big_model_offloading/big_model_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
from transformers import AutoTokenizer

from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot
from llmcompressor.transformers.compression.helpers import ( # noqa
calculate_offload_device_map,
custom_offload_device_map,
)

# define a llmcompressor recipe for FP8 quantization
# this recipe requires no calibration data since inputs are dynamically quantized
recipe = """
quant_stage:
quant_modifiers:
QuantizationModifier:
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 8
type: float
strategy: channel
dynamic: false
symmetric: true
input_activations:
num_bits: 8
type: float
strategy: token
dynamic: true
symmetric: true
targets: ["Linear"]
"""

model_stub = "meta-llama/Meta-Llama-3-70B-Instruct"

# determine which layers to offload to cpu based on available resources
device_map = calculate_offload_device_map(
model_stub, reserve_for_hessians=False, num_gpus=1, torch_dtype=torch.float16
)

# alternatively, specify the maximum memory to allocate per GPU directly
# device_map = custom_offload_device_map(
# model_stub, max_memory_per_gpu="10GB", num_gpus=2, torch_dtype=torch.float16
# )

model = SparseAutoModelForCausalLM.from_pretrained(
model_stub, torch_dtype=torch.float16, device_map=device_map
)

output_dir = "./test_output_llama3b_70b_fp8"


oneshot(
model=model,
recipe=recipe,
output_dir=output_dir,
save_compressed=True,
tokenizer=AutoTokenizer.from_pretrained(model_stub),
)
92 changes: 92 additions & 0 deletions examples/big_model_offloading/big_model_w8a8_calibrate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import torch
from datasets import load_dataset
from transformers import AutoTokenizer

from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot
from llmcompressor.transformers.compression.helpers import ( # noqa
calculate_offload_device_map,
custom_offload_device_map,
)

# define a llmcompressor recipe for FP8 quantization
# this recipe requires calibration
recipe = """
quant_stage:
quant_modifiers:
GPTQModifier:
sequential_update: true
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 8
type: int
strategy: tensor
dynamic: false
symmetric: true
input_activations:
num_bits: 8
type: float
strategy: tensor
dynamic: false
symmetric: true
targets: ["Linear"]
"""

model_stub = "meta-llama/Meta-Llama-3-70B-Instruct"

device_map = custom_offload_device_map(
model_stub, max_memory_per_gpu="74GB", num_gpus=1, torch_dtype=torch.float16
)

model = SparseAutoModelForCausalLM.from_pretrained(
model_stub, torch_dtype=torch.float16, device_map=device_map
)
tokenizer = AutoTokenizer.from_pretrained(model_stub)
output_dir = "./output_llama3b_70b_w8a8"

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
NUM_CALIBRATION_SAMPLES = 4
MAX_SEQUENCE_LENGTH = 512

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))


def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)


ds = ds.map(tokenize, remove_columns=ds.column_names)


oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
save_compressed=True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@

import torch
import torch.nn as nn
from compressed_tensors.utils import (
get_offloaded_device,
is_module_offloaded,
update_prefix_dict,
)
from loguru import logger

__all__ = ["GPTQWrapper"]
Expand Down Expand Up @@ -74,6 +79,9 @@ def compress(
:param percdamp: Amount of dampening to apply to H, as a fraction of the
diagonal norm
"""
if is_module_offloaded(self.layer):
self.layer._hf_hook.pre_forward(self.layer)

final_shape = self.layer.weight.shape
final_dtype = self.layer.weight.dtype
W = self.layer.weight.data.clone()
Expand Down Expand Up @@ -106,7 +114,6 @@ def compress(

update_layer_weight_quant_params(self.layer)


dead = torch.diag(self.H) == 0
self.H[dead, dead] = 1
W[:, dead] = 0
Expand Down Expand Up @@ -227,6 +234,15 @@ def compress(
self.layer.weight -= self.layer.weight
self.layer.weight += W

if is_module_offloaded(self.layer):
device = get_offloaded_device(self.layer)
update_prefix_dict(self.layer, "weight", self.layer.weight.to(device))
self.layer._hf_hook.post_forward(self.layer, None)

del W
del Losses
del diag

def free(self):
"""
Free the Hessian memory after the layer is complete
Expand Down
25 changes: 24 additions & 1 deletion src/llmcompressor/modifiers/quantization/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ def on_initialize(
module = state.model

# intialize quantization in appropriate modules
self._apply_modifier_to_model(module)
config = self._apply_modifier_to_model(module)

if self.calculate_start() == -1: # one-shot
self._check_calibration_data(config)
module.apply(set_module_for_calibration)
self._calibrate_if_possible(module)
self._check_token_distribution(
Expand Down Expand Up @@ -167,9 +168,27 @@ def check_should_disable_observer(self, event: Event) -> bool:
return True
return False

def _check_calibration_data(self, config: QuantizationConfig):
has_calibration_data = self.calibration_dataloader_ is not None
requires_calibration = config.requires_calibration_data()
if self.calculate_start() == -1: # one shot
if requires_calibration and not has_calibration_data:
raise ValueError(
"The provided quantization configuration requires calibration data "
"but none was provided. Calibration data is required for static "
"quantization of input or output activations."
)
if not requires_calibration and has_calibration_data:
logger.info(
"Skipping QuantizationModifier calibration, it is not required for "
"the provided quantization config."
)
self.calibration_dataloader_ = None

def _apply_modifier_to_model(self, model: Module):
modifier_as_config = self.create_init_config()
apply_quantization_config(model, modifier_as_config)
return modifier_as_config

def _calibrate_if_possible(self, module: Module):
if self.num_calibration_steps == 0 and self.calibration_dataloader_:
Expand Down Expand Up @@ -226,6 +245,10 @@ def _check_token_distribution(
logger.debug("Skipping token distribution check. threshold is None.")
return

if self.calibration_dataloader_ is None:
logger.debug("Skipping token distribution check. No calibration data.")
return

all_tokens = self.calibration_dataloader_.dataset["input_ids"]
total_token_count = sum(len(sample) for sample in all_tokens)
counter = get_observer_token_count(model)
Expand Down
3 changes: 3 additions & 0 deletions src/llmcompressor/modifiers/utils/compression_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ def __init__(self, name, layer):

self.name = name
self.layer = layer

self.dev = self.layer.weight.device
if hasattr(self.layer, "_hf_hook") and self.layer._hf_hook.offload:
self.dev = self.layer._hf_hook.execution_device

# Calculate weight shape to use during pruning
W = self.layer.weight
Expand Down
5 changes: 4 additions & 1 deletion src/llmcompressor/modifiers/utils/layer_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def revert_layer_wrappers(self):
set_layer(full_name, module_wrapper.layer, self.model)
else:
set_layer(name, module_wrapper.layer, self.layer)
module_wrapper.free()
torch.cuda.empty_cache()
self.modules = None

def compress(self):
Expand All @@ -128,8 +128,11 @@ def compress_module(module):
full_name = self._get_full_submodule_name(module.name)
logger.info(f"Compressing {full_name}...")
module.compress(**self.args)
module.free()
print("done")

self.layer.apply(compress_module)
torch.cuda.empty_cache()

def _get_full_submodule_name(self, name):
full_name = ".".join(x for x in [self.name, name] if len(x) > 0)
Expand Down
4 changes: 4 additions & 0 deletions src/llmcompressor/modifiers/utils/pytorch_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,7 @@ def run_calibration_forward(
batch = tensors_to_device(batch, model_device)
with torch.no_grad():
forward_fn(batch, module=model)
# TODO: not ideal, figure out where we aren't freeing memory instead
# currently without this we run OOM on the 2nd forward pass
torch.cuda.empty_cache()
torch.cuda.empty_cache()
2 changes: 1 addition & 1 deletion src/llmcompressor/pytorch/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def tensor_sparsity(
:return: the sparsity of the input tens, ie the fraction of numbers that are zero
"""
if dim is None:
zeros = (tens == 0).sum()
zeros = (tens.cpu() == 0).sum()
total = tens.numel()

return zeros.float() / float(total)
Expand Down
Loading
Loading