Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sa default # 49 #206

Merged
merged 4 commits into from
Nov 25, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions aiopg/sa/connection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio

from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.dml import UpdateBase
from sqlalchemy.sql.ddl import DDLElement
from sqlalchemy.sql.dml import UpdateBase

from . import exc
from .result import ResultProxy
Expand All @@ -12,7 +12,6 @@


class SAConnection:

def __init__(self, connection, engine):
self._connection = connection
self._transaction = None
Expand Down Expand Up @@ -83,8 +82,8 @@ def _execute(self, query, *multiparams, **params):
raise exc.ArgumentError("Don't mix sqlalchemy SELECT "
"clause with positional "
"parameters")
compiled_parameters = [compiled.construct_params(
dp)]

compiled_parameters = [compiled.construct_params(dp)]
processed_parameters = []
processors = compiled._bind_processors
for compiled_params in compiled_parameters:
Expand Down
23 changes: 21 additions & 2 deletions aiopg/sa/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,39 @@
import json

import aiopg

from .connection import SAConnection
from .exc import InvalidRequestError
from ..utils import PY_35, _PoolContextManager, _PoolAcquireContextManager
from ..connection import TIMEOUT

from ..utils import PY_35, _PoolContextManager, _PoolAcquireContextManager

try:
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.dialects.postgresql.psycopg2 import PGCompiler_psycopg2
except ImportError: # pragma: no cover
raise ImportError('aiopg.sa requires sqlalchemy')


class APGCompiler_psycopg2(PGCompiler_psycopg2):
def construct_params(self, params=None, _group_number=None, _check=True):
pd = super().construct_params(params, _group_number, _check)

for column in self.prefetch:
pd[column.key] = self._exec_default(column.default)

return pd

def _exec_default(self, default):
if default.is_callable:
return default.arg(self.dialect)
else:
return default.arg


_dialect = PGDialect_psycopg2(json_serializer=json.dumps,
json_deserializer=lambda x: x)

_dialect.statement_compiler = APGCompiler_psycopg2
_dialect.implicit_returning = True
_dialect.supports_native_enum = True
_dialect.supports_smallserial = True # 9.2+
Expand Down
89 changes: 89 additions & 0 deletions tests/test_sa_default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import asyncio
import datetime

import pytest
import sqlalchemy as sa
from sqlalchemy.sql.ddl import CreateTable

meta = sa.MetaData()
tbl = sa.Table('sa_tbl4', meta,
sa.Column('id', sa.Integer, nullable=False, primary_key=True),
sa.Column('id_sequence', sa.Integer, nullable=False,
default=sa.Sequence('id_sequence_seq')),
sa.Column('name', sa.String(255), nullable=False,
default='default test'),
sa.Column('count', sa.Integer, default=100, nullable=None),
sa.Column('date', sa.DateTime, default=datetime.datetime.now),
sa.Column('count_str', sa.Integer,
default=sa.func.length('abcdef')),
sa.Column('is_active', sa.Boolean, default=True))


@pytest.fixture
def engine(make_engine, loop):
@asyncio.coroutine
def start():
engine = yield from make_engine()
with (yield from engine) as conn:
yield from conn.execute('DROP TABLE IF EXISTS sa_tbl4')
yield from conn.execute('DROP SEQUENCE IF EXISTS id_sequence_seq')
yield from conn.execute(CreateTable(tbl))
yield from conn.execute('CREATE SEQUENCE id_sequence_seq')

return engine

return loop.run_until_complete(start())


@asyncio.coroutine
def test_default_fields(engine):
with (yield from engine) as conn:
yield from conn.execute(tbl.insert().values())

res = yield from conn.execute(tbl.select())
row = yield from res.fetchone()
assert row.count == 100
assert row.id_sequence == 1
assert row.count_str == 6
assert row.name == 'default test'
assert row.is_active is True
assert type(row.date) == datetime.datetime


@asyncio.coroutine
def test_default_fields_isnull(engine):
with (yield from engine) as conn:
yield from conn.execute(tbl.insert().values(
is_active=False,
date=None,
))

res = yield from conn.execute(tbl.select())
row = yield from res.fetchone()
assert row.count == 100
assert row.id_sequence == 1
assert row.count_str == 6
assert row.name == 'default test'
assert row.is_active is False
assert row.date is None


@asyncio.coroutine
def test_default_fields_edit(engine):
with (yield from engine) as conn:
date = datetime.datetime.now()
yield from conn.execute(tbl.insert().values(
name='edit name',
is_active=False,
date=date,
count=1,
))

res = yield from conn.execute(tbl.select())
row = yield from res.fetchone()
assert row.count == 1
assert row.id_sequence == 1
assert row.count_str == 6
assert row.name == 'edit name'
assert row.is_active is False
assert row.date == date