Skip to content

Commit f77675c

Browse files
committed
mxtensor: delete to_copy override
Summary: We get this for free from `TorchAOBaseTensor`, and that implementation handles more args, such as args we need for HF serialization to work. Test Plan: CI, and also unblocks saving an HF model with MXFP4 weights to disk. Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 580c44d ghstack-comment-id: 3336399088 Pull Request resolved: #3072
1 parent 4ab5f08 commit f77675c

File tree

1 file changed

+0
-43
lines changed

1 file changed

+0
-43
lines changed

torchao/prototype/mx_formats/mx_ops.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -311,49 +311,6 @@ def mx_copy_(func, types, args, kwargs):
311311
)
312312

313313

314-
@implements([aten._to_copy.default])
315-
def autocast_to_copy(func, types, args, kwargs):
316-
"""Autocast + device movement"""
317-
assert isinstance(args[0], MXTensor)
318-
319-
# Handle dtype parameter
320-
dtype = kwargs.pop("dtype", None)
321-
if dtype is not None:
322-
assert dtype in {
323-
torch.float16,
324-
torch.bfloat16,
325-
}, "Only support floating point conversion for autocast w/ MXTensor"
326-
327-
# Handle device parameter
328-
device = kwargs.pop("device", None)
329-
if device is not None:
330-
# Apply device change using _apply_fn_to_data
331-
tensor = args[0]._apply_fn_to_data(lambda x: func(x, device=device))
332-
tensor = return_and_correct_aliasing(func, args, {}, tensor)
333-
else:
334-
tensor = args[0]
335-
336-
# Verify no other kwargs remain
337-
assert len(kwargs) == 0, "Only support dtype and device kwargs for autocast"
338-
339-
# If dtype is specified, create a new MXTensor with the requested dtype
340-
if dtype is not None:
341-
res = MXTensor(
342-
tensor.qdata,
343-
tensor._scale_e8m0,
344-
tensor._elem_dtype,
345-
tensor._block_size,
346-
dtype,
347-
tensor._gemm_kernel_choice,
348-
tensor._pack_fp6,
349-
tensor.act_quant_kwargs,
350-
)
351-
return res
352-
353-
# If only device was changed, return the device-changed tensor
354-
return tensor
355-
356-
357314
@implements([aten.clone.default])
358315
def mx_clone(func, types, args, kwargs):
359316
self = args[0]

0 commit comments

Comments
 (0)