From da0315fd839dcb32434d51441ac1deb6bac605a6 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Mon, 19 Feb 2024 14:07:01 +0100 Subject: [PATCH] fix(qtensor): prevent shallow copies By overriding a native tensor method, we prevent torch from blindly assigning the content of a QTensor to another using Tensor.data. --- quanto/tensor/func.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/quanto/tensor/func.py b/quanto/tensor/func.py index 49461c79..294defbf 100644 --- a/quanto/tensor/func.py +++ b/quanto/tensor/func.py @@ -33,6 +33,12 @@ def get_qtensor_func(func): return _QTENSOR_FUNC_TABLE.get(func, None) +@register_qtensor_func([torch._has_compatible_shallow_copy_type]) +def has_compatible_shallow_copy_type(func, input: torch.Tensor, from_: torch.Tensor): + # Prevent torch from trying to shallow copy one QTensor to another + return False + + @register_qtensor_func( [ torch.nn.functional.cross_entropy,