Skip to content

Commit 6f8b524

Browse files
committed
Merge branch 'fix-380-rename-multi-fields' of github.com:waketzheng/aerich into fix-380-rename-multi-fields
2 parents 5d460be + 19adfe8 commit 6f8b524

File tree

6 files changed

+150
-68
lines changed

6 files changed

+150
-68
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
### [0.8.1](Unreleased)
66

77
#### Fixed
8+
- fix: add o2o field does not create constraint when migrating. (#396)
89
- Migration with duplicate renaming of columns in some cases. (#395)
910
- fix: intermediate table for m2m relation not created. (#394)
1011
- Migrate add m2m field with custom through generate duplicated table. (#393)
@@ -16,6 +17,7 @@
1617
- Fix configuration file reading error when containing Chinese characters. (#286)
1718
- sqlite: failed to create/drop index. (#302)
1819
- PostgreSQL: Cannot drop constraint after deleting or rename FK on a model. (#378)
20+
- Fix create/drop indexes in every migration. (#377)
1921
- Sort m2m fields before comparing them with diff. (#271)
2022

2123
#### Changed

aerich/migrate.py

+113-67
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
import hashlib
1+
from __future__ import annotations
2+
23
import importlib
34
import os
45
from datetime import datetime
56
from pathlib import Path
67
from typing import Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast
78

89
import asyncclick as click
10+
import tortoise
911
from dictdiffer import diff
1012
from tortoise import BaseDBAsyncClient, ConfigurationError, Model, Tortoise
1113
from tortoise.exceptions import OperationalError
@@ -201,21 +203,25 @@ def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False) -> None:
201203

202204
@classmethod
203205
def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]) -> list:
204-
ret: list = []
205-
206-
def index_hash(self) -> str:
207-
h = hashlib.new("MD5", usedforsecurity=False) # type:ignore[call-arg]
208-
h.update(
209-
self.index_name(cls.ddl.schema_generator, model).encode()
210-
+ self.__class__.__name__.encode()
211-
)
212-
return h.hexdigest()
213-
214-
for index in indexes:
215-
if isinstance(index, Index):
216-
index.__hash__ = index_hash # type:ignore[method-assign,assignment]
217-
ret.append(index)
218-
return ret
206+
if tortoise.__version__ > "0.22.2":
207+
# The min version of tortoise is '0.11.0', so we can compare it by a `>`,
208+
# tortoise>0.22.2 have __eq__/__hash__ with Index class since 313ee76.
209+
return indexes
210+
if index_classes := set(index.__class__ for index in indexes if isinstance(index, Index)):
211+
# Leave magic patch here to compare with older version of tortoise-orm
212+
# TODO: limit tortoise>0.22.2 in pyproject.toml and remove this function when v0.9.0 released
213+
for index_cls in index_classes:
214+
if index_cls(fields=("id",)) != index_cls(fields=("id",)):
215+
216+
def _hash(self) -> int:
217+
return hash((tuple(sorted(self.fields)), self.name, self.expressions))
218+
219+
def _eq(self, other) -> bool:
220+
return type(self) is type(other) and self.__dict__ == other.__dict__
221+
222+
setattr(index_cls, "__hash__", _hash)
223+
setattr(index_cls, "__eq__", _eq)
224+
return indexes
219225

220226
@classmethod
221227
def _get_indexes(cls, model, model_describe: dict) -> Set[Union[Index, Tuple[str, ...]]]:
@@ -282,6 +288,68 @@ def _handle_m2m_fields(
282288
if add:
283289
cls._add_operator(cls.drop_m2m(table), upgrade, True)
284290

291+
@classmethod
292+
def _handle_relational(
293+
cls,
294+
key: str,
295+
old_model_describe: Dict,
296+
new_model_describe: Dict,
297+
model: Type[Model],
298+
old_models: Dict,
299+
new_models: Dict,
300+
upgrade=True,
301+
) -> None:
302+
old_fk_fields = cast(List[dict], old_model_describe.get(key))
303+
new_fk_fields = cast(List[dict], new_model_describe.get(key))
304+
305+
old_fk_fields_name: List[str] = [i.get("name", "") for i in old_fk_fields]
306+
new_fk_fields_name: List[str] = [i.get("name", "") for i in new_fk_fields]
307+
308+
# add
309+
for new_fk_field_name in set(new_fk_fields_name).difference(set(old_fk_fields_name)):
310+
fk_field = cls.get_field_by_name(new_fk_field_name, new_fk_fields)
311+
if fk_field.get("db_constraint"):
312+
ref_describe = cast(dict, new_models[fk_field["python_type"]])
313+
sql = cls._add_fk(model, fk_field, ref_describe)
314+
cls._add_operator(sql, upgrade, fk_m2m_index=True)
315+
# drop
316+
for old_fk_field_name in set(old_fk_fields_name).difference(set(new_fk_fields_name)):
317+
old_fk_field = cls.get_field_by_name(old_fk_field_name, cast(List[dict], old_fk_fields))
318+
if old_fk_field.get("db_constraint"):
319+
ref_describe = cast(dict, old_models[old_fk_field["python_type"]])
320+
sql = cls._drop_fk(model, old_fk_field, ref_describe)
321+
cls._add_operator(sql, upgrade, fk_m2m_index=True)
322+
323+
@classmethod
324+
def _handle_fk_fields(
325+
cls,
326+
old_model_describe: Dict,
327+
new_model_describe: Dict,
328+
model: Type[Model],
329+
old_models: Dict,
330+
new_models: Dict,
331+
upgrade=True,
332+
) -> None:
333+
key = "fk_fields"
334+
cls._handle_relational(
335+
key, old_model_describe, new_model_describe, model, old_models, new_models, upgrade
336+
)
337+
338+
@classmethod
339+
def _handle_o2o_fields(
340+
cls,
341+
old_model_describe: Dict,
342+
new_model_describe: Dict,
343+
model: Type[Model],
344+
old_models: Dict,
345+
new_models: Dict,
346+
upgrade=True,
347+
) -> None:
348+
key = "o2o_fields"
349+
cls._handle_relational(
350+
key, old_model_describe, new_model_describe, model, old_models, new_models, upgrade
351+
)
352+
285353
@classmethod
286354
def diff_models(
287355
cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True
@@ -296,7 +364,7 @@ def diff_models(
296364
_aerich = f"{cls.app}.{cls._aerich}"
297365
old_models.pop(_aerich, None)
298366
new_models.pop(_aerich, None)
299-
models_with_rename_field: Set[str] = set()
367+
models_with_rename_field: Set[str] = set() # models that trigger the click.prompt
300368

301369
for new_model_str, new_model_describe in new_models.items():
302370
model = cls._get_model(new_model_describe["name"].split(".")[1])
@@ -336,6 +404,13 @@ def diff_models(
336404
# current only support rename pk
337405
if action == "change" and option == "name":
338406
cls._add_operator(cls._rename_field(model, *change), upgrade)
407+
# fk fields
408+
args = (old_model_describe, new_model_describe, model, old_models, new_models)
409+
cls._handle_fk_fields(*args, upgrade=upgrade)
410+
# o2o fields
411+
cls._handle_o2o_fields(*args, upgrade=upgrade)
412+
old_o2o_columns = [i["raw_field"] for i in old_model_describe.get("o2o_fields", [])]
413+
new_o2o_columns = [i["raw_field"] for i in new_model_describe.get("o2o_fields", [])]
339414
# m2m fields
340415
cls._handle_m2m_fields(
341416
old_model_describe, new_model_describe, model, new_models, upgrade
@@ -369,12 +444,10 @@ def diff_models(
369444
new_data_fields_name = cast(List[str], [i.get("name") for i in new_data_fields])
370445

371446
# add fields or rename fields
372-
rename_fields: Dict[str, str] = {}
373447
for new_data_field_name in set(new_data_fields_name).difference(
374448
set(old_data_fields_name)
375449
):
376450
new_data_field = cls.get_field_by_name(new_data_field_name, new_data_fields)
377-
model_rename_fields = cls._rename_fields.get(new_model_str)
378451
is_rename = False
379452
field_type = new_data_field.get("field_type")
380453
db_column = new_data_field.get("db_column")
@@ -399,8 +472,11 @@ def diff_models(
399472
and old_data_field_name not in new_data_fields_name
400473
):
401474
if upgrade:
402-
if old_data_field_name in rename_fields or (
403-
new_data_field_name in rename_fields.values()
475+
if (
476+
rename_fields := cls._rename_fields.get(new_model_str)
477+
) and (
478+
old_data_field_name in rename_fields
479+
or new_data_field_name in rename_fields.values()
404480
):
405481
continue
406482
prefix = f"({new_model_str}) "
@@ -417,22 +493,18 @@ def diff_models(
417493
show_choices=True,
418494
)
419495
if is_rename:
496+
if rename_fields is None:
497+
rename_fields = cls._rename_fields[new_model_str] = {}
420498
rename_fields[old_data_field_name] = new_data_field_name
421499
else:
422500
is_rename = False
423-
if model_rename_fields and (
424-
rename_to := model_rename_fields.get(new_data_field_name)
501+
if rename_to := cls._rename_fields.get(new_model_str, {}).get(
502+
new_data_field_name
425503
):
426504
is_rename = True
427505
if rename_to != old_data_field_name:
428506
continue
429507
if is_rename:
430-
if upgrade:
431-
if new_model_str not in cls._rename_fields:
432-
cls._rename_fields[new_model_str] = {}
433-
cls._rename_fields[new_model_str][
434-
old_data_field_name
435-
] = new_data_field_name
436508
# only MySQL8+ has rename syntax
437509
if (
438510
cls.dialect == "mysql"
@@ -452,7 +524,10 @@ def diff_models(
452524
)
453525
if not is_rename:
454526
cls._add_operator(cls._add_field(model, new_data_field), upgrade)
455-
if new_data_field["indexed"]:
527+
if (
528+
new_data_field["indexed"]
529+
and new_data_field["db_column"] not in new_o2o_columns
530+
):
456531
cls._add_operator(
457532
cls._add_index(
458533
model, (new_data_field["db_column"],), new_data_field["unique"]
@@ -461,14 +536,14 @@ def diff_models(
461536
True,
462537
)
463538
# remove fields
464-
model_rename_fields = cls._rename_fields.get(new_model_str)
539+
rename_fields = cls._rename_fields.get(new_model_str)
465540
for old_data_field_name in set(old_data_fields_name).difference(
466541
set(new_data_fields_name)
467542
):
468543
# don't remove field if is renamed
469-
if model_rename_fields and (
470-
(upgrade and old_data_field_name in model_rename_fields)
471-
or (not upgrade and old_data_field_name in model_rename_fields.values())
544+
if rename_fields and (
545+
(upgrade and old_data_field_name in rename_fields)
546+
or (not upgrade and old_data_field_name in rename_fields.values())
472547
):
473548
continue
474549
old_data_field = cls.get_field_by_name(old_data_field_name, old_data_fields)
@@ -477,46 +552,17 @@ def diff_models(
477552
cls._remove_field(model, db_column),
478553
upgrade,
479554
)
480-
if old_data_field["indexed"]:
555+
if (
556+
old_data_field["indexed"]
557+
and old_data_field["db_column"] not in old_o2o_columns
558+
):
481559
is_unique_field = old_data_field.get("unique")
482560
cls._add_operator(
483561
cls._drop_index(model, {db_column}, is_unique_field),
484562
upgrade,
485563
True,
486564
)
487565

488-
old_fk_fields = cast(List[dict], old_model_describe.get("fk_fields"))
489-
new_fk_fields = cast(List[dict], new_model_describe.get("fk_fields"))
490-
491-
old_fk_fields_name: List[str] = [i.get("name", "") for i in old_fk_fields]
492-
new_fk_fields_name: List[str] = [i.get("name", "") for i in new_fk_fields]
493-
494-
# add fk
495-
for new_fk_field_name in set(new_fk_fields_name).difference(
496-
set(old_fk_fields_name)
497-
):
498-
fk_field = cls.get_field_by_name(new_fk_field_name, new_fk_fields)
499-
if fk_field.get("db_constraint"):
500-
ref_describe = cast(dict, new_models[fk_field["python_type"]])
501-
cls._add_operator(
502-
cls._add_fk(model, fk_field, ref_describe),
503-
upgrade,
504-
fk_m2m_index=True,
505-
)
506-
# drop fk
507-
for old_fk_field_name in set(old_fk_fields_name).difference(
508-
set(new_fk_fields_name)
509-
):
510-
old_fk_field = cls.get_field_by_name(
511-
old_fk_field_name, cast(List[dict], old_fk_fields)
512-
)
513-
if old_fk_field.get("db_constraint"):
514-
ref_describe = cast(dict, old_models[old_fk_field["python_type"]])
515-
cls._add_operator(
516-
cls._drop_fk(model, old_fk_field, ref_describe),
517-
upgrade,
518-
fk_m2m_index=True,
519-
)
520566
# change fields
521567
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
522568
old_data_field = cls.get_field_by_name(field_name, old_data_fields)

tests/indexes.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from tortoise.indexes import Index
2+
3+
4+
class CustomIndex(Index):
5+
def __init__(self, *args, **kw) -> None:
6+
super().__init__(*args, **kw)
7+
self._foo = ""

tests/models.py

+10
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from enum import IntEnum
44

55
from tortoise import Model, fields
6+
from tortoise.indexes import Index
7+
8+
from tests.indexes import CustomIndex
69

710

811
class ProductType(IntEnum):
@@ -33,13 +36,18 @@ class User(Model):
3336

3437
products: fields.ManyToManyRelation["Product"]
3538

39+
class Meta:
40+
# reverse indexes elements
41+
indexes = [CustomIndex(fields=("is_superuser",)), Index(fields=("username", "is_active"))]
42+
3643

3744
class Email(Model):
3845
email_id = fields.IntField(primary_key=True)
3946
email = fields.CharField(max_length=200, db_index=True)
4047
is_primary = fields.BooleanField(default=False)
4148
address = fields.CharField(max_length=200)
4249
users: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User")
50+
config: fields.OneToOneRelation["Config"] = fields.OneToOneField("models.Config")
4351

4452

4553
def default_name():
@@ -92,6 +100,8 @@ class Config(Model):
92100
"models.User", description="User"
93101
)
94102

103+
email: fields.OneToOneRelation["Email"]
104+
95105

96106
class NewModel(Model):
97107
name = fields.CharField(max_length=50)

tests/old_models.py

+6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from enum import IntEnum
33

44
from tortoise import Model, fields
5+
from tortoise.indexes import Index
6+
7+
from tests.indexes import CustomIndex
58

69

710
class ProductType(IntEnum):
@@ -31,6 +34,9 @@ class User(Model):
3134
intro = fields.TextField(default="")
3235
longitude = fields.DecimalField(max_digits=12, decimal_places=9)
3336

37+
class Meta:
38+
indexes = [Index(fields=("username", "is_active")), CustomIndex(fields=("is_superuser",))]
39+
3440

3541
class Email(Model):
3642
email = fields.CharField(max_length=200)

0 commit comments

Comments
 (0)