Skip to content

Commit

Permalink
Fixed array item types not being adapted or imported
Browse files Browse the repository at this point in the history
Fixes #46 and closes #47 and closes #53.
  • Loading branch information
agronholm committed May 19, 2018
1 parent deedc36 commit 09ff674
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ Version history

* Fixed invalid class names being generated (fixes #60; PR by Dan O'Huiginn)

* Fixed array item types not being adapted or imported (thanks to Martin Glauer for help)


1.1.6
-----
Expand Down
17 changes: 11 additions & 6 deletions sqlacodegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import sqlalchemy
from sqlalchemy import (
Enum, ForeignKeyConstraint, PrimaryKeyConstraint, CheckConstraint, UniqueConstraint, Table,
Column)
Column, ARRAY)
from sqlalchemy.schema import ForeignKey
from sqlalchemy.sql.expression import TextClause
from sqlalchemy.types import Boolean, String
Expand Down Expand Up @@ -113,10 +113,7 @@ def __init__(self, table):

def _get_adapted_type(self, coltype):
for supercls in coltype.__class__.__mro__:
if (not supercls.__name__.startswith('_') and
supercls.__name__ != supercls.__name__.upper() and
hasattr(supercls, '__visit_name__')):

if not supercls.__name__.startswith('_') and hasattr(supercls, '__visit_name__'):
# Hack to fix adaptation of the Enum class which is broken since SQLAlchemy 1.2
kw = {}
if supercls is Enum:
Expand All @@ -127,7 +124,12 @@ def _get_adapted_type(self, coltype):
for key, value in kw.items():
setattr(coltype, key, value)

break
if isinstance(coltype, ARRAY):
coltype.item_type = self._get_adapted_type(coltype.item_type)

# Stop on the first valid non-uppercase column type class
if supercls.__name__ != supercls.__name__.upper():
break

return coltype

Expand All @@ -140,6 +142,9 @@ def add_imports(self, collector):
if column.server_default:
collector.add_literal_import('sqlalchemy', 'text')

if isinstance(column.type, ARRAY):
collector.add_import(column.type.item_type.__class__)

for constraint in sorted(self.table.constraints, key=_get_constraint_sort_key):
if isinstance(constraint, ForeignKeyConstraint):
if len(constraint.columns) > 1:
Expand Down
22 changes: 22 additions & 0 deletions tests/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from sqlalchemy.dialects.mysql import base as mysql
from sqlalchemy.dialects.mysql.base import TINYINT
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.dialects.postgresql.base import BIGINT, DOUBLE_PRECISION, BOOLEAN, ENUM
from sqlalchemy.engine import create_engine
from sqlalchemy.schema import (
Expand Down Expand Up @@ -83,6 +84,27 @@ def test_boolean_detection(self):
Column('bool2', Boolean),
Column('bool3', Boolean)
)
"""

def test_arrays(self):
Table(
'simple_items', self.metadata,
Column('dp_array', ARRAY(DOUBLE_PRECISION(precision=53))),
Column('int_array', ARRAY(INTEGER))
)

assert self.generate_code() == """\
# coding: utf-8
from sqlalchemy import ARRAY, Column, Float, Integer, MetaData, Table
metadata = MetaData()
t_simple_items = Table(
'simple_items', metadata,
Column('dp_array', ARRAY(Float(precision=53))),
Column('int_array', ARRAY(Integer()))
)
"""

def test_enum_detection(self):
Expand Down

0 comments on commit 09ff674

Please sign in to comment.