From 3ae0387f8ce3d4f4559b7d1363f216ea6cbc613f Mon Sep 17 00:00:00 2001 From: xiechen Date: Wed, 11 Aug 2021 17:36:07 +0800 Subject: [PATCH] increase: 1. Inspectdb adds DECIMAL, DOUBLE, CHAR, TIME data type matching; 2. Add exception handling, avoid the need to manually create the entire table because a certain data type is not supported. --- aerich/inspectdb.py | 90 +++++++++++++++++++++++++++++++-------------- 1 file changed, 63 insertions(+), 27 deletions(-) diff --git a/aerich/inspectdb.py b/aerich/inspectdb.py index b2eb3aa..c934b4a 100644 --- a/aerich/inspectdb.py +++ b/aerich/inspectdb.py @@ -8,15 +8,34 @@ class InspectDb: _table_template = "class {table}(Model):\n" _field_template_mapping = { - "INT": " {field} = fields.IntField({pk}{unique}{comment})", - "SMALLINT": " {field} = fields.IntField({pk}{unique}{comment})", - "TINYINT": " {field} = fields.BooleanField({null}{default}{comment})", - "VARCHAR": " {field} = fields.CharField({pk}{unique}{length}{null}{default}{comment})", - "LONGTEXT": " {field} = fields.TextField({null}{default}{comment})", - "TEXT": " {field} = fields.TextField({null}{default}{comment})", - "DATETIME": " {field} = fields.DatetimeField({null}{default}{comment})", - "FLOAT": " {field} = fields.FloatField({null}{default}{comment})", + # Numerical type + "TINYINT": "{field} = fields.BooleanField({null}{default}{comment})", + "SMALLINT": "{field} = fields.SmallIntField({pk}{unique}{comment})", + "INT": "{field} = fields.IntField({pk}{unique}{comment})", + "BIGINT": "{field} = fields.SmallIntField({pk}{unique}{comment})", + "FLOAT": "{field} = fields.FloatField({null}{default}{comment})", + "DECIMAL": "{field} = fields.DecimalField({null}{default}{other})", # 新增 + "DOUBLE": "{field} = fields.FloatField({null}{default}{comment})", # 新增 + + # String type + "CHAR": "{field} = fields.CharField({pk}{unique}{length}{null}{default}{comment})", # 新增 + "VARCHAR": "{field} = fields.CharField({pk}{unique}{length}{null}{default}{comment})", + "LONGTEXT": "{field} = fields.TextField({null}{default}{comment})", + "TEXT": "{field} = fields.TextField({null}{default}{comment})", + + # Date and time type + "DATETIME": "{field} = fields.DatetimeField({null}{default}{comment})", + "TIME": "{field} = fields.TimeDeltaField({null}{default}{comment})" # 新增 } + _text = """ + _____ _ _ _ + |_ _| | | | | | + | | _ __ ___ _ __ ___ ___| |_ __| | |__ + | | | '_ \/ __| '_ \ / _ \/ __| __/ _` | '_ \ + _| |_| | | \__ \ |_) | __/ (__| || (_| | |_) | + |_____|_| |_|___/ .__/ \___|\___|\__\__,_|_.__/ + | | + |_| \n\n\n""" def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None): self.conn = conn @@ -49,14 +68,16 @@ async def inspect(self): model = self._table_template.format(table=name) for column_name, column in columns.items(): comment = default = length = unique = null = pk = "" - if column.primary_key: - pk = "pk=True, " - if column.unique: - unique = "unique=True, " + other = {} + if column.data_type == "VARCHAR": length = f"max_length={column.length}, " - if not column.not_null: - null = "null=True, " + elif column.data_type == "CHAR": + length = f"max_length={column.length}, " + elif column.data_type == "DECIMAL": + other['max_digits'] = column.length + other['decimal_places'] = column.scale + if column.default is not None: if column.data_type == "TINYINT": default = f"default={'True' if column.default == '1' else 'False'}, " @@ -66,21 +87,36 @@ async def inspect(self): default = "auto_now_add=True, " else: default = "auto_now=True, " - else: - default = f"default={column.default}, " + elif column.default != "NULL": + default = f'default={column.default}, ' + + if not column.not_null: + null = "null=True, " + + if column.primary_key: + pk = "pk=True, " + + if column.unique: + unique = "unique=True, " if column.comment: comment = f"description='{column.comment}', " - field = self._field_template_mapping[column.data_type].format( - field=column_name, - pk=pk, - unique=unique, - length=length, - null=null, - default=default, - comment=comment, - ) + try: + field = self._field_template_mapping[column.data_type].format( + field=column_name, + pk=pk, + unique=unique, + length=length, + null=null, + default=default, + comment=comment, + other=', '.join([f'{parameter}={value}' for parameter, value in other.items()]) + ) + # Avoid the need to manually create the entire table because a certain data type is not supported + except KeyError: + field = f"{column_name} = The {column.data_type} data type type is not currently supported, please add it manually." fields.append(field) - tables.append(model + "\n".join(fields)) - sys.stdout.write(result + "\n\n\n".join(tables)) + tables.append(model + " " + "\n ".join(fields)) + sys.stdout.write(self._text) + sys.stdout.write(result + "\n\n\n".join(tables) + '\n\n')