Skip to content

Commit

Permalink
refactor(library): reduce overhead in marlin op
Browse files Browse the repository at this point in the history
Using torch.library.custom_op introduces an overhead.
  • Loading branch information
dacorvo committed Sep 30, 2024
1 parent 5f88400 commit 7b73aae
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 6 deletions.
24 changes: 22 additions & 2 deletions optimum/quanto/library/extensions/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,21 @@ def gemm_f16i4_awq(
return ext.lib.awq_v2_gemm_f16i4(input, other, scales, shift)


@torch.library.custom_op("quanto::fp8_marlin_gemm", mutates_args=(), device_types=["cuda"])
torch.library.define(
"quanto::gemm_f16f8_marlin",
"(Tensor a,"
"Tensor b_q_weight,"
"Tensor b_scales,"
"Tensor workspace,"
"int num_bits,"
"int size_m,"
"int size_n,"
"int size_k)"
" -> Tensor",
)


@torch.library.impl("quanto::gemm_f16f8_marlin", ["CUDA"])
def fp8_marlin_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
Expand All @@ -135,7 +149,13 @@ def fp8_marlin_gemm(
return ext.lib.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k)


@torch.library.custom_op("quanto::gptq_marlin_repack", mutates_args=(), device_types=["cuda"])
torch.library.define(
"quanto::pack_fp8_marlin",
"(Tensor b_q_weight, Tensor perm, int size_k, int size_n, int num_bits) -> Tensor",
)


@torch.library.impl("quanto::pack_fp8_marlin", ["CUDA"])
def gptq_marlin_repack(
b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion optimum/quanto/tensor/weights/marlin/fp8/packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def pack(cls, tensor: torch.Tensor):

perm = torch.empty(0, dtype=torch.int, device=tensor.device)

data_int32 = torch.ops.quanto.gptq_marlin_repack(
data_int32 = torch.ops.quanto.pack_fp8_marlin(
b_q_weight=data_int32, perm=perm, size_k=in_features, size_n=out_features, num_bits=8
)

Expand Down
2 changes: 1 addition & 1 deletion optimum/quanto/tensor/weights/marlin/fp8/qbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def forward(ctx, input, other, bias=None):
if input.ndim > 2:
input = input.view(-1, input_shape[-1])

output = torch.ops.quanto.fp8_marlin_gemm(
output = torch.ops.quanto.gemm_f16f8_marlin(
input,
b_q_weight=other._data._data,
b_scales=other._scale, # .to(input.dtype)
Expand Down
4 changes: 2 additions & 2 deletions test/library/test_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_fp8_marlin(tokens, in_features, out_features, dtype):
other_data_int32 = pack_fp8_as_int32(other_data)
perm = torch.empty(0, dtype=torch.int, device=device)

other_data_repack = torch.ops.quanto.gptq_marlin_repack(
other_data_repack = torch.ops.quanto.pack_fp8_marlin(
b_q_weight=other_data_int32, perm=perm, size_k=in_features, size_n=out_features, num_bits=8
)
other_scale = torch.rand(1, out_features, dtype=dtype, device=device)
Expand All @@ -124,7 +124,7 @@ def test_fp8_marlin(tokens, in_features, out_features, dtype):
other_scale = other_scale.reshape(-1, out_features).contiguous()

workspace = torch.zeros(out_features // 64 * 16, dtype=torch.int, device=device)
lib_outputs = torch.ops.quanto.fp8_marlin_gemm(
lib_outputs = torch.ops.quanto.gemm_f16f8_marlin(
a=inputs,
b_q_weight=other_data_repack,
b_scales=other_scale,
Expand Down

0 comments on commit 7b73aae

Please sign in to comment.