1
- import hashlib
1
+ from __future__ import annotations
2
+
2
3
import importlib
3
4
import os
4
5
from datetime import datetime
5
6
from pathlib import Path
6
7
from typing import Dict , Iterable , List , Optional , Set , Tuple , Type , Union , cast
7
8
8
9
import asyncclick as click
10
+ import tortoise
9
11
from dictdiffer import diff
10
12
from tortoise import BaseDBAsyncClient , ConfigurationError , Model , Tortoise
11
13
from tortoise .exceptions import OperationalError
@@ -201,21 +203,25 @@ def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False) -> None:
201
203
202
204
@classmethod
203
205
def _handle_indexes (cls , model : Type [Model ], indexes : List [Union [Tuple [str ], Index ]]) -> list :
204
- ret : list = []
205
-
206
- def index_hash (self ) -> str :
207
- h = hashlib .new ("MD5" , usedforsecurity = False ) # type:ignore[call-arg]
208
- h .update (
209
- self .index_name (cls .ddl .schema_generator , model ).encode ()
210
- + self .__class__ .__name__ .encode ()
211
- )
212
- return h .hexdigest ()
213
-
214
- for index in indexes :
215
- if isinstance (index , Index ):
216
- index .__hash__ = index_hash # type:ignore[method-assign,assignment]
217
- ret .append (index )
218
- return ret
206
+ if tortoise .__version__ > "0.22.2" :
207
+ # The min version of tortoise is '0.11.0', so we can compare it by a `>`,
208
+ # tortoise>0.22.2 have __eq__/__hash__ with Index class since 313ee76.
209
+ return indexes
210
+ if index_classes := set (index .__class__ for index in indexes if isinstance (index , Index )):
211
+ # Leave magic patch here to compare with older version of tortoise-orm
212
+ # TODO: limit tortoise>0.22.2 in pyproject.toml and remove this function when v0.9.0 released
213
+ for index_cls in index_classes :
214
+ if index_cls (fields = ("id" ,)) != index_cls (fields = ("id" ,)):
215
+
216
+ def _hash (self ) -> int :
217
+ return hash ((tuple (sorted (self .fields )), self .name , self .expressions ))
218
+
219
+ def _eq (self , other ) -> bool :
220
+ return type (self ) is type (other ) and self .__dict__ == other .__dict__
221
+
222
+ setattr (index_cls , "__hash__" , _hash )
223
+ setattr (index_cls , "__eq__" , _eq )
224
+ return indexes
219
225
220
226
@classmethod
221
227
def _get_indexes (cls , model , model_describe : dict ) -> Set [Union [Index , Tuple [str , ...]]]:
@@ -282,6 +288,68 @@ def _handle_m2m_fields(
282
288
if add :
283
289
cls ._add_operator (cls .drop_m2m (table ), upgrade , True )
284
290
291
+ @classmethod
292
+ def _handle_relational (
293
+ cls ,
294
+ key : str ,
295
+ old_model_describe : Dict ,
296
+ new_model_describe : Dict ,
297
+ model : Type [Model ],
298
+ old_models : Dict ,
299
+ new_models : Dict ,
300
+ upgrade = True ,
301
+ ) -> None :
302
+ old_fk_fields = cast (List [dict ], old_model_describe .get (key ))
303
+ new_fk_fields = cast (List [dict ], new_model_describe .get (key ))
304
+
305
+ old_fk_fields_name : List [str ] = [i .get ("name" , "" ) for i in old_fk_fields ]
306
+ new_fk_fields_name : List [str ] = [i .get ("name" , "" ) for i in new_fk_fields ]
307
+
308
+ # add
309
+ for new_fk_field_name in set (new_fk_fields_name ).difference (set (old_fk_fields_name )):
310
+ fk_field = cls .get_field_by_name (new_fk_field_name , new_fk_fields )
311
+ if fk_field .get ("db_constraint" ):
312
+ ref_describe = cast (dict , new_models [fk_field ["python_type" ]])
313
+ sql = cls ._add_fk (model , fk_field , ref_describe )
314
+ cls ._add_operator (sql , upgrade , fk_m2m_index = True )
315
+ # drop
316
+ for old_fk_field_name in set (old_fk_fields_name ).difference (set (new_fk_fields_name )):
317
+ old_fk_field = cls .get_field_by_name (old_fk_field_name , cast (List [dict ], old_fk_fields ))
318
+ if old_fk_field .get ("db_constraint" ):
319
+ ref_describe = cast (dict , old_models [old_fk_field ["python_type" ]])
320
+ sql = cls ._drop_fk (model , old_fk_field , ref_describe )
321
+ cls ._add_operator (sql , upgrade , fk_m2m_index = True )
322
+
323
+ @classmethod
324
+ def _handle_fk_fields (
325
+ cls ,
326
+ old_model_describe : Dict ,
327
+ new_model_describe : Dict ,
328
+ model : Type [Model ],
329
+ old_models : Dict ,
330
+ new_models : Dict ,
331
+ upgrade = True ,
332
+ ) -> None :
333
+ key = "fk_fields"
334
+ cls ._handle_relational (
335
+ key , old_model_describe , new_model_describe , model , old_models , new_models , upgrade
336
+ )
337
+
338
+ @classmethod
339
+ def _handle_o2o_fields (
340
+ cls ,
341
+ old_model_describe : Dict ,
342
+ new_model_describe : Dict ,
343
+ model : Type [Model ],
344
+ old_models : Dict ,
345
+ new_models : Dict ,
346
+ upgrade = True ,
347
+ ) -> None :
348
+ key = "o2o_fields"
349
+ cls ._handle_relational (
350
+ key , old_model_describe , new_model_describe , model , old_models , new_models , upgrade
351
+ )
352
+
285
353
@classmethod
286
354
def diff_models (
287
355
cls , old_models : Dict [str , dict ], new_models : Dict [str , dict ], upgrade = True
@@ -296,7 +364,7 @@ def diff_models(
296
364
_aerich = f"{ cls .app } .{ cls ._aerich } "
297
365
old_models .pop (_aerich , None )
298
366
new_models .pop (_aerich , None )
299
- models_with_rename_field : Set [str ] = set ()
367
+ models_with_rename_field : Set [str ] = set () # models that trigger the click.prompt
300
368
301
369
for new_model_str , new_model_describe in new_models .items ():
302
370
model = cls ._get_model (new_model_describe ["name" ].split ("." )[1 ])
@@ -336,6 +404,13 @@ def diff_models(
336
404
# current only support rename pk
337
405
if action == "change" and option == "name" :
338
406
cls ._add_operator (cls ._rename_field (model , * change ), upgrade )
407
+ # fk fields
408
+ args = (old_model_describe , new_model_describe , model , old_models , new_models )
409
+ cls ._handle_fk_fields (* args , upgrade = upgrade )
410
+ # o2o fields
411
+ cls ._handle_o2o_fields (* args , upgrade = upgrade )
412
+ old_o2o_columns = [i ["raw_field" ] for i in old_model_describe .get ("o2o_fields" , [])]
413
+ new_o2o_columns = [i ["raw_field" ] for i in new_model_describe .get ("o2o_fields" , [])]
339
414
# m2m fields
340
415
cls ._handle_m2m_fields (
341
416
old_model_describe , new_model_describe , model , new_models , upgrade
@@ -369,12 +444,10 @@ def diff_models(
369
444
new_data_fields_name = cast (List [str ], [i .get ("name" ) for i in new_data_fields ])
370
445
371
446
# add fields or rename fields
372
- rename_fields : Dict [str , str ] = {}
373
447
for new_data_field_name in set (new_data_fields_name ).difference (
374
448
set (old_data_fields_name )
375
449
):
376
450
new_data_field = cls .get_field_by_name (new_data_field_name , new_data_fields )
377
- model_rename_fields = cls ._rename_fields .get (new_model_str )
378
451
is_rename = False
379
452
field_type = new_data_field .get ("field_type" )
380
453
db_column = new_data_field .get ("db_column" )
@@ -399,8 +472,11 @@ def diff_models(
399
472
and old_data_field_name not in new_data_fields_name
400
473
):
401
474
if upgrade :
402
- if old_data_field_name in rename_fields or (
403
- new_data_field_name in rename_fields .values ()
475
+ if (
476
+ rename_fields := cls ._rename_fields .get (new_model_str )
477
+ ) and (
478
+ old_data_field_name in rename_fields
479
+ or new_data_field_name in rename_fields .values ()
404
480
):
405
481
continue
406
482
prefix = f"({ new_model_str } ) "
@@ -417,22 +493,18 @@ def diff_models(
417
493
show_choices = True ,
418
494
)
419
495
if is_rename :
496
+ if rename_fields is None :
497
+ rename_fields = cls ._rename_fields [new_model_str ] = {}
420
498
rename_fields [old_data_field_name ] = new_data_field_name
421
499
else :
422
500
is_rename = False
423
- if model_rename_fields and (
424
- rename_to := model_rename_fields . get ( new_data_field_name )
501
+ if rename_to := cls . _rename_fields . get ( new_model_str , {}). get (
502
+ new_data_field_name
425
503
):
426
504
is_rename = True
427
505
if rename_to != old_data_field_name :
428
506
continue
429
507
if is_rename :
430
- if upgrade :
431
- if new_model_str not in cls ._rename_fields :
432
- cls ._rename_fields [new_model_str ] = {}
433
- cls ._rename_fields [new_model_str ][
434
- old_data_field_name
435
- ] = new_data_field_name
436
508
# only MySQL8+ has rename syntax
437
509
if (
438
510
cls .dialect == "mysql"
@@ -452,7 +524,10 @@ def diff_models(
452
524
)
453
525
if not is_rename :
454
526
cls ._add_operator (cls ._add_field (model , new_data_field ), upgrade )
455
- if new_data_field ["indexed" ]:
527
+ if (
528
+ new_data_field ["indexed" ]
529
+ and new_data_field ["db_column" ] not in new_o2o_columns
530
+ ):
456
531
cls ._add_operator (
457
532
cls ._add_index (
458
533
model , (new_data_field ["db_column" ],), new_data_field ["unique" ]
@@ -461,14 +536,14 @@ def diff_models(
461
536
True ,
462
537
)
463
538
# remove fields
464
- model_rename_fields = cls ._rename_fields .get (new_model_str )
539
+ rename_fields = cls ._rename_fields .get (new_model_str )
465
540
for old_data_field_name in set (old_data_fields_name ).difference (
466
541
set (new_data_fields_name )
467
542
):
468
543
# don't remove field if is renamed
469
- if model_rename_fields and (
470
- (upgrade and old_data_field_name in model_rename_fields )
471
- or (not upgrade and old_data_field_name in model_rename_fields .values ())
544
+ if rename_fields and (
545
+ (upgrade and old_data_field_name in rename_fields )
546
+ or (not upgrade and old_data_field_name in rename_fields .values ())
472
547
):
473
548
continue
474
549
old_data_field = cls .get_field_by_name (old_data_field_name , old_data_fields )
@@ -477,46 +552,17 @@ def diff_models(
477
552
cls ._remove_field (model , db_column ),
478
553
upgrade ,
479
554
)
480
- if old_data_field ["indexed" ]:
555
+ if (
556
+ old_data_field ["indexed" ]
557
+ and old_data_field ["db_column" ] not in old_o2o_columns
558
+ ):
481
559
is_unique_field = old_data_field .get ("unique" )
482
560
cls ._add_operator (
483
561
cls ._drop_index (model , {db_column }, is_unique_field ),
484
562
upgrade ,
485
563
True ,
486
564
)
487
565
488
- old_fk_fields = cast (List [dict ], old_model_describe .get ("fk_fields" ))
489
- new_fk_fields = cast (List [dict ], new_model_describe .get ("fk_fields" ))
490
-
491
- old_fk_fields_name : List [str ] = [i .get ("name" , "" ) for i in old_fk_fields ]
492
- new_fk_fields_name : List [str ] = [i .get ("name" , "" ) for i in new_fk_fields ]
493
-
494
- # add fk
495
- for new_fk_field_name in set (new_fk_fields_name ).difference (
496
- set (old_fk_fields_name )
497
- ):
498
- fk_field = cls .get_field_by_name (new_fk_field_name , new_fk_fields )
499
- if fk_field .get ("db_constraint" ):
500
- ref_describe = cast (dict , new_models [fk_field ["python_type" ]])
501
- cls ._add_operator (
502
- cls ._add_fk (model , fk_field , ref_describe ),
503
- upgrade ,
504
- fk_m2m_index = True ,
505
- )
506
- # drop fk
507
- for old_fk_field_name in set (old_fk_fields_name ).difference (
508
- set (new_fk_fields_name )
509
- ):
510
- old_fk_field = cls .get_field_by_name (
511
- old_fk_field_name , cast (List [dict ], old_fk_fields )
512
- )
513
- if old_fk_field .get ("db_constraint" ):
514
- ref_describe = cast (dict , old_models [old_fk_field ["python_type" ]])
515
- cls ._add_operator (
516
- cls ._drop_fk (model , old_fk_field , ref_describe ),
517
- upgrade ,
518
- fk_m2m_index = True ,
519
- )
520
566
# change fields
521
567
for field_name in set (new_data_fields_name ).intersection (set (old_data_fields_name )):
522
568
old_data_field = cls .get_field_by_name (field_name , old_data_fields )
0 commit comments