Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

additional fixes for HFQuantizer compatibility #136

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/compressed_tensors/compressors/model_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ def parse_quantization_config(compression_config: Dict) -> Union[Dict, None]:
if hasattr(compression_config, QUANTIZATION_CONFIG_NAME):
# for loaded HFQuantizer config
return getattr(compression_config, QUANTIZATION_CONFIG_NAME)
elif isinstance(compression_config, dict) and (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already merged

QUANTIZATION_CONFIG_NAME in compression_config
):
return compression_config[QUANTIZATION_CONFIG_NAME]

# SparseAutoModel format
quantization_config = deepcopy(compression_config)
Expand Down
6 changes: 3 additions & 3 deletions src/compressed_tensors/quantization/observers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def get_observer_token_count(module: torch.nn.Module) -> Counter:
token_counts = Counter()
for name, module in module.named_modules():
if name.endswith(".input_observer"):
token_counts[name.replace(".input_observer", "")] = (
module._num_observed_tokens
)
token_counts[
name.replace(".input_observer", "")
] = module._num_observed_tokens
return token_counts


Expand Down
16 changes: 14 additions & 2 deletions src/compressed_tensors/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import torch
from torch.nn import Module
from torch.nn import Module, Parameter


__all__ = [
Expand Down Expand Up @@ -100,7 +100,19 @@ def update_parameter_data(

parameter = getattr(module, param_name, None)
dtype = parameter.dtype
parameter.data = new_param_data.to(device).to(dtype)
try:
parameter.data = new_param_data.to(device).to(dtype)
except RuntimeError:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is better handled by #193

# exception may occur when trying to overwrite meta device, overriding
# parameter directly
setattr(
module,
param_name,
Parameter(
data=new_param_data.to(device).to(dtype),
requires_grad=parameter.requires_grad,
),
)

if offloaded:
prefix_dict = module._hf_hook.weights_map.dataset
Expand Down
Loading