-
Notifications
You must be signed in to change notification settings - Fork 346
Add parq utility to create an optimizer #3165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+145
−0
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
from dataclasses import dataclass | ||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type | ||
|
||
import torch | ||
|
||
from torchao.prototype.parq.optim import QuantOptimizer | ||
from torchao.prototype.parq.quant import ( | ||
Quantizer, | ||
StretchedUnifTorchaoQuantizer, | ||
UnifTorchaoQuantizer, | ||
) | ||
|
||
|
||
@dataclass(frozen=True, slots=True) | ||
class QuantConfig: | ||
bitwidth: int | ||
group_size: Optional[int] = None | ||
quantizer: Optional[Quantizer] = None | ||
|
||
def __post_init__(self): | ||
if self.bitwidth < 2: | ||
raise ValueError("bitwidth must be >= 2") | ||
if self.group_size is not None and self.group_size <= 0: | ||
raise ValueError("group_size must be positive") | ||
|
||
if self.quantizer is None: | ||
if self.bitwidth in [2, 3]: | ||
q = StretchedUnifTorchaoQuantizer(b=self.bitwidth) | ||
else: | ||
q = UnifTorchaoQuantizer() | ||
object.__setattr__(self, "quantizer", q) | ||
|
||
|
||
def create_param_groups_and_group_quantizer_map( | ||
model: torch.nn.Module, | ||
quant_configs_and_filter_fns: List[ | ||
Tuple[QuantConfig, Callable[[torch.nn.Module, str], bool]] | ||
], | ||
): | ||
param_groups = [] | ||
group_quantizer_map = {} | ||
for idx, (config, _) in enumerate(quant_configs_and_filter_fns): | ||
params_quant = [] | ||
param_group = { | ||
"params": params_quant, | ||
"quant_bits": config.bitwidth, | ||
} | ||
if config.group_size is not None: | ||
param_group["quant_block_size"] = config.group_size | ||
param_group["_quantizer"] = config.quantizer | ||
param_groups.append(param_group) | ||
|
||
# Non-quantized group at end so that index in param_groups | ||
# is the index in the subset of quantized param groups, which is | ||
# used in defining group_quantizer_map | ||
params_no_quant = [] | ||
param_groups.append({"params": params_no_quant, "weight_decay": 0.0}) | ||
|
||
seen_data_ptrs = {} | ||
for param_name, param in model.named_parameters(): | ||
module_name, _, param_basename = param_name.rpartition(".") | ||
owning_module = model.get_submodule(module_name) if module_name else model | ||
|
||
data_ptr = param.data_ptr() | ||
if data_ptr in seen_data_ptrs: | ||
print( | ||
f"Not considering {param} because it shares a data_ptr with {seen_data_ptrs[data_ptr]}, which was previously considered" | ||
) | ||
continue | ||
seen_data_ptrs[data_ptr] = param_name | ||
|
||
print( | ||
"param_name", | ||
param_name, | ||
"module_type", | ||
type(owning_module), | ||
"matching_config:", | ||
end="", | ||
) | ||
matching_config = None | ||
for idx, (config, filter_fn) in enumerate(quant_configs_and_filter_fns): | ||
if filter_fn(owning_module, param_name): | ||
param_groups[idx]["params"].append(param) | ||
if matching_config is None: | ||
matching_config = config | ||
print(f"{config.bitwidth},{config.group_size}") | ||
else: | ||
raise ValueError( | ||
f"Found multiple matching configs for {param_name}. Previous match={matching_config}, new match={config}." | ||
) | ||
|
||
# If no match, add to no-quant group at last idx | ||
if matching_config is None: | ||
print("NONE") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this for debugging? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After each parameter, it prints out config (bitwidth,groupsize or NONE) |
||
param_groups[-1]["params"].append(param) | ||
|
||
# Filter out empty param groups | ||
param_groups = [pg for pg in param_groups if len(pg["params"]) > 0] | ||
|
||
# After filter define group_quantizer_map | ||
# The index in group_quantizer_map must correspond to index in | ||
# quantized params | ||
group_quantizer_map = {} | ||
for idx, param_group in enumerate(param_groups): | ||
if "_quantizer" in param_group: | ||
group_quantizer_map[idx] = param_group.pop("_quantizer") | ||
|
||
expected_n_params = sum(1 for p in model.parameters()) | ||
n_found_params = sum(len(pg["params"]) for pg in param_groups) | ||
assert n_found_params == expected_n_params, ( | ||
f"{n_found_params} != {expected_n_params=}" | ||
) | ||
|
||
return param_groups, group_quantizer_map | ||
|
||
|
||
from torchao.prototype.parq import ProxHardQuant | ||
|
||
|
||
def create_optimizer( | ||
model: torch.nn.Module, | ||
quant_configs_and_filter_fns: List[ | ||
Tuple[QuantConfig, Callable[[torch.nn.Module, str], bool]] | ||
], | ||
base_optimizer_cls: Type[torch.optim.Optimizer], | ||
base_optimizer_kwargs: Dict[str, Any], | ||
*, | ||
warmup_steps: int = 0, | ||
quant_period: int = 1, | ||
quant_per_channel: bool = True, | ||
): | ||
param_groups, group_quantizer_map = create_param_groups_and_group_quantizer_map( | ||
model, quant_configs_and_filter_fns | ||
) | ||
base_optimizer = base_optimizer_cls(param_groups, **base_optimizer_kwargs) | ||
optimizer = QuantOptimizer( | ||
base_optimizer, | ||
quantizer=UnifTorchaoQuantizer(), | ||
prox_map=ProxHardQuant(), | ||
warmup_steps=warmup_steps, | ||
quant_period=quant_period, | ||
quant_per_channel=quant_per_channel, | ||
group_quantizer_map=group_quantizer_map, | ||
) | ||
return optimizer |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Standardizing the indices this way is a great idea