Skip to content

Commit c406df2

Browse files
avizon-awsvkuzodanielvegamyhreXia-Weiwenjiayisunx
authored
Add support for MXFP8 All gather (#3435)
* add MXFP8 all gather support * added TODO for future feature * remove emoji from comment * fixed ruff formating * fixed ruff formatting * add mxfp8 and nvfp4 to Llama eval scripts (#3394) Update [ghstack-poisoned] * flip mx inference scaling setting to RCEIL (#3428) * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * add CLAUDE.local.md to gitignore (#3437) Summary: taking claude code for a more thorough spin, will start with local instructions and will see what makes sense to upstream Test Plan: Reviewers: Subscribers: Tasks: Tags: * bump python version in tutorial ci workflow (#3439) * [CPU] Reland qconv fp8 fusion passes (#3433) * [Reland][PT2E][X86] Add Inductor fusion passes of float8 qconv for X86Inductor backend * add torch version check for Qconv FP8 UTs * fix format issue * Skip tests for ROCm --------- Co-authored-by: Sun, Jiayi <jiayi.sun@intel.com> * Int8Tensor migration cleanup (#3407) * Int8Tensor migration Summary: This PR creates a new Int8Tensor and updates the configs to use the new Int8Tensor flow Test Plan: To ensure BC: ``` pytest test/quantization/test_quant_api.py ``` To test new Int8Tensor: ``` pytest test/quantization/quantize_/workflows/int8/test_int8_tensor.py ``` Reviewers: Subscribers: Tasks: Tags: * ruff fixes * add init * fix ruff again * update * wip * undo update tests * fix ruff * fix varname * fix typing * add tests * fix dtype * fix ci * address granularity cr * update _choose_quant_func_and_quantize_tensor * make block size required attribute * made dtype required as well * address nits * skip per tensor weight only test for now * [xpu][test] Port 2 test/dtypes_{floatx, bitpacking} UT files to intel XPU (#3368) * enable test/dtypes/test_bitpacking.py on intel xpu * enable test/dtypes/test_floatx.py * enable test/dtypes/test_floatx.py * fix format issue * fix format issue * update _DEVICES * [xpu][test] Port 2 test/quantization/pt2e/test_{quantize_pt2e, quantize_pt2e_qat} UT files to intel XPU (#3405) * add test/quantization/pt2e/test_quantize_pt2e.py * add test/quantization/pt2e/test_quantize_pt2e.py * test/quantization/pt2e/test_quantize_pt2e_qat.py * test/quantization/pt2e/test_quantize_pt2e_qat.py * fix format issue * update format * increase timeout for xpu * [Intel GPU] Enable optim SR test (#3055) * updated test with rebase changes * added checks to run only on CUDA with compatibility >=9 * updated test for H100 * added test to workflow --------- Co-authored-by: Vasiliy Kuznetsov <vkuzo@users.noreply.github.com> Co-authored-by: Daniel Vega-Myhre <danvm@meta.com> Co-authored-by: Xia Weiwen <weiwen.xia@intel.com> Co-authored-by: Sun, Jiayi <jiayi.sun@intel.com> Co-authored-by: Jesse Cai <jessecai@meta.com> Co-authored-by: xiangdong <40376367+zxd1997066@users.noreply.github.com> Co-authored-by: Artur Lesniak <artur.lesniak@intel.com>
1 parent 2ae2994 commit c406df2

File tree

4 files changed

+201
-0
lines changed

4 files changed

+201
-0
lines changed

.github/workflows/4xH100_tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,4 @@ jobs:
4747
pip install . --no-build-isolation
4848
./test/float8/test_everything_multi_gpu.sh
4949
./test/prototype/mx_formats/test_mx_dtensor.sh
50+
./test/prototype/mx_formats/test_mxfp8_allgather.sh
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import pytest
2+
import torch
3+
import torch.distributed as dist
4+
5+
from torchao.prototype.mx_formats.mx_tensor import MXTensor
6+
from torchao.utils import is_sm_at_least_90, torch_version_at_least
7+
8+
if not torch_version_at_least("2.7.0"):
9+
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
10+
11+
12+
def setup_distributed():
13+
dist.init_process_group("nccl")
14+
# seed must be the same in all processes
15+
torch.manual_seed(42)
16+
local_rank = torch.distributed.get_rank()
17+
torch.cuda.set_device(local_rank)
18+
return local_rank
19+
20+
21+
def _test_allgather(local_rank):
22+
golden_qdata = (
23+
torch.randint(0, 256, (256, 512), dtype=torch.uint8)
24+
.to(torch.float8_e5m2)
25+
.to(local_rank)
26+
)
27+
28+
# Random scale factors (typically float32 or uint8 for e8m0)
29+
golden_scale = (
30+
torch.randint(0, 256, (256, 16), dtype=torch.uint8)
31+
.view(torch.float8_e8m0fnu)
32+
.to(local_rank)
33+
)
34+
35+
# Create golden MXTensor
36+
golden_mx = MXTensor(
37+
golden_qdata,
38+
golden_scale,
39+
elem_dtype=torch.float8_e5m2,
40+
block_size=32,
41+
orig_dtype=torch.float32,
42+
kernel_preference=None,
43+
act_quant_kwargs=None,
44+
is_swizzled_scales=None,
45+
)
46+
47+
local_rank = torch.distributed.get_rank()
48+
world_size = torch.distributed.get_world_size()
49+
50+
# Each rank gets its shard (split along dim 0)
51+
shard_size = golden_qdata.shape[0] // world_size # 2 rows per rank
52+
start_idx = local_rank * shard_size
53+
end_idx = (local_rank + 1) * shard_size
54+
55+
# Create local MXTensor from shard
56+
local_mx = MXTensor(
57+
golden_qdata[start_idx:end_idx].clone().to(local_rank),
58+
golden_scale[start_idx:end_idx].clone().to(local_rank),
59+
elem_dtype=torch.float8_e5m2,
60+
block_size=32,
61+
orig_dtype=torch.float32,
62+
kernel_preference=None,
63+
act_quant_kwargs=None,
64+
is_swizzled_scales=None,
65+
)
66+
67+
# Perform all_gather
68+
gathered_mx = torch.ops._c10d_functional.all_gather_into_tensor.default(
69+
local_mx,
70+
world_size,
71+
"0",
72+
)
73+
gathered_mx = torch.ops._c10d_functional.wait_tensor.default(gathered_mx)
74+
75+
# Verify type
76+
assert isinstance(gathered_mx, MXTensor), (
77+
f"Expected MXTensor, got {type(gathered_mx)}"
78+
)
79+
80+
# Verify shape
81+
assert gathered_mx.shape == golden_mx.shape, (
82+
f"Shape mismatch: {gathered_mx.shape} vs {golden_mx.shape}"
83+
)
84+
85+
# Verify qdata matches golden exactly
86+
if not torch.equal(gathered_mx.qdata, golden_qdata):
87+
assert False, "qdata mismatch"
88+
89+
# Verify scale matches golden exactly
90+
if not torch.equal(
91+
gathered_mx.scale.view(torch.uint8),
92+
golden_scale.view(torch.uint8),
93+
):
94+
assert False, "scale mismatch"
95+
96+
assert gathered_mx.block_size == 32
97+
98+
99+
if __name__ == "__main__":
100+
local_rank = setup_distributed()
101+
102+
assert is_sm_at_least_90() == True, "SM must be > 9.0"
103+
104+
try:
105+
_test_allgather(local_rank)
106+
except Exception as e:
107+
raise e
108+
109+
torch.distributed.destroy_process_group()
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#!/bin/bash
2+
3+
# terminate script on first error
4+
set -e
5+
6+
if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then
7+
echo "Skipping test_dtensor.sh because no CUDA devices are available."
8+
exit
9+
fi
10+
11+
# integration tests for TP/SP
12+
NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/prototype/mx_formats/test_mxfp8_allgather.py

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,3 +842,82 @@ def mx_select(func, types, args, kwargs):
842842
old_mx_tensor._is_swizzled_scales,
843843
)
844844
return return_and_correct_aliasing(func, args, kwargs, new_mx_tensor)
845+
846+
847+
@implements([torch.ops._c10d_functional.all_gather_into_tensor.default])
848+
def mx_all_gather(func, types, args, kwargs):
849+
"""
850+
All-gather for MXTensor
851+
852+
Args:
853+
func: The operation (all_gather_into_tensor)
854+
types: Tensor types involved
855+
args: (mx_tensor, group_tag, ...)
856+
kwargs: Additional arguments
857+
"""
858+
mx_tensor = args[0]
859+
group_tag = args[1] if len(args) > 1 else "default"
860+
861+
# TODO: Add support for concat CC as a future optimization
862+
863+
# Gather both data and scale
864+
gathered_qdata = torch.ops._c10d_functional.all_gather_into_tensor.default(
865+
mx_tensor.qdata, # The quantized data
866+
group_tag,
867+
*args[2:],
868+
**kwargs,
869+
)
870+
871+
gathered_scale = torch.ops._c10d_functional.all_gather_into_tensor.default(
872+
mx_tensor.scale.view(
873+
torch.uint8
874+
), # The scale factors, Need to cast to uint8 as float8_e8m0fnu is not support for all gather.
875+
group_tag,
876+
*args[2:],
877+
**kwargs,
878+
)
879+
880+
gathered_scale = gathered_scale.view(torch.float8_e8m0fnu)
881+
882+
# Return new MXTensor with gathered data
883+
return MXTensor(
884+
gathered_qdata,
885+
gathered_scale,
886+
mx_tensor._elem_dtype,
887+
mx_tensor.block_size,
888+
mx_tensor._orig_dtype,
889+
mx_tensor.kernel_preference,
890+
mx_tensor.act_quant_kwargs,
891+
mx_tensor._is_swizzled_scales,
892+
)
893+
894+
895+
@implements([torch.ops._c10d_functional.wait_tensor.default])
896+
def mx_wait_tensor(func, types, args, kwargs):
897+
"""
898+
Wait for async collective to complete on MXTensor
899+
900+
This is called after collectives like all_gather to ensure
901+
the operation has completed before using the tensor.
902+
"""
903+
mx_tensor = args[0]
904+
905+
# Wait on both components
906+
waited_qdata = torch.ops._c10d_functional.wait_tensor.default(
907+
mx_tensor.qdata, *args[1:], **kwargs
908+
)
909+
910+
waited_scale = torch.ops._c10d_functional.wait_tensor.default(
911+
mx_tensor.scale, *args[1:], **kwargs
912+
)
913+
914+
return MXTensor(
915+
waited_qdata,
916+
waited_scale,
917+
mx_tensor._elem_dtype,
918+
mx_tensor.block_size,
919+
mx_tensor._orig_dtype,
920+
mx_tensor.kernel_preference,
921+
mx_tensor.act_quant_kwargs,
922+
mx_tensor._is_swizzled_scales,
923+
)

0 commit comments

Comments
 (0)