From 39c5a5c4c83748ad5070572697fde5b650ec832c Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Thu, 9 Nov 2023 17:43:13 +0000 Subject: [PATCH 1/2] quantize model to 8-bit during ckpt reshard --- llama/quant_util.py | 71 +++++++++++++++++++++++++++++++++++++ llama/xla_model_parallel.py | 43 ++++++++++++++++++---- reshard_checkpoints.py | 10 ++++-- 3 files changed, 116 insertions(+), 8 deletions(-) create mode 100644 llama/quant_util.py diff --git a/llama/quant_util.py b/llama/quant_util.py new file mode 100644 index 000000000..66c2e2b1a --- /dev/null +++ b/llama/quant_util.py @@ -0,0 +1,71 @@ +from copy import deepcopy +from dataclasses import dataclass + +import torch +import torch.ao.quantization.fx._decomposed +from typing import Optional + +EPS = torch.finfo(torch.float32).eps + +@dataclass +class TensorQConfig: + dtype: torch.dtype = torch.int8 + axis: int = -1 + quant_min: int = -128 + quant_max: int = 127 + symmetric_quant: bool = True + + +def _get_dtype_min_max(dtype: torch.dtype): + if dtype == torch.int8: + return -128, 127 + elif dtype == torch.uint8: + return 0, 127 + else: + assert False + +def _find_per_channel_min_max(x: torch.Tensor, axis: int): + x_dim = x.size() + new_axis_list = [i for i in range(len(x_dim))] + new_axis_list[axis] = 0 + new_axis_list[0] = axis + y = x.permute(new_axis_list) + y = torch.flatten(y, start_dim=1) + return torch.aminmax(y, dim=1) + +def _find_qparams(x: torch.Tensor, qconfig : TensorQConfig): + # Only support per-channel symmetric quant to int8 now + axis = qconfig.axis + dtype = qconfig.dtype + symmetric_quant = qconfig.symmetric_quant + quant_min = qconfig.quant_min + quant_max = qconfig.quant_max + assert axis >= 0 and axis < len(x.shape) + assert dtype == torch.int8 + min_val, max_val = _find_per_channel_min_max(x, axis) + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + scale = torch.ones(min_val_neg.size(), dtype=torch.float32) + if symmetric_quant: + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + eps = torch.zeros_like(scale).fill_(EPS) + scale = torch.max(scale, eps) + return scale, None + else: + assert symmetric_quant + +def _quantize_to_dtype(x: torch.Tensor, qconfig: TensorQConfig, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor] = None): + if zero_point is None: + zero_point = torch.zeros_like(scale) + return torch.ops.quantized_decomposed.quantize_per_channel( + x, scale, zero_point, qconfig.axis, qconfig.quant_min, + qconfig.quant_max, qconfig.dtype + ) + +def quantize_tensor(x: torch.Tensor, qconfig : TensorQConfig): + scale, zp = _find_qparams(x, qconfig) + x_int = _quantize_to_dtype(x, qconfig, scale, zp) + return x_int, scale, zp diff --git a/llama/xla_model_parallel.py b/llama/xla_model_parallel.py index f0b323976..11d6a8094 100644 --- a/llama/xla_model_parallel.py +++ b/llama/xla_model_parallel.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import Callable, Optional, List, Any import torch @@ -10,6 +11,8 @@ from fairscale.nn.model_parallel.utils import divide_and_check_no_remainder, split_tensor_along_last_dim +from .quant_util import TensorQConfig, quantize_tensor + import os USE_CUDA = os.environ.get('USE_CUDA', False) @@ -394,9 +397,8 @@ def __init__( if quant: self.weight = Parameter(torch.empty( (self.output_size_per_partition, self.in_features), - dtype=torch.int8), - requires_grad=False) - self.weight_scaler = Parameter(torch.zeros(1), requires_grad=False) + dtype=torch.int8), requires_grad=False) + self.weight_scaler = Parameter(torch.Tensor(self.output_size_per_partition)) else: self.weight = Parameter( torch.Tensor(self.output_size_per_partition, self.in_features)) @@ -427,6 +429,21 @@ def get_master_weight(self) -> torch.Tensor: self.weight.data.transpose(0, 1), self.groups, self.world_size, self.rank).transpose_(0, 1) + def quantize(self): + assert self.quant == False + fp_w = deepcopy(self.weight.data) + orig_dtype = fp_w.dtype + fp_w = fp_w.to(torch.float32) + self.weight = Parameter( + torch.empty((self.output_size_per_partition, self.in_features), dtype=torch.int8), + requires_grad=False, + ) + self.weight_scaler = Parameter(torch.Tensor(self.output_size_per_partition)) + qconfig = TensorQConfig(axis=0) + self.weight.data, scale, zero_point = quantize_tensor(fp_w, qconfig) + self.weight_scaler.data = scale.to(orig_dtype) + self.quant = True + def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore # Set up backprop all-reduce. input_parallel = copy_to_model_parallel_region(input_, self.groups, @@ -523,9 +540,8 @@ def __init__( if quant: self.weight = Parameter(torch.empty( (self.out_features, self.input_size_per_partition), - dtype=torch.int8), - requires_grad=False) - self.weight_scaler = Parameter(torch.zeros(1), requires_grad=False) + dtype=torch.int8), requires_grad=False) + self.weight_scaler = Parameter(torch.Tensor(self.out_features)) else: self.weight = Parameter( torch.Tensor(self.out_features, self.input_size_per_partition)) @@ -555,6 +571,21 @@ def get_master_weight(self) -> torch.Tensor: return gather_from_model_parallel_region(self.weight.data, self.groups, self.world_size, self.rank) + def quantize(self): + assert self.quant == False + fp_w = deepcopy(self.weight.data) + orig_dtype = fp_w.dtype + fp_w = fp_w.to(torch.float32) + self.weight = Parameter( + torch.empty((self.out_features, self.input_size_per_partition), dtype=torch.int8), + requires_grad=False, + ) + self.weight_scaler = Parameter(torch.Tensor(self.out_features)) + qconfig = TensorQConfig(axis=0) + self.weight.data, scale, zero_point = quantize_tensor(fp_w, qconfig) + self.weight_scaler.data = scale.to(orig_dtype) + self.quant = True + def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore # Set up backprop all-reduce. if self.input_is_parallel: diff --git a/reshard_checkpoints.py b/reshard_checkpoints.py index 86e147586..6c6144274 100644 --- a/reshard_checkpoints.py +++ b/reshard_checkpoints.py @@ -17,7 +17,7 @@ @torch.no_grad() -def reshard(original_mp, target_mp, ckpt_dir, output_dir, tokenizer_path): +def reshard(original_mp, target_mp, ckpt_dir, output_dir, tokenizer_path, quantize): assert target_mp > original_mp > 0 factor = divide_and_check_no_remainder(target_mp, original_mp) @@ -109,6 +109,8 @@ def reshard(original_mp, target_mp, ckpt_dir, output_dir, tokenizer_path): factor)[shard_rank].contiguous() assert weight_shard.size() == module.weight.size() module.weight.copy_(weight_shard) + if quantize: + module.quantize() elif isinstance(module, ColumnParallelLinear): source_module = original_model.get_submodule(name) assert module.bias is None and source_module.bias is None @@ -122,14 +124,18 @@ def reshard(original_mp, target_mp, ckpt_dir, output_dir, tokenizer_path): factor // kv_head_duplicate)[shard_rank // kv_head_duplicate].transpose(0, 1).contiguous() assert weight_shard.size() == module.weight.size() module.weight.copy_(weight_shard) + if quantize: + module.quantize() state_dict = { k: v for k, v in target_model.state_dict().items() - if k in checkpoint.keys() + if k in checkpoint.keys() or "weight_scaler" in k # TODO: "weight_scaler" are new parameters after quant, add to state_dict in a more elegant way. } torch.save(state_dict, Path(output_dir) / f"{target_rank:03}.pth") + if quantize: + new_params['quant'] = True with open(Path(output_dir) / "params.json", "w") as f: json.dump(new_params, f) From 427af61aace7e5410d10ca1a30c51ffbead0bfd5 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Tue, 12 Dec 2023 10:59:41 +0000 Subject: [PATCH 2/2] add default param for quant in reshard script --- reshard_checkpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reshard_checkpoints.py b/reshard_checkpoints.py index 6c6144274..f488d79f3 100644 --- a/reshard_checkpoints.py +++ b/reshard_checkpoints.py @@ -17,7 +17,7 @@ @torch.no_grad() -def reshard(original_mp, target_mp, ckpt_dir, output_dir, tokenizer_path, quantize): +def reshard(original_mp, target_mp, ckpt_dir, output_dir, tokenizer_path, quantize=False): assert target_mp > original_mp > 0 factor = divide_and_check_no_remainder(target_mp, original_mp)