From 8befd7d3bd2b2bd2a43c7e451df79b1c617c81c3 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 20 Sep 2023 12:56:06 -0700 Subject: [PATCH] test_state: add test_mutable_copy for entire state and mutable vars Test case for #1841 --- tests/test_state.py | 51 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/test_state.py b/tests/test_state.py index 763d36a6dd..cf59e5eb9c 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import datetime import functools import sys @@ -1585,6 +1586,56 @@ def assert_custom_dirty(): assert_custom_dirty() +@pytest.mark.parametrize( + ("copy_func",), + [ + (copy.copy,), + (copy.deepcopy,), + ], +) +def test_mutable_copy(mutable_state, copy_func): + """Test that mutable types are copied correctly. + + Args: + mutable_state: A test state. + copy_func: A copy function. + """ + ms_copy = copy_func(mutable_state) + assert ms_copy is not mutable_state + for attr in ("array", "hashmap", "test_set", "custom"): + assert getattr(ms_copy, attr) == getattr(mutable_state, attr) + assert getattr(ms_copy, attr) is not getattr(mutable_state, attr) + ms_copy.custom.array.append(42) + assert "custom" in ms_copy.dirty_vars + if copy_func is copy.copy: + assert "custom" in mutable_state.dirty_vars + else: + assert not mutable_state.dirty_vars + + +@pytest.mark.parametrize( + ("copy_func",), + [ + (copy.copy,), + (copy.deepcopy,), + ], +) +def test_mutable_copy_vars(mutable_state, copy_func): + """Test that mutable types are copied correctly. + + Args: + mutable_state: A test state. + copy_func: A copy function. + """ + for attr in ("array", "hashmap", "test_set", "custom"): + var_orig = getattr(mutable_state, attr) + var_copy = copy_func(var_orig) + assert var_orig is not var_copy + assert var_orig == var_copy + # copied vars should never be proxies, as they by definition are no longer attached to the state. + assert not isinstance(var_copy, MutableProxy) + + def test_duplicate_substate_class(duplicate_substate): with pytest.raises(ValueError): duplicate_substate()