Skip to content

Commit

Permalink
Support sqlalchemy>=2
Browse files Browse the repository at this point in the history
Signed-off-by: Tim Paine <timothy.paine@cubistsystematic.com>
  • Loading branch information
timkpaine committed Feb 2, 2024
1 parent 2acdade commit a03fff4
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 20 deletions.
52 changes: 52 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -515,3 +515,55 @@ jobs:
# run: make test

##########################################################################################################################

test_dependencies:
needs:
- initialize
- build

strategy:
matrix:
os:
- ubuntu-20.04
python-version:
- 3.9
package:
- "sqlalchemy>=2"
- "sqlalchemy<2"

runs-on: ${{ matrix.os }}

steps:
- name: Checkout
uses: actions/checkout@v4
with:
submodules: recursive

- name: Set up Python ${{ matrix.python-version }}
uses: ./.github/actions/setup-python
with:
version: '${{ matrix.python-version }}'

- name: Install python dependencies
run: make requirements

- name: Install test dependencies (Linux)
shell: bash
run: sudo apt-get install graphviz

# Download artifact
- name: Download wheel
uses: actions/download-artifact@v4
with:
name: csp-dist-${{ runner.os }}-${{ matrix.python-version }}

- name: Install wheel (Linux)
run: python -m pip install -U *manylinux2014*.whl --target .

- name: Install wheel (Linux)
run: python -m pip install -U ${{ matrix.package }}

# Run tests
- name: Python Test Steps
run: make test TEST_ARGS="-k TestDBReader"
if: ${{ contains( 'sqlalchemy', matrix.package )}}
5 changes: 3 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,12 @@ checks: check
#########
.PHONY: test-py coverage-py test tests

TEST_ARGS :=
test-py: ## Clean and Make unit tests
python -m pytest -v csp/tests --junitxml=junit.xml
python -m pytest -v csp/tests --junitxml=junit.xml $(TEST_ARGS)

coverage-py:
python -m pytest -v csp/tests --junitxml=junit.xml --cov=csp --cov-report xml --cov-report html --cov-branch --cov-fail-under=80 --cov-report term-missing
python -m pytest -v csp/tests --junitxml=junit.xml --cov=csp --cov-report xml --cov-report html --cov-branch --cov-fail-under=80 --cov-report term-missing $(TEST_ARGS)

test: test-py ## run the tests

Expand Down
27 changes: 17 additions & 10 deletions csp/adapters/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def schema_struct(self):
name = "DBDynStruct_{table}_{schema}".format(table=self._table_name or "", schema=self._schema_name or "")
if name not in globals():
db_metadata = db.MetaData(schema=self._schema_name)
table = db.Table(self._table_name, db_metadata, autoload=True, autoload_with=self._connection)
table = db.Table(self._table_name, db_metadata, autoload_with=self._connection)
struct_metadata = {col: col_obj.type.python_type for col, col_obj in table.columns.items()}

from csp.impl.struct import defineStruct
Expand Down Expand Up @@ -301,23 +301,30 @@ def __init__(self, engine, adapterRep):
self._row = None

def start(self, starttime, endtime):
query = self.build_query(starttime, endtime)
self._query = self.build_query(starttime, endtime)
if self._rep._log_query:
import logging

logging.info("DBReader query: %s", query)
self._q = self._rep._connection.execute(query)
logging.info("DBReader query: %s", self._query)
self._data_yielder = self._data_yielder_function()

def _data_yielder_function(self):
with self._rep._connection.connect() as conn:
for result in conn.execute(self._query).mappings():
yield result
# Signify the end
yield None

def build_query(self, starttime, endtime):
if self._rep._table_name:
metadata = db.MetaData(schema=self._rep._schema_name)
table = db.Table(self._rep._table_name, metadata, autoload=True, autoload_with=self._rep._connection)
table = db.Table(self._rep._table_name, metadata, autoload_with=self._rep._connection)
cols = [table.c[colname] for colname in self._rep._requested_cols]
q = db.select(cols)
q = db.select(*cols)
elif self._rep._use_raw_user_query:
return db.text(self._rep._query)
else: # self._rep._query
from_obj = db.text(f"({self._rep._query}) AS user_query")
from_obj = db.text(f"({self._rep._query})")

time_columns = self._rep._time_accessor.get_time_columns(self._rep._connection)
if time_columns:
Expand All @@ -330,7 +337,7 @@ def build_query(self, starttime, endtime):
time_columns = []
time_select = []
select_cols = [db.column(colname) for colname in self._rep._requested_cols.difference(set(time_columns))]
q = db.select(select_cols + time_select, from_obj=from_obj)
q = db.select(*(select_cols + time_select)).select_from(from_obj)

cond = self._rep._time_accessor.get_time_constraint(starttime.replace(tzinfo=UTC), endtime.replace(tzinfo=UTC))

Expand Down Expand Up @@ -361,15 +368,15 @@ def register_input_adapter(self, symbol, adapter):

def process_next_sim_timeslice(self, now):
if self._row is None:
self._row = self._q.fetchone()
self._row = next(self._data_yielder)

now = now.replace(tzinfo=UTC)
while self._row is not None:
time = self._rep._time_accessor.get_time(self._row)
if time > now:
return time
self.process_row(self._row)
self._row = self._q.fetchone()
self._row = next(self._data_yielder)

return None

Expand Down
24 changes: 17 additions & 7 deletions csp/tests/adapters/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def _prepopulate_in_mem_engine(self):
{"TIME": starttime.replace(second=5), "SYMBOL": "AAPL", "PRICE": 200.0, "SIZE": 400, "SIDE": "BUY"},
{"TIME": starttime.replace(second=6), "SYMBOL": "GM", "PRICE": 2.0, "SIZE": 1, "SIDE": "BUY"},
]
engine.execute(query, values_list)
with engine.connect() as conn:
conn.execute(query, values_list)
conn.commit()
return engine

def test_sqlite_basic(self):
Expand Down Expand Up @@ -92,7 +94,7 @@ def graph():

# UTC
result = csp.run(graph, starttime=datetime(2020, 3, 3, 9, 30))
print(result)

self.assertEqual(len(result["aapl"]), 4)
self.assertTrue(all(v[1].SYMBOL == "AAPL" for v in result["aapl"]))

Expand Down Expand Up @@ -211,7 +213,9 @@ def test_sqlite_constraints(self):
"SIDE": "BUY",
},
]
engine.execute(query, values_list)
with engine.connect() as conn:
conn.execute(query, values_list)
conn.commit()

def graph():
time_accessor = DateTimeAccessor(date_column="DATE", time_column="TIME", tz=pytz.timezone("US/Eastern"))
Expand Down Expand Up @@ -310,7 +314,9 @@ def test_join_query(self):
{"TIME": starttime.replace(second=5), "SYMBOL": "AAPL", "PRICE": 200.0},
{"TIME": starttime.replace(second=6), "SYMBOL": "GM", "PRICE": 2.0},
]
engine.execute(query, values_list1)
with engine.connect() as conn:
conn.execute(query, values_list1)
conn.commit()

query = db.insert(test2)
values_list2 = [
Expand All @@ -322,15 +328,17 @@ def test_join_query(self):
# { 'TIME': starttime.replace( second = 5 ), 'SIZE': 400, 'SIDE': 'BUY' },
{"TIME": starttime.replace(second=6), "SIZE": 1, "SIDE": "BUY"},
]
engine.execute(query, values_list2)
with engine.connect() as conn:
conn.execute(query, values_list2)
conn.commit()

metadata.create_all(engine)

def graph():
time_accessor = TimestampAccessor(time_column="TIME", tz=pytz.timezone("US/Eastern"))
query = "select * from test1 inner join test2 on test2.TIME=test1.TIME"
reader = DBReader.create_from_connection(
connection=engine.connect(), query=query, time_accessor=time_accessor, symbol_column="SYMBOL"
connection=engine, query=query, time_accessor=time_accessor, symbol_column="SYMBOL"
)

# Struct
Expand Down Expand Up @@ -414,7 +422,9 @@ def test_DateTimeAccessor(self):
(datetime(2020, 3, 5, 12), 700.0),
]
values_list = [{"DATE": v[0].date(), "TIME": v[0].time(), "SYMBOL": "AAPL", "PRICE": v[1]} for v in values]
engine.execute(query, values_list)
with engine.connect() as conn:
conn.execute(query, values_list)
conn.commit()

def graph():
time_accessor = DateTimeAccessor(date_column="DATE", time_column="TIME", tz=pytz.timezone("US/Eastern"))
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ develop = [
"pytest-cov",
"pytest-sugar",
"scikit-build",
"sqlalchemy<2",
"sqlalchemy",
]

[tool.check-manifest]
Expand Down

0 comments on commit a03fff4

Please sign in to comment.