Skip to content

Commit

Permalink
Fix test unit (#321)
Browse files Browse the repository at this point in the history
* modify Dockerfile to compile tritonbackend

* fix test_ls_ops

Co-authored-by: zhoubofan <zhoubofan@bytedance.com>
  • Loading branch information
hexisyztem and hexisyztem authored Jun 8, 2022
1 parent 932e42e commit 51f7530
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 66 deletions.
9 changes: 6 additions & 3 deletions lightseq/training/ops/pytorch/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
act_quant_config = QuantDescriptor(
num_bits=8, narrow_range=True, learn_amax=True, amax=16.0
)
out_quant_config = QuantDescriptor(
num_bits=8, narrow_range=True, learn_amax=False, amax=16.0
)
relu_quant_config = QuantDescriptor(
num_bits=8, narrow_range=True, learn_amax=True, amax=16.0, unsigned=True
)
weight_quant_config = QuantDescriptor(
num_bits=8, narrow_range=True, learn_amax=True, amax=1.0
num_bits=8, narrow_range=True, learn_amax=False, amax=1.0
)


Expand All @@ -36,8 +39,8 @@ def __init__(self, in_features, out_features, pre_activation=None, *args, **kwar
if pre_activation != "encoder_out":
self.input_quant = TensorQuantizer(input_quant_config)
self.output_quant = None
if pre_activation is None:
self.output_quant = TensorQuantizer(act_quant_config)
# if pre_activation is None:
self.output_quant = TensorQuantizer(out_quant_config)
self.weight_quant = TensorQuantizer(weight_quant_config)

def forward(self, input):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ class TensorQuantizer(nn.Module):
def __init__(
self,
quant_desc=QuantDescriptor(),
disabled=False,
if_quant=True,
disabled=True,
if_quant=False,
if_clip=False,
if_calib=False,
):
Expand Down Expand Up @@ -173,15 +173,17 @@ def disable_clip(self):
"""Disable clip stage"""
self._if_clip = False
# self.clip.clip_value_min.required_grad = False
self.clip.clip_value_max.required_grad = False
if hasattr(self.clip, "clip_value_max"):
self.clip.clip_value_max.required_grad = False

def enable_clip(self):
"""Enable clip stage"""
# logger.warning("Enable `clip` stage for amax learning.")
if not self._learn_amax:
raise ValueError("learn_amax is False. Cannot enable clip.")
# self.clip.clip_value_min.required_grad = True
self.clip.clip_value_max.required_grad = True
if hasattr(self.clip, "clip_value_max"):
self.clip.clip_value_max.required_grad = True
self._if_clip = True

def disable_calib(self):
Expand Down Expand Up @@ -467,6 +469,11 @@ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def enable_quant(m):
if isinstance(m, TensorQuantizer):
m.enable()
m.enable_quant()

elif isinstance(m, torch.nn.Module):
if hasattr(m, "enable_quant"):
m.enable_quant()


def disable_quant(m):
Expand All @@ -475,6 +482,9 @@ def disable_quant(m):
m.disable_quant()
m.disable_calib()
m.disable_clip()
elif isinstance(m, torch.nn.Module):
if hasattr(m, "disable_quant"):
m.disable_quant()


def qat_mode(m):
Expand Down
4 changes: 2 additions & 2 deletions lightseq/training/pytorch_quantization/tensor_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def forward(ctx, inputs, amax, num_bits=8, unsigned=False, narrow_range=True):
# ctx.save_for_backward(inputs, amax)
outputs, scale = _tensor_quant(inputs, amax, num_bits, unsigned, narrow_range)
if unsigned:
outputs += 127
outputs += (2.0 ** (num_bits - 1)) - 1.0
return (outputs * scale).to(inputs.dtype)

@staticmethod
Expand Down Expand Up @@ -394,7 +394,7 @@ def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True):
# (x + 0.5).floor() match the implementation of tensorflow fake_quant
outputs = inputs / scale
if unsigned:
outputs -= 127
outputs -= (2.0 ** (num_bits - 1)) - 1.0
outputs = (outputs + 0.5).floor_()
outputs = torch.clamp(outputs, min_bound, max_bound)

Expand Down
59 changes: 14 additions & 45 deletions tests/test_ls_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,12 +318,8 @@ def custom():
grads[11],
grads[2],
grads[3],
grads[0][:shs],
grads[1][:hidden_size],
grads[0][shs : shs * 2],
grads[1][hidden_size : hidden_size * 2],
grads[0][shs * 2 : shs * 3],
grads[1][hidden_size * 2 : hidden_size * 3],
grads[0],
grads[1],
grads[4],
grads[5],
]
Expand Down Expand Up @@ -352,12 +348,8 @@ def baseline():
curl.final_layer_norm.bias,
curl.self_attn.out_proj.weight,
curl.self_attn.out_proj.bias,
curl.self_attn.q_proj.weight,
curl.self_attn.q_proj.bias,
curl.self_attn.k_proj.weight,
curl.self_attn.k_proj.bias,
curl.self_attn.v_proj.weight,
curl.self_attn.v_proj.bias,
curl.self_attn.qkv_proj.weight,
curl.self_attn.qkv_proj.bias,
curl.self_attn_layer_norm.weight,
curl.self_attn_layer_norm.bias,
]
Expand Down Expand Up @@ -435,12 +427,8 @@ def custom():
grads[11],
grads[2],
grads[3],
grads[0][:shs],
grads[1][:hidden_size],
grads[0][shs : shs * 2],
grads[1][hidden_size : hidden_size * 2],
grads[0][shs * 2 : shs * 3],
grads[1][hidden_size * 2 : hidden_size * 3],
grads[0],
grads[1],
grads[4],
grads[5],
]
Expand All @@ -463,12 +451,8 @@ def baseline():
curl.final_layer_norm.bias,
curl.self_attn.out_proj.weight,
curl.self_attn.out_proj.bias,
curl.self_attn.q_proj.weight,
curl.self_attn.q_proj.bias,
curl.self_attn.k_proj.weight,
curl.self_attn.k_proj.bias,
curl.self_attn.v_proj.weight,
curl.self_attn.v_proj.bias,
curl.self_attn.qkv_proj.weight,
curl.self_attn.qkv_proj.bias,
curl.self_attn_layer_norm.weight,
curl.self_attn_layer_norm.bias,
]
Expand Down Expand Up @@ -528,7 +512,8 @@ def test_decoder_layer_backward():
batch_size, enc_seq_len = kt.bs_sl()
_, dec_seq_len = kt.bs_sl(batch_size)
print(
f"(batch_size, enc_seq_len, dec_seq_len): ({batch_size}, {enc_seq_len}, {dec_seq_len})"
f"(batch_size, enc_seq_len, dec_seq_len): ({batch_size}, {enc_seq_len},"
f" {dec_seq_len})"
)
hidden_size = 1024
shs = hidden_size * hidden_size
Expand Down Expand Up @@ -572,12 +557,8 @@ def custom():
grads[17],
grads[2],
grads[3],
grads[0][:shs],
grads[1][:hidden_size],
grads[0][shs : shs * 2],
grads[1][hidden_size : hidden_size * 2],
grads[0][shs * 2 : shs * 3],
grads[1][hidden_size * 2 : hidden_size * 3],
grads[0],
grads[1],
grads[4],
grads[5],
# encdec grad
Expand Down Expand Up @@ -637,22 +618,10 @@ def baseline():
.self_attn.out_proj.bias.grad.contiguous()
.detach(),
fairseq_dec_layer_list[i]
.self_attn.q_proj.weight.grad.contiguous()
.detach(),
fairseq_dec_layer_list[i]
.self_attn.q_proj.bias.grad.contiguous()
.detach(),
fairseq_dec_layer_list[i]
.self_attn.k_proj.weight.grad.contiguous()
.detach(),
fairseq_dec_layer_list[i]
.self_attn.k_proj.bias.grad.contiguous()
.detach(),
fairseq_dec_layer_list[i]
.self_attn.v_proj.weight.grad.contiguous()
.self_attn.qkv_proj.weight.grad.contiguous()
.detach(),
fairseq_dec_layer_list[i]
.self_attn.v_proj.bias.grad.contiguous()
.self_attn.qkv_proj.bias.grad.contiguous()
.detach(),
fairseq_dec_layer_list[i]
.self_attn_layer_norm.weight.grad.contiguous()
Expand Down
46 changes: 34 additions & 12 deletions tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,23 @@ def get_fairseq_enc_params(fairseq_layer):
initial_weights = []
initial_biases = []

initial_weights.append(fairseq_layer.self_attn.q_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.q_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.k_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.k_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.v_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.v_proj.bias.detach().clone())
if hasattr(fairseq_layer.self_attn, "qkv_proj"):
hidden_size = fairseq_layer.self_attn.out_proj.weight.shape[0]
initial_weights.extend(
fairseq_layer.self_attn.qkv_proj.weight.detach()
.clone()
.split(hidden_size, 0)
)
initial_biases.extend(
fairseq_layer.self_attn.qkv_proj.bias.detach().clone().split(hidden_size, 0)
)
else:
initial_weights.append(fairseq_layer.self_attn.q_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.q_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.k_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.k_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.v_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.v_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.out_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.out_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn_layer_norm.weight.detach().clone())
Expand All @@ -245,12 +256,23 @@ def get_fairseq_dec_params(fairseq_layer):
initial_weights = []
initial_biases = []

initial_weights.append(fairseq_layer.self_attn.q_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.q_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.k_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.k_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.v_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.v_proj.bias.detach().clone())
if hasattr(fairseq_layer.self_attn, "qkv_proj"):
hidden_size = fairseq_layer.self_attn.out_proj.weight.shape[0]
initial_weights.extend(
fairseq_layer.self_attn.qkv_proj.weight.detach()
.clone()
.split(hidden_size, 0)
)
initial_biases.extend(
fairseq_layer.self_attn.qkv_proj.bias.detach().clone().split(hidden_size, 0)
)
else:
initial_weights.append(fairseq_layer.self_attn.q_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.q_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.k_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.k_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.v_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.v_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn.out_proj.weight.detach().clone())
initial_biases.append(fairseq_layer.self_attn.out_proj.bias.detach().clone())
initial_weights.append(fairseq_layer.self_attn_layer_norm.weight.detach().clone())
Expand Down

0 comments on commit 51f7530

Please sign in to comment.