Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sqlalchemy>=2 #6

Merged
merged 2 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -540,3 +540,69 @@ jobs:
# run: make test

##########################################################################################################################


#################################
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
#~~~~~~~~~|##########|~~~~~~~~~~#
#~~~~~~~~~|##|~~~~~~~~~~~~~~~~~~#
#~~~~~~~~~|##|~~~~~~~~~~~~~~~~~~#
#~~~~~~~~~|##########|~~~~~~~~~~#
#~~~~~~~~~|##|~~~~|##|~~~~~~~~~~#
#~~~~~~~~~|##|~~~~|##|~~~~~~~~~~#
#~~~~~~~~~|##########|~~~~~~~~~~#
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
# Test Dependencies/Regressions #
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
test_dependencies:
needs:
- initialize
- build

strategy:
matrix:
os:
- ubuntu-20.04
python-version:
- 3.9
package:
- "sqlalchemy>=2"
- "sqlalchemy<2"

runs-on: ${{ matrix.os }}

steps:
- name: Checkout
uses: actions/checkout@v4
with:
submodules: recursive

- name: Set up Python ${{ matrix.python-version }}
uses: ./.github/actions/setup-python
with:
version: '${{ matrix.python-version }}'

- name: Install python dependencies
run: make requirements

- name: Install test dependencies
shell: bash
run: sudo apt-get install graphviz

# Download artifact
- name: Download wheel
uses: actions/download-artifact@v4
with:
name: csp-dist-${{ runner.os }}-${{ runner.arch }}-${{ matrix.python-version }}

- name: Install wheel
run: python -m pip install -U *manylinux2014*.whl --target .

- name: Install package - ${{ matrix.package }}
run: python -m pip install -U "${{ matrix.package }}"

# Run tests
- name: Python Test Steps
run: make test TEST_ARGS="-k TestDBReader"
if: ${{ contains( 'sqlalchemy', matrix.package )}}
7 changes: 4 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ EXTRA_ARGS :=
#########
.PHONY: develop build-py build install

requirements: ## install python dev dependnecies
requirements: ## install python dev and runtime dependencies
python -m pip install toml
python -m pip install `python -c 'import toml; c = toml.load("pyproject.toml"); print("\n".join(c["build-system"]["requires"]))'`
python -m pip install `python -c 'import toml; c = toml.load("pyproject.toml"); print("\n".join(c["project"]["optional-dependencies"]["develop"]))'`
Expand Down Expand Up @@ -64,11 +64,12 @@ checks: check
#########
.PHONY: test-py coverage-py test tests

TEST_ARGS :=
test-py: ## Clean and Make unit tests
python -m pytest -v csp/tests --junitxml=junit.xml
python -m pytest -v csp/tests --junitxml=junit.xml $(TEST_ARGS)

coverage-py:
python -m pytest -v csp/tests --junitxml=junit.xml --cov=csp --cov-report xml --cov-report html --cov-branch --cov-fail-under=80 --cov-report term-missing
python -m pytest -v csp/tests --junitxml=junit.xml --cov=csp --cov-report xml --cov-report html --cov-branch --cov-fail-under=80 --cov-report term-missing $(TEST_ARGS)

test: test-py ## run the tests

Expand Down
72 changes: 59 additions & 13 deletions csp/adapters/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,28 @@
from backports import zoneinfo

import pytz
import sqlalchemy as db
from importlib.metadata import PackageNotFoundError, version as get_package_version
from packaging import version

from csp import PushMode, ts
from csp.impl.adaptermanager import AdapterManagerImpl, ManagedSimInputAdapter
from csp.impl.wiring import py_managed_adapter_def

UTC = zoneinfo.ZoneInfo("UTC")

try:
if version.parse(get_package_version("sqlalchemy")) >= version.parse("2"):
_SQLALCHEMY_2 = True
else:
_SQLALCHEMY_2 = False

import sqlalchemy as db

_HAS_SQLALCHEMY = True
except (PackageNotFoundError, ValueError, TypeError, ImportError):
_HAS_SQLALCHEMY = False
db = None


class TimeAccessor(ABC):
@abstractmethod
Expand Down Expand Up @@ -185,6 +199,8 @@ def __init__(
:param log_query: set to True to see what query was generated to access the data
:param use_raw_user_query: Don't do any alteration to user query, assume it contains all the needed columns and sorting
"""
if not _HAS_SQLALCHEMY:
raise RuntimeError("Could not find SQLAlchemy installation")
self._connection = connection
self._table_name = table_name
self._schema_name = schema_name
Expand Down Expand Up @@ -248,7 +264,7 @@ def schema_struct(self):
name = "DBDynStruct_{table}_{schema}".format(table=self._table_name or "", schema=self._schema_name or "")
if name not in globals():
db_metadata = db.MetaData(schema=self._schema_name)
table = db.Table(self._table_name, db_metadata, autoload=True, autoload_with=self._connection)
table = db.Table(self._table_name, db_metadata, autoload_with=self._connection)
struct_metadata = {col: col_obj.type.python_type for col, col_obj in table.columns.items()}

from csp.impl.struct import defineStruct
Expand Down Expand Up @@ -301,23 +317,44 @@ def __init__(self, engine, adapterRep):
self._row = None

def start(self, starttime, endtime):
query = self.build_query(starttime, endtime)
self._query = self.build_query(starttime, endtime)
if self._rep._log_query:
import logging

logging.info("DBReader query: %s", query)
self._q = self._rep._connection.execute(query)
logging.info("DBReader query: %s", self._query)
if _SQLALCHEMY_2:
self._data_yielder = self._data_yielder_function()
else:
self._q = self._rep._connection.execute(self._query)

def _data_yielder_function(self):
# Connection yielder for SQLAlchemy 2
with self._rep._connection.connect() as conn:
for result in conn.execute(self._query).mappings():
yield result
# Signify the end
yield None

def build_query(self, starttime, endtime):
if self._rep._table_name:
metadata = db.MetaData(schema=self._rep._schema_name)
table = db.Table(self._rep._table_name, metadata, autoload=True, autoload_with=self._rep._connection)
cols = [table.c[colname] for colname in self._rep._requested_cols]
q = db.select(cols)

if _SQLALCHEMY_2:
table = db.Table(self._rep._table_name, metadata, autoload_with=self._rep._connection)
cols = [table.c[colname] for colname in self._rep._requested_cols]
q = db.select(*cols)
else:
table = db.Table(self._rep._table_name, metadata, autoload=True, autoload_with=self._rep._connection)
cols = [table.c[colname] for colname in self._rep._requested_cols]
q = db.select(cols)

elif self._rep._use_raw_user_query:
return db.text(self._rep._query)
else: # self._rep._query
from_obj = db.text(f"({self._rep._query}) AS user_query")
if _SQLALCHEMY_2:
from_obj = db.text(f"({self._rep._query})")
else:
from_obj = db.text(f"({self._rep._query}) AS user_query")

time_columns = self._rep._time_accessor.get_time_columns(self._rep._connection)
if time_columns:
Expand All @@ -330,7 +367,11 @@ def build_query(self, starttime, endtime):
time_columns = []
time_select = []
select_cols = [db.column(colname) for colname in self._rep._requested_cols.difference(set(time_columns))]
q = db.select(select_cols + time_select, from_obj=from_obj)

if _SQLALCHEMY_2:
q = db.select(*(select_cols + time_select)).select_from(from_obj)
else:
q = db.select(select_cols + time_select, from_obj=from_obj)

cond = self._rep._time_accessor.get_time_constraint(starttime.replace(tzinfo=UTC), endtime.replace(tzinfo=UTC))

Expand Down Expand Up @@ -361,16 +402,21 @@ def register_input_adapter(self, symbol, adapter):

def process_next_sim_timeslice(self, now):
if self._row is None:
self._row = self._q.fetchone()
if _SQLALCHEMY_2:
self._row = next(self._data_yielder)
else:
self._row = self._q.fetchone()

now = now.replace(tzinfo=UTC)
while self._row is not None:
time = self._rep._time_accessor.get_time(self._row)
if time > now:
return time
self.process_row(self._row)
self._row = self._q.fetchone()

if _SQLALCHEMY_2:
self._row = next(self._data_yielder)
else:
self._row = self._q.fetchone()
return None

def process_row(self, row):
Expand Down
13 changes: 8 additions & 5 deletions csp/adapters/output_adapters/parquet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy
import os
import pkg_resources
from importlib.metadata import PackageNotFoundError, version as get_package_version
from packaging import version
from typing import Callable, Dict, Optional, TypeVar

Expand Down Expand Up @@ -37,10 +37,13 @@ def resolve_compression(self):


def _get_default_parquet_version():
if version.parse(pkg_resources.get_distribution("pyarrow").version) >= version.parse("6.0.1"):
return "2.6"
else:
return "2.0"
try:
if version.parse(get_package_version("pyarrow")) >= version.parse("6.0.1"):
return "2.6"
except PackageNotFoundError:
# Don't need to do anything in particular
...
return "2.0"


class ParquetWriter:
Expand Down
6 changes: 3 additions & 3 deletions csp/adapters/parquet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import datetime
import io
import numpy
import pkg_resources
import platform
import pyarrow
import pyarrow.parquet
from importlib.metadata import PackageNotFoundError, version as get_package_version
from packaging import version
from typing import TypeVar

Expand All @@ -28,9 +28,9 @@

try:
_CAN_READ_ARROW_BINARY = False
if version.parse(pkg_resources.get_distribution("pyarrow").version) >= version.parse("4.0.1"):
if version.parse(get_package_version("pyarrow")) >= version.parse("4.0.1"):
_CAN_READ_ARROW_BINARY = True
except (ValueError, TypeError):
except (PackageNotFoundError, ValueError, TypeError):
# Cannot read binary arrow
...

Expand Down
27 changes: 19 additions & 8 deletions csp/tests/adapters/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datetime import date, datetime, time

import csp
from csp.adapters.db import DateTimeAccessor, DBReader, EngineStartTimeAccessor, TimestampAccessor
from csp.adapters.db import _SQLALCHEMY_2, DateTimeAccessor, DBReader, EngineStartTimeAccessor, TimestampAccessor


class PriceQuantity(csp.Struct):
Expand All @@ -21,6 +21,15 @@ class PriceQuantity2(csp.Struct):
side: str


def execute_with_commit(engine, query, values):
if _SQLALCHEMY_2:
with engine.connect() as conn:
conn.execute(query, values)
conn.commit()
else:
engine.execute(query, values)


class TestDBReader(unittest.TestCase):
def _prepopulate_in_mem_engine(self):
engine = db.create_engine("sqlite:///:memory:") # in-memory sqlite db
Expand All @@ -46,7 +55,7 @@ def _prepopulate_in_mem_engine(self):
{"TIME": starttime.replace(second=5), "SYMBOL": "AAPL", "PRICE": 200.0, "SIZE": 400, "SIDE": "BUY"},
{"TIME": starttime.replace(second=6), "SYMBOL": "GM", "PRICE": 2.0, "SIZE": 1, "SIDE": "BUY"},
]
engine.execute(query, values_list)
execute_with_commit(engine, query, values_list)
return engine

def test_sqlite_basic(self):
Expand Down Expand Up @@ -92,7 +101,7 @@ def graph():

# UTC
result = csp.run(graph, starttime=datetime(2020, 3, 3, 9, 30))
print(result)

self.assertEqual(len(result["aapl"]), 4)
self.assertTrue(all(v[1].SYMBOL == "AAPL" for v in result["aapl"]))

Expand Down Expand Up @@ -211,7 +220,8 @@ def test_sqlite_constraints(self):
"SIDE": "BUY",
},
]
engine.execute(query, values_list)

execute_with_commit(engine, query, values_list)

def graph():
time_accessor = DateTimeAccessor(date_column="DATE", time_column="TIME", tz=pytz.timezone("US/Eastern"))
Expand Down Expand Up @@ -310,7 +320,7 @@ def test_join_query(self):
{"TIME": starttime.replace(second=5), "SYMBOL": "AAPL", "PRICE": 200.0},
{"TIME": starttime.replace(second=6), "SYMBOL": "GM", "PRICE": 2.0},
]
engine.execute(query, values_list1)
execute_with_commit(engine, query, values_list1)

query = db.insert(test2)
values_list2 = [
Expand All @@ -322,15 +332,15 @@ def test_join_query(self):
# { 'TIME': starttime.replace( second = 5 ), 'SIZE': 400, 'SIDE': 'BUY' },
{"TIME": starttime.replace(second=6), "SIZE": 1, "SIDE": "BUY"},
]
engine.execute(query, values_list2)
execute_with_commit(engine, query, values_list2)

metadata.create_all(engine)

def graph():
time_accessor = TimestampAccessor(time_column="TIME", tz=pytz.timezone("US/Eastern"))
query = "select * from test1 inner join test2 on test2.TIME=test1.TIME"
reader = DBReader.create_from_connection(
connection=engine.connect(), query=query, time_accessor=time_accessor, symbol_column="SYMBOL"
connection=engine, query=query, time_accessor=time_accessor, symbol_column="SYMBOL"
)

# Struct
Expand Down Expand Up @@ -414,7 +424,8 @@ def test_DateTimeAccessor(self):
(datetime(2020, 3, 5, 12), 700.0),
]
values_list = [{"DATE": v[0].date(), "TIME": v[0].time(), "SYMBOL": "AAPL", "PRICE": v[1]} for v in values]
engine.execute(query, values_list)

execute_with_commit(engine, query, values_list)

def graph():
time_accessor = DateTimeAccessor(date_column="DATE", time_column="TIME", tz=pytz.timezone("US/Eastern"))
Expand Down
2 changes: 1 addition & 1 deletion dev-environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies:
- ruamel.yaml
- scikit-build
- psutil
- sqlalchemy<2
- sqlalchemy
- bump2version>=1.0.0
- python-graphviz
- httpx
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ develop = [
"pytest-cov",
"pytest-sugar",
"scikit-build",
"sqlalchemy<2",
"sqlalchemy",
"tornado",
]

Expand Down