Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to enable FSDP one-shot #58

Merged
merged 12 commits into from
Jun 3, 2024
2 changes: 2 additions & 0 deletions src/compressed_tensors/compressors/model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
iter_named_leaf_modules,
)
from compressed_tensors.utils import get_safetensors_folder
from compressed_tensors.utils.helpers import fix_fsdp_module_name
from torch import Tensor
from torch.nn import Module, Parameter
from tqdm import tqdm
Expand Down Expand Up @@ -260,6 +261,7 @@ def _get_weight_arg_mappings(model: Module) -> Dict:
for name, submodule in iter_named_leaf_modules(model):
if is_module_quantized(submodule):
if submodule.quantization_scheme.weights is not None:
name = fix_fsdp_module_name(name)
quantized_modules_to_args[name] = submodule.quantization_scheme.weights

return quantized_modules_to_args
17 changes: 16 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import re
from collections import OrderedDict
from typing import Dict, Iterable, Optional
Expand All @@ -35,6 +36,7 @@
infer_quantization_status,
iter_named_leaf_modules,
)
from compressed_tensors.utils.helpers import fix_fsdp_module_name
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
from torch.nn import Module

Expand All @@ -50,6 +52,9 @@
from compressed_tensors.utils.safetensors_load import get_quantization_state_dict


_LOGGER = logging.getLogger(__name__)


def load_pretrained_quantization(model: Module, model_name_or_path: str):
"""
Loads the quantization parameters (scale and zero point) from model_name_or_path to
Expand Down Expand Up @@ -105,15 +110,24 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
for target in scheme.targets:
target_to_scheme[target] = scheme

# list of submodules to ignore
ignored_submodules = []
# mark appropriate layers for quantization by setting their quantization schemes
for name, submodule in iter_named_leaf_modules(model):
# potentially fix module name to remove FSDP wrapper prefix
name = fix_fsdp_module_name(name)
if find_first_name_or_class_match(name, submodule, config.ignore):
ignored_submodules.append(name)
continue # layer matches ignore list, continue
target = find_first_name_or_class_match(name, submodule, target_to_scheme)
if target is not None:
# target matched - add layer and scheme to target list
submodule.quantization_scheme = target_to_scheme[target]

if set(config.ignore) - set(ignored_submodules):
_LOGGER.warning(
"Some layers that were to be ignored were "
f"not found in the model: {set(config.ignore) - set(ignored_submodules)}"
)
# apply current quantization status across all targeted layers
apply_quantization_status(model, config.quantization_status)

Expand Down Expand Up @@ -157,6 +171,7 @@ def _find_first_match(
# returns first element of target that matches value either
# exactly or as a regex after 're:'. if check_contains is set to True,
# additionally checks if the target string is contained with value.

for target in targets:
if target.startswith("re:"):
pattern = target[3:]
Expand Down
8 changes: 8 additions & 0 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ def quantize(
:param dtype: optional dtype to cast the quantized output to
:return: fake quantized tensor
"""
# ensure all tensors are on the same device
# assumes that the target device is the input
# tensor's device
if x.device != scale.device:
scale = scale.to(x.device)
if x.device != zero_point.device:
zero_point = zero_point.to(x.device)

return _process_quantization(
x=x,
scale=scale,
Expand Down
27 changes: 23 additions & 4 deletions src/compressed_tensors/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,27 @@
from typing import Optional

from compressed_tensors.base import SPARSITY_CONFIG_NAME
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.config import CompressionConfig
from transformers import AutoConfig


__all__ = ["infer_compressor_from_model_config"]
__all__ = ["infer_compressor_from_model_config", "fix_fsdp_module_name"]

FSDP_WRAPPER_NAME = "_fsdp_wrapped_module."
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved


def infer_compressor_from_model_config(
pretrained_model_name_or_path: str,
) -> Optional[ModelCompressor]:
) -> Optional["ModelCompressor"]: # noqa: F821
"""
Given a path to a model config, extract a sparsity config if it exists and return
the associated ModelCompressor

:param pretrained_model_name_or_path: path to model config on disk or HF hub
:return: matching compressor if config contains a sparsity config
"""
from compressed_tensors.compressors import ModelCompressor
from compressed_tensors.config import CompressionConfig

config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None)
if sparsity_config is None:
Expand All @@ -43,3 +46,19 @@ def infer_compressor_from_model_config(
sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config)
compressor = ModelCompressor.load_from_registry(format, config=sparsity_config)
return compressor


# TODO: There is already the same function in
# SparseML, should be moved to a shared location
# in the future
def fix_fsdp_module_name(name: str) -> str:
"""
Remove FSDP wrapper prefixes from a module name
Accounts for scenario where FSDP_WRAPPER_NAME is
at the end of the name, as well as in the middle.
:param name: name to strip
:return: stripped name
"""
return name.replace(FSDP_WRAPPER_NAME + ".", "").replace(
"." + FSDP_WRAPPER_NAME, ""
)
Loading