diff --git a/changelog.d/759.change.rst b/changelog.d/759.change.rst new file mode 100644 index 000000000..e627f5ea3 --- /dev/null +++ b/changelog.d/759.change.rst @@ -0,0 +1 @@ +``attrs.evolve()`` now works recursively with nested ``attrs`` classes. diff --git a/docs/examples.rst b/docs/examples.rst index 0fac312a0..83810f034 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -618,6 +618,27 @@ In Clojure that function is called `assoc >> i1 == i2 False +This functions also works for nested ``attrs`` classes. +Pass a (possibly nested) dictionary with changes for an attribute: + +.. doctest:: + + >>> @attr.s(frozen=True) + ... class Child(object): + ... x = attr.ib() + ... y = attr.ib() + >>> @attr.s(frozen=True) + ... class Parent(object): + ... child = attr.ib() + >>> i1 = Parent(Child(1, 2)) + >>> i1 + Parent(child=Child(x=1, y=2)) + >>> i2 = attr.evolve(i1, child={"y": 3}) + >>> i2 + Parent(child=Child(x=1, y=3)) + >>> i1 == i2, i1.child == i2.child + (False, False) + Other Goodies ------------- diff --git a/src/attr/_funcs.py b/src/attr/_funcs.py index e6c930cbd..30200260a 100644 --- a/src/attr/_funcs.py +++ b/src/attr/_funcs.py @@ -319,7 +319,8 @@ def evolve(inst, **changes): Create a new instance, based on *inst* with *changes* applied. :param inst: Instance of a class with ``attrs`` attributes. - :param changes: Keyword changes in the new copy. + :param changes: Keyword changes in the new copy. Nested ``attrs`` classes + can be updated by passing (nested) dicts of values. :return: A copy of inst with *changes* incorporated. @@ -337,8 +338,13 @@ def evolve(inst, **changes): continue attr_name = a.name # To deal with private attributes. init_name = attr_name if attr_name[0] != "_" else attr_name[1:] + value = getattr(inst, attr_name) if init_name not in changes: - changes[init_name] = getattr(inst, attr_name) + # Add original value to changes + changes[init_name] = value + elif has(value): + # Evolve nested attrs classes + changes[init_name] = evolve(value, **changes[init_name]) return cls(**changes) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 2fc73dced..6b9052acb 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -597,3 +597,28 @@ class C(object): b = attr.ib(init=False, default=0) assert evolve(C(1), a=2).a == 2 + + def test_recursive(self): + """ + evolve() recursively evolves nested attrs classes when a dict is + passed for an attribute. + """ + + @attr.s + class N2(object): + e = attr.ib(type=int) + + @attr.s + class N1(object): + c = attr.ib(type=N2) + d = attr.ib(type=int) + + @attr.s + class C(object): + a = attr.ib(type=N1) + b = attr.ib(type=int) + + c1 = C(N1(N2(1), 2), 3) + c2 = evolve(c1, a={"c": {"e": 23}}, b=42) + + assert c2 == C(N1(N2(23), 2), 42)