diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e29e4da06..787cb0321 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -515,3 +515,55 @@ jobs: # run: make test ########################################################################################################################## + + test_dependencies: + needs: + - initialize + - build + + strategy: + matrix: + os: + - ubuntu-20.04 + python-version: + - 3.9 + package: + - "sqlalchemy>=2" + - "sqlalchemy<2" + + runs-on: ${{ matrix.os }} + + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Set up Python ${{ matrix.python-version }} + uses: ./.github/actions/setup-python + with: + version: '${{ matrix.python-version }}' + + - name: Install python dependencies + run: make requirements + + - name: Install test dependencies (Linux) + shell: bash + run: sudo apt-get install graphviz + + # Download artifact + - name: Download wheel + uses: actions/download-artifact@v4 + with: + name: csp-dist-${{ runner.os }}-${{ matrix.python-version }} + + - name: Install wheel (Linux) + run: python -m pip install -U *manylinux2014*.whl --target . + + - name: Install wheel (Linux) + run: python -m pip install -U ${{ matrix.package }} + + # Run tests + - name: Python Test Steps + run: make test TEST_ARGS="-k TestDBReader" + if: ${{ contains( 'sqlalchemy', matrix.package )}} diff --git a/Makefile b/Makefile index 4e483ddd1..ee102337b 100644 --- a/Makefile +++ b/Makefile @@ -64,11 +64,12 @@ checks: check ######### .PHONY: test-py coverage-py test tests +TEST_ARGS := test-py: ## Clean and Make unit tests - python -m pytest -v csp/tests --junitxml=junit.xml + python -m pytest -v csp/tests --junitxml=junit.xml $(TEST_ARGS) coverage-py: - python -m pytest -v csp/tests --junitxml=junit.xml --cov=csp --cov-report xml --cov-report html --cov-branch --cov-fail-under=80 --cov-report term-missing + python -m pytest -v csp/tests --junitxml=junit.xml --cov=csp --cov-report xml --cov-report html --cov-branch --cov-fail-under=80 --cov-report term-missing $(TEST_ARGS) test: test-py ## run the tests diff --git a/csp/adapters/db.py b/csp/adapters/db.py index b8eb12671..c6ffc9935 100644 --- a/csp/adapters/db.py +++ b/csp/adapters/db.py @@ -248,7 +248,7 @@ def schema_struct(self): name = "DBDynStruct_{table}_{schema}".format(table=self._table_name or "", schema=self._schema_name or "") if name not in globals(): db_metadata = db.MetaData(schema=self._schema_name) - table = db.Table(self._table_name, db_metadata, autoload=True, autoload_with=self._connection) + table = db.Table(self._table_name, db_metadata, autoload_with=self._connection) struct_metadata = {col: col_obj.type.python_type for col, col_obj in table.columns.items()} from csp.impl.struct import defineStruct @@ -301,23 +301,30 @@ def __init__(self, engine, adapterRep): self._row = None def start(self, starttime, endtime): - query = self.build_query(starttime, endtime) + self._query = self.build_query(starttime, endtime) if self._rep._log_query: import logging - logging.info("DBReader query: %s", query) - self._q = self._rep._connection.execute(query) + logging.info("DBReader query: %s", self._query) + self._data_yielder = self._data_yielder_function() + + def _data_yielder_function(self): + with self._rep._connection.connect() as conn: + for result in conn.execute(self._query).mappings(): + yield result + # Signify the end + yield None def build_query(self, starttime, endtime): if self._rep._table_name: metadata = db.MetaData(schema=self._rep._schema_name) - table = db.Table(self._rep._table_name, metadata, autoload=True, autoload_with=self._rep._connection) + table = db.Table(self._rep._table_name, metadata, autoload_with=self._rep._connection) cols = [table.c[colname] for colname in self._rep._requested_cols] - q = db.select(cols) + q = db.select(*cols) elif self._rep._use_raw_user_query: return db.text(self._rep._query) else: # self._rep._query - from_obj = db.text(f"({self._rep._query}) AS user_query") + from_obj = db.text(f"({self._rep._query})") time_columns = self._rep._time_accessor.get_time_columns(self._rep._connection) if time_columns: @@ -330,7 +337,7 @@ def build_query(self, starttime, endtime): time_columns = [] time_select = [] select_cols = [db.column(colname) for colname in self._rep._requested_cols.difference(set(time_columns))] - q = db.select(select_cols + time_select, from_obj=from_obj) + q = db.select(*(select_cols + time_select)).select_from(from_obj) cond = self._rep._time_accessor.get_time_constraint(starttime.replace(tzinfo=UTC), endtime.replace(tzinfo=UTC)) @@ -361,7 +368,7 @@ def register_input_adapter(self, symbol, adapter): def process_next_sim_timeslice(self, now): if self._row is None: - self._row = self._q.fetchone() + self._row = next(self._data_yielder) now = now.replace(tzinfo=UTC) while self._row is not None: @@ -369,7 +376,7 @@ def process_next_sim_timeslice(self, now): if time > now: return time self.process_row(self._row) - self._row = self._q.fetchone() + self._row = next(self._data_yielder) return None diff --git a/csp/tests/adapters/test_db.py b/csp/tests/adapters/test_db.py index 45e9ad470..0ca31fcfb 100644 --- a/csp/tests/adapters/test_db.py +++ b/csp/tests/adapters/test_db.py @@ -46,7 +46,9 @@ def _prepopulate_in_mem_engine(self): {"TIME": starttime.replace(second=5), "SYMBOL": "AAPL", "PRICE": 200.0, "SIZE": 400, "SIDE": "BUY"}, {"TIME": starttime.replace(second=6), "SYMBOL": "GM", "PRICE": 2.0, "SIZE": 1, "SIDE": "BUY"}, ] - engine.execute(query, values_list) + with engine.connect() as conn: + conn.execute(query, values_list) + conn.commit() return engine def test_sqlite_basic(self): @@ -92,7 +94,7 @@ def graph(): # UTC result = csp.run(graph, starttime=datetime(2020, 3, 3, 9, 30)) - print(result) + self.assertEqual(len(result["aapl"]), 4) self.assertTrue(all(v[1].SYMBOL == "AAPL" for v in result["aapl"])) @@ -211,7 +213,9 @@ def test_sqlite_constraints(self): "SIDE": "BUY", }, ] - engine.execute(query, values_list) + with engine.connect() as conn: + conn.execute(query, values_list) + conn.commit() def graph(): time_accessor = DateTimeAccessor(date_column="DATE", time_column="TIME", tz=pytz.timezone("US/Eastern")) @@ -310,7 +314,9 @@ def test_join_query(self): {"TIME": starttime.replace(second=5), "SYMBOL": "AAPL", "PRICE": 200.0}, {"TIME": starttime.replace(second=6), "SYMBOL": "GM", "PRICE": 2.0}, ] - engine.execute(query, values_list1) + with engine.connect() as conn: + conn.execute(query, values_list1) + conn.commit() query = db.insert(test2) values_list2 = [ @@ -322,7 +328,9 @@ def test_join_query(self): # { 'TIME': starttime.replace( second = 5 ), 'SIZE': 400, 'SIDE': 'BUY' }, {"TIME": starttime.replace(second=6), "SIZE": 1, "SIDE": "BUY"}, ] - engine.execute(query, values_list2) + with engine.connect() as conn: + conn.execute(query, values_list2) + conn.commit() metadata.create_all(engine) @@ -330,7 +338,7 @@ def graph(): time_accessor = TimestampAccessor(time_column="TIME", tz=pytz.timezone("US/Eastern")) query = "select * from test1 inner join test2 on test2.TIME=test1.TIME" reader = DBReader.create_from_connection( - connection=engine.connect(), query=query, time_accessor=time_accessor, symbol_column="SYMBOL" + connection=engine, query=query, time_accessor=time_accessor, symbol_column="SYMBOL" ) # Struct @@ -414,7 +422,9 @@ def test_DateTimeAccessor(self): (datetime(2020, 3, 5, 12), 700.0), ] values_list = [{"DATE": v[0].date(), "TIME": v[0].time(), "SYMBOL": "AAPL", "PRICE": v[1]} for v in values] - engine.execute(query, values_list) + with engine.connect() as conn: + conn.execute(query, values_list) + conn.commit() def graph(): time_accessor = DateTimeAccessor(date_column="DATE", time_column="TIME", tz=pytz.timezone("US/Eastern")) diff --git a/pyproject.toml b/pyproject.toml index e62fe0410..5b4151a41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ develop = [ "pytest-cov", "pytest-sugar", "scikit-build", - "sqlalchemy<2", + "sqlalchemy", ] [tool.check-manifest]