Skip to content

Commit

Permalink
Pass args from __init__ to __attrs_pre_init__ if requested (#1187)
Browse files Browse the repository at this point in the history
* Pass args from `__init__` to `__attrs_pre_init__` if requested

Detect if `__attrs_pre_init__` has arguments besides `self`
using `inspect.signature`. If so, pass `__attrs_pre_init__`
the same arguments that `__init__` (or `__attrs_init__`)
expects.

* Add changelog entry for `__attrs_pre_init__` args changes

* Don't use monkeypatch in new code

* Clarify docs

* Missed one monkeypatching

---------

Co-authored-by: Hynek Schlawack <hs@ox.cx>
  • Loading branch information
Tyrubias and hynek authored Sep 29, 2023
1 parent 46a03dc commit c2824ac
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 5 deletions.
2 changes: 2 additions & 0 deletions changelog.d/1187.change.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
If *attrs* detects that `__attrs_pre_init__` accepts more than just `self`, it will call it with the same arguments as `__init__` was called.
This allows you to, for example, pass arguments to `super().__init__()`.
3 changes: 2 additions & 1 deletion docs/init.md
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,8 @@ However, sometimes you need to do that one quick thing before or after your clas
For that purpose, *attrs* offers the following options:

- `__attrs_pre_init__` is automatically detected and run *before* *attrs* starts initializing.
This is useful if you need to inject a call to `super().__init__()`.
If `__attrs_pre_init__` takes more than the `self` argument, the *attrs*-generated `__init__` will call it with the same arguments it received itself.
This is useful if you need to inject a call to `super().__init__()` -- with or without arguments.

- `__attrs_post_init__` is automatically detected and run *after* *attrs* is done initializing your instance.
This is useful if you want to derive some attribute from others or perform some kind of validation over the whole instance.
Expand Down
26 changes: 26 additions & 0 deletions src/attr/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import copy
import enum
import inspect
import linecache
import sys
import types
Expand Down Expand Up @@ -624,6 +625,7 @@ class _ClassBuilder:
"_delete_attribs",
"_frozen",
"_has_pre_init",
"_pre_init_has_args",
"_has_post_init",
"_is_exc",
"_on_setattr",
Expand Down Expand Up @@ -670,6 +672,13 @@ def __init__(
self._weakref_slot = weakref_slot
self._cache_hash = cache_hash
self._has_pre_init = bool(getattr(cls, "__attrs_pre_init__", False))
self._pre_init_has_args = False
if self._has_pre_init:
# Check if the pre init method has more arguments than just `self`
# We want to pass arguments if pre init expects arguments
pre_init_func = cls.__attrs_pre_init__
pre_init_signature = inspect.signature(pre_init_func)
self._pre_init_has_args = len(pre_init_signature.parameters) > 1
self._has_post_init = bool(getattr(cls, "__attrs_post_init__", False))
self._delete_attribs = not bool(these)
self._is_exc = is_exc
Expand Down Expand Up @@ -974,6 +983,7 @@ def add_init(self):
self._cls,
self._attrs,
self._has_pre_init,
self._pre_init_has_args,
self._has_post_init,
self._frozen,
self._slots,
Expand All @@ -1000,6 +1010,7 @@ def add_attrs_init(self):
self._cls,
self._attrs,
self._has_pre_init,
self._pre_init_has_args,
self._has_post_init,
self._frozen,
self._slots,
Expand Down Expand Up @@ -1984,6 +1995,7 @@ def _make_init(
cls,
attrs,
pre_init,
pre_init_has_args,
post_init,
frozen,
slots,
Expand Down Expand Up @@ -2027,6 +2039,7 @@ def _make_init(
frozen,
slots,
pre_init,
pre_init_has_args,
post_init,
cache_hash,
base_attr_map,
Expand Down Expand Up @@ -2107,6 +2120,7 @@ def _attrs_to_init_script(
frozen,
slots,
pre_init,
pre_init_has_args,
post_init,
cache_hash,
base_attr_map,
Expand Down Expand Up @@ -2361,11 +2375,23 @@ def fmt_setter_with_converter(
lines.append(f"BaseException.__init__(self, {vals})")

args = ", ".join(args)
pre_init_args = args
if kw_only_args:
args += "%s*, %s" % (
", " if args else "", # leading comma
", ".join(kw_only_args), # kw_only args
)
pre_init_kw_only_args = ", ".join(
["%s=%s" % (kw_arg, kw_arg) for kw_arg in kw_only_args]
)
pre_init_args += (
", " if pre_init_args else ""
) # handle only kwargs and no regular args
pre_init_args += pre_init_kw_only_args

if pre_init and pre_init_has_args:
# If pre init method has arguments, pass same arguments as `__init__`
lines[0] = "self.__attrs_pre_init__(%s)" % pre_init_args

return (
"def %s(self, %s):\n %s\n"
Expand Down
8 changes: 7 additions & 1 deletion tests/test_dunders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


import copy
import inspect
import pickle

import pytest
Expand Down Expand Up @@ -84,10 +85,15 @@ def _add_init(cls, frozen):
This function used to be part of _make. It wasn't used anymore however
the tests for it are still useful to test the behavior of _make_init.
"""
has_pre_init = bool(getattr(cls, "__attrs_pre_init__", False))

cls.__init__ = _make_init(
cls,
cls.__attrs_attrs__,
getattr(cls, "__attrs_pre_init__", False),
has_pre_init,
len(inspect.signature(cls.__attrs_pre_init__).parameters) > 1
if has_pre_init
else False,
getattr(cls, "__attrs_post_init__", False),
frozen,
_is_slot_cls(cls),
Expand Down
74 changes: 71 additions & 3 deletions tests/test_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,21 +613,89 @@ class D:
assert C.D.__qualname__ == C.__qualname__ + ".D"

@pytest.mark.parametrize("with_validation", [True, False])
def test_pre_init(self, with_validation, monkeypatch):
def test_pre_init(self, with_validation):
"""
Verify that __attrs_pre_init__ gets called if defined.
"""
monkeypatch.setattr(_config, "_run_validators", with_validation)

@attr.s
class C:
def __attrs_pre_init__(self2):
self2.z = 30

c = C()
try:
attr.validators.set_disabled(not with_validation)
c = C()
finally:
attr.validators.set_disabled(False)

assert 30 == getattr(c, "z", None)

@pytest.mark.parametrize("with_validation", [True, False])
def test_pre_init_args(self, with_validation):
"""
Verify that __attrs_pre_init__ gets called with extra args if defined.
"""

@attr.s
class C:
x = attr.ib()

def __attrs_pre_init__(self2, x):
self2.z = x + 1

try:
attr.validators.set_disabled(not with_validation)
c = C(x=10)
finally:
attr.validators.set_disabled(False)

assert 11 == getattr(c, "z", None)

@pytest.mark.parametrize("with_validation", [True, False])
def test_pre_init_kwargs(self, with_validation):
"""
Verify that __attrs_pre_init__ gets called with extra args and kwargs if defined.
"""

@attr.s
class C:
x = attr.ib()
y = attr.field(kw_only=True)

def __attrs_pre_init__(self2, x, y):
self2.z = x + y + 1

try:
attr.validators.set_disabled(not with_validation)
c = C(10, y=11)
finally:
attr.validators.set_disabled(False)

assert 22 == getattr(c, "z", None)

@pytest.mark.parametrize("with_validation", [True, False])
def test_pre_init_kwargs_only(self, with_validation):
"""
Verify that __attrs_pre_init__ gets called with extra kwargs only if
defined.
"""

@attr.s
class C:
y = attr.field(kw_only=True)

def __attrs_pre_init__(self2, y):
self2.z = y + 1

try:
attr.validators.set_disabled(not with_validation)
c = C(y=11)
finally:
attr.validators.set_disabled(False)

assert 12 == getattr(c, "z", None)

@pytest.mark.parametrize("with_validation", [True, False])
def test_post_init(self, with_validation, monkeypatch):
"""
Expand Down

0 comments on commit c2824ac

Please sign in to comment.