Skip to content

Commit 1064e83

Browse files
committed
Update
[ghstack-poisoned]
2 parents 7c1166e + d3306b2 commit 1064e83

File tree

20 files changed

+795
-100
lines changed

20 files changed

+795
-100
lines changed

examples/sam2_amg_server/generate_data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def latencies_statistics(data):
6060
mean = np.mean(data_array)
6161
# Calculate the median
6262
median = np.median(data_array)
63+
# Calculate the 90th percentile
64+
p90 = np.percentile(data_array, 90)
6365
# Calculate the 95th percentile
6466
p95 = np.percentile(data_array, 95)
6567
# Calculate the 99th percentile
@@ -74,6 +76,7 @@ def latencies_statistics(data):
7476
{
7577
"mean": mean,
7678
"median": median,
79+
"p90": p90,
7780
"p95": p95,
7881
"p99": p99,
7982
"p999": p999,

examples/sam2_amg_server/result.csv

Lines changed: 70 additions & 70 deletions
Large diffs are not rendered by default.

ruff.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,9 @@
22
# Add linting rules here
33
lint.select = ["F", "I"]
44
lint.ignore = ["E731"]
5+
6+
7+
# Exclude third-party modules
8+
exclude = [
9+
"third_party/*",
10+
]

setup.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,7 @@ def get_extensions():
215215
extra_link_args = []
216216
extra_compile_args = {
217217
"cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"],
218-
"nvcc": [
219-
"-O3" if not debug_mode else "-O0",
220-
"-t=0",
221-
],
218+
"nvcc": ["-O3" if not debug_mode else "-O0", "-t=0", "-std=c++17"],
222219
}
223220

224221
if not IS_WINDOWS:
@@ -257,12 +254,16 @@ def get_extensions():
257254
use_cutlass = True
258255
cutlass_dir = os.path.join(third_party_path, "cutlass")
259256
cutlass_include_dir = os.path.join(cutlass_dir, "include")
257+
cutlass_tools_include_dir = os.path.join(
258+
cutlass_dir, "tools", "util", "include"
259+
)
260260
cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir)
261261
if use_cutlass:
262262
extra_compile_args["nvcc"].extend(
263263
[
264264
"-DTORCHAO_USE_CUTLASS",
265265
"-I" + cutlass_include_dir,
266+
"-I" + cutlass_tools_include_dir,
266267
"-I" + cutlass_extensions_include_dir,
267268
]
268269
)

test/float8/test_base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,10 @@ def test_transpose(self):
164164

165165
@pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)])
166166
@pytest.mark.parametrize("axiswise_dim", [0, -1])
167-
def test_axiswise_dynamic_cast(self, shape, axiswise_dim):
167+
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])
168+
def test_axiswise_dynamic_cast(
169+
self, shape, axiswise_dim, round_scales_to_power_of_2
170+
):
168171
a = torch.randn(*shape, dtype=torch.bfloat16)
169172
linear_mm_config = LinearMMConfig()
170173
a_fp8 = hp_tensor_to_float8_dynamic(
@@ -173,6 +176,7 @@ def test_axiswise_dynamic_cast(self, shape, axiswise_dim):
173176
linear_mm_config,
174177
scaling_granularity=ScalingGranularity.AXISWISE,
175178
axiswise_dim=axiswise_dim,
179+
round_scales_to_power_of_2=round_scales_to_power_of_2,
176180
)
177181
a_dq = a_fp8.to_original_precision()
178182
sqnr = compute_error(a, a_dq)

test/float8/test_compile.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,7 @@
4545
hp_tensor_to_float8_delayed,
4646
hp_tensor_to_float8_dynamic,
4747
)
48-
from torchao.float8.float8_tensor import (
49-
GemmInputRole,
50-
LinearMMConfig,
51-
ScaledMMConfig,
52-
)
48+
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig
5349
from torchao.float8.float8_utils import config_has_stateful_scaling
5450
from torchao.float8.stateful_float8_linear import StatefulFloat8Linear
5551
from torchao.testing.float8.test_utils import get_test_float8_linear_config
@@ -420,13 +416,23 @@ def test_sync_amax_func_cuda_graph_success():
420416
torch.float16,
421417
],
422418
)
423-
def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
419+
@pytest.mark.parametrize(
420+
"round_scales_to_power_of_2",
421+
[
422+
True,
423+
False,
424+
],
425+
)
426+
def test_dynamic_scale_numeric_parity(
427+
dtype: torch.dtype, round_scales_to_power_of_2: bool
428+
):
424429
scaling_type_weight = ScalingType.DYNAMIC
425430
torch.manual_seed(42)
426431
hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype)
427432
hp_tensor2 = hp_tensor1.detach().clone()
428433
float8_config = Float8LinearConfig(
429434
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
435+
round_scales_to_power_of_2=round_scales_to_power_of_2,
430436
)
431437
linear_mm_config = LinearMMConfig(
432438
# output
@@ -456,13 +462,15 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
456462
e4m3_dtype,
457463
linear_mm_config,
458464
gemm_input_role=GemmInputRole.WEIGHT,
465+
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
459466
)
460467
torch._dynamo.reset()
461468
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)(
462469
hp_tensor2,
463470
e4m3_dtype,
464471
linear_mm_config,
465472
gemm_input_role=GemmInputRole.WEIGHT,
473+
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
466474
)
467475
assert torch.equal(float8_eager._scale, float8_compile._scale)
468476
assert torch.equal(float8_eager._data, float8_compile._data)

test/float8/test_float8_utils.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import unittest
2+
3+
import pytest
4+
import torch
5+
6+
from torchao.float8.float8_utils import _round_scale_down_to_power_of_2
7+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
8+
9+
if not TORCH_VERSION_AT_LEAST_2_5:
10+
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
11+
12+
13+
# source for notable single-precision cases:
14+
# https://en.wikipedia.org/wiki/Single-precision_floating-point_format
15+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
16+
@pytest.mark.parametrize(
17+
"test_case",
18+
[
19+
# ("test_case_name", input, expected result)
20+
("one", 1.0, 1.0),
21+
("inf", float("inf"), float("inf")),
22+
("nan", float("nan"), float("nan")),
23+
("smallest positive subnormal number", 2**-126 * 2**-23, 2**-126 * 2**-23),
24+
("largest normal number", 2**127 * (2 - 2**-23), float("inf")),
25+
("smallest positive normal number", 2**-126, 2**-126),
26+
("largest number less than one", 1.0 - 2**-24, 0.5),
27+
("smallest number larger than one", 1.0 + 2**-23, 1.0),
28+
# TODO(danielvegamyhre): debug why creating a tensor with largest
29+
# subnormal value in CI env for pytorch 2.5.1 truncates the value to 0.
30+
# ("largest subnormal number", [2**-126 * (1 - 2**-23), 1.1754943508222875e-38]),
31+
],
32+
)
33+
def test_round_scale_down_to_power_of_2_valid_inputs(
34+
test_case: dict,
35+
):
36+
test_case_name, input, expected_result = test_case
37+
input_tensor, expected_tensor = (
38+
torch.tensor(input, dtype=torch.float32).cuda(),
39+
torch.tensor(expected_result, dtype=torch.float32).cuda(),
40+
)
41+
result = _round_scale_down_to_power_of_2(input_tensor)
42+
43+
assert (
44+
torch.equal(result, expected_tensor)
45+
or (result.isnan() and expected_tensor.isnan())
46+
), f"test: {test_case_name}, input: {input_tensor}, expected {expected_tensor}, but got {result}"
47+
48+
49+
@pytest.mark.parametrize(
50+
"invalid_dtype",
51+
[
52+
torch.bfloat16,
53+
torch.float16,
54+
torch.float64,
55+
torch.int8,
56+
torch.uint8,
57+
torch.int32,
58+
torch.uint32,
59+
torch.int64,
60+
],
61+
)
62+
def test_non_float32_input(invalid_dtype: torch.dtype):
63+
non_float32_tensor = torch.tensor([3.0], dtype=invalid_dtype)
64+
with pytest.raises(AssertionError, match="scale must be float32 tensor"):
65+
_round_scale_down_to_power_of_2(non_float32_tensor)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import pytest
2+
import torch
3+
4+
from torchao.float8.float8_utils import compute_error
5+
from torchao.ops import mx_fp4_bf16, mx_fp8_bf16
6+
from torchao.prototype.mx_formats.mx_tensor import DTYPE_FP4, MXTensor
7+
from torchao.prototype.mx_formats.utils import to_blocked
8+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_100
9+
10+
if not TORCH_VERSION_AT_LEAST_2_4:
11+
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
12+
13+
14+
def run_matrix_test(M: int, K: int, N: int, format) -> float:
15+
dtype = torch.bfloat16
16+
device = torch.device("cuda")
17+
18+
a = torch.rand((M, K), dtype=dtype, device=device)
19+
b = torch.rand((N, K), dtype=dtype, device=device)
20+
21+
fmt = torch.float8_e4m3fn if format == "fp8" else DTYPE_FP4
22+
mx_func = mx_fp8_bf16 if format == "fp8" else mx_fp4_bf16
23+
24+
a_mx = MXTensor.to_mx(a, fmt, 32)
25+
b_mx = MXTensor.to_mx(b, fmt, 32)
26+
27+
a_data = a_mx._data
28+
b_data = b_mx._data
29+
assert b_data.is_contiguous()
30+
b_data = b_data.transpose(-1, -2)
31+
32+
a_scale = a_mx._scale_e8m0.view(M, K // 32)
33+
b_scale = b_mx._scale_e8m0.view(N, K // 32)
34+
35+
a_scale_block = to_blocked(a_scale)
36+
b_scale_block = to_blocked(b_scale)
37+
38+
out_hp = a_mx.to_dtype(torch.bfloat16) @ b_mx.to_dtype(torch.bfloat16).transpose(
39+
-1, -2
40+
)
41+
out = mx_func(a_data, b_data, a_scale_block, b_scale_block)
42+
43+
return compute_error(out_hp, out).item()
44+
45+
46+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
47+
@pytest.mark.skipif(
48+
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8"
49+
)
50+
@pytest.mark.parametrize(
51+
"size",
52+
[
53+
(128, 128, 128),
54+
(256, 256, 256),
55+
(384, 384, 384), # Small
56+
(512, 512, 512),
57+
(768, 768, 768), # Medium
58+
(1024, 1024, 1024),
59+
(8192, 8192, 8192), # Large
60+
(128, 256, 384),
61+
(256, 384, 512), # Non-square
62+
(129, 256, 384),
63+
(133, 512, 528), # Non-aligned
64+
],
65+
ids=lambda x: f"{x[0]}x{x[1]}x{x[2]}",
66+
)
67+
@pytest.mark.parametrize("format", ["fp8", "fp4"])
68+
def test_matrix_multiplication(size, format):
69+
M, K, N = size
70+
sqnr = run_matrix_test(M, K, N, format)
71+
threshold = 80.0
72+
assert (
73+
sqnr >= threshold
74+
), f"{format} SQNR {sqnr} below threshold for dims {M}x{K}x{N}"

third_party/cutlass

Submodule cutlass updated 361 files

0 commit comments

Comments
 (0)