From 55aec3f8a07714adab20b00f59029f328e6ee94e Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Thu, 18 Apr 2024 19:31:21 +0000 Subject: [PATCH 1/3] minmax channel wise --- .../quantization/observers/min_max.py | 67 +++++++++++++++---- 1 file changed, 54 insertions(+), 13 deletions(-) diff --git a/src/compressed_tensors/quantization/observers/min_max.py b/src/compressed_tensors/quantization/observers/min_max.py index 808f24c3..1816a059 100644 --- a/src/compressed_tensors/quantization/observers/min_max.py +++ b/src/compressed_tensors/quantization/observers/min_max.py @@ -45,21 +45,62 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: """ # TODO: Add support for full range of quantization Args, only supports 8bit # per tensor - min_val = torch.tensor([observed.min()]) - max_val = torch.tensor([observed.max()]) - # update running average - if self.counter > 0: - self.min_val = (self.min_val * self.counter + min_val) / (self.counter + 1) - self.max_val = (self.max_val * self.counter + max_val) / (self.counter + 1) + # channel wise quantization + if self.quantization_args.group_size == -1 and observed.dim() > 2: + scale_zero_points = [] + for chan_i in range(observed.shape[0]): + + min_val = torch.tensor([observed[chan_i:].min()]) + max_val = torch.tensor([observed[chan_i:].max()]) + + # update running average + if self.counter > 0: + self.min_val = (self.min_val * self.counter + min_val) / ( + self.counter + 1 + ) + self.max_val = (self.max_val * self.counter + max_val) / ( + self.counter + 1 + ) + else: + self.min_val = min_val + self.max_val = max_val + + # ensure that the zeros are in the range + min_val = torch.min(self.min_val, torch.zeros_like(self.min_val)) + max_val = torch.max(self.max_val, torch.zeros_like(self.max_val)) + + scale_zero_points.append( + calculate_qparams(min_val, max_val, self.quantization_args) + ) + + self.counter += 1 + else: - self.min_val = min_val - self.max_val = max_val + # regular quantization + # TODO: group size quantization + + min_val = torch.tensor([observed.min()]) + max_val = torch.tensor([observed.max()]) + + # update running average + if self.counter > 0: + self.min_val = (self.min_val * self.counter + min_val) / ( + self.counter + 1 + ) + self.max_val = (self.max_val * self.counter + max_val) / ( + self.counter + 1 + ) + else: + self.min_val = min_val + self.max_val = max_val + + self.counter += 1 - # ensure that the zeros are in the range - min_val = torch.min(self.min_val, torch.zeros_like(self.min_val)) - max_val = torch.max(self.max_val, torch.zeros_like(self.max_val)) + # ensure that the zeros are in the range + min_val = torch.min(self.min_val, torch.zeros_like(self.min_val)) + max_val = torch.max(self.max_val, torch.zeros_like(self.max_val)) - self.counter += 1 + return calculate_qparams(min_val, max_val, self.quantization_args) - return calculate_qparams(min_val, max_val, self.quantization_args) + return scale_zero_points From 13828c01f687cc9206fc13d40c76160674f09671 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Fri, 19 Apr 2024 13:25:21 +0000 Subject: [PATCH 2/3] comments --- .../quantization/observers/min_max.py | 72 ++++++++----------- 1 file changed, 29 insertions(+), 43 deletions(-) diff --git a/src/compressed_tensors/quantization/observers/min_max.py b/src/compressed_tensors/quantization/observers/min_max.py index 1816a059..8a9ddfe9 100644 --- a/src/compressed_tensors/quantization/observers/min_max.py +++ b/src/compressed_tensors/quantization/observers/min_max.py @@ -46,61 +46,47 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: # TODO: Add support for full range of quantization Args, only supports 8bit # per tensor - # channel wise quantization + # channel wise quantization -- group_size == -1 if self.quantization_args.group_size == -1 and observed.dim() > 2: - scale_zero_points = [] - for chan_i in range(observed.shape[0]): - - min_val = torch.tensor([observed[chan_i:].min()]) - max_val = torch.tensor([observed[chan_i:].max()]) - - # update running average - if self.counter > 0: - self.min_val = (self.min_val * self.counter + min_val) / ( - self.counter + 1 - ) - self.max_val = (self.max_val * self.counter + max_val) / ( - self.counter + 1 - ) - else: - self.min_val = min_val - self.max_val = max_val - - # ensure that the zeros are in the range - min_val = torch.min(self.min_val, torch.zeros_like(self.min_val)) - max_val = torch.max(self.max_val, torch.zeros_like(self.max_val)) - - scale_zero_points.append( - calculate_qparams(min_val, max_val, self.quantization_args) - ) - self.counter += 1 - - else: - # regular quantization - # TODO: group size quantization - - min_val = torch.tensor([observed.min()]) - max_val = torch.tensor([observed.max()]) + reduce_dims = [1, 2] # 0th dim for channel, 1st, 2nd contain data + min_vals = observed.amin(dim=reduce_dims, keepdim=True) + max_vals = observed.amax(dim=reduce_dims, keepdim=True) # update running average if self.counter > 0: - self.min_val = (self.min_val * self.counter + min_val) / ( + self.min_vals = (self.min_vals * self.counter + min_vals) / ( self.counter + 1 ) - self.max_val = (self.max_val * self.counter + max_val) / ( + self.max_vals = (self.max_vals * self.counter + max_vals) / ( self.counter + 1 ) else: - self.min_val = min_val - self.max_val = max_val + self.min_vals = min_vals + self.max_vals = max_vals self.counter += 1 - # ensure that the zeros are in the range - min_val = torch.min(self.min_val, torch.zeros_like(self.min_val)) - max_val = torch.max(self.max_val, torch.zeros_like(self.max_val)) + return calculate_qparams(min_vals, max_vals, self.quantization_args) + + # regular quantization + # TODO: group size quantization + + min_val = torch.tensor([observed.min()]) + max_val = torch.tensor([observed.max()]) + + # update running average + if self.counter > 0: + self.min_val = (self.min_val * self.counter + min_val) / (self.counter + 1) + self.max_val = (self.max_val * self.counter + max_val) / (self.counter + 1) + else: + self.min_val = min_val + self.max_val = max_val + + self.counter += 1 - return calculate_qparams(min_val, max_val, self.quantization_args) + # ensure that the zeros are in the range + min_val = torch.min(self.min_val, torch.zeros_like(self.min_val)) + max_val = torch.max(self.max_val, torch.zeros_like(self.max_val)) - return scale_zero_points + return calculate_qparams(min_val, max_val, self.quantization_args) From f6769c3081752696c39dc17ef82ba262d6f682cc Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Fri, 19 Apr 2024 14:37:48 +0000 Subject: [PATCH 3/3] correct the reducer --- src/compressed_tensors/quantization/observers/min_max.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/quantization/observers/min_max.py b/src/compressed_tensors/quantization/observers/min_max.py index 8a9ddfe9..163392f9 100644 --- a/src/compressed_tensors/quantization/observers/min_max.py +++ b/src/compressed_tensors/quantization/observers/min_max.py @@ -47,9 +47,10 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]: # per tensor # channel wise quantization -- group_size == -1 - if self.quantization_args.group_size == -1 and observed.dim() > 2: + if self.quantization_args.group_size == -1: + + reduce_dims = [1] # everything thats not zero - reduce_dims = [1, 2] # 0th dim for channel, 1st, 2nd contain data min_vals = observed.amin(dim=reduce_dims, keepdim=True) max_vals = observed.amax(dim=reduce_dims, keepdim=True)