diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ec4dcf47..f34499da 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -540,3 +540,69 @@ jobs: # run: make test ########################################################################################################################## + + + ################################# + #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + #~~~~~~~~~|##########|~~~~~~~~~~# + #~~~~~~~~~|##|~~~~~~~~~~~~~~~~~~# + #~~~~~~~~~|##|~~~~~~~~~~~~~~~~~~# + #~~~~~~~~~|##########|~~~~~~~~~~# + #~~~~~~~~~|##|~~~~|##|~~~~~~~~~~# + #~~~~~~~~~|##|~~~~|##|~~~~~~~~~~# + #~~~~~~~~~|##########|~~~~~~~~~~# + #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + # Test Dependencies/Regressions # + #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# + test_dependencies: + needs: + - initialize + - build + + strategy: + matrix: + os: + - ubuntu-20.04 + python-version: + - 3.9 + package: + - "sqlalchemy>=2" + - "sqlalchemy<2" + + runs-on: ${{ matrix.os }} + + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Set up Python ${{ matrix.python-version }} + uses: ./.github/actions/setup-python + with: + version: '${{ matrix.python-version }}' + + - name: Install python dependencies + run: make requirements + + - name: Install test dependencies + shell: bash + run: sudo apt-get install graphviz + + # Download artifact + - name: Download wheel + uses: actions/download-artifact@v4 + with: + name: csp-dist-${{ runner.os }}-${{ runner.arch }}-${{ matrix.python-version }} + + - name: Install wheel + run: python -m pip install -U *manylinux2014*.whl --target . + + - name: Install package - ${{ matrix.package }} + run: python -m pip install -U "${{ matrix.package }}" + + # Run tests + - name: Python Test Steps + run: make test TEST_ARGS="-k TestDBReader" + if: ${{ contains( 'sqlalchemy', matrix.package )}} diff --git a/Makefile b/Makefile index 33eb855f..259f20bd 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ EXTRA_ARGS := ######### .PHONY: develop build-py build install -requirements: ## install python dev dependnecies +requirements: ## install python dev and runtime dependencies python -m pip install toml python -m pip install `python -c 'import toml; c = toml.load("pyproject.toml"); print("\n".join(c["build-system"]["requires"]))'` python -m pip install `python -c 'import toml; c = toml.load("pyproject.toml"); print("\n".join(c["project"]["optional-dependencies"]["develop"]))'` @@ -64,11 +64,12 @@ checks: check ######### .PHONY: test-py coverage-py test tests +TEST_ARGS := test-py: ## Clean and Make unit tests - python -m pytest -v csp/tests --junitxml=junit.xml + python -m pytest -v csp/tests --junitxml=junit.xml $(TEST_ARGS) coverage-py: - python -m pytest -v csp/tests --junitxml=junit.xml --cov=csp --cov-report xml --cov-report html --cov-branch --cov-fail-under=80 --cov-report term-missing + python -m pytest -v csp/tests --junitxml=junit.xml --cov=csp --cov-report xml --cov-report html --cov-branch --cov-fail-under=80 --cov-report term-missing $(TEST_ARGS) test: test-py ## run the tests diff --git a/csp/adapters/db.py b/csp/adapters/db.py index b8eb1267..1aba1240 100644 --- a/csp/adapters/db.py +++ b/csp/adapters/db.py @@ -9,7 +9,8 @@ from backports import zoneinfo import pytz -import sqlalchemy as db +from importlib.metadata import PackageNotFoundError, version as get_package_version +from packaging import version from csp import PushMode, ts from csp.impl.adaptermanager import AdapterManagerImpl, ManagedSimInputAdapter @@ -17,6 +18,19 @@ UTC = zoneinfo.ZoneInfo("UTC") +try: + if version.parse(get_package_version("sqlalchemy")) >= version.parse("2"): + _SQLALCHEMY_2 = True + else: + _SQLALCHEMY_2 = False + + import sqlalchemy as db + + _HAS_SQLALCHEMY = True +except (PackageNotFoundError, ValueError, TypeError, ImportError): + _HAS_SQLALCHEMY = False + db = None + class TimeAccessor(ABC): @abstractmethod @@ -185,6 +199,8 @@ def __init__( :param log_query: set to True to see what query was generated to access the data :param use_raw_user_query: Don't do any alteration to user query, assume it contains all the needed columns and sorting """ + if not _HAS_SQLALCHEMY: + raise RuntimeError("Could not find SQLAlchemy installation") self._connection = connection self._table_name = table_name self._schema_name = schema_name @@ -248,7 +264,7 @@ def schema_struct(self): name = "DBDynStruct_{table}_{schema}".format(table=self._table_name or "", schema=self._schema_name or "") if name not in globals(): db_metadata = db.MetaData(schema=self._schema_name) - table = db.Table(self._table_name, db_metadata, autoload=True, autoload_with=self._connection) + table = db.Table(self._table_name, db_metadata, autoload_with=self._connection) struct_metadata = {col: col_obj.type.python_type for col, col_obj in table.columns.items()} from csp.impl.struct import defineStruct @@ -301,23 +317,44 @@ def __init__(self, engine, adapterRep): self._row = None def start(self, starttime, endtime): - query = self.build_query(starttime, endtime) + self._query = self.build_query(starttime, endtime) if self._rep._log_query: import logging - logging.info("DBReader query: %s", query) - self._q = self._rep._connection.execute(query) + logging.info("DBReader query: %s", self._query) + if _SQLALCHEMY_2: + self._data_yielder = self._data_yielder_function() + else: + self._q = self._rep._connection.execute(self._query) + + def _data_yielder_function(self): + # Connection yielder for SQLAlchemy 2 + with self._rep._connection.connect() as conn: + for result in conn.execute(self._query).mappings(): + yield result + # Signify the end + yield None def build_query(self, starttime, endtime): if self._rep._table_name: metadata = db.MetaData(schema=self._rep._schema_name) - table = db.Table(self._rep._table_name, metadata, autoload=True, autoload_with=self._rep._connection) - cols = [table.c[colname] for colname in self._rep._requested_cols] - q = db.select(cols) + + if _SQLALCHEMY_2: + table = db.Table(self._rep._table_name, metadata, autoload_with=self._rep._connection) + cols = [table.c[colname] for colname in self._rep._requested_cols] + q = db.select(*cols) + else: + table = db.Table(self._rep._table_name, metadata, autoload=True, autoload_with=self._rep._connection) + cols = [table.c[colname] for colname in self._rep._requested_cols] + q = db.select(cols) + elif self._rep._use_raw_user_query: return db.text(self._rep._query) else: # self._rep._query - from_obj = db.text(f"({self._rep._query}) AS user_query") + if _SQLALCHEMY_2: + from_obj = db.text(f"({self._rep._query})") + else: + from_obj = db.text(f"({self._rep._query}) AS user_query") time_columns = self._rep._time_accessor.get_time_columns(self._rep._connection) if time_columns: @@ -330,7 +367,11 @@ def build_query(self, starttime, endtime): time_columns = [] time_select = [] select_cols = [db.column(colname) for colname in self._rep._requested_cols.difference(set(time_columns))] - q = db.select(select_cols + time_select, from_obj=from_obj) + + if _SQLALCHEMY_2: + q = db.select(*(select_cols + time_select)).select_from(from_obj) + else: + q = db.select(select_cols + time_select, from_obj=from_obj) cond = self._rep._time_accessor.get_time_constraint(starttime.replace(tzinfo=UTC), endtime.replace(tzinfo=UTC)) @@ -361,7 +402,10 @@ def register_input_adapter(self, symbol, adapter): def process_next_sim_timeslice(self, now): if self._row is None: - self._row = self._q.fetchone() + if _SQLALCHEMY_2: + self._row = next(self._data_yielder) + else: + self._row = self._q.fetchone() now = now.replace(tzinfo=UTC) while self._row is not None: @@ -369,8 +413,10 @@ def process_next_sim_timeslice(self, now): if time > now: return time self.process_row(self._row) - self._row = self._q.fetchone() - + if _SQLALCHEMY_2: + self._row = next(self._data_yielder) + else: + self._row = self._q.fetchone() return None def process_row(self, row): diff --git a/csp/adapters/output_adapters/parquet.py b/csp/adapters/output_adapters/parquet.py index adb908f0..1d6f8a18 100644 --- a/csp/adapters/output_adapters/parquet.py +++ b/csp/adapters/output_adapters/parquet.py @@ -1,6 +1,6 @@ import numpy import os -import pkg_resources +from importlib.metadata import PackageNotFoundError, version as get_package_version from packaging import version from typing import Callable, Dict, Optional, TypeVar @@ -37,10 +37,13 @@ def resolve_compression(self): def _get_default_parquet_version(): - if version.parse(pkg_resources.get_distribution("pyarrow").version) >= version.parse("6.0.1"): - return "2.6" - else: - return "2.0" + try: + if version.parse(get_package_version("pyarrow")) >= version.parse("6.0.1"): + return "2.6" + except PackageNotFoundError: + # Don't need to do anything in particular + ... + return "2.0" class ParquetWriter: diff --git a/csp/adapters/parquet.py b/csp/adapters/parquet.py index 1ca5c475..85e457ee 100644 --- a/csp/adapters/parquet.py +++ b/csp/adapters/parquet.py @@ -1,10 +1,10 @@ import datetime import io import numpy -import pkg_resources import platform import pyarrow import pyarrow.parquet +from importlib.metadata import PackageNotFoundError, version as get_package_version from packaging import version from typing import TypeVar @@ -28,9 +28,9 @@ try: _CAN_READ_ARROW_BINARY = False - if version.parse(pkg_resources.get_distribution("pyarrow").version) >= version.parse("4.0.1"): + if version.parse(get_package_version("pyarrow")) >= version.parse("4.0.1"): _CAN_READ_ARROW_BINARY = True -except (ValueError, TypeError): +except (PackageNotFoundError, ValueError, TypeError): # Cannot read binary arrow ... diff --git a/csp/tests/adapters/test_db.py b/csp/tests/adapters/test_db.py index 45e9ad47..87c247cf 100644 --- a/csp/tests/adapters/test_db.py +++ b/csp/tests/adapters/test_db.py @@ -5,7 +5,7 @@ from datetime import date, datetime, time import csp -from csp.adapters.db import DateTimeAccessor, DBReader, EngineStartTimeAccessor, TimestampAccessor +from csp.adapters.db import _SQLALCHEMY_2, DateTimeAccessor, DBReader, EngineStartTimeAccessor, TimestampAccessor class PriceQuantity(csp.Struct): @@ -21,6 +21,15 @@ class PriceQuantity2(csp.Struct): side: str +def execute_with_commit(engine, query, values): + if _SQLALCHEMY_2: + with engine.connect() as conn: + conn.execute(query, values) + conn.commit() + else: + engine.execute(query, values) + + class TestDBReader(unittest.TestCase): def _prepopulate_in_mem_engine(self): engine = db.create_engine("sqlite:///:memory:") # in-memory sqlite db @@ -46,7 +55,7 @@ def _prepopulate_in_mem_engine(self): {"TIME": starttime.replace(second=5), "SYMBOL": "AAPL", "PRICE": 200.0, "SIZE": 400, "SIDE": "BUY"}, {"TIME": starttime.replace(second=6), "SYMBOL": "GM", "PRICE": 2.0, "SIZE": 1, "SIDE": "BUY"}, ] - engine.execute(query, values_list) + execute_with_commit(engine, query, values_list) return engine def test_sqlite_basic(self): @@ -92,7 +101,7 @@ def graph(): # UTC result = csp.run(graph, starttime=datetime(2020, 3, 3, 9, 30)) - print(result) + self.assertEqual(len(result["aapl"]), 4) self.assertTrue(all(v[1].SYMBOL == "AAPL" for v in result["aapl"])) @@ -211,7 +220,8 @@ def test_sqlite_constraints(self): "SIDE": "BUY", }, ] - engine.execute(query, values_list) + + execute_with_commit(engine, query, values_list) def graph(): time_accessor = DateTimeAccessor(date_column="DATE", time_column="TIME", tz=pytz.timezone("US/Eastern")) @@ -310,7 +320,7 @@ def test_join_query(self): {"TIME": starttime.replace(second=5), "SYMBOL": "AAPL", "PRICE": 200.0}, {"TIME": starttime.replace(second=6), "SYMBOL": "GM", "PRICE": 2.0}, ] - engine.execute(query, values_list1) + execute_with_commit(engine, query, values_list1) query = db.insert(test2) values_list2 = [ @@ -322,7 +332,7 @@ def test_join_query(self): # { 'TIME': starttime.replace( second = 5 ), 'SIZE': 400, 'SIDE': 'BUY' }, {"TIME": starttime.replace(second=6), "SIZE": 1, "SIDE": "BUY"}, ] - engine.execute(query, values_list2) + execute_with_commit(engine, query, values_list2) metadata.create_all(engine) @@ -330,7 +340,7 @@ def graph(): time_accessor = TimestampAccessor(time_column="TIME", tz=pytz.timezone("US/Eastern")) query = "select * from test1 inner join test2 on test2.TIME=test1.TIME" reader = DBReader.create_from_connection( - connection=engine.connect(), query=query, time_accessor=time_accessor, symbol_column="SYMBOL" + connection=engine, query=query, time_accessor=time_accessor, symbol_column="SYMBOL" ) # Struct @@ -414,7 +424,8 @@ def test_DateTimeAccessor(self): (datetime(2020, 3, 5, 12), 700.0), ] values_list = [{"DATE": v[0].date(), "TIME": v[0].time(), "SYMBOL": "AAPL", "PRICE": v[1]} for v in values] - engine.execute(query, values_list) + + execute_with_commit(engine, query, values_list) def graph(): time_accessor = DateTimeAccessor(date_column="DATE", time_column="TIME", tz=pytz.timezone("US/Eastern")) diff --git a/dev-environment.yml b/dev-environment.yml index 9f66f499..c6bf3996 100644 --- a/dev-environment.yml +++ b/dev-environment.yml @@ -11,7 +11,7 @@ dependencies: - ruamel.yaml - scikit-build - psutil - - sqlalchemy<2 + - sqlalchemy - bump2version>=1.0.0 - python-graphviz - httpx diff --git a/pyproject.toml b/pyproject.toml index 44bed976..c697fed5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ develop = [ "pytest-cov", "pytest-sugar", "scikit-build", - "sqlalchemy<2", + "sqlalchemy", "tornado", ]