Skip to content

Commit

Permalink
SQL: support for other databases with SQLAlchemy
Browse files Browse the repository at this point in the history
  • Loading branch information
PrimozGodec committed Nov 10, 2020
1 parent 81ccf30 commit 81e243c
Show file tree
Hide file tree
Showing 10 changed files with 411 additions and 47 deletions.
2 changes: 2 additions & 0 deletions Orange/data/sql/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@
from .mssql import PymssqlBackend
except ImportError:
pass

from .alchemy_base import SQLAlchemyBackend, MSSqlAlchemy, MySqlAlchemy
259 changes: 259 additions & 0 deletions Orange/data/sql/backend/alchemy_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
from contextlib import contextmanager
from datetime import date, datetime, time
from typing import Optional, List, Iterable, Tuple, Any, Union

from sqlalchemy import create_engine, MetaData, select, Table, text, func
from sqlalchemy.exc import NoSuchTableError, ProgrammingError
from sqlalchemy.sql import Select
from sqlalchemy.sql.ddl import DDLElement
from sqlalchemy.sql.elements import and_, TextClause
from sqlalchemy.sql.sqltypes import NullType
from sqlalchemy.ext import compiler

from Orange.data import (
StringVariable,
TimeVariable,
ContinuousVariable,
DiscreteVariable, Domain,
)
from Orange.data.sql.backend import Backend
from Orange.data.sql.backend.base import ToSql, BackendError, TableDesc


class CreateTableAs(DDLElement):
def __init__(self, name, selectable):
self.name = name
self.selectable = selectable


@compiler.compiles(CreateTableAs)
def compile(element, _, **kw):
# in case any backend uses different syntax reimplement this function
# for it
return "CREATE TABLE %s AS %s" % (
element.name,
element.selectable
)


class SQLAlchemyBackend(Backend):
connection_string = (
"{dialect_driver}://{user}:{password}@{host}:"
"{port}/{database}?charset=utf8"
)
dialect_driver = None

def __init__(self, connection_params: dict):
print("init")
super().__init__(connection_params)
self.engine = create_engine(
self.connection_string.format(
dialect_driver=self.dialect_driver,
**connection_params)
)

def list_tables(self, schema: Optional[str] = None):
if not schema:
schema = None
tables = []
for t in self.engine.table_names(schema=schema):
s_t = (schema, t, ) if schema else (t,)
tables.append(TableDesc(t, schema, ".".join(s_t)))
return tables

def create_sql_query(
self,
table_name: str,
fields: List[str],
filters: Iterable[str] = (),
group_by: Optional[List[str]] = None,
order_by: Optional[List[str]] = None,
offset: Optional[int] = None,
limit: Optional[int] = None,
use_time_sample: Optional[int] = None,
) -> Select:
stn = table_name.split(".")
schema, table_name = (None, stn[0]) if len(stn) == 1 else stn
meta = MetaData(bind=self.engine, schema=schema)
try:
table = Table(table_name, meta, autoload=True)
except NoSuchTableError:
# when from SQL sentence - custom SQL in Orange
table = text(table_name)

columns = []
for f in fields:
if isinstance(table, TextClause):
columns.append(text(f))
elif "AS" in f:
col, label = f.split("AS")
columns.append(
table.c[col.strip("() ")].label(label.strip("() "))
)
elif "(" in f or f == "*":
# fields is a function
# TODO: think about not allowing this
# make separate functions for e.g. count
columns.append(text(f))
else:
columns.append(table.c[f])

query = select(columns).select_from(table)
# MSSQL requires an order_by when using an OFFSET or a non-simple
# LIMIT clause
# TODO: check if order_by(None) would be fine
if offset and not order_by:
order_by = [x.strip('" ') for x in fields[0].split("AS")[1:]]

if use_time_sample is not None:
query = query.tablesample(func.system_time(1000))
if filters:
query = query.where(and_(text(f) for f in filters))
if order_by is not None:
query = query.order_by(*[text(o) for o in order_by])
if limit is not None:
query = query.limit(limit)
if offset is not None:
query = query.offset(offset)
if group_by is not None:
query = query.group_by(*[text(g) for g in group_by])
print(query)
return query

@contextmanager
def execute_sql_query(self, query: Union[Select, str], params: Optional[Tuple[Any]] = ()):
with self.engine.connect() as connection:
try:
result = connection.execute(
text(query) if isinstance(query, str) else query, *params
)
yield result
result.close()
except ProgrammingError as ex:
raise BackendError(str(ex)) from ex

def get_fields(self, table_name: str):
query = self.create_sql_query(table_name, ["*"], limit=3)
types = {
c.name: c.type.python_type for c in query.inner_columns
if not isinstance(c.type, NullType)
}

# for plain textual SQL query types cannot be retrieved form the query
# so we get missing types from data
with self.execute_sql_query(query) as cur:
res = cur.fetchall()
missing = set(cur.keys()) - set(types.keys())
for col in missing:
if len(res) > 0:
t = set([type(r[col]) for r in res])
assert len(t) == 1 # types must match
t, = t # unpack set
else:
t = str
types[col] = t
return list(types.items())

def _guess_variable(self, field_name: str, field_metadata: Tuple, inspect_table: Optional[str]):
type_ = field_metadata[0]

if type_ == float:
return ContinuousVariable.make(field_name)

if type_ in (datetime, date, time):
return TimeVariable(
field_name,
have_date=type_ in (date, datetime),
have_time=type_ in (time, datetime)
)

if type_ == int:
if inspect_table:
values = self.get_distinct_values(field_name, inspect_table)
if values:
return DiscreteVariable(field_name, values)
return ContinuousVariable(field_name)

if type_ == bool:
return DiscreteVariable(field_name, ["false", "true"])

if type_ == str:
if inspect_table:
values = self.get_distinct_values(field_name, inspect_table)
# remove trailing spaces
values = [v.rstrip() for v in values]
if values:
return DiscreteVariable(field_name, values)

return StringVariable(field_name)

def create_variable(
self, field_name: str, field_metadata: Tuple[Any], type_hints: Domain, inspect_table: Optional[str] = None
):
if field_name in type_hints:
var = type_hints[field_name]
else:
var = self._guess_variable(
field_name, field_metadata, inspect_table
)

field_name_q = self.quote_identifier(field_name)
if var.is_continuous:
if isinstance(var, TimeVariable):
var.to_sql = ToSql(field_name_q)
else:
var.to_sql = ToSql(field_name_q)
else: # discrete or string
var.to_sql = ToSql(field_name_q)
return var

def count_approx(self, query: Select):
"""
Count is faster than fetching complete table
"""
q = query.alias("subquery")
q = select([text("COUNT(*)")]).select_from(q)
with self.execute_sql_query(q) as cur:
return cur.fetchone()[0]

def unquote_identifier(self, quoted_name: str) -> str:
return quoted_name

def quote_identifier(self, name: str) -> str:
return name

def create_table(self, name: str, sql: str) -> None:
with self.engine.begin() as conn:
conn.execute(CreateTableAs(name, sql))

def drop_table(self, name):
stn = name.split(".")
schema, table_name = (None, stn[0]) if len(stn) == 1 else stn
meta = MetaData(bind=self.engine, schema=schema)
try:
table = Table(table_name, meta, autoload=True)
except NoSuchTableError:
return
table.drop()

def table_exists(self, name: str) -> bool:
return self.engine.dialect.has_table(self.engine, name)


class MSSqlAlchemy(SQLAlchemyBackend):
display_name = "MS Server Alchemy"
dialect_driver = "mssql+pymssql"


class MySqlAlchemy(SQLAlchemyBackend):
display_name = "MySQL Alchemy"
# we decided to use mysqlclient from pypi
# installed via: pip install mysqlclient
dialect_driver = "mysql+mysqldb"


class SqliteAlchemy(SQLAlchemyBackend):
display_name = "Sqllite"
# requirement sqlite3 - included in the standard module
dialect_driver = "sqlite+pysqlite"
connection_string = "{dialect_driver}:///{database}"
36 changes: 35 additions & 1 deletion Orange/data/sql/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ def __init__(self, connection_params):
@classmethod
def available_backends(cls):
"""Return a list of all available backends"""
return cls.registry.values()
return [
c for c in cls.registry.values()
if c.__name__ != "SQLAlchemyBackend"
]

# "meta" methods

Expand Down Expand Up @@ -220,6 +223,36 @@ def unquote_identifier(self, quoted_name):
"""
raise NotImplementedError

def create_table(self, name: str, sql: str) -> None:
"""
Create new SQL table with provided name form the SQL query
Parameters
----------
name
The name of the new table
sql
The sql query
"""
raise NotImplementedError

def drop_table(self, name):
"""
Drops table from the database
Parameters
----------
name
Database name
"""
raise NotImplementedError

def table_exists(self, name: str) -> bool:
"""
Check if table exists in the database
"""
raise NotImplementedError


class TableDesc:
def __init__(self, name, schema, sql):
Expand All @@ -230,6 +263,7 @@ def __init__(self, name, schema, sql):
def __str__(self):
return self.name


class ToSql:
def __init__(self, sql):
self.sql = sql
Expand Down
21 changes: 20 additions & 1 deletion Orange/data/sql/backend/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def parse_ex(ex: Exception) -> str:


class PymssqlBackend(Backend):
display_name = "SQL Server"
display_name = "Ms SQL Server"

def __init__(self, connection_params):
connection_params["server"] = connection_params.pop("host", None)
Expand Down Expand Up @@ -78,6 +78,7 @@ def create_sql_query(self, table_name, fields, filters=(),

@contextmanager
def execute_sql_query(self, query, params=()):
print(query)
try:
with self.connection.cursor() as cur:
cur.execute(query, *params)
Expand Down Expand Up @@ -152,3 +153,21 @@ def count_approx(self, query):
warnings.warn("SHOWPLAN permission denied, count approximates will not be used")
return None
raise BackendError(parse_ex(ex)) from ex

def create_table(self, name: str, sql: str) -> None:
query = f"SELECT * INTO {name} FROM ({sql}) source_table"
with self.execute_sql_query(query):
pass

def drop_table(self, name):
query = f"DROP TABLE IF EXISTS {name}"
with self.execute_sql_query(query):
pass

def table_exists(self, name: str) -> bool:
stmt = (
f"SELECT * FROM INFORMATION_SCHEMA.TABLES "
f"WHERE TABLE_NAME = '{name}'"
)
with self.execute_sql_query(stmt) as cur:
return bool(cur.fetchone())
20 changes: 19 additions & 1 deletion Orange/data/sql/backend/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,27 @@ def _guess_variable(self, field_name, field_metadata, inspect_table):
def count_approx(self, query):
sql = "EXPLAIN " + query
with self.execute_sql_query(sql) as cur:
s = ''.join(row[0] for row in cur.fetchall())
s = ''.join( w[0] for row in cur.fetchall())
return int(re.findall(r'rows=(\d*)', s)[0])

def create_table(self, name: str, sql: str) -> None:
query = f"CREATE TABLE {name} AS {sql}"
with self.execute_sql_query(query):
pass

def drop_table(self, name):
query = f"DROP TABLE IF EXISTS {name}"
with self.execute_sql_query(query):
pass

def table_exists(self, name: str) -> bool:
stmt = (
f"SELECT * FROM information_schema.tables "
f"WHERE TABLE_NAME = '{name}'"
)
with self.execute_sql_query(stmt) as cur:
return bool(cur.fetchone())

def __getstate__(self):
# Drop connection_pool from state as it cannot be pickled
state = dict(self.__dict__)
Expand Down
Loading

0 comments on commit 81e243c

Please sign in to comment.