diff --git a/optimum/quanto/library/extensions/cuda/__init__.py b/optimum/quanto/library/extensions/cuda/__init__.py index becc715b..d44684db 100644 --- a/optimum/quanto/library/extensions/cuda/__init__.py +++ b/optimum/quanto/library/extensions/cuda/__init__.py @@ -118,7 +118,21 @@ def gemm_f16i4_awq( return ext.lib.awq_v2_gemm_f16i4(input, other, scales, shift) -@torch.library.custom_op("quanto::fp8_marlin_gemm", mutates_args=(), device_types=["cuda"]) +torch.library.define( + "quanto::gemm_f16f8_marlin", + "(Tensor a," + "Tensor b_q_weight," + "Tensor b_scales," + "Tensor workspace," + "int num_bits," + "int size_m," + "int size_n," + "int size_k)" + " -> Tensor", +) + + +@torch.library.impl("quanto::gemm_f16f8_marlin", ["CUDA"]) def fp8_marlin_gemm( a: torch.Tensor, b_q_weight: torch.Tensor, @@ -135,7 +149,13 @@ def fp8_marlin_gemm( return ext.lib.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k) -@torch.library.custom_op("quanto::gptq_marlin_repack", mutates_args=(), device_types=["cuda"]) +torch.library.define( + "quanto::pack_fp8_marlin", + "(Tensor b_q_weight, Tensor perm, int size_k, int size_n, int num_bits) -> Tensor", +) + + +@torch.library.impl("quanto::pack_fp8_marlin", ["CUDA"]) def gptq_marlin_repack( b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int ) -> torch.Tensor: diff --git a/optimum/quanto/tensor/weights/marlin/fp8/packed.py b/optimum/quanto/tensor/weights/marlin/fp8/packed.py index 1ba750ca..f075c8ff 100644 --- a/optimum/quanto/tensor/weights/marlin/fp8/packed.py +++ b/optimum/quanto/tensor/weights/marlin/fp8/packed.py @@ -180,7 +180,7 @@ def pack(cls, tensor: torch.Tensor): perm = torch.empty(0, dtype=torch.int, device=tensor.device) - data_int32 = torch.ops.quanto.gptq_marlin_repack( + data_int32 = torch.ops.quanto.pack_fp8_marlin( b_q_weight=data_int32, perm=perm, size_k=in_features, size_n=out_features, num_bits=8 ) diff --git a/optimum/quanto/tensor/weights/marlin/fp8/qbits.py b/optimum/quanto/tensor/weights/marlin/fp8/qbits.py index 666b9721..3b9db6d0 100644 --- a/optimum/quanto/tensor/weights/marlin/fp8/qbits.py +++ b/optimum/quanto/tensor/weights/marlin/fp8/qbits.py @@ -34,7 +34,7 @@ def forward(ctx, input, other, bias=None): if input.ndim > 2: input = input.view(-1, input_shape[-1]) - output = torch.ops.quanto.fp8_marlin_gemm( + output = torch.ops.quanto.gemm_f16f8_marlin( input, b_q_weight=other._data._data, b_scales=other._scale, # .to(input.dtype) diff --git a/test/library/test_mm.py b/test/library/test_mm.py index 6714b2de..e4297b64 100644 --- a/test/library/test_mm.py +++ b/test/library/test_mm.py @@ -113,7 +113,7 @@ def test_fp8_marlin(tokens, in_features, out_features, dtype): other_data_int32 = pack_fp8_as_int32(other_data) perm = torch.empty(0, dtype=torch.int, device=device) - other_data_repack = torch.ops.quanto.gptq_marlin_repack( + other_data_repack = torch.ops.quanto.pack_fp8_marlin( b_q_weight=other_data_int32, perm=perm, size_k=in_features, size_n=out_features, num_bits=8 ) other_scale = torch.rand(1, out_features, dtype=dtype, device=device) @@ -124,7 +124,7 @@ def test_fp8_marlin(tokens, in_features, out_features, dtype): other_scale = other_scale.reshape(-1, out_features).contiguous() workspace = torch.zeros(out_features // 64 * 16, dtype=torch.int, device=device) - lib_outputs = torch.ops.quanto.fp8_marlin_gemm( + lib_outputs = torch.ops.quanto.gemm_f16f8_marlin( a=inputs, b_q_weight=other_data_repack, b_scales=other_scale,