Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerate Utilities #193

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
bddc83c
wip
kylesayrs Oct 21, 2024
94d8c56
add modify_offload_module
kylesayrs Oct 23, 2024
f939e98
update docs
kylesayrs Oct 23, 2024
167e741
WIP
kylesayrs Oct 31, 2024
cb6edb1
cleanup functions, begin depreciation
kylesayrs Nov 18, 2024
cb70047
remove extra space
kylesayrs Nov 18, 2024
98a2889
revert get_offloaded_device
kylesayrs Nov 18, 2024
8cd69ef
update to align_module_device
kylesayrs Nov 19, 2024
0d23183
add requires skip for accelerate
kylesayrs Nov 19, 2024
82235b3
Merge remote-tracking branch 'origin' into kylesayrs/upstream-candidates
kylesayrs Nov 19, 2024
0b0d8b6
fix per token initialization
kylesayrs Nov 19, 2024
95e5907
remove align_module_device
kylesayrs Nov 19, 2024
a6a3198
Merge remote-tracking branch 'origin' into kylesayrs/upstream-candidates
kylesayrs Dec 2, 2024
e3c3f95
Merge remote-tracking branch 'origin' into kylesayrs/upstream-candidates
kylesayrs Dec 2, 2024
81a1eab
respond to nits
kylesayrs Dec 6, 2024
e7e1d81
Accelerate Utilities Follow-up (#224)
kylesayrs Dec 6, 2024
9af736f
rename
kylesayrs Dec 6, 2024
35fa1cd
implement recursive case
kylesayrs Dec 6, 2024
38765bd
remove print
kylesayrs Dec 6, 2024
64f4d98
support OffloadedWeightsLoader
kylesayrs Dec 6, 2024
b8ae387
add lifecycle docstring
kylesayrs Dec 10, 2024
870095e
implement offload_to_weights_map with recursive definition
kylesayrs Dec 16, 2024
77411ca
add docstring
kylesayrs Dec 16, 2024
a5b1792
fix type hint
kylesayrs Dec 16, 2024
ed9ee4e
add check_accelerate guard
kylesayrs Dec 16, 2024
1632cc3
make device used by clearer
kylesayrs Dec 16, 2024
1c55a10
update update_prefix_dict
kylesayrs Dec 17, 2024
9177650
reuse fixture
kylesayrs Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 19 additions & 45 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
from compressed_tensors.utils import get_execution_device, is_module_offloaded
from compressed_tensors.utils import (
disable_hf_hook,
has_offloaded_params,
register_offload_parameter,
)
from torch.nn import Module, Parameter


Expand Down Expand Up @@ -112,43 +116,10 @@ def initialize_module_for_quantization(
module.quantization_scheme = scheme
module.quantization_status = QuantizationStatus.INITIALIZED

offloaded = False
# What is this doing/why isn't this in the attn case?
if is_module_offloaded(module):
try:
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
from accelerate.utils import PrefixedDataset
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Offloaded model detected. To use CPU offloading with "
"compressed-tensors the `accelerate` package must be installed, "
"run `pip install compressed-tensors[accelerate]`"
)

offloaded = True
hook = module._hf_hook
prefix_dict = module._hf_hook.weights_map
new_prefix = {}

# recreate the prefix dict (since it is immutable)
# and add quantization parameters
for key, data in module.named_parameters():
if key not in prefix_dict:
new_prefix[f"{prefix_dict.prefix}{key}"] = data
else:
new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
remove_hook_from_module(module)

# wrap forward call of module to perform
# quantized actions based on calltime status
wrap_module_forward_quantized(module, scheme)

if offloaded:
# we need to re-add the hook for offloading now that we've wrapped forward
add_hook_to_module(module, hook)
if prefix_dict is not None:
module._hf_hook.weights_map = new_prefix_dict
with disable_hf_hook(module):
# wrap forward call of module to perform
# quantized actions based on calltime status
wrap_module_forward_quantized(module, scheme)


def is_attention_module(module: Module):
Expand All @@ -169,12 +140,15 @@ def _initialize_scale_zero_point(
if quantization_args.dynamic:
return

device = next(module.parameters()).device
if is_module_offloaded(module):
device = get_execution_device(module)
# begin on the same device as other parameters or cpu if offloaded
params_device = next(module.parameters()).device
device = "cpu" if has_offloaded_params(module) else params_device
kylesayrs marked this conversation as resolved.
Show resolved Hide resolved

# infer expected scale/zero point shape
expected_shape = 1 # per tensor
if quantization_args.strategy == QuantizationStrategy.TOKEN:
expected_shape = (1, 1)
kylesayrs marked this conversation as resolved.
Show resolved Hide resolved
else:
expected_shape = 1

if base_name == "weight" and weight_shape is not None:
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
Expand All @@ -193,15 +167,15 @@ def _initialize_scale_zero_point(
torch.empty(expected_shape, dtype=scale_dtype, device=device),
requires_grad=False,
)
module.register_parameter(f"{base_name}_scale", init_scale)
register_offload_parameter(module, f"{base_name}_scale", init_scale)

if force_zero_point or not quantization_args.symmetric:
zp_dtype = quantization_args.pytorch_dtype()
init_zero_point = Parameter(
torch.zeros(expected_shape, device=device, dtype=zp_dtype),
requires_grad=False,
)
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)

# only grouped activation ordering has g_idx
if quantization_args.actorder == ActivationOrdering.GROUP:
Expand All @@ -211,7 +185,7 @@ def _initialize_scale_zero_point(
torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
requires_grad=False,
)
module.register_parameter(f"{base_name}_g_idx", init_g_idx)
register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)


def _initialize_attn_scales(module: Module) -> None:
Expand Down
65 changes: 64 additions & 1 deletion src/compressed_tensors/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional
import warnings
from functools import wraps
from typing import Any, Callable, Dict, Optional

import torch
from transformers import AutoConfig
Expand All @@ -24,6 +26,8 @@
"tensor_follows_mask_structure",
"replace_module",
"is_compressed_tensors_config",
"getattr_chain",
"deprecated",
"Aliasable",
]

Expand Down Expand Up @@ -122,6 +126,65 @@ def is_compressed_tensors_config(compression_config: Any) -> bool:
return False


def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
kylesayrs marked this conversation as resolved.
Show resolved Hide resolved
"""
Chain multiple getattr calls, separated by `.`

:param obj: base object whose attributes are being retrieved
:param chain_str: attribute names separated by `.`
:param default: default value, throw error otherwise
"""
if len(args) >= 1:
has_default = True
default = args[0]
elif "default" in kwargs:
has_default = True
default = kwargs["default"]
else:
has_default = False

attr_names = chain_str.split(".")

res = obj
for attr_name in attr_names:
if not hasattr(res, attr_name):
if has_default:
return default
else:
raise AttributeError(f"{res} object has no attribute {attr_name}")
res = getattr(res, attr_name)

return res


def deprecated(future_name: Optional[str] = None, message: Optional[str] = None):
kylesayrs marked this conversation as resolved.
Show resolved Hide resolved
"""
Decorator to mark functions as deprecated

:param new_function: Function called in place of depreciated function
:param message: Depreciation message, replaces default depreciation message
"""

def decorator(func: Callable[[Any], Any]):
kylesayrs marked this conversation as resolved.
Show resolved Hide resolved
nonlocal message

if message is None:
message = (
f"{func.__name__} is deprecated and will be removed in a future release"
)
if future_name is not None:
message += f". Please use {future_name} instead."

kylesayrs marked this conversation as resolved.
Show resolved Hide resolved
@wraps(func)
def wrapped(*args, **kwargs):
warnings.warn(message, DeprecationWarning, stacklevel=2)
return func(*args, **kwargs)

return wrapped

return decorator


class Aliasable:
"""
A mixin for enums to allow aliasing of enum members
Expand Down
Loading