Skip to content

Commit

Permalink
delete inplace copy override
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers committed Aug 23, 2024
1 parent ce86906 commit 9ebfb42
Showing 1 changed file with 0 additions and 25 deletions.
25 changes: 0 additions & 25 deletions torchtune/modules/low_precision/_register_nf4_dispatch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,3 @@ def clone(func, *args, **kwargs):
in precision.
"""
return to_nf4(args[0][0].get_original_weight())


# TorchAO have `NF4.copy_` starting from `0.2.0`
# it's a superset of `inplace_copy` since it covers `NF4.copy_(NF4)`
@nf4_tensor_impl([torch.ops.aten.copy_.default])
def inplace_copy(func, *args, **kwargs):
"""
Performs an inplace copy of an incoming tensor into the tensor
being copied into. The inplace tensor is given by args[0][1] and the
tensor being copied into is given by args[0][0]. The copy is performed
by copying over all attributes. This method would have to be updated
if additional attributes are added to NF4Tensor.
"""
dest_tensor = args[0][0] # tensor we are inplace copying into
ref_tensor = to_nf4(
args[0][1].to(dest_tensor.device)
) # TODO check if nf4 tensor takes in device arg
dest_tensor.block_size = ref_tensor.block_size
dest_tensor.n_blocks = ref_tensor.n_blocks
dest_tensor.scaler_block_size = ref_tensor.scaler_block_size
dest_tensor.quantized_scalers = ref_tensor.quantized_scalers
dest_tensor.quantization_factor = ref_tensor.quantization_factor
dest_tensor.scaler_mean = ref_tensor.scaler_mean
dest_tensor.quantized_data = ref_tensor.quantized_data
dest_tensor.nf4 = ref_tensor.nf4

0 comments on commit 9ebfb42

Please sign in to comment.