Skip to content

Commit

Permalink
!15271 quantbmmv3 bias support fp16/fp32 in mix-core senario
Browse files Browse the repository at this point in the history
Merge pull request !15271 from ZitaoWang/c20bias
  • Loading branch information
ZitaoWang authored and it-is-a-robot committed Oct 26, 2024
1 parent e4a82db commit 84c05f5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 13 deletions.
9 changes: 8 additions & 1 deletion test/test_fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,7 +1740,6 @@ def test_npu_quant_matmul_meta(self):
x1 = torch.randint(-1, 1, (1, 1, 1024), dtype=torch.int8).npu()
x2 = torch.randint(-1, 1, (1, 1024, 100), dtype=torch.int8).npu()
expect_ret = torch.randint(-1, 1, (1, 1, 100), dtype=torch.int8).npu()
scale = torch.randn(1, dtype=torch.bfloat16).npu()
scale = torch.randn(1, dtype=torch.float32).npu()
offset = torch.randn(1, dtype=torch.float32).npu()
bias = torch.randint(-1, -1, (1, 1, 100), dtype=torch.int32).npu()
Expand All @@ -1755,6 +1754,14 @@ def test_npu_quant_matmul_meta(self):
self.assertTrue(expect_ret_bf16.shape == res_bf16.shape)
self.assertTrue(expect_ret_bf16.dtype == res_bf16.dtype)

expect_ret_fp16 = torch.randint(-1, 1, (1, 1, 100), dtype=torch.float16).npu()
bias_fp32 = torch.randint(-1, -1, (1, 1, 100), dtype=torch.float32).npu()
pertoken_scale = torch.randn(1, dtype=torch.float32).npu()
res_fp16 = torch_npu.npu_quant_matmul(x1, x2, scale, offset=None, pertoken_scale=pertoken_scale,
bias=bias_fp32, output_dtype=torch.float16)
self.assertTrue(expect_ret_fp16.shape == res_fp16.shape)
self.assertTrue(expect_ret_fp16.dtype == res_fp16.dtype)

x1 = torch.randint(-1, 1, (16, 8), dtype=torch.int32).npu()
x2 = torch.randint(-1, 1, (64, 5), dtype=torch.int32).npu()
expect_ret = torch.randint(-1, 1, (16, 40), dtype=torch.float16).npu()
Expand Down
47 changes: 35 additions & 12 deletions torch_npu/meta/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,40 @@ def quant_matmul_shape_check(*args):
)


def quant_matmul_bias_dtype_check(bias, pertoken_scale, output_dtype):
bias_dtype_supported_list = [torch.int32, torch.bfloat16, torch.float32, torch.float16]
torch._check(
bias.dtype in bias_dtype_supported_list,
lambda: "bias's type supported for int32, bfloat16, float16 and float32, but bias.dtype is " + str(bias.dtype) + ops_error(ErrCode.TYPE),
)
if bias.dtype == torch.bfloat16:
torch._check(
output_dtype == torch.bfloat16,
lambda: "When bias dtype is bfloat16, output_dtype must be bfloat16, but it is " +
str(output_dtype) + ops_error(ErrCode.TYPE),
)
if pertoken_scale is not None:
if bias.dtype == torch.float16:
torch._check(
output_dtype == torch.float16,
lambda: "When bias dtype is float16 and pertoken is given, output_dtype must be float16, but it is " +
str(output_dtype) + ops_error(ErrCode.TYPE),
)
else:
torch._check(
bias.dtype != torch.float16,
lambda: "Bias dtype cannot be float16 when pertoken not given." + ops_error(ErrCode.TYPE),
)
if bias.dtype == torch.float32:
torch._check(
output_dtype == torch.bfloat16,
lambda: "When bias dtype is float32 and pertoken not given, output_dtype must be bfloat16, but it is " +
str(output_dtype) + ops_error(ErrCode.TYPE),
)


def quant_matmul_dtype_check(*args):
x1, x2, scale, offset, pertoken_scale, bias, is_a4w4 = args
x1, x2, scale, offset, pertoken_scale, bias, output_dtype, is_a4w4 = args
torch._check(
x1.dtype == x2.dtype,
lambda: "x1's type and x2's type should be same, but x1.dtype is " + str(x1.dtype) + " and x2.dtype is " +
Expand Down Expand Up @@ -479,10 +511,7 @@ def quant_matmul_dtype_check(*args):
str(offset.dtype) + ops_error(ErrCode.TYPE),
)
if bias is not None:
torch._check(
bias.dtype == torch.int32 or bias.dtype == torch.bfloat16,
lambda: "bias's type supported for int32 and bfloat16, but bias.dtype is " + str(bias.dtype) + ops_error(ErrCode.TYPE),
)
quant_matmul_bias_dtype_check(bias, pertoken_scale, output_dtype)


def quant_matmul_scale_offset_out_check(scale, offset, pertoken_scale, output_dtype, is_a4w4):
Expand Down Expand Up @@ -553,19 +582,13 @@ def npu_quant_matmul_meta(x1, x2, scale, *, offset=None, pertoken_scale=None, bi
dim_list.append(dimn)
quant_matmul_shape_check(x1, x2, scale, offset, pertoken_scale, is_a4w4, transpose_x2)
if bias is not None:
if bias.dtype == torch.bfloat16:
torch._check(
output_dtype == torch.bfloat16,
lambda: "When bias dtype is bfloat16, output_dtype must be bfloat16, but it is " +
str(output_dtype) + ops_error(ErrCode.TYPE),
)
if bias.dim() == 3:
torch._check(
len(dim_list) == 3,
lambda:"when bias dim is 3, out dim need to be 3" + ops_error(ErrCode.TYPE),
)
bias_shape_check(x2, bias, batch_val, is_a4w4, transpose_x2)
quant_matmul_dtype_check(x1, x2, scale, offset, pertoken_scale, bias, is_a4w4)
quant_matmul_dtype_check(x1, x2, scale, offset, pertoken_scale, bias, output_dtype, is_a4w4)
quant_matmul_scale_offset_out_check(scale, offset, pertoken_scale, output_dtype, is_a4w4)
if output_dtype == torch.float16:
return shape_long.new_empty(tuple(dim_list), dtype=torch.float16)
Expand Down

0 comments on commit 84c05f5

Please sign in to comment.