Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: migrate drop the wrong m2m field when model have multi m2m fields #376

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### [0.8.1](Unreleased)

#### Fixed
- Migrate drop the wrong m2m field when model have multi m2m fields. (#376)
- KeyError raised when removing or renaming an existing model (#386)
- fix: error when there is `__init__.py` in the migration folder (#272)
- Setting null=false on m2m field causes migration to fail. (#334)
Expand Down
15 changes: 10 additions & 5 deletions aerich/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@

from aerich.ddl import BaseDDL
from aerich.models import MAX_VERSION_LENGTH, Aerich
from aerich.utils import get_app_connection, get_models_describe, is_default_function
from aerich.utils import (
get_app_connection,
get_models_describe,
is_default_function,
reorder_m2m_fields,
)

MIGRATE_TEMPLATE = """from tortoise import BaseDBAsyncClient

Expand Down Expand Up @@ -279,10 +284,10 @@ def diff_models(
# m2m fields
old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields"))
new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields"))
if old_m2m_fields and len(new_m2m_fields) >= 2:
length = len(old_m2m_fields)
field_index = {f["name"]: i for i, f in enumerate(new_m2m_fields)}
new_m2m_fields.sort(key=lambda field: field_index.get(field["name"], length))
if old_m2m_fields and new_m2m_fields:
old_m2m_fields, new_m2m_fields = reorder_m2m_fields(
old_m2m_fields, new_m2m_fields
)
for action, option, change in diff(old_m2m_fields, new_m2m_fields):
if (option and option[-1] == "nullable") or change[0][0] == "db_constraint":
continue
Expand Down
111 changes: 111 additions & 0 deletions aerich/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import importlib.util
import os
import re
Expand Down Expand Up @@ -101,3 +103,112 @@ def import_py_file(file: Union[str, Path]) -> ModuleType:
module = importlib.util.module_from_spec(spec) # type:ignore[arg-type]
spec.loader.exec_module(module) # type:ignore[union-attr]
return module


def _pick_match_to_head(m2m_fields: list[dict], field_info: dict) -> list[dict]:
"""
If there is a element in m2m_fields whose through or name is equal to field_info's
and this element is not at the first position, put it to the head then return the
new list, otherwise return the origin list.

Example::
>>> m2m_fields = [{'through': 'u1', 'name': 'u1'}, {'throught': 'u2', 'name': 'u2'}]
waketzheng marked this conversation as resolved.
Show resolved Hide resolved
>>> _pick_match_to_head(m2m_fields, {'through': 'u2', 'name': 'u2'})
[{'through': 'u2', 'name': 'u2'}, {'through': 'u1', 'name': 'u1'}]

"""
through = field_info["through"]
name = field_info["name"]
for index, field in enumerate(m2m_fields):
if field["through"] == through or field["name"] == name:
if index != 0:
m2m_fields = [field, *m2m_fields[:index], *m2m_fields[index + 1 :]]
break
return m2m_fields


def reorder_m2m_fields(
old_m2m_fields: list[dict], new_m2m_fields: list[dict]
) -> tuple[list[dict], list[dict]]:
"""
Reorder m2m fields to help dictdiffer.diff generate more precise changes
waketzheng marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please explain why we need to order the fields in this way? Is the code that follows this expects some sort of ordering? What are the criteria for this ordering?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. why we need to order the fields in this way?
    Good question~
    May be a custom diff function is a better solution:
-                if old_m2m_fields and new_m2m_fields:
-                    old_m2m_fields, new_m2m_fields = reorder_m2m_fields(
-                        old_m2m_fields, new_m2m_fields
-                    )
-                for action, option, change in diff(old_m2m_fields, new_m2m_fields):
+                for action, option, change in diff_plus(old_m2m_fields, new_m2m_fields):
  1. Is the code that follows this expects some sort of ordering? What are the criteria for this ordering?
    All cases are covered by unittest: https://github.com/tortoise/aerich/pull/376/files#diff-33c13e0b177bacd2f02e29bcb8aea5b49e7ce34901fd8f41fefb65defba1bd33R7

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reply!

I'm not saying that a custom diff function is better, I'm just trying to understand why this ordering is required. if it's really hard to explain with a spoken language, it's a good sign that something isn't right.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@henadzit I updated the PR description

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@henadzit I updated the PR description

@henadzit Thanks for you comments. I created a new PR, which fix the issue by another solution: #390

:param old_m2m_fields: previous m2m field info list
:param new_m2m_fields: current m2m field info list
:return: ordered old/new m2m fields
"""
length_old, length_new = len(old_m2m_fields), len(new_m2m_fields)
if length_old == length_new == 1:
# No need to change order if both of them have only one element
pass
elif length_old == 1:
# If any element of new fields match the one in old fields, put it to head
new_m2m_fields = _pick_match_to_head(new_m2m_fields, old_m2m_fields[0])
elif length_new == 1:
old_m2m_fields = _pick_match_to_head(old_m2m_fields, new_m2m_fields[0])
else:
old_table_names = [f["through"] for f in old_m2m_fields]
new_table_names = [f["through"] for f in new_m2m_fields]
old_field_names = [f["name"] for f in old_m2m_fields]
new_field_names = [f["name"] for f in new_m2m_fields]
if old_table_names == new_table_names:
pass
elif sorted(old_table_names) == sorted(new_table_names):
# If table name are the same but order not match,
# reorder new fields by through to match the order of old.

# Case like::
# old_m2m_fields = [
# {'through': 'users_group', 'name': 'users',},
# {'through': 'admins_group', 'name': 'admins'},
# ]
# new_m2m_fields = [
# {'through': 'admins_group', 'name': 'admins_new'},
# {'through': 'users_group', 'name': 'users_new',},
# ]
new_m2m_fields = sorted(
new_m2m_fields, key=lambda f: old_table_names.index(f["through"])
)
elif old_field_names == new_field_names:
pass
elif sorted(old_field_names) == sorted(new_field_names):
# Case like:
# old_m2m_fields = [
# {'name': 'users', 'through': 'users_group'},
# {'name': 'admins', 'through': 'admins_group'},
# ]
# new_m2m_fields = [
# {'name': 'admins', 'through': 'admin_group_map'},
# {'name': 'users', 'through': 'user_group_map'},
# ]
new_m2m_fields = sorted(new_m2m_fields, key=lambda f: old_field_names.index(f["name"]))
elif unchanged_table_names := set(old_table_names) & set(new_table_names):
# If old/new m2m field list have one or some unchanged table names, put them to head of list.

# Case like::
# old_m2m_fields = [
# {'through': 'users_group', 'name': 'users',},
# {'through': 'staffs_group', 'name': 'users',},
# {'through': 'admins_group', 'name': 'admins'},
# ]
# new_m2m_fields = [
# {'through': 'admins_group', 'name': 'admins_new'},
# {'through': 'users_group', 'name': 'users_new',},
# ]
unchanged = sorted(unchanged_table_names, key=lambda name: old_table_names.index(name))
ordered_old_tables = unchanged + sorted(
set(old_table_names) - unchanged_table_names,
key=lambda name: old_table_names.index(name),
)
ordered_new_tables = unchanged + sorted(
set(new_table_names) - unchanged_table_names,
key=lambda name: new_table_names.index(name),
)
if ordered_old_tables != old_table_names:
old_m2m_fields = sorted(
old_m2m_fields, key=lambda f: ordered_old_tables.index(f["through"])
)
if ordered_new_tables != new_table_names:
new_m2m_fields = sorted(
new_m2m_fields, key=lambda f: ordered_new_tables.index(f["through"])
)
return old_m2m_fields, new_m2m_fields
3 changes: 3 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ class Meta:


class Config(Model):
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField(
"models.Category", through="config_category_map", related_name="category_set"
)
label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20)
value: dict = fields.JSONField()
Expand Down
4 changes: 4 additions & 0 deletions tests/old_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class Product(Model):


class Config(Model):
category: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category")
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField(
"models.Category", through="config_category_map", related_name="config_set"
)
name = fields.CharField(max_length=100, unique=True)
label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20)
Expand Down
51 changes: 50 additions & 1 deletion tests/test_migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,48 @@
"backward_fk_fields": [],
"o2o_fields": [],
"backward_o2o_fields": [],
"m2m_fields": [],
"m2m_fields": [
{
"name": "category",
"field_type": "ManyToManyFieldInstance",
"python_type": "models.Category",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
"model_name": "models.Category",
"related_name": "configs",
"forward_key": "category_id",
"backward_key": "config_id",
"through": "config_category",
"on_delete": "CASCADE",
"_generated": False,
},
{
"name": "categories",
"field_type": "ManyToManyFieldInstance",
"python_type": "models.Category",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
"model_name": "models.Category",
"related_name": "config_set",
"forward_key": "category_id",
"backward_key": "config_id",
"through": "config_category_map",
"on_delete": "CASCADE",
"_generated": False,
},
],
},
"models.Email": {
"name": "models.Email",
Expand Down Expand Up @@ -898,6 +939,8 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL",
"ALTER TABLE `email` MODIFY COLUMN `is_primary` BOOL NOT NULL DEFAULT 0",
"CREATE TABLE `product_user` (\n `product_id` INT NOT NULL REFERENCES `product` (`id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"CREATE TABLE `config_category_map` (\n `category_id` INT NOT NULL REFERENCES `category` (`id`) ON DELETE CASCADE,\n `config_id` INT NOT NULL REFERENCES `config` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"DROP TABLE IF EXISTS `config_category`",
}
expected_downgrade_operators = {
"ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL",
Expand Down Expand Up @@ -937,6 +980,8 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(12,9) NOT NULL",
"ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL",
"ALTER TABLE `email` MODIFY COLUMN `is_primary` BOOL NOT NULL DEFAULT 0",
"CREATE TABLE `config_category` (\n `config_id` INT NOT NULL REFERENCES `config` (`id`) ON DELETE CASCADE,\n `category_id` INT NOT NULL REFERENCES `category` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"DROP TABLE IF EXISTS `config_category_map`",
}
assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators)

Expand Down Expand Up @@ -983,6 +1028,8 @@ def test_migrate(mocker: MockerFixture):
'CREATE UNIQUE INDEX "uid_product_name_869427" ON "product" ("name", "type_db_alias")',
'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")',
'CREATE TABLE "product_user" (\n "product_id" INT NOT NULL REFERENCES "product" ("id") ON DELETE CASCADE,\n "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE\n)',
'CREATE TABLE "config_category_map" (\n "category_id" INT NOT NULL REFERENCES "category" ("id") ON DELETE CASCADE,\n "config_id" INT NOT NULL REFERENCES "config" ("id") ON DELETE CASCADE\n)',
'DROP TABLE IF EXISTS "config_category"',
}
expected_downgrade_operators = {
'CREATE UNIQUE INDEX "uid_category_title_f7fc03" ON "category" ("title")',
Expand Down Expand Up @@ -1022,6 +1069,8 @@ def test_migrate(mocker: MockerFixture):
'DROP INDEX IF EXISTS "uid_product_name_869427"',
'DROP TABLE IF EXISTS "email_user"',
'DROP TABLE IF EXISTS "newmodel"',
'CREATE TABLE "config_category" (\n "config_id" INT NOT NULL REFERENCES "config" ("id") ON DELETE CASCADE,\n "category_id" INT NOT NULL REFERENCES "category" ("id") ON DELETE CASCADE\n)',
'DROP TABLE IF EXISTS "config_category_map"',
}
assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators)
assert not set(Migrate.downgrade_operators).symmetric_difference(
Expand Down
Loading
Loading