Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix wrong is_namedtuple implementation #2475

Merged
merged 4 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,10 @@ def is_tensor_information(tensor_info):

def is_namedtuple(data):
"""
Checks if `x` is a `namedtuple` or not. Can have false positives, but only if a user is trying to mimic a
Checks if `data` is a `namedtuple` or not. Can have false positives, but only if a user is trying to mimic a
`namedtuple` perfectly.
"""
data_type = type(data)
bases = data_type.__bases__
if len(bases) != 1 or bases[0] != tuple:
return False
fields = getattr(data_type, "_fields", None)
if not isinstance(fields, tuple):
return False
return all(isinstance(member, str) for member in fields)
return isinstance(data, tuple) and hasattr(data, "_asdict") and hasattr(data, "_fields")


def honor_type(obj, generator):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest
import warnings
from collections import UserDict, namedtuple
from typing import NamedTuple, Optional
from unittest.mock import Mock, patch

import torch
Expand All @@ -40,6 +41,7 @@
save,
send_to_device,
)
from accelerate.utils.operations import is_namedtuple


ExampleNamedTuple = namedtuple("ExampleNamedTuple", "a b c")
Expand Down Expand Up @@ -310,3 +312,25 @@ def test_send_to_device_compiles(self):
def test_convert_to_fp32(self):
compiled_convert_to_fp32 = torch.compile(convert_to_fp32, fullgraph=True)
compiled_convert_to_fp32(torch.zeros([1], dtype=torch.bfloat16))

def test_named_tuples(self):
class QuantTensorBase(NamedTuple):
value: torch.Tensor
scale: Optional[torch.Tensor]
zero_point: Optional[torch.Tensor]

class Second(QuantTensorBase):
pass

a = QuantTensorBase(torch.tensor(1.0), None, None)
b = Second(torch.tensor(1.0), None, None)

point = namedtuple("Point", ["x", "y"])
p = point(11, y=22)

self.assertTrue(is_namedtuple(a))
self.assertTrue(is_namedtuple(b))
self.assertTrue(is_namedtuple(p))
self.assertFalse(is_namedtuple((1, 2)))
self.assertFalse(is_namedtuple("hey"))
self.assertFalse(is_namedtuple(object()))
Loading