diff --git a/torchtune/modules/low_precision/_register_nf4_dispatch_ops.py b/torchtune/modules/low_precision/_register_nf4_dispatch_ops.py index b057a1fbc3..037891ac36 100644 --- a/torchtune/modules/low_precision/_register_nf4_dispatch_ops.py +++ b/torchtune/modules/low_precision/_register_nf4_dispatch_ops.py @@ -6,7 +6,6 @@ import torch from torchao.dtypes.nf4tensor import implements as nf4_tensor_impl, to_nf4 -from torchtune.modules.low_precision._utils import _get_torchao_version @nf4_tensor_impl([torch.ops.aten.clone.default]) @@ -18,37 +17,3 @@ def clone(func, *args, **kwargs): in precision. """ return to_nf4(args[0][0].get_original_weight()) - - -should_define_inplace_copy = True -ao_version, is_nightly = _get_torchao_version() -if ao_version: - if (is_nightly and ao_version >= "2024.5.20") or ( - not is_nightly and ao_version >= "0.2.0" - ): - should_define_inplace_copy = False - -if should_define_inplace_copy: - # TorchAO have `NF4.copy_` starting from `0.2.0` - # it's a superset of `inplace_copy` since it covers `NF4.copy_(NF4)` - @nf4_tensor_impl([torch.ops.aten.copy_.default]) - def inplace_copy(func, *args, **kwargs): - """ - Performs an inplace copy of an incoming tensor into the tensor - being copied into. The inplace tensor is given by args[0][1] and the - tensor being copied into is given by args[0][0]. The copy is performed - by copying over all attributes. This method would have to be updated - if additional attributes are added to NF4Tensor. - """ - dest_tensor = args[0][0] # tensor we are inplace copying into - ref_tensor = to_nf4( - args[0][1].to(dest_tensor.device) - ) # TODO check if nf4 tensor takes in device arg - dest_tensor.block_size = ref_tensor.block_size - dest_tensor.n_blocks = ref_tensor.n_blocks - dest_tensor.scaler_block_size = ref_tensor.scaler_block_size - dest_tensor.quantized_scalers = ref_tensor.quantized_scalers - dest_tensor.quantization_factor = ref_tensor.quantization_factor - dest_tensor.scaler_mean = ref_tensor.scaler_mean - dest_tensor.quantized_data = ref_tensor.quantized_data - dest_tensor.nf4 = ref_tensor.nf4