Skip to content

Commit

Permalink
refactor: return new lists instead of change argument values in function
Browse files Browse the repository at this point in the history
  • Loading branch information
waketzheng committed Dec 6, 2024
1 parent aa8e5af commit 75ad00a
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 59 deletions.
4 changes: 3 additions & 1 deletion aerich/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,9 @@ def diff_models(
old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields"))
new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields"))
if old_m2m_fields and new_m2m_fields:
reorder_m2m_fields(old_m2m_fields, new_m2m_fields)
old_m2m_fields, new_m2m_fields = reorder_m2m_fields(
old_m2m_fields, new_m2m_fields
)
for action, _, change in diff(old_m2m_fields, new_m2m_fields):
if change[0][0] == "db_constraint":
continue
Expand Down
128 changes: 104 additions & 24 deletions aerich/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,32 +105,112 @@ def import_py_file(file: Union[str, Path]) -> ModuleType:
return module


def reorder_m2m_fields(old_m2m_fields: list[dict], new_m2m_fields: list[dict]) -> None:
def _pick_match_to_head(m2m_fields: list[dict], field_info: dict) -> list[dict]:
"""
If there is a element in m2m_fields whose through or name is equal to field_info's
and this element is not at the first position, put it to the head then return the
new list, otherwise return the origin list.
Example::
>>> m2m_fields = [{'through': 'u1', 'name': 'u1'}, {'throught': 'u2', 'name': 'u2'}]
>>> _pick_match_to_head(m2m_fields, {'through': 'u2', 'name': 'u2'})
[{'through': 'u2', 'name': 'u2'}, {'throught': 'u1', 'name': 'u1'}]
"""
through = field_info["through"]
name = field_info["name"]
for index, field in enumerate(m2m_fields):
if field["through"] == through or field["name"] == name:
if index != 0:
m2m_fields = [field, *m2m_fields[:index], *m2m_fields[index + 1 :]]
break
return m2m_fields


def reorder_m2m_fields(
old_m2m_fields: list[dict], new_m2m_fields: list[dict]
) -> tuple[list[dict], list[dict]]:
"""
Reorder m2m fields to help dictdiffer.diff generate more precise changes
:param old_m2m_fields: previous m2m field info list
:param new_m2m_fields: current m2m field info list
:return:
:return: ordered old/new m2m fields
"""
old_table_names: list[str] = [f["through"] for f in old_m2m_fields]
new_table_names: list[str] = [f["through"] for f in new_m2m_fields]
if old_table_names == new_table_names:
return
if sorted(old_table_names) == sorted(new_table_names):
new_m2m_fields.sort(key=lambda field: old_table_names.index(field["through"]))
return
old_field_names: list[str] = [f["name"] for f in old_m2m_fields]
new_field_names: list[str] = [f["name"] for f in new_m2m_fields]
if old_field_names == new_field_names:
return
if sorted(old_field_names) == sorted(new_field_names):
new_m2m_fields.sort(key=lambda field: old_field_names.index(field["name"]))
return
if unchanged_tables := set(old_table_names) & set(new_table_names):
unchanged = sorted(unchanged_tables)
ordered_old_tables = unchanged + sorted(set(old_table_names) - unchanged_tables)
ordered_new_tables = unchanged + sorted(set(new_table_names) - unchanged_tables)
if ordered_old_tables != old_table_names:
old_m2m_fields.sort(key=lambda field: ordered_old_tables.index(field["through"]))
if ordered_new_tables != new_table_names:
new_m2m_fields.sort(key=lambda field: ordered_new_tables.index(field["through"]))
length_old, length_new = len(old_m2m_fields), len(new_m2m_fields)
if length_old == length_new == 1:
# No need to change order if both of them have only one element
pass
elif length_old == 1:
# If any element of new fields match the one in old fields, put it to head
new_m2m_fields = _pick_match_to_head(new_m2m_fields, old_m2m_fields[0])
elif length_new == 1:
old_m2m_fields = _pick_match_to_head(old_m2m_fields, new_m2m_fields[0])
else:
old_table_names = [f["through"] for f in old_m2m_fields]
new_table_names = [f["through"] for f in new_m2m_fields]
old_field_names = [f["name"] for f in old_m2m_fields]
new_field_names = [f["name"] for f in new_m2m_fields]
if old_table_names == new_table_names:
pass
elif sorted(old_table_names) == sorted(new_table_names):
# If table name are the same but order not match,
# reorder new fields by through to match the order of old.

# Case like::
# old_m2m_fields = [
# {'through': 'users_group', 'name': 'users',},
# {'through': 'admins_group', 'name': 'admins'},
# ]
# new_m2m_fields = [
# {'through': 'admins_group', 'name': 'admins_new'},
# {'through': 'users_group', 'name': 'users_new',},
# ]
new_m2m_fields = sorted(
new_m2m_fields, key=lambda f: old_table_names.index(f["through"])
)
elif old_field_names == new_field_names:
pass
elif sorted(old_field_names) == sorted(new_field_names):
# Case like:
# old_m2m_fields = [
# {'name': 'users', 'through': 'users_group'},
# {'name': 'admins', 'through': 'admins_group'},
# ]
# new_m2m_fields = [
# {'name': 'admins', 'through': 'admin_group_map'},
# {'name': 'users', 'through': 'user_group_map'},
# ]
new_m2m_fields = sorted(new_m2m_fields, key=lambda f: old_field_names.index(f["name"]))
elif unchanged_table_names := set(old_table_names) & set(new_table_names):
# If old/new m2m field list have one or some unchanged table names, put them to head of list.

# Case like::
# old_m2m_fields = [
# {'through': 'users_group', 'name': 'users',},
# {'through': 'staffs_group', 'name': 'users',},
# {'through': 'admins_group', 'name': 'admins'},
# ]
# new_m2m_fields = [
# {'through': 'admins_group', 'name': 'admins_new'},
# {'through': 'users_group', 'name': 'users_new',},
# ]
unchanged = sorted(
unchanged_table_names, key=lambda name: old_table_names.index(name)
)
ordered_old_tables = unchanged + sorted(
set(old_table_names) - unchanged_table_names,
key=lambda name: old_table_names.index(name),
)
ordered_new_tables = unchanged + sorted(
set(new_table_names) - unchanged_table_names,
key=lambda name: new_table_names.index(name),
)
if ordered_old_tables != old_table_names:
old_m2m_fields = sorted(
old_m2m_fields, key=lambda f: ordered_old_tables.index(f["through"])
)
if ordered_new_tables != new_table_names:
new_m2m_fields = sorted(
new_m2m_fields, key=lambda f: ordered_new_tables.index(f["through"])
)
return old_m2m_fields, new_m2m_fields
89 changes: 55 additions & 34 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ def test_the_same_through_order(self) -> None:
{"name": "members", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
old_backup, new_backup = old[:], new[:]
reorder_m2m_fields(old, new)
assert old == old_backup
assert new == new_backup
sorted_old, sorted_new = reorder_m2m_fields(old, new)
assert sorted_old == old
assert sorted_new == new

def test_same_through_with_different_orders(self) -> None:
old = [
Expand All @@ -30,10 +29,9 @@ def test_same_through_with_different_orders(self) -> None:
{"name": "admins", "through": "admins_group"},
{"name": "members", "through": "users_group"},
]
old_backup, new_backup = old[:], new[:]
reorder_m2m_fields(old, new)
assert old == old_backup
assert new == new_backup[::-1]
sorted_old, sorted_new = reorder_m2m_fields(old, new)
assert sorted_old == old
assert sorted_new == new[::-1]

def test_the_same_field_name_order(self) -> None:
old = [
Expand All @@ -44,10 +42,9 @@ def test_the_same_field_name_order(self) -> None:
{"name": "users", "through": "user_groups"},
{"name": "admins", "through": "admin_groups"},
]
old_backup, new_backup = old[:], new[:]
reorder_m2m_fields(old, new)
assert old == old_backup
assert new == new_backup
sorted_old, sorted_new = reorder_m2m_fields(old, new)
assert sorted_old == old
assert sorted_new == new

def test_same_field_name_with_different_orders(self) -> None:
old = [
Expand All @@ -58,10 +55,9 @@ def test_same_field_name_with_different_orders(self) -> None:
{"name": "users", "through": "user_groups"},
{"name": "admins", "through": "admin_groups"},
]
old_backup, new_backup = old[:], new[:]
reorder_m2m_fields(old, new)
assert old == old_backup
assert new == new_backup[::-1]
sorted_old, sorted_new = reorder_m2m_fields(old, new)
assert sorted_old == old
assert sorted_new == new[::-1]

def test_drop_one(self) -> None:
old = [
Expand All @@ -71,11 +67,10 @@ def test_drop_one(self) -> None:
new = [
{"name": "admins", "through": "admins_group"},
]
old_backup, new_backup = old[:], new[:]
reorder_m2m_fields(old, new)
assert new == new_backup
assert old == old_backup[::-1]
assert old[0] == new[0]
sorted_old, sorted_new = reorder_m2m_fields(old, new)
assert sorted_old == old[::-1]
assert sorted_new == new
assert sorted_old[0] == new[0]

def test_add_one(self) -> None:
old = [
Expand All @@ -85,11 +80,10 @@ def test_add_one(self) -> None:
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
old_backup, new_backup = old[:], new[:]
reorder_m2m_fields(old, new)
assert old == old_backup
assert new == new_backup[::-1]
assert new[0] == old[0]
sorted_old, sorted_new = reorder_m2m_fields(old, new)
assert sorted_old == old
assert sorted_new == new[::-1]
assert sorted_old[0] == sorted_new[0]

def test_drop_some(self) -> None:
old = [
Expand All @@ -100,10 +94,9 @@ def test_drop_some(self) -> None:
new = [
{"name": "admins", "through": "admins_group"},
]
old_backup, new_backup = old[:], new[:]
reorder_m2m_fields(old, new)
assert new == new_backup
assert old[0] == old_backup[1] == new[0]
sorted_old, sorted_new = reorder_m2m_fields(old, new)
assert sorted_new == new
assert sorted_old[0] == old[1] == sorted_new[0]

def test_add_some(self) -> None:
old = [
Expand All @@ -114,7 +107,35 @@ def test_add_some(self) -> None:
{"name": "admins", "through": "admins_group"},
{"name": "staffs", "through": "staffs_group"},
]
old_backup, new_backup = old[:], new[:]
reorder_m2m_fields(old, new)
assert old == old_backup
assert new[0] == new_backup[-1] == old[0]
sorted_old, sorted_new = reorder_m2m_fields(old, new)
assert sorted_old == old
assert sorted_new[0] == new[-1] == sorted_old[0]

def test_some_through_unchanged(self) -> None:
old = [
{"name": "staffs", "through": "staffs_group"},
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "users", "through": "users_group"},
{"name": "admins_new", "through": "admins_group"},
{"name": "staffs_new", "through": "staffs_group"},
]
sorted_old, sorted_new = reorder_m2m_fields(old, new)
assert sorted_old == old
assert [i["through"] for i in sorted_new][: len(old)] == [i["through"] for i in sorted_old]

def test_some_unchanged_without_drop_or_add(self) -> None:
old = [
{"name": "staffs", "through": "staffs_group"},
{"name": "admins", "through": "admins_group"},
{"name": "users", "through": "users_group"},
]
new = [
{"name": "users_new", "through": "users_group"},
{"name": "admins_new", "through": "admins_group"},
{"name": "staffs_new", "through": "staffs_group"},
]
sorted_old, sorted_new = reorder_m2m_fields(old, new)
assert sorted_old == old
assert [i["through"] for i in sorted_new] == [i["through"] for i in sorted_old]

0 comments on commit 75ad00a

Please sign in to comment.