Skip to content

Commit

Permalink
BUG: Fix incorrect assumptions about attached SQLite databases
Browse files Browse the repository at this point in the history
Author: Phillip Cloud <cpcloud@gmail.com>

Closes ibis-project#1937 from cpcloud/sqlite-sqlalchemy and squashes the following commits:

f3cda71 [Phillip Cloud] Remove constraints on windows
22913f2 [Phillip Cloud] Don't use type hints for now
a018895 [Phillip Cloud] Py35 lint
a4ec057 [Phillip Cloud] Disposal is in fact needed
c9e5d72 [Phillip Cloud] BUG: Fix incorrect assumptions about attached SQLite databases
  • Loading branch information
cpcloud authored Aug 22, 2019
1 parent 277ff1b commit b748593
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 81 deletions.
2 changes: 1 addition & 1 deletion ci/azure/windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
- script: conda update --all conda=$(conda.version)
displayName: 'Update conda and install an appropriate version'

- script: conda create --name $(conda.env) python=$(python.version) numpy pandas pytables ruamel.yaml jinja2 pyarrow multipledispatch pymysql "sqlalchemy<1.3.6" psycopg2 graphviz click mock plumbum geopandas toolz regex
- script: conda create --name $(conda.env) python=$(python.version) numpy pandas pytables ruamel.yaml jinja2 pyarrow multipledispatch pymysql "sqlalchemy>=1.1" psycopg2 graphviz click mock plumbum geopandas toolz regex
displayName: 'Create conda environment'

- script: |
Expand Down
2 changes: 1 addition & 1 deletion ci/requirements-3.5-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dependencies:
- regex
- requests
- ruamel.yaml
- sqlalchemy>=1.1,<1.3.7
- sqlalchemy>=1.1
- thrift>=0.9.3
# required for impyla in case of py3
- thriftpy
Expand Down
2 changes: 1 addition & 1 deletion ci/requirements-3.6-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ dependencies:
- requests
- ruamel.yaml
- shapely
- sqlalchemy>=1.1,<1.3.7
- sqlalchemy>=1.1
- thrift>=0.9.3
- thriftpy2 # required for impyla in case of py3
- toolz
Expand Down
2 changes: 1 addition & 1 deletion ci/requirements-3.7-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ dependencies:
- requests
- ruamel.yaml
- shapely
- sqlalchemy>=1.1,<1.3.7
- sqlalchemy>=1.1
- thrift>=0.9.3
- thriftpy2 # required for impyla in case of py3
- toolz
Expand Down
83 changes: 58 additions & 25 deletions ibis/sql/alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
import operator
import sys
from typing import List, Optional

import pandas as pd
import sqlalchemy as sa
Expand Down Expand Up @@ -256,7 +257,7 @@ def schema_from_table(table, schema=None):
return sch.schema(pairs)


def table_from_schema(name, meta, schema):
def table_from_schema(name, meta, schema, database: Optional[str] = None):
# Convert Ibis schema to SQLA table
columns = []

Expand All @@ -265,7 +266,7 @@ def table_from_schema(name, meta, schema):
column = sa.Column(colname, satype, nullable=dtype.nullable)
columns.append(column)

return sa.Table(name, meta, *columns)
return sa.Table(name, meta, schema=database, *columns)


def _variance_reduction(func_name):
Expand Down Expand Up @@ -565,8 +566,10 @@ def _window(t, expr):
return t.translate(arg)

if window.max_lookback is not None:
raise NotImplementedError('Rows with max lookback is not implemented '
'for SQLAlchemy-based backends.')
raise NotImplementedError(
'Rows with max lookback is not implemented '
'for SQLAlchemy-based backends.'
)

# Some analytic functions need to have the expression of interest in
# the ORDER BY part of the window clause
Expand Down Expand Up @@ -1039,7 +1042,7 @@ class AlchemyClient(SQLClient):
dialect = AlchemyDialect
query_class = AlchemyQuery

def __init__(self, con):
def __init__(self, con: sa.engine.Engine) -> None:
super().__init__()
self.con = con
self.meta = sa.MetaData(bind=con)
Expand Down Expand Up @@ -1079,7 +1082,9 @@ def create_table(self, name, expr=None, schema=None, database=None):
schema = expr.schema()

self._schemas[self._fully_qualified_name(name, database)] = schema
t = table_from_schema(name, self.meta, schema)
t = self._table_from_schema(
name, schema, database=database or self.current_database
)

with self.begin() as bind:
t.create(bind=bind)
Expand All @@ -1088,15 +1093,34 @@ def create_table(self, name, expr=None, schema=None, database=None):
t.insert().from_select(list(expr.columns), expr.compile())
)

def _columns_from_schema(
self, name: str, schema: sch.Schema
) -> List[sa.Column]:
return [
sa.Column(colname, _to_sqla_type(dtype), nullable=dtype.nullable)
for colname, dtype in zip(schema.names, schema.types)
]

def _table_from_schema(
self, name: str, schema: sch.Schema, database: Optional[str] = None
) -> sa.Table:
columns = self._columns_from_schema(name, schema)
return sa.Table(name, self.meta, *columns)

@invalidates_reflection_cache
def drop_table(self, table_name, database=None, force=False):
if database is not None and database != self.engine.url.database:
def drop_table(
self,
table_name: str,
database: Optional[str] = None,
force: bool = False,
) -> None:
if database is not None and database != self.con.url.database:
raise NotImplementedError(
'Dropping tables from a different database is not yet '
'implemented'
)

t = sa.Table(table_name, self.meta)
t = self._get_sqla_table(table_name, schema=database, autoload=False)
t.drop(checkfirst=force)

assert (
Expand All @@ -1112,23 +1136,33 @@ def drop_table(self, table_name, database=None, force=False):
except KeyError: # schemas won't be cached if created with raw_sql
pass

def truncate_table(self, table_name, database=None):
self.meta.tables[table_name].delete().execute()
def truncate_table(
self, table_name: str, database: Optional[str] = None
) -> None:
t = self._get_sqla_table(table_name, schema=database)
t.delete().execute()

def list_tables(self, like=None, database=None, schema=None):
"""
List tables/views in the current (or indicated) database.
def list_tables(
self,
like: Optional[str] = None,
database: Optional[str] = None,
schema: Optional[str] = None,
) -> List[str]:
"""List tables/views in the current or indicated database.
Parameters
----------
like : string, default None
Checks for this string contained in name
database : string, default None
If not passed, uses the current/default database
like
Checks for this string contained in name
database
If not passed, uses the current database
schema
The schema namespace that tables should be listed from
Returns
-------
tables : list of strings
List[str]
"""
inspector = self.inspector
# inspector returns a mutable version of its names, so make a copy.
Expand All @@ -1138,19 +1172,18 @@ def list_tables(self, like=None, database=None, schema=None):
names = [x for x in names if like in x]
return sorted(names)

def _execute(self, query, results=True):
with self.begin() as con:
return AlchemyProxy(con.execute(query))
def _execute(self, query: str, results: bool = True):
return AlchemyProxy(self.con.execute(query))

@invalidates_reflection_cache
def raw_sql(self, query, results=False):
def raw_sql(self, query: str, results: bool = False):
return super().raw_sql(query, results=results)

def _build_ast(self, expr, context):
return build_ast(expr, context)

def _get_sqla_table(self, name, schema=None):
return sa.Table(name, self.meta, schema=schema, autoload=True)
def _get_sqla_table(self, name, schema=None, autoload=True):
return sa.Table(name, self.meta, schema=schema, autoload=autoload)

def _sqla_table_to_expr(self, table):
node = self.table_class(table, self)
Expand Down
38 changes: 12 additions & 26 deletions ibis/sql/postgres/client.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,6 @@
# Copyright 2015 Cloudera Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import getpass
from typing import Optional

import psycopg2 # NOQA fail early if the driver is missing
import sqlalchemy as sa
Expand Down Expand Up @@ -50,21 +37,20 @@ class PostgreSQLClient(alch.AlchemyClient):

def __init__(
self,
host='localhost',
user=None,
password=None,
port=5432,
database='public',
url=None,
driver='psycopg2',
host: str = 'localhost',
user: str = getpass.getuser(),
password: Optional[str] = None,
port: int = 5432,
database: str = 'public',
url: Optional[str] = None,
driver: str = 'psycopg2',
):
if url is None:
if driver != 'psycopg2':
raise NotImplementedError(
'psycopg2 is currently the only supported driver'
)
user = user or getpass.getuser()
url = sa.engine.url.URL(
sa_url = sa.engine.url.URL(
'postgresql+psycopg2',
host=host,
port=port,
Expand All @@ -73,10 +59,10 @@ def __init__(
database=database,
)
else:
url = sa.engine.url.make_url(url)
sa_url = sa.engine.url.make_url(url)

super().__init__(sa.create_engine(url))
self.database_name = url.database
super().__init__(sa.create_engine(sa_url))
self.database_name = sa_url.database

@contextlib.contextmanager
def begin(self):
Expand Down
44 changes: 30 additions & 14 deletions ibis/sql/sqlite/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import errno
import functools
import inspect
import math
import os
from typing import Optional

import regex as re
import sqlalchemy as sa

import ibis.common.exceptions as com
import ibis.sql.alchemy as alch
from ibis.client import Database
from ibis.sql.sqlite.compiler import SQLiteDialect
Expand Down Expand Up @@ -312,19 +313,16 @@ def _register_aggregate(agg, con):


class SQLiteClient(alch.AlchemyClient):

"""
The Ibis SQLite client class
"""
"""The Ibis SQLite client class."""

dialect = SQLiteDialect
database_class = SQLiteDatabase
table_class = SQLiteTable

def __init__(self, path=None, create=False):
super().__init__(sa.create_engine('sqlite://'))
super().__init__(sa.create_engine("sqlite://"))
self.name = path
self.database_name = 'base'
self.database_name = "base"

if path is not None:
self.attach(self.database_name, path, create=create)
Expand All @@ -336,18 +334,18 @@ def __init__(self, path=None, create=False):
self.con.run_callable(functools.partial(_register_aggregate, agg))

@property
def current_database(self):
def current_database(self) -> Optional[str]:
return self.database_name

def list_databases(self):
raise NotImplementedError(
'Listing databases in SQLite is not implemented'
)

def set_database(self, name):
def set_database(self, name: str) -> None:
raise NotImplementedError('set_database is not implemented for SQLite')

def attach(self, name, path, create=False):
def attach(self, name, path, create: bool = False) -> None:
"""Connect another SQLite database file
Parameters
Expand All @@ -359,21 +357,32 @@ def attach(self, name, path, create=False):
create : boolean, optional
If file does not exist, create file if True otherwise raise an
Exception
"""
if not os.path.exists(path) and not create:
raise com.IbisError('File {!r} does not exist'.format(path))
raise FileNotFoundError(
errno.ENOENT, os.strerror(errno.ENOENT), path
)

quoted_name = self.con.dialect.identifier_preparer.quote(name)
self.raw_sql(
"ATTACH DATABASE {path!r} AS {name}".format(
path=path,
name=self.con.dialect.identifier_preparer.quote(name),
path=path, name=quoted_name
)
)

@property
def client(self):
return self

def _get_sqla_table(self, name, schema=None, autoload=True):
return sa.Table(
name,
self.meta,
schema=schema or self.current_database,
autoload=autoload,
)

def table(self, name, database=None):
"""
Create a table expression that references a particular table in the
Expand All @@ -387,7 +396,8 @@ def table(self, name, database=None):
Returns
-------
table : TableExpr
TableExpr
"""
alch_table = self._get_sqla_table(name, schema=database)
node = self.table_class(alch_table, self)
Expand All @@ -397,3 +407,9 @@ def list_tables(self, like=None, database=None, schema=None):
if database is None:
database = self.database_name
return super().list_tables(like, schema=database)

def _table_from_schema(
self, name, schema, database: Optional[str] = None
) -> sa.Table:
columns = self._columns_from_schema(name, schema)
return sa.Table(name, self.meta, schema=database, *columns)
6 changes: 1 addition & 5 deletions ibis/sql/sqlite/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@ def dbpath(data_directory):

@pytest.fixture(scope='module')
def con(dbpath):
con = ibis.sqlite.connect(dbpath)
try:
yield con
finally:
con.con.dispose()
return ibis.sqlite.connect(dbpath)


@pytest.fixture(scope='module')
Expand Down
Loading

0 comments on commit b748593

Please sign in to comment.