Skip to content

Commit

Permalink
Recursively evolve nested attrs classes
Browse files Browse the repository at this point in the history
Fixes: #634
  • Loading branch information
sscherfke committed Feb 14, 2021
1 parent c2712fd commit 695a0bc
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/attr/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,13 @@ def evolve(inst, **changes):
continue
attr_name = a.name # To deal with private attributes.
init_name = attr_name if attr_name[0] != "_" else attr_name[1:]
value = getattr(inst, attr_name)
if init_name not in changes:
changes[init_name] = getattr(inst, attr_name)
# Add original value to changes
changes[init_name] = value
elif has(value):
# Evolve nested attrs classes
changes[init_name] = evolve(value, **changes[init_name])

return cls(**changes)

Expand Down
24 changes: 24 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,3 +597,27 @@ class C(object):
b = attr.ib(init=False, default=0)

assert evolve(C(1), a=2).a == 2

def test_recursive(self):
"""
evolve() recursively evolves nested attrs classes when a dict is
passed for an attribute.
"""

@attr.s
class N2:
e = attr.ib(type=int)

@attr.s
class N1:
c = attr.ib(type=N2)
d = attr.ib(type=int)

@attr.s
class C:
a = attr.ib(type=N1)
b = attr.ib(type=int)

c1 = C(N1(N2(1), 2), 3)
c2 = evolve(c1, a={"c": {"e": 23}})
assert c2 == C(N1(N2(23), 2), 3)

0 comments on commit 695a0bc

Please sign in to comment.