Skip to content

Commit

Permalink
Merge pull request #318 from microsoft/add-comments-support
Browse files Browse the repository at this point in the history
Add db_comment support
  • Loading branch information
dauinsight authored Dec 14, 2023
2 parents 32d6fb3 + bb2cb08 commit 072ea26
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 15 deletions.
1 change: 1 addition & 0 deletions mssql/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
requires_literal_defaults = True
requires_sqlparse_for_splitting = False
supports_boolean_expr_in_select_clause = False
supports_comments = True
supports_covering_indexes = True
supports_deferrable_unique_constraints = False
supports_expression_indexes = False
Expand Down
49 changes: 39 additions & 10 deletions mssql/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from django.db import DatabaseError
import pyodbc as Database

from collections import namedtuple

from django import VERSION
from django.db.backends.base.introspection import (
BaseDatabaseIntrospection, FieldInfo, TableInfo,
)
from django.db.backends.base.introspection import BaseDatabaseIntrospection
from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
from django.db.backends.base.introspection import TableInfo as BaseTableInfo
from django.db.models.indexes import Index
from django.conf import settings

Expand All @@ -16,6 +18,8 @@
SQL_SMALLAUTOFIELD = -777333
SQL_TIMESTAMP_WITH_TIMEZONE = -155

FieldInfo = namedtuple("FieldInfo", BaseFieldInfo._fields + ("comment",))
TableInfo = namedtuple("TableInfo", BaseTableInfo._fields + ("comment",))

def get_schema_name():
return getattr(settings, 'SCHEMA_TO_INSPECT', 'SCHEMA_NAME()')
Expand Down Expand Up @@ -73,13 +77,26 @@ def get_table_list(self, cursor):
"""
Returns a list of table and view names in the current database.
"""
sql = 'SELECT TABLE_NAME, TABLE_TYPE FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = %s' % (
sql = """SELECT
TABLE_NAME,
TABLE_TYPE,
CAST(ep.value AS VARCHAR) AS COMMENT
FROM INFORMATION_SCHEMA.TABLES i
LEFT JOIN sys.tables t ON t.name = i.TABLE_NAME
LEFT JOIN sys.extended_properties ep ON t.object_id = ep.major_id
AND ((ep.name = 'MS_DESCRIPTION' AND ep.minor_id = 0) OR ep.value IS NULL)
AND i.TABLE_SCHEMA = %s""" % (
get_schema_name())
cursor.execute(sql)
types = {'BASE TABLE': 't', 'VIEW': 'v'}
return [TableInfo(row[0], types.get(row[1]))
for row in cursor.fetchall()
if row[0] not in self.ignored_tables]
if VERSION >= (4, 2):
return [TableInfo(row[0], types.get(row[1]), row[2])
for row in cursor.fetchall()
if row[0] not in self.ignored_tables]
else:
return [BaseTableInfo(row[0], types.get(row[1]))
for row in cursor.fetchall()
if row[0] not in self.ignored_tables]

def _is_auto_field(self, cursor, table_name, column_name):
"""
Expand Down Expand Up @@ -113,7 +130,7 @@ def get_table_description(self, cursor, table_name, identity_check=True):

if not columns:
raise DatabaseError(f"Table {table_name} does not exist.")

items = []
for column in columns:
if VERSION >= (3, 2):
Expand All @@ -128,7 +145,16 @@ def get_table_description(self, cursor, table_name, identity_check=True):
column.append(collation_name[0] if collation_name else '')
else:
column.append('')

if VERSION >= (4, 2):
sql = """select CAST(ep.value AS VARCHAR) AS COMMENT
FROM sys.columns c
INNER JOIN sys.tables t ON c.object_id = t.object_id
INNER JOIN sys.extended_properties ep ON c.object_id=ep.major_id AND ep.minor_id = c.column_id
WHERE t.name = '%s' AND c.name = '%s' AND ep.name = 'MS_Description'
""" % (table_name, column[0])
cursor.execute(sql)
comment = cursor.fetchone()
column.append(comment[0] if comment else '')
if identity_check and self._is_auto_field(cursor, table_name, column[0]):
if column[1] == Database.SQL_BIGINT:
column[1] = SQL_BIGAUTOFIELD
Expand All @@ -138,7 +164,10 @@ def get_table_description(self, cursor, table_name, identity_check=True):
column[1] = SQL_AUTOFIELD
if column[1] == Database.SQL_WVARCHAR and column[3] < 4000:
column[1] = Database.SQL_WCHAR
items.append(FieldInfo(*column))
if VERSION >= (4, 2):
items.append(FieldInfo(*column))
else:
items.append(BaseFieldInfo(*column))
return items

def get_sequences(self, cursor, table_name, table_fields=()):
Expand Down
99 changes: 94 additions & 5 deletions mssql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,40 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_rename_table = "EXEC sp_rename %(old_table)s, %(new_table)s"
sql_create_unique_null = "CREATE UNIQUE INDEX %(name)s ON %(table)s(%(columns)s) " \
"WHERE %(columns)s IS NOT NULL"

sql_alter_table_comment= """
IF NOT EXISTS (SELECT NULL FROM sys.extended_properties ep
WHERE ep.major_id = OBJECT_ID('%(table)s')
AND ep.name = 'MS_Description'
AND ep.minor_id = 0)
EXECUTE sp_addextendedproperty
@name = 'MS_Description', @value = %(comment)s,
@level0type = 'SCHEMA', @level0name = 'dbo',
@level1type = 'TABLE', @level1name = %(table)s
ELSE
EXECUTE sp_updateextendedproperty
@name = 'MS_Description', @value = %(comment)s,
@level0type = 'SCHEMA', @level0name = 'dbo',
@level1type = 'TABLE', @level1name = %(table)s
"""
sql_alter_column_comment= """
IF NOT EXISTS (SELECT NULL FROM sys.extended_properties ep
WHERE ep.major_id = OBJECT_ID('%(table)s')
AND ep.name = 'MS_Description'
AND ep.minor_id = (SELECT column_id FROM sys.columns
WHERE name = '%(column)s'
AND object_id = OBJECT_ID('%(table)s')))
EXECUTE sp_addextendedproperty
@name = 'MS_Description', @value = %(comment)s,
@level0type = 'SCHEMA', @level0name = 'dbo',
@level1type = 'TABLE', @level1name = %(table)s,
@level2type = 'COLUMN', @level2name = %(column)s
ELSE
EXECUTE sp_updateextendedproperty
@name = 'MS_Description', @value = %(comment)s,
@level0type = 'SCHEMA', @level0name = 'dbo',
@level1type = 'TABLE', @level1name = %(table)s,
@level2type = 'COLUMN', @level2name = %(column)s
"""
_deferred_unique_indexes = defaultdict(list)

def _alter_column_default_sql(self, model, old_field, new_field, drop=False):
Expand Down Expand Up @@ -138,7 +171,18 @@ def _alter_column_default_sql(self, model, old_field, new_field, drop=False):
},
params,
)


def _alter_column_comment_sql(self, model, new_field, new_type, new_db_comment):
return (
self.sql_alter_column_comment
% {
"table": self.quote_name(model._meta.db_table),
"column": new_field.column,
"comment": self._comment_sql(new_db_comment),
},
[],
)

def _alter_column_null_sql(self, model, old_field, new_field):
"""
Hook to specialize column null alteration.
Expand Down Expand Up @@ -316,7 +360,19 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,

# Drop any FK constraints, we'll remake them later
fks_dropped = set()
if old_field.remote_field and old_field.db_constraint:
if (
old_field.remote_field
and old_field.db_constraint
and (django_version < (4,2)
or
(django_version >= (4, 2)
and self._field_should_be_altered(
old_field,
new_field,
ignore={"db_comment"})
)
)
):
# Drop index, SQL Server requires explicit deletion
if not hasattr(new_field, 'db_constraint') or not new_field.db_constraint:
index_names = self._constraint_names(model, [old_field.column], index=True)
Expand Down Expand Up @@ -446,8 +502,11 @@ def _alter_field(self, model, old_field, new_field, old_type, new_type,
actions = []
null_actions = []
post_actions = []
# Type change?
if old_type != new_type:
# Type or comment change?
if old_type != new_type or (django_version >= (4, 2) and
self.connection.features.supports_comments
and old_field.db_comment != new_field.db_comment
):
if django_version >= (4, 2):
fragment, other_actions = self._alter_column_type_sql(
model, old_field, new_field, new_type, old_collation=None, new_collation=None
Expand Down Expand Up @@ -922,6 +981,19 @@ def add_field(self, model, field):
"changes": changes_sql,
}
self.execute(sql, params)
# Add field comment, if required.
if django_version >= (4, 2):
if (
field.db_comment
and self.connection.features.supports_comments
and not self.connection.features.supports_comments_inline
):
field_type = db_params["type"]
self.execute(
*self._alter_column_comment_sql(
model, field, field_type, field.db_comment
)
)
# Add an index, if required
self.deferred_sql.extend(self._field_indexes_sql(model, field))
# Add any FK constraints later
Expand Down Expand Up @@ -1129,6 +1201,23 @@ def create_model(self, model):
# Prevent using [] as params, in the case a literal '%' is used in the definition
self.execute(sql, params or None)

if django_version >= (4, 2) and self.connection.features.supports_comments:
# Add table comment.
if model._meta.db_table_comment:
self.alter_db_table_comment(model, None, model._meta.db_table_comment)
# Add column comments.
if not self.connection.features.supports_comments_inline:
for field in model._meta.local_fields:
if field.db_comment:
field_db_params = field.db_parameters(
connection=self.connection
)
field_type = field_db_params["type"]
self.execute(
*self._alter_column_comment_sql(
model, field, field_type, field.db_comment
)
)
# Add any field index and index_together's (deferred as SQLite3 _remake_table needs it)
self.deferred_sql.extend(self._model_indexes_sql(model))
self.deferred_sql = list(set(self.deferred_sql))
Expand Down

0 comments on commit 072ea26

Please sign in to comment.