From 51f75308298ea208771b615fa395fd2e04e0665b Mon Sep 17 00:00:00 2001 From: hexisyztem Date: Wed, 8 Jun 2022 19:35:50 +0800 Subject: [PATCH] Fix test unit (#321) * modify Dockerfile to compile tritonbackend * fix test_ls_ops Co-authored-by: zhoubofan --- lightseq/training/ops/pytorch/quantization.py | 9 ++- .../nn/modules/tensor_quantizer.py | 18 ++++-- .../pytorch_quantization/tensor_quant.py | 4 +- tests/test_ls_ops.py | 59 +++++-------------- tests/util.py | 46 +++++++++++---- 5 files changed, 70 insertions(+), 66 deletions(-) diff --git a/lightseq/training/ops/pytorch/quantization.py b/lightseq/training/ops/pytorch/quantization.py index fe284671..68d1ca5d 100644 --- a/lightseq/training/ops/pytorch/quantization.py +++ b/lightseq/training/ops/pytorch/quantization.py @@ -16,11 +16,14 @@ act_quant_config = QuantDescriptor( num_bits=8, narrow_range=True, learn_amax=True, amax=16.0 ) +out_quant_config = QuantDescriptor( + num_bits=8, narrow_range=True, learn_amax=False, amax=16.0 +) relu_quant_config = QuantDescriptor( num_bits=8, narrow_range=True, learn_amax=True, amax=16.0, unsigned=True ) weight_quant_config = QuantDescriptor( - num_bits=8, narrow_range=True, learn_amax=True, amax=1.0 + num_bits=8, narrow_range=True, learn_amax=False, amax=1.0 ) @@ -36,8 +39,8 @@ def __init__(self, in_features, out_features, pre_activation=None, *args, **kwar if pre_activation != "encoder_out": self.input_quant = TensorQuantizer(input_quant_config) self.output_quant = None - if pre_activation is None: - self.output_quant = TensorQuantizer(act_quant_config) + # if pre_activation is None: + self.output_quant = TensorQuantizer(out_quant_config) self.weight_quant = TensorQuantizer(weight_quant_config) def forward(self, input): diff --git a/lightseq/training/pytorch_quantization/nn/modules/tensor_quantizer.py b/lightseq/training/pytorch_quantization/nn/modules/tensor_quantizer.py index 8e3124c1..191d16e4 100644 --- a/lightseq/training/pytorch_quantization/nn/modules/tensor_quantizer.py +++ b/lightseq/training/pytorch_quantization/nn/modules/tensor_quantizer.py @@ -77,8 +77,8 @@ class TensorQuantizer(nn.Module): def __init__( self, quant_desc=QuantDescriptor(), - disabled=False, - if_quant=True, + disabled=True, + if_quant=False, if_clip=False, if_calib=False, ): @@ -173,7 +173,8 @@ def disable_clip(self): """Disable clip stage""" self._if_clip = False # self.clip.clip_value_min.required_grad = False - self.clip.clip_value_max.required_grad = False + if hasattr(self.clip, "clip_value_max"): + self.clip.clip_value_max.required_grad = False def enable_clip(self): """Enable clip stage""" @@ -181,7 +182,8 @@ def enable_clip(self): if not self._learn_amax: raise ValueError("learn_amax is False. Cannot enable clip.") # self.clip.clip_value_min.required_grad = True - self.clip.clip_value_max.required_grad = True + if hasattr(self.clip, "clip_value_max"): + self.clip.clip_value_max.required_grad = True self._if_clip = True def disable_calib(self): @@ -467,6 +469,11 @@ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def enable_quant(m): if isinstance(m, TensorQuantizer): m.enable() + m.enable_quant() + + elif isinstance(m, torch.nn.Module): + if hasattr(m, "enable_quant"): + m.enable_quant() def disable_quant(m): @@ -475,6 +482,9 @@ def disable_quant(m): m.disable_quant() m.disable_calib() m.disable_clip() + elif isinstance(m, torch.nn.Module): + if hasattr(m, "disable_quant"): + m.disable_quant() def qat_mode(m): diff --git a/lightseq/training/pytorch_quantization/tensor_quant.py b/lightseq/training/pytorch_quantization/tensor_quant.py index 8a05c42b..da31c565 100644 --- a/lightseq/training/pytorch_quantization/tensor_quant.py +++ b/lightseq/training/pytorch_quantization/tensor_quant.py @@ -335,7 +335,7 @@ def forward(ctx, inputs, amax, num_bits=8, unsigned=False, narrow_range=True): # ctx.save_for_backward(inputs, amax) outputs, scale = _tensor_quant(inputs, amax, num_bits, unsigned, narrow_range) if unsigned: - outputs += 127 + outputs += (2.0 ** (num_bits - 1)) - 1.0 return (outputs * scale).to(inputs.dtype) @staticmethod @@ -394,7 +394,7 @@ def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True): # (x + 0.5).floor() match the implementation of tensorflow fake_quant outputs = inputs / scale if unsigned: - outputs -= 127 + outputs -= (2.0 ** (num_bits - 1)) - 1.0 outputs = (outputs + 0.5).floor_() outputs = torch.clamp(outputs, min_bound, max_bound) diff --git a/tests/test_ls_ops.py b/tests/test_ls_ops.py index 6a5e1f51..d1345059 100644 --- a/tests/test_ls_ops.py +++ b/tests/test_ls_ops.py @@ -318,12 +318,8 @@ def custom(): grads[11], grads[2], grads[3], - grads[0][:shs], - grads[1][:hidden_size], - grads[0][shs : shs * 2], - grads[1][hidden_size : hidden_size * 2], - grads[0][shs * 2 : shs * 3], - grads[1][hidden_size * 2 : hidden_size * 3], + grads[0], + grads[1], grads[4], grads[5], ] @@ -352,12 +348,8 @@ def baseline(): curl.final_layer_norm.bias, curl.self_attn.out_proj.weight, curl.self_attn.out_proj.bias, - curl.self_attn.q_proj.weight, - curl.self_attn.q_proj.bias, - curl.self_attn.k_proj.weight, - curl.self_attn.k_proj.bias, - curl.self_attn.v_proj.weight, - curl.self_attn.v_proj.bias, + curl.self_attn.qkv_proj.weight, + curl.self_attn.qkv_proj.bias, curl.self_attn_layer_norm.weight, curl.self_attn_layer_norm.bias, ] @@ -435,12 +427,8 @@ def custom(): grads[11], grads[2], grads[3], - grads[0][:shs], - grads[1][:hidden_size], - grads[0][shs : shs * 2], - grads[1][hidden_size : hidden_size * 2], - grads[0][shs * 2 : shs * 3], - grads[1][hidden_size * 2 : hidden_size * 3], + grads[0], + grads[1], grads[4], grads[5], ] @@ -463,12 +451,8 @@ def baseline(): curl.final_layer_norm.bias, curl.self_attn.out_proj.weight, curl.self_attn.out_proj.bias, - curl.self_attn.q_proj.weight, - curl.self_attn.q_proj.bias, - curl.self_attn.k_proj.weight, - curl.self_attn.k_proj.bias, - curl.self_attn.v_proj.weight, - curl.self_attn.v_proj.bias, + curl.self_attn.qkv_proj.weight, + curl.self_attn.qkv_proj.bias, curl.self_attn_layer_norm.weight, curl.self_attn_layer_norm.bias, ] @@ -528,7 +512,8 @@ def test_decoder_layer_backward(): batch_size, enc_seq_len = kt.bs_sl() _, dec_seq_len = kt.bs_sl(batch_size) print( - f"(batch_size, enc_seq_len, dec_seq_len): ({batch_size}, {enc_seq_len}, {dec_seq_len})" + f"(batch_size, enc_seq_len, dec_seq_len): ({batch_size}, {enc_seq_len}," + f" {dec_seq_len})" ) hidden_size = 1024 shs = hidden_size * hidden_size @@ -572,12 +557,8 @@ def custom(): grads[17], grads[2], grads[3], - grads[0][:shs], - grads[1][:hidden_size], - grads[0][shs : shs * 2], - grads[1][hidden_size : hidden_size * 2], - grads[0][shs * 2 : shs * 3], - grads[1][hidden_size * 2 : hidden_size * 3], + grads[0], + grads[1], grads[4], grads[5], # encdec grad @@ -637,22 +618,10 @@ def baseline(): .self_attn.out_proj.bias.grad.contiguous() .detach(), fairseq_dec_layer_list[i] - .self_attn.q_proj.weight.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .self_attn.q_proj.bias.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .self_attn.k_proj.weight.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .self_attn.k_proj.bias.grad.contiguous() - .detach(), - fairseq_dec_layer_list[i] - .self_attn.v_proj.weight.grad.contiguous() + .self_attn.qkv_proj.weight.grad.contiguous() .detach(), fairseq_dec_layer_list[i] - .self_attn.v_proj.bias.grad.contiguous() + .self_attn.qkv_proj.bias.grad.contiguous() .detach(), fairseq_dec_layer_list[i] .self_attn_layer_norm.weight.grad.contiguous() diff --git a/tests/util.py b/tests/util.py index 47258773..680c5763 100644 --- a/tests/util.py +++ b/tests/util.py @@ -221,12 +221,23 @@ def get_fairseq_enc_params(fairseq_layer): initial_weights = [] initial_biases = [] - initial_weights.append(fairseq_layer.self_attn.q_proj.weight.detach().clone()) - initial_biases.append(fairseq_layer.self_attn.q_proj.bias.detach().clone()) - initial_weights.append(fairseq_layer.self_attn.k_proj.weight.detach().clone()) - initial_biases.append(fairseq_layer.self_attn.k_proj.bias.detach().clone()) - initial_weights.append(fairseq_layer.self_attn.v_proj.weight.detach().clone()) - initial_biases.append(fairseq_layer.self_attn.v_proj.bias.detach().clone()) + if hasattr(fairseq_layer.self_attn, "qkv_proj"): + hidden_size = fairseq_layer.self_attn.out_proj.weight.shape[0] + initial_weights.extend( + fairseq_layer.self_attn.qkv_proj.weight.detach() + .clone() + .split(hidden_size, 0) + ) + initial_biases.extend( + fairseq_layer.self_attn.qkv_proj.bias.detach().clone().split(hidden_size, 0) + ) + else: + initial_weights.append(fairseq_layer.self_attn.q_proj.weight.detach().clone()) + initial_biases.append(fairseq_layer.self_attn.q_proj.bias.detach().clone()) + initial_weights.append(fairseq_layer.self_attn.k_proj.weight.detach().clone()) + initial_biases.append(fairseq_layer.self_attn.k_proj.bias.detach().clone()) + initial_weights.append(fairseq_layer.self_attn.v_proj.weight.detach().clone()) + initial_biases.append(fairseq_layer.self_attn.v_proj.bias.detach().clone()) initial_weights.append(fairseq_layer.self_attn.out_proj.weight.detach().clone()) initial_biases.append(fairseq_layer.self_attn.out_proj.bias.detach().clone()) initial_weights.append(fairseq_layer.self_attn_layer_norm.weight.detach().clone()) @@ -245,12 +256,23 @@ def get_fairseq_dec_params(fairseq_layer): initial_weights = [] initial_biases = [] - initial_weights.append(fairseq_layer.self_attn.q_proj.weight.detach().clone()) - initial_biases.append(fairseq_layer.self_attn.q_proj.bias.detach().clone()) - initial_weights.append(fairseq_layer.self_attn.k_proj.weight.detach().clone()) - initial_biases.append(fairseq_layer.self_attn.k_proj.bias.detach().clone()) - initial_weights.append(fairseq_layer.self_attn.v_proj.weight.detach().clone()) - initial_biases.append(fairseq_layer.self_attn.v_proj.bias.detach().clone()) + if hasattr(fairseq_layer.self_attn, "qkv_proj"): + hidden_size = fairseq_layer.self_attn.out_proj.weight.shape[0] + initial_weights.extend( + fairseq_layer.self_attn.qkv_proj.weight.detach() + .clone() + .split(hidden_size, 0) + ) + initial_biases.extend( + fairseq_layer.self_attn.qkv_proj.bias.detach().clone().split(hidden_size, 0) + ) + else: + initial_weights.append(fairseq_layer.self_attn.q_proj.weight.detach().clone()) + initial_biases.append(fairseq_layer.self_attn.q_proj.bias.detach().clone()) + initial_weights.append(fairseq_layer.self_attn.k_proj.weight.detach().clone()) + initial_biases.append(fairseq_layer.self_attn.k_proj.bias.detach().clone()) + initial_weights.append(fairseq_layer.self_attn.v_proj.weight.detach().clone()) + initial_biases.append(fairseq_layer.self_attn.v_proj.bias.detach().clone()) initial_weights.append(fairseq_layer.self_attn.out_proj.weight.detach().clone()) initial_biases.append(fairseq_layer.self_attn.out_proj.bias.detach().clone()) initial_weights.append(fairseq_layer.self_attn_layer_norm.weight.detach().clone())