77import math
88from dataclasses import dataclass
99from enum import Enum
10- from typing import Any , Dict , Optional
10+ from typing import Optional
1111
1212import torch
1313from torch .utils ._python_dispatch import return_and_correct_aliasing
3939
4040aten = torch .ops .aten
4141
42- NVFP4_OPS_TABLE : Dict [Any , Any ] = {}
43-
4442
4543class NVFP4MMConfig (Enum ):
4644 DYNAMIC = "dynamic"
@@ -55,18 +53,6 @@ class QuantizeTensorToNVFP4Kwargs(QuantizeTensorKwargs):
5553 use_dynamic_per_tensor_scale : bool = False
5654
5755
58- # TODO(future PR): move over to TorchAOBaseTensor's dispatch
59- def implements (aten_ops ):
60- """Register aten ops to the NVFP4 op table"""
61-
62- def decorator (func ):
63- for op in aten_ops :
64- NVFP4_OPS_TABLE [op ] = func
65- return func
66-
67- return decorator
68-
69-
7056class NVFP4Tensor (TorchAOBaseTensor ):
7157 """NVIDIA FP4 (NVFP4) Tensor subclass.
7258
@@ -141,14 +127,6 @@ def __repr__(self):
141127 def _quantization_type (self ):
142128 return f"{ self ._is_swizzled_scales = } , { self .use_triton_kernel = } , { self .act_quant_kwargs = } "
143129
144- @classmethod
145- def __torch_dispatch__ (cls , func , types , args , kwargs = None ):
146- # Use NVFP4-specific ops table
147- if func in NVFP4_OPS_TABLE :
148- return NVFP4_OPS_TABLE [func ](func , types , args , kwargs )
149-
150- raise NotImplementedError (f"{ func } not implemented for NVFP4Tensor" )
151-
152130 @staticmethod
153131 def to_nvfp4 (
154132 data_hp : torch .Tensor ,
@@ -308,13 +286,10 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool:
308286 )
309287
310288
311- @implements ([aten .detach .default , aten .alias .default ])
312- def nvfp4_detach_alias (func , types , args , kwargs ):
313- return return_and_correct_aliasing (
314- func , args , kwargs , args [0 ]._apply_fn_to_data (func )
315- )
289+ implements = NVFP4Tensor .implements
316290
317291
292+ # TODO(future PR): move this to AOBaseTensor (will require debugging/fixing CI)
318293@implements ([aten ._to_copy .default ])
319294def nvfp4_to_copy (func , types , args , kwargs ):
320295 """Autocast + device movement"""
@@ -354,33 +329,6 @@ def nvfp4_to_copy(func, types, args, kwargs):
354329 return tensor
355330
356331
357- @implements ([aten .copy_ .default ])
358- def nvfp4_copy_ (func , types , args , kwargs ):
359- self = args [0 ]
360- src = args [1 ]
361- if NVFP4Tensor ._same_metadata (self , src ):
362- self_tensors = self .__tensor_flatten__ ()[0 ]
363- for tensor_name in self_tensors :
364- getattr (self , tensor_name ).copy_ (getattr (src , tensor_name ))
365- return self
366- raise ValueError (
367- f"Not supported args for copy_ due to metadata mismatch: { self } , { src } "
368- )
369-
370-
371- @implements ([aten .clone .default ])
372- def nvfp4_clone (func , types , args , kwargs ):
373- self = args [0 ]
374- memory_format = kwargs .get ("memory_format" , None )
375-
376- if memory_format is not None :
377- clone_fn = lambda x : x .clone (memory_format = memory_format )
378- else :
379- clone_fn = lambda x : x .clone ()
380-
381- return self ._apply_fn_to_data (clone_fn )
382-
383-
384332@implements ([aten .slice .Tensor ])
385333def nvfp4_slice (func , types , args , kwargs ):
386334 x , dim , start , end , step = fill_defaults (args , 5 , [0 , None , None , 1 ])
0 commit comments