Skip to content

Commit ca37677

Browse files
committed
6/x mx cleanup: make NVFP4Tensor use base implements
Summary: Refactors NVFP4Tensor to use the `implements` function from `AOBaseTensor`. Test Plan: ``` pytest test/prototype/mx_formats/ -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 8370a1e ghstack-comment-id: 3572568961 Pull-Request: #3387
1 parent e1734bb commit ca37677

File tree

1 file changed

+3
-55
lines changed

1 file changed

+3
-55
lines changed

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 3 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import math
88
from dataclasses import dataclass
99
from enum import Enum
10-
from typing import Any, Dict, Optional
10+
from typing import Optional
1111

1212
import torch
1313
from torch.utils._python_dispatch import return_and_correct_aliasing
@@ -39,8 +39,6 @@
3939

4040
aten = torch.ops.aten
4141

42-
NVFP4_OPS_TABLE: Dict[Any, Any] = {}
43-
4442

4543
class NVFP4MMConfig(Enum):
4644
DYNAMIC = "dynamic"
@@ -55,18 +53,6 @@ class QuantizeTensorToNVFP4Kwargs(QuantizeTensorKwargs):
5553
use_dynamic_per_tensor_scale: bool = False
5654

5755

58-
# TODO(future PR): move over to TorchAOBaseTensor's dispatch
59-
def implements(aten_ops):
60-
"""Register aten ops to the NVFP4 op table"""
61-
62-
def decorator(func):
63-
for op in aten_ops:
64-
NVFP4_OPS_TABLE[op] = func
65-
return func
66-
67-
return decorator
68-
69-
7056
class NVFP4Tensor(TorchAOBaseTensor):
7157
"""NVIDIA FP4 (NVFP4) Tensor subclass.
7258
@@ -141,14 +127,6 @@ def __repr__(self):
141127
def _quantization_type(self):
142128
return f"{self._is_swizzled_scales=}, {self.use_triton_kernel=}, {self.act_quant_kwargs=}"
143129

144-
@classmethod
145-
def __torch_dispatch__(cls, func, types, args, kwargs=None):
146-
# Use NVFP4-specific ops table
147-
if func in NVFP4_OPS_TABLE:
148-
return NVFP4_OPS_TABLE[func](func, types, args, kwargs)
149-
150-
raise NotImplementedError(f"{func} not implemented for NVFP4Tensor")
151-
152130
@staticmethod
153131
def to_nvfp4(
154132
data_hp: torch.Tensor,
@@ -308,13 +286,10 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool:
308286
)
309287

310288

311-
@implements([aten.detach.default, aten.alias.default])
312-
def nvfp4_detach_alias(func, types, args, kwargs):
313-
return return_and_correct_aliasing(
314-
func, args, kwargs, args[0]._apply_fn_to_data(func)
315-
)
289+
implements = NVFP4Tensor.implements
316290

317291

292+
# TODO(future PR): move this to AOBaseTensor (will require debugging/fixing CI)
318293
@implements([aten._to_copy.default])
319294
def nvfp4_to_copy(func, types, args, kwargs):
320295
"""Autocast + device movement"""
@@ -354,33 +329,6 @@ def nvfp4_to_copy(func, types, args, kwargs):
354329
return tensor
355330

356331

357-
@implements([aten.copy_.default])
358-
def nvfp4_copy_(func, types, args, kwargs):
359-
self = args[0]
360-
src = args[1]
361-
if NVFP4Tensor._same_metadata(self, src):
362-
self_tensors = self.__tensor_flatten__()[0]
363-
for tensor_name in self_tensors:
364-
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
365-
return self
366-
raise ValueError(
367-
f"Not supported args for copy_ due to metadata mismatch: {self}, {src}"
368-
)
369-
370-
371-
@implements([aten.clone.default])
372-
def nvfp4_clone(func, types, args, kwargs):
373-
self = args[0]
374-
memory_format = kwargs.get("memory_format", None)
375-
376-
if memory_format is not None:
377-
clone_fn = lambda x: x.clone(memory_format=memory_format)
378-
else:
379-
clone_fn = lambda x: x.clone()
380-
381-
return self._apply_fn_to_data(clone_fn)
382-
383-
384332
@implements([aten.slice.Tensor])
385333
def nvfp4_slice(func, types, args, kwargs):
386334
x, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])

0 commit comments

Comments
 (0)