From 773a4322846742433dd201118a6b723f1164ac58 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 21 Feb 2024 13:26:21 +0100 Subject: [PATCH 1/2] fix --- src/accelerate/utils/operations.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 6f675846350..c27b172dfbb 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -66,17 +66,10 @@ def is_tensor_information(tensor_info): def is_namedtuple(data): """ - Checks if `x` is a `namedtuple` or not. Can have false positives, but only if a user is trying to mimic a + Checks if `data` is a `namedtuple` or not. Can have false positives, but only if a user is trying to mimic a `namedtuple` perfectly. """ - data_type = type(data) - bases = data_type.__bases__ - if len(bases) != 1 or bases[0] != tuple: - return False - fields = getattr(data_type, "_fields", None) - if not isinstance(fields, tuple): - return False - return all(isinstance(member, str) for member in fields) + return isinstance(data, tuple) and hasattr(data, "_asdict") and hasattr(data, "_fields") def honor_type(obj, generator): From e793fc0ad9210750e15e6add92ef3a6691a8d544 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 21 Feb 2024 14:10:34 +0100 Subject: [PATCH 2/2] add test --- tests/test_utils.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index 4658e44ef1a..1a588a9b6d3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,6 +17,7 @@ import unittest import warnings from collections import UserDict, namedtuple +from typing import NamedTuple, Optional from unittest.mock import Mock, patch import torch @@ -39,6 +40,7 @@ save, send_to_device, ) +from accelerate.utils.operations import is_namedtuple ExampleNamedTuple = namedtuple("ExampleNamedTuple", "a b c") @@ -302,3 +304,25 @@ def test_slice_and_concatenate(self): result = pad_input_tensors(batch, batch_size, num_processes) # We should expect there to be 66 items now assert result.shape == torch.Size([66, 4, 4]) + + def test_named_tuples(self): + class QuantTensorBase(NamedTuple): + value: torch.Tensor + scale: Optional[torch.Tensor] + zero_point: Optional[torch.Tensor] + + class Second(QuantTensorBase): + pass + + a = QuantTensorBase(torch.tensor(1.0), None, None) + b = Second(torch.tensor(1.0), None, None) + + point = namedtuple("Point", ["x", "y"]) + p = point(11, y=22) + + self.assertTrue(is_namedtuple(a)) + self.assertTrue(is_namedtuple(b)) + self.assertTrue(is_namedtuple(p)) + self.assertFalse(is_namedtuple((1, 2))) + self.assertFalse(is_namedtuple("hey")) + self.assertFalse(is_namedtuple(object()))