Skip to content

Commit

Permalink
feat: e4m3fnuz added
Browse files Browse the repository at this point in the history
  • Loading branch information
root authored and dacorvo committed Sep 17, 2024
1 parent ec1f85e commit f37c58b
Show file tree
Hide file tree
Showing 11 changed files with 55 additions and 32 deletions.
1 change: 1 addition & 0 deletions optimum/quanto/tensor/qtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def qfloat(dtype: torch.dtype):


qfloat8_e4m3fn = qfloat(torch.float8_e4m3fn)
qfloat8_e4m3fnuz = qfloat(torch.float8_e4m3fnuz)
qfloat8_e5m2 = qfloat(torch.float8_e5m2)

# Alias the float8 representation that has the better support and inference efficiency
Expand Down
4 changes: 3 additions & 1 deletion test/library/test_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
@pytest.mark.parametrize("input_features", [32, 50])
@pytest.mark.parametrize("output_features", [48, 50, 64])
@pytest.mark.parametrize("input_dtype", [None, torch.int8], ids=["i-as-out", "i-int8"])
@pytest.mark.parametrize("weight_dtype", [torch.float8_e4m3fn, torch.int8], ids=["w-float8", "w-int8"])
@pytest.mark.parametrize(
"weight_dtype", [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.int8], ids=["w-float8", "w-float8-uz", "w-int8"]
)
@pytest.mark.parametrize("output_dtype", [torch.float16, torch.bfloat16], ids=["o-fp16", "o-bf16"])
def test_qbytes_mm(batch_size, input_features, input_dtype, weight_dtype, output_features, output_dtype, device):
if device.type == "mps" and weight_dtype.is_floating_point:
Expand Down
5 changes: 4 additions & 1 deletion test/library/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
absmax_scale,
qfloat8,
qfloat8_e4m3fn,
qfloat8_e4m3fnuz,
qfloat8_e5m2,
qint2,
qint4,
Expand Down Expand Up @@ -50,7 +51,9 @@ def test_symmetric_quantize_int(input_shape, dtype, qtype, axis, device):
@pytest.mark.parametrize("input_shape", [(32, 32), (32, 10, 32)])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"])
@pytest.mark.parametrize(
"qtype", [qfloat8, qfloat8_e4m3fn, qfloat8_e5m2], ids=["qfloat8", "qfloat8_e4m3fn", "qfloat8_e5m2"]
"qtype",
[qfloat8, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2],
ids=["qfloat8", "qfloat8_e4m3fn", "qfloat8_e4m3fnuz", "qfloat8_e5m2"],
)
@pytest.mark.parametrize(
"axis",
Expand Down
10 changes: 5 additions & 5 deletions test/nn/test_calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
from helpers import random_qactivation

from optimum.quanto import Calibration, qfloat8_e4m3fn, qfloat8_e5m2, qint8
from optimum.quanto import Calibration, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2, qint8
from optimum.quanto.nn import QLinear


Expand Down Expand Up @@ -51,8 +51,8 @@ def test_calibrate_qlinear_activations_int8(batch_size, tokens, embeddings, use_
@pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"])
@pytest.mark.parametrize(
"activations",
[qfloat8_e5m2, qfloat8_e4m3fn],
ids=["a-qfloat8-e5m2", "a-qfloat8-e4m3"],
[qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz],
ids=["a-qfloat8-e5m2", "a-qfloat8-e4m3", "a-qfloat8-e4m3-uz"],
)
@pytest.mark.skip_device("mps")
def test_calibrate_qlinear_activations_float8(batch_size, tokens, embeddings, use_bias, activations, device):
Expand Down Expand Up @@ -91,8 +91,8 @@ def test_calibrate_custom_module_activations_int8(device):

@pytest.mark.parametrize(
"activations",
[qfloat8_e5m2, qfloat8_e4m3fn],
ids=["a-qfloat8-e5m2", "a-qfloat8-e4m3"],
[qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz],
ids=["a-qfloat8-e5m2", "a-qfloat8-e4m3", "a-qfloat8-e4m3-uz"],
)
@pytest.mark.skip_device("mps")
def test_calibrate_custom_module_activations_float8(activations, device):
Expand Down
7 changes: 3 additions & 4 deletions test/nn/test_qattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from helpers import assert_similar, random_tensor
from torch import nn

from optimum.quanto import Calibration, qfloat8_e4m3fn, qfloat8_e5m2, qint8, quantize
from optimum.quanto import Calibration, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2, qint8, quantize


class RotaryEmbedding(nn.Module):
Expand Down Expand Up @@ -186,7 +186,6 @@ def _test_quantize_attention(device, dtype=torch.float32, weights=qint8, activat
else:
with torch.no_grad(), Calibration():
qoutputs = att(inputs)

assert_similar(outputs, qoutputs, atol=atol)


Expand All @@ -208,8 +207,8 @@ def test_quantize_attention_activations_int8(weights, device):
@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"])
@pytest.mark.parametrize(
"activations",
[qfloat8_e5m2, qfloat8_e4m3fn],
ids=["a-float8-e5m2", "a-float8-e4m3"],
[qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz],
ids=["a-float8-e5m2", "a-float8-e4m3", "a-float8-e4m3-uz"],
)
@pytest.mark.skip_device("mps")
def test_quantize_attention_activations_float8(weights, activations, device):
Expand Down
22 changes: 16 additions & 6 deletions test/nn/test_qconv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@
import torch
from helpers import assert_similar, random_qactivation, random_tensor

from optimum.quanto import ActivationQBytesTensor, Calibration, qfloat8_e4m3fn, qfloat8_e5m2, qint4, qint8
from optimum.quanto import (
ActivationQBytesTensor,
Calibration,
qfloat8_e4m3fn,
qfloat8_e4m3fnuz,
qfloat8_e5m2,
qint4,
qint8,
)
from optimum.quanto.nn import QConv2d


Expand All @@ -37,7 +45,9 @@ def _test_quantize_conv2d(batch_size, img_shape, out_channels, use_bias, weights
# We need to increase atol for float16 dtype
dtype_atol = {torch.float32: 1e-4, torch.float16: 1e-3}[dtype]
# We also need to increase atol for float8 itypes
atol = {None: dtype_atol, qint8: dtype_atol, qfloat8_e5m2: 5e-3, qfloat8_e4m3fn: 5e-3}[activations]
atol = {None: dtype_atol, qint8: dtype_atol, qfloat8_e5m2: 5e-3, qfloat8_e4m3fn: 5e-3, qfloat8_e4m3fnuz: 5e-3}[
activations
]
assert_similar(out, qout, atol=atol)


Expand Down Expand Up @@ -66,8 +76,8 @@ def test_quantize_conv2d_float32_activations_int8(batch_size, img_shape, out_cha
@pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-int4", "w-int8"])
@pytest.mark.parametrize(
"activations",
[qfloat8_e5m2, qfloat8_e4m3fn],
ids=["a-float8-e5m2", "a-float8-e4m3"],
[qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz],
ids=["a-float8-e5m2", "a-float8-e4m3", "a-float8_e4m3-uz"],
)
@pytest.mark.skip_device("mps")
def test_quantize_conv2d_float16_activations_float8(
Expand All @@ -83,8 +93,8 @@ def test_quantize_conv2d_float16_activations_float8(
@pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-int4", "w-int8"])
@pytest.mark.parametrize(
"activations",
[qfloat8_e5m2, qfloat8_e4m3fn],
ids=["a-float8-e5m2", "a-float8-e4m3"],
[qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz],
ids=["a-float8-e5m2", "a-float8-e4m3", "a-float8-e4m3-uz"],
)
@pytest.mark.skip_device("mps")
def test_quantize_conv2d_float32_activations_float8(
Expand Down
12 changes: 6 additions & 6 deletions test/nn/test_qlayernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
from helpers import assert_similar, random_qactivation

from optimum.quanto import ActivationQBytesTensor, Calibration, qfloat8_e4m3fn, qfloat8_e5m2, qint8
from optimum.quanto import ActivationQBytesTensor, Calibration, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2, qint8
from optimum.quanto.nn import QLayerNorm


Expand All @@ -37,7 +37,7 @@ def _test_quantize_layernorm(batch_size, tokens, embeddings, dtype, activations,
# We need to increase atol for float16 dtype
dtype_atol = {torch.float32: 1e-4, torch.float16: 1e-3}[dtype]
# We also need to increase atol for float8 qtypes
atol = {qint8: dtype_atol, qfloat8_e5m2: 5e-3, qfloat8_e4m3fn: 5e-3}[activations]
atol = {qint8: dtype_atol, qfloat8_e5m2: 5e-3, qfloat8_e4m3fn: 5e-3, qfloat8_e4m3fnuz: 5e-3}[activations]
assert_similar(out, qout, atol=atol)


Expand All @@ -57,8 +57,8 @@ def test_quantize_layernorm_float32_activations_int8(batch_size, tokens, embeddi
@pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)])
@pytest.mark.parametrize(
"activations",
[qfloat8_e5m2, qfloat8_e4m3fn],
ids=["a-float8-e5m2", "a-float8-e4m3"],
[qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz],
ids=["a-float8-e5m2", "a-float8-e4m3", "a-float8-e4m3-uz"],
)
@pytest.mark.skip_device("mps")
def test_quantize_layernorm_float16_activations_float8(batch_size, tokens, embeddings, activations, device):
Expand All @@ -69,8 +69,8 @@ def test_quantize_layernorm_float16_activations_float8(batch_size, tokens, embed
@pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)])
@pytest.mark.parametrize(
"activations",
[qfloat8_e5m2, qfloat8_e4m3fn],
ids=["a-float8-e5m2", "a-float8-e4m3"],
[qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz],
ids=["a-float8-e5m2", "a-float8-e4m3", "a-float8-e4m3-uz"],
)
@pytest.mark.skip_device("mps")
def test_quantize_layernorm_float32_activations_float8(batch_size, tokens, embeddings, activations, device):
Expand Down
9 changes: 5 additions & 4 deletions test/nn/test_qlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
absmax_scale,
qfloat8,
qfloat8_e4m3fn,
qfloat8_e4m3fnuz,
qfloat8_e5m2,
qint4,
qint8,
Expand Down Expand Up @@ -82,8 +83,8 @@ def test_quantize_linear_float32_activations_int8(batch_size, tokens, embeddings
@pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-qint4", "w-qint8"])
@pytest.mark.parametrize(
"activations",
[qfloat8_e4m3fn],
ids=["a-qfloat8-e4m3"],
[qfloat8_e4m3fn, qfloat8_e4m3fnuz],
ids=["a-qfloat8-e4m3", "a-float8-e4m3-uz"],
)
@pytest.mark.skip_device("mps")
def test_quantize_linear_float16_activations_float8(
Expand All @@ -99,8 +100,8 @@ def test_quantize_linear_float16_activations_float8(
@pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-qint4", "w-qint8"])
@pytest.mark.parametrize(
"activations",
[qfloat8_e5m2, qfloat8_e4m3fn],
ids=["a-qfloat8-e5m2", "a-qfloat8-e4m3"],
[qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz],
ids=["a-qfloat8-e5m2", "a-qfloat8-e4m3", "a-float8-e4m3-uz"],
)
@pytest.mark.skip_device("mps")
def test_quantize_linear_float32_activations_float8(
Expand Down
7 changes: 4 additions & 3 deletions test/quantize/test_quantize_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
absmax_scale,
freeze,
qfloat8_e4m3fn,
qfloat8_e4m3fnuz,
qfloat8_e5m2,
qint4,
qint8,
Expand Down Expand Up @@ -102,13 +103,13 @@ def test_quantize_mlp_int8_activations(weights, frozen, device):
@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"])
@pytest.mark.parametrize(
"activations",
[qfloat8_e5m2, qfloat8_e4m3fn],
ids=["a-qfloat8-e5m2", "a-qfloat8-e4m3"],
[qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz],
ids=["a-qfloat8-e5m2", "a-qfloat8-e4m3", "a-float8-e4m3-uz"],
)
@pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"])
@pytest.mark.skip_device("mps")
def test_quantize_mlp_float8_activations(weights, activations, frozen, device):
atol = {qfloat8_e4m3fn: 1e-3, qfloat8_e5m2: 1e-2}[activations]
atol = {qfloat8_e4m3fn: 1e-3, qfloat8_e4m3fnuz: 1e-3, qfloat8_e5m2: 1e-2}[activations]
_test_quantize_mlp(weights, activations, None, frozen, device, atol=atol)


Expand Down
5 changes: 4 additions & 1 deletion test/tensor/activations/test_activations_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
absmax_scale,
qfloat8,
qfloat8_e4m3fn,
qfloat8_e4m3fnuz,
qfloat8_e5m2,
qint8,
)
Expand All @@ -44,7 +45,9 @@ def test_symmetric_quantize_int(input_shape, dtype, qtype, device):
@pytest.mark.parametrize("input_shape", [(32, 32), (32, 10, 32)])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"])
@pytest.mark.parametrize(
"qtype", [qfloat8, qfloat8_e4m3fn, qfloat8_e5m2], ids=["qfloat8", "qfloat8_e4m3fn", "qfloat8_e5m2"]
"qtype",
[qfloat8, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2],
ids=["qfloat8", "qfloat8_e4m3fn", "qfloat8_e4m3fnuz", "qfloat8_e5m2"],
)
def test_symmetric_quantize_float8(input_shape, dtype, qtype, device):
a = random_tensor(input_shape, dtype=dtype).to(device)
Expand Down
5 changes: 4 additions & 1 deletion test/tensor/weights/test_weights_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
absmax_scale,
qfloat8,
qfloat8_e4m3fn,
qfloat8_e4m3fnuz,
qfloat8_e5m2,
qint8,
)
Expand Down Expand Up @@ -49,7 +50,9 @@ def test_symmetric_quantize_int(input_shape, dtype, qtype, axis, device):
@pytest.mark.parametrize("input_shape", [(32, 32), (32, 10, 32)])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"])
@pytest.mark.parametrize(
"qtype", [qfloat8, qfloat8_e4m3fn, qfloat8_e5m2], ids=["qfloat8", "qfloat8_e4m3fn", "qfloat8_e5m2"]
"qtype",
[qfloat8, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2],
ids=["qfloat8", "qfloat8_e4m3fn", "qfloat8_e4m3fnuz", "qfloat8_e5m2"],
)
@pytest.mark.parametrize(
"axis",
Expand Down

0 comments on commit f37c58b

Please sign in to comment.