You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Compute A @ B^T where B is reconstructed from bit-twiddled 4-bit quantized data and per-block scales, returning bfloat16 results.
403
+
404
+
Converts the quantized matrix `qB` to floating-point via `torch_convert_bit_twiddling`, applies a per-element scale factor of 2^(Scale - 127) (where Scale indexes are grouped by 32 columns of B), computes the matrix product A · B^T in float, and casts the result to bfloat16.
405
+
406
+
Parameters:
407
+
A (torch.Tensor): Left operand with shape (M, K), used in floating precision.
408
+
qB (torch.Tensor): Quantized representation of B (packed 4-bit values) compatible with torch_convert_bit_twiddling.
409
+
Scale (torch.Tensor): Per-column-group scale values; Scale indices correspond to groups of 32 columns in B.
410
+
Bias (torch.Tensor): Bias tensor with shape (M, N).
411
+
412
+
Returns:
413
+
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
Compute a BF16 matrix product A · B^T from a quantized B with simple (non-twiddling) dequantization.
454
+
455
+
Converts the quantized tensor `qB` to floating B via `torch_convert`, applies a per-element scale factor computed as 2^(Scale[i][j//32] - 127) (Scale supplies exponent offsets in 32-column groups), then computes C = A · B^T and returns the result converted to bfloat16.
456
+
457
+
Parameters:
458
+
459
+
Returns:
460
+
- A: 2D tensor representing the left operand (will be cast to float32 for the matmul).
461
+
- qB: Quantized representation of B accepted by `torch_convert`.
462
+
- Scale: 2D tensor of exponent offsets; Scale[i][g] is applied to columns j where g == j // 32.
463
+
- Bias: 2D tensor representing the Bias (will be cast to float32 for the matmul).
464
+
465
+
466
+
Returns:
467
+
- 2D bfloat16 tensor C containing the matrix product A · B^T.
468
+
469
+
No in-place modification is performed on inputs (a local floating copy of B is scaled).
0 commit comments