Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

overwrite to() for QTensor and QBitsTensor #88

Closed
wants to merge 4 commits into from

Conversation

SunMarc
Copy link
Member

@SunMarc SunMarc commented Feb 16, 2024

What does this PR do ?

This PR fixes the to() method so that it is applied to the tensors _data, _scale and _zeropoint when using QTensor or QBitsTensor. Before this PR, doing the following QLinear().to(device) would not change the device of these tensors.

I ran the following tests (that were not passing before) : python -m pytest -sv test/nn/test_qlinear.py::test_move_qlinear

I need to check why these tests are not passing anymore with this PR :

FAILED test/tensor/ops/test_quantized_dispatch.py::test_to_device[cuda] - AssertionError: assert 'cpu' == 'cuda'
FAILED test/tensor/ops/test_quantized_dispatch.py::test_softmax[cuda-5-5-1] - AssertionError
FAILED test/tensor/ops/test_quantized_dispatch.py::test_softmax[cuda-5-5-10] - AssertionError
FAILED test/tensor/ops/test_quantized_dispatch.py::test_softmax[cuda-32-32-1] - AssertionError
FAILED test/tensor/ops/test_quantized_dispatch.py::test_softmax[cuda-32-32-10] - AssertionError
FAILED test/tensor/ops/test_quantized_dispatch.py::test_softmax[cuda-10-32-1] - AssertionError
FAILED test/tensor/ops/test_quantized_dispatch.py::test_softmax[cuda-10-32-10] - AssertionError

cc @dacorvo

Comment on lines +11 to +16
def test_qtensor_move(device):
input_shape = (2, 4, 8)
qa = random_qtensor(input_shape, dtype=torch.float32)
qa = qa.to(device)
assert qa._data.device.type == device.type
assert qa._scale.device.type == device.type
Copy link
Member Author

@SunMarc SunMarc Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following works even before this PR. This is why you never had this specific issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def to(self, *args, **kwargs):
self._data = self._data.to(*args, **kwargs)
self._scale = self._scale.to(*args, **kwargs)
return self
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to return super().to(*args, **kwargs) but it was causing weird behavior with tests using QBitsTensor and it was calling __torch_function__ after . To reproduce, return super().to(*args, **kwargs) and run python -m pytest -sv test/nn/test_qlinear.py::test_move_qlinear

Copy link
Collaborator

@dacorvo dacorvo Feb 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensor subclasses are a very special beasts: you should not override the base Tensor methods that way, and instead do it through the dispatch.

@dacorvo
Copy link
Collaborator

dacorvo commented Feb 19, 2024

I am not sure there is actually an issue with QTensor, as .to() is already correctly dispatched here:

https://github.com/huggingface/quanto/blob/6302171b7569a3fd86f31e1731d01d390d1eb557/quanto/tensor/ops.py#L74

It won't work for QBitsTensor that has an extra zeropoint inner tensor, but it should be fixed at the dispatch stage, and not directly in the class declaration IMHO.

@dacorvo
Copy link
Collaborator

dacorvo commented Feb 19, 2024

I can reproduce the issue when moving a module. This happens because the move happens in two steps:

  • move the QTensor (this calls the dispatch): t -> new_t
  • assign it back to the module param.

The second step does not replace the original tensor, though: instead it does a shallow copy by doing t.data = new_t.

This results in a weird situation indeed where the moved tensor is stored as an attribute of the original one ...

@dacorvo
Copy link
Collaborator

dacorvo commented Feb 19, 2024

Closing as obsoleted by #90

@dacorvo dacorvo closed this Feb 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants