forked from hkwi/sqlalchemy_gevent
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sqlalchemy_gevent.py
167 lines (141 loc) · 4.8 KB
/
sqlalchemy_gevent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import sqlalchemy.engine
from sqlalchemy.dialects import registry
import sqlalchemy.dialects.sqlite
import gevent
import gevent.threadpool
import importlib
import functools
from sqlalchemy.engine import interfaces
def call_in_gevent(tp_factory):
def wraps(func):
if tp_factory is None:
return func
@functools.wraps(func)
def proxy(*args, **kwargs):
threadpool = tp_factory()
return threadpool.apply_e(BaseException, func, args, kwargs)
return proxy
return wraps
class Proxy(object):
_intercept = dict()
def __init__(self, inner):
self._inner = inner
def __getattr__(self, name):
obj = getattr(self._inner, name)
if name in self._intercept:
return self._intercept[name](obj)
else:
return obj
def cursor_proxy(tp_factory):
g = call_in_gevent(tp_factory)
ic = {k:g for k in ("callproc", "close", "execute", "executemany", "fetchone",
"fetchmany", "fetchall", "nextset", "setinputsizes", "setoutputsize")}
def proxy(func):
@functools.wraps(func)
def wraps(*args, **kwargs):
cur = g(func)(*args, **kwargs)
return type("CursorProxy", (Proxy,), {"_intercept":ic})(cur)
return wraps
return proxy
def connection_proxy(tp_factory):
g = call_in_gevent(tp_factory)
ic = {k:g for k in ("close", "commit", "rollback")}
ic["cursor"] = cursor_proxy(tp_factory)
def proxy(func):
@functools.wraps(func)
def wraps(*args, **kwargs):
con = g(func)(*args, **kwargs)
return type("ConnectionProxy", (Proxy,), {"_intercept":ic})(con)
return wraps
return proxy
def dbapi_proxy(tp_factory):
g = call_in_gevent(tp_factory)
ic = dict(connect= connection_proxy(tp_factory))
return type("DbapiProxy", (Proxy,), {"_intercept":ic})
def dbapi_factory_proxy(tp_factory):
def proxy(func):
@functools.wraps(func)
def wraps(*args, **kwargs):
m = func(*args, **kwargs) # obtain dbapi module
return dbapi_proxy(tp_factory)(m)
return wraps
return proxy
def dialect_on_connect_proxy(tp_factory):
def cb_proxy(func):
@functools.wraps(func)
def wraps(*args, **kwargs):
f = lambda x: x._inner if isinstance(x, Proxy) else x
args = [f(a) for a in args]
kwargs = {k:f(v) for k,v in kwargs.items()}
return call_in_gevent(tp_factory)(func)(*args, **kwargs)
return wraps
def proxy(func):
@functools.wraps(func)
def wraps(*args, **kwargs):
cb = func(*args, **kwargs)
if cb:
return cb_proxy(cb)
return wraps
return proxy
class DialectProxy(object):
_tp_factory = None
def __init__(self, inner):
self._inner = inner
def __getattr__(self, name):
obj = getattr(self._inner, name)
if name == "get_dialect_cls": # Dialect
return lambda *args:self
elif name == "dbapi": # DefaultDialect
return dbapi_factory_proxy(self._tp_factory)(obj)
elif name == "on_connect": # DefaultDialect
return dialect_on_connect_proxy(self._tp_factory)(obj)
else:
return obj
def dialect_init_wrap(tp_factory):
def proxy(func):
@functools.wraps(func)
def wraps(self, *args, **kwargs):
inner = call_in_gevent(tp_factory)(func)(*args, **kwargs)
return type(self.__name__, (DialectProxy,), {"_tp_factory":staticmethod(tp_factory)})(inner)
return wraps
return proxy
single_pool = gevent.threadpool.ThreadPool(1)
def dialect_name(*args):
return "".join([s[0].upper()+s[1:] for s in args if s])+"Dialect"
def dialect_maker(db, driver):
class_name = dialect_name(db, driver)
if driver is None:
driver = "base"
dialect = importlib.import_module("sqlalchemy.dialects.%s.%s" % (db, driver)).dialect
tp_factory = lambda: gevent.get_hub().threadpool
if db == "sqlite": # pysqlite dbapi connection requires single threaded
tp_factory = lambda: single_pool
return type(dialect.__name__, (DialectProxy,), {
"_tp_factory":staticmethod(tp_factory),
"__call__":dialect_init_wrap(tp_factory)(dialect)
})(dialect)
bundled_drivers = {
"drizzle":"mysqldb".split(),
"firebird":"kinterbasdb fdb".split(),
"mssql":"pyodbc adodbapi pymssql zxjdbc mxodbc".split(),
"mysql":"mysqldb oursql pyodbc zxjdbc mysqlconnector pymysql gaerdbms cymysql".split(),
"oracle":"cx_oracle zxjdbc".split(),
"postgresql":"psycopg2 pg8000 pypostgresql zxjdbc".split(),
"sqlite":"pysqlite".split(),
"sybase":"pysybase pyodbc".split()
}
for db, drivers in bundled_drivers.items():
try:
globals()[dialect_name(db)] = dialect_maker(db, None)
registry.register("gevent_%s" % db, "sqlalchemy_gevent", dialect_name(db))
for driver in drivers:
globals()[dialect_name(db,driver)] = dialect_maker(db, driver)
registry.register("gevent_%s.%s" % (db,driver), "sqlalchemy_gevent", dialect_name(db,driver))
except ImportError:
# drizzle was removed in sqlalchemy v1.0
pass
def patch_all():
for db, drivers in bundled_drivers.items():
registry.register(db, "sqlalchemy_gevent", dialect_name(db))
for driver in drivers:
registry.register("%s.%s" % (db,driver), "sqlalchemy_gevent", dialect_name(db,driver))