Skip to content

Commit

Permalink
Pull thread-local into _compat module to fix cloudpickling.
Browse files Browse the repository at this point in the history
Because cloudpickle tries to pickle a function's globals, when it
pickled an attrs instance, it would try to pickle the `__repr__` method
and its globals, which included a `threading.local`. This broke
cloudpickle for all attrs classes unless they explicitly specified
`repr=False`. Modules, however, are pickled by reference, not by value,
so moving the repr into a different module means we can put `_compat`
into the function's globals and not worry about direct references.
  • Loading branch information
thetorpedodog committed Nov 2, 2021
1 parent f812ec5 commit 0d710b4
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 53 deletions.
15 changes: 15 additions & 0 deletions src/attr/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import platform
import sys
import threading
import types
import warnings

Expand Down Expand Up @@ -243,3 +244,17 @@ def func():


set_closure_cell = make_set_closure_cell()

# Thread-local global to track attrs instances which are already being repr'd.
# This is needed because there is no other (thread-safe) way to pass info
# about the instances that are already being repr'd through the call stack
# in order to ensure we don't perform infinite recursion.
#
# For instance, if an instance contains a dict which contains that instance,
# we need to know that we're already repr'ing the outside instance from within
# the dict's repr() call.
#
# This lives here rather than in _make.py so that the functions in _make.py
# don't have a direct reference to the thread-local in their globals dict.
# If they have such a reference, it breaks cloudpickle.
repr_context = threading.local()
83 changes: 30 additions & 53 deletions src/attr/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,11 @@
import inspect
import linecache
import sys
import threading
import warnings

from operator import itemgetter

from . import _config, setters
from ._compat import (
HAS_F_STRINGS,
PY2,
PY310,
PYPY,
isclass,
iteritems,
metadata_proxy,
new_class,
ordered_dict,
set_closure_cell,
)
from . import _compat, _config, setters
from .exceptions import (
DefaultAlreadySetError,
FrozenInstanceError,
Expand All @@ -31,7 +18,7 @@
)


if not PY2:
if not _compat.PY2:
import typing


Expand All @@ -53,7 +40,7 @@
# (when slots=True)
_hash_cache_field = "_attrs_cached_hash"

_empty_metadata_singleton = metadata_proxy({})
_empty_metadata_singleton = _compat.metadata_proxy({})

# Unique object for unequivocal getattr() defaults.
_sentinel = object()
Expand Down Expand Up @@ -103,7 +90,7 @@ class _CacheHashWrapper(int):
See GH #613 for more details.
"""

if PY2:
if _compat.PY2:
# For some reason `type(None)` isn't callable in Python 2, but we don't
# actually need a constructor for None objects, we just need any
# available function that returns None.
Expand Down Expand Up @@ -521,9 +508,9 @@ def _transform_attrs(
anns = _get_annotations(cls)

if these is not None:
ca_list = [(name, ca) for name, ca in iteritems(these)]
ca_list = [(name, ca) for name, ca in _compat.iteritems(these)]

if not isinstance(these, ordered_dict):
if not isinstance(these, _compat.ordered_dict):
ca_list.sort(key=_counter_getter)
elif auto_attribs is True:
ca_names = {
Expand Down Expand Up @@ -613,7 +600,7 @@ def _transform_attrs(
return _Attributes((AttrsClass(attrs), base_attrs, base_attr_map))


if PYPY:
if _compat.PYPY:

def _frozen_setattrs(self, name, value):
"""
Expand Down Expand Up @@ -795,7 +782,7 @@ def _create_slots_class(self):
"""
cd = {
k: v
for k, v in iteritems(self._cls_dict)
for k, v in _compat.iteritems(self._cls_dict)
if k not in tuple(self._attr_names) + ("__dict__", "__weakref__")
}

Expand Down Expand Up @@ -850,7 +837,7 @@ def _create_slots_class(self):
# we collect them here and update the class dict
reused_slots = {
slot: slot_descriptor
for slot, slot_descriptor in iteritems(existing_slots)
for slot, slot_descriptor in _compat.iteritems(existing_slots)
if slot in slot_names
}
slot_names = [name for name in slot_names if name not in reused_slots]
Expand Down Expand Up @@ -893,7 +880,7 @@ def _create_slots_class(self):
pass
else:
if match:
set_closure_cell(cell, cls)
_compat.set_closure_cell(cell, cls)

return cls

Expand Down Expand Up @@ -1475,7 +1462,7 @@ def attrs(
.. versionchanged:: 21.1.0 *cmp* undeprecated
.. versionadded:: 21.3.0 *match_args*
"""
if auto_detect and PY2:
if auto_detect and _compat.PY2:
raise PythonTooOldError(
"auto_detect only works on Python 3 and later."
)
Expand Down Expand Up @@ -1591,7 +1578,7 @@ def wrap(cls):
)

if (
PY310
_compat.PY310
and match_args
and not _has_own_attribute(cls, "__match_args__")
):
Expand All @@ -1614,7 +1601,7 @@ def wrap(cls):
"""


if PY2:
if _compat.PY2:

def _has_frozen_base_class(cls):
"""
Expand Down Expand Up @@ -1666,7 +1653,7 @@ def _make_hash(cls, attrs, frozen, cache_hash):
if not cache_hash:
hash_def += "):"
else:
if not PY2:
if not _compat.PY2:
hash_def += ", *"

hash_def += (
Expand Down Expand Up @@ -1864,17 +1851,7 @@ def _add_eq(cls, attrs=None):
return cls


# Thread-local global to track attrs instances which are already being repr'd.
# This is needed because there is no other (thread-safe) way to pass info
# about the instances that are already being repr'd through the call stack
# in order to ensure we don't perform infinite recursion.
#
# For instance, if an instance contains a dict which contains that instance,
# we need to know that we're already repr'ing the outside instance from within
# the dict's repr() call.
_already_repring = threading.local()

if HAS_F_STRINGS:
if _compat.HAS_F_STRINGS:

def _make_repr(attrs, ns, cls):
unique_filename = "repr"
Expand All @@ -1891,7 +1868,7 @@ def _make_repr(attrs, ns, cls):
for name, r, _ in attr_names_with_reprs
if r != repr
}
globs["_already_repring"] = _already_repring
globs["_compat"] = _compat
globs["AttributeError"] = AttributeError
globs["NOTHING"] = NOTHING
attribute_fragments = []
Expand Down Expand Up @@ -1919,10 +1896,10 @@ def _make_repr(attrs, ns, cls):
lines = []
lines.append("def __repr__(self):")
lines.append(" try:")
lines.append(" working_set = _already_repring.working_set")
lines.append(" working_set = _compat.repr_context.working_set")
lines.append(" except AttributeError:")
lines.append(" working_set = {id(self),}")
lines.append(" _already_repring.working_set = working_set")
lines.append(" _compat.repr_context.working_set = working_set")
lines.append(" else:")
lines.append(" if id(self) in working_set:")
lines.append(" return '...'")
Expand Down Expand Up @@ -1962,10 +1939,10 @@ def __repr__(self):
Automatically created by attrs.
"""
try:
working_set = _already_repring.working_set
working_set = _compat.repr_context.working_set
except AttributeError:
working_set = set()
_already_repring.working_set = working_set
_compat.repr_context.working_set = working_set

if id(self) in working_set:
return "..."
Expand Down Expand Up @@ -2035,7 +2012,7 @@ def fields(cls):
.. versionchanged:: 16.2.0 Returned tuple allows accessing the fields
by name.
"""
if not isclass(cls):
if not _compat.isclass(cls):
raise TypeError("Passed object must be a class.")
attrs = getattr(cls, "__attrs_attrs__", None)
if attrs is None:
Expand Down Expand Up @@ -2063,14 +2040,14 @@ def fields_dict(cls):
.. versionadded:: 18.1.0
"""
if not isclass(cls):
if not _compat.isclass(cls):
raise TypeError("Passed object must be a class.")
attrs = getattr(cls, "__attrs_attrs__", None)
if attrs is None:
raise NotAnAttrsClassError(
"{cls!r} is not an attrs-decorated class.".format(cls=cls)
)
return ordered_dict(((a.name, a) for a in attrs))
return _compat.ordered_dict(((a.name, a) for a in attrs))


def validate(inst):
Expand Down Expand Up @@ -2223,7 +2200,7 @@ def _assign_with_converter(attr_name, value_var, has_on_setattr):
)


if PY2:
if _compat.PY2:

def _unpack_kw_only_py2(attr_name, default=None):
"""
Expand Down Expand Up @@ -2497,7 +2474,7 @@ def fmt_setter_with_converter(
if a.init is True:
if a.type is not None and a.converter is None:
annotations[arg_name] = a.type
elif a.converter is not None and not PY2:
elif a.converter is not None and not _compat.PY2:
# Try to get the type from the converter.
sig = None
try:
Expand Down Expand Up @@ -2554,7 +2531,7 @@ def fmt_setter_with_converter(

args = ", ".join(args)
if kw_only_args:
if PY2:
if _compat.PY2:
lines = _unpack_kw_only_lines_py2(kw_only_args) + lines

args += "%s**_kw_only" % (", " if args else "",) # leading comma
Expand Down Expand Up @@ -2673,7 +2650,7 @@ def __init__(
bound_setattr(
"metadata",
(
metadata_proxy(metadata)
_compat.metadata_proxy(metadata)
if metadata
else _empty_metadata_singleton
),
Expand Down Expand Up @@ -2768,7 +2745,7 @@ def _setattrs(self, name_values_pairs):
else:
bound_setattr(
name,
metadata_proxy(value)
_compat.metadata_proxy(value)
if value
else _empty_metadata_singleton,
)
Expand Down Expand Up @@ -3044,7 +3021,7 @@ def make_class(name, attrs, bases=(object,), **attributes_arguments):
if user_init is not None:
body["__init__"] = user_init

type_ = new_class(name, bases, {}, lambda ns: ns.update(body))
type_ = _compat.new_class(name, bases, {}, lambda ns: ns.update(body))

# For pickling to work, the __module__ variable needs to be set to the
# frame where the class is created. Bypass this step in environments where
Expand Down Expand Up @@ -3131,7 +3108,7 @@ def pipe_converter(val):

return val

if not PY2:
if not _compat.PY2:
if not converters:
# If the converter list is empty, pipe_converter is the identity.
A = typing.TypeVar("A")
Expand Down

0 comments on commit 0d710b4

Please sign in to comment.