Skip to content

Commit a7859fd

Browse files
committed
mxtensor: support clone
Summary: This is needed for saving HF models with mxfp4 weights to disk. Test Plan: ```bash pytest test/prototype/mx_formats/test_mx_tensor.py -s -k test_clone ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 300a78b ghstack-comment-id: 3336365476 Pull Request resolved: #3070
1 parent 8f3f438 commit a7859fd

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,20 @@ def test_view(elem_dtype):
455455
x_mx_2 = x_mx.view(2, 4) # noqa: F841
456456

457457

458+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
459+
def test_clone():
460+
data = torch.randn(8, 8, device="cuda", dtype=torch.bfloat16)
461+
block_size = 4
462+
data_mx = MXTensor.to_mx(data, torch.float8_e4m3fn, block_size)
463+
data_mx_c = data_mx.clone()
464+
torch.testing.assert_close(
465+
data_mx.to_dtype(torch.bfloat16),
466+
data_mx_c.to_dtype(torch.bfloat16),
467+
atol=0,
468+
rtol=0,
469+
)
470+
471+
458472
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
459473
@pytest.mark.parametrize("elem_dtype", [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2])
460474
@pytest.mark.parametrize("pack_fp6", [False, True])

torchao/prototype/mx_formats/mx_ops.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,16 @@ def autocast_to_copy(func, types, args, kwargs):
352352

353353
# If only device was changed, return the device-changed tensor
354354
return tensor
355+
356+
357+
@implements([aten.clone.default])
358+
def mx_clone(func, types, args, kwargs):
359+
self = args[0]
360+
memory_format = kwargs.get("memory_format", None)
361+
362+
if memory_format is not None:
363+
clone_fn = lambda x: x.clone(memory_format=memory_format)
364+
else:
365+
clone_fn = lambda x: x.clone()
366+
367+
return self._apply_fn_to_data(clone_fn)

0 commit comments

Comments
 (0)