diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7450f785..eed612eb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: ruff-check args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.402 + rev: v1.1.408 hooks: - id: pyright - repo: https://github.com/pycqa/isort diff --git a/nummus/commands/export.py b/nummus/commands/export.py index 0fa6162e..2fe0251a 100644 --- a/nummus/commands/export.py +++ b/nummus/commands/export.py @@ -15,11 +15,6 @@ import argparse import io - from sqlalchemy import orm - - from nummus.models.currency import Currency - from nummus.models.transaction import TransactionSplit - class Export(Command): """Export transactions.""" @@ -85,35 +80,20 @@ def setup_args(cls, parser: argparse.ArgumentParser) -> None: @override def run(self) -> int: - # Defer for faster time to main - from nummus.models.transaction import TransactionSplit - - with self._p.begin_session() as s: - query = ( - s.query(TransactionSplit) - .where( - TransactionSplit.asset_id.is_(None), - ) - .with_entities(TransactionSplit.amount) - ) - if self._start is not None: - query = query.where( - TransactionSplit.date_ord >= self._start.toordinal(), - ) - if self._end is not None: - query = query.where( - TransactionSplit.date_ord <= self._end.toordinal(), - ) - - with self._csv_path.open("w", encoding="utf-8") as file: - n = write_csv(file, query, no_bars=self._no_bars) + + with ( + self._p.begin_session(), + self._csv_path.open("w", encoding="utf-8") as file, + ): + n = write_csv(file, self._start, self._end, no_bars=self._no_bars) print(f"{Fore.GREEN}{n} transactions exported to {self._csv_path}") return 0 def write_csv( file: io.TextIOBase, - transactions_query: orm.Query[TransactionSplit], + start: datetime.date | None, + end: datetime.date | None, *, no_bars: bool, ) -> int: @@ -121,7 +101,8 @@ def write_csv( Args: file: Destination file to write to - transactions_query: ORM query to obtain TransactionSplits + start: Start date to filter transactions + end: End date to filter transactions no_bars: True will disable progress bars Returns: @@ -131,34 +112,43 @@ def write_csv( # Defer for faster time to main import tqdm + from nummus import sql from nummus.models.account import Account - from nummus.models.base import YIELD_PER from nummus.models.currency import CURRENCY_FORMATS from nummus.models.transaction import TransactionCategory, TransactionSplit - from nummus.models.utils import query_count - s = transactions_query.session - - query = s.query(Account).with_entities( + query = Account.query( Account.id_, Account.name, Account.currency, ) - accounts: dict[int, tuple[str, Currency]] = { - r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER) - } - - categories = TransactionCategory.map_name_emoji(s) - - query = transactions_query.with_entities( - TransactionSplit.date_ord, - TransactionSplit.account_id, - TransactionSplit.payee, - TransactionSplit.memo, - TransactionSplit.category_id, - TransactionSplit.amount, - ).order_by(TransactionSplit.date_ord) - n = query_count(query) + accounts = sql.to_dict_tuple(query) + + categories = TransactionCategory.map_name_emoji() + + query = ( + TransactionSplit.query( + TransactionSplit.date_ord, + TransactionSplit.account_id, + TransactionSplit.payee, + TransactionSplit.memo, + TransactionSplit.category_id, + TransactionSplit.amount, + ) + .order_by(TransactionSplit.date_ord) + .where( + TransactionSplit.asset_id.is_(None), + ) + ) + if start is not None: + query = query.where( + TransactionSplit.date_ord >= start.toordinal(), + ) + if end is not None: + query = query.where( + TransactionSplit.date_ord <= end.toordinal(), + ) + n = sql.count(query) header = [ "Date", @@ -177,7 +167,7 @@ def write_csv( t_cat_id, amount, ) in tqdm.tqdm( - query.yield_per(YIELD_PER), + sql.yield_(query), total=n, desc="Exporting", disable=no_bars, @@ -189,8 +179,8 @@ def write_csv( [ datetime.date.fromordinal(date).isoformat(), acct_name, - payee, - memo, + payee or "", + memo or "", categories[t_cat_id], cf(amount), ], diff --git a/nummus/commands/health.py b/nummus/commands/health.py index daa38e97..1912f81e 100644 --- a/nummus/commands/health.py +++ b/nummus/commands/health.py @@ -113,13 +113,15 @@ def run(self) -> int: p = self._p - with p.begin_session() as s: + with p.begin_session(): if self._clear_ignores: - s.query(HealthCheckIssue).delete() + HealthCheckIssue.query().delete() elif self._ignores: # Set ignore for all specified issues + print(self._ignores) + print(HealthCheckIssue.uri_to_id(self._ignores[0])) ids = {HealthCheckIssue.uri_to_id(uri) for uri in self._ignores} - s.query(HealthCheckIssue).where(HealthCheckIssue.id_.in_(ids)).update( + HealthCheckIssue.query().where(HealthCheckIssue.id_.in_(ids)).update( {HealthCheckIssue.ignore: True}, ) @@ -141,8 +143,8 @@ def run(self) -> int: # Update LAST_HEALTH_CHECK_TS utc_now = datetime.datetime.now(datetime.UTC) - with p.begin_session() as s: - Config.set_(s, ConfigKey.LAST_HEALTH_CHECK_TS, utc_now.isoformat()) + with p.begin_session(): + Config.set_(ConfigKey.LAST_HEALTH_CHECK_TS, utc_now.isoformat()) if any_severe_issues: return -2 if any_issues: @@ -164,8 +166,8 @@ def _test_check(self, check_type: type[HealthCheck]) -> str | None: no_ignores=self._no_ignores, no_description_typos=self._no_description_typos, ) - with self._p.begin_session() as s: - c.test(s) + with self._p.begin_session(): + c.test() n_issues = len(c.issues) if n_issues == 0: print(f"{Fore.GREEN}Check '{c.name()}' has no issues") diff --git a/nummus/commands/migrate.py b/nummus/commands/migrate.py index f18fcec3..2821c94d 100644 --- a/nummus/commands/migrate.py +++ b/nummus/commands/migrate.py @@ -54,8 +54,8 @@ def run(self) -> int: # Back up Portfolio _, tar_ver = p.backup() - with p.begin_session() as s: - v_db = Config.db_version(s) + with p.begin_session(): + v_db = Config.db_version() any_migrated = False try: @@ -78,13 +78,13 @@ def run(self) -> int: m.migrate(p) # no comments print(f"{Fore.GREEN}Portfolio model schemas updated") - with p.begin_session() as s: + with p.begin_session(): v = max( Version(__version__), *[m.min_version() for m in MIGRATORS], ) - Config.set_(s, ConfigKey.VERSION, str(v)) + Config.set_(ConfigKey.VERSION, str(v)) except Exception: # pragma: no cover # No immediate exception thrown, can't easily test portfolio.Portfolio.restore(p, tar_ver=tar_ver) diff --git a/nummus/commands/summarize.py b/nummus/commands/summarize.py index 999e049a..7bd49fc3 100644 --- a/nummus/commands/summarize.py +++ b/nummus/commands/summarize.py @@ -108,41 +108,33 @@ def _get_summary( # Defer for faster time to main from sqlalchemy import func - from nummus import utils + from nummus import sql, utils from nummus.models.account import Account from nummus.models.asset import Asset, AssetCategory, AssetValuation from nummus.models.config import Config from nummus.models.currency import CURRENCY_FORMATS from nummus.models.transaction import TransactionSplit - from nummus.models.utils import query_count today = datetime.datetime.now().astimezone().date() today_ord = today.toordinal() - with self._p.begin_session() as s: - accts = {acct.id_: acct for acct in s.query(Account).all()} - assets = { - a.id_: a - for a in ( - s.query(Asset).where(Asset.category != AssetCategory.INDEX).all() - ) - } + with self._p.begin_session(): + accts = {acct.id_: acct for acct in Account.all()} + query = Asset.query().where(Asset.category != AssetCategory.INDEX) + assets = {a.id_: a for a in sql.yield_(query)} # Get the inception date - start_date_ord: int = ( - s.query( - func.min(TransactionSplit.date_ord), - ).scalar() - or datetime.date(1970, 1, 1).toordinal() + query = TransactionSplit.query( + func.min(TransactionSplit.date_ord), ) + start_date_ord = sql.scalar(query) or datetime.date(1970, 1, 1).toordinal() n_accounts = len(accts) - n_transactions = query_count(s.query(TransactionSplit)) + n_transactions = TransactionSplit.count() n_assets = len(assets) - n_valuations = query_count(s.query(AssetValuation)) + n_valuations = AssetValuation.count() value_accts, profit_accts, value_assets = Account.get_value_all( - s, start_date_ord, today_ord, ) @@ -180,7 +172,6 @@ def _get_summary( ) profit_assets = Account.get_profit_by_asset_all( - s, start_date_ord, today_ord, ) @@ -224,7 +215,7 @@ def _get_summary( "total_asset_value": total_asset_value, "assets": summary_assets, "db_size": self._p.path.stat().st_size, - "cf": CURRENCY_FORMATS[Config.base_currency(s)], + "cf": CURRENCY_FORMATS[Config.base_currency()], } @classmethod diff --git a/nummus/controllers/accounts.py b/nummus/controllers/accounts.py index d7eca261..b0cbd786 100644 --- a/nummus/controllers/accounts.py +++ b/nummus/controllers/accounts.py @@ -11,11 +11,10 @@ from sqlalchemy import func from nummus import exceptions as exc -from nummus import utils, web +from nummus import sql, utils, web from nummus.controllers import base, transactions from nummus.models.account import Account, AccountCategory from nummus.models.asset import Asset, AssetCategory -from nummus.models.base import YIELD_PER from nummus.models.config import Config from nummus.models.currency import ( Currency, @@ -24,13 +23,11 @@ ) from nummus.models.transaction import Transaction, TransactionSplit from nummus.models.transaction_category import TransactionCategory -from nummus.models.utils import query_to_dict if TYPE_CHECKING: import datetime import werkzeug - from sqlalchemy import orm from nummus.models.currency import CurrencyFormat @@ -123,12 +120,12 @@ def page_all() -> flask.Response: """ p = web.portfolio - with p.begin_session() as s: + with p.begin_session(): include_closed = "include-closed" in flask.request.args return base.page( "accounts/page-all.jinja", "Accounts", - ctx=ctx_accounts(s, base.today_client(), include_closed=include_closed), + ctx=ctx_accounts(base.today_client(), include_closed=include_closed), ) @@ -144,11 +141,10 @@ def page(uri: str) -> flask.Response: """ p = web.portfolio today = base.today_client() - with p.begin_session() as s: - acct = base.find(s, Account, uri) + with p.begin_session(): + acct = base.find(Account, uri) args = flask.request.args txn_table, title = transactions.ctx_table( - s, today, args.get("search"), args.get("account"), @@ -163,16 +159,15 @@ def page(uri: str) -> flask.Response: title = title.removeprefix("Transactions").strip() title = f"{acct.name}, {title}" if title else f"{acct.name}" - ctx = ctx_account(s, acct, today) + ctx = ctx_account(acct, today) if acct.category == AccountCategory.INVESTMENT: ctx["performance"] = ctx_performance( - s, acct, today, args.get("chart-period"), CURRENCY_FORMATS[acct.currency], ) - ctx["assets"] = ctx_assets(s, acct, today) + ctx["assets"] = ctx_assets(acct, today) return base.page( "accounts/page.jinja", title=title, @@ -192,7 +187,7 @@ def new() -> str | flask.Response: """ p = web.portfolio with p.begin_session() as s: - base_currency = Config.base_currency(s) + base_currency = Config.base_currency() if flask.request.method == "GET": ctx: AccountContext = { "uri": None, @@ -236,7 +231,7 @@ def new() -> str | flask.Response: try: with s.begin_nested(): - acct = Account( + Account.create( institution=institution, name=name, number=number, @@ -245,7 +240,6 @@ def new() -> str | flask.Response: budgeted=budgeted, currency=currency, ) - s.add(acct) except (exc.IntegrityError, exc.InvalidORMValueError) as e: return base.error(e) @@ -267,18 +261,18 @@ def account(uri: str) -> str | werkzeug.Response: today_ord = today.toordinal() with p.begin_session() as s: - base_currency = Config.base_currency(s) + base_currency = Config.base_currency() - acct = base.find(s, Account, uri) + acct = base.find(Account, uri) if flask.request.method == "GET": return flask.render_template( "accounts/edit.jinja", - acct=ctx_account(s, acct, today), + acct=ctx_account(acct, today), ) if flask.request.method == "DELETE": with s.begin_nested(): - s.delete(acct) + acct.delete() return flask.redirect(flask.url_for("accounts.page_all")) values, _, _ = acct.get_value(today_ord, today_ord) @@ -326,14 +320,13 @@ def performance(uri: str) -> flask.Response: """ p = web.portfolio args = flask.request.args - with p.begin_session() as s: - acct = base.find(s, Account, uri) + with p.begin_session(): + acct = base.find(Account, uri) html = flask.render_template( "accounts/performance.jinja", acct={ "uri": uri, "performance": ctx_performance( - s, acct, base.today_client(), args.get("chart-period"), @@ -365,13 +358,13 @@ def validation() -> str: p = web.portfolio # dict{key: (required, prop if unique required)} - properties: dict[str, tuple[bool, orm.QueryableAttribute | None]] = { + properties: dict[str, tuple[bool, sql.Column | None]] = { "name": (True, Account.name), "institution": (True, None), "number": (False, Account.number), } - with p.begin_session() as s: + with p.begin_session(): args = flask.request.args uri = args.get("uri") for key, (required, prop) in properties.items(): @@ -380,7 +373,7 @@ def validation() -> str: return base.validate_string( args[key], is_required=required, - session=s, + cls=Account, no_duplicates=prop, no_duplicate_wheres=( None if uri is None else [Account.id_ != Account.uri_to_id(uri)] @@ -391,7 +384,6 @@ def validation() -> str: def ctx_account( - s: orm.Session, acct: Account, today: datetime.date, *, @@ -400,7 +392,6 @@ def ctx_account( """Get the context to build the account details. Args: - s: SQL session to use acct: Account to generate context for today: Today's date skip_today: True will skip fetching today's value @@ -423,32 +414,24 @@ def ctx_account( None if updated_on_ord is None else today_ord - updated_on_ord ) - query = ( - s.query(Transaction) - .with_entities( - func.count(Transaction.id_), - func.sum(Transaction.amount), - ) - .where( - Transaction.date_ord == today_ord, - Transaction.account_id == acct.id_, - ) + query = Transaction.query( + func.count(Transaction.id_), + func.sum(Transaction.amount), + ).where( + Transaction.date_ord == today_ord, + Transaction.account_id == acct.id_, ) - n_today, change_today = query.one() + n_today, change_today = sql.one(query) change_today: Decimal = change_today or Decimal() - query = ( - s.query(Transaction) - .with_entities( - func.count(Transaction.id_), - func.sum(Transaction.amount), - ) - .where( - Transaction.date_ord > today_ord, - Transaction.account_id == acct.id_, - ) + query = Transaction.query( + func.count(Transaction.id_), + func.sum(Transaction.amount), + ).where( + Transaction.date_ord > today_ord, + Transaction.account_id == acct.id_, ) - n_future, change_future = query.one() + n_future, change_future = sql.one(query) values, _, _ = acct.get_value(today_ord, today_ord) current_value = values[0] @@ -478,7 +461,6 @@ def ctx_account( def ctx_performance( - s: orm.Session, acct: Account, today: datetime.date, period: str | None, @@ -487,7 +469,6 @@ def ctx_performance( """Get the context to build the account performance details. Args: - s: SQL session to use acct: Account to generate context for today: Today's date period: Period string to get data for @@ -502,25 +483,30 @@ def ctx_performance( end_ord = end.toordinal() start_ord = acct.opened_on_ord or end_ord if start is None else start.toordinal() - query = s.query(TransactionCategory.id_, TransactionCategory.name).where( + query = TransactionCategory.query( + TransactionCategory.id_, + TransactionCategory.name, + ).where( TransactionCategory.is_profit_loss.is_(True), ) - pnl_categories: dict[int, str] = query_to_dict(query) + pnl_categories: dict[int, str] = sql.to_dict(query) # Calculate total cost basis total_cost_basis = Decimal() dividends = Decimal() fees = Decimal() query = ( - s.query(TransactionSplit) - .with_entities(TransactionSplit.category_id, func.sum(TransactionSplit.amount)) + TransactionSplit.query( + TransactionSplit.category_id, + func.sum(TransactionSplit.amount), + ) .where( TransactionSplit.date_ord <= end_ord, TransactionSplit.account_id == acct.id_, ) .group_by(TransactionSplit.category_id) ) - for cat_id, value in query.yield_per(YIELD_PER): + for cat_id, value in sql.yield_(query): name = pnl_categories.get(cat_id) if name is None: total_cost_basis += value @@ -566,14 +552,12 @@ def ctx_performance( def ctx_assets( - s: orm.Session, acct: Account, today: datetime.date, ) -> list[AssetContext] | None: """Get the context to build the account assets. Args: - s: SQL session to use acct: Account to generate context for today: Today's date @@ -591,32 +575,32 @@ def ctx_assets( return None # Not an investment account # Include all assets every held - query = s.query(TransactionSplit.asset_id).where( - TransactionSplit.account_id == acct.id_, - TransactionSplit.asset_id.is_not(None), + query = ( + TransactionSplit.query(TransactionSplit.asset_id) + .where( + TransactionSplit.account_id == acct.id_, + TransactionSplit.asset_id.is_not(None), + ) + .distinct() ) - a_ids = {a_id for a_id, in query.distinct()} + a_ids = {a_id for a_id in sql.col0(query) if a_id} - end_prices = Asset.get_value_all(s, today_ord, today_ord, ids=a_ids) + end_prices = Asset.get_value_all(today_ord, today_ord, ids=a_ids) asset_profits = acct.get_profit_by_asset(start_ord, today_ord) # Sum of profits should match final profit value, add any mismatch to cash - query = ( - s.query(Asset) - .with_entities( - Asset.id_, - Asset.name, - Asset.ticker, - Asset.category, - ) - .where(Asset.id_.in_(a_ids)) - ) + query = Asset.query( + Asset.id_, + Asset.name, + Asset.ticker, + Asset.category, + ).where(Asset.id_.in_(a_ids)) assets: list[AssetContext] = [] total_value = Decimal() total_profit = Decimal() - for a_id, name, ticker, category in query.yield_per(YIELD_PER): + for a_id, name, ticker, category in sql.yield_(query): end_qty = asset_qtys[a_id] end_price = end_prices[a_id][0] end_value = end_qty * end_price @@ -639,12 +623,11 @@ def ctx_assets( assets.append(ctx_asset) # Add in cash too - cash: Decimal = ( - s.query(func.sum(TransactionSplit.amount)) - .where(TransactionSplit.account_id == acct.id_) - .where(TransactionSplit.date_ord <= today_ord) - .one()[0] + query = TransactionSplit.query(func.sum(TransactionSplit.amount)).where( + TransactionSplit.account_id == acct.id_, + TransactionSplit.date_ord <= today_ord, ) + cash: Decimal = sql.one(query) total_value += cash ctx_asset = { "uri": None, @@ -676,7 +659,6 @@ def ctx_assets( def ctx_accounts( - s: orm.Session, today: datetime.date, *, include_closed: bool = False, @@ -684,7 +666,6 @@ def ctx_accounts( """Get the context to build the accounts table. Args: - s: SQL session to use today: Today's date include_closed: True will include Accounts marked closed, False will exclude @@ -705,71 +686,59 @@ def ctx_accounts( # Get basic info accounts: dict[int, AccountContext] = {} currencies: dict[int, Currency] = {} - query = s.query(Account).order_by(Account.category) + query = Account.query().order_by(Account.category) if not include_closed: query = query.where(Account.closed.is_(False)) - for acct in query.all(): - accounts[acct.id_] = ctx_account(s, acct, today, skip_today=True) + for acct in sql.yield_(query): + accounts[acct.id_] = ctx_account(acct, today, skip_today=True) currencies[acct.id_] = acct.currency if acct.closed: n_closed += 1 # Get updated_on query = ( - s.query(Transaction) - .with_entities( + Transaction.query( Transaction.account_id, func.max(Transaction.date_ord), ) .group_by(Transaction.account_id) .where(Transaction.account_id.in_(accounts)) ) - for acct_id, updated_on_ord in query.all(): + for acct_id, updated_on_ord in sql.yield_(query): acct_id: int updated_on_ord: int accounts[acct_id]["updated_days_ago"] = today_ord - updated_on_ord # Get n_today query = ( - s.query(Transaction) - .with_entities( + Transaction.query( Transaction.account_id, func.count(Transaction.id_), func.sum(Transaction.amount), ) - .where(Transaction.date_ord == today_ord) + .where(Transaction.date_ord == today_ord, Transaction.account_id.in_(accounts)) .group_by(Transaction.account_id) - .where(Transaction.account_id.in_(accounts)) ) - for acct_id, n_today, change_today in query.all(): - acct_id: int - n_today: int - change_today: Decimal | None + for acct_id, n_today, change_today in sql.yield_(query): accounts[acct_id]["n_today"] = n_today accounts[acct_id]["change_today"] = change_today or Decimal() # Get n_future query = ( - s.query(Transaction) - .with_entities( + Transaction.query( Transaction.account_id, func.count(Transaction.id_), func.sum(Transaction.amount), ) - .where(Transaction.date_ord > today_ord) + .where(Transaction.date_ord > today_ord, Transaction.account_id.in_(accounts)) .group_by(Transaction.account_id) - .where(Transaction.account_id.in_(accounts)) ) - for acct_id, n_future, change_future in query.all(): - acct_id: int - n_future: int - change_future: Decimal + for acct_id, n_future, change_future in sql.yield_(query): accounts[acct_id]["n_future"] = n_future accounts[acct_id]["change_future"] = change_future - base_currency = Config.base_currency(s) + base_currency = Config.base_currency() forex = Asset.get_forex( - s, today_ord, today_ord, base_currency, @@ -777,7 +746,7 @@ def ctx_accounts( ) # Get all Account values - acct_values, _, _ = Account.get_value_all(s, today_ord, today_ord, ids=accounts) + acct_values, _, _ = Account.get_value_all(today_ord, today_ord, ids=accounts) for acct_id, ctx in accounts.items(): v = acct_values[acct_id][0] ctx["value"] = v @@ -822,7 +791,7 @@ def ctx_accounts( }, "include_closed": include_closed, "n_closed": n_closed, - "currency_format": CURRENCY_FORMATS[Config.base_currency(s)], + "currency_format": CURRENCY_FORMATS[Config.base_currency()], } @@ -840,9 +809,8 @@ def txns(uri: str) -> str | flask.Response: args = flask.request.args first_page = "page" not in args - with p.begin_session() as s: + with p.begin_session(): txn_table, title = transactions.ctx_table( - s, base.today_client(), args.get("search"), args.get("account"), @@ -855,7 +823,7 @@ def txns(uri: str) -> str | flask.Response: acct_uri=uri, ) title = title.removeprefix("Transactions").strip() - acct = base.find(s, Account, uri) + acct = base.find(Account, uri) title = f"{acct.name}, {title}" if title else f"{acct.name}" html_title = f"{title} - nummus\n" html = html_title + flask.render_template( @@ -893,8 +861,8 @@ def txns_options(uri: str) -> str: """ p = web.portfolio - with p.begin_session() as s: - accounts = Account.map_name(s) + with p.begin_session(): + accounts = Account.map_name() args = flask.request.args uncleared = "uncleared" in args @@ -905,7 +873,6 @@ def txns_options(uri: str) -> str: selected_end = args.get("end") tbl_query = transactions.table_query( - s, None, selected_account, selected_period, @@ -918,7 +885,7 @@ def txns_options(uri: str) -> str: tbl_query, base.today_client(), accounts, - base.tranaction_category_groups(s), + base.tranaction_category_groups(), selected_account, selected_category, ) diff --git a/nummus/controllers/allocation.py b/nummus/controllers/allocation.py index a1e2d161..2f68e4f9 100644 --- a/nummus/controllers/allocation.py +++ b/nummus/controllers/allocation.py @@ -7,11 +7,10 @@ from decimal import Decimal from typing import TYPE_CHECKING, TypedDict -from nummus import web +from nummus import sql, web from nummus.controllers import base from nummus.models.account import Account from nummus.models.asset import Asset, AssetSector -from nummus.models.base import YIELD_PER from nummus.models.config import Config from nummus.models.currency import CURRENCY_FORMATS @@ -19,7 +18,6 @@ import datetime import flask - from sqlalchemy import orm from nummus.models.asset import AssetCategory, USSector from nummus.models.currency import CurrencyFormat @@ -76,19 +74,18 @@ def page() -> flask.Response: """ p = web.portfolio - with p.begin_session() as s: + with p.begin_session(): return base.page( "allocation/page.jinja", title="Asset allocation", - allocation=ctx_allocation(s, base.today_client()), + allocation=ctx_allocation(base.today_client()), ) -def ctx_allocation(s: orm.Session, today: datetime.date) -> AllocationContext: +def ctx_allocation(today: datetime.date) -> AllocationContext: """Get the context to build the allocation chart. Args: - s: SQL session to use today: Today's date Returns: @@ -98,7 +95,7 @@ def ctx_allocation(s: orm.Session, today: datetime.date) -> AllocationContext: today_ord = today.toordinal() asset_qtys: dict[int, Decimal] = defaultdict(Decimal) - acct_qtys = Account.get_asset_qty_all(s, today_ord, today_ord) + acct_qtys = Account.get_asset_qty_all(today_ord, today_ord) for acct_qty in acct_qtys.values(): for a_id, values in acct_qty.items(): asset_qtys[a_id] += values[0] @@ -107,7 +104,6 @@ def ctx_allocation(s: orm.Session, today: datetime.date) -> AllocationContext: asset_prices = { a_id: values[0] for a_id, values in Asset.get_value_all( - s, today_ord, today_ord, ids=set(asset_qtys), @@ -117,13 +113,13 @@ def ctx_allocation(s: orm.Session, today: datetime.date) -> AllocationContext: asset_values = {a_id: qty * asset_prices[a_id] for a_id, qty in asset_qtys.items()} asset_sectors: dict[int, dict[USSector, Decimal]] = defaultdict(dict) - for a_sector in s.query(AssetSector).yield_per(YIELD_PER): + for a_sector in sql.yield_(AssetSector.query()): asset_sectors[a_sector.asset_id][a_sector.sector] = a_sector.weight assets_by_category: dict[AssetCategory, list[AssetContext]] = defaultdict(list) assets_by_sector: dict[USSector, list[AssetContext]] = defaultdict(list) - query = s.query(Asset).where(Asset.id_.in_(asset_qtys)).order_by(Asset.name) - for asset in query.yield_per(YIELD_PER): + query = Asset.query().where(Asset.id_.in_(asset_qtys)).order_by(Asset.name) + for asset in sql.yield_(query): qty = asset_qtys[asset.id_] value = asset_values[asset.id_] @@ -174,7 +170,7 @@ def chart_assets(assets: list[AssetContext]) -> list[ChartAssetContext]: for a in assets ] - cf = CURRENCY_FORMATS[Config.base_currency(s)] + cf = CURRENCY_FORMATS[Config.base_currency()] return { "categories": sorted(categories, key=operator.itemgetter("name")), diff --git a/nummus/controllers/assets.py b/nummus/controllers/assets.py index 975344a2..560b5ba6 100644 --- a/nummus/controllers/assets.py +++ b/nummus/controllers/assets.py @@ -12,7 +12,7 @@ from sqlalchemy import func from nummus import exceptions as exc -from nummus import utils, web +from nummus import sql, utils, web from nummus.controllers import base from nummus.models.account import Account from nummus.models.asset import ( @@ -22,19 +22,15 @@ AssetSplit, AssetValuation, ) -from nummus.models.base import YIELD_PER from nummus.models.config import Config from nummus.models.currency import ( Currency, CURRENCY_FORMATS, ) from nummus.models.transaction import TransactionSplit -from nummus.models.utils import query_count if TYPE_CHECKING: - import sqlalchemy import werkzeug - from sqlalchemy import orm from nummus.models.currency import CurrencyFormat @@ -132,8 +128,8 @@ def page_all() -> flask.Response: p = web.portfolio include_unheld = "include-unheld" in flask.request.args - with p.begin_session() as s: - categories = ctx_rows(s, base.today_client(), include_unheld=include_unheld) + with p.begin_session(): + categories = ctx_rows(base.today_client(), include_unheld=include_unheld) return base.page( "assets/page-all.jinja", @@ -160,10 +156,9 @@ def page(uri: str) -> flask.Response: """ p = web.portfolio args = flask.request.args - with p.begin_session() as s: - a = base.find(s, Asset, uri) + with p.begin_session(): + a = base.find(Asset, uri) ctx = ctx_asset( - s, a, base.today_client(), args.get("period"), @@ -190,7 +185,7 @@ def new() -> str | flask.Response: with p.begin_session() as s: if flask.request.method == "GET": - currency = Config.base_currency(s) + currency = Config.base_currency() ctx: AssetContext = { "uri": None, "name": "", @@ -222,14 +217,13 @@ def new() -> str | flask.Response: try: with s.begin_nested(): - a = Asset( + Asset.create( name=name, description=description, category=category, ticker=ticker, currency=currency, ) - s.add(a) except (exc.IntegrityError, exc.InvalidORMValueError) as e: return base.error(e) @@ -248,14 +242,13 @@ def asset(uri: str) -> str | werkzeug.Response: """ p = web.portfolio with p.begin_session() as s: - a = base.find(s, Asset, uri) + a = base.find(Asset, uri) if flask.request.method == "GET": args = flask.request.args return flask.render_template( "assets/edit.jinja", asset=ctx_asset( - s, a, base.today_client(), args.get("period"), @@ -267,10 +260,10 @@ def asset(uri: str) -> str | werkzeug.Response: ) if flask.request.method == "DELETE": with s.begin_nested(): - s.query(AssetSector).where(AssetSector.asset_id == a.id_).delete() - s.query(AssetSplit).where(AssetSplit.asset_id == a.id_).delete() - s.query(AssetValuation).where(AssetValuation.asset_id == a.id_).delete() - s.delete(a) + AssetSector.query().where(AssetSector.asset_id == a.id_).delete() + AssetSplit.query().where(AssetSplit.asset_id == a.id_).delete() + AssetValuation.query().where(AssetValuation.asset_id == a.id_).delete() + a.delete() return flask.redirect(flask.url_for("assets.page_all")) form = flask.request.form @@ -302,14 +295,14 @@ def performance(uri: str) -> flask.Response: """ p = web.portfolio - with p.begin_session() as s: - a = base.find(s, Asset, uri) + with p.begin_session(): + a = base.find(Asset, uri) period = flask.request.args.get("chart-period") html = flask.render_template( "assets/performance.jinja", asset={ "uri": uri, - "performance": ctx_performance(s, a, base.today_client(), period), + "performance": ctx_performance(a, base.today_client(), period), }, ) response = flask.make_response(html) @@ -337,11 +330,10 @@ def table(uri: str) -> str | flask.Response: """ p = web.portfolio args = flask.request.args - with p.begin_session() as s: - a = base.find(s, Asset, uri) + with p.begin_session(): + a = base.find(Asset, uri) cf = CURRENCY_FORMATS[a.currency] val_table = ctx_table( - s, a, base.today_client(), args.get("period"), @@ -382,13 +374,13 @@ def validation() -> str: """ p = web.portfolio # dict{key: (required, prop if unique required)} - properties: dict[str, tuple[bool, orm.QueryableAttribute | None]] = { + properties: dict[str, tuple[bool, sql.Column | None]] = { "name": (True, Asset.name), "description": (False, None), "ticker": (False, Asset.ticker), } - with p.begin_session() as s: + with p.begin_session(): args = flask.request.args uri = args.get("uri") for key, (required, prop) in properties.items(): @@ -398,7 +390,7 @@ def validation() -> str: args[key], is_required=required, check_length=key != "ticker", - session=s, + cls=Asset, no_duplicates=prop, no_duplicate_wheres=( None if uri is None else [Asset.id_ != Asset.uri_to_id(uri)] @@ -406,7 +398,7 @@ def validation() -> str: ) if "date" in args: - wheres: list[sqlalchemy.ColumnExpressionArgument] = [] + wheres: list[sql.ColumnClause] = [] if uri: wheres.append( AssetValuation.asset_id == Asset.uri_to_id(uri), @@ -420,7 +412,7 @@ def validation() -> str: args["date"], base.today_client(), is_required=True, - session=s, + cls=AssetValuation, no_duplicates=AssetValuation.date_ord, no_duplicate_wheres=wheres, ) @@ -470,14 +462,13 @@ def new_valuation(uri: str) -> str | flask.Response: try: p = web.portfolio - with p.begin_session() as s: - a = base.find(s, Asset, uri) - v = AssetValuation( + with p.begin_session(): + a = base.find(Asset, uri) + AssetValuation.create( asset_id=a.id_, date_ord=date.toordinal(), value=value, ) - s.add(v) except exc.IntegrityError as e: # Get the line that starts with (...IntegrityError) orig = str(e.orig) @@ -505,7 +496,7 @@ def valuation(uri: str) -> str | flask.Response: today = base.today_client() with p.begin_session() as s: - v = base.find(s, AssetValuation, uri) + v = base.find(AssetValuation, uri) date_max = today + datetime.timedelta(days=utils.DAYS_IN_WEEK) if flask.request.method == "GET": @@ -521,7 +512,7 @@ def valuation(uri: str) -> str | flask.Response: ) if flask.request.method == "DELETE": date = v.date - s.delete(v) + v.delete() return base.dialog_swap( event="valuation", snackbar=f"{date} valuation deleted", @@ -564,8 +555,8 @@ def update() -> str | flask.Response: """ p = web.portfolio - with p.begin_session() as s: - n = query_count(s.query(Asset).where(Asset.ticker.is_not(None))) + with p.begin_session(): + n = sql.count(Asset.query().where(Asset.ticker.is_not(None))) if flask.request.method == "GET": return flask.render_template( "assets/update.jinja", @@ -600,7 +591,6 @@ def update() -> str | flask.Response: def ctx_rows( - s: orm.Session, today: datetime.date, *, include_unheld: bool, @@ -608,7 +598,6 @@ def ctx_rows( """Get the context to build the page all rows. Args: - s: SQL session to use today: Today's date include_unheld: True will include assets with zero current quantity @@ -620,7 +609,7 @@ def ctx_rows( today_ord = today.toordinal() - accounts = Account.get_asset_qty_all(s, today_ord, today_ord) + accounts = Account.get_asset_qty_all(today_ord, today_ord) qtys: dict[int, Decimal] = defaultdict(Decimal) for acct_qtys in accounts.values(): for a_id, values in acct_qtys.items(): @@ -628,14 +617,14 @@ def ctx_rows( held_ids = {a_id for a_id, qty in qtys.items() if qty} query = ( - s.query(Asset) + Asset.query() .where(Asset.category != AssetCategory.INDEX) .order_by(Asset.category) ) if not include_unheld: query = query.where(Asset.id_.in_(held_ids)) - prices = Asset.get_value_all(s, today_ord, today_ord, held_ids) - for asset in query.yield_per(YIELD_PER): + prices = Asset.get_value_all(today_ord, today_ord, held_ids) + for asset in sql.yield_(query): qty = qtys[asset.id_] price = prices[asset.id_][0] value = qty * price @@ -655,7 +644,6 @@ def ctx_rows( def ctx_asset( - s: orm.Session, a: Asset, today: datetime.date, period: str | None, @@ -667,7 +655,6 @@ def ctx_asset( """Get the context to build the asset details. Args: - s: SQL session to use a: Asset to generate context for today: Today's date period: Period to get table for @@ -681,7 +668,7 @@ def ctx_asset( """ valuation = ( - s.query(AssetValuation) + AssetValuation.query() .where(AssetValuation.asset_id == a.id_) .order_by(AssetValuation.date_ord.desc()) .first() @@ -692,18 +679,14 @@ def ctx_asset( else: current_value = valuation.value current_date = valuation.date - deletable = ( - s.query(TransactionSplit.id_) - .where(TransactionSplit.asset_id == a.id_) - .limit(1) - .scalar() - is None + query = TransactionSplit.query(TransactionSplit.id_).where( + TransactionSplit.asset_id == a.id_, ) + deletable = not sql.any_(query) - accounts = Account.map_name(s) + accounts = Account.map_name() query = ( - s.query(TransactionSplit) - .with_entities( + TransactionSplit.query( TransactionSplit.account_id, func.sum(TransactionSplit.asset_quantity), ) @@ -722,7 +705,7 @@ def ctx_asset( qty, qty * current_value, ) - for acct_id, qty in query.yield_per(YIELD_PER) + for acct_id, qty in sql.yield_(query) if qty ] @@ -738,15 +721,14 @@ def ctx_asset( "currency_format": CURRENCY_FORMATS[a.currency], "value": current_value, "value_date": current_date, - "performance": ctx_performance(s, a, today, period_chart), - "table": ctx_table(s, a, today, period, start, end, page), + "performance": ctx_performance(a, today, period_chart), + "table": ctx_table(a, today, period, start, end, page), "holdings": sorted(holdings, key=operator.itemgetter(2), reverse=True), "deletable": deletable, } def ctx_performance( - s: orm.Session, a: Asset, today: datetime.date, period: str | None, @@ -754,7 +736,6 @@ def ctx_performance( """Get the context to build the asset performance details. Args: - s: SQL session to use a: Asset to generate context for today: Today's date period: Chart-period to fetch performance for @@ -767,12 +748,10 @@ def ctx_performance( start, end = base.parse_period(period, today) end_ord = end.toordinal() if start is None: - start_ord = ( - s.query(func.min(AssetValuation.date_ord)) - .where(AssetValuation.asset_id == a.id_) - .scalar() - or end_ord + query = AssetValuation.query(func.min(AssetValuation.date_ord)).where( + AssetValuation.asset_id == a.id_, ) + start_ord = sql.one(query) or end_ord else: start_ord = start.toordinal() @@ -782,12 +761,11 @@ def ctx_performance( **base.chart_data(start_ord, end_ord, values), "period": period, "period_options": base.PERIOD_OPTIONS, - "currency_format": CURRENCY_FORMATS[Config.base_currency(s)]._asdict(), + "currency_format": CURRENCY_FORMATS[Config.base_currency()]._asdict(), } def ctx_table( - s: orm.Session, a: Asset, today: datetime.date, period: str | None, @@ -798,7 +776,6 @@ def ctx_table( """Get the context to build the valuations table. Args: - s: SQL session to use a: Asset to get valuations for today: Today's date period: Period to get table for @@ -813,7 +790,7 @@ def ctx_table( page_start = None if page is None else datetime.date.fromisoformat(page).toordinal() query = ( - s.query(AssetValuation) + AssetValuation.query() .where(AssetValuation.asset_id == a.id_) .order_by(AssetValuation.date_ord.desc()) ) @@ -854,7 +831,7 @@ def ctx_table( "date_max": None, "value": v.value, } - for v in query.limit(PAGE_LEN).yield_per(YIELD_PER) + for v in sql.yield_(query.limit(PAGE_LEN)) ] next_page = ( diff --git a/nummus/controllers/auth.py b/nummus/controllers/auth.py index 6b33e8c9..792b0af4 100644 --- a/nummus/controllers/auth.py +++ b/nummus/controllers/auth.py @@ -124,18 +124,18 @@ def login() -> str | werkzeug.Response: if not password: return base.error("Password must not be blank") - with p.begin_session() as s: - expected_encoded = Config.fetch(s, ConfigKey.WEB_KEY) + with p.begin_session(): + expected_encoded = Config.fetch(ConfigKey.WEB_KEY) - expected = p.decrypt(expected_encoded) - if password.encode() != expected: - return base.error("Bad password") + expected = p.decrypt(expected_encoded) + if password.encode() != expected: + return base.error("Bad password") - web_user = WebUser() - flask_login.login_user(web_user, remember=True) + web_user = WebUser() + flask_login.login_user(web_user, remember=True) - next_url = form.get("next") - return flask.redirect(next_url or flask.url_for("common.page_dashboard")) + next_url = form.get("next") + return flask.redirect(next_url or flask.url_for("common.page_dashboard")) def logout() -> str | werkzeug.Response: diff --git a/nummus/controllers/base.py b/nummus/controllers/base.py index 8bd2d913..9f367fd1 100644 --- a/nummus/controllers/base.py +++ b/nummus/controllers/base.py @@ -8,29 +8,23 @@ import textwrap from decimal import Decimal from pathlib import Path -from typing import NamedTuple, overload, TYPE_CHECKING, TypedDict +from typing import NamedTuple, overload, TypedDict import flask import flask.typing from nummus import exceptions as exc -from nummus import utils, web +from nummus import sql, utils, web from nummus.models.base import ( Base, BaseEnum, - YIELD_PER, ) from nummus.models.transaction_category import ( TransactionCategory, TransactionCategoryGroup, ) -from nummus.models.utils import query_count from nummus.version import __version__ -if TYPE_CHECKING: - import sqlalchemy - from sqlalchemy import orm - type Routes = dict[str, tuple[flask.typing.RouteCallable, list[str]]] @@ -445,11 +439,10 @@ def change_redirect_to_htmx(response: flask.Response) -> flask.Response: return response -def find[T: Base](s: orm.Session, cls: type[T], uri: str) -> T: +def find[T: Base](cls: type[T], uri: str) -> T: """Find the matching object by URI. Args: - s: SQL session to search cls: Type of object to find uri: URI to find @@ -466,7 +459,7 @@ def find[T: Base](s: orm.Session, cls: type[T], uri: str) -> T: except (exc.InvalidURIError, exc.WrongURITypeError) as e: raise exc.http.BadRequest(str(e)) from e try: - obj = s.query(cls).where(cls.id_ == id_).one() + obj = sql.one(cls.query().where(cls.id_ == id_)) except exc.NoResultFound as e: msg = f"{cls.__name__} {uri} not found in Portfolio" raise exc.http.NotFound(msg) from e @@ -558,9 +551,9 @@ def validate_string( *, is_required: bool = False, check_length: bool = True, - session: orm.Session | None = None, - no_duplicates: orm.QueryableAttribute | None = None, - no_duplicate_wheres: list[sqlalchemy.ColumnExpressionArgument] | None = None, + cls: type[Base] | None = None, + no_duplicates: sql.Column | None = None, + no_duplicate_wheres: list[sql.ColumnClause] | None = None, ) -> str: """Validate a string matches requirements. @@ -568,7 +561,7 @@ def validate_string( value: String to test is_required: True will require the value be non-empty check_length: True will require value to be MIN_STR_LEN long - session: SQL session to use for no_duplicates + cls: Model class to test for duplicate values no_duplicates: Property to test for duplicate values no_duplicate_wheres: Additional where clauses to add to no_duplicates @@ -582,11 +575,11 @@ def validate_string( if check_length and len(value) < utils.MIN_STR_LEN: # Ticker can be short return f"{utils.MIN_STR_LEN} characters required" - if no_duplicates is None: + if no_duplicates is None or cls is None: return "" return _test_duplicates( value, - session, + cls, no_duplicates, no_duplicate_wheres, ) @@ -594,19 +587,15 @@ def validate_string( def _test_duplicates( value: object, - session: orm.Session | None, - no_duplicates: orm.QueryableAttribute, - no_duplicate_wheres: list[sqlalchemy.ColumnExpressionArgument] | None, + cls: type[Base], + no_duplicates: sql.Column, + no_duplicate_wheres: list[sql.ColumnClause] | None, ) -> str: - if session is None: - msg = "Cannot test no_duplicates without a session" - raise TypeError(msg) - query = session.query(no_duplicates.parent).where( + query = cls.query().where( no_duplicates == value, *(no_duplicate_wheres or []), ) - n = query_count(query) - if n != 0: + if sql.any_(query): return "Must be unique" return "" @@ -617,9 +606,9 @@ def validate_date( *, is_required: bool = False, max_future: int | None = utils.DAYS_IN_WEEK, - session: orm.Session | None = None, - no_duplicates: orm.QueryableAttribute | None = None, - no_duplicate_wheres: list[sqlalchemy.ColumnExpressionArgument] | None = None, + cls: type[Base] | None = None, + no_duplicates: sql.Column | None = None, + no_duplicate_wheres: list[sql.ColumnClause] | None = None, ) -> str: """Validate a date string matches requirements. @@ -628,7 +617,7 @@ def validate_date( today: Today's date is_required: True will require the value be non-empty max_future: Maximum number of days date is allowed in the future - session: SQL session to use for no_duplicates + cls: Model class to test for duplicate values no_duplicates: Property to test for duplicate values no_duplicate_wheres: Additional where clauses to add to no_duplicates @@ -652,11 +641,11 @@ def validate_date( ): return f"Only up to {utils.format_days(max_future)} in advance" - if no_duplicates is None: + if no_duplicates is None or cls is None: return "" return _test_duplicates( date.toordinal(), - session, + cls, no_duplicates, no_duplicate_wheres, ) @@ -760,34 +749,27 @@ def parse_date( return date -def tranaction_category_groups(s: orm.Session) -> CategoryGroups: +def tranaction_category_groups() -> CategoryGroups: """Get TransactionCategory by groups. - Args: - s: SQL session to use - Returns: dict{group: list[CategoryContext]} """ - query = ( - s.query(TransactionCategory) - .with_entities( - TransactionCategory.id_, - TransactionCategory.name, - TransactionCategory.emoji_name, - TransactionCategory.asset_linked, - TransactionCategory.group, - ) - .order_by(TransactionCategory.name) - ) + query = TransactionCategory.query( + TransactionCategory.id_, + TransactionCategory.name, + TransactionCategory.emoji_name, + TransactionCategory.asset_linked, + TransactionCategory.group, + ).order_by(TransactionCategory.name) category_groups: CategoryGroups = { TransactionCategoryGroup.INCOME: [], TransactionCategoryGroup.EXPENSE: [], TransactionCategoryGroup.TRANSFER: [], TransactionCategoryGroup.OTHER: [], } - for t_cat_id, name, emoji_name, asset_linked, group in query.yield_per(YIELD_PER): + for t_cat_id, name, emoji_name, asset_linked, group in sql.yield_(query): category_groups[group].append( CategoryContext( TransactionCategory.id_to_uri(t_cat_id), diff --git a/nummus/controllers/budgeting.py b/nummus/controllers/budgeting.py index e2eda4cc..d77607ef 100644 --- a/nummus/controllers/budgeting.py +++ b/nummus/controllers/budgeting.py @@ -10,12 +10,10 @@ from typing import NamedTuple, NotRequired, TYPE_CHECKING, TypedDict import flask -from sqlalchemy import sql from nummus import exceptions as exc -from nummus import utils, web +from nummus import sql, utils, web from nummus.controllers import base -from nummus.models.base import YIELD_PER from nummus.models.budget import ( BudgetAssignment, BudgetGroup, @@ -29,11 +27,9 @@ TransactionCategory, TransactionCategoryGroup, ) -from nummus.models.utils import query_count if TYPE_CHECKING: import werkzeug.datastructures - from sqlalchemy import orm from nummus.models.budget import BudgetAvailableCategory from nummus.models.currency import CurrencyFormat @@ -158,10 +154,9 @@ def page() -> flask.Response: ) sidebar_uri = args.get("sidebar") or None - with p.begin_session() as s: - data = BudgetAssignment.get_monthly_available(s, month) + with p.begin_session(): + data = BudgetAssignment.get_monthly_available(month) budget, title = ctx_budget( - s, today, month, data.categories, @@ -169,7 +164,6 @@ def page() -> flask.Response: flask.session.get("groups_open", []), ) sidebar = ctx_sidebar( - s, today, month, data.categories, @@ -253,21 +247,21 @@ def assign(uri: str) -> str: form = flask.request.form amount = utils.evaluate_real_statement(form["amount"]) or Decimal() - with p.begin_session() as s: - cat = base.find(s, TransactionCategory, uri) + with p.begin_session(): + cat = base.find(TransactionCategory, uri) group_uri = ( None if cat.budget_group_id is None else BudgetGroup.id_to_uri(cat.budget_group_id) ) if amount == 0: - s.query(BudgetAssignment).where( + BudgetAssignment.query().where( BudgetAssignment.month_ord == month_ord, BudgetAssignment.category_id == cat.id_, ).delete() else: a = ( - s.query(BudgetAssignment) + BudgetAssignment.query() .where( BudgetAssignment.month_ord == month_ord, BudgetAssignment.category_id == cat.id_, @@ -275,18 +269,16 @@ def assign(uri: str) -> str: .one_or_none() ) if a is None: - a = BudgetAssignment( + a = BudgetAssignment.create( month_ord=month_ord, category_id=cat.id_, amount=amount, ) - s.add(a) else: a.amount = amount - data = BudgetAssignment.get_monthly_available(s, month) + data = BudgetAssignment.get_monthly_available(month) budget, _ = ctx_budget( - s, today, month, data.categories, @@ -295,7 +287,6 @@ def assign(uri: str) -> str: ) sidebar_uri = form.get("sidebar") or None sidebar = ctx_sidebar( - s, today, month, data.categories, @@ -328,15 +319,15 @@ def move(uri: str) -> str | flask.Response: month = datetime.date.fromisoformat(month_str + "-01") month_ord = month.toordinal() - with p.begin_session() as s: - cf = CURRENCY_FORMATS[Config.base_currency(s)] - data = BudgetAssignment.get_monthly_available(s, month) + with p.begin_session(): + cf = CURRENCY_FORMATS[Config.base_currency()] + data = BudgetAssignment.get_monthly_available(month) if uri == "income": src_cat = None src_cat_id = None src_available = data.assignable else: - src_cat = base.find(s, TransactionCategory, uri) + src_cat = base.find(TransactionCategory, uri) src_cat_id = src_cat.id_ src_available = data.categories[src_cat_id].available @@ -358,7 +349,7 @@ def move(uri: str) -> str | flask.Response: # Max of the negative number is min of the positive/abs to_move = max(src_available, -dest_available) - BudgetAssignment.move(s, month_ord, src_cat_id, dest_cat_id, to_move) + BudgetAssignment.move(month_ord, src_cat_id, dest_cat_id, to_move) return base.dialog_swap( event="budget", @@ -373,8 +364,7 @@ def move(uri: str) -> str | flask.Response: destination = args.get("destination") query = ( - s.query(TransactionCategory) - .with_entities( + TransactionCategory.query( TransactionCategory.id_, TransactionCategory.emoji_name, TransactionCategory.group, @@ -386,11 +376,7 @@ def move(uri: str) -> str | flask.Response: ) .order_by(TransactionCategory.group, TransactionCategory.name) ) - for t_cat_id, name, group in query.yield_per(YIELD_PER): - t_cat_id: int - name: str - group: TransactionCategoryGroup - + for t_cat_id, name, group in sql.yield_(query): t_cat_uri = TransactionCategory.id_to_uri(t_cat_id) available = data.categories[t_cat_id].available if destination or src_available > 0 or available > 0: @@ -425,7 +411,7 @@ def reorder() -> str: t_cat_uris = form.getlist("category-uri") groups = form.getlist("group") - with p.begin_session() as s: + with p.begin_session(): g_positions = { BudgetGroup.uri_to_id(g_uri): i for i, g_uri in enumerate(group_uris) } @@ -452,7 +438,7 @@ def reorder() -> str: last_group = g_uri # Set all to None first so swapping can occur without unique violations - s.query(TransactionCategory).update( + TransactionCategory.query().update( { TransactionCategory.budget_group_id: None, TransactionCategory.budget_position: None, @@ -460,11 +446,11 @@ def reorder() -> str: ) # Delete any groups - s.query(BudgetGroup).where(BudgetGroup.id_.not_in(g_positions)).delete() + BudgetGroup.query().where(BudgetGroup.id_.not_in(g_positions)).delete() if g_positions: # Set all to -index first so swapping can occur without unique violations - s.query(BudgetGroup).update( + BudgetGroup.query().update( { BudgetGroup.position: sql.case( {g_id: -i - 1 for i, g_id in enumerate(g_positions)}, @@ -474,7 +460,7 @@ def reorder() -> str: ) # Set new group positions - s.query(BudgetGroup).update( + BudgetGroup.query().update( { BudgetGroup.position: sql.case( g_positions, @@ -485,7 +471,7 @@ def reorder() -> str: if t_cat_positions: # Set new category positions - s.query(TransactionCategory).update( + TransactionCategory.query().update( { TransactionCategory.budget_group_id: sql.case( t_cat_groups, @@ -525,8 +511,8 @@ def group(uri: str) -> str: flask.session["groups_open"] = groups_open elif uri != "ungrouped": try: - with p.begin_session() as s: - g = base.find(s, BudgetGroup, uri) + with p.begin_session(): + g = base.find(BudgetGroup, uri) g.name = name except (exc.IntegrityError, exc.InvalidORMValueError) as e: return base.error(e) @@ -547,27 +533,26 @@ def new_group() -> str: """ p = web.portfolio name = "New group" - with p.begin_session() as s: - cf = CURRENCY_FORMATS[Config.base_currency(s)] + with p.begin_session(): + cf = CURRENCY_FORMATS[Config.base_currency()] # Ensure the name isn't a duplicate i = 1 - n = query_count(s.query(BudgetGroup).where(BudgetGroup.name == name)) - while n != 0: + + query = BudgetGroup.query().where(BudgetGroup.name == name) + while sql.any_(query): i += 1 name = f"New group {i}" - n = query_count(s.query(BudgetGroup).where(BudgetGroup.name == name)) + query = BudgetGroup.query().where(BudgetGroup.name == name) # Move existing groups down one - n = query_count(s.query(BudgetGroup)) + n = sql.count(BudgetGroup.query()) for i in range(n, -1, -1): # Do one at a time in reverse order to prevent duplicate value - s.query(BudgetGroup).where(BudgetGroup.position == i).update( + BudgetGroup.query().where(BudgetGroup.position == i).update( {BudgetGroup.position: i + 1}, ) - g = BudgetGroup(name=name, position=0) - s.add(g) - s.flush() + g = BudgetGroup.create(name=name, position=0) g_uri = g.uri ctx: GroupContext = { "position": 0, @@ -607,17 +592,16 @@ def target(uri: str) -> str | flask.Response: with p.begin_session() as s: try: - tar = base.find(s, Target, uri) + tar = base.find(Target, uri) t_cat_id = tar.category_id except exc.http.BadRequest: t_cat_id = TransactionCategory.uri_to_id(uri) - tar = s.query(Target).where(Target.category_id == t_cat_id).one_or_none() + tar = Target.query().where(Target.category_id == t_cat_id).one_or_none() - emoji_name = ( - s.query(TransactionCategory.emoji_name) - .where(TransactionCategory.id_ == t_cat_id) - .one()[0] + query = TransactionCategory.query(TransactionCategory.emoji_name).where( + TransactionCategory.id_ == t_cat_id, ) + emoji_name = sql.one(query) new_target = tar is None if tar is None: @@ -631,7 +615,7 @@ def target(uri: str) -> str | flask.Response: repeat_every=1, ) elif flask.request.method == "DELETE": - s.delete(tar) + tar.delete() return base.dialog_swap( event="budget", snackbar=f"{emoji_name} target deleted", @@ -679,7 +663,7 @@ def target(uri: str) -> str | flask.Response: "from_amount": ( flask.request.headers.get("HX-Trigger") == "budgeting-amount" ), - "currency_format": CURRENCY_FORMATS[Config.base_currency(s)], + "currency_format": CURRENCY_FORMATS[Config.base_currency()], } # Don't make the changes s.rollback() @@ -775,13 +759,9 @@ def sidebar() -> flask.Response: ) uri = args.get("uri") - with p.begin_session() as s: - data = BudgetAssignment.get_monthly_available( - s, - month, - ) + with p.begin_session(): + data = BudgetAssignment.get_monthly_available(month) sidebar = ctx_sidebar( - s, today, month, data.categories, @@ -792,7 +772,7 @@ def sidebar() -> flask.Response: "budgeting/sidebar.jinja", ctx={ "month": month_str, - "currency_format": CURRENCY_FORMATS[Config.base_currency(s)], + "currency_format": CURRENCY_FORMATS[Config.base_currency()], }, budget_sidebar=sidebar, ) @@ -810,7 +790,6 @@ def sidebar() -> flask.Response: def ctx_sidebar( - s: orm.Session, today: datetime.date, month: datetime.date, categories: dict[int, BudgetAvailableCategory], @@ -820,7 +799,6 @@ def ctx_sidebar( """Get the context to build the budgeting sidebar. Args: - s: SQL session to use today: Today's date month: Month of table categories: Dict of categories from Budget.get_monthly_available @@ -839,13 +817,13 @@ def ctx_sidebar( total_assigned = Decimal() total_activity = Decimal() - query = s.query(TransactionCategory.id_).where( + query = TransactionCategory.query(TransactionCategory.id_).where( TransactionCategory.group == TransactionCategoryGroup.INCOME, ) - income_ids = {row[0] for row in query.all()} + income_ids = set(sql.col0(query)) targets: dict[int, Target] = { - t.category_id: t for t in s.query(Target).yield_per(YIELD_PER) + t.category_id: t for t in sql.yield_(Target.query()) } no_target: set[int] = set() @@ -873,8 +851,7 @@ def ctx_sidebar( total_to_go += target_ctx["to_go"] query = ( - s.query(TransactionCategory) - .with_entities( + TransactionCategory.query( TransactionCategory.id_, TransactionCategory.emoji_name, ) @@ -883,7 +860,7 @@ def ctx_sidebar( ) no_target_names: dict[str, str] = { TransactionCategory.id_to_uri(t_cat_id): name - for t_cat_id, name in query.all() + for t_cat_id, name in sql.yield_(query) } return { @@ -901,11 +878,11 @@ def ctx_sidebar( "no_target": no_target_names, "target": None, } - t_cat = base.find(s, TransactionCategory, uri) + t_cat = base.find(TransactionCategory, uri) t_cat_id = t_cat.id_ assigned, activity, available, leftover = categories[t_cat_id] - tar = s.query(Target).where(Target.category_id == t_cat_id).one_or_none() + tar = Target.query().where(Target.category_id == t_cat_id).one_or_none() if tar is None: return { "uri": uri, @@ -1082,7 +1059,6 @@ def ctx_target( def ctx_budget( - s: orm.Session, today: datetime.date, month: datetime.date, categories: dict[int, BudgetAvailableCategory], @@ -1092,7 +1068,6 @@ def ctx_budget( """Get the context to build the budgeting table. Args: - s: SQL session to use today: Today's date month: Month of table categories: Dict of categories from Budget.get_monthly_available @@ -1105,13 +1080,10 @@ def ctx_budget( """ n_overspent = 0 - targets: dict[int, Target] = { - t.category_id: t for t in s.query(Target).yield_per(YIELD_PER) - } + targets: dict[int, Target] = {t.category_id: t for t in sql.yield_(Target.query())} groups: dict[int | None, GroupContext] = {} - query = s.query(BudgetGroup) - for g in query.all(): + for g in sql.yield_(BudgetGroup.query()): groups[g.id_] = { "position": g.position, "name": g.name, @@ -1135,8 +1107,7 @@ def ctx_budget( "has_error": False, } - query = s.query(TransactionCategory) - for t_cat in query.yield_per(YIELD_PER): + for t_cat in sql.yield_(TransactionCategory.query()): assigned, activity, available, leftover = categories[t_cat.id_] tar = targets.get(t_cat.id_) # Skip category if all numbers are 0 and not grouped @@ -1209,7 +1180,7 @@ def ctx_budget( "assignable": assignable, "groups": groups_list, "n_overspent": n_overspent, - "currency_format": CURRENCY_FORMATS[Config.base_currency(s)], + "currency_format": CURRENCY_FORMATS[Config.base_currency()], }, title, ) diff --git a/nummus/controllers/emergency_fund.py b/nummus/controllers/emergency_fund.py index ce28204e..8a6942a0 100644 --- a/nummus/controllers/emergency_fund.py +++ b/nummus/controllers/emergency_fund.py @@ -16,8 +16,6 @@ if TYPE_CHECKING: from decimal import Decimal - from sqlalchemy import orm - from nummus.models.currency import CurrencyFormat @@ -62,11 +60,11 @@ def page() -> flask.Response: """ p = web.portfolio - with p.begin_session() as s: + with p.begin_session(): return base.page( "emergency-fund/page.jinja", "Emergency fund", - ctx=ctx_page(s, base.today_client()), + ctx=ctx_page(base.today_client()), ) @@ -78,18 +76,17 @@ def dashboard() -> str: """ p = web.portfolio - with p.begin_session() as s: + with p.begin_session(): return flask.render_template( "emergency-fund/dashboard.jinja", - ctx=ctx_page(s, base.today_client()), + ctx=ctx_page(base.today_client()), ) -def ctx_page(s: orm.Session, today: datetime.date) -> EFundContext: +def ctx_page(today: datetime.date) -> EFundContext: """Get the context to build the emergency fund page. Args: - s: SQL session to use today: Today's date Returns: @@ -103,7 +100,6 @@ def ctx_page(s: orm.Session, today: datetime.date) -> EFundContext: t_lowers, t_uppers, balances, categories, categories_total = ( BudgetAssignment.get_emergency_fund( - s, start_ord, today_ord, utils.DAYS_IN_QUARTER, @@ -144,7 +140,7 @@ def ctx_page(s: orm.Session, today: datetime.date) -> EFundContext: key=lambda item: (-round(item["monthly"], 2), item["name"]), ) - cf = CURRENCY_FORMATS[Config.base_currency(s)] + cf = CURRENCY_FORMATS[Config.base_currency()] return { "chart": { "labels": [d.isoformat() for d in dates], diff --git a/nummus/controllers/health.py b/nummus/controllers/health.py index 107a15c6..de25a212 100644 --- a/nummus/controllers/health.py +++ b/nummus/controllers/health.py @@ -5,20 +5,16 @@ import datetime import operator from collections import defaultdict -from typing import TYPE_CHECKING, TypedDict +from typing import TypedDict import flask -from nummus import web +from nummus import sql, web from nummus.controllers import base from nummus.health_checks.top import HEALTH_CHECKS -from nummus.models.base import YIELD_PER from nummus.models.config import Config, ConfigKey from nummus.models.health_checks import HealthCheckIssue -if TYPE_CHECKING: - from sqlalchemy import orm - class HealthContext(TypedDict): """Type definition for health page context.""" @@ -45,11 +41,11 @@ def page() -> flask.Response: """ p = web.portfolio - with p.begin_session() as s: + with p.begin_session(): return base.page( "health/page.jinja", title="Health", - ctx=ctx_checks(s, run=False), + ctx=ctx_checks(run=False), ) @@ -61,10 +57,10 @@ def refresh() -> str: """ p = web.portfolio - with p.begin_session() as s: + with p.begin_session(): return flask.render_template( "health/checks.jinja", - ctx=ctx_checks(s, run=True), + ctx=ctx_checks(run=True), include_oob=True, ) @@ -80,12 +76,12 @@ def ignore(uri: str) -> str: """ p = web.portfolio - with p.begin_session() as s: - c = base.find(s, HealthCheckIssue, uri) + with p.begin_session(): + c = base.find(HealthCheckIssue, uri) c.ignore = True name = c.check - checks = ctx_checks(s, run=False)["checks"] + checks = ctx_checks(run=False)["checks"] return flask.render_template( "health/check-row.jinja", @@ -94,11 +90,10 @@ def ignore(uri: str) -> str: ) -def ctx_checks(s: orm.Session, *, run: bool) -> HealthContext: +def ctx_checks(*, run: bool) -> HealthContext: """Get the context to build the health checks. Args: - s: SQL session to use run: True will rerun health checks Returns: @@ -109,17 +104,17 @@ def ctx_checks(s: orm.Session, *, run: bool) -> HealthContext: issues: dict[str, dict[str, str]] = defaultdict(dict) if run: - Config.set_(s, ConfigKey.LAST_HEALTH_CHECK_TS, utc_now.isoformat()) + Config.set_(ConfigKey.LAST_HEALTH_CHECK_TS, utc_now.isoformat()) last_update = utc_now else: - last_update_str = Config.fetch(s, ConfigKey.LAST_HEALTH_CHECK_TS, no_raise=True) + last_update_str = Config.fetch(ConfigKey.LAST_HEALTH_CHECK_TS, no_raise=True) last_update = ( None if last_update_str is None else datetime.datetime.fromisoformat(last_update_str) ) - query = s.query(HealthCheckIssue).where(HealthCheckIssue.ignore.is_(False)) - for i in query.yield_per(YIELD_PER): + query = HealthCheckIssue.query().where(HealthCheckIssue.ignore.is_(False)) + for i in sql.yield_(query): issues[i.check][i.uri] = i.msg checks: list[HealthCheckContext] = [] @@ -128,7 +123,7 @@ def ctx_checks(s: orm.Session, *, run: bool) -> HealthContext: if run: c = check_type() - c.test(s) + c.test() c_issues = c.issues else: c_issues = issues[name] diff --git a/nummus/controllers/income.py b/nummus/controllers/income.py index 38076bb7..ba42a2d1 100644 --- a/nummus/controllers/income.py +++ b/nummus/controllers/income.py @@ -17,10 +17,9 @@ def page() -> flask.Response: """ args = flask.request.args p = web.portfolio - with p.begin_session() as s: + with p.begin_session(): today = base.today_client() ctx, title = spending.ctx_chart( - s, today, args.get("account"), args.get("category"), @@ -48,10 +47,9 @@ def chart() -> flask.Response: """ args = flask.request.args p = web.portfolio - with p.begin_session() as s: + with p.begin_session(): today = base.today_client() ctx, title = spending.ctx_chart( - s, today, args.get("account"), args.get("category"), @@ -89,10 +87,9 @@ def dashboard() -> str: """ p = web.portfolio - with p.begin_session() as s: + with p.begin_session(): today = base.today_client() ctx, _ = spending.ctx_chart( - s, today, None, None, diff --git a/nummus/controllers/labels.py b/nummus/controllers/labels.py index 477b9835..110caeb2 100644 --- a/nummus/controllers/labels.py +++ b/nummus/controllers/labels.py @@ -2,19 +2,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING - import flask from nummus import exceptions as exc -from nummus import web +from nummus import sql, web from nummus.controllers import base -from nummus.models.base import YIELD_PER from nummus.models.label import Label, LabelLink -if TYPE_CHECKING: - from sqlalchemy import orm - def page() -> flask.Response: """GET /labels. @@ -25,11 +19,11 @@ def page() -> flask.Response: """ p = web.portfolio - with p.begin_session() as s: + with p.begin_session(): return base.page( "labels/page.jinja", "Labels", - labels=ctx_labels(s), + labels=ctx_labels(), ) @@ -45,7 +39,7 @@ def label(uri: str) -> str | flask.Response: """ p = web.portfolio with p.begin_session() as s: - label = base.find(s, Label, uri) + label = base.find(Label, uri) if flask.request.method == "GET": ctx: dict[str, object] = { @@ -59,11 +53,10 @@ def label(uri: str) -> str | flask.Response: ) if flask.request.method == "DELETE": - - s.query(LabelLink).where( + LabelLink.query().where( LabelLink.label_id == label.id_, ).delete() - s.delete(label) + label.delete() return base.dialog_swap( event="label", @@ -96,11 +89,11 @@ def validation() -> str: args = flask.request.args uri = args["uri"] if "name" in args: - with p.begin_session() as s: + with p.begin_session(): return base.validate_string( args["name"], is_required=True, - session=s, + cls=Label, no_duplicates=Label.name, no_duplicate_wheres=([Label.id_ != Label.uri_to_id(uri)]), ) @@ -108,20 +101,15 @@ def validation() -> str: raise NotImplementedError -def ctx_labels(s: orm.Session) -> list[base.NamePair]: +def ctx_labels() -> list[base.NamePair]: """Get the context required to build the labels table. - Args: - s: SQL session to use - Returns: List of HTML context """ - query = s.query(Label).order_by(Label.name) - return [ - base.NamePair(label.uri, label.name) for label in query.yield_per(YIELD_PER) - ] + query = Label.query().order_by(Label.name) + return [base.NamePair(label.uri, label.name) for label in sql.yield_(query)] ROUTES: base.Routes = { diff --git a/nummus/controllers/net_worth.py b/nummus/controllers/net_worth.py index fe0f42a0..901d2e3e 100644 --- a/nummus/controllers/net_worth.py +++ b/nummus/controllers/net_worth.py @@ -9,7 +9,7 @@ import flask from sqlalchemy import func -from nummus import utils, web +from nummus import sql, utils, web from nummus.controllers import base from nummus.models.account import Account from nummus.models.asset import Asset @@ -21,7 +21,6 @@ from nummus.models.transaction import TransactionSplit if TYPE_CHECKING: - from sqlalchemy import orm from nummus.controllers.base import Routes from nummus.models.currency import Currency, CurrencyFormat @@ -63,9 +62,8 @@ def page() -> flask.Response: """ args = flask.request.args p = web.portfolio - with p.begin_session() as s: + with p.begin_session(): ctx = ctx_chart( - s, base.today_client(), args.get("period", base.DEFAULT_PERIOD), ) @@ -86,8 +84,8 @@ def chart() -> flask.Response: args = flask.request.args period = args.get("period", base.DEFAULT_PERIOD) p = web.portfolio - with p.begin_session() as s: - ctx = ctx_chart(s, base.today_client(), period) + with p.begin_session(): + ctx = ctx_chart(base.today_client(), period) html = flask.render_template( "net-worth/chart-data.jinja", ctx=ctx, @@ -113,9 +111,8 @@ def dashboard() -> str: """ p = web.portfolio - with p.begin_session() as s: + with p.begin_session(): ctx = ctx_chart( - s, base.today_client(), base.DEFAULT_PERIOD, ) @@ -126,14 +123,12 @@ def dashboard() -> str: def ctx_chart( - s: orm.Session, today: datetime.date, period: str, ) -> Context: """Get the context to build the net worth chart. Args: - s: SQL session to use today: Today's date period: Selected chart period @@ -144,22 +139,22 @@ def ctx_chart( start, end = base.parse_period(period, today) if start is None: - query = s.query(func.min(TransactionSplit.date_ord)).where( + query = TransactionSplit.query(func.min(TransactionSplit.date_ord)).where( TransactionSplit.asset_id.is_(None), ) - start_ord = query.scalar() + start_ord = sql.scalar(query) start = datetime.date.fromordinal(start_ord) if start_ord else end start_ord = start.toordinal() end_ord = end.toordinal() - query = s.query(Account) account_currencies: dict[int, Currency] = { - acct.id_: acct.currency for acct in query.all() if acct.do_include(start_ord) + acct.id_: acct.currency + for acct in sql.yield_(Account.query()) + if acct.do_include(start_ord) } - base_currency = Config.base_currency(s) + base_currency = Config.base_currency() forex = Asset.get_forex( - s, start_ord, end_ord, base_currency, @@ -167,7 +162,6 @@ def ctx_chart( ) acct_values, _, _ = Account.get_value_all( - s, start_ord, end_ord, ids=account_currencies.keys(), @@ -182,7 +176,7 @@ def ctx_chart( ] or [Decimal()] * (end_ord - start_ord + 1) data_tuple = base.chart_data(start_ord, end_ord, (total, *acct_values.values())) - mapping = Account.map_name(s) + mapping = Account.map_name() ctx_accounts: list[AccountContext] = [ { diff --git a/nummus/controllers/performance.py b/nummus/controllers/performance.py index 2d743e3a..e18a8c2d 100644 --- a/nummus/controllers/performance.py +++ b/nummus/controllers/performance.py @@ -11,7 +11,7 @@ import flask from sqlalchemy import func -from nummus import utils, web +from nummus import sql, utils, web from nummus.controllers import base from nummus.models.account import Account, AccountCategory from nummus.models.asset import ( @@ -24,9 +24,9 @@ CURRENCY_FORMATS, ) from nummus.models.transaction import TransactionSplit +from nummus.sql import yield_ if TYPE_CHECKING: - from sqlalchemy import orm from nummus.models.currency import Currency, CurrencyFormat @@ -92,9 +92,8 @@ def page() -> flask.Response: """ args = flask.request.args p = web.portfolio - with p.begin_session() as s: + with p.begin_session(): ctx = ctx_chart( - s, base.today_client(), args.get("period", base.DEFAULT_PERIOD), args.get("index", _DEFAULT_INDEX), @@ -119,9 +118,8 @@ def chart() -> flask.Response: index = args.get("index", _DEFAULT_INDEX) excluded_accounts = args.getlist("exclude") p = web.portfolio - with p.begin_session() as s: + with p.begin_session(): ctx = ctx_chart( - s, base.today_client(), period, index, @@ -154,8 +152,8 @@ def dashboard() -> str: """ p = web.portfolio - with p.begin_session() as s: - acct_ids = Account.ids(s, AccountCategory.INVESTMENT) + with p.begin_session(): + acct_ids = Account.ids(AccountCategory.INVESTMENT) end = base.today_client() start = end - datetime.timedelta(days=90) start_ord = start.toordinal() @@ -164,16 +162,15 @@ def dashboard() -> str: indices: dict[str, Decimal] = {} query = ( - s.query(Asset.name) + Asset.query(Asset.name) .where(Asset.category == AssetCategory.INDEX) .order_by(Asset.name) ) - for (name,) in query.all(): - twrr = Asset.index_twrr(s, name, start_ord, end_ord) + for (name,) in yield_(query): + twrr = Asset.index_twrr(name, start_ord, end_ord) indices[name] = twrr[-1] acct_values, acct_profits, _ = Account.get_value_all( - s, start_ord, end_ord, ids=acct_ids, @@ -191,7 +188,7 @@ def dashboard() -> str: "pnl": total_profit[-1], "twrr": twrr[-1], "indices": indices, - "currency_format": CURRENCY_FORMATS[Config.base_currency(s)], + "currency_format": CURRENCY_FORMATS[Config.base_currency()], } return flask.render_template( "performance/dashboard.jinja", @@ -200,7 +197,6 @@ def dashboard() -> str: def ctx_chart( - s: orm.Session, today: datetime.date, period: str, index: str, @@ -209,7 +205,6 @@ def ctx_chart( """Get the context to build the performance chart. Args: - s: SQL session to use today: Today's date period: Selected chart period index: Selected index to compare against @@ -223,34 +218,34 @@ def ctx_chart( ctx_accounts: list[AccountContext] = [] - acct_ids = Account.ids(s, AccountCategory.INVESTMENT) + acct_ids = Account.ids(AccountCategory.INVESTMENT) if start is None: - query = s.query(func.min(TransactionSplit.date_ord)).where( + query = TransactionSplit.query(func.min(TransactionSplit.date_ord)).where( TransactionSplit.asset_id.is_(None), TransactionSplit.account_id.in_(acct_ids), ) - start_ord = query.scalar() + start_ord = sql.scalar(query) start = datetime.date.fromordinal(start_ord) if start_ord else end start_ord = start.toordinal() end_ord = end.toordinal() n = end_ord - start_ord + 1 query = ( - s.query(Asset.name) + Asset.query(Asset.name) .where(Asset.category == AssetCategory.INDEX) .order_by(Asset.name) ) - indices: list[str] = [name for (name,) in query.all()] - index_description: str | None = ( - s.query(Asset.description).where(Asset.name == index).scalar() + indices: list[str] = list(sql.col0(query)) + index_description: str | None = sql.scalar( + Asset.query(Asset.description).where(Asset.name == index), ) - query = s.query(Account).where(Account.id_.in_(acct_ids)) + query = Account.query().where(Account.id_.in_(acct_ids)) mapping: dict[int, str] = {} currencies: dict[int, Currency] = {} acct_ids.clear() account_options: list[base.NamePairState] = [] - for acct in query.all(): + for acct in sql.yield_(query): if acct.do_include(start_ord): excluded = acct.id_ in excluded_accounts account_options.append( @@ -261,9 +256,8 @@ def ctx_chart( currencies[acct.id_] = acct.currency acct_ids.add(acct.id_) - base_currency = Config.base_currency(s) + base_currency = Config.base_currency() forex = Asset.get_forex( - s, start_ord, end_ord, base_currency, @@ -271,7 +265,6 @@ def ctx_chart( ) acct_values, acct_profits, _ = Account.get_value_all( - s, start_ord, end_ord, ids=acct_ids, @@ -287,7 +280,7 @@ def ctx_chart( twrr = utils.twrr(total, total_profit) mwrr = utils.mwrr(total, total_profit) - index_twrr = Asset.index_twrr(s, index, start_ord, end_ord) + index_twrr = Asset.index_twrr(index, start_ord, end_ord) sum_cash_flow = Decimal(0) diff --git a/nummus/controllers/settings.py b/nummus/controllers/settings.py index c97d6b15..df08276b 100644 --- a/nummus/controllers/settings.py +++ b/nummus/controllers/settings.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, TypedDict +from typing import TypedDict import flask @@ -11,9 +11,6 @@ from nummus.models.config import Config, ConfigKey from nummus.models.currency import Currency -if TYPE_CHECKING: - from sqlalchemy import orm - class SettingsContext(TypedDict): """Type definition for settings context.""" @@ -30,11 +27,11 @@ def page() -> flask.Response: """ p = web.portfolio - with p.begin_session() as s: + with p.begin_session(): return base.page( "settings/page.jinja", "Settings", - ctx=ctx_settings(s), + ctx=ctx_settings(), ) @@ -48,26 +45,23 @@ def edit() -> flask.Response: p = web.portfolio currency = flask.request.form.get("currency", type=Currency) if currency: - with p.begin_session() as s: - Config.set_(s, ConfigKey.BASE_CURRENCY, str(currency.value)) + with p.begin_session(): + Config.set_(ConfigKey.BASE_CURRENCY, str(currency.value)) else: raise NotImplementedError return base.dialog_swap(event="config", snackbar="All changes saved") -def ctx_settings(s: orm.Session) -> SettingsContext: +def ctx_settings() -> SettingsContext: """Get the context to build the settings page. - Args: - s: SQL session to use - Returns: SettingsContext """ return { - "currency": Config.base_currency(s), + "currency": Config.base_currency(), "currency_type": Currency, } diff --git a/nummus/controllers/spending.py b/nummus/controllers/spending.py index b314da5b..bdd7372f 100644 --- a/nummus/controllers/spending.py +++ b/nummus/controllers/spending.py @@ -9,10 +9,9 @@ import flask from sqlalchemy import func -from nummus import utils, web +from nummus import sql, utils, web from nummus.controllers import base from nummus.models.account import Account -from nummus.models.base import YIELD_PER from nummus.models.config import Config from nummus.models.currency import ( Currency, @@ -24,12 +23,10 @@ TransactionCategory, TransactionCategoryGroup, ) -from nummus.models.utils import query_count if TYPE_CHECKING: from decimal import Decimal - import sqlalchemy from sqlalchemy import orm from nummus.models.currency import Currency, CurrencyFormat @@ -55,7 +52,7 @@ class Context(OptionsContext): start: str | None end: str | None by_account: list[tuple[str, Decimal]] - by_payee: list[tuple[str, Decimal]] + by_payee: list[tuple[str | None, Decimal]] by_category: list[tuple[str, Decimal]] by_label: list[tuple[str | None, Decimal]] currency_format: CurrencyFormat @@ -65,7 +62,7 @@ class DataQuery(NamedTuple): """Type definition for result of data_query().""" query: orm.Query[TransactionSplit] - clauses: dict[str, sqlalchemy.ColumnElement] + clauses: dict[str, sql.ColumnClause] any_filters: bool @property @@ -83,10 +80,9 @@ def page() -> flask.Response: """ args = flask.request.args p = web.portfolio - with p.begin_session() as s: + with p.begin_session(): today = base.today_client() ctx, title = ctx_chart( - s, today, args.get("account"), args.get("category"), @@ -114,10 +110,9 @@ def chart() -> flask.Response: """ args = flask.request.args p = web.portfolio - with p.begin_session() as s: + with p.begin_session(): today = base.today_client() ctx, title = ctx_chart( - s, today, args.get("account"), args.get("category"), @@ -155,10 +150,9 @@ def dashboard() -> str: """ p = web.portfolio - with p.begin_session() as s: + with p.begin_session(): today = base.today_client() ctx, _ = ctx_chart( - s, today, None, None, @@ -177,7 +171,6 @@ def dashboard() -> str: def data_query( - s: orm.Session, selected_currency: Currency, selected_account: str | None = None, selected_period: str | None = None, @@ -191,7 +184,6 @@ def data_query( """Create transactions data query. Args: - s: SQL session to use selected_currency: Currency to filter by selected_account: URI of account from args selected_period: Name of period from args @@ -214,18 +206,18 @@ def data_query( ), TransactionCategoryGroup.TRANSFER, } - query = s.query(TransactionCategory.id_).where( + query = TransactionCategory.query(TransactionCategory.id_).where( (TransactionCategory.name == "securities traded") | TransactionCategory.group.in_(skip_groups), ) - skip_ids = {r[0] for r in query.yield_per(YIELD_PER)} - query = s.query(Account.id_).where(Account.currency != selected_currency) - skip_acct_ids = {r[0] for r in query.yield_per(YIELD_PER)} - query = s.query(TransactionSplit).where( + skip_ids = set(sql.col0(query)) + query = Account.query(Account.id_).where(Account.currency != selected_currency) + skip_acct_ids = set(sql.col0(query)) + query = TransactionSplit.query().where( TransactionSplit.category_id.not_in(skip_ids), TransactionSplit.account_id.not_in(skip_acct_ids), ) - clauses: dict[str, sqlalchemy.ColumnElement] = {} + clauses: dict[str, sql.ColumnClause] = {} any_filters = False @@ -264,11 +256,11 @@ def data_query( label_id = Label.uri_to_id(selected_label) label_query = ( query.join(LabelLink) - .with_entities(TransactionSplit.id_) + .with_entities(TransactionSplit.id_) # nummus: ignore .where(LabelLink.label_id == label_id) .distinct() ) - t_split_ids: set[int] = {r[0] for r in label_query.yield_per(YIELD_PER)} + t_split_ids: set[int] = {r[0] for r in sql.yield_(label_query)} clauses["label"] = TransactionSplit.id_.in_(t_split_ids) return DataQuery(query, clauses, any_filters) @@ -315,14 +307,14 @@ def ctx_options( clauses = dat_query.clauses.copy() clauses.pop("account", None) query_options = ( - query.with_entities(TransactionSplit.account_id) + query.with_entities(TransactionSplit.account_id) # nummus: ignore .where(*clauses.values()) .distinct() ) options_account = sorted( [ base.NamePair(Account.id_to_uri(acct_id), accounts[acct_id]) - for acct_id, in query_options.yield_per(YIELD_PER) + for acct_id, in sql.yield_(query_options) ], key=operator.itemgetter(0), ) @@ -333,12 +325,12 @@ def ctx_options( clauses = dat_query.clauses.copy() clauses.pop("category", None) query_options = ( - query.with_entities(TransactionSplit.category_id) + query.with_entities(TransactionSplit.category_id) # nummus: ignore .where(*clauses.values()) .distinct() ) options_uris = { - TransactionCategory.id_to_uri(r[0]) for r in query_options.yield_per(YIELD_PER) + TransactionCategory.id_to_uri(r[0]) for r in sql.yield_(query_options) } if selected_category: options_uris.add(selected_category) @@ -356,14 +348,14 @@ def ctx_options( clauses.pop("label", None) query_options = ( query.join(LabelLink) - .with_entities(LabelLink.label_id) + .with_entities(LabelLink.label_id) # nummus: ignore .where(*clauses.values()) .distinct() ) options_label = sorted( [ base.NamePair(Label.id_to_uri(label_id), labels[label_id]) - for label_id, in query_options.yield_per(YIELD_PER) + for label_id, in sql.yield_(query_options) ], key=operator.itemgetter(1), ) @@ -380,7 +372,6 @@ def ctx_options( def ctx_chart( - s: orm.Session, today: datetime.date, selected_account: str | None, selected_category: str | None, @@ -394,7 +385,6 @@ def ctx_chart( """Get the context to build the chart data. Args: - s: SQL session to use today: Today's date selected_account: Selected account for filtering selected_category: Selected category for filtering @@ -409,13 +399,12 @@ def ctx_chart( tuple(Context, title) """ - accounts = Account.map_name(s) - base_currency = Config.base_currency(s) - categories_emoji = TransactionCategory.map_name_emoji(s) - labels = Label.map_name(s) + accounts = Account.map_name() + base_currency = Config.base_currency() + categories_emoji = TransactionCategory.map_name_emoji() + labels = Label.map_name() dat_query = data_query( - s, base_currency, selected_account, selected_period, @@ -429,7 +418,7 @@ def ctx_chart( dat_query, today, accounts, - base.tranaction_category_groups(s), + base.tranaction_category_groups(), labels, selected_account, selected_category, @@ -437,51 +426,51 @@ def ctx_chart( ) final_query = dat_query.final_query - n_matches = query_count(final_query) + n_matches = sql.count(final_query) if not n_matches: # If no matches, reset period to all selected_period = None dat_query.clauses.pop("start", None) dat_query.clauses.pop("end", None) final_query = dat_query.final_query - n_matches = query_count(final_query) + n_matches = sql.count(final_query) - query = final_query.with_entities( + query = final_query.with_entities( # nummus: ignore TransactionSplit.account_id, func.sum(TransactionSplit.amount), ).group_by(TransactionSplit.account_id) by_account: list[tuple[str, Decimal]] = [ (accounts[account_id], amount if is_income else -amount) - for account_id, amount in query.yield_per(YIELD_PER) + for account_id, amount in sql.yield_(query) if amount ] by_account = sorted(by_account, key=operator.itemgetter(1), reverse=True) - query = final_query.with_entities( + query = final_query.with_entities( # nummus: ignore TransactionSplit.payee, func.sum(TransactionSplit.amount), ).group_by(TransactionSplit.payee) - by_payee: list[tuple[str, Decimal]] = [ + by_payee: list[tuple[str | None, Decimal]] = [ (payee, amount if is_income else -amount) - for payee, amount in query.yield_per(YIELD_PER) + for payee, amount in sql.yield_(query) if amount ] by_payee = sorted(by_payee, key=operator.itemgetter(1), reverse=True) - query = final_query.with_entities( + query = final_query.with_entities( # nummus: ignore TransactionSplit.category_id, func.sum(TransactionSplit.amount), ).group_by(TransactionSplit.category_id) by_category: list[tuple[str, Decimal]] = [ (categories_emoji[cat_id], amount if is_income else -amount) - for cat_id, amount in query.yield_per(YIELD_PER) + for cat_id, amount in sql.yield_(query) if amount ] by_category = sorted(by_category, key=operator.itemgetter(1), reverse=True) query = ( final_query.join(LabelLink, full=True) - .with_entities( + .with_entities( # nummus: ignore LabelLink.label_id, func.sum(TransactionSplit.amount), ) @@ -490,11 +479,11 @@ def ctx_chart( selected_label_id = selected_label and Label.uri_to_id(selected_label) by_label: list[tuple[str | None, Decimal, bool]] = [ ( - label_id and labels[label_id], + labels.get(label_id), amount if is_income else -amount, - label_id and label_id == selected_label_id, + label_id == selected_label_id if label_id else False, ) - for label_id, amount in query.yield_per(YIELD_PER) + for label_id, amount in sql.yield_(query) if amount ] by_label = sorted(by_label, key=operator.itemgetter(1), reverse=True) diff --git a/nummus/controllers/transaction_categories.py b/nummus/controllers/transaction_categories.py index 2816dee6..af1ec7c3 100644 --- a/nummus/controllers/transaction_categories.py +++ b/nummus/controllers/transaction_categories.py @@ -2,23 +2,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING - import flask from nummus import exceptions as exc -from nummus import utils, web +from nummus import sql, utils, web from nummus.controllers import base -from nummus.models.base import YIELD_PER from nummus.models.transaction import TransactionSplit from nummus.models.transaction_category import ( TransactionCategory, TransactionCategoryGroup, ) -from nummus.models.utils import query_count - -if TYPE_CHECKING: - from sqlalchemy import orm def page() -> flask.Response: @@ -30,11 +23,11 @@ def page() -> flask.Response: """ p = web.portfolio - with p.begin_session() as s: + with p.begin_session(): return base.page( "transaction-categories/page.jinja", "Transaction categories", - groups=ctx_categories(s), + groups=ctx_categories(), ) @@ -65,8 +58,8 @@ def new() -> str | flask.Response: try: p = web.portfolio - with p.begin_session() as s: - cat = TransactionCategory( + with p.begin_session(): + TransactionCategory.create( emoji_name=name, group=group, locked=False, @@ -74,7 +67,6 @@ def new() -> str | flask.Response: asset_linked=False, essential_spending=essential_spending, ) - s.add(cat) except (exc.IntegrityError, exc.InvalidORMValueError) as e: return base.error(e) @@ -99,7 +91,7 @@ def category(uri: str) -> str | flask.Response: """ p = web.portfolio with p.begin_session() as s: - cat = base.find(s, TransactionCategory, uri) + cat = base.find(TransactionCategory, uri) if flask.request.method == "GET": ctx: dict[str, object] = { @@ -122,12 +114,12 @@ def category(uri: str) -> str | flask.Response: msg = f"Locked category {cat.name} cannot be modified" raise exc.http.Forbidden(msg) # Move all transactions to uncategorized - uncategorized_id, _ = TransactionCategory.uncategorized(s) + uncategorized_id, _ = TransactionCategory.uncategorized() - s.query(TransactionSplit).where( + TransactionSplit.query().where( TransactionSplit.category_id == cat.id_, ).update({"category_id": uncategorized_id}) - s.delete(cat) + cat.delete() return base.dialog_swap( event="category", @@ -177,39 +169,30 @@ def validation() -> str: return "Required" if len(name) < utils.MIN_STR_LEN: return f"{utils.MIN_STR_LEN} characters required" - with p.begin_session() as s: + with p.begin_session(): # Only get original name if locked - locked_name = ( - s.query(TransactionCategory.name) - .where( + locked_name = sql.scalar( + TransactionCategory.query(TransactionCategory.name).where( TransactionCategory.id_ == category_id, TransactionCategory.locked, - ) - .scalar() + ), ) if locked_name and locked_name != name: return "May only add/remove emojis" - n = query_count( - s.query(TransactionCategory).where( - TransactionCategory.name == name, - TransactionCategory.id_ != category_id, - ), + query = TransactionCategory.query().where( + TransactionCategory.name == name, + TransactionCategory.id_ != category_id, ) - if n != 0: + if sql.any_(query): return "Must be unique" return "" raise NotImplementedError -def ctx_categories( - s: orm.Session, -) -> dict[TransactionCategoryGroup, list[base.NamePair]]: +def ctx_categories() -> dict[TransactionCategoryGroup, list[base.NamePair]]: """Get the context required to build the categories table. - Args: - s: SQL session to use - Returns: List of HTML context @@ -220,8 +203,8 @@ def ctx_categories( TransactionCategoryGroup.TRANSFER: [], TransactionCategoryGroup.OTHER: [], } - query = s.query(TransactionCategory).order_by(TransactionCategory.name) - for cat in query.yield_per(YIELD_PER): + query = TransactionCategory.query().order_by(TransactionCategory.name) + for cat in sql.yield_(query): cat_d = base.NamePair(cat.uri, cat.emoji_name) if cat.group != TransactionCategoryGroup.OTHER or cat.name == "uncategorized": groups[cat.group].append(cat_d) diff --git a/nummus/controllers/transactions.py b/nummus/controllers/transactions.py index 62bb4f62..42e45c47 100644 --- a/nummus/controllers/transactions.py +++ b/nummus/controllers/transactions.py @@ -6,17 +6,17 @@ import operator from collections import defaultdict from decimal import Decimal +from itertools import starmap from typing import NamedTuple, NotRequired, TYPE_CHECKING, TypedDict import flask from sqlalchemy import func from nummus import exceptions as exc -from nummus import utils, web +from nummus import sql, utils, web from nummus.controllers import base from nummus.models.account import Account from nummus.models.asset import Asset -from nummus.models.base import YIELD_PER from nummus.models.config import Config from nummus.models.currency import ( Currency, @@ -25,15 +25,9 @@ from nummus.models.label import Label, LabelLink from nummus.models.transaction import Transaction, TransactionSplit from nummus.models.transaction_category import TransactionCategory -from nummus.models.utils import ( - obj_session, - query_count, - query_to_dict, - update_rows_list, -) +from nummus.models.utils import update_rows_list if TYPE_CHECKING: - import sqlalchemy from sqlalchemy import orm from nummus.models.currency import Currency, CurrencyFormat @@ -122,7 +116,7 @@ class TableQuery(NamedTuple): """Type definition for result of table_query().""" query: orm.Query[TransactionSplit] - clauses: dict[str, sqlalchemy.ColumnElement] + clauses: dict[str, sql.ColumnClause] any_filters: bool @property @@ -130,7 +124,7 @@ def final_query(self) -> orm.Query[TransactionSplit]: """Build the final query with clauses.""" return self.query.where(*self.clauses.values()) - def where(self, **clauses: sqlalchemy.ColumnElement) -> TableQuery: + def where(self, **clauses: sql.ColumnClause) -> TableQuery: """Add clauses to query. Args: @@ -155,9 +149,8 @@ def page_all() -> flask.Response: args = flask.request.args p = web.portfolio - with p.begin_session() as s: + with p.begin_session(): txn_table, title = ctx_table( - s, base.today_client(), args.get("search"), args.get("account"), @@ -186,9 +179,8 @@ def table() -> str | flask.Response: args = flask.request.args first_page = "page" not in args p = web.portfolio - with p.begin_session() as s: + with p.begin_session(): txn_table, title = ctx_table( - s, base.today_client(), args.get("search"), args.get("account"), @@ -229,8 +221,8 @@ def table_options() -> str: """ p = web.portfolio - with p.begin_session() as s: - accounts = Account.map_name(s) + with p.begin_session(): + accounts = Account.map_name() args = flask.request.args uncleared = "uncleared" in args @@ -241,7 +233,6 @@ def table_options() -> str: selected_end = args.get("end") tbl_query = table_query( - s, None, selected_account, selected_period, @@ -254,7 +245,7 @@ def table_options() -> str: tbl_query, base.today_client(), accounts, - base.tranaction_category_groups(s), + base.tranaction_category_groups(), selected_account, selected_category, ) @@ -287,25 +278,27 @@ def new() -> str | flask.Response: with p.begin_session() as s: query = ( - s.query(Account) - .with_entities(Account.id_, Account.name, Account.currency) + Account.query(Account.id_, Account.name, Account.currency) .where(Account.closed.is_(False)) .order_by(Account.name) ) accounts: dict[int, tuple[str, Currency]] = { - r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER) + r[0]: (r[1], r[2]) for r in sql.yield_(query) } - uncategorized_id, uncategorized_uri = TransactionCategory.uncategorized(s) + uncategorized_id, uncategorized_uri = TransactionCategory.uncategorized() - query = s.query(Transaction.payee) + query = Transaction.query(Transaction.payee).distinct() payees = sorted( - filter(None, (item for item, in query.distinct())), + filter(None, (item for item, in sql.yield_(query))), key=lambda item: item.lower(), ) - query = s.query(Label.name) - labels = sorted(item for item, in query.distinct()) + query = Label.query(Label.name).distinct() + labels = sorted( + filter(None, (item for item, in sql.yield_(query))), + key=lambda item: item.lower(), + ) acct_uri = ( flask.request.form.get("account") or flask.request.args.get("account") or "" @@ -313,7 +306,7 @@ def new() -> str | flask.Response: if acct_uri: cf = CURRENCY_FORMATS[accounts[Account.uri_to_id(acct_uri)][1]] else: - cf = CURRENCY_FORMATS[Config.base_currency(s)] + cf = CURRENCY_FORMATS[Config.base_currency()] empty_split: SplitContext = { "parent_uri": "", @@ -340,7 +333,7 @@ def new() -> str | flask.Response: "statement": "Manually created", "payee": None, "splits": [empty_split], - "category_groups": base.tranaction_category_groups(s), + "category_groups": base.tranaction_category_groups(), "payees": payees, "labels": labels, "similar_uri": None, @@ -427,7 +420,7 @@ def new() -> str | flask.Response: return base.error(err) s.add(txn) s.flush() - if err := _transaction_split_edit(s, txn): + if err := _transaction_split_edit(txn): return base.error(err) except (exc.IntegrityError, exc.InvalidORMValueError) as e: return base.error(e) @@ -452,7 +445,7 @@ def transaction(uri: str) -> str | flask.Response: p = web.portfolio today = base.today_client() with p.begin_session() as s: - txn = base.find(s, Transaction, uri) + txn = base.find(Transaction, uri) if flask.request.method == "GET": return flask.render_template( @@ -461,7 +454,7 @@ def transaction(uri: str) -> str | flask.Response: ) if flask.request.method == "PATCH": txn.cleared = True - s.query(TransactionSplit).where( + TransactionSplit.query().where( TransactionSplit.parent_id == txn.id_, ).update({"cleared": True}) return base.dialog_swap( @@ -472,15 +465,15 @@ def transaction(uri: str) -> str | flask.Response: if txn.cleared: return base.error("Cannot delete cleared transaction") date = txn.date - query = s.query(TransactionSplit.id_).where( + query = TransactionSplit.query(TransactionSplit.id_).where( TransactionSplit.parent_id == txn.id_, ) - t_split_ids = {r[0] for r in query.yield_per(YIELD_PER)} - s.query(LabelLink).where(LabelLink.t_split_id.in_(t_split_ids)).delete() - s.query(TransactionSplit).where( + t_split_ids = set(sql.col0(query)) + LabelLink.query().where(LabelLink.t_split_id.in_(t_split_ids)).delete() + TransactionSplit.query().where( TransactionSplit.id_.in_(t_split_ids), ).delete() - s.delete(txn) + txn.delete() return base.dialog_swap( # update-account since transaction was deleted event="account", @@ -493,7 +486,7 @@ def transaction(uri: str) -> str | flask.Response: if err := _transaction_edit(txn, today): return base.error(err) s.flush() - if err := _transaction_split_edit(s, txn): + if err := _transaction_split_edit(txn): return base.error(err) except (exc.IntegrityError, exc.InvalidORMValueError) as e: return base.error(e) @@ -536,11 +529,10 @@ def _transaction_edit(txn: Transaction, today: datetime.date) -> str: return "" -def _transaction_split_edit(s: orm.Session, txn: Transaction) -> str: +def _transaction_split_edit(txn: Transaction) -> str: """Edit transaction from form. Args: - s: SQL session to use txn: Transaction to edit Returns: @@ -564,16 +556,15 @@ def _transaction_split_edit(s: orm.Session, txn: Transaction) -> str: remaining = txn.amount - sum(filter(None, split_amounts)) if remaining != 0: - currency = ( - s.query(Account.currency).where(Account.id_ == txn.account_id).one()[0] - ) + query = Account.query(Account.currency).where(Account.id_ == txn.account_id) + currency = sql.one(query) cf = CURRENCY_FORMATS[currency] if remaining < 0: return f"Remove {cf(-remaining)} from splits" return f"Assign {cf(remaining)} to splits" - splits = [ + splits: list[dict[str, object]] = [ { "parent": txn, "category_id": cat_id, @@ -589,19 +580,16 @@ def _transaction_split_edit(s: orm.Session, txn: Transaction) -> str: if amount ] query = ( - s.query(TransactionSplit) + TransactionSplit.query() .where(TransactionSplit.parent_id == txn.id_) .order_by(TransactionSplit.id_) ) t_split_ids = update_rows_list( - s, TransactionSplit, query, splits, ) - s.flush() LabelLink.add_links( - s, { t_split_id: set(form.getlist(f"label-{i}")) for i, t_split_id in enumerate(t_split_ids) @@ -624,12 +612,13 @@ def split(uri: str) -> str: p = web.portfolio form = flask.request.form - with p.begin_session() as s: - txn = base.find(s, Transaction, uri) + with p.begin_session(): + txn = base.find(Transaction, uri) parent_amount = utils.parse_real(form["amount"]) or Decimal() account_id = Account.uri_to_id(form["account"]) - currency = s.query(Account.currency).where(Account.id_ == account_id).one()[0] + query = Account.query(Account.currency).where(Account.id_ == account_id) + currency = sql.one(query) payee = form["payee"] date = utils.parse_date(form["date"]) @@ -650,7 +639,7 @@ def split(uri: str) -> str: split_labels.append(set()) split_amounts.append(None) - _, uncategorized_uri = TransactionCategory.uncategorized(s) + _, uncategorized_uri = TransactionCategory.uncategorized() cf = CURRENCY_FORMATS[currency] @@ -781,14 +770,14 @@ def _validate_splits() -> str: else: uri = args.get("account") p = web.portfolio - with p.begin_session() as s: - currency = ( - s.query(Account.currency) - .where(Account.id_ == Account.uri_to_id(uri)) - .one()[0] - if uri - else Config.base_currency(s) - ) + with p.begin_session(): + if uri: + query = Account.query(Account.currency).where( + Account.id_ == Account.uri_to_id(uri), + ) + currency = sql.one(query) + else: + currency = Config.base_currency() cf = CURRENCY_FORMATS[currency] msg = ( f"Assign {cf(remaining)} to splits" @@ -805,7 +794,6 @@ def _validate_splits() -> str: def table_query( - s: orm.Session, acct_uri: str | None = None, selected_account: str | None = None, selected_period: str | None = None, @@ -818,7 +806,6 @@ def table_query( """Create transactions table query. Args: - s: SQL session to use acct_uri: Account URI to filter to selected_account: URI of account from args selected_period: Name of period from args @@ -832,14 +819,14 @@ def table_query( """ selected_account = acct_uri or selected_account - query = s.query(TransactionSplit).order_by( + query = TransactionSplit.query().order_by( TransactionSplit.date_ord.desc(), TransactionSplit.account_id, TransactionSplit.payee, TransactionSplit.category_id, TransactionSplit.memo, ) - clauses: dict[str, sqlalchemy.ColumnElement] = {} + clauses: dict[str, sql.ColumnClause] = {} any_filters = False @@ -906,38 +893,26 @@ def ctx_txn( Dictionary HTML context """ - s = obj_session(txn) - account_id = txn.account_id if account_id is None else account_id - query = ( - s.query(Account) - .with_entities( - Account.id_, - Account.name, - Account.closed, - Account.currency, - ) - .order_by(Account.name) - ) - accounts: dict[int, tuple[str, bool, Currency]] = { - r[0]: (r[1], r[2], r[3]) for r in query.yield_per(YIELD_PER) - } - query = s.query(Asset).with_entities(Asset.id_, Asset.name, Asset.ticker) - assets: dict[int, tuple[str, str | None]] = { - r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER) - } - query = s.query(Label.id_, Label.name) - labels: dict[int, str] = query_to_dict(query) + query = Account.query( + Account.id_, + Account.name, + Account.closed, + Account.currency, + ).order_by(Account.name) + accounts: dict[int, tuple[str, bool, Currency]] = sql.to_dict_tuple(query) + query = Asset.query(Asset.id_, Asset.name, Asset.ticker) + assets = sql.to_dict_tuple(query) + query = Label.query(Label.id_, Label.name) + labels: dict[int, str] = sql.to_dict(query) cf = CURRENCY_FORMATS[accounts[account_id][2]] - query = ( - s.query(LabelLink) - .with_entities(LabelLink.t_split_id, LabelLink.label_id) - .where(LabelLink.t_split_id.in_(t_split.id_ for t_split in txn.splits)) + query = LabelLink.query(LabelLink.t_split_id, LabelLink.label_id).where( + LabelLink.t_split_id.in_(t_split.id_ for t_split in txn.splits), ) label_links: dict[int, set[int]] = defaultdict(set) - for t_split_id, label_id in query.yield_per(YIELD_PER): + for t_split_id, label_id in sql.yield_(query): label_links[t_split_id].add(label_id) ctx_splits: list[SplitContext] = ( @@ -955,9 +930,9 @@ def ctx_txn( ) any_asset_splits = any(split.get("asset_name") for split in ctx_splits) - query = s.query(Transaction.payee) + query = Transaction.query(Transaction.payee).distinct() payees = sorted( - filter(None, (item for item, in query.distinct())), + filter(None, sql.col0(query)), key=lambda item: item.lower(), ) @@ -980,7 +955,7 @@ def ctx_txn( "statement": txn.statement, "payee": txn.payee if payee is None else payee, "splits": ctx_splits, - "category_groups": base.tranaction_category_groups(s), + "category_groups": base.tranaction_category_groups(), "payees": payees, "labels": sorted(labels.values()), "similar_uri": similar_uri, @@ -1105,14 +1080,14 @@ def ctx_options( clauses = tbl_query.clauses.copy() clauses.pop("account", None) query_options = ( - query.with_entities(TransactionSplit.account_id) + query.with_entities(TransactionSplit.account_id) # nummus: ignore .where(*clauses.values()) .distinct() ) options_account = sorted( [ base.NamePair(Account.id_to_uri(acct_id), accounts[acct_id]) - for acct_id, in query_options.yield_per(YIELD_PER) + for acct_id, in sql.yield_(query_options) ], key=operator.itemgetter(0), ) @@ -1123,13 +1098,13 @@ def ctx_options( clauses = tbl_query.clauses.copy() clauses.pop("category", None) query_options = ( - query.with_entities(TransactionSplit.category_id) + query.with_entities(TransactionSplit.category_id) # nummus: ignore .where(*clauses.values()) .distinct() ) - options_uris = { - TransactionCategory.id_to_uri(r[0]) for r in query_options.yield_per(YIELD_PER) - } + options_uris = set( + starmap(TransactionCategory.id_to_uri, sql.yield_(query_options)), + ) if selected_category: options_uris.add(selected_category) options_category = { @@ -1150,7 +1125,6 @@ def ctx_options( def ctx_table( - s: orm.Session, today: datetime.date, search_str: str | None, selected_account: str | None, @@ -1166,7 +1140,6 @@ def ctx_table( """Get the context to build the transaction table. Args: - s: SQL session to use today: Today's date search_str: String to search for selected_account: Selected account for filtering @@ -1182,24 +1155,22 @@ def ctx_table( tuple(TableContext, title) """ - query = s.query(Account).with_entities(Account.id_, Account.name, Account.currency) + query = Account.query(Account.id_, Account.name, Account.currency) accounts: dict[int, str] = {} currency_formats: dict[int, CurrencyFormat] = {} - for acct_id, name, currency in query.yield_per(YIELD_PER): + for acct_id, name, currency in sql.yield_(query): accounts[acct_id] = name currency_formats[acct_id] = CURRENCY_FORMATS[currency] - categories_emoji = TransactionCategory.map_name_emoji(s) + categories_emoji = TransactionCategory.map_name_emoji() categories = { cat_id: TransactionCategory.clean_emoji_name(name) for cat_id, name in categories_emoji.items() } - query = s.query(Asset).with_entities(Asset.id_, Asset.name, Asset.ticker) - assets: dict[int, tuple[str, str | None]] = { - r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER) - } - labels = Label.map_name(s) + query = Asset.query(Asset.id_, Asset.name, Asset.ticker) + assets = sql.to_dict_tuple(query) + labels = Label.map_name() if page_start is None: page_start_int = None @@ -1210,7 +1181,6 @@ def ctx_table( page_start_int = datetime.date.fromisoformat(page_start).toordinal() tbl_query = table_query( - s, acct_uri, selected_account, selected_period, @@ -1223,7 +1193,7 @@ def ctx_table( tbl_query, today, accounts, - base.tranaction_category_groups(s), + base.tranaction_category_groups(), selected_account, selected_category, ) @@ -1245,7 +1215,9 @@ def ctx_table( t_split_order = {} final_query = tbl_query.final_query - query_total = final_query.with_entities(func.sum(TransactionSplit.amount)) + query_total = final_query.with_entities( # nummus: ignore + func.sum(TransactionSplit.amount), + ) if matches is not None: i_start = page_start_int or 0 @@ -1256,7 +1228,7 @@ def ctx_table( # Find the fewest dates to include that will make page at least # PAGE_LEN long included_date_ords: set[int] = set() - query_page_count = final_query.with_entities( + query_page_count = final_query.with_entities( # nummus: ignore TransactionSplit.date_ord, func.count(), ).group_by(TransactionSplit.date_ord) @@ -1266,9 +1238,7 @@ def ctx_table( ) page_count = 0 # Limit to PAGE_LEN since at most there is one txn per day - for date_ord, count in query_page_count.limit(PAGE_LEN).yield_per( - YIELD_PER, - ): + for date_ord, count in sql.yield_(query_page_count.limit(PAGE_LEN)): included_date_ords.add(date_ord) page_count += count if page_count >= PAGE_LEN: @@ -1284,7 +1254,7 @@ def ctx_table( else datetime.date.fromordinal(min(included_date_ords) - 1) ) - n_matches = query_count(final_query) + n_matches = sql.count(final_query) groups = _table_results( final_query, assets, @@ -1306,7 +1276,7 @@ def ctx_table( return { "uri": acct_uri, "transactions": groups, - "query_total": query_total.scalar() or Decimal(), + "query_total": sql.scalar(query_total) or Decimal(), "no_matches": n_matches == 0 and page_start_int is None, "next_page": None if n_matches < PAGE_LEN else str(next_page), "any_filters": tbl_query.any_filters, @@ -1318,7 +1288,7 @@ def ctx_table( "uncleared": uncleared, "start": selected_start, "end": selected_end, - "currency_format": CURRENCY_FORMATS[Config.base_currency(s)], + "currency_format": CURRENCY_FORMATS[Config.base_currency()], }, title @@ -1350,19 +1320,17 @@ def _table_results( )] """ - s = query.session - # Iterate first to get required second query t_splits: list[TransactionSplit] = [] parent_ids: set[int] = set() - for t_split in query.yield_per(YIELD_PER): + for t_split in sql.yield_(query): t_splits.append(t_split) parent_ids.add(t_split.parent_id) # There are no more if there wasn't enough for a full page query_has_splits = ( - s.query(Transaction.id_) + Transaction.query(Transaction.id_) .join(TransactionSplit) .where( Transaction.id_.in_(parent_ids), @@ -1370,15 +1338,13 @@ def _table_results( .group_by(Transaction.id_) .having(func.count() > 1) ) - has_splits = {r[0] for r in query_has_splits.yield_per(YIELD_PER)} + has_splits = set(sql.col0(query_has_splits)) - query_labels = ( - s.query(LabelLink) - .with_entities(LabelLink.t_split_id, LabelLink.label_id) - .where(LabelLink.t_split_id.in_(t_split.id_ for t_split in t_splits)) + query_labels = LabelLink.query(LabelLink.t_split_id, LabelLink.label_id).where( + LabelLink.t_split_id.in_(t_split.id_ for t_split in t_splits), ) label_links: dict[int, set[int]] = defaultdict(set) - for t_split_id, label_id in query_labels.yield_per(YIELD_PER): + for t_split_id, label_id in sql.yield_(query_labels): label_links[t_split_id].add(label_id) t_splits_flat: list[tuple[RowContext, int]] = [] diff --git a/nummus/encryption/top.py b/nummus/encryption/top.py index 8c05f092..4a745629 100644 --- a/nummus/encryption/top.py +++ b/nummus/encryption/top.py @@ -14,7 +14,8 @@ logger.warning("Install libsqlcipher: apt install libsqlcipher-dev") logger.warning("Install encrypt extra: pip install nummus[encrypt]") Encryption = NoEncryption - ENCRYPTION_AVAILABLE = False + encryption_available = False else: Encryption = EncryptionAES - ENCRYPTION_AVAILABLE = True + encryption_available = True +ENCRYPTION_AVAILABLE = encryption_available diff --git a/nummus/exceptions.py b/nummus/exceptions.py index b720632e..592db581 100644 --- a/nummus/exceptions.py +++ b/nummus/exceptions.py @@ -39,8 +39,8 @@ "MissingAssetError", "MultipleResultsFound", "NoAssetWebSourceError", - "NoIDError", "NoImporterBufferError", + "NoKeywordArgumentsError", "NoResultFound", "NoURIError", "NonAssetTransactionError", @@ -175,10 +175,6 @@ class ProtectedObjectNotFoundError(Exception): """Error when a protected object (non-deletable) could not be found.""" -class NoIDError(Exception): - """Error when model does not have id_ yet, likely a flush is needed.""" - - class NoURIError(Exception): """Error when a URI is requested for a model without one.""" @@ -260,3 +256,7 @@ class InvalidAssetTransactionCategoryError(Exception): class InvalidKeyError(Exception): """Error when a key does not meet minimum requirements.""" + + +class NoKeywordArgumentsError(Exception): + """Error when function is given kwargs when not expected.""" diff --git a/nummus/health_checks/base.py b/nummus/health_checks/base.py index 972b5fc0..f92d08e2 100644 --- a/nummus/health_checks/base.py +++ b/nummus/health_checks/base.py @@ -3,16 +3,12 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import ClassVar, TYPE_CHECKING +from typing import ClassVar -from nummus import utils -from nummus.models.base import YIELD_PER +from nummus import sql, utils from nummus.models.health_checks import HealthCheckIssue from nummus.models.utils import update_rows -if TYPE_CHECKING: - from sqlalchemy import orm - class HealthCheck(ABC): """Base health check class.""" @@ -24,7 +20,7 @@ def __init__( self, *, no_ignores: bool = False, - **_, + **_: object, ) -> None: """Initialize Base health check. @@ -79,26 +75,20 @@ def is_severe(cls) -> bool: return cls._SEVERE @abstractmethod - def test(self, s: orm.Session) -> None: - """Run the health check on a portfolio. - - Args: - s: SQL session to use - - """ + def test(self) -> None: + """Run the health check on a portfolio.""" raise NotImplementedError @classmethod - def ignore(cls, s: orm.Session, values: list[str] | set[str]) -> None: + def ignore(cls, values: list[str] | set[str]) -> None: """Ignore false positive issues. Args: - s: SQL session to use values: List of issues to ignore """ ( - s.query(HealthCheckIssue) + HealthCheckIssue.query() .where( HealthCheckIssue.check == cls.name(), HealthCheckIssue.value.in_(values), @@ -106,40 +96,36 @@ def ignore(cls, s: orm.Session, values: list[str] | set[str]) -> None: .update({"ignore": True}) ) - def _commit_issues(self, s: orm.Session, issues: dict[str, str]) -> None: + def _commit_issues(self, issues: dict[str, str]) -> None: """Commit issues to Portfolio. Args: - s: SQL session to use issues: dict{value: message} """ - query = s.query(HealthCheckIssue.value).where( + query = HealthCheckIssue.query(HealthCheckIssue.value).where( HealthCheckIssue.check == self.name(), HealthCheckIssue.ignore.is_(True), ) - ignored = {r[0] for r in query.yield_per(YIELD_PER)} + ignored = set(sql.col0(query)) updates: dict[object, dict[str, object]] = { value: {"check": self.name(), "ignore": value in ignored, "msg": msg} for value, msg in issues.items() } - query = s.query(HealthCheckIssue).where( + query = HealthCheckIssue.query().where( HealthCheckIssue.check == self.name(), ) - update_rows(s, HealthCheckIssue, query, "value", updates) - s.flush() + update_rows(HealthCheckIssue, query, "value", updates) - query = ( - s.query(HealthCheckIssue) - .with_entities(HealthCheckIssue.id_, HealthCheckIssue.msg) - .where( - HealthCheckIssue.check == self.name(), - ) + query = HealthCheckIssue.query( + HealthCheckIssue.id_, + HealthCheckIssue.msg, + ).where( + HealthCheckIssue.check == self.name(), ) if not self._no_ignores: query = query.where(HealthCheckIssue.ignore.is_(False)) self._issues = { - HealthCheckIssue.id_to_uri(id_): msg - for id_, msg in query.yield_per(YIELD_PER) + HealthCheckIssue.id_to_uri(id_): msg for id_, msg in sql.yield_(query) } diff --git a/nummus/health_checks/category_direction.py b/nummus/health_checks/category_direction.py index c56f095f..d62c734f 100644 --- a/nummus/health_checks/category_direction.py +++ b/nummus/health_checks/category_direction.py @@ -4,25 +4,17 @@ import datetime import textwrap -from typing import override, TYPE_CHECKING +from typing import override +from nummus import sql from nummus.health_checks.base import HealthCheck from nummus.models.account import Account -from nummus.models.base import YIELD_PER from nummus.models.currency import CURRENCY_FORMATS from nummus.models.transaction import TransactionSplit from nummus.models.transaction_category import ( TransactionCategory, TransactionCategoryGroup, ) -from nummus.models.utils import query_to_dict - -if TYPE_CHECKING: - from decimal import Decimal - - from sqlalchemy import orm - - from nummus.models.currency import Currency class CategoryDirection(HealthCheck): @@ -36,39 +28,36 @@ class CategoryDirection(HealthCheck): _SEVERE = True @override - def test(self, s: orm.Session) -> None: - query = s.query(Account).with_entities( + def test(self) -> None: + query = Account.query( Account.id_, Account.name, Account.currency, ) - accounts: dict[int, tuple[str, Currency]] = { - r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER) - } + accounts = sql.to_dict_tuple(query) if len(accounts) == 0: - self._commit_issues(s, {}) + self._commit_issues({}) return acct_len = max(len(acct[0]) for acct in accounts.values()) issues: dict[str, str] = {} - query = s.query( + query = TransactionCategory.query( TransactionCategory.id_, TransactionCategory.emoji_name, ).where( TransactionCategory.group == TransactionCategoryGroup.INCOME, ) - cat_income_ids: dict[int, str] = query_to_dict(query) - query = s.query( + cat_income_ids: dict[int, str] = sql.to_dict(query) + query = TransactionCategory.query( TransactionCategory.id_, TransactionCategory.emoji_name, ).where( TransactionCategory.group == TransactionCategoryGroup.EXPENSE, ) - cat_expense_ids: dict[int, str] = query_to_dict(query) + cat_expense_ids: dict[int, str] = sql.to_dict(query) query = ( - s.query(TransactionSplit) - .with_entities( + TransactionSplit.query( TransactionSplit.id_, TransactionSplit.account_id, TransactionSplit.date_ord, @@ -82,13 +71,7 @@ def test(self, s: orm.Session) -> None: ) .order_by(TransactionSplit.date_ord) ) - for t_id, acct_id, date_ord, payee, amount, t_cat_id in query.yield_per( - YIELD_PER, - ): - acct_id: int - date_ord: int - payee: str - amount: Decimal + for t_id, acct_id, date_ord, payee, amount, t_cat_id in sql.yield_(query): uri = TransactionSplit.id_to_uri(t_id) acct_name, currency = accounts[acct_id] @@ -104,8 +87,7 @@ def test(self, s: orm.Session) -> None: issues[uri] = msg query = ( - s.query(TransactionSplit) - .with_entities( + TransactionSplit.query( TransactionSplit.id_, TransactionSplit.account_id, TransactionSplit.date_ord, @@ -119,13 +101,7 @@ def test(self, s: orm.Session) -> None: ) .order_by(TransactionSplit.date_ord) ) - for t_id, acct_id, date_ord, payee, amount, t_cat_id in query.yield_per( - YIELD_PER, - ): - acct_id: int - date_ord: int - payee: str - amount: Decimal + for t_id, acct_id, date_ord, payee, amount, t_cat_id in sql.yield_(query): uri = TransactionSplit.id_to_uri(t_id) acct_name, currency = accounts[acct_id] @@ -140,4 +116,4 @@ def test(self, s: orm.Session) -> None: ) issues[uri] = msg - self._commit_issues(s, issues) + self._commit_issues(issues) diff --git a/nummus/health_checks/database_integrity.py b/nummus/health_checks/database_integrity.py index 9cf9ba10..f147064b 100644 --- a/nummus/health_checks/database_integrity.py +++ b/nummus/health_checks/database_integrity.py @@ -6,7 +6,9 @@ import sqlalchemy +from nummus import sql from nummus.health_checks.base import HealthCheck +from nummus.models.base import Base if TYPE_CHECKING: from sqlalchemy import orm @@ -19,11 +21,13 @@ class DatabaseIntegrity(HealthCheck): _SEVERE = True @override - def test(self, s: orm.Session) -> None: - result = s.execute(sqlalchemy.text("PRAGMA integrity_check")) - rows = [row for row, in result.all()] + def test(self) -> None: + query: orm.query.RowReturningQuery[tuple[str]] = Base.session().execute( # type: ignore[attr-defined] + sqlalchemy.text("PRAGMA integrity_check"), + ) + rows = list(sql.col0(query)) if len(rows) != 1 or rows[0] != "ok": issues = {str(i): row for i, row in enumerate(rows)} else: issues = {} - self._commit_issues(s, issues) + self._commit_issues(issues) diff --git a/nummus/health_checks/duplicate_transactions.py b/nummus/health_checks/duplicate_transactions.py index 353c494c..cbc50666 100644 --- a/nummus/health_checks/duplicate_transactions.py +++ b/nummus/health_checks/duplicate_transactions.py @@ -3,23 +3,16 @@ from __future__ import annotations import datetime -from typing import override, TYPE_CHECKING +from typing import override from sqlalchemy import func +from nummus import sql from nummus.health_checks.base import HealthCheck from nummus.models.account import Account -from nummus.models.base import YIELD_PER from nummus.models.currency import CURRENCY_FORMATS from nummus.models.transaction import Transaction -if TYPE_CHECKING: - from decimal import Decimal - - from sqlalchemy import orm - - from nummus.models.currency import Currency - class DuplicateTransactions(HealthCheck): """Checks for transactions with same amount, date, and statement.""" @@ -28,21 +21,18 @@ class DuplicateTransactions(HealthCheck): _SEVERE = True @override - def test(self, s: orm.Session) -> None: - query = s.query(Account).with_entities( + def test(self) -> None: + query = Account.query( Account.id_, Account.name, Account.currency, ) - accounts: dict[int, tuple[str, Currency]] = { - r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER) - } + accounts = sql.to_dict_tuple(query) issues: list[tuple[str, str, str]] = [] query = ( - s.query(Transaction) - .with_entities( + Transaction.query( Transaction.date_ord, Transaction.account_id, Transaction.amount, @@ -58,10 +48,7 @@ def test(self, s: orm.Session) -> None: .order_by(Transaction.date_ord) .having(func.count() > 1) ) - for date_ord, acct_id, amount in query.yield_per(YIELD_PER): - date_ord: int - acct_id: int - amount: Decimal + for date_ord, acct_id, amount in sql.yield_(query): amount_raw = Transaction.amount.type.process_bind_param(amount, None) # Create a robust uri for this duplicate uri = f"{acct_id}.{date_ord}.{amount_raw}" @@ -81,7 +68,6 @@ def test(self, s: orm.Session) -> None: amount_len = 0 self._commit_issues( - s, { uri: f"{source:{source_len}}: {amount_str:>{amount_len}}" for uri, source, amount_str in issues diff --git a/nummus/health_checks/empty_fields.py b/nummus/health_checks/empty_fields.py index 1a04b082..8fe72c9c 100644 --- a/nummus/health_checks/empty_fields.py +++ b/nummus/health_checks/empty_fields.py @@ -3,18 +3,15 @@ from __future__ import annotations import datetime -from typing import override, TYPE_CHECKING +from typing import override +from nummus import sql from nummus.health_checks.base import HealthCheck from nummus.models.account import Account from nummus.models.asset import Asset -from nummus.models.base import YIELD_PER from nummus.models.transaction import Transaction, TransactionSplit from nummus.models.transaction_category import TransactionCategory -if TYPE_CHECKING: - from sqlalchemy import orm - class EmptyFields(HealthCheck): """Checks for empty fields that are better when populated.""" @@ -23,53 +20,36 @@ class EmptyFields(HealthCheck): _SEVERE = False @override - def test(self, s: orm.Session) -> None: - accounts = Account.map_name(s) + def test(self) -> None: + accounts = Account.map_name() # List of (uri, source, field) issues: list[tuple[str, str, str]] = [] - query = ( - s.query(Account) - .with_entities(Account.id_, Account.name) - .where(Account.number.is_(None)) - ) - for acct_id, name in query.yield_per(YIELD_PER): - acct_id: int - name: str + query = Account.query(Account.id_, Account.name).where(Account.number.is_(None)) + for acct_id, name in sql.yield_(query): uri = Account.id_to_uri(acct_id) issues.append( (f"{uri}.number", f"Account {name}", "has an empty number"), ) - query = ( - s.query(Asset) - .with_entities(Asset.id_, Asset.name) - .where(Asset.description.is_(None)) + query = Asset.query(Asset.id_, Asset.name).where( + Asset.description.is_(None), ) - for a_id, name in query.yield_per(YIELD_PER): - a_id: int - name: str + for a_id, name in sql.yield_(query): uri = Asset.id_to_uri(a_id) issues.append( (f"{uri}.description", f"Asset {name}", "has an empty description"), ) - query = ( - s.query(Transaction) - .with_entities( - Transaction.id_, - Transaction.date_ord, - Transaction.account_id, - ) - .where( - Transaction.payee.is_(None), - ) + query = Transaction.query( + Transaction.id_, + Transaction.date_ord, + Transaction.account_id, + ).where( + Transaction.payee.is_(None), ) - for t_id, date_ord, acct_id in query.yield_per(YIELD_PER): - t_id: int - date_ord: int - acct_id: int + for t_id, date_ord, acct_id in sql.yield_(query): uri = Transaction.id_to_uri(t_id) date = datetime.date.fromordinal(date_ord) @@ -78,20 +58,13 @@ def test(self, s: orm.Session) -> None: (f"{uri}.payee", source, "has an empty payee"), ) - uncategorized_id, _ = TransactionCategory.uncategorized(s) - query = ( - s.query(TransactionSplit) - .with_entities( - TransactionSplit.id_, - TransactionSplit.date_ord, - TransactionSplit.account_id, - ) - .where(TransactionSplit.category_id == uncategorized_id) - ) - for t_id, date_ord, acct_id in query.yield_per(YIELD_PER): - t_id: int - date_ord: int - acct_id: int + uncategorized_id, _ = TransactionCategory.uncategorized() + query = TransactionSplit.query( + TransactionSplit.id_, + TransactionSplit.date_ord, + TransactionSplit.account_id, + ).where(TransactionSplit.category_id == uncategorized_id) + for t_id, date_ord, acct_id in sql.yield_(query): uri = TransactionSplit.id_to_uri(t_id) date = datetime.date.fromordinal(date_ord) @@ -101,6 +74,5 @@ def test(self, s: orm.Session) -> None: source_len = max(len(item[1]) for item in issues) if issues else 0 self._commit_issues( - s, {uri: f"{source:{source_len}} {field}" for uri, source, field in issues}, ) diff --git a/nummus/health_checks/missing_asset_link.py b/nummus/health_checks/missing_asset_link.py index f58330e7..2a5f48d6 100644 --- a/nummus/health_checks/missing_asset_link.py +++ b/nummus/health_checks/missing_asset_link.py @@ -5,9 +5,9 @@ import datetime from typing import override, TYPE_CHECKING +from nummus import sql from nummus.health_checks.base import HealthCheck from nummus.models.account import Account -from nummus.models.base import YIELD_PER from nummus.models.currency import CURRENCY_FORMATS from nummus.models.transaction import TransactionSplit from nummus.models.transaction_category import TransactionCategory @@ -15,10 +15,6 @@ if TYPE_CHECKING: from decimal import Decimal - from sqlalchemy import orm - - from nummus.models.currency import Currency - class MissingAssetLink(HealthCheck): """Checks for transactions that should be linked to an asset that aren't.""" @@ -27,50 +23,39 @@ class MissingAssetLink(HealthCheck): _SEVERE = False @override - def test(self, s: orm.Session) -> None: - query = s.query(Account).with_entities( + def test(self) -> None: + query = Account.query( Account.id_, Account.name, Account.currency, ) - accounts: dict[int, tuple[str, Currency]] = { - r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER) - } + accounts = sql.to_dict_tuple(query) if len(accounts) == 0: - self._commit_issues(s, {}) + self._commit_issues({}) return acct_len = max(len(acct[0]) for acct in accounts.values()) issues: dict[str, str] = {} - categories = TransactionCategory.map_name_emoji(s) + categories = TransactionCategory.map_name_emoji() # These categories should be linked to an asset - query = s.query(TransactionCategory.id_).where( + query = TransactionCategory.query(TransactionCategory.id_).where( TransactionCategory.asset_linked.is_(True), ) - categories_assets_id = {r for r, in query.all()} + categories_assets_id = set(sql.col0(query)) # Get transactions in these categories that do not have an asset - query = ( - s.query(TransactionSplit) - .with_entities( - TransactionSplit.id_, - TransactionSplit.date_ord, - TransactionSplit.account_id, - TransactionSplit.category_id, - TransactionSplit.amount, - ) - .where( - TransactionSplit.category_id.in_(categories_assets_id), - TransactionSplit.asset_id.is_(None), - ) + query = TransactionSplit.query( + TransactionSplit.id_, + TransactionSplit.date_ord, + TransactionSplit.account_id, + TransactionSplit.category_id, + TransactionSplit.amount, + ).where( + TransactionSplit.category_id.in_(categories_assets_id), + TransactionSplit.asset_id.is_(None), ) - for t_id, date_ord, acct_id, cat_id, amount in query.yield_per(YIELD_PER): - t_id: int - date_ord: int - acct_id: int - cat_id: int - amount: Decimal + for t_id, date_ord, acct_id, cat_id, amount in sql.yield_(query): uri = TransactionSplit.id_to_uri(t_id) acct_name, currency = accounts[acct_id] @@ -85,21 +70,17 @@ def test(self, s: orm.Session) -> None: issues[uri] = msg # Get transactions not in these categories that do have an asset - query = ( - s.query(TransactionSplit) - .with_entities( - TransactionSplit.id_, - TransactionSplit.date_ord, - TransactionSplit.account_id, - TransactionSplit.category_id, - TransactionSplit.amount, - ) - .where( - TransactionSplit.category_id.not_in(categories_assets_id), - TransactionSplit.asset_id.is_not(None), - ) + query = TransactionSplit.query( + TransactionSplit.id_, + TransactionSplit.date_ord, + TransactionSplit.account_id, + TransactionSplit.category_id, + TransactionSplit.amount, + ).where( + TransactionSplit.category_id.not_in(categories_assets_id), + TransactionSplit.asset_id.is_not(None), ) - for t_id, date_ord, acct_id, cat_id, amount in query.yield_per(YIELD_PER): + for t_id, date_ord, acct_id, cat_id, amount in sql.yield_(query): t_id: int date_ord: int acct_id: int @@ -118,4 +99,4 @@ def test(self, s: orm.Session) -> None: ) issues[uri] = msg - self._commit_issues(s, issues) + self._commit_issues(issues) diff --git a/nummus/health_checks/missing_valuations.py b/nummus/health_checks/missing_valuations.py index 2131ec49..46933d4a 100644 --- a/nummus/health_checks/missing_valuations.py +++ b/nummus/health_checks/missing_valuations.py @@ -3,20 +3,17 @@ from __future__ import annotations import datetime -from typing import override, TYPE_CHECKING +from typing import override from sqlalchemy import func +from nummus import sql from nummus.health_checks.base import HealthCheck from nummus.models.asset import ( Asset, AssetValuation, ) from nummus.models.transaction import TransactionSplit -from nummus.models.utils import query_to_dict - -if TYPE_CHECKING: - from sqlalchemy import orm class MissingAssetValuations(HealthCheck): @@ -26,30 +23,25 @@ class MissingAssetValuations(HealthCheck): _SEVERE = True @override - def test(self, s: orm.Session) -> None: - assets = Asset.map_name(s) + def test(self) -> None: + assets = Asset.map_name() issues: dict[str, str] = {} query = ( - s.query(TransactionSplit) - .with_entities( + TransactionSplit.query( TransactionSplit.asset_id, func.min(TransactionSplit.date_ord), ) .where(TransactionSplit.asset_id.isnot(None)) .group_by(TransactionSplit.asset_id) ) - first_date_ords: dict[int | None, int] = query_to_dict(query) + first_date_ords: dict[int | None, int] = sql.to_dict(query) - query = ( - s.query(AssetValuation) - .with_entities( - AssetValuation.asset_id, - func.min(AssetValuation.date_ord), - ) - .group_by(AssetValuation.asset_id) - ) - first_valuations: dict[int, int] = query_to_dict(query) + query = AssetValuation.query( + AssetValuation.asset_id, + func.min(AssetValuation.date_ord), + ).group_by(AssetValuation.asset_id) + first_valuations: dict[int, int] = sql.to_dict(query) for a_id, date_ord in first_date_ords.items(): if a_id is None: @@ -68,4 +60,4 @@ def test(self, s: orm.Session) -> None: ) issues[uri] = msg - self._commit_issues(s, issues) + self._commit_issues(issues) diff --git a/nummus/health_checks/outlier_asset_price.py b/nummus/health_checks/outlier_asset_price.py index eecb7e57..3f9a2363 100644 --- a/nummus/health_checks/outlier_asset_price.py +++ b/nummus/health_checks/outlier_asset_price.py @@ -9,17 +9,14 @@ from sqlalchemy import func -from nummus import utils +from nummus import sql, utils from nummus.health_checks.base import HealthCheck from nummus.models.account import Account from nummus.models.asset import Asset -from nummus.models.base import YIELD_PER from nummus.models.currency import CURRENCY_FORMATS from nummus.models.transaction import TransactionSplit -from nummus.models.utils import query_to_dict if TYPE_CHECKING: - from sqlalchemy import orm from nummus.models.currency import Currency @@ -38,39 +35,37 @@ class OutlierAssetPrice(HealthCheck): _RANGE = Decimal("0.4") @override - def test(self, s: orm.Session) -> None: + def test(self) -> None: today = datetime.datetime.now(datetime.UTC).date() today_ord = today.toordinal() - start_ord = ( - s.query(func.min(TransactionSplit.date_ord)) - .where(TransactionSplit.asset_id.isnot(None)) - .scalar() + start_ord = sql.scalar( + TransactionSplit.query(func.min(TransactionSplit.date_ord)).where( + TransactionSplit.asset_id.isnot(None), + ), ) if start_ord is None: # No asset transactions at all - self._commit_issues(s, {}) + self._commit_issues({}) return # List of (uri, source, field) issues: list[tuple[str, str, str]] = [] - assets = Asset.map_name(s) + assets = Asset.map_name() asset_valuations = Asset.get_value_all( - s, start_ord, today_ord + utils.DAYS_IN_WEEK, ) - query = s.query(Account).with_entities( + query = Account.query( Account.id_, Account.currency, ) - accounts: dict[int, Currency] = query_to_dict(query) + accounts: dict[int, Currency] = sql.to_dict(query) query = ( - s.query(TransactionSplit) - .with_entities( + TransactionSplit.query( TransactionSplit.id_, TransactionSplit.account_id, TransactionSplit.date_ord, @@ -84,19 +79,14 @@ def test(self, s: orm.Session) -> None: ) .where(TransactionSplit.asset_id.isnot(None)) ) - for t_id, acct_id, date_ord, a_id, amount, qty in query.yield_per( - YIELD_PER, - ): - t_id: int - acct_id: int - date_ord: int - a_id: int - amount: Decimal - qty: Decimal + for t_id, acct_id, date_ord, a_id, amount, qty in sql.yield_(query): uri = TransactionSplit.id_to_uri(t_id) - if qty == 0: + if not qty: continue + if TYPE_CHECKING: + # Enforced by query and SQL constraints + assert a_id is not None # Transaction asset price t_price = -amount / qty @@ -131,6 +121,5 @@ def test(self, s: orm.Session) -> None: source_len = max(len(item[1]) for item in issues) if issues else 0 self._commit_issues( - s, {uri: f"{source:{source_len}} {field}" for uri, source, field in issues}, ) diff --git a/nummus/health_checks/overdrawn_accounts.py b/nummus/health_checks/overdrawn_accounts.py index e6533bc1..86dc7881 100644 --- a/nummus/health_checks/overdrawn_accounts.py +++ b/nummus/health_checks/overdrawn_accounts.py @@ -4,21 +4,16 @@ import datetime from decimal import Decimal -from typing import override, TYPE_CHECKING +from typing import override from sqlalchemy import func +from nummus import sql from nummus.health_checks.base import HealthCheck from nummus.models.account import Account, AccountCategory -from nummus.models.base import YIELD_PER from nummus.models.currency import CURRENCY_FORMATS from nummus.models.transaction import TransactionSplit -if TYPE_CHECKING: - from sqlalchemy import orm - - from nummus.models.currency import Currency - class OverdrawnAccounts(HealthCheck): """Checks for accounts that had a negative cash balance when they shouldn't.""" @@ -27,38 +22,30 @@ class OverdrawnAccounts(HealthCheck): _SEVERE = True @override - def test(self, s: orm.Session) -> None: + def test(self) -> None: # Get a list of accounts subject to overdrawn so not credit and loans categories_exclude = [ AccountCategory.CREDIT, AccountCategory.LOAN, AccountCategory.MORTGAGE, ] - query = ( - s.query(Account) - .with_entities(Account.id_, Account.name, Account.currency) - .where(Account.category.not_in(categories_exclude)) + query = Account.query(Account.id_, Account.name, Account.currency).where( + Account.category.not_in(categories_exclude), ) - accounts: dict[int, tuple[str, Currency]] = { - r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER) - } + accounts = sql.to_dict_tuple(query) acct_ids = set(accounts) issues: list[tuple[str, str, str]] = [] - start_ord, end_ord = ( - s.query( + start_ord, end_ord = sql.one( + TransactionSplit.query( func.min(TransactionSplit.date_ord), func.max(TransactionSplit.date_ord), - ) - .where(TransactionSplit.account_id.in_(acct_ids)) - .one() + ).where(TransactionSplit.account_id.in_(acct_ids)), ) - start_ord: int | None - end_ord: int | None - if start_ord is None or end_ord is None: - # No asset transactions at all - self._commit_issues(s, {}) + if start_ord is None or end_ord is None: # type: ignore[attr-defined] + # No transactions at all + self._commit_issues({}) return n = end_ord - start_ord + 1 @@ -66,20 +53,13 @@ def test(self, s: orm.Session) -> None: cf = CURRENCY_FORMATS[currency] # Get cash holdings across all time cash_flow: list[Decimal | None] = [None] * n - query = ( - s.query(TransactionSplit) - .with_entities( - TransactionSplit.date_ord, - TransactionSplit.amount, - ) - .where( - TransactionSplit.account_id == acct_id, - ) + query = TransactionSplit.query( + TransactionSplit.date_ord, + TransactionSplit.amount, + ).where( + TransactionSplit.account_id == acct_id, ) - for date_ord, amount in query.yield_per(YIELD_PER): - date_ord: int - amount: Decimal - + for date_ord, amount in sql.yield_(query): i = date_ord - start_ord v = cash_flow[i] @@ -111,7 +91,6 @@ def test(self, s: orm.Session) -> None: amount_len = 0 self._commit_issues( - s, { uri: f"{source:{source_len}}: {amount_str:>{amount_len}}" for uri, source, amount_str in issues diff --git a/nummus/health_checks/typos.py b/nummus/health_checks/typos.py index 30083751..e180fc19 100644 --- a/nummus/health_checks/typos.py +++ b/nummus/health_checks/typos.py @@ -9,23 +9,18 @@ from typing import override, TYPE_CHECKING import spellchecker -from sqlalchemy import func, orm +from sqlalchemy import func -from nummus import utils +from nummus import sql, utils from nummus.health_checks.base import HealthCheck from nummus.models.account import Account from nummus.models.asset import ( Asset, AssetCategory, ) -from nummus.models.base import YIELD_PER from nummus.models.label import Label from nummus.models.transaction import TransactionSplit -if TYPE_CHECKING: - from sqlalchemy import orm - - _LIMIT_FREQUENCY = 10 @@ -60,18 +55,18 @@ def __init__( self._proper_nouns: set[str] = set() @override - def test(self, s: orm.Session) -> None: + def test(self) -> None: spell = spellchecker.SpellChecker() - accounts = Account.map_name(s) - assets = Asset.map_name(s) + accounts = Account.map_name() + assets = Asset.map_name() issues: dict[str, tuple[str, str, str]] = {} self._proper_nouns.update(accounts.values()) self._proper_nouns.update(assets.values()) - issues.update(self._test_accounts(s, accounts)) - issues.update(self._test_labels(s)) - issues.update(self._test_transaction_nouns(s, accounts)) + issues.update(self._test_accounts(accounts)) + issues.update(self._test_labels()) + issues.update(self._test_transaction_nouns(accounts)) # Escape words and sort to replace longest words first # So long words aren't partially replaced if they contain a short word @@ -82,8 +77,8 @@ def test(self, s: orm.Session) -> None: # Remove proper nouns indicated by word boundary or space at end re_cleaner = re.compile(rf"\b(?:{'|'.join(proper_nouns_re)})(?:\b|(?= |$))") - issues.update(self._test_transaction_texts(s, accounts, re_cleaner, spell)) - issues.update(self._test_assets(s, assets, re_cleaner, spell)) + issues.update(self._test_transaction_texts(accounts, re_cleaner, spell)) + issues.update(self._test_assets(assets, re_cleaner, spell)) source_len = 0 field_len = 0 @@ -95,7 +90,6 @@ def test(self, s: orm.Session) -> None: # Getting a suggested correction is slow and error prone, # Just say if a word is outside of the dictionary self._commit_issues( - s, { uri: f"{source:{source_len}} {field:{field_len}}: {word}" for uri, (word, source, field) in issues.items() @@ -138,29 +132,22 @@ def _create_issues(self) -> dict[str, tuple[str, str, str]]: def _test_accounts( self, - s: orm.Session, accounts: dict[int, str], ) -> dict[str, tuple[str, str, str]]: - query = s.query(Account).with_entities( + query = Account.query( Account.id_, Account.institution, ) - for acct_id, institution in query.yield_per(YIELD_PER): - acct_id: int - institution: str + for acct_id, institution in sql.yield_(query): name = accounts[acct_id] source = f"Account {name}" self._add(institution, source, "institution", 1) self._proper_nouns.add(institution) return self._create_issues() - def _test_labels( - self, - s: orm.Session, - ) -> dict[str, tuple[str, str, str]]: - query = s.query(Label.name) - for (name,) in query.yield_per(YIELD_PER): - name: str + def _test_labels(self) -> dict[str, tuple[str, str, str]]: + query = Label.query(Label.name) + for name in sql.col0(query): source = f"Label {name}" self._add(name, source, "name", 1) self._proper_nouns.add(name) @@ -168,7 +155,6 @@ def _test_labels( def _test_transaction_nouns( self, - s: orm.Session, accounts: dict[int, str], ) -> dict[str, tuple[str, str, str]]: issues: dict[str, tuple[str, str, str]] = {} @@ -177,8 +163,7 @@ def _test_transaction_nouns( ] for field in txn_fields: query = ( - s.query(TransactionSplit) - .with_entities( + TransactionSplit.query( TransactionSplit.date_ord, TransactionSplit.account_id, field, @@ -187,10 +172,10 @@ def _test_transaction_nouns( .where(field.is_not(None)) .group_by(field) ) - for date_ord, acct_id, value, count in query.yield_per(YIELD_PER): - date_ord: int - acct_id: int - value: str + for date_ord, acct_id, value, count in sql.yield_(query): + if TYPE_CHECKING: + # Enforced by query and SQL constraints + assert value is not None date = datetime.date.fromordinal(date_ord) source = f"{date} - {accounts[acct_id]}" self._add(value, source, field.key, count) @@ -200,27 +185,24 @@ def _test_transaction_nouns( def _test_transaction_texts( self, - s: orm.Session, accounts: dict[int, str], - re_cleaner: re.Pattern, + re_cleaner: re.Pattern[str], spell: spellchecker.SpellChecker, ) -> dict[str, tuple[str, str, str]]: query = ( - s.query(TransactionSplit) - .with_entities( + TransactionSplit.query( TransactionSplit.date_ord, TransactionSplit.account_id, TransactionSplit.memo, func.count(), ) .group_by(TransactionSplit.memo) + .where(TransactionSplit.memo.is_not(None)) ) - for date_ord, acct_id, value, count in query.yield_per(YIELD_PER): - date_ord: int - acct_id: int - value: str | None - if value is None: - continue + for date_ord, acct_id, value, count in sql.yield_(query): + if TYPE_CHECKING: + # Enforced by query and SQL constraints + assert value is not None date = datetime.date.fromordinal(date_ord) source = f"{date} - {accounts[acct_id]}" cleaned = re_cleaner.sub("", value).lower() @@ -237,25 +219,21 @@ def _test_transaction_texts( def _test_assets( self, - s: orm.Session, assets: dict[int, str], - re_cleaner: re.Pattern, + re_cleaner: re.Pattern[str], spell: spellchecker.SpellChecker, ) -> dict[str, tuple[str, str, str]]: - query = ( - s.query(Asset) - .with_entities( - Asset.id_, - Asset.description, - ) - .where( - Asset.category != AssetCategory.INDEX, - Asset.description.is_not(None), - ) + query = Asset.query( + Asset.id_, + Asset.description, + ).where( + Asset.category != AssetCategory.INDEX, + Asset.description.is_not(None), ) - for a_id, value in query.yield_per(YIELD_PER): - a_id: int - value: str + for a_id, value in sql.yield_(query): + if TYPE_CHECKING: + # Enforced by query and SQL constraints + assert value is not None source = f"Asset {assets[a_id]}" cleaned = re_cleaner.sub("", value).lower() for word in self._RE_WORDS.split(cleaned): diff --git a/nummus/health_checks/unbalanced_transfers.py b/nummus/health_checks/unbalanced_transfers.py index b0f60e78..136cec9e 100644 --- a/nummus/health_checks/unbalanced_transfers.py +++ b/nummus/health_checks/unbalanced_transfers.py @@ -9,19 +9,17 @@ from decimal import Decimal from typing import override, TYPE_CHECKING +from nummus import sql from nummus.health_checks.base import HealthCheck from nummus.models.account import Account -from nummus.models.base import YIELD_PER from nummus.models.currency import CURRENCY_FORMATS from nummus.models.transaction import TransactionSplit from nummus.models.transaction_category import ( TransactionCategory, TransactionCategoryGroup, ) -from nummus.models.utils import query_to_dict if TYPE_CHECKING: - from sqlalchemy import orm from nummus.models.currency import Currency @@ -37,28 +35,25 @@ class UnbalancedTransfers(HealthCheck): _SEVERE = True @override - def test(self, s: orm.Session) -> None: + def test(self) -> None: issues: dict[str, str] = {} - query = s.query( + query = TransactionCategory.query( TransactionCategory.id_, TransactionCategory.emoji_name, ).where( TransactionCategory.group == TransactionCategoryGroup.TRANSFER, ) - cat_transfers_ids: dict[int, str] = query_to_dict(query) + cat_transfers_ids = sql.to_dict(query) - query = s.query(Account).with_entities( + query = Account.query( Account.id_, Account.name, Account.currency, ) - accounts: dict[int, tuple[str, Currency]] = { - r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER) - } + accounts = sql.to_dict_tuple(query) query = ( - s.query(TransactionSplit) - .with_entities( + TransactionSplit.query( TransactionSplit.account_id, TransactionSplit.date_ord, TransactionSplit.amount, @@ -68,12 +63,9 @@ def test(self, s: orm.Session) -> None: .order_by(TransactionSplit.date_ord, TransactionSplit.amount) ) current_date_ord: int | None = None - total = defaultdict(Decimal) + total: dict[int, Decimal] = defaultdict(Decimal) current_splits: dict[int, list[tuple[int, Decimal]]] = defaultdict(list) - for acct_id, date_ord, amount, t_cat_id in query.yield_per(YIELD_PER): - acct_id: int - date_ord: int - amount: Decimal + for acct_id, date_ord, amount, t_cat_id in sql.yield_(query): if current_date_ord is None: current_date_ord = date_ord if date_ord != current_date_ord: @@ -101,7 +93,7 @@ def test(self, s: orm.Session) -> None: ) issues[uri] = msg - self._commit_issues(s, issues) + self._commit_issues(issues) @classmethod def _create_issue( diff --git a/nummus/health_checks/uncleared_transactions.py b/nummus/health_checks/uncleared_transactions.py index 45c39462..cc955f28 100644 --- a/nummus/health_checks/uncleared_transactions.py +++ b/nummus/health_checks/uncleared_transactions.py @@ -4,21 +4,14 @@ import datetime import textwrap -from typing import override, TYPE_CHECKING +from typing import override +from nummus import sql from nummus.health_checks.base import HealthCheck from nummus.models.account import Account -from nummus.models.base import YIELD_PER from nummus.models.currency import CURRENCY_FORMATS from nummus.models.transaction import TransactionSplit -if TYPE_CHECKING: - from decimal import Decimal - - from sqlalchemy import orm - - from nummus.models.currency import Currency - class UnclearedTransactions(HealthCheck): """Checks for uncleared transactions.""" @@ -31,38 +24,27 @@ class UnclearedTransactions(HealthCheck): _SEVERE = False @override - def test(self, s: orm.Session) -> None: - query = s.query(Account).with_entities( + def test(self) -> None: + query = Account.query( Account.id_, Account.name, Account.currency, ) - accounts: dict[int, tuple[str, Currency]] = { - r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER) - } + accounts = sql.to_dict_tuple(query) if len(accounts) == 0: - self._commit_issues(s, {}) + self._commit_issues({}) return acct_len = max(len(acct[0]) for acct in accounts.values()) issues: dict[str, str] = {} - query = ( - s.query(TransactionSplit) - .with_entities( - TransactionSplit.id_, - TransactionSplit.date_ord, - TransactionSplit.account_id, - TransactionSplit.payee, - TransactionSplit.amount, - ) - .where(TransactionSplit.cleared.is_(False)) - ) - for t_id, date_ord, acct_id, payee, amount in query.yield_per(YIELD_PER): - t_id: int - date_ord: int - acct_id: int - payee: str - amount: Decimal + query = TransactionSplit.query( + TransactionSplit.id_, + TransactionSplit.date_ord, + TransactionSplit.account_id, + TransactionSplit.payee, + TransactionSplit.amount, + ).where(TransactionSplit.cleared.is_(False)) + for t_id, date_ord, acct_id, payee, amount in sql.yield_(query): uri = TransactionSplit.id_to_uri(t_id) acct_name, currency = accounts[acct_id] @@ -76,4 +58,4 @@ def test(self, s: orm.Session) -> None: ) issues[uri] = msg - self._commit_issues(s, issues) + self._commit_issues(issues) diff --git a/nummus/health_checks/unnecessary_slits.py b/nummus/health_checks/unnecessary_slits.py index 55f73725..2ae9f634 100644 --- a/nummus/health_checks/unnecessary_slits.py +++ b/nummus/health_checks/unnecessary_slits.py @@ -3,19 +3,16 @@ from __future__ import annotations import datetime -from typing import NamedTuple, override, TYPE_CHECKING +from typing import NamedTuple, override from sqlalchemy import func +from nummus import sql from nummus.health_checks.base import HealthCheck from nummus.models.account import Account -from nummus.models.base import YIELD_PER from nummus.models.transaction import TransactionSplit from nummus.models.transaction_category import TransactionCategory -if TYPE_CHECKING: - from sqlalchemy import orm - class RawIssue(NamedTuple): """Type definition for a raw issue.""" @@ -33,15 +30,14 @@ class UnnecessarySplits(HealthCheck): _SEVERE = False @override - def test(self, s: orm.Session) -> None: - accounts = Account.map_name(s) - categories = TransactionCategory.map_name_emoji(s) + def test(self) -> None: + accounts = Account.map_name() + categories = TransactionCategory.map_name_emoji() issues: list[RawIssue] = [] query = ( - s.query(TransactionSplit) - .with_entities( + TransactionSplit.query( TransactionSplit.date_ord, TransactionSplit.account_id, TransactionSplit.parent_id, @@ -55,14 +51,7 @@ def test(self, s: orm.Session) -> None: .order_by(TransactionSplit.date_ord) .having(func.count() > 1) ) - for date_ord, acct_id, t_id, payee, t_cat_id in query.yield_per( - YIELD_PER, - ): - date_ord: int - acct_id: int - t_id: int - payee: str | None - t_cat_id: int + for date_ord, acct_id, t_id, payee, t_cat_id in sql.yield_(query): # Create a robust uri for this duplicate uri = f"{t_id}.{payee}.{t_cat_id}" @@ -82,7 +71,6 @@ def test(self, s: orm.Session) -> None: t_cat_len = 0 self._commit_issues( - s, { issue.uri: ( f"{issue.source:{source_len}}: " diff --git a/nummus/health_checks/unused_categories.py b/nummus/health_checks/unused_categories.py index 4d45017d..92e83eb5 100644 --- a/nummus/health_checks/unused_categories.py +++ b/nummus/health_checks/unused_categories.py @@ -2,16 +2,13 @@ from __future__ import annotations -from typing import override, TYPE_CHECKING +from typing import override +from nummus import sql from nummus.health_checks.base import HealthCheck from nummus.models.budget import BudgetAssignment from nummus.models.transaction import TransactionSplit from nummus.models.transaction_category import TransactionCategory -from nummus.models.utils import query_to_dict - -if TYPE_CHECKING: - from sqlalchemy import orm class UnusedCategories(HealthCheck): @@ -21,23 +18,22 @@ class UnusedCategories(HealthCheck): _SEVERE = False @override - def test(self, s: orm.Session) -> None: + def test(self) -> None: # Only check unlocked categories - query = ( - s.query(TransactionCategory) - .with_entities(TransactionCategory.id_, TransactionCategory.emoji_name) - .where(TransactionCategory.locked.is_(False)) - ) - categories: dict[int, str] = query_to_dict(query) + query = TransactionCategory.query( + TransactionCategory.id_, + TransactionCategory.emoji_name, + ).where(TransactionCategory.locked.is_(False)) + categories: dict[int, str] = sql.to_dict(query) if len(categories) == 0: - self._commit_issues(s, {}) + self._commit_issues({}) return - query = s.query(TransactionSplit.category_id) - used_categories = {r[0] for r in query.distinct()} + query = TransactionSplit.query(TransactionSplit.category_id) + used_categories = set(sql.col0(query)) - query = s.query(BudgetAssignment.category_id) - used_categories.update(r[0] for r in query.distinct()) + query = BudgetAssignment.query(BudgetAssignment.category_id) + used_categories.update(sql.col0(query)) categories = { t_cat_id: name @@ -49,7 +45,6 @@ def test(self, s: orm.Session) -> None: ) self._commit_issues( - s, { TransactionCategory.id_to_uri(t_cat_id): ( f"{name:{category_len}} has no " diff --git a/nummus/migrations/base.py b/nummus/migrations/base.py index 1992d53f..ced1eecf 100644 --- a/nummus/migrations/base.py +++ b/nummus/migrations/base.py @@ -12,13 +12,11 @@ from sqlalchemy.schema import CreateTable from nummus import sql +from nummus.models.base import Base from nummus.models.utils import dump_table_configs, get_constraints if TYPE_CHECKING: - from sqlalchemy import orm - from nummus import portfolio - from nummus.models.base import Base class Migrator(ABC): @@ -55,20 +53,19 @@ def min_version(cls) -> Version: def add_column( self, - s: orm.Session, model: type[Base], - column: orm.QueryableAttribute, + column: sql.Column, initial_value: object | None = None, ) -> None: """Add a column to a table. Args: - s: SQL session to use model: Table to modify column: Column to add initial_value: Value to set all rows to """ + s = model.session() engine = s.get_bind().engine col_name = sql.escape(column.name) @@ -77,13 +74,12 @@ def add_column( s.execute(sqlalchemy.text(stmt)) if initial_value is not None: - s.query(model).update({column: initial_value}) + model.query().update({column: initial_value}) self.pending_schema_updates.add(model) def rename_column( self, - s: orm.Session, model: type[Base], old_name: str, new_name: str, @@ -91,7 +87,6 @@ def rename_column( """Rename a column in a table. Args: - s: SQL session to use model: Table to modify old_name: Current name of column new_name: New name of column @@ -100,7 +95,7 @@ def rename_column( old_name = sql.escape(old_name) new_name = sql.escape(new_name) stmt = f'ALTER TABLE "{model.__tablename__}" RENAME {old_name} TO {new_name}' - s.execute(sqlalchemy.text(stmt)) + model.session().execute(sqlalchemy.text(stmt)) # RENAME modifies column references but not constraint names # Need to update schema to update those @@ -108,32 +103,29 @@ def rename_column( def drop_column( self, - s: orm.Session, model: type[Base], col_name: str, ) -> None: """Rename a column in a table. Args: - s: SQL session to use model: Table to modify col_name: Name of column to drop """ - constraints = get_constraints(s, model) + constraints = get_constraints(model) if any(col_name in sql_text for _, sql_text in constraints): - self.recreate_table(s, model, drop={col_name}) + self.recreate_table(model, drop={col_name}) else: # Able to drop directly col_name = sql.escape(col_name) stmt = f'ALTER TABLE "{model.__tablename__}" DROP {col_name}' - s.execute(sqlalchemy.text(stmt)) + model.session().execute(sqlalchemy.text(stmt)) # DROP does not need updated schema def recreate_table( self, - s: orm.Session, model: type[Base], *, drop: set[str] | None = None, @@ -142,13 +134,13 @@ def recreate_table( """Rebuild table, optionally dropping columns. Args: - s: SQL session to use model: Table to modify drop: Set of column names to drop create_stmt: Statement to execute to create new table, None will modify existing config """ + s = model.session() drop = drop or set() # In SQLite we can do the hacky way or recreate the table # Opt for recreate @@ -159,7 +151,7 @@ def recreate_table( new_config = create_stmt.splitlines() else: # Edit table config, dropping any columns - old_config = dump_table_configs(s, model) + old_config = dump_table_configs(model) new_config: list[str] = [] re_column = re.compile(r" +([a-z_]+) [A-Z ]+") re_constraint = re.compile(r' +[A-Z ]+(?:"[^\"]+" [A-Z ]+)?\(([^\)]+)\)') @@ -191,7 +183,7 @@ def recreate_table( s.execute(sqlalchemy.text(stmt)) # Drop old table - self.drop_table(s, name) + self.drop_table(name) # Rename new into old stmt = f'ALTER TABLE "migration_temp" RENAME TO "{name}"' @@ -204,16 +196,15 @@ def recreate_table( self.pending_schema_updates.add(model) @staticmethod - def drop_table(s: orm.Session, table_name: str) -> None: + def drop_table(table_name: str) -> None: """Drop a table. Args: - s: SQL session to use table_name: Name of table to drop """ stmt = f'DROP TABLE "{table_name}"' - s.execute(sqlalchemy.text(stmt)) + Base.session().execute(sqlalchemy.text(stmt)) class SchemaMigrator(Migrator): @@ -235,5 +226,5 @@ def migrate(self, p: portfolio.Portfolio) -> list[str]: with p.begin_session() as s: table: sqlalchemy.Table = model.sql_table() create_stmt = CreateTable(table).compile(s.get_bind()).string.strip() - self.recreate_table(s, model, create_stmt=create_stmt) + self.recreate_table(model, create_stmt=create_stmt) return [] diff --git a/nummus/migrations/v0_10.py b/nummus/migrations/v0_10.py index 4e7bc61b..c0734c76 100644 --- a/nummus/migrations/v0_10.py +++ b/nummus/migrations/v0_10.py @@ -5,7 +5,6 @@ from typing import override, TYPE_CHECKING from nummus.migrations.base import Migrator -from nummus.models.base import YIELD_PER from nummus.models.health_checks import HealthCheckIssue from nummus.models.transaction_category import TransactionCategory @@ -23,16 +22,14 @@ def migrate(self, p: portfolio.Portfolio) -> list[str]: comments: list[str] = [] - with p.begin_session() as s: + with p.begin_session(): self.rename_column( - s, TransactionCategory, "essential", "essential_spending", ) - query = s.query(HealthCheckIssue) - for issue in query.yield_per(YIELD_PER): + for issue in HealthCheckIssue.all(): issue.check = issue.check.capitalize() return comments diff --git a/nummus/migrations/v0_13.py b/nummus/migrations/v0_13.py index 3c7c8f18..3bb0a348 100644 --- a/nummus/migrations/v0_13.py +++ b/nummus/migrations/v0_13.py @@ -40,15 +40,13 @@ def migrate(self, p: portfolio.Portfolio) -> list[str]: for t_split_id, name in s.execute(sqlalchemy.text(stmt)): tag_mapping[name].add(t_split_id) - labels = [Label(name=name) for name in tag_mapping] - s.add_all(labels) - s.flush() + labels = [Label.create(name=name) for name in tag_mapping] for label in labels: for t_split_id in tag_mapping[label.name]: - s.add(LabelLink(label_id=label.id_, t_split_id=t_split_id)) + LabelLink.create(label_id=label.id_, t_split_id=t_split_id) - with p.begin_session() as s: - self.drop_column(s, TransactionSplit, "tag") + with p.begin_session(): + self.drop_column(TransactionSplit, "tag") return comments diff --git a/nummus/migrations/v0_15.py b/nummus/migrations/v0_15.py index 616eabb4..a719e8b8 100644 --- a/nummus/migrations/v0_15.py +++ b/nummus/migrations/v0_15.py @@ -7,12 +7,15 @@ import sqlalchemy from nummus import exceptions as exc +from nummus import sql from nummus.migrations.base import Migrator from nummus.models.base import Base from nummus.models.label import Label, LabelLink from nummus.models.utils import dump_table_configs if TYPE_CHECKING: + from sqlalchemy import orm + from nummus import portfolio @@ -30,7 +33,7 @@ def migrate(self, p: portfolio.Portfolio) -> list[str]: with p.begin_session() as s: # Already have Label from updated v0.13 migrator, skip this one try: - dump_table_configs(s, Label) + dump_table_configs(Label) except exc.NoResultFound: pass else: @@ -43,19 +46,20 @@ def migrate(self, p: portfolio.Portfolio) -> list[str]: # Move existing tags to labels with p.begin_session() as s: stmt = "SELECT id_, name FROM tag" - # Hand crafted SQL statement can't use query_to_dict - tags: dict[int, str] = dict(s.execute(sqlalchemy.text(stmt)).all()) # type: ignore[attr-defined] + query: orm.query.RowReturningQuery[tuple[int, str]] = s.execute( # type: ignore[attr-defined] + sqlalchemy.text(stmt), + ) + tags: dict[int, str] = sql.to_dict(query) - labels = [Label(id_=tag_id, name=name) for tag_id, name in tags.items()] - s.add_all(labels) - s.flush() + for tag_id, name in tags.items(): + Label.create(id_=tag_id, name=name) stmt = "SELECT tag_id, t_split_id FROM tag_link" for tag_id, t_split_id in s.execute(sqlalchemy.text(stmt)): - s.add(LabelLink(label_id=tag_id, t_split_id=t_split_id)) + LabelLink.create(label_id=tag_id, t_split_id=t_split_id) - with p.begin_session() as s: - self.drop_table(s, "tag_link") - self.drop_table(s, "tag") + with p.begin_session(): + self.drop_table("tag_link") + self.drop_table("tag") return comments diff --git a/nummus/migrations/v0_16.py b/nummus/migrations/v0_16.py index 09b76d25..76cc6077 100644 --- a/nummus/migrations/v0_16.py +++ b/nummus/migrations/v0_16.py @@ -27,12 +27,13 @@ def migrate(self, p: portfolio.Portfolio) -> list[str]: f"Portfolio currency set to {DEFAULT_CURRENCY.pretty}, use web to edit", ] - with p.begin_session() as s: - s.add( - Config(key=ConfigKey.BASE_CURRENCY, value=str(DEFAULT_CURRENCY.value)), + with p.begin_session(): + Config.create( + key=ConfigKey.BASE_CURRENCY, + value=str(DEFAULT_CURRENCY.value), ) - self.add_column(s, Account, Account.currency, DEFAULT_CURRENCY) - self.add_column(s, Asset, Asset.currency, DEFAULT_CURRENCY) + self.add_column(Account, Account.currency, DEFAULT_CURRENCY) + self.add_column(Asset, Asset.currency, DEFAULT_CURRENCY) return comments diff --git a/nummus/migrations/v0_2.py b/nummus/migrations/v0_2.py index 8bb11716..284046bb 100644 --- a/nummus/migrations/v0_2.py +++ b/nummus/migrations/v0_2.py @@ -4,12 +4,12 @@ from typing import override, TYPE_CHECKING -from sqlalchemy import func, sql +from sqlalchemy import func +from nummus import sql from nummus.migrations.base import Migrator from nummus.models.account import Account from nummus.models.asset import Asset -from nummus.models.base import YIELD_PER from nummus.models.budget import BudgetGroup from nummus.models.config import Config from nummus.models.health_checks import HealthCheckIssue @@ -32,29 +32,29 @@ def migrate(self, p: portfolio.Portfolio) -> list[str]: comments: list[str] = [] - with p.begin_session() as s: + with p.begin_session(): # Update TransactionSplit to add text_fields - self.add_column(s, TransactionSplit, TransactionSplit.text_fields) - self.rename_column(s, TransactionSplit, "description", "memo") - self.rename_column(s, TransactionSplit, "linked", "cleared") - self.drop_column(s, TransactionSplit, "locked") + self.add_column(TransactionSplit, TransactionSplit.text_fields) + self.rename_column(TransactionSplit, "description", "memo") + self.rename_column(TransactionSplit, "linked", "cleared") + self.drop_column(TransactionSplit, "locked") - with p.begin_session() as s: + with p.begin_session(): # Update Transaction to add payee - self.add_column(s, Transaction, Transaction.payee) - self.rename_column(s, Transaction, "linked", "cleared") - self.drop_column(s, Transaction, "locked") + self.add_column(Transaction, Transaction.payee) + self.rename_column(Transaction, "linked", "cleared") + self.drop_column(Transaction, "locked") - with p.begin_session() as s: + with p.begin_session(): # Check which ones have more than one payee - accounts = Account.map_name(s) + accounts = Account.map_name() query = ( - s.query(TransactionSplit) + TransactionSplit.query() .group_by(TransactionSplit.parent_id) .having(func.count(TransactionSplit.payee.distinct()) > 1) .order_by(TransactionSplit.date_ord) ) - for t_split in query.yield_per(YIELD_PER): + for t_split in sql.yield_(query): msg = ( "This transaction had multiple payees, only one allowed: " f"{t_split.date} {accounts[t_split.account_id]}, please validate" @@ -62,21 +62,20 @@ def migrate(self, p: portfolio.Portfolio) -> list[str]: comments.append(msg) sub_query = ( - s.query(TransactionSplit.payee) + TransactionSplit.query(TransactionSplit.payee) .where( TransactionSplit.parent_id == Transaction.id_, ) .scalar_subquery() ) - s.query(Transaction).update( + Transaction.query().update( {Transaction.payee: sub_query}, ) - with p.begin_session() as s: + with p.begin_session(): n_batch = 100 # Update text_fields after payee is set - query = s.query(TransactionSplit) - for t_split in query.yield_per(YIELD_PER): + for t_split in TransactionSplit.all(): t_split.parent = t_split.parent t_split.memo = t_split.memo @@ -85,8 +84,7 @@ def migrate(self, p: portfolio.Portfolio) -> list[str]: offset = 0 while has_more: query = ( - s.query(TransactionCategory) - .with_entities( + TransactionCategory.query( TransactionCategory.id_, TransactionCategory.emoji_name, ) @@ -96,9 +94,9 @@ def migrate(self, p: portfolio.Portfolio) -> list[str]: ) values = { id_: TransactionCategory.clean_emoji_name(v) - for id_, v in query.yield_per(YIELD_PER) + for id_, v in sql.yield_(query) } - s.query(TransactionCategory).where( + TransactionCategory.query().where( TransactionCategory.id_.in_(values), ).update( { diff --git a/nummus/models/account.py b/nummus/models/account.py index c62c1eea..af167511 100644 --- a/nummus/models/account.py +++ b/nummus/models/account.py @@ -8,7 +8,7 @@ from sqlalchemy import func, orm, UniqueConstraint -from nummus import utils +from nummus import sql, utils from nummus.models.asset import Asset from nummus.models.base import ( Base, @@ -18,12 +18,10 @@ ORMStrOpt, SQLEnum, string_column_args, - YIELD_PER, ) from nummus.models.currency import Currency from nummus.models.transaction import Transaction, TransactionSplit from nummus.models.transaction_category import TransactionCategory -from nummus.models.utils import obj_session if TYPE_CHECKING: from collections.abc import Iterable @@ -92,6 +90,8 @@ class Account(Base): *string_column_args("institution"), ) + _SEARCH_PROPERTIES = ("number", "institution", "name") + @orm.validates("name", "number", "institution") def validate_strings(self, key: str, field: str | None) -> str | None: """Validate string fields satisfy constraints. @@ -109,25 +109,22 @@ def validate_strings(self, key: str, field: str | None) -> str | None: @property def opened_on_ord(self) -> int | None: """Date ordinal of first Transaction.""" - s = obj_session(self) - query = s.query(func.min(Transaction.date_ord)).where( + query = Transaction.query(func.min(Transaction.date_ord)).where( Transaction.account_id == self.id_, ) - return query.scalar() + return sql.scalar(query) @property def updated_on_ord(self) -> int | None: """Date ordinal of latest Transaction.""" - s = obj_session(self) - query = s.query(func.max(Transaction.date_ord)).where( + query = Transaction.query(func.max(Transaction.date_ord)).where( Transaction.account_id == self.id_, ) - return query.scalar() + return sql.scalar(query) @classmethod def get_value_all( cls, - s: orm.Session, start_ord: int, end_ord: int, ids: Iterable[int] | None = None, @@ -136,7 +133,6 @@ def get_value_all( """Get the value of all Accounts from start to end date. Args: - s: SQL session to use start_ord: First date ordinal to evaluate end_ord: Last date ordinal to evaluate (inclusive) ids: Limit results to specific Accounts by ID @@ -151,10 +147,11 @@ def get_value_all( n = end_ord - start_ord + 1 if not ids and ids is not None: - acct_values = defaultdict(lambda: [Decimal()] * n) - acct_profit = defaultdict(lambda: [Decimal()] * n) - asset_values = defaultdict(lambda: [Decimal()] * n) - return ValueResultAll(acct_values, acct_profit, asset_values) + return ValueResultAll( + defaultdict(lambda: [Decimal()] * n), + defaultdict(lambda: [Decimal()] * n), + defaultdict(lambda: [Decimal()] * n), + ) cash_flow_accounts: dict[int, list[Decimal | None]] = defaultdict( lambda: [None] * n, @@ -162,21 +159,20 @@ def get_value_all( cost_basis_accounts: dict[int, list[Decimal | None]] = defaultdict( lambda: [None] * n, ) - ids = ids or {r[0] for r in s.query(Account.id_).all()} + ids = ids or set(sql.col0(Account.query(Account.id_))) # Profit = Interest + dividends + rewards + change in asset value - fees # Dividends, fees, and change in value can be assigned to an asset # Change in value = current value - basis # Get list of transaction categories not included in cost basis - query = s.query(TransactionCategory.id_).where( + query = TransactionCategory.query(TransactionCategory.id_).where( TransactionCategory.is_profit_loss.is_(True), ) - cost_basis_skip_ids = {t_cat_id for t_cat_id, in query.all()} + cost_basis_skip_ids = set(sql.col0(query)) # Get Account cash value on start date query = ( - s.query(TransactionSplit) - .with_entities( + TransactionSplit.query( TransactionSplit.account_id, func.sum(TransactionSplit.amount), ) @@ -186,15 +182,12 @@ def get_value_all( ) .group_by(TransactionSplit.account_id) ) - for acct_id, iv in query.all(): - acct_id: int - iv: Decimal + for acct_id, iv in sql.yield_(query): cash_flow_accounts[acct_id][0] = iv # Calculate cost basis on first day query = ( - s.query(TransactionSplit) - .with_entities( + TransactionSplit.query( TransactionSplit.account_id, func.sum(TransactionSplit.amount), ) @@ -205,36 +198,24 @@ def get_value_all( ) .group_by(TransactionSplit.account_id) ) - for acct_id, iv in query.all(): - acct_id: int - iv: Decimal + for acct_id, iv in sql.yield_(query): cost_basis_accounts[acct_id][0] = -iv if start_ord != end_ord: # Get cash_flow on each day between start and end # Not Account.get_cash_flow because being categorized doesn't matter and # slows it down - query = ( - s.query(TransactionSplit) - .with_entities( - TransactionSplit.account_id, - TransactionSplit.date_ord, - TransactionSplit.amount, - TransactionSplit.category_id, - ) - .where( - TransactionSplit.date_ord <= end_ord, - TransactionSplit.date_ord > start_ord, - TransactionSplit.account_id.in_(ids), - ) + query = TransactionSplit.query( + TransactionSplit.account_id, + TransactionSplit.date_ord, + TransactionSplit.amount, + TransactionSplit.category_id, + ).where( + TransactionSplit.date_ord <= end_ord, + TransactionSplit.date_ord > start_ord, + TransactionSplit.account_id.in_(ids), ) - - for acct_id, date_ord, amount, t_cat_id in query.yield_per(YIELD_PER): - acct_id: int - date_ord: int - amount: Decimal - t_cat_id: int - + for acct_id, date_ord, amount, t_cat_id in sql.yield_(query): i = date_ord - start_ord v = cash_flow_accounts[acct_id][i] @@ -248,33 +229,29 @@ def get_value_all( # Get assets for all Accounts assets_accounts = cls.get_asset_qty_all( - s, start_ord, end_ord, list(cash_flow_accounts.keys()), ) # Get day one asset transactions to add to profit & loss - query = ( - s.query(TransactionSplit) - .with_entities( - TransactionSplit.account_id, - TransactionSplit.asset_id, - TransactionSplit.asset_quantity, - ) - .where( - TransactionSplit.asset_id.isnot(None), - TransactionSplit.date_ord == start_ord, - TransactionSplit.account_id.in_(ids), - ) + query = TransactionSplit.query( + TransactionSplit.account_id, + TransactionSplit.asset_id, + TransactionSplit.asset_quantity, + ).where( + TransactionSplit.asset_id.isnot(None), + TransactionSplit.date_ord == start_ord, + TransactionSplit.account_id.in_(ids), ) assets_day_zero: dict[int, dict[int, Decimal]] = defaultdict( lambda: defaultdict(Decimal), ) - for acct_id, a_id, qty in query.yield_per(YIELD_PER): - acct_id: int - a_id: int - qty: Decimal + for acct_id, a_id, qty in sql.yield_(query): + if TYPE_CHECKING: + # Enforced by query and SQL constraints + assert a_id is not None + assert qty is not None assets_day_zero[acct_id][a_id] += qty # Remove zeros @@ -295,17 +272,15 @@ def get_value_all( a_ids: set[int] = utils.set_sub_keys(assets_accounts) a_ids.update(utils.set_sub_keys(assets_day_zero)) - asset_prices = Asset.get_value_all(s, start_ord, end_ord, a_ids) + asset_prices = Asset.get_value_all(start_ord, end_ord, a_ids) forex_by_account: dict[int, list[Decimal]] | None = None if forex is not None: - query = ( - s.query(Account) - .with_entities(Account.id_, Account.currency) - .where(Account.id_.in_(ids)) + query = Account.query(Account.id_, Account.currency).where( + Account.id_.in_(ids), ) forex_by_account = { - acct_id: forex[currency] for acct_id, currency in query.all() + acct_id: forex[currency] for acct_id, currency in sql.yield_(query) } return cls._merge_value_data( @@ -384,23 +359,23 @@ def get_value( self, start_ord: int, end_ord: int, + forex: dict[Currency, list[Decimal]] | None = None, ) -> ValueResult: """Get the value of Account from start to end date. Args: start_ord: First date ordinal to evaluate end_ord: Last date ordinal to evaluate (inclusive) + forex: Currency exchange rates, None will not normalize Returns: ValueResult """ - s = obj_session(self) - # Not reusing get_value_all is faster by ~2ms, # not worth maintaining two almost identical implementations - r = self.get_value_all(s, start_ord, end_ord, [self.id_]) + r = self.get_value_all(start_ord, end_ord, [self.id_], forex=forex) return ValueResult( r.values_by_account[self.id_], r.profits[self.id_], @@ -410,7 +385,6 @@ def get_value( @classmethod def get_cash_flow_all( cls, - s: orm.Session, start_ord: int, end_ord: int, ids: Iterable[int] | None = None, @@ -420,7 +394,6 @@ def get_cash_flow_all( Does not separate results by account. Args: - s: SQL session to use start_ord: First date ordinal to evaluate end_ord: Last date ordinal to evaluate (inclusive) ids: Limit results to specific Accounts by ID @@ -435,26 +408,18 @@ def get_cash_flow_all( categories: dict[int, list[Decimal]] = defaultdict(lambda: [Decimal()] * n) # Transactions between start and end - query = ( - s.query(TransactionSplit) - .with_entities( - TransactionSplit.date_ord, - TransactionSplit.amount, - TransactionSplit.category_id, - ) - .where( - TransactionSplit.date_ord <= end_ord, - TransactionSplit.date_ord >= start_ord, - ) + query = TransactionSplit.query( + TransactionSplit.date_ord, + TransactionSplit.amount, + TransactionSplit.category_id, + ).where( + TransactionSplit.date_ord <= end_ord, + TransactionSplit.date_ord >= start_ord, ) if ids is not None: query = query.where(TransactionSplit.account_id.in_(ids)) - for t_date_ord, amount, category_id in query.yield_per(YIELD_PER): - t_date_ord: int - amount: Decimal - category_id: int - + for t_date_ord, amount, category_id in sql.yield_(query): categories[category_id][t_date_ord - start_ord] += amount return categories @@ -478,13 +443,11 @@ def get_cash_flow( Includes None in categories """ - s = obj_session(self) - return self.get_cash_flow_all(s, start_ord, end_ord, [self.id_]) + return self.get_cash_flow_all(start_ord, end_ord, [self.id_]) @classmethod def get_asset_qty_all( cls, - s: orm.Session, start_ord: int, end_ord: int, ids: Iterable[int] | None = None, @@ -492,7 +455,6 @@ def get_asset_qty_all( """Get the quantity of Assets held from start to end date. Args: - s: SQL session to use start_ord: First date ordinal to evaluate end_ord: Last date ordinal to evaluate (inclusive) ids: Limit results to specific Accounts by ID @@ -509,13 +471,15 @@ def get_asset_qty_all( lambda: defaultdict(lambda: [Decimal()] * n), ) - iv_accounts: dict[int, dict[int, Decimal]] = defaultdict(dict) - ids = ids or {r[0] for r in s.query(Account.id_).all()} + # Daily delta in qty + deltas_accounts: dict[int, dict[int, list[Decimal | None]]] = defaultdict( + lambda: defaultdict(lambda: [None] * n), + ) + ids = ids or set(sql.col0(Account.query(Account.id_))) # Get Asset quantities on start date query = ( - s.query(TransactionSplit) - .with_entities( + TransactionSplit.query( TransactionSplit.account_id, TransactionSplit.asset_id, func.sum(TransactionSplit.asset_quantity), @@ -530,27 +494,17 @@ def get_asset_qty_all( TransactionSplit.asset_id, ) ) - - for acct_id, a_id, qty in query.yield_per(YIELD_PER): - acct_id: int - a_id: int - qty: Decimal - iv_accounts[acct_id][a_id] = qty - - # Daily delta in qty - deltas_accounts: dict[int, dict[int, list[Decimal | None]]] = defaultdict( - lambda: defaultdict(lambda: [None] * n), - ) - for acct_id, iv in iv_accounts.items(): - deltas = deltas_accounts[acct_id] - for a_id, v in iv.items(): - deltas[a_id][0] = v + for acct_id, a_id, qty in sql.yield_(query): + if TYPE_CHECKING: + # Enforced by query and SQL constraints + assert a_id is not None + assert qty is not None + deltas_accounts[acct_id][a_id][0] = qty if start_ord != end_ord: # Transactions between start and end query = ( - s.query(TransactionSplit) - .with_entities( + TransactionSplit.query( TransactionSplit.date_ord, TransactionSplit.account_id, TransactionSplit.asset_id, @@ -565,15 +519,14 @@ def get_asset_qty_all( .order_by(TransactionSplit.account_id) ) - current_acct_id = None + current_acct_id: int | None = None deltas = {} - for date_ord, acct_id, a_id, qty in query.yield_per(YIELD_PER): - date_ord: int - acct_id: int - a_id: int - qty: Decimal - + for date_ord, acct_id, a_id, qty in sql.yield_(query): + if TYPE_CHECKING: + # Enforced by query and SQL constraints + assert a_id is not None + assert qty is not None i = date_ord - start_ord if acct_id != current_acct_id: @@ -608,13 +561,11 @@ def get_asset_qty( dict{Asset.id_: list[values]} """ - s = obj_session(self) - return self.get_asset_qty_all(s, start_ord, end_ord, [self.id_])[self.id_] + return self.get_asset_qty_all(start_ord, end_ord, [self.id_])[self.id_] @classmethod def get_profit_by_asset_all( cls, - s: orm.Session, start_ord: int, end_ord: int, ids: Iterable[int] | None = None, @@ -622,7 +573,6 @@ def get_profit_by_asset_all( """Get the profit of Assets on end_date since start_ord. Args: - s: SQL session to use start_ord: First date ordinal to evaluate end_ord: Last date ordinal to evaluate (inclusive) ids: Limit results to specific Accounts by ID @@ -633,9 +583,9 @@ def get_profit_by_asset_all( """ # Get Asset quantities on start date + initial_qty: dict[int, Decimal] = defaultdict(Decimal) query = ( - s.query(TransactionSplit) - .with_entities( + TransactionSplit.query( TransactionSplit.asset_id, func.sum(TransactionSplit.asset_quantity), ) @@ -647,40 +597,38 @@ def get_profit_by_asset_all( ) if ids is not None: query = query.where(TransactionSplit.account_id.in_(ids)) - - initial_qty: dict[int, Decimal] = defaultdict( - Decimal, - {a_id: qty for a_id, qty in query.yield_per(YIELD_PER) if qty != 0}, - ) - - query = ( - s.query(TransactionSplit) - .with_entities( - TransactionSplit.asset_id, - TransactionSplit.asset_quantity, - TransactionSplit.amount, - ) - .where( - TransactionSplit.asset_id.is_not(None), - TransactionSplit.date_ord >= start_ord, - TransactionSplit.date_ord <= end_ord, - ) + for a_id, qty in sql.yield_(query): + if TYPE_CHECKING: + # Enforced by query and SQL constraints + assert a_id is not None + assert qty is not None + initial_qty[a_id] = qty + + query = TransactionSplit.query( + TransactionSplit.asset_id, + TransactionSplit.asset_quantity, + TransactionSplit.amount, + ).where( + TransactionSplit.asset_id.is_not(None), + TransactionSplit.date_ord >= start_ord, + TransactionSplit.date_ord <= end_ord, ) if ids is not None: query = query.where(TransactionSplit.account_id.in_(ids)) cost_basis: dict[int, Decimal] = defaultdict(Decimal) end_qty: dict[int, Decimal] = initial_qty.copy() - for a_id, qty, amount in query.yield_per(YIELD_PER): - a_id: int - qty: Decimal - amount: Decimal + for a_id, qty, amount in sql.yield_(query): + if TYPE_CHECKING: + # Enforced by query and SQL constraints + assert a_id is not None + assert qty is not None end_qty[a_id] += qty cost_basis[a_id] += amount a_ids = set(end_qty) - initial_price = Asset.get_value_all(s, start_ord, start_ord, ids=a_ids) - end_price = Asset.get_value_all(s, end_ord, end_ord, ids=a_ids) + initial_price = Asset.get_value_all(start_ord, start_ord, ids=a_ids) + end_price = Asset.get_value_all(end_ord, end_ord, ids=a_ids) profits: dict[int, Decimal] = defaultdict(Decimal) for a_id in a_ids: @@ -707,23 +655,21 @@ def get_profit_by_asset( dict{Asset.id_: profit} """ - s = obj_session(self) - return self.get_profit_by_asset_all(s, start_ord, end_ord, [self.id_]) + return self.get_profit_by_asset_all(start_ord, end_ord, [self.id_]) @classmethod - def ids(cls, s: orm.Session, category: AccountCategory) -> set[int]: + def ids(cls, category: AccountCategory) -> set[int]: """Get Account ids for a specific category. Args: - s: SQL session to use category: AccountCategory to filter Returns: set{Account.id_} """ - query = s.query(Account.id_).where(Account.category == category) - return {acct_id for acct_id, in query.all()} + query = Account.query(Account.id_).where(Account.category == category) + return {acct_id for acct_id, in sql.yield_(query)} def do_include(self, date_ord: int) -> bool: """Test if account should be included for data. diff --git a/nummus/models/asset.py b/nummus/models/asset.py index 170d47ae..8ba6ea6e 100644 --- a/nummus/models/asset.py +++ b/nummus/models/asset.py @@ -8,13 +8,12 @@ from decimal import Decimal from typing import override, TYPE_CHECKING -import pandas as pd import yfinance import yfinance.exceptions from sqlalchemy import CheckConstraint, ForeignKey, func, Index, orm, UniqueConstraint from nummus import exceptions as exc -from nummus import utils +from nummus import sql, utils from nummus.models.base import ( Base, BaseEnum, @@ -26,11 +25,10 @@ ORMStrOpt, SQLEnum, string_column_args, - YIELD_PER, ) from nummus.models.currency import Currency, DEFAULT_CURRENCY from nummus.models.transaction import TransactionSplit -from nummus.models.utils import obj_session, query_count, query_to_dict, update_rows +from nummus.models.utils import update_rows if TYPE_CHECKING: from collections.abc import Iterable, Mapping @@ -230,6 +228,8 @@ class Asset(Base): *string_column_args("ticker", short_check=False), ) + _SEARCH_PROPERTIES = ("ticker", "name") + @orm.validates("name", "description", "ticker") def validate_strings(self, key: str, field: str | None) -> str | None: """Validate string fields satisfy constraints. @@ -247,7 +247,6 @@ def validate_strings(self, key: str, field: str | None) -> str | None: @classmethod def get_value_all( cls, - s: orm.Session, start_ord: int, end_ord: int, ids: Iterable[int] | None = None, @@ -255,7 +254,6 @@ def get_value_all( """Get the value of all Assets from start to end date. Args: - s: SQL session to use start_ord: First date ordinal to evaluate end_ord: Last date ordinal to evaluate (inclusive) ids: Limit results to specific Assets by ID @@ -269,15 +267,14 @@ def get_value_all( # Get a list of valuations (date offset, value) for each Asset valuations_assets: dict[int, list[tuple[int, Decimal]]] = defaultdict(list) - query = s.query(Asset.id_).where(Asset.interpolate) + query = Asset.query(Asset.id_).where(Asset.interpolate) if ids is not None: query = query.where(Asset.id_.in_(ids)) - interpolated_assets: set[int] = {r[0] for r in query.all()} + interpolated_assets: set[int] = {r[0] for r in sql.yield_(query)} # Get latest Valuation before or including start date query = ( - s.query(AssetValuation) - .with_entities( + AssetValuation.query( AssetValuation.asset_id, func.max(AssetValuation.date_ord), AssetValuation.value, @@ -287,41 +284,30 @@ def get_value_all( ) if ids is not None: query = query.where(AssetValuation.asset_id.in_(ids)) - for a_id, date_ord, v in query.all(): - a_id: int - date_ord: int - v: Decimal + for a_id, date_ord, v in sql.yield_(query): i = date_ord - start_ord valuations_assets[a_id] = [(i, v)] if start_ord != end_ord: # Transactions between start and end - query = ( - s.query(AssetValuation) - .with_entities( - AssetValuation.asset_id, - AssetValuation.date_ord, - AssetValuation.value, - ) - .where( - AssetValuation.date_ord <= end_ord, - AssetValuation.date_ord > start_ord, - ) + query = AssetValuation.query( + AssetValuation.asset_id, + AssetValuation.date_ord, + AssetValuation.value, + ).where( + AssetValuation.date_ord <= end_ord, + AssetValuation.date_ord > start_ord, ) if ids is not None: query = query.where(AssetValuation.asset_id.in_(ids)) - for a_id, date_ord, v in query.yield_per(YIELD_PER): - a_id: int - date_ord: int - v: Decimal + for a_id, date_ord, v in sql.yield_(query): i = date_ord - start_ord valuations_assets[a_id].append((i, v)) # Get interpolation point for assets with interpolation query = ( - s.query(AssetValuation) - .with_entities( + AssetValuation.query( AssetValuation.asset_id, func.min(AssetValuation.date_ord), AssetValuation.value, @@ -332,10 +318,7 @@ def get_value_all( ) .group_by(AssetValuation.asset_id) ) - for a_id, date_ord, v in query.all(): - a_id: int - date_ord: int - v: Decimal + for a_id, date_ord, v in sql.yield_(query): i = date_ord - start_ord valuations_assets[a_id].append((i, v)) @@ -360,12 +343,9 @@ def get_value(self, start_ord: int, end_ord: int) -> list[Decimal]: list[values] """ - s = obj_session(self) - # Not reusing get_value_all is faster by ~2ms, # not worth maintaining two almost identical implementations - - return self.get_value_all(s, start_ord, end_ord, [self.id_])[self.id_] + return self.get_value_all(start_ord, end_ord, [self.id_])[self.id_] def update_splits(self) -> None: """Recalculate adjusted TransactionSplit.asset_quantity based on all splits. @@ -375,21 +355,16 @@ def update_splits(self) -> None: # This function is best here but need to avoid circular imports from nummus.models.transaction import TransactionSplit # noqa: PLC0415 - s = obj_session(self) - multiplier = Decimal(1) splits: list[tuple[int, Decimal]] = [] query = ( - s.query(AssetSplit) - .with_entities(AssetSplit.date_ord, AssetSplit.multiplier) + AssetSplit.query(AssetSplit.date_ord, AssetSplit.multiplier) .where(AssetSplit.asset_id == self.id_) .order_by(AssetSplit.date_ord.desc()) ) - for s_date_ord, s_multiplier in query.yield_per(YIELD_PER): - s_date_ord: int - s_multiplier: Decimal + for s_date_ord, s_multiplier in sql.yield_(query): # Compound splits as we go multiplier *= s_multiplier splits.append((s_date_ord, multiplier)) @@ -400,7 +375,7 @@ def update_splits(self) -> None: multiplier = Decimal(1) query = ( - s.query(TransactionSplit) + TransactionSplit.query() .where( TransactionSplit.asset_id == self.id_, TransactionSplit._asset_qty_unadjusted.isnot(None), # noqa: SLF001 @@ -410,9 +385,7 @@ def update_splits(self) -> None: sum_unadjusted = Decimal() sum_adjusted = Decimal() - for t_split in query.yield_per(YIELD_PER): - # Query whole object okay, need to set things - t_split: TransactionSplit + for t_split in sql.yield_(query): # If txn is on/after the split, update the multiplier while len(splits) >= 1 and t_split.date_ord >= splits[0][0]: splits.pop(0) @@ -441,7 +414,6 @@ def prune_valuations(self) -> int: if self.category == AssetCategory.INDEX: # If asset is an INDEX, do not prune return 0 - s = obj_session(self) # Date when quantity is zero date_ord_zero: int | None = None @@ -451,26 +423,25 @@ def prune_valuations(self) -> int: periods_zero: list[tuple[int | None, int | None]] = [] query = ( - s.query(TransactionSplit) - .with_entities( + TransactionSplit.query( TransactionSplit.date_ord, TransactionSplit.asset_quantity, ) .where(TransactionSplit.asset_id == self.id_) .order_by(TransactionSplit.date_ord) ) - if query_count(query) == 0: + if not sql.any_(query): # No transactions, prune all return ( - s.query(AssetValuation) + AssetValuation.query() .where(AssetValuation.asset_id == self.id_) .delete() ) - for date_ord, qty in query.yield_per(YIELD_PER): - date_ord: int - qty: Decimal - + for date_ord, qty in sql.yield_(query): + if TYPE_CHECKING: + # Ensured by query and constraints + assert qty is not None if current_qty == 0: # Bought some, record the period when zero date_ord_non_zero = date_ord @@ -500,32 +471,31 @@ def _delete_valuations( Number of valuations deleted """ - s = obj_session(self) n_deleted = 0 for date_ord_sell, date_ord_buy in periods_zero: trim_start: int | None = None trim_end: int | None = None if date_ord_sell is not None: # Get date of oldest valuation after or on the sell - query = s.query(func.min(AssetValuation.date_ord)).where( + query = AssetValuation.query(func.min(AssetValuation.date_ord)).where( AssetValuation.asset_id == self.id_, AssetValuation.date_ord >= date_ord_sell, ) - trim_start = query.scalar() + trim_start = sql.scalar(query) if date_ord_buy is not None: # Get date of most recent valuation or on before the buy - query = s.query(func.max(AssetValuation.date_ord)).where( + query = AssetValuation.query(func.max(AssetValuation.date_ord)).where( AssetValuation.asset_id == self.id_, AssetValuation.date_ord <= date_ord_buy, ) - trim_end = query.scalar() + trim_end = sql.scalar(query) if trim_start is None and trim_end is None: # Can happen if no valuations exist before/after a transaction continue - query = s.query(AssetValuation).where(AssetValuation.asset_id == self.id_) + query = AssetValuation.query().where(AssetValuation.asset_id == self.id_) if trim_start: query = query.where(AssetValuation.date_ord > trim_start) if trim_end: @@ -552,18 +522,15 @@ def update_valuations( Raises: NoAssetWebSourceError: If Asset has no ticker AssetWebError: If failed to download data - TypeError: If returned DataFrame indices aren't Timestamps """ if self.ticker is None: raise exc.NoAssetWebSourceError - s = obj_session(self) - today = datetime.datetime.now(datetime.UTC).date() today_ord = today.toordinal() - query = s.query(TransactionSplit).with_entities( + query = TransactionSplit.query( func.min(TransactionSplit.date_ord), func.max(TransactionSplit.date_ord), ) @@ -575,10 +542,8 @@ def update_valuations( through_today = True else: query = query.where(TransactionSplit.asset_id == self.id_) - start_ord, end_ord = query.one() - start_ord: int | None - end_ord: int | None - if start_ord is None or end_ord is None: + start_ord, end_ord = sql.one(query) + if not start_ord or not end_ord: return None, None start_ord -= utils.DAYS_IN_WEEK @@ -601,15 +566,12 @@ def update_valuations( # yfinance raises Exception if no data found raise exc.AssetWebError(e) from e - valuations: dict[int, float] = {} - for k, v in raw["Close"].items(): - if not isinstance(k, pd.Timestamp): - raise TypeError - valuations[k.to_pydatetime().date().toordinal()] = float(v) + valuations: dict[int, float] = utils.pd_series_to_dict( + raw["Close"], # type: ignore[attr-defined] + ) - query = s.query(AssetValuation).where(AssetValuation.asset_id == self.id_) + query = AssetValuation.query().where(AssetValuation.asset_id == self.id_) update_rows( - s, AssetValuation, query, "date_ord", @@ -620,13 +582,12 @@ def update_valuations( }, ) - raw_splits = raw.loc[raw["Stock Splits"] != 0]["Stock Splits"] - splits: dict[int, float] = { - k.to_pydatetime().date().toordinal(): v for k, v in raw_splits.items() - } - query = s.query(AssetSplit).where(AssetSplit.asset_id == self.id_) + splits: dict[int, float] = utils.pd_series_to_dict( + raw.loc[raw["Stock Splits"] != 0]["Stock Splits"], # type: ignore[attr-defined] + ) + + query = AssetSplit.query().where(AssetSplit.asset_id == self.id_) update_rows( - s, AssetSplit, query, "date_ord", @@ -651,8 +612,6 @@ def update_sectors(self) -> None: if self.ticker is None: raise exc.NoAssetWebSourceError - s = obj_session(self) - yf_ticker = yfinance.Ticker(self.ticker) funds = yf_ticker.funds_data try: @@ -665,13 +624,12 @@ def update_sectors(self) -> None: # Not a fund sector = yf_ticker.info.get("sector") if sector is None: - s.query(AssetSector).where(AssetSector.asset_id == self.id_).delete() + AssetSector.query().where(AssetSector.asset_id == self.id_).delete() return weights = {USSector(sector): Decimal(1)} - query = s.query(AssetSector).where(AssetSector.asset_id == self.id_) + query = AssetSector.query().where(AssetSector.asset_id == self.id_) update_rows( - s, AssetSector, query, "sector", @@ -684,7 +642,6 @@ def update_sectors(self) -> None: @classmethod def index_twrr( cls, - s: orm.Session, name: str, start_ord: int, end_ord: int, @@ -692,7 +649,6 @@ def index_twrr( """Get the TWRR for an index from start to end date. Args: - s: SQL session to use name: Name of index start_ord: First date ordinal to evaluate end_ord: Last date ordinal to evaluate (inclusive) @@ -705,22 +661,17 @@ def index_twrr( """ try: - a_id = s.query(Asset.id_).where(Asset.name == name).one()[0] + a_id = sql.one(Asset.query(Asset.id_).where(Asset.name == name)) except exc.NoResultFound as e: msg = f"Could not find asset index {name}" raise exc.ProtectedObjectNotFoundError(msg) from e - values = cls.get_value_all(s, start_ord, end_ord, ids=[a_id])[a_id] + values = cls.get_value_all(start_ord, end_ord, ids=[a_id])[a_id] cost_basis = values[0] return utils.twrr(values, [v - cost_basis for v in values]) @classmethod - def add_indices(cls, s: orm.Session) -> None: - """Add Asset indices used for performance comparison. - - Args: - s: SQL session to use - - """ + def add_indices(cls) -> None: + """Add Asset indices used for performance comparison.""" indices: dict[str, dict[str, str]] = { "^GSPC": { "name": "S&P 500", @@ -766,7 +717,7 @@ def add_indices(cls, s: orm.Session) -> None: }, } for ticker, item in indices.items(): - a = Asset( + cls.create( name=item["name"], description=item["description"], category=AssetCategory.INDEX, @@ -774,21 +725,18 @@ def add_indices(cls, s: orm.Session) -> None: ticker=ticker, currency=DEFAULT_CURRENCY, ) - s.add(a) def autodetect_interpolate(self) -> None: """Autodetect if Asset needs interpolation. Does not commit changes, call s.commit() afterwards. """ - s = obj_session(self) - query = ( - s.query(AssetValuation.date_ord) + AssetValuation.query(AssetValuation.date_ord) .where(AssetValuation.asset_id == self.id_) .order_by(AssetValuation.date_ord) ) - date_ords = [r[0] for r in query.yield_per(YIELD_PER)] + date_ords = [r[0] for r in sql.yield_(query)] has_dailys = any( (date_ords[i] - date_ords[i - 1]) == 1 for i in range(1, len(date_ords)) ) @@ -798,14 +746,12 @@ def autodetect_interpolate(self) -> None: @classmethod def create_forex( cls, - s: orm.Session, base: Currency, others: set[Currency], ) -> None: """Create foreign exchange rate assets. Args: - s: SQL session to use base: Base currency to get FOREX referenced to others: Other currencys to get @@ -813,8 +759,16 @@ def create_forex( if base in others: others.discard(base) - query = s.query(Asset.ticker).where(Asset.category == AssetCategory.FOREX) - existing: set[str] = {r[0] for r in query.all()} + query = Asset.query(Asset.ticker).where( + Asset.category == AssetCategory.FOREX, + Asset.ticker.is_not(None), + ) + existing: set[str] = set() + for (ticker,) in sql.yield_(query): + if TYPE_CHECKING: + # Ensured by query + assert ticker is not None + existing.add(ticker) for other in others: ticker = f"{other.name}{base.name}=X" @@ -822,31 +776,28 @@ def create_forex( existing.discard(ticker) continue - asset = Asset( + cls.create( name=f"{other.name} to {base.name}", description=f"Exchange rate from {other.pretty} to {base.pretty}", category=AssetCategory.FOREX, ticker=ticker, currency=base, ) - s.add(asset) existing.discard(ticker) # existing has unused FOREX assets # TODO (WattsUp): #463 Handle when accounts hold FOREX - to_delete = { - r[0] for r in s.query(Asset.id_).where(Asset.ticker.in_(existing)).all() - } - s.query(AssetValuation).where(AssetValuation.asset_id.in_(to_delete)).delete() - s.query(AssetSplit).where(AssetSplit.asset_id.in_(to_delete)).delete() - s.query(AssetSector).where(AssetSector.asset_id.in_(to_delete)).delete() - s.query(Asset).where(Asset.id_.in_(to_delete)).delete() + query = Asset.query(Asset.id_).where(Asset.ticker.in_(existing)) + to_delete = {r[0] for r in sql.yield_(query)} + AssetValuation.query().where(AssetValuation.asset_id.in_(to_delete)).delete() + AssetSplit.query().where(AssetSplit.asset_id.in_(to_delete)).delete() + AssetSector.query().where(AssetSector.asset_id.in_(to_delete)).delete() + Asset.query().where(Asset.id_.in_(to_delete)).delete() @classmethod def get_forex( cls, - s: orm.Session, start_ord: int, end_ord: int, base: Currency, @@ -855,7 +806,6 @@ def get_forex( """Get foreign exchange rate over time. Args: - s: SQL session to use start_ord: First date ordinal to evaluate end_ord: Last date ordinal to evaluate (inclusive) base: Base currency to exchange to @@ -875,14 +825,14 @@ def get_forex( f"{other.name}{base.name}=X": other for other in currencies } # null ticker filtered out by query - query = s.query(Asset.id_, Asset.ticker).where( + query = Asset.query(Asset.id_, Asset.ticker).where( Asset.category == AssetCategory.FOREX, Asset.currency == base, ) query = query.where(Asset.ticker.in_(currencies_by_ticker)) - assets = query_to_dict(query) + assets = sql.to_dict(query) - values = cls.get_value_all(s, start_ord, end_ord, assets.keys()) + values = cls.get_value_all(start_ord, end_ord, assets.keys()) forex: dict[Currency, list[Decimal]] = defaultdict( lambda: [Decimal(1)] * (end_ord - start_ord + 1), diff --git a/nummus/models/base.py b/nummus/models/base.py index 3805b72b..e3f0f485 100644 --- a/nummus/models/base.py +++ b/nummus/models/base.py @@ -2,23 +2,23 @@ from __future__ import annotations +import contextlib import enum from decimal import Decimal -from typing import ClassVar, override, TYPE_CHECKING +from typing import ClassVar, NamedTuple, overload, override, Self, TYPE_CHECKING import sqlalchemy -from sqlalchemy import CheckConstraint, orm, sql, types +from sqlalchemy import CheckConstraint, orm, types from nummus import exceptions as exc -from nummus import utils +from nummus import sql, utils from nummus.models import base_uri if TYPE_CHECKING: - from collections.abc import Iterable, Mapping + from collections.abc import Generator, Iterable, Mapping + import sqlalchemy.sql.roles as sql_roles -# Yield per instead of fetch all is faster -YIELD_PER = 100 ORMBool = orm.Mapped[bool] ORMBoolOpt = orm.Mapped[bool | None] @@ -30,7 +30,202 @@ ORMRealOpt = orm.Mapped[Decimal | None] -class Base(orm.DeclarativeBase): +class NamePair(NamedTuple): + """ID & name pair.""" + + id_: int + name: str | None + + +class SessionMixIn: + """Mix-in that provides a session reference to the type.""" + + _sessions: ClassVar[list[orm.Session]] = [] + + @classmethod + @contextlib.contextmanager + def set_session(cls, s: orm.Session) -> Generator[None]: + """Set session used by active record. + + Yields: + SQL session + + """ + cls._sessions.append(s) + try: + yield + finally: + cls._sessions.pop() + + @classmethod + def session(cls) -> orm.Session: + """Get scoped session. + + Returns: + SQL session + + Raises: + UnboundExecutionError: set_session has not been called yet + + """ + if not cls._sessions: + raise exc.UnboundExecutionError + return cls._sessions[-1] + + +class QueryMixIn(SessionMixIn): + """Mix-in that provides a query interface to the type.""" + + @classmethod + def create(cls, **kwargs: object) -> Self: + """Create a new instance. + + Args: + kwargs: Passed to init + + Returns: + New instance + + """ + i = cls(**kwargs) + s = cls.session() + s.add(i) # nummus: ignore + s.flush() + return i + + def delete(self) -> None: + """Delete an instance.""" + self.session().delete(self) + + def refresh(self) -> None: + """Refresh an instance.""" + self.session().refresh(self) + + @overload + @classmethod + def query(cls) -> orm.Query[Self]: ... + + @overload + @classmethod + def query[T0]( + cls, + c0: sql_roles.TypedColumnsClauseRole[T0], + ) -> orm.query.RowReturningQuery[tuple[T0]]: ... + + @overload + @classmethod + def query[T0, T1]( + cls, + c0: sql_roles.TypedColumnsClauseRole[T0], + c1: sql_roles.TypedColumnsClauseRole[T1], + ) -> orm.query.RowReturningQuery[tuple[T0, T1]]: ... + + @overload + @classmethod + def query[T0, T1, T2]( + cls, + c0: sql_roles.TypedColumnsClauseRole[T0], + c1: sql_roles.TypedColumnsClauseRole[T1], + c2: sql_roles.TypedColumnsClauseRole[T2], + ) -> orm.query.RowReturningQuery[tuple[T0, T1, T2]]: ... + + @overload + @classmethod + def query[T0, T1, T2, T3]( + cls, + c0: sql_roles.TypedColumnsClauseRole[T0], + c1: sql_roles.TypedColumnsClauseRole[T1], + c2: sql_roles.TypedColumnsClauseRole[T2], + c3: sql_roles.TypedColumnsClauseRole[T3], + ) -> orm.query.RowReturningQuery[tuple[T0, T1, T2, T3]]: ... + + @overload + @classmethod + def query[T0, T1, T2, T3, T4]( + cls, + c0: sql_roles.TypedColumnsClauseRole[T0], + c1: sql_roles.TypedColumnsClauseRole[T1], + c2: sql_roles.TypedColumnsClauseRole[T2], + c3: sql_roles.TypedColumnsClauseRole[T3], + c4: sql_roles.TypedColumnsClauseRole[T4], + ) -> orm.query.RowReturningQuery[tuple[T0, T1, T2, T3, T4]]: ... + + @overload + @classmethod + def query[T0, T1, T2, T3, T4, T5]( + cls, + c0: sql_roles.TypedColumnsClauseRole[T0], + c1: sql_roles.TypedColumnsClauseRole[T1], + c2: sql_roles.TypedColumnsClauseRole[T2], + c3: sql_roles.TypedColumnsClauseRole[T3], + c4: sql_roles.TypedColumnsClauseRole[T4], + c5: sql_roles.TypedColumnsClauseRole[T5], + ) -> orm.query.RowReturningQuery[tuple[T0, T1, T2, T3, T4, T5]]: ... + + @classmethod + def query[T]( + cls, + *columns: sql_roles.TypedColumnsClauseRole[object], + **kwargs: T, + ) -> orm.Query[Self] | orm.Query[T]: + """Create a new query. + + Returns: + Query of table + + Raises: + NoKeywordArgumentsError: if kwargs are provided + + """ + if kwargs: + raise exc.NoKeywordArgumentsError + query: orm.Query[Self] = cls.session().query(cls) + if columns: + return query.with_entities(*columns) # nummus: ignore + return query + + @classmethod + def all(cls) -> list[Self]: + """Fetch all rows. + + Returns: + List of each row object + + """ + return list(sql.yield_(cls.query())) + + @classmethod + def one(cls) -> Self: + """Fetch one rows. + + Returns: + Only row + + """ + return sql.one(cls.query()) + + @classmethod + def first(cls) -> Self | None: + """Fetch first rows. + + Returns: + First row + + """ + return cls.query().first() + + @classmethod + def count(cls) -> int: + """Count number of rows. + + Returns: + Number of rows + + """ + return sql.count(cls.query()) + + +class Base(orm.DeclarativeBase, QueryMixIn): """Base ORM model. Attributes: @@ -43,6 +238,10 @@ class Base(orm.DeclarativeBase): __table_id__: int | None + id_: ORMInt = orm.mapped_column(primary_key=True, autoincrement=True) + + _SEARCH_PROPERTIES: tuple[str, ...] = () + @override def __init_subclass__(cls, *, skip_register: bool = False, **kw: object) -> None: super().__init_subclass__(**kw) @@ -61,15 +260,13 @@ def __init_subclass__(cls, *, skip_register: bool = False, **kw: object) -> None i += 1 @classmethod - def metadata_create_all(cls, s: orm.Session) -> None: + def metadata_create_all(cls) -> None: """Create all tables for nummus models. Creates tables then commits - Args: - s: Session to create tables for - """ + s = cls.session() cls.metadata.create_all(s.get_bind(), [m.sql_table() for m in cls._MODELS]) s.commit() @@ -88,8 +285,6 @@ def sql_table(cls) -> sqlalchemy.Table: return cls.__table__ raise TypeError - id_: ORMInt = orm.mapped_column(primary_key=True, autoincrement=True) - @classmethod def id_to_uri(cls, id_: int) -> str: """Uniform Resource Identifier derived from id_ and __table_id__. @@ -132,15 +327,7 @@ def uri_to_id(cls, uri: str) -> int: @property def uri(self) -> str: - """Uniform Resource Identifier derived from id_ and __table_id__. - - Raises: - NoIDError: If object does not have id_ - - """ - if self.id_ is None: - msg = f"{self.__class__.__name__} does not have an id_, maybe flush" - raise exc.NoIDError(msg) + """Uniform Resource Identifier derived from id_ and __table_id__.""" return self.id_to_uri(self.id_) @override @@ -163,12 +350,9 @@ def __ne__(self, other: Base | object) -> bool: return not isinstance(other, Base) or self.uri != other.uri @classmethod - def map_name(cls, s: orm.Session) -> dict[int, str]: + def map_name(cls) -> dict[int, str]: """Get mapping between id and names. - Args: - s: SQL session to use - Returns: Dictionary {id: name} @@ -176,13 +360,12 @@ def map_name(cls, s: orm.Session) -> dict[int, str]: KeyError: if model does not have name property """ - attr = getattr(cls, "name", None) + attr: orm.QueryableAttribute[str] | None = getattr(cls, "name", None) if not attr: msg = f"{cls.__name__} does not have name column" raise KeyError(msg) - query = s.query(cls).with_entities(cls.id_, attr) - return dict(query.all()) + return sql.to_dict(cls.query(cls.id_, attr)) @classmethod def clean_strings( @@ -246,6 +429,66 @@ def clean_emoji_name(cls, s: str) -> str: """ return utils.strip_emojis(s).strip().lower() + @classmethod + def find( + cls, + search: str, + cache: dict[str, NamePair], + ) -> NamePair: + """Find a matching object by uri, or field value. + + Args: + search: Search query + cache: Cache results to speed up look ups + + Returns: + tuple(id_, name) + + Raises: + NoResultFound: if object not found + + """ + pair = cache.get(search) + if pair is not None: + return pair + + def cache_and_return(m: Self) -> NamePair: + id_ = m.id_ + name: str | None = getattr(m, "name", None) + pair = NamePair(id_, name) + cache[search] = pair + return pair + + try: + # See if query is an URI + id_ = cls.uri_to_id(search) + except (exc.InvalidURIError, exc.WrongURITypeError): + pass + else: + query = cls.query().where(cls.id_ == id_) + if m := sql.scalar(query): + return cache_and_return(m) + + for name in cls._SEARCH_PROPERTIES: + prop: sql.Column = getattr(cls, name) + # Exact? + query = cls.query().where(prop == search) + if m := sql.scalar(query): + return cache_and_return(m) + + # Exact lower case? + query = cls.query().where(prop.ilike(search)) + if m := sql.scalar(query): + return cache_and_return(m) + + # For account number, see if there is one ending in the search + query = cls.query().where(prop.ilike(f"%{search}")) + if name == "number" and (m := sql.scalar(query)): + return cache_and_return(m) + + msg = f"{cls.__name__} matching '{search}' could not be found" + raise exc.NoResultFound(msg) + class BaseEnum(enum.IntEnum): """Enum class with a parser.""" @@ -295,7 +538,7 @@ def pretty(self) -> str: return self.name.replace("_", " ").title() -class SQLEnum(types.TypeDecorator): +class SQLEnum(types.TypeDecorator[BaseEnum]): """SQL type for enumeration, stores as integer.""" impl = types.Integer @@ -361,7 +604,7 @@ def process_result_value( return self._enum_type(value) -class Decimal6(types.TypeDecorator): +class Decimal6(types.TypeDecorator[Decimal]): """SQL type for fixed point numbers, stores as micro-integer.""" impl = types.BigInteger @@ -453,7 +696,7 @@ def string_column_args( Tuple of constraints """ - name_col = f"`{name}`" if name in sql.compiler.RESERVED_WORDS else name + name_col = sql.escape(name) checks = [ ( CheckConstraint( diff --git a/nummus/models/base_uri.py b/nummus/models/base_uri.py index 48abe232..07453d4f 100644 --- a/nummus/models/base_uri.py +++ b/nummus/models/base_uri.py @@ -22,7 +22,7 @@ _ROUNDS = 3 -_CIPHER: Cipher +_cipher: Cipher class Cipher: @@ -203,7 +203,7 @@ def from_bytes(buf: bytes) -> Cipher: msg = f"Buf is {len(buf)}B long, expected {n}B" raise ValueError(msg) - keys = [] + keys: list[int] = [] for _ in range(_ROUNDS): buf_next = buf[:ID_BYTES] buf = buf[ID_BYTES:] @@ -216,14 +216,14 @@ def from_bytes(buf: bytes) -> Cipher: def load_cipher(buf: bytes) -> None: - """Load a Cipher from bytes into _CIPHER. + """Load a Cipher from bytes into _cipher. Args: buf: Bytes to load """ - global _CIPHER - _CIPHER = Cipher.from_bytes(buf) + global _cipher + _cipher = Cipher.from_bytes(buf) def id_to_uri(id_: int) -> str: @@ -236,7 +236,7 @@ def id_to_uri(id_: int) -> str: URI, hex encoded, 1:1 mapping """ - return _CIPHER.encode(id_).to_bytes(ID_BYTES, _ORDER).hex() + return _cipher.encode(id_).to_bytes(ID_BYTES, _ORDER).hex() def uri_to_id(uri: str) -> int: @@ -261,4 +261,4 @@ def uri_to_id(uri: str) -> int: msg = f"URI is not a hex number: {uri}" raise exc.InvalidURIError(msg) from e else: - return _CIPHER.decode(uri_int) + return _cipher.decode(uri_int) diff --git a/nummus/models/budget.py b/nummus/models/budget.py index 2627e474..28583fa8 100644 --- a/nummus/models/budget.py +++ b/nummus/models/budget.py @@ -9,7 +9,7 @@ from sqlalchemy import CheckConstraint, ForeignKey, func, Index, orm, UniqueConstraint -from nummus import utils +from nummus import sql, utils from nummus.models.account import Account from nummus.models.base import ( Base, @@ -21,14 +21,12 @@ ORMStr, SQLEnum, string_column_args, - YIELD_PER, ) from nummus.models.transaction import TransactionSplit from nummus.models.transaction_category import ( TransactionCategory, TransactionCategoryGroup, ) -from nummus.models.utils import query_to_dict class BudgetAvailableCategory(NamedTuple): @@ -129,13 +127,11 @@ def validate_decimals(self, key: str, field: Decimal | None) -> Decimal | None: @classmethod def get_monthly_available( cls, - s: orm.Session, month: datetime.date, ) -> BudgetAvailable: """Get available budget for a month. Args: - s: SQL session to use month: Month to compute budget during Returns: @@ -147,44 +143,37 @@ def get_monthly_available( """ month_ord = month.toordinal() - query = s.query(Account).where(Account.budgeted) + query = Account.query().where(Account.budgeted) accounts = { - acct.id_: acct.name for acct in query.all() if acct.do_include(month_ord) + acct.id_: acct.name + for acct in sql.yield_(query) + if acct.do_include(month_ord) } # Starting balance - query = ( - s.query(TransactionSplit) - .with_entities( - func.sum(TransactionSplit.amount), - ) - .where( - TransactionSplit.account_id.in_(accounts), - TransactionSplit.date_ord < month_ord, - ) + query = TransactionSplit.query( + func.sum(TransactionSplit.amount), + ).where( + TransactionSplit.account_id.in_(accounts), + TransactionSplit.date_ord < month_ord, ) - starting_balance = query.scalar() or Decimal() + starting_balance = sql.scalar(query) or Decimal() ending_balance = starting_balance total_available = Decimal() # Check all categories not INCOME - budget_categories = { - t_cat_id - for t_cat_id, in ( - s.query(TransactionCategory.id_) - .where(TransactionCategory.group != TransactionCategoryGroup.INCOME) - .all() - ) - } + query = TransactionCategory.query(TransactionCategory.id_).where( + TransactionCategory.group != TransactionCategoryGroup.INCOME, + ) + budget_categories = {t_cat_id for t_cat_id, in sql.yield_(query)} # Current month's assignment - query = ( - s.query(BudgetAssignment) - .with_entities(BudgetAssignment.category_id, BudgetAssignment.amount) - .where(BudgetAssignment.month_ord == month_ord) - ) - categories_assigned: dict[int, Decimal] = query_to_dict(query) + query = BudgetAssignment.query( + BudgetAssignment.category_id, + BudgetAssignment.amount, + ).where(BudgetAssignment.month_ord == month_ord) + categories_assigned: dict[int, Decimal] = sql.to_dict(query) # Prior months' assignment min_month_ord = month_ord @@ -192,8 +181,7 @@ def get_monthly_available( t_cat_id: {} for t_cat_id in budget_categories } query = ( - s.query(BudgetAssignment) - .with_entities( + BudgetAssignment.query( BudgetAssignment.category_id, BudgetAssignment.amount, BudgetAssignment.month_ord, @@ -201,7 +189,7 @@ def get_monthly_available( .where(BudgetAssignment.month_ord < month_ord) .order_by(BudgetAssignment.month_ord) ) - for cat_id, amount, m_ord in query.yield_per(YIELD_PER): + for cat_id, amount, m_ord in sql.yield_(query): prior_assigned[cat_id][m_ord] = amount min_month_ord = min(min_month_ord, m_ord) @@ -210,8 +198,7 @@ def get_monthly_available( t_cat_id: {} for t_cat_id in budget_categories } query = ( - s.query(TransactionSplit) - .with_entities( + TransactionSplit.query( TransactionSplit.category_id, func.sum(TransactionSplit.amount), TransactionSplit.month_ord, @@ -227,7 +214,7 @@ def get_monthly_available( TransactionSplit.month_ord, ) ) - for cat_id, amount, m_ord in query.yield_per(YIELD_PER): + for cat_id, amount, m_ord in sql.yield_(query): prior_activity[cat_id][m_ord] = amount # Carry over leftover to next months to get current month's leftover amounts @@ -246,17 +233,14 @@ def get_monthly_available( date = utils.date_add_months(date, 1) # Future months' assignment - query = ( - s.query(BudgetAssignment) - .with_entities(func.sum(BudgetAssignment.amount)) - .where(BudgetAssignment.month_ord > month_ord) + query = BudgetAssignment.query(func.sum(BudgetAssignment.amount)).where( + BudgetAssignment.month_ord > month_ord, ) - future_assigned = query.scalar() or Decimal() + future_assigned = sql.scalar(query) or Decimal() # Current month's activity query = ( - s.query(TransactionSplit) - .with_entities( + TransactionSplit.query( TransactionSplit.category_id, func.sum(TransactionSplit.amount), ) @@ -266,14 +250,14 @@ def get_monthly_available( ) .group_by(TransactionSplit.category_id) ) - categories_activity: dict[int, Decimal] = query_to_dict(query) + categories_activity: dict[int, Decimal] = sql.to_dict(query) categories: dict[int, BudgetAvailableCategory] = {} - query = s.query(TransactionCategory).with_entities( + query = TransactionCategory.query( TransactionCategory.id_, TransactionCategory.group, ) - for t_cat_id, group in query.yield_per(YIELD_PER): + for t_cat_id, group in sql.yield_(query): activity = categories_activity.get(t_cat_id, Decimal()) assigned = categories_assigned.get(t_cat_id, Decimal()) leftover = categories_leftover.get(t_cat_id, Decimal()) @@ -299,7 +283,6 @@ def get_monthly_available( @classmethod def get_emergency_fund( cls, - s: orm.Session, start_ord: int, end_ord: int, n_lower: int, @@ -308,7 +291,6 @@ def get_emergency_fund( """Get the emergency fund target range and assigned balance. Args: - s: SQL session to use start_ord: First day of calculated range end_ord: Last day of calculated range n_lower: Number of days in sliding lower period @@ -321,30 +303,21 @@ def get_emergency_fund( n = end_ord - start_ord + 1 n_smoothing = 15 - query = ( - s.query(Account) - .with_entities(Account.id_, Account.name) - .where(Account.budgeted) - ) - accounts: dict[int, str] = query_to_dict(query) + query = Account.query(Account.id_, Account.name).where(Account.budgeted) + accounts: dict[int, str] = sql.to_dict(query) - t_cat_id, _ = TransactionCategory.emergency_fund(s) + t_cat_id, _ = TransactionCategory.emergency_fund() - balance = ( - s.query(func.sum(BudgetAssignment.amount)) - .where( - BudgetAssignment.category_id == t_cat_id, - BudgetAssignment.month_ord <= start_ord, - ) - .scalar() - or Decimal() + query = BudgetAssignment.query(func.sum(BudgetAssignment.amount)).where( + BudgetAssignment.category_id == t_cat_id, + BudgetAssignment.month_ord <= start_ord, ) + balance = sql.scalar(query) or Decimal() balances: list[Decimal] = [] query = ( - s.query(BudgetAssignment) - .with_entities(BudgetAssignment.month_ord, BudgetAssignment.amount) + BudgetAssignment.query(BudgetAssignment.month_ord, BudgetAssignment.amount) .where( BudgetAssignment.category_id == t_cat_id, BudgetAssignment.month_ord > start_ord, @@ -353,7 +326,7 @@ def get_emergency_fund( .order_by(BudgetAssignment.month_ord) ) date_ord = start_ord - for b_ord, amount in query.all(): + for b_ord, amount in sql.yield_(query): while date_ord < b_ord: balances.append(balance) date_ord += 1 @@ -368,23 +341,18 @@ def get_emergency_fund( daily = Decimal() dailys: list[Decimal] = [] - query = ( - s.query(TransactionCategory) - .with_entities( - TransactionCategory.id_, - TransactionCategory.name, - TransactionCategory.emoji_name, - ) - .where(TransactionCategory.essential_spending) - ) - for t_cat_id, name, emoji_name in query.all(): + query = TransactionCategory.query( + TransactionCategory.id_, + TransactionCategory.name, + TransactionCategory.emoji_name, + ).where(TransactionCategory.essential_spending) + for t_cat_id, name, emoji_name in sql.yield_(query): categories[t_cat_id] = name, emoji_name categories_total[t_cat_id] = Decimal() start_ord_dailys = start_ord - n_upper - n_smoothing query = ( - s.query(TransactionSplit) - .with_entities( + TransactionSplit.query( TransactionSplit.date_ord, TransactionSplit.category_id, func.sum(TransactionSplit.amount), @@ -397,7 +365,7 @@ def get_emergency_fund( .group_by(TransactionSplit.date_ord, TransactionSplit.category_id) ) date_ord = start_ord_dailys - for t_ord, t_cat_id, amount in query.yield_per(YIELD_PER): + for t_ord, t_cat_id, amount in sql.yield_(query): while date_ord < t_ord: dailys.append(daily) date_ord += 1 @@ -438,7 +406,6 @@ def get_emergency_fund( @classmethod def move( cls, - s: orm.Session, month_ord: int, src_cat_id: int | None, dest_cat_id: int | None, @@ -447,7 +414,6 @@ def move( """Move funds between budget assignments. Args: - s: SQL session to use month_ord: Month of BudgetAssignment src_cat_id: Source category ID, or None dest_cat_id: Destination category ID, or None @@ -457,7 +423,7 @@ def move( if src_cat_id is not None: # Remove to_move from src_cat_id a = ( - s.query(BudgetAssignment) + BudgetAssignment.query() .where( BudgetAssignment.category_id == src_cat_id, BudgetAssignment.month_ord == month_ord, @@ -465,20 +431,19 @@ def move( .one_or_none() ) if a is None: - a = BudgetAssignment( + BudgetAssignment.create( month_ord=month_ord, amount=-to_move, category_id=src_cat_id, ) - s.add(a) elif a.amount == to_move: - s.delete(a) + a.delete() else: a.amount -= to_move if dest_cat_id is not None: a = ( - s.query(BudgetAssignment) + BudgetAssignment.query() .where( BudgetAssignment.category_id == dest_cat_id, BudgetAssignment.month_ord == month_ord, @@ -486,12 +451,11 @@ def move( .one_or_none() ) if a is None: - a = BudgetAssignment( + BudgetAssignment.create( month_ord=month_ord, amount=to_move, category_id=dest_cat_id, ) - s.add(a) else: a.amount += to_move diff --git a/nummus/models/config.py b/nummus/models/config.py index 8fe926cd..4d92983c 100644 --- a/nummus/models/config.py +++ b/nummus/models/config.py @@ -8,6 +8,7 @@ from sqlalchemy import orm from nummus import exceptions as exc +from nummus import sql from nummus.models.base import Base, BaseEnum, ORMStr, SQLEnum, string_column_args from nummus.models.currency import Currency @@ -56,24 +57,22 @@ def validate_strings(self, key: str, field: str | None) -> str | None: return self.clean_strings(key, field) @classmethod - def set_(cls, s: orm.Session, key: ConfigKey, value: str) -> None: + def set_(cls, key: ConfigKey, value: str) -> None: """Set a Configuration value. Args: - s: SQL session to use key: ConfigKey to query value: Value to set """ - if s.query(Config).where(Config.key == key).update({"value": value}): + if Config.query().where(Config.key == key).update({"value": value}): return - s.add(Config(key=key, value=value)) + Config.create(key=key, value=value) @overload @classmethod def fetch( cls, - s: orm.Session, key: ConfigKey, *, no_raise: Literal[False] = False, @@ -83,7 +82,6 @@ def fetch( @classmethod def fetch( cls, - s: orm.Session, key: ConfigKey, *, no_raise: Literal[True], @@ -92,7 +90,6 @@ def fetch( @classmethod def fetch( cls, - s: orm.Session, key: ConfigKey, *, no_raise: bool = False, @@ -100,7 +97,6 @@ def fetch( """Fetch a Configuration value. Args: - s: SQL session to use key: ConfigKey to query no_raise: True will return None if missing @@ -112,7 +108,7 @@ def fetch( """ try: - return s.query(Config.value).where(Config.key == key).one()[0] + return sql.one(cls.query(cls.value).where(cls.key == key)) except exc.NoResultFound as e: if no_raise: return None @@ -120,27 +116,21 @@ def fetch( raise exc.ProtectedObjectNotFoundError(msg) from e @classmethod - def db_version(cls, s: orm.Session) -> Version: + def db_version(cls) -> Version: """Query the database version. - Args: - s: SQL session to use - Returns: Version of database """ - return Version(Config.fetch(s, ConfigKey.VERSION)) + return Version(Config.fetch(ConfigKey.VERSION)) @classmethod - def base_currency(cls, s: orm.Session) -> Currency: + def base_currency(cls) -> Currency: """Query the basse currency. - Args: - s: SQL session to use - Returns: Base currency all accounts are converted into """ - return Currency(int(Config.fetch(s, ConfigKey.BASE_CURRENCY))) + return Currency(int(Config.fetch(ConfigKey.BASE_CURRENCY))) diff --git a/nummus/models/label.py b/nummus/models/label.py index b2131708..39b10a93 100644 --- a/nummus/models/label.py +++ b/nummus/models/label.py @@ -4,13 +4,14 @@ from sqlalchemy import ForeignKey, Index, orm, UniqueConstraint +from nummus import sql from nummus.models.base import ( Base, ORMInt, ORMStr, string_column_args, ) -from nummus.models.utils import query_to_dict, update_rows +from nummus.models.utils import update_rows class LabelLink(Base): @@ -34,12 +35,11 @@ class LabelLink(Base): Index("label_link_t_split_id", "t_split_id"), ) - @staticmethod - def add_links(s: orm.Session, split_labels: dict[int, set[str]]) -> None: + @classmethod + def add_links(cls, split_labels: dict[int, set[str]]) -> None: """Add links between TransactionSplits and Labels. Args: - s: SQL session to use split_labels: dict {TransactionSplit: {label names to link} """ @@ -51,23 +51,16 @@ def add_links(s: orm.Session, split_labels: dict[int, set[str]]) -> None: for labels in split_labels.values(): label_names.update(labels) - query = ( - s.query(Label) - .with_entities(Label.name, Label.id_) - .where(Label.name.in_(label_names)) - ) - mapping: dict[str, int] = query_to_dict(query) + query = Label.query(Label.name, Label.id_).where(Label.name.in_(label_names)) + mapping: dict[str, int] = sql.to_dict(query) - to_add = [Label(name=name) for name in label_names if name not in mapping] - if to_add: - s.add_all(to_add) - s.flush() - mapping.update({label.name: label.id_ for label in to_add}) + for name in label_names: + if name not in mapping: + mapping[name] = Label.create(name=name).id_ for t_split_id, labels in split_labels.items(): - query = s.query(LabelLink).where(LabelLink.t_split_id == t_split_id) + query = LabelLink.query().where(LabelLink.t_split_id == t_split_id) update_rows( - s, LabelLink, query, "label_id", diff --git a/nummus/models/transaction.py b/nummus/models/transaction.py index 74b417b4..b856f5f5 100644 --- a/nummus/models/transaction.py +++ b/nummus/models/transaction.py @@ -11,10 +11,10 @@ import sqlalchemy from rapidfuzz import process -from sqlalchemy import CheckConstraint, event, ForeignKey, Index, orm +from sqlalchemy import CheckConstraint, ForeignKey, Index, orm from nummus import exceptions as exc -from nummus import utils +from nummus import sql, utils from nummus.models.base import ( Base, Decimal6, @@ -27,11 +27,9 @@ ORMStr, ORMStrOpt, string_column_args, - YIELD_PER, ) from nummus.models.label import Label, LabelLink from nummus.models.transaction_category import TransactionCategory -from nummus.models.utils import obj_session if TYPE_CHECKING: from decimal import Decimal @@ -92,6 +90,10 @@ class TransactionSplit(Base): "(asset_quantity IS NOT NULL) == (_asset_qty_unadjusted IS NOT NULL)", name="asset_quantity and unadjusted must be same null state", ), + CheckConstraint( + "(asset_id IS NOT NULL) == (_asset_qty_unadjusted IS NOT NULL)", + name="asset_id and asset quantity must be same null state", + ), CheckConstraint( "amount != 0", "transaction_split.amount must be non-zero", @@ -225,15 +227,10 @@ def adjust_asset_quantity_residual(self, residual: Decimal) -> None: @property def parent(self) -> Transaction: """Parent Transaction.""" - s = obj_session(self) - query = s.query(Transaction).where(Transaction.id_ == self.parent_id) - return query.one() + return sql.one(Transaction.query().where(Transaction.id_ == self.parent_id)) @parent.setter def parent(self, parent: Transaction) -> None: - if parent.id_ is None: - self.parent_tmp = parent - return super().__setattr__("parent_id", parent.id_) super().__setattr__("date_ord", parent.date_ord) super().__setattr__("month_ord", parent.month_ord) @@ -270,9 +267,9 @@ def search( query = query.join(LabelLink, full=True) tokens_must, tokens_can, tokens_not = utils.tokenize_search_str(search_str) - category_names = category_names or TransactionCategory.map_name(query.session) + category_names = category_names or TransactionCategory.map_name() category_names_rev = {v: k for k, v in category_names.items()} - label_names = label_names or Label.map_name(query.session) + label_names = label_names or Label.map_name() label_names_rev = {v.lower(): k for k, v in label_names.items()} query = cls._search_must( @@ -283,17 +280,18 @@ def search( ) query = cls._search_not(query, tokens_not, category_names_rev, label_names_rev) - sub_query = query.with_entities(TransactionSplit.id_).scalar_subquery() - query_modified = ( - query.session.query(LabelLink) - .with_entities(LabelLink.t_split_id, LabelLink.label_id) - .where(LabelLink.t_split_id.in_(sub_query)) - ) + sub_query = query.with_entities( # nummus: ignore + TransactionSplit.id_, + ).scalar_subquery() + query_modified = LabelLink.query( + LabelLink.t_split_id, + LabelLink.label_id, + ).where(LabelLink.t_split_id.in_(sub_query)) split_labels: dict[int, set[int]] = defaultdict(set) - for t_split_id, label_id in query_modified.yield_per(YIELD_PER): + for t_split_id, label_id in sql.yield_(query_modified): split_labels[t_split_id].add(label_id) - query_modified = query.with_entities( + query_modified = query.with_entities( # nummus: ignore TransactionSplit.id_, TransactionSplit.date_ord, TransactionSplit.category_id, @@ -306,12 +304,7 @@ def search( date_ord, cat_id, text_fields, - ) in query_modified.yield_per(YIELD_PER): - t_id: int - date_ord: int - cat_id: int - text_fields: str | None - + ) in sql.yield_(query_modified): full_text = f"{category_names[cat_id]} {text_fields or ''} " + " ".join( label_names[label_id] for label_id in split_labels[t_id] ) @@ -356,7 +349,7 @@ def _search_must( continue - clauses_or: list[sqlalchemy.ColumnExpressionArgument] = [] + clauses_or: list[sql.ColumnClause] = [] categories = { cat_id for cat_name, cat_id in category_names.items() @@ -418,24 +411,6 @@ def _search_not( return query -@event.listens_for(TransactionSplit, "before_insert") -def before_insert_transaction_split( - _: orm.Mapper, - __: sqlalchemy.Connection, - target: TransactionSplit, -) -> None: - """Handle event before insert of TransactionSplit. - - Args: - target: TransactionSplit being inserted - - """ - # If TransactionSplit has parent_tmp set, move it to real parent - if hasattr(target, "parent_tmp"): - target.parent = target.parent_tmp - delattr(target, "parent_tmp") - - class Transaction(Base): """Transaction model for storing an exchange of cash for an asset (or none). @@ -534,8 +509,6 @@ def find_similar( Most similar Transaction.id_ """ - s = obj_session(self) - if cache_ok and self.similar_txn_id is not None: return self.similar_txn_id @@ -553,26 +526,21 @@ def set_match(matching_row: int | Row[tuple[int]]) -> int: id_ = matching_row if isinstance(matching_row, int) else matching_row[0] if set_property: self.similar_txn_id = id_ - s.flush() return id_ # Convert txn.amount to the raw SQL value to make a raw query amount_raw = Transaction.amount.type.process_bind_param(self.amount, None) sort_closest_amount = sqlalchemy.text(f"abs({amount_raw} - amount)") - cat_asset_linked = { - t_cat_id - for t_cat_id, in ( - s.query(TransactionCategory.id_) - .where(TransactionCategory.asset_linked.is_(True)) - .all() - ) - } + query = TransactionCategory.query(TransactionCategory.id_).where( + TransactionCategory.asset_linked.is_(True), + ) + cat_asset_linked = {t_cat_id for t_cat_id, in sql.yield_(query)} # Check within Account first, exact matches # If this matches, great, no post filtering needed query = ( - s.query(Transaction.id_) + Transaction.query(Transaction.id_) .where( Transaction.account_id == self.account_id, Transaction.id_ != self.id_, @@ -588,7 +556,7 @@ def set_match(matching_row: int | Row[tuple[int]]) -> int: # Maybe exact statement but different account query = ( - s.query(Transaction.id_) + Transaction.query(Transaction.id_) .where( Transaction.id_ != self.id_, Transaction.amount >= amount_min, @@ -603,7 +571,7 @@ def set_match(matching_row: int | Row[tuple[int]]) -> int: # Maybe exact statement but different amount query = ( - s.query(Transaction.id_) + Transaction.query(Transaction.id_) .where( Transaction.id_ != self.id_, Transaction.statement == self.statement, @@ -616,8 +584,7 @@ def set_match(matching_row: int | Row[tuple[int]]) -> int: # No statements match, choose highest fuzzy matching statement query = ( - s.query(Transaction) - .with_entities( + Transaction.query( Transaction.id_, Transaction.statement, ) @@ -630,22 +597,20 @@ def set_match(matching_row: int | Row[tuple[int]]) -> int: ) statements: dict[int, str] = { t_id: re.sub(r"[0-9]+", "", statement).lower() - for t_id, statement in query.yield_per(YIELD_PER) + for t_id, statement in sql.yield_(query) } if len(statements) == 0: return None # Don't match a Transaction if it has a Securities Traded split - has_asset_linked = { - id_ - for id_, in ( - s.query(TransactionSplit.parent_id) - .where( - TransactionSplit.parent_id.in_(statements), - TransactionSplit.category_id.in_(cat_asset_linked), - ) - .distinct() + query = ( + Transaction.query(TransactionSplit.parent_id) + .where( + TransactionSplit.parent_id.in_(statements), + TransactionSplit.category_id.in_(cat_asset_linked), ) - } + .distinct() + ) + has_asset_linked = {id_ for id_, in sql.yield_(query)} statements = { t_id: statement for t_id, statement in statements.items() @@ -669,17 +634,13 @@ def set_match(matching_row: int | Row[tuple[int]]) -> int: ) # Add a bonuse points for closeness in price and same account - query = ( - s.query(Transaction) - .with_entities( - Transaction.id_, - Transaction.account_id, - Transaction.amount, - ) - .where(Transaction.id_.in_(matches)) - ) + query = Transaction.query( + Transaction.id_, + Transaction.account_id, + Transaction.amount, + ).where(Transaction.id_.in_(matches)) matches_bonus: dict[int, float] = {} - for t_id, acct_id, amount in query.yield_per(YIELD_PER): + for t_id, acct_id, amount in sql.yield_(query): # 5% off will reduce score by 5% amount_diff_percent = abs(amount - self.amount) / self.amount # Extra 10 points for same account diff --git a/nummus/models/transaction_category.py b/nummus/models/transaction_category.py index 03892808..a63e7621 100644 --- a/nummus/models/transaction_category.py +++ b/nummus/models/transaction_category.py @@ -7,6 +7,7 @@ from sqlalchemy import CheckConstraint, ForeignKey, orm, UniqueConstraint from nummus import exceptions as exc +from nummus import sql from nummus.models.base import ( Base, BaseEnum, @@ -16,7 +17,6 @@ SQLEnum, string_column_args, ) -from nummus.models.utils import query_to_dict class TransactionCategoryGroup(BaseEnum): @@ -127,12 +127,9 @@ def validate_essential_spending(self, _: str, field: object) -> bool: return field @staticmethod - def add_default(s: orm.Session) -> dict[str, TransactionCategory]: + def add_default() -> dict[str, TransactionCategory]: """Create default transaction categories. - Args: - s: SQL session to use - Returns: Dictionary {name: category} @@ -232,7 +229,7 @@ class Spec(NamedTuple): for group, categories in groups.items(): for name, spec in categories.items(): - cat = TransactionCategory( + cat = TransactionCategory.create( emoji_name=name, group=group, locked=spec.locked, @@ -240,7 +237,6 @@ class Spec(NamedTuple): asset_linked=spec.asset_linked, essential_spending=spec.essential_spending, ) - s.add(cat) d[cat.name] = cat return d @@ -248,14 +244,12 @@ class Spec(NamedTuple): @classmethod def map_name( cls, - s: orm.Session, *, no_asset_linked: bool = False, ) -> dict[int, str]: """Get mapping between id and names. Args: - s: SQL session to use no_asset_linked: True will not include asset_linked categories Returns: @@ -265,29 +259,20 @@ def map_name( KeyError if model does not have name property """ - query = ( - s.query(TransactionCategory) - .with_entities( - TransactionCategory.id_, - TransactionCategory.name, - ) - .order_by(TransactionCategory.name) - ) + query = cls.query(cls.id_, cls.name).order_by(cls.name) if no_asset_linked: - query = query.where(TransactionCategory.asset_linked.is_(False)) - return query_to_dict(query) + query = query.where(cls.asset_linked.is_(False)) + return sql.to_dict(query) @classmethod def map_name_emoji( cls, - s: orm.Session, *, no_asset_linked: bool = False, ) -> dict[int, str]: """Get mapping between id and names with emojis. Args: - s: SQL session to use no_asset_linked: True will not include asset_linked categories Returns: @@ -297,24 +282,16 @@ def map_name_emoji( KeyError if model does not have name property """ - query = ( - s.query(TransactionCategory) - .with_entities( - TransactionCategory.id_, - TransactionCategory.emoji_name, - ) - .order_by(TransactionCategory.name) - ) + query = cls.query(cls.id_, cls.emoji_name).order_by(cls.name) if no_asset_linked: - query = query.where(TransactionCategory.asset_linked.is_(False)) - return query_to_dict(query) + query = query.where(cls.asset_linked.is_(False)) + return sql.to_dict(query) @classmethod - def _get_protected_id(cls, s: orm.Session, name: str) -> tuple[int, str]: + def _get_protected_id(cls, name: str) -> tuple[int, str]: """Get the ID and URI of a protected category. Args: - s: SQL session to use name: Name of protected category to fetch Returns: @@ -325,23 +302,16 @@ def _get_protected_id(cls, s: orm.Session, name: str) -> tuple[int, str]: """ try: - id_ = ( - s.query(TransactionCategory.id_) - .where(TransactionCategory.name == name) - .one()[0] - ) + id_ = sql.one(cls.query(cls.id_).where(TransactionCategory.name == name)) except exc.NoResultFound as e: msg = f"Category {name} not found" raise exc.ProtectedObjectNotFoundError(msg) from e return id_, cls.id_to_uri(id_) @classmethod - def uncategorized(cls, s: orm.Session) -> tuple[int, str]: + def uncategorized(cls) -> tuple[int, str]: """Get the ID and URI of the uncategorized category. - Args: - s: SQL session to use - Returns: tuple(id_, URI) @@ -349,15 +319,12 @@ def uncategorized(cls, s: orm.Session) -> tuple[int, str]: ProtectedObjectNotFound if not found """ - return cls._get_protected_id(s, "uncategorized") + return cls._get_protected_id("uncategorized") @classmethod - def emergency_fund(cls, s: orm.Session) -> tuple[int, str]: + def emergency_fund(cls) -> tuple[int, str]: """Get the ID and URI of the emergency fund category. - Args: - s: SQL session to use - Returns: tuple(id_, URI) @@ -365,15 +332,12 @@ def emergency_fund(cls, s: orm.Session) -> tuple[int, str]: ProtectedObjectNotFound if not found """ - return cls._get_protected_id(s, "emergency fund") + return cls._get_protected_id("emergency fund") @classmethod - def securities_traded(cls, s: orm.Session) -> tuple[int, str]: + def securities_traded(cls) -> tuple[int, str]: """Get the ID and URI of the securities traded category. - Args: - s: SQL session to use - Returns: tuple(id_, URI) @@ -381,4 +345,4 @@ def securities_traded(cls, s: orm.Session) -> tuple[int, str]: ProtectedObjectNotFound if not found """ - return cls._get_protected_id(s, "securities traded") + return cls._get_protected_id("securities traded") diff --git a/nummus/models/utils.py b/nummus/models/utils.py index 952c6388..88704c9b 100644 --- a/nummus/models/utils.py +++ b/nummus/models/utils.py @@ -9,69 +9,25 @@ from sqlalchemy import ( CheckConstraint, ForeignKeyConstraint, - func, - orm, UniqueConstraint, ) -from nummus import exceptions as exc -from nummus.models.base import YIELD_PER +from nummus import sql if TYPE_CHECKING: from sqlalchemy import ( Constraint, + orm, ) from nummus.models.base import Base -def query_to_dict[K, V](query: orm.query.RowReturningQuery[tuple[K, V]]) -> dict[K, V]: - """Fetch results from query and return a dict. - - Args: - query: Query that returns 2 columns - - Returns: - dict{first column: second column} - - """ - # pyright is happier with comprehension - # ruff is happier with dict() - return dict(query.yield_per(YIELD_PER)) # type: ignore[attr-defined] - - -def query_count(query: orm.Query) -> int: - """Count the number of result a query will return. - - Args: - query: Session query to execute - - Returns: - Number of instances query will return upon execution - - Raises: - TypeError: if query.statement is not a Select - - """ - # From here: - # https://datawookie.dev/blog/2021/01/sqlalchemy-efficient-counting/ - col_one = sqlalchemy.literal_column("1") - stmt = query.statement - if not isinstance(stmt, sqlalchemy.Select): - raise TypeError - counter = stmt.with_only_columns( - func.count(col_one), - maintain_column_froms=True, - ) - counter = counter.order_by(None) - return query.session.execute(counter).scalar() or 0 - - -def paginate( - query: orm.Query[Base], +def paginate[T: Base]( + query: orm.Query[T], limit: int, offset: int, -) -> tuple[list[Base], int, int | None]: +) -> tuple[list[T], int, int | None]: """Paginate query response for smaller results. Args: @@ -87,12 +43,12 @@ def paginate( offset = max(0, offset) # Get amount number from filters - count = query_count(query) + count = sql.count(query) # Apply limiting, and offset query = query.limit(limit).offset(offset) - results = query.all() + results = list(sql.yield_(query)) # Compute next_offset n_current = len(results) @@ -102,14 +58,10 @@ def paginate( return results, count, next_offset -def dump_table_configs( - s: orm.Session, - model: type[Base], -) -> list[str]: +def dump_table_configs(model: type[Base]) -> list[str]: """Get the table configs (columns and constraints) and print. Args: - s: SQL session to use model: Filter to specific table Returns: @@ -123,26 +75,26 @@ def dump_table_configs( type='table' AND name='{model.__tablename__}' """.strip() # noqa: S608 - result = s.execute(sqlalchemy.text(stmt)).one()[0] - result: str + query: orm.query.RowReturningQuery[tuple[str]] = model.session().execute( # type: ignore[attr-defined] + sqlalchemy.text(stmt), + ) + result: str = sql.one(query) return [s.replace("\t", " ") for s in result.splitlines()] def get_constraints( - s: orm.Session, model: type[Base], ) -> list[tuple[type[Constraint], str]]: """Get constraints of a table. Args: - s: SQL session to use model: Filter to specific table Returns: list[(Constraint type, construction text)] """ - config = "\n".join(dump_table_configs(s, model)) + config = "\n".join(dump_table_configs(model)) constraints: list[tuple[type[Constraint], str]] = [] re_unique = re.compile(r"UNIQUE \(([^\)]+)\)") @@ -163,36 +115,15 @@ def get_constraints( return constraints -def obj_session(m: Base) -> orm.Session: - """Get the SQL session for an object. - - Args: - m: Model to get from - - Returns: - Session - - Raises: - UnboundExecutionError: if model is unbound - - """ - s = orm.object_session(m) - if s is None: - raise exc.UnboundExecutionError - return s - - -def update_rows( - s: orm.Session, - cls: type[Base], - query: orm.Query, +def update_rows[T: Base]( + cls: type[T], + query: orm.Query[T], id_key: str, updates: dict[object, dict[str, object]], ) -> None: """Update many rows, reusing leftovers when possible. Args: - s: SQL session to use cls: Type of model to update query: Query to fetch all applicable models id_key: Name of property used for identification @@ -202,7 +133,7 @@ def update_rows( updates = updates.copy() leftovers: list[Base] = [] - for m in query.yield_per(YIELD_PER): + for m in sql.yield_(query): update = updates.pop(getattr(m, id_key), None) if update is None: # No longer needed @@ -211,6 +142,8 @@ def update_rows( for k, v in update.items(): setattr(m, k, v) + s = cls.session() + # Add any missing ones for id_, update in updates.items(): if leftovers: @@ -219,24 +152,21 @@ def update_rows( for k, v in update.items(): setattr(m, k, v) else: - m = cls(**{id_key: id_, **update}) - s.add(m) + cls.create(**{id_key: id_, **update}) # Delete any leftovers for m in leftovers: s.delete(m) -def update_rows_list( - s: orm.Session, - cls: type[Base], - query: orm.Query, +def update_rows_list[T: Base]( + cls: type[T], + query: orm.Query[T], updates: list[dict[str, object]], ) -> list[int]: """Update many rows, reusing leftovers when possible. Args: - s: SQL session to use cls: Type of model to update query: Query to fetch all applicable models updates: list[{parameter: value}] @@ -250,7 +180,7 @@ def update_rows_list( updates = updates.copy() leftovers: list[Base] = [] - for m in query.yield_per(YIELD_PER): + for m in sql.yield_(query): if len(updates) == 0: # No longer needed leftovers.append(m) @@ -260,6 +190,8 @@ def update_rows_list( setattr(m, k, v) ids.append(m.id_) + s = cls.session() + to_add = [cls(**update) for update in updates] s.add_all(to_add) @@ -271,17 +203,3 @@ def update_rows_list( ids.extend(m.id_ for m in to_add) return ids - - -def one_or_none[T](query: orm.Query[T]) -> T | None: - """Return one result. - - Returns: - One result - If no results or multiple, return None - - """ - try: - return query.one_or_none() - except (exc.NoResultFound, exc.MultipleResultsFound): - return None diff --git a/nummus/portfolio.py b/nummus/portfolio.py index cb43e63c..18c40a90 100644 --- a/nummus/portfolio.py +++ b/nummus/portfolio.py @@ -3,6 +3,7 @@ from __future__ import annotations import base64 +import contextlib import datetime import hashlib import io @@ -27,27 +28,20 @@ from nummus.migrations.top import MIGRATORS from nummus.models.account import Account from nummus.models.asset import Asset -from nummus.models.base import Base, YIELD_PER +from nummus.models.base import Base from nummus.models.base_uri import Cipher, load_cipher from nummus.models.config import Config, ConfigKey -from nummus.models.currency import ( - Currency, - DEFAULT_CURRENCY, -) +from nummus.models.currency import DEFAULT_CURRENCY from nummus.models.imported_file import ImportedFile from nummus.models.transaction import Transaction, TransactionSplit from nummus.models.transaction_category import TransactionCategory -from nummus.models.utils import ( - one_or_none, - query_to_dict, -) from nummus.version import __version__ if TYPE_CHECKING: - import contextlib + from collections.abc import Iterator from nummus.importers.base import TxnDict - from nummus.models.currency import Currency + from nummus.models.base import NamePair class AssetUpdate(NamedTuple): @@ -200,9 +194,9 @@ def create(cls, path: str | Path, key: str | None = None) -> Portfolio: test_value = enc.encrypt(Portfolio._ENCRYPTION_TEST_VALUE) engine = sql.get_engine(path_db, enc) - with orm.Session(engine) as s: + with orm.Session(engine) as s, Base.set_session(s): with s.begin(): - Base.metadata_create_all(s) + Base.metadata_create_all() with s.begin(): # If developing a migration, current version will be less @@ -212,20 +206,20 @@ def create(cls, path: str | Path, key: str | None = None) -> Portfolio: *[m.min_version() for m in MIGRATORS], ) - Config.set_(s, ConfigKey.VERSION, str(v)) - Config.set_(s, ConfigKey.ENCRYPTION_TEST, test_value) - Config.set_(s, ConfigKey.CIPHER, cipher_b64) - Config.set_(s, ConfigKey.SECRET_KEY, secrets.token_hex()) - Config.set_(s, ConfigKey.BASE_CURRENCY, str(DEFAULT_CURRENCY.value)) + Config.set_(ConfigKey.VERSION, str(v)) + Config.set_(ConfigKey.ENCRYPTION_TEST, test_value) + Config.set_(ConfigKey.CIPHER, cipher_b64) + Config.set_(ConfigKey.SECRET_KEY, secrets.token_hex()) + Config.set_(ConfigKey.BASE_CURRENCY, str(DEFAULT_CURRENCY.value)) if enc is not None and key is not None: - Config.set_(s, ConfigKey.WEB_KEY, enc.encrypt(key)) + Config.set_(ConfigKey.WEB_KEY, enc.encrypt(key)) path_db.chmod(0o600) # Only owner can read/write p = Portfolio(path_db, key) - with p.begin_session() as s: - TransactionCategory.add_default(s) - Asset.add_indices(s) + with p.begin_session(): + TransactionCategory.add_default() + Asset.add_indices() return p def _unlock(self) -> dict[ConfigKey, str]: @@ -240,9 +234,9 @@ def _unlock(self) -> dict[ConfigKey, str]: """ try: - with self.begin_session() as s: - query = s.query(Config).with_entities(Config.key, Config.value) - configs: dict[ConfigKey, str] = query_to_dict(query) + with self.begin_session(): + query = Config.query(Config.key, Config.value) + configs: dict[ConfigKey, str] = sql.to_dict(query) except exc.DatabaseError as e: msg = f"Failed to open database {self._path_db}" raise exc.UnlockingError(msg) from e @@ -280,14 +274,17 @@ def get_engine(self) -> sqlalchemy.Engine: """ return sql.get_engine(self._path_db, self._enc) - def begin_session(self) -> contextlib.AbstractContextManager[orm.Session]: + @contextlib.contextmanager + def begin_session(self) -> Iterator[orm.Session]: """Get SQL Session to the database. - Returns: + Yields: Open Session """ - return self._session_maker.begin() + s = self._session_maker() + with s, s.begin(), Base.set_session(s): + yield s def encrypt(self, secret: bytes | str) -> str: """Encrypt a secret using the key. @@ -346,8 +343,8 @@ def migration_required(self, version_str: str | None) -> Version | None: """ if version_str is None: - with self.begin_session() as s: - v_db = Config.db_version(s) + with self.begin_session(): + v_db = Config.db_version() else: v_db = Version(version_str) for m in MIGRATORS[::-1]: @@ -375,11 +372,13 @@ def import_file(self, path: Path, path_debug: Path, *, force: bool = False) -> N sha = hashlib.sha256() sha.update(path.read_bytes()) h = sha.hexdigest() - with self.begin_session() as s: + with self.begin_session(): if force: - s.query(ImportedFile).where(ImportedFile.hash_ == h).delete() - existing_date_ord: int | None = ( - s.query(ImportedFile.date_ord).where(ImportedFile.hash_ == h).scalar() + ImportedFile.query().where(ImportedFile.hash_ == h).delete() + existing_date_ord: int | None = sql.scalar( + ImportedFile.query(ImportedFile.date_ord).where( + ImportedFile.hash_ == h, + ), ) if existing_date_ord is not None: date = datetime.date.fromordinal(existing_date_ord) @@ -388,12 +387,12 @@ def import_file(self, path: Path, path_debug: Path, *, force: bool = False) -> N i = get_importer(path, path_debug, self._importers) today = datetime.datetime.now(datetime.UTC).date() - categories = TransactionCategory.map_name(s) + categories = TransactionCategory.map_name() # Reverse categories for LUT categories = {v: k for k, v in categories.items()} # Cache a mapping from account/asset name to the ID - acct_mapping: dict[str, tuple[int, str | None]] = {} - asset_mapping: dict[str, tuple[int, str | None]] = {} + acct_mapping: dict[str, NamePair] = {} + asset_mapping: dict[str, NamePair] = {} try: txns_raw = i.run() except Exception as e: @@ -406,25 +405,25 @@ def import_file(self, path: Path, path_debug: Path, *, force: bool = False) -> N # Create a single split for each transaction acct_raw = d["account"] - acct_id, _ = self.find(s, Account, acct_raw, acct_mapping) + acct_id, _ = Account.find(acct_raw, acct_mapping) asset_raw = d["asset"] asset_id: int | None = None if asset_raw: - asset_id, asset_name = self.find(s, Asset, asset_raw, asset_mapping) + asset_id, asset_name = Asset.find(asset_raw, asset_mapping) if not d["statement"]: d["statement"] = f"Asset Transaction {asset_name}" - self._import_transaction(s, d, acct_id, asset_id, categories) + self._import_transaction(d, acct_id, asset_id, categories) # Add file hash to prevent importing again - s.add(ImportedFile(hash_=h)) + ImportedFile.create(hash_=h) # Update splits on each touched - query = s.query(Asset).where( + query = Asset.query().where( Asset.id_.in_(a_id for a_id, _ in asset_mapping.values()), ) - for asset in query.all(): + for asset in sql.yield_(query): asset.update_splits() # If successful, delete the temp file @@ -432,42 +431,36 @@ def import_file(self, path: Path, path_debug: Path, *, force: bool = False) -> N def _import_transaction( self, - s: orm.Session, d: TxnDict, acct_id: int, asset_id: int | None, categories: dict[str, int], ) -> None: if asset_id is not None: - self._import_asset_transaction(s, d, acct_id, asset_id, categories) + self._import_asset_transaction(d, acct_id, asset_id, categories) return # See if anything matches date_ord = d["date"].toordinal() - matches = list( - s.query(Transaction) - .with_entities(Transaction.id_, Transaction.date_ord) - .where( - Transaction.account_id == acct_id, - Transaction.amount == d["amount"], - Transaction.date_ord >= date_ord - 5, - Transaction.date_ord <= date_ord + 5, - Transaction.cleared.is_(False), - ) - .all(), + query = Transaction.query(Transaction.id_, Transaction.date_ord).where( + Transaction.account_id == acct_id, + Transaction.amount == d["amount"], + Transaction.date_ord >= date_ord - 5, + Transaction.date_ord <= date_ord + 5, + Transaction.cleared.is_(False), ) - matches = sorted(matches, key=lambda x: abs(x[1] - date_ord)) + matches = sorted(sql.yield_(query), key=lambda x: abs(x[1] - date_ord)) # If only one match on closest day, link transaction if len(matches) == 1 or (len(matches) > 1 and matches[0][1] != matches[1][1]): match_id = matches[0][0] statement_clean = Transaction.clean_strings("statement", d["statement"]) - s.query(Transaction).where(Transaction.id_ == match_id).update( + Transaction.query().where(Transaction.id_ == match_id).update( { "cleared": True, "statement": statement_clean, }, ) - s.query(TransactionSplit).where( + TransactionSplit.query().where( TransactionSplit.parent_id == match_id, ).update({"cleared": True}) return @@ -478,7 +471,7 @@ def _import_transaction( except KeyError: category_id = categories["uncategorized"] - txn = Transaction( + txn = Transaction.create( account_id=acct_id, amount=d["amount"], date=d["date"], @@ -486,21 +479,19 @@ def _import_transaction( payee=d["payee"], cleared=True, ) - t_split = TransactionSplit( + TransactionSplit.create( + parent=txn, amount=d["amount"], memo=d["memo"], category_id=category_id, ) - t_split.parent = txn - s.add_all((txn, t_split)) @classmethod def _import_asset_transaction( cls, - s: orm.Session, d: TxnDict, acct_id: int, - asset_id: int, + asset_id: int | None, categories: dict[str, int], ) -> None: category_name = (d["category"] or "uncategorized").lower() @@ -513,7 +504,7 @@ def _import_asset_transaction( raise exc.MissingAssetError(msg) qty = abs(qty) - txn = Transaction( + txn = Transaction.create( account_id=acct_id, amount=0, date=d["date"], @@ -521,7 +512,7 @@ def _import_asset_transaction( payee=d["payee"], cleared=True, ) - t_split_0 = TransactionSplit( + TransactionSplit.create( parent=txn, amount=amount, memo=d["memo"], @@ -529,7 +520,7 @@ def _import_asset_transaction( asset_id=asset_id, asset_quantity_unadjusted=-qty, ) - t_split_1 = TransactionSplit( + TransactionSplit.create( parent=txn, amount=-amount, memo=d["memo"], @@ -537,7 +528,6 @@ def _import_asset_transaction( asset_id=asset_id, asset_quantity_unadjusted=0, ) - s.add_all((txn, t_split_0, t_split_1)) return if category_name == "dividends received": # Associate dividends with asset @@ -548,7 +538,7 @@ def _import_asset_transaction( raise exc.MissingAssetError(msg) qty = abs(qty) - txn = Transaction( + txn = Transaction.create( account_id=acct_id, amount=0, date=d["date"], @@ -556,7 +546,7 @@ def _import_asset_transaction( payee=d["payee"], cleared=True, ) - t_split_0 = TransactionSplit( + TransactionSplit.create( parent=txn, amount=amount, memo=d["memo"], @@ -564,21 +554,19 @@ def _import_asset_transaction( asset_id=asset_id, asset_quantity_unadjusted=0, ) - t_split_1 = TransactionSplit( - parent=txn, - amount=-amount, - memo=d["memo"], - category_id=categories["securities traded"], - asset_id=asset_id, - asset_quantity_unadjusted=qty, - ) - s.add_all((txn, t_split_0)) if qty != 0: # Zero quantity means cash dividends, not reinvested - s.add(t_split_1) + TransactionSplit.create( + parent=txn, + amount=-amount, + memo=d["memo"], + category_id=categories["securities traded"], + asset_id=asset_id, + asset_quantity_unadjusted=qty, + ) return if category_name == "securities traded": - txn = Transaction( + txn = Transaction.create( account_id=acct_id, amount=d["amount"], date=d["date"], @@ -586,86 +574,19 @@ def _import_asset_transaction( payee=d["payee"], cleared=True, ) - t_split = TransactionSplit( + TransactionSplit( + parent=txn, amount=d["amount"], memo=d["memo"], category_id=categories[category_name], asset_id=asset_id, asset_quantity_unadjusted=d["asset_quantity"], ) - t_split.parent = txn - s.add_all((txn, t_split)) return msg = f"'{category_name}' is not a valid category for asset transaction" raise exc.InvalidAssetTransactionCategoryError(msg) - @classmethod - def find( - cls, - s: orm.Session, - model: type[Base], - search: str, - cache: dict[str, tuple[int, str | None]], - ) -> tuple[int, str | None]: - """Find a matching object by uri, or field value. - - Args: - s: Session to use - model: Type of model to search for - search: Search query - cache: Cache results to speed up look ups - - Returns: - tuple(id_, name) - - Raises: - NoResultFound: if object not found - - """ - id_, name = cache.get(search, (None, None)) - if id_ is not None: - return id_, name - - def cache_and_return(m: Base) -> tuple[int, str | None]: - id_ = m.id_ - name: str | None = getattr(m, "name", None) - cache[search] = id_, name - return id_, name - - try: - # See if query is an URI - id_ = model.uri_to_id(search) - except (exc.InvalidURIError, exc.WrongURITypeError): - pass - else: - query = s.query(model).where(model.id_ == id_) - if m := one_or_none(query): - return cache_and_return(m) - - properties: list[orm.QueryableAttribute] = { - Account: [Account.number, Account.institution, Account.name], - Asset: [Asset.ticker, Asset.name], - }[model] - for prop in properties: - # Exact? - query = s.query(model).where(prop == search) - if m := one_or_none(query): - return cache_and_return(m) - - # Exact lower case? - query = s.query(model).where(prop.ilike(search)) - if m := one_or_none(query): - return cache_and_return(m) - - # For account number, see if there is one ending in the search - query = s.query(model).where(prop.ilike(f"%{search}")) - if prop is Account.number and (m := one_or_none(query)): - return cache_and_return(m) - - msg = f"{model.__name__} matching '{search}' could not be found" - raise exc.NoResultFound(msg) - def backup(self) -> tuple[Path, int]: """Back up database, duplicates files. @@ -765,7 +686,7 @@ def clean(self) -> tuple[int, int]: # Prune unused AssetValuations with self.begin_session() as s: - for asset in s.query(Asset).yield_per(YIELD_PER): + for asset in Asset.all(): asset.prune_valuations() asset.autodetect_interpolate() @@ -924,17 +845,18 @@ def update_assets(self, *, no_bars: bool = False) -> list[AssetUpdate]: today_ord = today.toordinal() updated: list[AssetUpdate] = [] - with self.begin_session() as s: + with self.begin_session(): # Get FOREXes, add if need be - currencies: set[Currency] = {r[0] for r in s.query(Account.currency).all()} - base_currency = Config.base_currency(s) - Asset.create_forex(s, base_currency, currencies) + currencies = set(sql.col0(Account.query(Account.currency))) + base_currency = Config.base_currency() + Asset.create_forex(base_currency, currencies) - assets = s.query(Asset).where(Asset.ticker.isnot(None)).all() + query = Asset.query().where(Asset.ticker.isnot(None)) + assets = list(sql.yield_(query)) ids = [asset.id_ for asset in assets] # Get currently held assets - asset_qty = Account.get_asset_qty_all(s, today_ord, today_ord) + asset_qty = Account.get_asset_qty_all(today_ord, today_ord) currently_held_assets: set[int] = set() for acct_assets in asset_qty.values(): for a_id in ids: @@ -958,7 +880,7 @@ def update_assets(self, *, no_bars: bool = False) -> list[AssetUpdate]: # start & end are None if there are no transactions for the Asset # Auto update if asset needs interpolation - for asset in s.query(Asset).all(): + for asset in Asset.all(): asset.autodetect_interpolate() return updated @@ -1030,8 +952,8 @@ def filter_(tables: list[sqlalchemy.Table]) -> list[sqlalchemy.Table]: conn_dst.commit() # Use new encryption key - with self.begin_session() as s: - value_encrypted = Config.fetch(s, ConfigKey.WEB_KEY) + with self.begin_session(): + value_encrypted = Config.fetch(ConfigKey.WEB_KEY, no_raise=True) value = key if value_encrypted is None else self.decrypt_s(value_encrypted) dst.change_web_key(value) @@ -1062,5 +984,5 @@ def change_web_key(self, key: str) -> None: raise exc.InvalidKeyError(msg) key_encrypted = self.encrypt(key) - with self.begin_session() as s: - Config.set_(s, ConfigKey.WEB_KEY, key_encrypted) + with self.begin_session(): + Config.set_(ConfigKey.WEB_KEY, key_encrypted) diff --git a/nummus/sql.py b/nummus/sql.py index 15672ca5..f1cb8a34 100644 --- a/nummus/sql.py +++ b/nummus/sql.py @@ -4,13 +4,17 @@ import base64 import sys -from typing import TYPE_CHECKING +from collections.abc import Sequence +from typing import overload, TYPE_CHECKING import sqlalchemy import sqlalchemy.event +from sqlalchemy import func, orm +from sqlalchemy.sql import case if TYPE_CHECKING: import sqlite3 + from collections.abc import Generator, Iterable from pathlib import Path from nummus.encryption.base import EncryptionInterface @@ -23,6 +27,16 @@ _ENGINE_ARGS: dict[str, object] = {} +Column = ( + orm.InstrumentedAttribute[str] + | orm.InstrumentedAttribute[str | None] + | orm.InstrumentedAttribute[int] + | orm.InstrumentedAttribute[int | None] +) +ColumnClause = sqlalchemy.ColumnElement[bool] + +__all__ = ["case"] + @sqlalchemy.event.listens_for(sqlalchemy.engine.Engine, "connect") def set_sqlite_pragma(db_connection: sqlite3.Connection, *_) -> None: @@ -80,3 +94,201 @@ def escape(s: str) -> str: """ return f"`{s}`" if s in sqlalchemy.sql.compiler.RESERVED_WORDS else s + + +@overload +def to_dict_tuple[K, T0, T1]( + query: orm.query.RowReturningQuery[tuple[K, T0, T1]], +) -> dict[K, tuple[T0, T1]]: ... + + +@overload +def to_dict_tuple[K, T0, T1, T2]( + query: orm.query.RowReturningQuery[tuple[K, T0, T1, T2]], +) -> dict[K, tuple[T0, T1, T2]]: ... + + +@overload +def to_dict_tuple[K, T0, T1, T2, T3]( + query: orm.query.RowReturningQuery[tuple[K, T0, T1, T2, T3]], +) -> dict[K, tuple[T0, T1, T2, T3]]: ... + + +def to_dict_tuple[T: tuple[object, ...]]( # type: ignore[attr-defined] + query: orm.query.RowReturningQuery[T], +) -> dict[object, tuple[object, ...]]: + """Fetch results from query and return a dict. + + Args: + query: Query that returns 2 columns + + Returns: + dict{first column: second column} + or + dict{first column: tuple(other columns)} + + """ + return {r[0]: r[1:] for r in yield_(query)} + + +def to_dict[K, V]( + query: orm.query.RowReturningQuery[tuple[K, V]], +) -> dict[K, V]: + """Fetch results from query and return a dict. + + Args: + query: Query that returns 2 columns + + Returns: + dict{first column: second column} + + """ + return {r[0]: r[1] for r in yield_(query)} + + +def count[T](query: orm.Query[T]) -> int: + """Count the number of result a query will return. + + Args: + query: Session query to execute + + Returns: + Number of instances query will return upon execution + + Raises: + TypeError: if query.statement is not a Select + + """ + # From here: + # https://datawookie.dev/blog/2021/01/sqlalchemy-efficient-counting/ + col_one: sqlalchemy.ColumnClause[object] = sqlalchemy.literal_column("1") + stmt = query.statement + if not isinstance(stmt, sqlalchemy.Select): + raise TypeError + counter = stmt.with_only_columns( + func.count(col_one), + maintain_column_froms=True, + ) + counter = counter.order_by(None) + return query.session.execute(counter).scalar() or 0 # nummus: ignore + + +def any_[T](query: orm.Query[T]) -> bool: + """Check if any rows exists in query. + + Args: + query: Session query to execute + + Returns: + True if any results + + """ + return count(query.limit(1)) != 0 + + +@overload +def one[T0]( + query: orm.query.RowReturningQuery[tuple[T0]], +) -> T0: ... + + +@overload +def one[T0, T1]( + query: orm.query.RowReturningQuery[tuple[T0, T1]], +) -> tuple[T0, T1]: ... + + +@overload +def one[T](query: orm.Query[T]) -> T: ... + + +def one[T](query: orm.Query[T]) -> object: + """Check if any rows exists in query. + + Args: + query: Session query to execute + + Returns: + One result + + """ + ret: T | Sequence[T] = query.one() # nummus: ignore + if not isinstance(ret, Sequence): + return ret + if len(ret) == 1: # type: ignore[attr-defined] + return ret[0] # type: ignore[attr-defined] + return ret[0:] # type: ignore[attr-defined] + + +@overload +def scalar[T0]( + query: orm.query.RowReturningQuery[tuple[T0]], +) -> T0 | None: ... + + +@overload +def scalar[T0, T1]( + query: orm.query.RowReturningQuery[tuple[T0, T1]], +) -> T0 | None: ... + + +@overload +def scalar[T0, T1, T2]( + query: orm.query.RowReturningQuery[tuple[T0, T1, T2]], +) -> T0 | None: ... + + +@overload +def scalar[T](query: orm.Query[T]) -> T | None: ... + + +def scalar[T](query: orm.Query[T]) -> object | None: + """Check if any rows exists in query. + + Args: + query: Session query to execute + + Returns: + One result + + """ + return query.scalar() # nummus: ignore + + +@overload +def yield_[T: tuple[object, ...]]( + query: orm.query.RowReturningQuery[T], +) -> Iterable[T]: ... + + +@overload +def yield_[T](query: orm.Query[T]) -> Iterable[T]: ... + + +def yield_[T](query: orm.Query[T]) -> Iterable[object]: + """Yield a query. + + Args: + query: Query to yield + + Yields: + Rows + + """ + # Yield per instead of fetch all is faster + for r in query.yield_per(100): # nummus: ignore + yield r[0:] if isinstance(r, Sequence) else r + + +def col0[T](query: orm.query.RowReturningQuery[tuple[T]]) -> Generator[T]: + """Yield a query into a list. + + Args: + query: Query to yield + + Yields: + first column + + """ + for (r,) in yield_(query): + yield r diff --git a/nummus/utils.py b/nummus/utils.py index b586bc43..8c139da0 100644 --- a/nummus/utils.py +++ b/nummus/utils.py @@ -16,6 +16,7 @@ from typing import NamedTuple, overload, TYPE_CHECKING import emoji as emoji_mod +import pandas as pd from colorama import Fore from rapidfuzz import process from scipy import optimize @@ -822,7 +823,7 @@ def pretty_table(table: list[list[str] | None]) -> list[str]: col_widths[i] = n margin += n_trim extra = False - excess.append(0 if cell[-1] == "/" else col_widths[i] - n_label) + excess.append(0 if cell[-1] == "/" else max(0, col_widths[i] - n_label)) # Distribute excess while margin < 0 and any(excess): @@ -832,7 +833,7 @@ def pretty_table(table: list[list[str] | None]) -> list[str]: excess[i] -= 1 margin += 1 - formats = [] + formats: list[str] = [] for cell, n in zip(header_raw, col_widths, strict=True): align = cell[0] align = align if align in "<>^" else "" @@ -1073,3 +1074,26 @@ def set_sub_keys[_, T, V](dicts: dict[_, dict[T, V]]) -> set[T]: for d in dicts.values(): keys.update(d.keys()) return keys + + +def pd_series_to_dict(s: pd.Series[float]) -> dict[int, float]: + """Convert pandas series to dict. + + Args: + s: pandas series + + Returns: + dict{date ordinal: value} + + Raises: + TypeError: if columns are not date, floatable + + """ + d: dict[int, float] = {} + for k, v in s.items(): + if not isinstance(k, pd.Timestamp): + raise TypeError + if not isinstance(v, float): + raise TypeError + d[k.to_pydatetime().date().toordinal()] = float(v) + return d diff --git a/nummus/web.py b/nummus/web.py index cc81e8fb..f2b19e33 100644 --- a/nummus/web.py +++ b/nummus/web.py @@ -3,7 +3,9 @@ from __future__ import annotations import datetime +import functools import os +from decimal import Decimal from pathlib import Path from typing import TYPE_CHECKING @@ -68,28 +70,30 @@ def init_app(self, app: flask.Flask) -> None: self._init_metrics(app) # Inject common variables into templates - app.context_processor( - lambda: { - "url_args": {}, - }, - ) + args: dict[str, dict[str, object]] = { + "url_args": {}, + } + app.context_processor(lambda: args) @classmethod - def _open_portfolio(cls, config: flask.Config) -> Portfolio: - path = ( - Path( - config.get("PORTFOLIO", "~/.nummus/portfolio.db"), - ) - .expanduser() - .absolute() - ) + def _open_portfolio(cls, config: dict[str, object]) -> Portfolio: + s = config.get("PORTFOLIO", "~/.nummus/portfolio.db") + if not isinstance(s, str): + raise TypeError + path = Path(s).expanduser().absolute() key = config.get("KEY") if key is None: path_key = config.get("KEY_PATH") - path_key = Path(path_key).expanduser().absolute() if path_key else None + path_key = ( + Path(path_key).expanduser().absolute() + if isinstance(path_key, str) + else None + ) if path_key and path_key.exists(): key = path_key.read_text("utf-8").strip() + elif not isinstance(key, str): + raise TypeError return Portfolio(path, key) @@ -129,27 +133,25 @@ def _add_routes(cls, app: flask.Flask) -> None: @classmethod def _init_auth(cls, app: flask.Flask, p: Portfolio) -> None: - with p.begin_session() as s: - secret_key = Config.fetch(s, ConfigKey.SECRET_KEY) + with p.begin_session(): + secret_key = Config.fetch(ConfigKey.SECRET_KEY) app.secret_key = secret_key - app.config.update( - SESSION_COOKIE_SECURE=True, - SESSION_COOKIE_HTTPONLY=True, - SESSION_COOKIE_SAMESITE="Lax", - REMEMBER_COOKIE_SECURE=True, - REMEMBER_COOKIE_HTTPONLY=True, - REMEMBER_COOKIE_SAMESITE="Lax", - REMEMBER_COOKIE_DURATION=datetime.timedelta(days=28), - ) + config: dict[str, object] = app.config + config["SESSION_COOKIE_SECURE"] = True + config["SESSION_COOKIE_HTTPONLY"] = True + config["SESSION_COOKIE_SAMESITE"] = "Lax" + config["REMEMBER_COOKIE_SECURE"] = True + config["REMEMBER_COOKIE_HTTPONLY"] = True + config["REMEMBER_COOKIE_SAMESITE"] = "Lax" + config["REMEMBER_COOKIE_DURATION"] = datetime.timedelta(days=28) app.after_request(base.update_client_timezone) app.after_request(base.change_redirect_to_htmx) login_manager = flask_login.LoginManager() login_manager.init_app(app) login_manager.user_loader(auth.get_user) - # LoginManager.login_view not typed to str | None - login_manager.login_view = "auth.page_login" # type: ignore[attr-defined] + login_manager.login_view = "auth.page_login" if p.is_encrypted: # Only can have authentiation with encrypted portfolio @@ -159,25 +161,41 @@ def _init_auth(cls, app: flask.Flask, p: Portfolio) -> None: def _init_jinja_env(cls, env: jinja2.Environment) -> None: env.filters["seconds"] = utils.format_seconds env.filters["days"] = utils.format_days - env.filters["days_abv"] = lambda x: utils.format_days( - x, - ["days", "wks", "mos", "yrs"], + env.filters["days_abv"] = functools.partial( + utils.format_days, + labels=["days", "wks", "mos", "yrs"], ) env.filters["comma"] = lambda x: f"{x:,.2f}" env.filters["qty"] = lambda x: f"{x:,.6f}" - env.filters["percent"] = lambda x: f"{x * 100:5.2f}%" - env.filters["pnl_color"] = lambda x: ( - "" if x is None or x == 0 else ("text-primary" if x > 0 else "text-error") - ) - env.filters["pnl_arrow"] = lambda x: ( - "" - if x is None or x == 0 - else ("arrow_upward" if x > 0 else "arrow_downward") - ) env.filters["no_emojis"] = utils.strip_emojis env.filters["tojson"] = base.ctx_to_json env.filters["input_value"] = lambda x: str(x or "").rstrip("0").rstrip(".") + def percent(x: Decimal | float | object) -> str: + if not isinstance(x, Decimal | float): + raise TypeError + return f"{x * 100:5.2f}%" + + env.filters["percent"] = percent + + def pnl_color(x: Decimal | float | object) -> str: + if not x: + return "" + if not isinstance(x, Decimal | float): + raise TypeError + return "text-primary" if x > 0 else "text-error" + + env.filters["pnl_color"] = pnl_color + + def pnl_arrow(x: Decimal | float | object) -> str: + if not x: + return "" + if not isinstance(x, Decimal | float): + raise TypeError + return "arrow_upward" if x > 0 else "arrow_downward" + + env.filters["pnl_arrow"] = pnl_arrow + @classmethod def _init_metrics(cls, app: flask.Flask) -> None: multiproc = "PROMETHEUS_MULTIPROC_DIR" in os.environ diff --git a/nummus/web_assets.py b/nummus/web_assets.py index e9ca943b..2536681f 100644 --- a/nummus/web_assets.py +++ b/nummus/web_assets.py @@ -29,7 +29,7 @@ class TailwindCSSFilter(webassets.filter.Filter): DEBUG = False @override - def output(self, _in: io.StringIO, out: io.StringIO, **_) -> None: + def output(self, _in: io.StringIO, out: io.StringIO, **_: object) -> None: if pytailwindcss is None: raise NotImplementedError path_root = Path(__file__).parent.resolve() @@ -54,7 +54,7 @@ class JSMinFilter(webassets.filter.Filter): """webassets Filter for running jsmin over.""" @override - def output(self, _in: io.StringIO, out: io.StringIO, **_) -> None: + def output(self, _in: io.StringIO, out: io.StringIO, **_: object) -> None: if jsmin is None: raise NotImplementedError # Add back tick to quote_chars for template strings diff --git a/pyproject.toml b/pyproject.toml index 60e0be11..c76360d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,8 +76,10 @@ dev = [ "isort", "pre-commit", "djlint>=1.36.4", - "pyright>=1.1.402", + "pyright>=1.1.408", "taplo", + "scipy-stubs", + "pandas-stubs", ] build = ["build"] build-docker = ["nummus-financial[build,encrypt]", "setuptools-scm>=8"] @@ -148,9 +150,20 @@ force_alphabetical_sort_within_sections = true [tool.pyright] include = ["nummus", "tests", "tools"] -exclude = ["**/__pycache__"] +exclude = ["**/__pycache__", "typing"] venvPath = "." venv = ".venv" +typeCheckingMode = "strict" + +reportUnknownParameterType = "warning" +reportUnknownArgumentType = "warning" +reportUnknownLambdaType = "warning" +reportUnknownVariableType = "warning" +reportUnknownMemberType = "warning" +reportUnnecessaryTypeIgnoreComment = "warning" + +# Checked by ruff +reportPrivateUsage = "none" [tool.pytest.ini_options] markers = ["encryption: tests that require encryption"] @@ -158,6 +171,7 @@ addopts = ["--durations=10", "--import-mode=importlib"] [tool.ruff] target-version = "py312" +exclude = ["typings"] [tool.ruff.lint] select = ["ALL"] @@ -212,6 +226,7 @@ ignore = [ "FBT001", # Boolean positional arguments "T201", # Allow printing for extra context "SLF001", # Allow access to privates + "ARG001", # Allow unused arguments for fixtures ] "tests/controllers/*.py" = [ "N802", # Allow CAPS methods diff --git a/tests/commands/test_backup.py b/tests/commands/test_backup.py index d2853064..e6108f44 100644 --- a/tests/commands/test_backup.py +++ b/tests/commands/test_backup.py @@ -14,7 +14,7 @@ from nummus.portfolio import Portfolio -def test_backup(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> None: +def test_backup(capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio) -> None: c = Backup(empty_portfolio.path, None) assert c.run() == 0 @@ -30,7 +30,10 @@ def test_backup(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> No assert not captured.err -def test_restore(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> None: +def test_restore( + capsys: pytest.CaptureFixture[str], + empty_portfolio: Portfolio, +) -> None: empty_portfolio.backup() c = Restore(empty_portfolio.path, None, tar_ver=None, list_ver=False) assert c.run() == 0 @@ -45,7 +48,7 @@ def test_restore(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> N def test_restore_missing( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio, ) -> None: c = Restore(empty_portfolio.path, None, tar_ver=None, list_ver=False) @@ -58,7 +61,7 @@ def test_restore_missing( def test_restore_list_empty( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio, ) -> None: c = Restore(empty_portfolio.path, None, tar_ver=None, list_ver=True) @@ -71,7 +74,7 @@ def test_restore_list_empty( def test_restore_list( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio, utc_frozen: datetime.datetime, ) -> None: diff --git a/tests/commands/test_base.py b/tests/commands/test_base.py index a7cd28ae..050a0b27 100644 --- a/tests/commands/test_base.py +++ b/tests/commands/test_base.py @@ -32,7 +32,7 @@ class MockCommand(Command): @classmethod def setup_args(cls, parser: argparse.ArgumentParser) -> None: - _ = parser + pass @override def run(self) -> int: @@ -43,7 +43,7 @@ def test_no_unlock(tmp_path: Path) -> None: MockCommand(tmp_path / "fake.db", None, do_unlock=False) -def test_no_file(capsys: pytest.CaptureFixture, tmp_path: Path) -> None: +def test_no_file(capsys: pytest.CaptureFixture[str], tmp_path: Path) -> None: path = tmp_path / "fake.db" with pytest.raises(SystemExit): MockCommand(path, None) @@ -54,7 +54,7 @@ def test_no_file(capsys: pytest.CaptureFixture, tmp_path: Path) -> None: assert captured.err == target -def test_unlock(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> None: +def test_unlock(capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio) -> None: MockCommand(empty_portfolio.path, None) captured = capsys.readouterr() @@ -63,7 +63,10 @@ def test_unlock(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> No assert not captured.err -def test_migration_required(capsys: pytest.CaptureFixture, data_path: Path) -> None: +def test_migration_required( + capsys: pytest.CaptureFixture[str], + data_path: Path, +) -> None: with pytest.raises(SystemExit): MockCommand(data_path / "old_versions" / "v0.1.16.db", None) @@ -80,7 +83,7 @@ def test_migration_required(capsys: pytest.CaptureFixture, data_path: Path) -> N @pytest.mark.skipif(not ENCRYPTION_AVAILABLE, reason="Encryption is not installed") @pytest.mark.encryption def test_unlock_encrypted_path( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio_encrypted: tuple[Portfolio, str], tmp_path: Path, ) -> None: @@ -99,7 +102,7 @@ def test_unlock_encrypted_path( @pytest.mark.skipif(not ENCRYPTION_AVAILABLE, reason="Encryption is not installed") @pytest.mark.encryption def test_unlock_encrypted_path_bad_key( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio_encrypted: tuple[Portfolio, str], tmp_path: Path, ) -> None: @@ -119,7 +122,7 @@ def test_unlock_encrypted_path_bad_key( @pytest.mark.skipif(not ENCRYPTION_AVAILABLE, reason="Encryption is not installed") @pytest.mark.encryption def test_unlock_encrypted( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, empty_portfolio_encrypted: tuple[Portfolio, str], ) -> None: @@ -147,7 +150,7 @@ def mock_get_pass(to_print: str) -> str | None: @pytest.mark.skipif(not ENCRYPTION_AVAILABLE, reason="Encryption is not installed") @pytest.mark.encryption def test_unlock_encrypted_cancel( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, empty_portfolio_encrypted: tuple[Portfolio, str], ) -> None: @@ -171,7 +174,7 @@ def mock_get_pass(to_print: str) -> str | None: @pytest.mark.skipif(not ENCRYPTION_AVAILABLE, reason="Encryption is not installed") @pytest.mark.encryption def test_unlock_encrypted_failed( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, empty_portfolio_encrypted: tuple[Portfolio, str], ) -> None: diff --git a/tests/commands/test_change_password.py b/tests/commands/test_change_password.py index 7e26ec8d..2ac0b1b3 100644 --- a/tests/commands/test_change_password.py +++ b/tests/commands/test_change_password.py @@ -26,7 +26,7 @@ def change_web_key(self, key: str) -> None: def test_no_change_unencrypted( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio, tmp_path: Path, ) -> None: @@ -49,7 +49,7 @@ def test_no_change_unencrypted( ], ) def test_change( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, empty_portfolio: Portfolio, tmp_path: Path, diff --git a/tests/commands/test_clean.py b/tests/commands/test_clean.py index 563ee5dc..8d6cd964 100644 --- a/tests/commands/test_clean.py +++ b/tests/commands/test_clean.py @@ -12,7 +12,7 @@ from nummus.portfolio import Portfolio -def test_clean(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> None: +def test_clean(capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio) -> None: c = Clean(empty_portfolio.path, None) assert c.run() == 0 diff --git a/tests/commands/test_create.py b/tests/commands/test_create.py index 0ca4dd38..49a0c2d4 100644 --- a/tests/commands/test_create.py +++ b/tests/commands/test_create.py @@ -28,7 +28,7 @@ def create(cls, path: str | Path, key: str | None = None) -> Portfolio: def test_create_existing( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], tmp_path: Path, ) -> None: path = tmp_path / "new.db" @@ -44,7 +44,7 @@ def test_create_existing( def test_create_unencrypted_forced( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, tmp_path: Path, ) -> None: @@ -63,7 +63,7 @@ def test_create_unencrypted_forced( def test_create_unencrypted( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, tmp_path: Path, ) -> None: @@ -81,7 +81,7 @@ def test_create_unencrypted( def test_create_encrypted( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, tmp_path: Path, rand_str: str, @@ -108,7 +108,7 @@ def mock_get_pass(_: str) -> str | None: def test_create_encrypted_pass_file( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, tmp_path: Path, rand_str: str, @@ -130,7 +130,7 @@ def test_create_encrypted_pass_file( def test_create_encrypted_cancelled( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, tmp_path: Path, ) -> None: diff --git a/tests/commands/test_export.py b/tests/commands/test_export.py index 0b44833b..d84e46bc 100644 --- a/tests/commands/test_export.py +++ b/tests/commands/test_export.py @@ -18,7 +18,7 @@ def test_export_empty( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio, tmp_path: Path, ) -> None: @@ -43,7 +43,7 @@ def test_export_empty( def test_export( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio, account: Account, transactions: list[Transaction], diff --git a/tests/commands/test_health.py b/tests/commands/test_health.py index 5f34c9ce..5a33d96a 100644 --- a/tests/commands/test_health.py +++ b/tests/commands/test_health.py @@ -9,7 +9,6 @@ from nummus.health_checks.top import HEALTH_CHECKS from nummus.models.config import Config, ConfigKey from nummus.models.health_checks import HealthCheckIssue -from nummus.models.utils import query_count from nummus.portfolio import Portfolio if TYPE_CHECKING: @@ -20,7 +19,7 @@ def test_issues( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio, utc_frozen: datetime.datetime, ) -> None: @@ -47,13 +46,13 @@ def test_issues( assert f"{Fore.MAGENTA}Use web interface to fix issues" in captured.out assert not captured.err - with empty_portfolio.begin_session() as s: - v = Config.fetch(s, ConfigKey.LAST_HEALTH_CHECK_TS) + with empty_portfolio.begin_session(): + v = Config.fetch(ConfigKey.LAST_HEALTH_CHECK_TS) assert v == utc_frozen.isoformat() def test_no_limit_severe( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, empty_portfolio: Portfolio, utc_frozen: datetime.datetime, @@ -85,25 +84,24 @@ def test_no_limit_severe( assert f"{Fore.MAGENTA}Use web interface to fix issues" in captured.out assert not captured.err - with empty_portfolio.begin_session() as s: - v = Config.fetch(s, ConfigKey.LAST_HEALTH_CHECK_TS) + with empty_portfolio.begin_session(): + v = Config.fetch(ConfigKey.LAST_HEALTH_CHECK_TS) assert v == utc_frozen.isoformat() def test_ignore_all( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio, utc_frozen: datetime.datetime, ) -> None: ignores: list[str] = [] for check_type in HEALTH_CHECKS: c = check_type() - with empty_portfolio.begin_session() as s: - c.test(s) + with empty_portfolio.begin_session(): + c.test() ignores.extend(c.issues.keys()) - with empty_portfolio.begin_session() as s: + with empty_portfolio.begin_session(): Config.set_( - s, ConfigKey.LAST_HEALTH_CHECK_TS, (utc_frozen - datetime.timedelta(days=1)).isoformat(), ) @@ -131,22 +129,22 @@ def test_ignore_all( assert f"{Fore.MAGENTA}Use web interface to fix issues" not in captured.out assert not captured.err - with empty_portfolio.begin_session() as s: - v = Config.fetch(s, ConfigKey.LAST_HEALTH_CHECK_TS) + with empty_portfolio.begin_session(): + v = Config.fetch(ConfigKey.LAST_HEALTH_CHECK_TS) assert v == utc_frozen.isoformat() - assert query_count(s.query(HealthCheckIssue)) == len(ignores) + assert HealthCheckIssue.count() == len(ignores) def test_clear_ignores( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio, ) -> None: ignores: list[str] = [] for check_type in HEALTH_CHECKS: c = check_type() - with empty_portfolio.begin_session() as s: - c.test(s) + with empty_portfolio.begin_session(): + c.test() ignores.extend(c.issues.keys()) c = Health( @@ -168,5 +166,5 @@ def test_clear_ignores( assert "has no transactions nor budget assignments" in captured.out assert not captured.err - with empty_portfolio.begin_session() as s: - assert query_count(s.query(HealthCheckIssue)) == len(ignores) + with empty_portfolio.begin_session(): + assert HealthCheckIssue.count() == len(ignores) diff --git a/tests/commands/test_import_files.py b/tests/commands/test_import_files.py index 63463484..c7b4e8ac 100644 --- a/tests/commands/test_import_files.py +++ b/tests/commands/test_import_files.py @@ -17,7 +17,7 @@ from nummus.portfolio import Portfolio -def test_empty(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> None: +def test_empty(capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio) -> None: path_debug = empty_portfolio.path.with_suffix(".importer_debug") c = Import(empty_portfolio.path, None, [], force=False) @@ -32,7 +32,7 @@ def test_empty(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> Non def test_non_existant( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio, tmp_path: Path, ) -> None: @@ -55,7 +55,7 @@ def test_non_existant( def test_data_dir( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio, account: Account, account_investments: Account, @@ -63,9 +63,6 @@ def test_data_dir( tmp_path: Path, data_path: Path, ) -> None: - _ = account - _ = account_investments - _ = asset files = [ "transactions_required.csv", "transactions_extras.csv", @@ -87,7 +84,7 @@ def test_data_dir( def test_unknown_importer( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio, account: Account, account_investments: Account, @@ -95,9 +92,6 @@ def test_unknown_importer( tmp_path: Path, data_path: Path, ) -> None: - _ = account - _ = account_investments - _ = asset file = "transactions_lacking.csv" path = tmp_path / file shutil.copyfile(data_path / file, path) @@ -120,7 +114,7 @@ def test_unknown_importer( def test_duplicate( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], today: datetime.date, empty_portfolio: Portfolio, account: Account, @@ -129,9 +123,6 @@ def test_duplicate( tmp_path: Path, data_path: Path, ) -> None: - _ = account - _ = account_investments - _ = asset file = "transactions_required.csv" path = tmp_path / file shutil.copyfile(data_path / file, path) @@ -155,7 +146,7 @@ def test_duplicate( @pytest.mark.xfail def test_data_dir_no_account( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio, data_path: Path, ) -> None: diff --git a/tests/commands/test_migrate.py b/tests/commands/test_migrate.py index b9259419..b4322ff3 100644 --- a/tests/commands/test_migrate.py +++ b/tests/commands/test_migrate.py @@ -16,7 +16,7 @@ def test_not_required( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio, ) -> None: @@ -33,7 +33,7 @@ def test_not_required( def test_v0_1_migration( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], tmp_path: Path, data_path: Path, ) -> None: diff --git a/tests/commands/test_summarize.py b/tests/commands/test_summarize.py index 1d142895..b191d765 100644 --- a/tests/commands/test_summarize.py +++ b/tests/commands/test_summarize.py @@ -52,9 +52,6 @@ def test_non_empty_summary( transactions: list[Transaction], asset_valuation: AssetValuation, ) -> None: - _ = transactions - _ = asset_valuation - c = Summarize(empty_portfolio.path, None, include_all=False) utc = utc.replace(tzinfo=zoneinfo.ZoneInfo("UTC")) @@ -102,10 +99,8 @@ def test_exclude_empty( account: Account, asset: Asset, ) -> None: - account.closed = True - session.commit() - _ = asset - + with session.begin_nested(): + account.closed = True c = Summarize(empty_portfolio.path, None, include_all=False) result = c._get_summary() @@ -125,10 +120,9 @@ def test_exclude_empty( def test_empty_print( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio, ) -> None: - c = Summarize(empty_portfolio.path, None, include_all=False) assert c.run() == 0 @@ -142,15 +136,12 @@ def test_empty_print( def test_non_empty_print( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], utc: datetime.datetime, empty_portfolio: Portfolio, transactions: list[Transaction], asset_valuation: AssetValuation, ) -> None: - _ = transactions - _ = asset_valuation - c = Summarize(empty_portfolio.path, None, include_all=False) utc = utc.replace(tzinfo=zoneinfo.ZoneInfo("UTC")) diff --git a/tests/commands/test_unlock.py b/tests/commands/test_unlock.py index 8f17c199..03844c52 100644 --- a/tests/commands/test_unlock.py +++ b/tests/commands/test_unlock.py @@ -14,7 +14,7 @@ def test_empty( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio, ) -> None: diff --git a/tests/commands/test_update_assets.py b/tests/commands/test_update_assets.py index e396bd2a..e5afeef5 100644 --- a/tests/commands/test_update_assets.py +++ b/tests/commands/test_update_assets.py @@ -20,10 +20,9 @@ def test_empty( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio, ) -> None: - c = UpdateAssets(empty_portfolio.path, None, no_bars=True) assert c.run() == 0 @@ -38,15 +37,14 @@ def test_empty( def test_one( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], today: datetime.date, empty_portfolio: Portfolio, transactions: list[Transaction], asset: Asset, ) -> None: - _ = transactions - with empty_portfolio.begin_session() as s: - s.query(Asset).where(Asset.category == AssetCategory.INDEX).delete() + with empty_portfolio.begin_session(): + Asset.query().where(Asset.category == AssetCategory.INDEX).delete() c = UpdateAssets(empty_portfolio.path, None, no_bars=True) assert c.run() == 0 @@ -62,15 +60,15 @@ def test_one( def test_failed( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio, transactions: list[Transaction], asset: Asset, ) -> None: - _ = transactions - with empty_portfolio.begin_session() as s: - s.query(Asset).where(Asset.category == AssetCategory.INDEX).delete() - s.query(Asset).update({"ticker": "FAKE"}) + with empty_portfolio.begin_session(): + Asset.query().where(Asset.category == AssetCategory.INDEX).delete() + Asset.query().update({"ticker": "FAKE"}) + asset.refresh() c = UpdateAssets(empty_portfolio.path, None, no_bars=True) assert c.run() != 0 diff --git a/tests/conftest.py b/tests/conftest.py index cad6d7db..4c324fe0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,6 +26,7 @@ AssetValuation, USSector, ) +from nummus.models.base import Base from nummus.models.budget import ( BudgetAssignment, BudgetGroup, @@ -41,6 +42,8 @@ from tests.mock_yfinance import MockTicker if TYPE_CHECKING: + from collections.abc import Generator + import time_machine @@ -48,7 +51,7 @@ def id_func(val: object) -> str | None: if isinstance(val, datetime.date): return val.isoformat() if isinstance(val, Iterable | Decimal | Path): - return str(val) + return str(val) # type: ignore[attr-defined] if callable(val): return val.__name__ return None @@ -211,21 +214,23 @@ def empty_portfolio_encrypted( return p, key -@pytest.fixture -def session(empty_portfolio: Portfolio) -> orm.Session: +@pytest.fixture(autouse=True) +def session(empty_portfolio: Portfolio) -> Generator[orm.Session]: """Create SQL session. - Returns: - Session generator + Yields: + Session """ - return orm.Session(sql.get_engine(empty_portfolio.path, None)) + s = orm.Session(sql.get_engine(empty_portfolio.path, None)) + with Base.set_session(s): + yield s -@pytest.fixture(autouse=True) +@pytest.fixture(autouse=True, scope="session") def uri_cipher() -> None: """Generate a URI cipher.""" - base_uri._CIPHER = base_uri.Cipher.generate() + base_uri._cipher = base_uri.Cipher.generate() @pytest.fixture(autouse=True) @@ -308,18 +313,16 @@ def account(session: orm.Session, rand_str_generator: RandomStringGenerator) -> Checking Account, not closed, budgeted """ - acct = Account( - name="Monkey bank checking", - institution="Monkey bank", - category=AccountCategory.CASH, - closed=False, - budgeted=True, - currency=DEFAULT_CURRENCY, - number=rand_str_generator(), - ) - session.add(acct) - session.commit() - return acct + with session.begin_nested(): + return Account.create( + name="Monkey bank checking", + institution="Monkey bank", + category=AccountCategory.CASH, + closed=False, + budgeted=True, + currency=DEFAULT_CURRENCY, + number=rand_str_generator(), + ) @pytest.fixture @@ -330,19 +333,17 @@ def account_savings(session: orm.Session) -> Account: Savings Account, not closed, not budgeted """ - acct = Account( - # capital case for HTML header check - name="Monkey bank savings", - institution="Monkey bank", - category=AccountCategory.CASH, - closed=False, - budgeted=False, - currency=DEFAULT_CURRENCY, - number="1234", - ) - session.add(acct) - session.commit() - return acct + with session.begin_nested(): + return Account.create( + # capital case for HTML header check + name="Monkey bank savings", + institution="Monkey bank", + category=AccountCategory.CASH, + closed=False, + budgeted=False, + currency=DEFAULT_CURRENCY, + number="1234", + ) @pytest.fixture @@ -353,18 +354,16 @@ def account_investments(session: orm.Session) -> Account: Investments Account, not closed, not budgeted """ - acct = Account( - name="Monkey bank investments", - institution="Monkey bank", - category=AccountCategory.INVESTMENT, - closed=False, - budgeted=False, - currency=DEFAULT_CURRENCY, - number="1235", - ) - session.add(acct) - session.commit() - return acct + with session.begin_nested(): + return Account.create( + name="Monkey bank investments", + institution="Monkey bank", + category=AccountCategory.INVESTMENT, + closed=False, + budgeted=False, + currency=DEFAULT_CURRENCY, + number="1235", + ) @pytest.fixture @@ -375,7 +374,7 @@ def categories(session: orm.Session) -> dict[str, int]: dict{name: category id} """ - return {name: id_ for id_, name in TransactionCategory.map_name(session).items()} + return {name: id_ for id_, name in TransactionCategory.map_name().items()} @pytest.fixture @@ -386,10 +385,11 @@ def labels(session: orm.Session) -> dict[str, int]: dict{name: label id} """ - labels = {"engineer", "fruit", "apartments 4 U"} - session.add_all(Label(name=name) for name in labels) - session.commit() - return {name: id_ for id_, name in Label.map_name(session).items()} + with session.begin_nested(): + labels = {"engineer", "fruit", "apartments 4 U"} + for name in labels: + Label.create(name=name) + return {name: id_ for id_, name in Label.map_name().items()} @pytest.fixture @@ -400,16 +400,14 @@ def asset(session: orm.Session) -> Asset: Banana Incorporated, STOCKS """ - asset = Asset( - name="Banana incorporated", - category=AssetCategory.STOCKS, - ticker="BANANA", - description="Banana Incorporated makes bananas", - currency=DEFAULT_CURRENCY, - ) - session.add(asset) - session.commit() - return asset + with session.begin_nested(): + return Asset.create( + name="Banana incorporated", + category=AssetCategory.STOCKS, + ticker="BANANA", + description="Banana Incorporated makes bananas", + currency=DEFAULT_CURRENCY, + ) @pytest.fixture @@ -420,16 +418,14 @@ def asset_etf(session: orm.Session) -> Asset: Banana ETF, STOCKS """ - asset = Asset( - name="Banana ETF", - category=AssetCategory.STOCKS, - ticker="BANANA_ETF", - description="Banana ETF", - currency=DEFAULT_CURRENCY, - ) - session.add(asset) - session.commit() - return asset + with session.begin_nested(): + return Asset.create( + name="Banana ETF", + category=AssetCategory.STOCKS, + ticker="BANANA_ETF", + description="Banana ETF", + currency=DEFAULT_CURRENCY, + ) @pytest.fixture @@ -444,10 +440,8 @@ def asset_valuation( AssetValuation on today of $10 """ - v = AssetValuation(asset_id=asset.id_, date_ord=today_ord, value=2) - session.add(v) - session.commit() - return v + with session.begin_nested(): + return AssetValuation.create(asset_id=asset.id_, date_ord=today_ord, value=2) @pytest.fixture @@ -462,10 +456,8 @@ def asset_split( AssetSplit on today of 10:1 """ - v = AssetSplit(asset_id=asset.id_, date_ord=today_ord, multiplier=10) - session.add(v) - session.commit() - return v + with session.begin_nested(): + return AssetSplit.create(asset_id=asset.id_, date_ord=today_ord, multiplier=10) @pytest.fixture @@ -479,19 +471,18 @@ def asset_sectors( 20% BASIC_MATERIALS, 80% TECHNOLOGY """ - s0 = AssetSector( - asset_id=asset.id_, - sector=USSector.BASIC_MATERIALS, - weight=Decimal("0.2"), - ) - s1 = AssetSector( - asset_id=asset.id_, - sector=USSector.TECHNOLOGY, - weight=Decimal("0.8"), - ) - session.add_all((s0, s1)) - session.commit() - return s0, s1 + with session.begin_nested(): + s0 = AssetSector.create( + asset_id=asset.id_, + sector=USSector.BASIC_MATERIALS, + weight=Decimal("0.2"), + ) + s1 = AssetSector.create( + asset_id=asset.id_, + sector=USSector.TECHNOLOGY, + weight=Decimal("0.8"), + ) + return s0, s1 @pytest.fixture @@ -505,10 +496,8 @@ def budget_group( BudgetGroup with position 0 """ - g = BudgetGroup(name=rand_str_generator(), position=0) - session.add(g) - session.commit() - return g + with session.begin_nested(): + return BudgetGroup.create(name=rand_str_generator(), position=0) @pytest.fixture @@ -521,84 +510,80 @@ def transactions( categories: dict[str, int], labels: dict[str, int], ) -> list[Transaction]: - # Fund account on 3 days before today - txn = Transaction( - account_id=account.id_, - date=today - datetime.timedelta(days=3), - amount=100, - statement=rand_str_generator(), - payee="Monkey Bank", - cleared=True, - ) - t_split_0 = TransactionSplit( - parent=txn, - amount=txn.amount, - category_id=categories["other income"], - ) - session.add_all((txn, t_split_0)) - - # Buy asset on 2 days before today - txn = Transaction( - account_id=account.id_, - date=today - datetime.timedelta(days=2), - amount=-10, - statement=rand_str_generator(), - payee="Monkey Bank", - cleared=True, - ) - t_split_1 = TransactionSplit( - parent=txn, - amount=txn.amount, - asset_id=asset.id_, - asset_quantity_unadjusted=10, - category_id=categories["securities traded"], - ) - session.add_all((txn, t_split_1)) - - # Sell asset tomorrow - txn = Transaction( - account_id=account.id_, - date=today + datetime.timedelta(days=1), - amount=50, - statement=rand_str_generator(), - payee="Monkey Bank", - cleared=True, - ) - t_split = TransactionSplit( - parent=txn, - amount=txn.amount, - asset_id=asset.id_, - asset_quantity_unadjusted=-5, - category_id=categories["securities traded"], - memo="for rent", - ) - session.add_all((txn, t_split)) - - # Sell remaining next week - txn = Transaction( - account_id=account.id_, - date=today + datetime.timedelta(days=7), - amount=50, - statement=rand_str_generator(), - payee="Monkey Bank", - cleared=True, - ) - t_split = TransactionSplit( - parent=txn, - amount=txn.amount, - asset_id=asset.id_, - asset_quantity_unadjusted=-5, - category_id=categories["securities traded"], - memo="rent transfer", - ) - session.add_all((txn, t_split)) + with session.begin_nested(): + # Fund account on 3 days before today + txn = Transaction.create( + account_id=account.id_, + date=today - datetime.timedelta(days=3), + amount=100, + statement=rand_str_generator(), + payee="Monkey Bank", + cleared=True, + ) + t_split_0 = TransactionSplit.create( + parent=txn, + amount=txn.amount, + category_id=categories["other income"], + ) - session.commit() + # Buy asset on 2 days before today + txn = Transaction.create( + account_id=account.id_, + date=today - datetime.timedelta(days=2), + amount=-10, + statement=rand_str_generator(), + payee="Monkey Bank", + cleared=True, + ) + t_split_1 = TransactionSplit.create( + parent=txn, + amount=txn.amount, + asset_id=asset.id_, + asset_quantity_unadjusted=10, + category_id=categories["securities traded"], + ) + + # Sell asset tomorrow + txn = Transaction.create( + account_id=account.id_, + date=today + datetime.timedelta(days=1), + amount=50, + statement=rand_str_generator(), + payee="Monkey Bank", + cleared=True, + ) + TransactionSplit.create( + parent=txn, + amount=txn.amount, + asset_id=asset.id_, + asset_quantity_unadjusted=-5, + category_id=categories["securities traded"], + memo="for rent", + ) - session.add(LabelLink(label_id=labels["engineer"], t_split_id=t_split_0.id_)) - session.add(LabelLink(label_id=labels["engineer"], t_split_id=t_split_1.id_)) - session.commit() - return session.query(Transaction).order_by(Transaction.date_ord).all() + # Sell remaining next week + txn = Transaction.create( + account_id=account.id_, + date=today + datetime.timedelta(days=7), + amount=50, + statement=rand_str_generator(), + payee="Monkey Bank", + cleared=True, + ) + TransactionSplit.create( + parent=txn, + amount=txn.amount, + asset_id=asset.id_, + asset_quantity_unadjusted=-5, + category_id=categories["securities traded"], + memo="rent transfer", + ) + + LabelLink.create(label_id=labels["engineer"], t_split_id=t_split_0.id_) + LabelLink.create(label_id=labels["engineer"], t_split_id=t_split_1.id_) + return ( + Transaction.query().order_by(Transaction.date_ord).all() # nummus: ignore + ) @pytest.fixture @@ -612,59 +597,113 @@ def transactions_spending( categories: dict[str, int], labels: dict[str, int], ) -> list[Transaction]: - statement_income = rand_str_generator() - statement_groceries = rand_str_generator() - statement_rent = rand_str_generator() - specs = [ - (account, Decimal(100), statement_income, "other income"), - (account, Decimal(100), statement_income, "other income"), - (account, Decimal(120), statement_income, "other income"), - (account, Decimal(-10), statement_groceries, "groceries"), - (account, Decimal(-10), statement_groceries + " other word", "groceries"), - (account, Decimal(-50), statement_rent, "rent"), - (account, Decimal(1000), rand_str_generator(), "other income"), - (account_savings, Decimal(100), statement_income, "other income"), - ] - for acct, amount, statement, category in specs: - txn = Transaction( - account_id=acct.id_, + with session.begin_nested(): + statement_income = rand_str_generator() + statement_groceries = rand_str_generator() + statement_rent = rand_str_generator() + specs = [ + (account, Decimal(100), statement_income, "other income"), + (account, Decimal(100), statement_income, "other income"), + (account, Decimal(120), statement_income, "other income"), + (account, Decimal(-10), statement_groceries, "groceries"), + (account, Decimal(-10), statement_groceries + " other word", "groceries"), + (account, Decimal(-50), statement_rent, "rent"), + (account, Decimal(1000), rand_str_generator(), "other income"), + (account_savings, Decimal(100), statement_income, "other income"), + ] + for acct, amount, statement, category in specs: + txn = Transaction.create( + account_id=acct.id_, + date=today, + amount=amount, + statement=statement, + ) + TransactionSplit.create( + parent=txn, + amount=txn.amount, + category_id=categories[category], + ) + + txn = Transaction.create( + account_id=account.id_, date=today, - amount=amount, - statement=statement, + amount=-50, + statement=statement_rent + " other word", ) - t_split = TransactionSplit( + TransactionSplit.create( parent=txn, amount=txn.amount, - category_id=categories[category], + asset_id=asset.id_, + asset_quantity_unadjusted=10, + category_id=categories["securities traded"], ) - session.add_all((txn, t_split)) - txn = Transaction( - account_id=account.id_, - date=today, - amount=-50, - statement=statement_rent + " other word", - ) - t_split = TransactionSplit( - parent=txn, - amount=txn.amount, - asset_id=asset.id_, - asset_quantity_unadjusted=10, - category_id=categories["securities traded"], - ) - session.add_all((txn, t_split)) + query = TransactionSplit.query(TransactionSplit.id_).where( + TransactionSplit.category_id == categories["rent"], + ) + t_split_id = sql.one(query) + LabelLink.create(label_id=labels["apartments 4 U"], t_split_id=t_split_id) - session.commit() + return ( + Transaction.query().order_by(Transaction.date_ord).all() # nummus: ignore + ) - t_split_id = ( - session.query(TransactionSplit.id_) - .where(TransactionSplit.category_id == categories["rent"]) - .one()[0] - ) - session.add(LabelLink(label_id=labels["apartments 4 U"], t_split_id=t_split_id)) - session.commit() - return session.query(Transaction).order_by(Transaction.date_ord).all() +@pytest.fixture +def budget_assignments( + month: datetime.date, + month_ord: int, + session: orm.Session, + categories: dict[str, int], +) -> list[BudgetAssignment]: + """Create BudgetAssignments. + + Returns: + [ + BudgetAssignment this month for $50 of groceries, + BudgetAssignment this month for $100 of emergency fund, + BudgetAssignment next month for $2000 of rent, + ] + + """ + with session.begin_nested(): + BudgetAssignment.create( + month_ord=month_ord, + amount=Decimal(50), + category_id=categories["groceries"], + ) + BudgetAssignment.create( + month_ord=month_ord, + amount=Decimal(100), + category_id=categories["emergency fund"], + ) + BudgetAssignment.create( + month_ord=utils.date_add_months(month, 1).toordinal(), + amount=Decimal(2000), + category_id=categories["rent"], + ) + return BudgetAssignment.all() + + +@pytest.fixture +def budget_target( + session: orm.Session, + categories: dict[str, int], +) -> Target: + """Create a budget target. + + Returns: + Target for Emergency Fund, $1000, no due date + + """ + with session.begin_nested(): + return Target.create( + category_id=categories["emergency fund"], + amount=Decimal(1000), + type_=TargetType.BALANCE, + period=TargetPeriod.ONCE, + repeat_every=0, + ) @pytest.fixture(autouse=True) @@ -719,8 +758,7 @@ def __init__( class MockExtension(web.FlaskExtension): @override @classmethod - def _open_portfolio(cls, config: flask.Config) -> Portfolio: - _ = config + def _open_portfolio(cls, config: dict[str, object]) -> Portfolio: return generator()[0] self._ext = MockExtension() @@ -797,65 +835,3 @@ def flask_app_encrypted( """ return flask_app_encrypted_generator(empty_portfolio_encrypted[0]) - - -@pytest.fixture -def budget_assignments( - month: datetime.date, - month_ord: int, - session: orm.Session, - categories: dict[str, int], -) -> list[BudgetAssignment]: - """Create BudgetAssignments. - - Returns: - [ - BudgetAssignment this month for $50 of groceries, - BudgetAssignment this month for $100 of emergency fund, - BudgetAssignment next month for $2000 of rent, - ] - - """ - b = BudgetAssignment( - month_ord=month_ord, - amount=Decimal(50), - category_id=categories["groceries"], - ) - session.add(b) - b = BudgetAssignment( - month_ord=month_ord, - amount=Decimal(100), - category_id=categories["emergency fund"], - ) - session.add(b) - b = BudgetAssignment( - month_ord=utils.date_add_months(month, 1).toordinal(), - amount=Decimal(2000), - category_id=categories["rent"], - ) - session.add(b) - session.commit() - return list(session.query(BudgetAssignment).all()) - - -@pytest.fixture -def budget_target( - session: orm.Session, - categories: dict[str, int], -) -> Target: - """Create a budget target. - - Returns: - Target for Emergency Fund, $1000, no due date - - """ - target = Target( - category_id=categories["emergency fund"], - amount=Decimal(1000), - type_=TargetType.BALANCE, - period=TargetPeriod.ONCE, - repeat_every=0, - ) - session.add(target) - session.commit() - return target diff --git a/tests/controllers/accounts/conftest.py b/tests/controllers/accounts/conftest.py index dc1ef3ce..35efc12f 100644 --- a/tests/controllers/accounts/conftest.py +++ b/tests/controllers/accounts/conftest.py @@ -25,56 +25,55 @@ def transactions( categories: dict[str, int], transactions: list[Transaction], ) -> list[Transaction]: - _ = transactions - # Add dividends yesterday - txn = Transaction( - account_id=account.id_, - date=today - datetime.timedelta(days=1), - amount=0, - statement=rand_str_generator(), - payee="Monkey Bank", - cleared=True, - ) - t_split_0 = TransactionSplit( - parent=txn, - amount=-1, - asset_id=asset.id_, - asset_quantity_unadjusted=1, - category_id=categories["securities traded"], - ) - t_split_1 = TransactionSplit( - parent=txn, - amount=1, - asset_id=asset.id_, - asset_quantity_unadjusted=0, - category_id=categories["dividends received"], - ) - session.add_all((txn, t_split_0, t_split_1)) + with session.begin_nested(): + # Add dividends yesterday + txn = Transaction.create( + account_id=account.id_, + date=today - datetime.timedelta(days=1), + amount=0, + statement=rand_str_generator(), + payee="Monkey Bank", + cleared=True, + ) + TransactionSplit.create( + parent=txn, + amount=-1, + asset_id=asset.id_, + asset_quantity_unadjusted=1, + category_id=categories["securities traded"], + ) + TransactionSplit.create( + parent=txn, + amount=1, + asset_id=asset.id_, + asset_quantity_unadjusted=0, + category_id=categories["dividends received"], + ) - # Add fee today - txn = Transaction( - account_id=account.id_, - date=today, - amount=0, - statement=rand_str_generator(), - payee="Monkey Bank", - cleared=True, - ) - t_split_0 = TransactionSplit( - parent=txn, - amount=2, - asset_id=asset.id_, - asset_quantity_unadjusted=-2, - category_id=categories["securities traded"], - ) - t_split_1 = TransactionSplit( - parent=txn, - amount=-2, - asset_id=asset.id_, - asset_quantity_unadjusted=0, - category_id=categories["investment fees"], - ) - session.add_all((txn, t_split_0, t_split_1)) + # Add fee today + txn = Transaction.create( + account_id=account.id_, + date=today, + amount=0, + statement=rand_str_generator(), + payee="Monkey Bank", + cleared=True, + ) + TransactionSplit.create( + parent=txn, + amount=2, + asset_id=asset.id_, + asset_quantity_unadjusted=-2, + category_id=categories["securities traded"], + ) + TransactionSplit.create( + parent=txn, + amount=-2, + asset_id=asset.id_, + asset_quantity_unadjusted=0, + category_id=categories["investment fees"], + ) - session.commit() - return session.query(Transaction).order_by(Transaction.date_ord).all() + return ( + Transaction.query().order_by(Transaction.date_ord).all() # nummus: ignore + ) diff --git a/tests/controllers/accounts/test_contexts.py b/tests/controllers/accounts/test_contexts.py index 4be565f6..21c1f3d0 100644 --- a/tests/controllers/accounts/test_contexts.py +++ b/tests/controllers/accounts/test_contexts.py @@ -27,11 +27,10 @@ @pytest.mark.parametrize("skip_today", [False, True]) def test_ctx_account_empty( today: datetime.date, - session: orm.Session, account: Account, skip_today: bool, ) -> None: - ctx = accounts.ctx_account(session, account, today, skip_today=skip_today) + ctx = accounts.ctx_account(account, today, skip_today=skip_today) target: accounts.AccountContext = { "uri": account.uri, @@ -60,11 +59,10 @@ def test_ctx_account_empty( def test_ctx_account( today: datetime.date, - session: orm.Session, account: Account, transactions: list[Transaction], ) -> None: - ctx = accounts.ctx_account(session, account, today) + ctx = accounts.ctx_account(account, today) target: accounts.AccountContext = { "uri": account.uri, @@ -93,14 +91,12 @@ def test_ctx_account( def test_ctx_performance_empty( today: datetime.date, - session: orm.Session, account: Account, ) -> None: start = utils.date_add_months(today, -12) labels, mode = base.date_labels(start.toordinal(), today.toordinal()) ctx = accounts.ctx_performance( - session, account, today, "1yr", @@ -134,12 +130,11 @@ def test_ctx_performance( asset_valuation: AssetValuation, transactions: list[Transaction], ) -> None: - asset_valuation.date_ord -= 7 - session.commit() + with session.begin_nested(): + asset_valuation.date_ord -= 7 labels, mode = base.date_labels(transactions[0].date_ord, today.toordinal()) ctx = accounts.ctx_performance( - session, account, today, "max", @@ -172,10 +167,9 @@ def test_ctx_performance( def test_ctx_assets_empty( today: datetime.date, - session: orm.Session, account: Account, ) -> None: - assert accounts.ctx_assets(session, account, today) is None + assert accounts.ctx_assets(account, today) is None def test_ctx_assets( @@ -186,11 +180,10 @@ def test_ctx_assets( asset_valuation: AssetValuation, transactions: list[Transaction], ) -> None: - _ = transactions - asset_valuation.date_ord -= 7 - session.commit() + with session.begin_nested(): + asset_valuation.date_ord -= 7 - ctx = accounts.ctx_assets(session, account, today) + ctx = accounts.ctx_assets(account, today) target: list[accounts.AssetContext] = [ { @@ -219,8 +212,8 @@ def test_ctx_assets( assert ctx == target -def test_ctx_accounts_empty(today: datetime.date, session: orm.Session) -> None: - ctx = accounts.ctx_accounts(session, today) +def test_ctx_accounts_empty(today: datetime.date) -> None: + ctx = accounts.ctx_accounts(today) target: accounts.AllAccountsContext = { "net_worth": Decimal(), @@ -244,12 +237,10 @@ def test_ctx_accounts( transactions: list[Transaction], asset_valuation: AssetValuation, ) -> None: - _ = transactions - _ = asset_valuation - account_investments.closed = True - session.commit() + with session.begin_nested(): + account_investments.closed = True - ctx = accounts.ctx_accounts(session, today, include_closed=True) + ctx = accounts.ctx_accounts(today, include_closed=True) target: accounts.AllAccountsContext = { "net_worth": Decimal(108), @@ -260,11 +251,11 @@ def test_ctx_accounts( "categories": { account.category: ( Decimal(108), - [accounts.ctx_account(session, account, today)], + [accounts.ctx_account(account, today)], ), account_investments.category: ( Decimal(), - [accounts.ctx_account(session, account_investments, today)], + [accounts.ctx_account(account_investments, today)], ), }, "include_closed": True, diff --git a/tests/controllers/accounts/test_endpoints.py b/tests/controllers/accounts/test_endpoints.py index fbcdd8a2..b189c0e6 100644 --- a/tests/controllers/accounts/test_endpoints.py +++ b/tests/controllers/accounts/test_endpoints.py @@ -45,7 +45,6 @@ def test_txns_options( account: Account, transactions: list[Transaction], ) -> None: - _ = transactions result, _ = web_client.GET(("accounts.txns_options", {"uri": account.uri})) assert 'name="period"' in result assert 'name="category"' in result @@ -78,7 +77,6 @@ def test_page( account: Account, transactions: list[Transaction], ) -> None: - _ = transactions result, _ = web_client.GET(("accounts.page", {"uri": account.uri})) assert "Transactions" in result assert "Balance" in result @@ -94,9 +92,8 @@ def test_page_performance( account: Account, transactions: list[Transaction], ) -> None: - _ = transactions - account.category = AccountCategory.INVESTMENT - session.commit() + with session.begin_nested(): + account.category = AccountCategory.INVESTMENT result, _ = web_client.GET(("accounts.page", {"uri": account.uri})) assert "Transactions" in result @@ -133,7 +130,7 @@ def test_new( assert "All changes saved" in result assert "account" in headers["HX-Trigger"] - account = session.query(Account).one() + account = Account.one() assert account.name == "New name" assert account.category == AccountCategory.INVESTMENT assert account.currency == Currency.USD @@ -186,7 +183,6 @@ def test_account_get( account: Account, transactions: list[Transaction], ) -> None: - _ = transactions result, _ = web_client.GET(("accounts.account", {"uri": account.uri})) assert account.name in result assert account.institution in result @@ -242,7 +238,6 @@ def test_account_edit_error( target: str, transactions: list[Transaction], ) -> None: - _ = transactions form = { "name": name, "category": "INVESTMENT", @@ -268,7 +263,6 @@ def test_performance( account: Account, transactions: list[Transaction], ) -> None: - _ = transactions result, headers = web_client.GET(("accounts.performance", {"uri": account.uri})) assert headers["HX-Push-URL"] == web_client.url_for( "accounts.page", @@ -300,7 +294,6 @@ def test_validation( value: str, target: str, ) -> None: - _ = account_investments result, _ = web_client.GET( ( "accounts.validation", diff --git a/tests/controllers/allocation/test_contexts.py b/tests/controllers/allocation/test_contexts.py index c537a592..a455dda1 100644 --- a/tests/controllers/allocation/test_contexts.py +++ b/tests/controllers/allocation/test_contexts.py @@ -9,8 +9,6 @@ if TYPE_CHECKING: import datetime - from sqlalchemy import orm - from nummus.models.asset import ( Asset, AssetSector, @@ -19,8 +17,8 @@ from nummus.models.transaction import Transaction -def test_ctx_empty(today: datetime.date, session: orm.Session) -> None: - ctx = allocation.ctx_allocation(session, today) +def test_ctx_empty(today: datetime.date) -> None: + ctx = allocation.ctx_allocation(today) target: allocation.AllocationContext = { "chart": { @@ -37,17 +35,12 @@ def test_ctx_empty(today: datetime.date, session: orm.Session) -> None: def test_ctx( today: datetime.date, - session: orm.Session, asset: Asset, transactions: list[Transaction], asset_valuation: AssetValuation, asset_sectors: tuple[AssetSector, AssetSector], ) -> None: - _ = transactions - _ = asset_valuation - _ = asset_sectors - - ctx = allocation.ctx_allocation(session, today) + ctx = allocation.ctx_allocation(today) target: allocation.AllocationContext = { "chart": { diff --git a/tests/controllers/assets/test_contexts.py b/tests/controllers/assets/test_contexts.py index 8c7ba68d..968aa74d 100644 --- a/tests/controllers/assets/test_contexts.py +++ b/tests/controllers/assets/test_contexts.py @@ -12,8 +12,6 @@ if TYPE_CHECKING: import datetime - from sqlalchemy import orm - from nummus.models.account import Account from nummus.models.asset import ( Asset, @@ -25,11 +23,10 @@ def test_ctx_performance_empty( today: datetime.date, - session: orm.Session, asset: Asset, ) -> None: start = utils.date_add_months(today, -12) - ctx = assets.ctx_performance(session, asset, today, "1yr") + ctx = assets.ctx_performance(asset, today, "1yr") labels, mode = base.date_labels(start.toordinal(), today.toordinal()) target: assets.PerformanceContext = { "mode": mode, @@ -46,11 +43,10 @@ def test_ctx_performance_empty( def test_ctx_performance( today: datetime.date, - session: orm.Session, asset: Asset, asset_valuation: AssetValuation, ) -> None: - ctx = assets.ctx_performance(session, asset, today, "max") + ctx = assets.ctx_performance(asset, today, "max") labels, mode = base.date_labels(asset_valuation.date_ord, today.toordinal()) target: assets.PerformanceContext = { "mode": mode, @@ -68,10 +64,9 @@ def test_ctx_performance( def test_ctx_table_empty( today: datetime.date, month: datetime.date, - session: orm.Session, asset: Asset, ) -> None: - ctx = assets.ctx_table(session, asset, today, None, None, None, None) + ctx = assets.ctx_table(asset, today, None, None, None, None) last_months = [utils.date_add_months(month, i) for i in range(0, -3, -1)] options_period = [ @@ -111,7 +106,6 @@ def test_ctx_table_empty( ) def test_ctx_table( today: datetime.date, - session: orm.Session, asset: Asset, asset_valuation: AssetValuation, period: str | None, @@ -121,7 +115,7 @@ def test_ctx_table( any_filters: bool, has_valuation: bool, ) -> None: - ctx = assets.ctx_table(session, asset, today, period, start, end, page) + ctx = assets.ctx_table(asset, today, period, start, end, page) if page is None: assert ctx["first_page"] @@ -151,10 +145,9 @@ def test_ctx_table( def test_ctx_asset_empty( today: datetime.date, - session: orm.Session, asset: Asset, ) -> None: - ctx = assets.ctx_asset(session, asset, today, None, None, None, None, None) + ctx = assets.ctx_asset(asset, today, None, None, None, None, None) assert ctx["uri"] == asset.uri assert ctx["name"] == asset.name assert ctx["category"] == asset.category @@ -166,14 +159,12 @@ def test_ctx_asset_empty( def test_ctx_asset( today: datetime.date, - session: orm.Session, account: Account, asset: Asset, asset_valuation: AssetValuation, transactions: list[Transaction], ) -> None: - _ = transactions - ctx = assets.ctx_asset(session, asset, today, None, None, None, None, None) + ctx = assets.ctx_asset(asset, today, None, None, None, None, None) assert ctx["uri"] == asset.uri assert ctx["name"] == asset.name assert ctx["category"] == asset.category @@ -185,17 +176,16 @@ def test_ctx_asset( ] -def test_ctx_rows_empty(today: datetime.date, session: orm.Session) -> None: - ctx = assets.ctx_rows(session, today, include_unheld=True) +def test_ctx_rows_empty(today: datetime.date) -> None: + ctx = assets.ctx_rows(today, include_unheld=True) assert ctx == {} def test_ctx_rows_unheld( today: datetime.date, - session: orm.Session, asset: Asset, ) -> None: - ctx = assets.ctx_rows(session, today, include_unheld=True) + ctx = assets.ctx_rows(today, include_unheld=True) target: dict[AssetCategory, list[assets.RowContext]] = { asset.category: [ { @@ -214,13 +204,11 @@ def test_ctx_rows_unheld( def test_ctx_rows( today: datetime.date, - session: orm.Session, asset: Asset, asset_valuation: AssetValuation, transactions: list[Transaction], ) -> None: - _ = transactions - ctx = assets.ctx_rows(session, today, include_unheld=False) + ctx = assets.ctx_rows(today, include_unheld=False) target: dict[AssetCategory, list[assets.RowContext]] = { asset.category: [ { diff --git a/tests/controllers/assets/test_endpoints.py b/tests/controllers/assets/test_endpoints.py index a36c785f..ae412313 100644 --- a/tests/controllers/assets/test_endpoints.py +++ b/tests/controllers/assets/test_endpoints.py @@ -4,6 +4,7 @@ import pytest +from nummus import sql from nummus.controllers import base from nummus.models.asset import ( Asset, @@ -11,7 +12,6 @@ AssetValuation, ) from nummus.models.currency import Currency, CURRENCY_FORMATS, DEFAULT_CURRENCY -from nummus.models.utils import query_count if TYPE_CHECKING: import datetime @@ -56,8 +56,8 @@ def test_new( web_client: WebClient, session: orm.Session, ) -> None: - session.query(Asset).delete() - session.commit() + with session.begin_nested(): + Asset.query().delete() result, headers = web_client.POST( "assets.new", @@ -73,7 +73,7 @@ def test_new( assert "All changes saved" in result assert "asset" in headers["HX-Trigger"] - a = session.query(Asset).one() + a = Asset.one() assert a.name == "New name" assert a.category == AssetCategory.STOCKS assert a.currency == Currency.USD @@ -112,7 +112,6 @@ def test_asset_get( asset: Asset, transactions: list[Transaction], ) -> None: - _ = transactions result, _ = web_client.GET(("assets.asset", {"uri": asset.uri})) assert asset.name in result assert asset.ticker is not None @@ -124,7 +123,7 @@ def test_asset_get( assert "Delete" not in result -def test_asset_edit(web_client: WebClient, session: orm.Session, asset: Asset) -> None: +def test_asset_edit(web_client: WebClient, asset: Asset) -> None: result, headers = web_client.PUT( ("assets.asset", {"uri": asset.uri}), data={ @@ -139,7 +138,7 @@ def test_asset_edit(web_client: WebClient, session: orm.Session, asset: Asset) - assert "All changes saved" in result assert "asset" in headers["HX-Trigger"] - session.refresh(asset) + asset.refresh() assert asset.name == "New name" assert asset.category == AssetCategory.BONDS assert asset.currency == Currency.EUR @@ -166,7 +165,6 @@ def test_account_delete( asset: Asset, asset_valuation: AssetValuation, ) -> None: - _ = asset_valuation result, headers = web_client.DELETE(("assets.asset", {"uri": asset.uri})) assert not result assert headers["HX-Redirect"] == web_client.url_for("assets.page_all") @@ -225,7 +223,6 @@ def test_validation( value: str, target: str, ) -> None: - _ = asset_etf args = {prop: value} if include_asset: args["uri"] = asset.uri @@ -247,7 +244,6 @@ def test_new_valuation_get( def test_new_valuation( today: datetime.date, - session: orm.Session, web_client: WebClient, asset: Asset, rand_real: Decimal, @@ -263,7 +259,7 @@ def test_new_valuation( assert "All changes saved" in result assert "valuation" in headers["HX-Trigger"] - v = session.query(AssetValuation).one() + v = AssetValuation.one() assert v.asset_id == asset.id_ assert v.date == today assert v.value == rand_real @@ -323,7 +319,6 @@ def test_valuation_get( def test_valuation_delete( - session: orm.Session, web_client: WebClient, asset_valuation: AssetValuation, ) -> None: @@ -334,13 +329,12 @@ def test_valuation_delete( assert f"{asset_valuation.date} valuation deleted" in result assert "valuation" in headers["HX-Trigger"] - v = session.query(AssetValuation).one_or_none() + v = AssetValuation.query().one_or_none() assert v is None def test_valuation_edit( tomorrow: datetime.date, - session: orm.Session, web_client: WebClient, asset_valuation: AssetValuation, rand_real: Decimal, @@ -356,7 +350,7 @@ def test_valuation_edit( assert "All changes saved" in result assert "valuation" in headers["HX-Trigger"] - session.refresh(asset_valuation) + asset_valuation.refresh() assert asset_valuation.date == tomorrow assert asset_valuation.value == rand_real @@ -396,13 +390,12 @@ def test_valuation_duplicate( asset: Asset, asset_valuation: AssetValuation, ) -> None: - v = AssetValuation( - asset_id=asset.id_, - date_ord=tomorrow.toordinal(), - value=asset_valuation.value, - ) - session.add(v) - session.commit() + with session.begin_nested(): + AssetValuation.create( + asset_id=asset.id_, + date_ord=tomorrow.toordinal(), + value=asset_valuation.value, + ) result, _ = web_client.PUT( ("assets.valuation", {"uri": asset_valuation.uri}), @@ -415,8 +408,8 @@ def test_valuation_duplicate( def test_update_get_empty(session: orm.Session, web_client: WebClient) -> None: - session.query(Asset).delete() - session.commit() + with session.begin_nested(): + Asset.query().delete() result, _ = web_client.GET("assets.update") assert "Update assets" in result @@ -428,17 +421,16 @@ def test_update_get_one( web_client: WebClient, asset: Asset, ) -> None: - _ = asset - session.query(Asset).where(Asset.category == AssetCategory.INDEX).delete() - session.commit() + with session.begin_nested(): + Asset.query().where(Asset.category == AssetCategory.INDEX).delete() result, _ = web_client.GET("assets.update") assert "Update assets" in result assert "There is one asset with ticker to update" in result -def test_update_get(session: orm.Session, web_client: WebClient) -> None: - n = query_count(session.query(Asset)) +def test_update_get(web_client: WebClient) -> None: + n = sql.count(Asset.query()) result, _ = web_client.GET("assets.update") assert "Update assets" in result @@ -456,10 +448,8 @@ def test_update( asset: Asset, transactions: list[Transaction], ) -> None: - _ = asset - _ = transactions - session.query(Asset).where(Asset.category == AssetCategory.INDEX).delete() - session.commit() + with session.begin_nested(): + Asset.query().where(Asset.category == AssetCategory.INDEX).delete() result, headers = web_client.POST("assets.update") assert "snackbar.show" in result @@ -471,6 +461,5 @@ def test_update_error( web_client: WebClient, transactions: list[Transaction], ) -> None: - _ = transactions result, _ = web_client.POST("assets.update") assert "No timezone found, symbol may be delisted" in result diff --git a/tests/controllers/budgeting/test_contexts.py b/tests/controllers/budgeting/test_contexts.py index 93bc5d02..bd7fbfe8 100644 --- a/tests/controllers/budgeting/test_contexts.py +++ b/tests/controllers/budgeting/test_contexts.py @@ -6,7 +6,7 @@ import pytest -from nummus import utils +from nummus import sql, utils from nummus.controllers import budgeting from nummus.models.budget import ( BudgetAssignment, @@ -33,18 +33,13 @@ def test_ctx_sidebar_global( today: datetime.date, month: datetime.date, - session: orm.Session, transactions_spending: list[Transaction], budget_assignments: list[BudgetAssignment], budget_target: Target, ) -> None: - _ = transactions_spending - _ = budget_assignments - _ = budget_target - data = BudgetAssignment.get_monthly_available(session, month) + data = BudgetAssignment.get_monthly_available(month) ctx = budgeting.ctx_sidebar( - session, today, month, data.categories, @@ -52,11 +47,11 @@ def test_ctx_sidebar_global( None, ) - query = session.query(TransactionCategory).where( + query = TransactionCategory.query().where( TransactionCategory.name.not_in({"emergency fund"}), TransactionCategory.group != TransactionCategoryGroup.INCOME, ) - categories = {t_cat.uri: t_cat.emoji_name for t_cat in query.all()} + categories = {t_cat.uri: t_cat.emoji_name for t_cat in sql.yield_(query)} target: budgeting.SidebarContext = { "uri": None, @@ -77,18 +72,14 @@ def test_ctx_sidebar_global( def test_ctx_sidebar_no_target( today: datetime.date, month: datetime.date, - session: orm.Session, transactions_spending: list[Transaction], budget_assignments: list[BudgetAssignment], categories: dict[str, int], ) -> None: - _ = transactions_spending - _ = budget_assignments - data = BudgetAssignment.get_monthly_available(session, month) + data = BudgetAssignment.get_monthly_available(month) uri = TransactionCategory.id_to_uri(categories["emergency fund"]) ctx = budgeting.ctx_sidebar( - session, today, month, data.categories, @@ -113,19 +104,15 @@ def test_ctx_sidebar_no_target( def test_ctx_sidebar( today: datetime.date, month: datetime.date, - session: orm.Session, transactions_spending: list[Transaction], budget_assignments: list[BudgetAssignment], budget_target: Target, ) -> None: - _ = transactions_spending - _ = budget_assignments - data = BudgetAssignment.get_monthly_available(session, month) + data = BudgetAssignment.get_monthly_available(month) data_cat = data.categories[budget_target.category_id] uri = TransactionCategory.id_to_uri(budget_target.category_id) ctx = budgeting.ctx_sidebar( - session, today, month, data.categories, @@ -191,8 +178,8 @@ def test_ctx_target_once( budget_target: Target, ) -> None: due_date = utils.date_add_months(month, 8) - budget_target.due_date_ord = due_date.toordinal() - session.commit() + with session.begin_nested(): + budget_target.due_date_ord = due_date.toordinal() ctx = budgeting.ctx_target( budget_target, @@ -226,11 +213,11 @@ def test_ctx_target_weekly_refil( budget_target: Target, ) -> None: month = utils.date_add_months(month, 1) - budget_target.period = TargetPeriod.WEEK - budget_target.type_ = TargetType.REFILL - budget_target.due_date_ord = today.toordinal() - budget_target.repeat_every = 1 - session.commit() + with session.begin_nested(): + budget_target.period = TargetPeriod.WEEK + budget_target.type_ = TargetType.REFILL + budget_target.due_date_ord = today.toordinal() + budget_target.repeat_every = 1 n_weekdays = utils.weekdays_in_month(today.weekday(), month) ctx = budgeting.ctx_target( @@ -264,11 +251,11 @@ def test_ctx_target_weekly_accumulate( session: orm.Session, budget_target: Target, ) -> None: - budget_target.period = TargetPeriod.WEEK - budget_target.type_ = TargetType.ACCUMULATE - budget_target.due_date_ord = today.toordinal() - budget_target.repeat_every = 1 - session.commit() + with session.begin_nested(): + budget_target.period = TargetPeriod.WEEK + budget_target.type_ = TargetType.ACCUMULATE + budget_target.due_date_ord = today.toordinal() + budget_target.repeat_every = 1 n_weekdays = utils.weekdays_in_month(today.weekday(), month) ctx = budgeting.ctx_target( @@ -303,11 +290,11 @@ def test_ctx_target_monthly_refil( budget_target: Target, ) -> None: month = utils.date_add_months(month, 1) - budget_target.period = TargetPeriod.MONTH - budget_target.type_ = TargetType.REFILL - budget_target.due_date_ord = today.toordinal() - budget_target.repeat_every = 2 - session.commit() + with session.begin_nested(): + budget_target.period = TargetPeriod.MONTH + budget_target.type_ = TargetType.REFILL + budget_target.due_date_ord = today.toordinal() + budget_target.repeat_every = 2 ctx = budgeting.ctx_target( budget_target, @@ -340,11 +327,11 @@ def test_ctx_target_monthly_accumulate( session: orm.Session, budget_target: Target, ) -> None: - budget_target.period = TargetPeriod.MONTH - budget_target.type_ = TargetType.ACCUMULATE - budget_target.due_date_ord = today.toordinal() - budget_target.repeat_every = 1 - session.commit() + with session.begin_nested(): + budget_target.period = TargetPeriod.MONTH + budget_target.type_ = TargetType.ACCUMULATE + budget_target.due_date_ord = today.toordinal() + budget_target.repeat_every = 1 ctx = budgeting.ctx_target( budget_target, @@ -374,11 +361,9 @@ def test_ctx_target_monthly_accumulate( def test_ctx_budget_empty( today: datetime.date, month: datetime.date, - session: orm.Session, ) -> None: - data = BudgetAssignment.get_monthly_available(session, month) + data = BudgetAssignment.get_monthly_available(month) ctx, title = budgeting.ctx_budget( - session, today, month, data.categories, @@ -419,21 +404,18 @@ def test_ctx_budget( budget_target: Target, categories: dict[str, int], ) -> None: - session.query(TransactionCategory).where( - TransactionCategory.name == "groceries", - ).update( - { - TransactionCategory.budget_group_id: budget_group.id_, - TransactionCategory.budget_position: 0, - }, - ) - session.commit() - _ = transactions_spending - _ = budget_assignments - data = BudgetAssignment.get_monthly_available(session, month) + with session.begin_nested(): + TransactionCategory.query().where( + TransactionCategory.name == "groceries", + ).update( + { + TransactionCategory.budget_group_id: budget_group.id_, + TransactionCategory.budget_position: 0, + }, + ) + data = BudgetAssignment.get_monthly_available(month) ctx, title = budgeting.ctx_budget( - session, today, month, data.categories, diff --git a/tests/controllers/budgeting/test_endpoints.py b/tests/controllers/budgeting/test_endpoints.py index 1c2e568c..0aedd8da 100644 --- a/tests/controllers/budgeting/test_endpoints.py +++ b/tests/controllers/budgeting/test_endpoints.py @@ -7,6 +7,7 @@ import pytest import werkzeug.datastructures +from nummus import sql from nummus.controllers import base, budgeting from nummus.models.budget import ( BudgetAssignment, @@ -16,7 +17,6 @@ TargetType, ) from nummus.models.transaction_category import TransactionCategory -from nummus.models.utils import query_count if TYPE_CHECKING: import flask @@ -53,7 +53,6 @@ def test_validation( def test_assign_new( month: datetime.date, - session: orm.Session, web_client: WebClient, categories: dict[str, int], rand_real: Decimal, @@ -68,7 +67,7 @@ def test_assign_new( assert "Ungrouped" in result - a = session.query(BudgetAssignment).one() + a = BudgetAssignment.one() assert a.category_id == t_cat_id assert a.month_ord == month.toordinal() assert a.amount == round(rand_real, 2) @@ -76,7 +75,6 @@ def test_assign_new( def test_assign_edit( month: datetime.date, - session: orm.Session, web_client: WebClient, budget_assignments: list[BudgetAssignment], ) -> None: @@ -90,14 +88,13 @@ def test_assign_edit( assert "Ungrouped" in result - session.refresh(a) + a.refresh() assert a.month_ord == month.toordinal() assert a.amount == Decimal(10) def test_assign_remove( month: datetime.date, - session: orm.Session, web_client: WebClient, budget_assignments: list[BudgetAssignment], ) -> None: @@ -111,11 +108,7 @@ def test_assign_remove( assert "Ungrouped" in result - a = ( - session.query(BudgetAssignment) - .where(BudgetAssignment.id_ == a.id_) - .one_or_none() - ) + a = BudgetAssignment.query().where(BudgetAssignment.id_ == a.id_).one_or_none() assert a is None @@ -125,8 +118,6 @@ def test_move_get_income( transactions_spending: list[Transaction], budget_assignments: list[BudgetAssignment], ) -> None: - _ = transactions_spending - _ = budget_assignments result, _ = web_client.GET( ("budgeting.move", {"uri": "income", "month": month.isoformat()[:7]}), ) @@ -143,8 +134,6 @@ def test_move_get( categories: dict[str, int], budget_assignments: list[BudgetAssignment], ) -> None: - _ = transactions_spending - _ = budget_assignments t_cat_uri = TransactionCategory.id_to_uri(categories["groceries"]) result, _ = web_client.GET( @@ -163,8 +152,6 @@ def test_move_get_destination( categories: dict[str, int], budget_assignments: list[BudgetAssignment], ) -> None: - _ = transactions_spending - _ = budget_assignments t_cat_uri = TransactionCategory.id_to_uri(categories["groceries"]) result, _ = web_client.GET( @@ -185,7 +172,6 @@ def test_move_get_overspending( transactions_spending: list[Transaction], categories: dict[str, int], ) -> None: - _ = transactions_spending t_cat_uri = TransactionCategory.id_to_uri(categories["groceries"]) result, _ = web_client.GET( @@ -199,13 +185,11 @@ def test_move_get_overspending( def test_move_overspending( month: datetime.date, - session: orm.Session, web_client: WebClient, transactions_spending: list[Transaction], budget_assignments: list[BudgetAssignment], categories: dict[str, int], ) -> None: - _ = transactions_spending a = budget_assignments[0] uri = TransactionCategory.id_to_uri(categories["securities traded"]) dest_uri = TransactionCategory.id_to_uri(a.category_id) @@ -218,19 +202,17 @@ def test_move_overspending( assert "$30.00 reallocated" in result assert "budget" in headers["HX-Trigger"] - session.refresh(a) + a.refresh() assert a.month_ord == month.toordinal() assert a.amount == Decimal(20) def test_move_to_income( month: datetime.date, - session: orm.Session, web_client: WebClient, transactions_spending: list[Transaction], budget_assignments: list[BudgetAssignment], ) -> None: - _ = transactions_spending a = budget_assignments[0] t_cat_uri = TransactionCategory.id_to_uri(a.category_id) @@ -242,7 +224,7 @@ def test_move_to_income( assert "$10.00 reallocated" in result assert "budget" in headers["HX-Trigger"] - session.refresh(a) + a.refresh() assert a.month_ord == month.toordinal() assert a.amount == Decimal(40) @@ -263,11 +245,9 @@ def test_move_error( def test_reorder_empty( - session: orm.Session, web_client: WebClient, budget_group: BudgetGroup, ) -> None: - _ = budget_group result, _ = web_client.PUT( "budgeting.reorder", data={ @@ -278,37 +258,29 @@ def test_reorder_empty( ) assert not result - query = session.query(BudgetGroup) - assert query_count(query) == 0 - query = session.query(TransactionCategory).where( + assert not sql.any_(BudgetGroup.query()) + query = TransactionCategory.query().where( TransactionCategory.budget_group_id.is_not(None), ) - assert query_count(query) == 0 - query = session.query(TransactionCategory).where( + assert not sql.any_(query) + query = TransactionCategory.query().where( TransactionCategory.budget_position.is_not(None), ) - assert query_count(query) == 0 + assert not sql.any_(query) def test_reorder( - session: orm.Session, web_client: WebClient, budget_group: BudgetGroup, ) -> None: - t_cat_0 = ( - session.query(TransactionCategory) - .where(TransactionCategory.name == "groceries") - .one() + t_cat_0 = sql.one( + TransactionCategory.query().where(TransactionCategory.name == "groceries"), ) - t_cat_1 = ( - session.query(TransactionCategory) - .where(TransactionCategory.name == "rent") - .one() + t_cat_1 = sql.one( + TransactionCategory.query().where(TransactionCategory.name == "rent"), ) - t_cat_2 = ( - session.query(TransactionCategory) - .where(TransactionCategory.name == "transfers") - .one() + t_cat_2 = sql.one( + TransactionCategory.query().where(TransactionCategory.name == "transfers"), ) result, _ = web_client.PUT( @@ -321,15 +293,15 @@ def test_reorder( ) assert not result - session.refresh(t_cat_0) + t_cat_0.refresh() assert t_cat_0.budget_group_id == budget_group.id_ assert t_cat_0.budget_position == 0 - session.refresh(t_cat_1) + t_cat_1.refresh() assert t_cat_1.budget_group_id == budget_group.id_ assert t_cat_1.budget_position == 1 - session.refresh(t_cat_2) + t_cat_2.refresh() assert t_cat_2.budget_group_id is None assert t_cat_2.budget_position is None @@ -500,10 +472,10 @@ def test_target_get_once( web_client: WebClient, budget_target: Target, ) -> None: - budget_target.type_ = TargetType.BALANCE - budget_target.period = TargetPeriod.ONCE - budget_target.due_date_ord = today_ord - session.commit() + with session.begin_nested(): + budget_target.type_ = TargetType.BALANCE + budget_target.period = TargetPeriod.ONCE + budget_target.due_date_ord = today_ord result, _ = web_client.GET(("budgeting.target", {"uri": budget_target.uri})) assert "Edit target" in result @@ -511,7 +483,6 @@ def test_target_get_once( def test_target_new( today_ord: int, - session: orm.Session, web_client: WebClient, categories: dict[str, int], ) -> None: @@ -527,7 +498,7 @@ def test_target_new( assert "Groceries target created" in result assert "budget" in headers["HX-Trigger"] - tar = session.query(Target).one() + tar = Target.one() assert tar.category_id == t_cat_id assert tar.amount == Decimal(10) assert tar.type_ == TargetType.ACCUMULATE @@ -554,7 +525,6 @@ def test_target_new_error( def test_target_put( - session: orm.Session, web_client: WebClient, budget_target: Target, ) -> None: @@ -566,12 +536,11 @@ def test_target_put( assert "All changes saved" in result assert "budget" in headers["HX-Trigger"] - session.refresh(budget_target) + budget_target.refresh() assert budget_target.amount == Decimal(10) def test_target_delete( - session: orm.Session, web_client: WebClient, budget_target: Target, ) -> None: @@ -582,7 +551,7 @@ def test_target_delete( assert "Emergency Fund target deleted" in result assert "budget" in headers["HX-Trigger"] - tar = session.query(Target).one_or_none() + tar = Target.query().one_or_none() assert tar is None diff --git a/tests/controllers/conftest.py b/tests/controllers/conftest.py index 7558cd91..d918e5d0 100644 --- a/tests/controllers/conftest.py +++ b/tests/controllers/conftest.py @@ -254,7 +254,7 @@ def has_valid_classes(self, inner_html: str) -> bool: ResultType = dict[str, object] | str | bytes -Queries = dict[str, str] | dict[str, str | bool | list[str | bool]] +Queries = dict[str, str] | dict[str, str | bool | list[str] | list[str | bool]] class HTMLValidator: diff --git a/tests/controllers/emergency_fund/test_contexts.py b/tests/controllers/emergency_fund/test_contexts.py index e31280aa..95d1f921 100644 --- a/tests/controllers/emergency_fund/test_contexts.py +++ b/tests/controllers/emergency_fund/test_contexts.py @@ -19,12 +19,12 @@ from nummus.models.budget import BudgetAssignment -def test_empty(today: datetime.date, session: orm.Session) -> None: +def test_empty(today: datetime.date) -> None: start = today - datetime.timedelta(days=utils.DAYS_IN_QUARTER * 2) dates = utils.range_date(start.toordinal(), today.toordinal()) n = len(dates) - ctx = emergency_fund.ctx_page(session, today) + ctx = emergency_fund.ctx_page(today) target: emergency_fund.EFundContext = { "chart": { @@ -56,26 +56,23 @@ def test_ctx_underfunded( budget_assignments: list[BudgetAssignment], rand_str: str, ) -> None: - _ = transactions_spending - _ = budget_assignments - session.query(TransactionCategory).where( - TransactionCategory.name == "groceries", - ).update({"essential_spending": True}) - txn = Transaction( - account_id=account.id_, - date=today - datetime.timedelta(days=100), - amount=-1000, - statement=rand_str, - ) - t_split = TransactionSplit( - parent=txn, - amount=txn.amount, - category_id=categories["groceries"], - ) - session.add_all((txn, t_split)) - session.commit() - - ctx = emergency_fund.ctx_page(session, today) + with session.begin_nested(): + TransactionCategory.query().where( + TransactionCategory.name == "groceries", + ).update({"essential_spending": True}) + txn = Transaction.create( + account_id=account.id_, + date=today - datetime.timedelta(days=100), + amount=-1000, + statement=rand_str, + ) + TransactionSplit.create( + parent=txn, + amount=txn.amount, + category_id=categories["groceries"], + ) + + ctx = emergency_fund.ctx_page(today) assert ctx["current"] == Decimal(100) assert ctx["days"] == pytest.approx(Decimal(34), abs=Decimal(1)) @@ -98,26 +95,23 @@ def test_ctx_overfunded( budget_assignments: list[BudgetAssignment], rand_str: str, ) -> None: - _ = transactions_spending - _ = budget_assignments - session.query(TransactionCategory).where( - TransactionCategory.name == "groceries", - ).update({"essential_spending": True}) - txn = Transaction( - account_id=account.id_, - date=today - datetime.timedelta(days=100), - amount=-50, - statement=rand_str, - ) - t_split = TransactionSplit( - parent=txn, - amount=txn.amount, - category_id=categories["groceries"], - ) - session.add_all((txn, t_split)) - session.commit() - - ctx = emergency_fund.ctx_page(session, today) + with session.begin_nested(): + TransactionCategory.query().where( + TransactionCategory.name == "groceries", + ).update({"essential_spending": True}) + txn = Transaction.create( + account_id=account.id_, + date=today - datetime.timedelta(days=100), + amount=-50, + statement=rand_str, + ) + TransactionSplit.create( + parent=txn, + amount=txn.amount, + category_id=categories["groceries"], + ) + + ctx = emergency_fund.ctx_page(today) assert ctx["current"] == Decimal(100) assert ctx["days"] == pytest.approx(Decimal(347), abs=Decimal(1)) @@ -140,26 +134,23 @@ def test_ctx( budget_assignments: list[BudgetAssignment], rand_str: str, ) -> None: - _ = transactions_spending - _ = budget_assignments - session.query(TransactionCategory).where( - TransactionCategory.name == "groceries", - ).update({"essential_spending": True}) - txn = Transaction( - account_id=account.id_, - date=today - datetime.timedelta(days=100), - amount=-200, - statement=rand_str, - ) - t_split = TransactionSplit( - parent=txn, - amount=txn.amount, - category_id=categories["groceries"], - ) - session.add_all((txn, t_split)) - session.commit() - - ctx = emergency_fund.ctx_page(session, today) + with session.begin_nested(): + TransactionCategory.query().where( + TransactionCategory.name == "groceries", + ).update({"essential_spending": True}) + txn = Transaction.create( + account_id=account.id_, + date=today - datetime.timedelta(days=100), + amount=-200, + statement=rand_str, + ) + TransactionSplit.create( + parent=txn, + amount=txn.amount, + category_id=categories["groceries"], + ) + + ctx = emergency_fund.ctx_page(today) assert ctx["current"] == Decimal(100) assert ctx["days"] == pytest.approx(Decimal(119), abs=Decimal(1)) diff --git a/tests/controllers/health/test_contexts.py b/tests/controllers/health/test_contexts.py index 9ee0f0ab..46857d02 100644 --- a/tests/controllers/health/test_contexts.py +++ b/tests/controllers/health/test_contexts.py @@ -1,18 +1,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING - +from nummus import sql from nummus.controllers import health from nummus.health_checks.top import HEALTH_CHECKS from nummus.models.transaction_category import TransactionCategory -from nummus.models.utils import query_count - -if TYPE_CHECKING: - from sqlalchemy import orm -def test_ctx_empty(session: orm.Session) -> None: - ctx = health.ctx_checks(session, run=False) +def test_ctx_empty() -> None: + ctx = health.ctx_checks(run=False) assert ctx["last_update_ago"] is None checks = ctx["checks"] @@ -21,8 +16,8 @@ def test_ctx_empty(session: orm.Session) -> None: assert not has_issues -def test_ctx_empty_run(session: orm.Session) -> None: - ctx = health.ctx_checks(session, run=True) +def test_ctx_empty_run() -> None: + ctx = health.ctx_checks(run=True) assert ctx["last_update_ago"] == 0 checks = ctx["checks"] @@ -33,7 +28,7 @@ def test_ctx_empty_run(session: orm.Session) -> None: assert c["name"] == "Unused categories" # All unused - query = session.query(TransactionCategory).where( + query = TransactionCategory.query().where( TransactionCategory.locked.is_(False), ) - assert len(c["issues"]) == query_count(query) + assert len(c["issues"]) == sql.count(query) diff --git a/tests/controllers/health/test_endpoints.py b/tests/controllers/health/test_endpoints.py index b36571f1..24844fae 100644 --- a/tests/controllers/health/test_endpoints.py +++ b/tests/controllers/health/test_endpoints.py @@ -35,9 +35,9 @@ def test_refresh(web_client: WebClient, n_runs: int) -> None: def test_ignore(web_client: WebClient, session: orm.Session) -> None: - c = UnusedCategories() - c.test(session) - session.commit() + with session.begin_nested(): + c = UnusedCategories() + c.test() uri = next(iter(c.issues.keys())) diff --git a/tests/controllers/income/test_endpoints.py b/tests/controllers/income/test_endpoints.py index 09285e0a..47f59bc1 100644 --- a/tests/controllers/income/test_endpoints.py +++ b/tests/controllers/income/test_endpoints.py @@ -17,7 +17,6 @@ def test_page( account: Account, transactions: list[Transaction], ) -> None: - _ = transactions result, _ = web_client.GET(("income.page", {"period": "all"})) assert "Income" in result assert account.name in result @@ -30,7 +29,6 @@ def test_chart( account: Account, transactions: list[Transaction], ) -> None: - _ = transactions result, _ = web_client.GET(("income.chart", {"period": "all"})) assert "Income" in result assert account.name in result @@ -45,9 +43,9 @@ def test_dashboard( account: Account, transactions: list[Transaction], ) -> None: - transactions[0].date = today - transactions[0].splits[0].parent = transactions[0] - session.commit() + with session.begin_nested(): + transactions[0].date = today + transactions[0].splits[0].parent = transactions[0] result, _ = web_client.GET("income.dashboard") assert "Income" in result diff --git a/tests/controllers/labels/test_contexts.py b/tests/controllers/labels/test_contexts.py index 545e19c4..d6967247 100644 --- a/tests/controllers/labels/test_contexts.py +++ b/tests/controllers/labels/test_contexts.py @@ -1,17 +1,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from nummus.controllers import base from nummus.controllers import labels as label_controller from nummus.models.label import Label -if TYPE_CHECKING: - from sqlalchemy import orm - -def test_ctx(session: orm.Session, labels: dict[str, int]) -> None: - ctx = label_controller.ctx_labels(session) +def test_ctx(labels: dict[str, int]) -> None: + ctx = label_controller.ctx_labels() target: list[base.NamePair] = [ base.NamePair(Label.id_to_uri(label_id), name) diff --git a/tests/controllers/labels/test_endpoints.py b/tests/controllers/labels/test_endpoints.py index 69d9e40f..b8413511 100644 --- a/tests/controllers/labels/test_endpoints.py +++ b/tests/controllers/labels/test_endpoints.py @@ -4,12 +4,11 @@ import pytest +from nummus import sql from nummus.controllers import base from nummus.models.label import Label, LabelLink -from nummus.models.utils import query_count if TYPE_CHECKING: - from sqlalchemy import orm from nummus.models.transaction import Transaction from tests.controllers.conftest import WebClient @@ -53,10 +52,8 @@ def test_label_get(web_client: WebClient, labels: dict[str, int]) -> None: def test_label_delete( web_client: WebClient, labels: dict[str, int], - session: orm.Session, transactions: list[Transaction], ) -> None: - _ = transactions uri = Label.id_to_uri(labels["engineer"]) result, headers = web_client.DELETE( @@ -66,14 +63,12 @@ def test_label_delete( assert "Deleted label engineer" in result assert "label" in headers["HX-Trigger"] - n = query_count(session.query(LabelLink)) - assert n == 0 + assert not sql.any_(LabelLink.query()) def test_label_edit( web_client: WebClient, labels: dict[str, int], - session: orm.Session, ) -> None: uri = Label.id_to_uri(labels["engineer"]) @@ -85,7 +80,7 @@ def test_label_edit( assert "All changes saved" in result assert "label" in headers["HX-Trigger"] - label = session.query(Label).where(Label.name == "new label").one() + label = sql.one(Label.query().where(Label.name == "new label")) assert label.id_ == labels["engineer"] diff --git a/tests/controllers/net_worth/test_contexts.py b/tests/controllers/net_worth/test_contexts.py index 5b06b2c6..f7ffdfa8 100644 --- a/tests/controllers/net_worth/test_contexts.py +++ b/tests/controllers/net_worth/test_contexts.py @@ -19,10 +19,8 @@ def test_ctx_chart_empty( today: datetime.date, account: Account, - session: orm.Session, ) -> None: - _ = account - ctx = net_worth.ctx_chart(session, today, "max") + ctx = net_worth.ctx_chart(today, "max") chart: base.ChartData = { "labels": [today.isoformat()], @@ -51,9 +49,8 @@ def test_ctx_chart_empty( def test_ctx_chart_this_year( today: datetime.date, - session: orm.Session, ) -> None: - ctx = net_worth.ctx_chart(session, today, "ytd") + ctx = net_worth.ctx_chart(today, "ytd") assert ctx["start"] == today.replace(month=1, day=1) assert ctx["end"] == today @@ -69,28 +66,25 @@ def test_ctx_chart( session: orm.Session, categories: dict[str, int], ) -> None: - _ = asset_valuation - _ = transactions - # Make account_investments negative - txn = Transaction( - account_id=account_investments.id_, - date=today, - amount=-100, - statement=rand_str_generator(), - payee="Monkey Bank", - cleared=True, - ) - t_split = TransactionSplit( - parent=txn, - amount=txn.amount, - category_id=categories["groceries"], - ) - session.add_all((txn, t_split)) - session.commit() + with session.begin_nested(): + # Make account_investments negative + txn = Transaction.create( + account_id=account_investments.id_, + date=today, + amount=-100, + statement=rand_str_generator(), + payee="Monkey Bank", + cleared=True, + ) + TransactionSplit.create( + parent=txn, + amount=txn.amount, + category_id=categories["groceries"], + ) start = today - datetime.timedelta(days=3) end = today + datetime.timedelta(days=3) - ctx = net_worth.ctx_chart(session, end, "max") + ctx = net_worth.ctx_chart(end, "max") chart: base.ChartData = { "labels": base.date_labels(start.toordinal(), end.toordinal())[0], diff --git a/tests/controllers/net_worth/test_endpoints.py b/tests/controllers/net_worth/test_endpoints.py index 55d99ae5..f1668cf9 100644 --- a/tests/controllers/net_worth/test_endpoints.py +++ b/tests/controllers/net_worth/test_endpoints.py @@ -23,9 +23,6 @@ def test_page( asset_valuation: AssetValuation, transactions: list[Transaction], ) -> None: - _ = asset_valuation - _ = transactions - result, _ = web_client.GET("net_worth.page") assert "Net worth" in result assert "Assets" in result @@ -39,10 +36,6 @@ def test_chart( asset_valuation: AssetValuation, transactions: list[Transaction], ) -> None: - _ = account - _ = asset_valuation - _ = transactions - result, headers = web_client.GET("net_worth.chart") assert headers["HX-Push-URL"] == web_client.url_for( "net_worth.page", @@ -57,10 +50,6 @@ def test_dashboard( asset_valuation: AssetValuation, transactions: list[Transaction], ) -> None: - _ = account - _ = asset_valuation - _ = transactions - result, _ = web_client.GET("net_worth.dashboard") assert "Net worth" in result assert "JSON.parse" in result diff --git a/tests/controllers/performance/test_contexts.py b/tests/controllers/performance/test_contexts.py index eac08cf2..195f3035 100644 --- a/tests/controllers/performance/test_contexts.py +++ b/tests/controllers/performance/test_contexts.py @@ -4,13 +4,13 @@ from decimal import Decimal from typing import TYPE_CHECKING +from nummus import sql from nummus.controllers import base, performance from nummus.models.account import AccountCategory from nummus.models.asset import ( Asset, AssetCategory, ) -from nummus.models.base import YIELD_PER from nummus.models.currency import CURRENCY_FORMATS, DEFAULT_CURRENCY if TYPE_CHECKING: @@ -25,10 +25,8 @@ def test_ctx_chart_empty( today: datetime.date, account: Account, - session: orm.Session, ) -> None: - _ = account - ctx = performance.ctx_chart(session, today, "max", "S&P 500", set()) + ctx = performance.ctx_chart(today, "max", "S&P 500", set()) chart: performance.ChartData = { "labels": [today.isoformat()], @@ -55,10 +53,10 @@ def test_ctx_chart_empty( "currency_format": CURRENCY_FORMATS[DEFAULT_CURRENCY], } - query = session.query(Asset.name).order_by(Asset.name) - indices: list[str] = [r[0] for r in query.yield_per(YIELD_PER)] + query = Asset.query(Asset.name).order_by(Asset.name) + indices: list[str] = list(sql.col0(query)) - desc = session.query(Asset.description).where(Asset.name == "S&P 500").one()[0] + desc = sql.one(Asset.query(Asset.description).where(Asset.name == "S&P 500")) target: performance.Context = { "start": today, @@ -79,11 +77,11 @@ def test_ctx_chart_this_year( session: orm.Session, account: Account, ) -> None: - account.category = AccountCategory.INVESTMENT - account.closed = True - session.commit() + with session.begin_nested(): + account.category = AccountCategory.INVESTMENT + account.closed = True - ctx = performance.ctx_chart(session, today, "ytd", "S&P 500", set()) + ctx = performance.ctx_chart(today, "ytd", "S&P 500", set()) assert ctx["start"] == today.replace(month=1, day=1) assert ctx["end"] == today @@ -97,14 +95,12 @@ def test_ctx_chart( transactions: list[Transaction], session: orm.Session, ) -> None: - account.category = AccountCategory.INVESTMENT - session.commit() - _ = asset_valuation - _ = transactions + with session.begin_nested(): + account.category = AccountCategory.INVESTMENT start = today - datetime.timedelta(days=3) end = today + datetime.timedelta(days=3) - ctx = performance.ctx_chart(session, end, "max", "S&P 500", set()) + ctx = performance.ctx_chart(end, "max", "S&P 500", set()) chart: performance.ChartData = { "labels": base.date_labels(start.toordinal(), end.toordinal())[0], @@ -152,13 +148,13 @@ def test_ctx_chart( } query = ( - session.query(Asset.name) + Asset.query(Asset.name) .where(Asset.category == AssetCategory.INDEX) .order_by(Asset.name) ) - indices: list[str] = [r[0] for r in query.yield_per(YIELD_PER)] + indices: list[str] = list(sql.col0(query)) - desc = session.query(Asset.description).where(Asset.name == "S&P 500").one()[0] + desc = sql.one(Asset.query(Asset.description).where(Asset.name == "S&P 500")) target: performance.Context = { "start": start, @@ -181,14 +177,12 @@ def test_ctx_chart_exclude( transactions: list[Transaction], session: orm.Session, ) -> None: - account.category = AccountCategory.INVESTMENT - session.commit() - _ = asset_valuation - _ = transactions + with session.begin_nested(): + account.category = AccountCategory.INVESTMENT start = today - datetime.timedelta(days=3) end = today + datetime.timedelta(days=3) - ctx = performance.ctx_chart(session, end, "max", "S&P 500", {account.id_}) + ctx = performance.ctx_chart(end, "max", "S&P 500", {account.id_}) chart: performance.ChartData = { "labels": base.date_labels(start.toordinal(), end.toordinal())[0], @@ -218,13 +212,14 @@ def test_ctx_chart_exclude( } query = ( - session.query(Asset.name) + Asset.query(Asset.name) .where(Asset.category == AssetCategory.INDEX) .order_by(Asset.name) ) - indices: list[str] = [r[0] for r in query.yield_per(YIELD_PER)] - desc = session.query(Asset.description).where(Asset.name == "S&P 500").one()[0] + indices: list[str] = list(sql.col0(query)) + + desc = sql.one(Asset.query(Asset.description).where(Asset.name == "S&P 500")) target: performance.Context = { "start": start, diff --git a/tests/controllers/performance/test_endpoints.py b/tests/controllers/performance/test_endpoints.py index c6f1efe8..21e49572 100644 --- a/tests/controllers/performance/test_endpoints.py +++ b/tests/controllers/performance/test_endpoints.py @@ -28,10 +28,8 @@ def test_page( asset_valuation: AssetValuation, transactions: list[Transaction], ) -> None: - account.category = AccountCategory.INVESTMENT - session.commit() - _ = asset_valuation - _ = transactions + with session.begin_nested(): + account.category = AccountCategory.INVESTMENT result, _ = web_client.GET( ("performance.page", {"index": "Dow Jones Industrial Average"}), @@ -49,10 +47,8 @@ def test_chart( asset_valuation: AssetValuation, transactions: list[Transaction], ) -> None: - account.category = AccountCategory.INVESTMENT - session.commit() - _ = asset_valuation - _ = transactions + with session.begin_nested(): + account.category = AccountCategory.INVESTMENT result, headers = web_client.GET("performance.chart") assert headers["HX-Push-URL"] == web_client.url_for( @@ -70,10 +66,8 @@ def test_dashboard( asset_valuation: AssetValuation, transactions: list[Transaction], ) -> None: - account.category = AccountCategory.INVESTMENT - session.commit() - _ = asset_valuation - _ = transactions + with session.begin_nested(): + account.category = AccountCategory.INVESTMENT result, _ = web_client.GET("performance.dashboard") assert "Investing performance" in result diff --git a/tests/controllers/settings/test_contexts.py b/tests/controllers/settings/test_contexts.py index cdaf90d5..42f727a1 100644 --- a/tests/controllers/settings/test_contexts.py +++ b/tests/controllers/settings/test_contexts.py @@ -1,16 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from nummus.controllers import settings from nummus.models.currency import Currency, DEFAULT_CURRENCY -if TYPE_CHECKING: - from sqlalchemy import orm - -def test_ctx(session: orm.Session) -> None: - ctx = settings.ctx_settings(session) +def test_ctx() -> None: + ctx = settings.ctx_settings() target: settings.SettingsContext = { "currency": DEFAULT_CURRENCY, diff --git a/tests/controllers/settings/test_endpoints.py b/tests/controllers/settings/test_endpoints.py index f09ca4f5..b22e0a9c 100644 --- a/tests/controllers/settings/test_endpoints.py +++ b/tests/controllers/settings/test_endpoints.py @@ -6,8 +6,6 @@ from nummus.models.currency import Currency if TYPE_CHECKING: - from sqlalchemy import orm - from tests.controllers.conftest import WebClient @@ -16,10 +14,10 @@ def test_page(web_client: WebClient) -> None: assert "Base currency" in result -def test_edit_currency(web_client: WebClient, session: orm.Session) -> None: +def test_edit_currency(web_client: WebClient) -> None: result, headers = web_client.PATCH("settings.edit", data={"currency": "CHF"}) assert "snackbar.show" in result assert "All changes saved" in result assert "config" in headers["HX-Trigger"] - assert Config.base_currency(session) == Currency.CHF + assert Config.base_currency() == Currency.CHF diff --git a/tests/controllers/spending/test_contexts.py b/tests/controllers/spending/test_contexts.py index d56727bf..b2f69f13 100644 --- a/tests/controllers/spending/test_contexts.py +++ b/tests/controllers/spending/test_contexts.py @@ -4,6 +4,7 @@ import pytest +from nummus import sql from nummus.controllers import base, spending from nummus.models.account import Account from nummus.models.currency import DEFAULT_CURRENCY @@ -13,13 +14,10 @@ TransactionCategory, TransactionCategoryGroup, ) -from nummus.models.utils import query_count if TYPE_CHECKING: import datetime - from sqlalchemy import orm - from nummus.models.transaction import Transaction @@ -48,7 +46,6 @@ ], ) def test_data_query( - session: orm.Session, account: Account, transactions_spending: list[Transaction], categories: dict[str, int], @@ -62,9 +59,7 @@ def test_data_query( is_income: bool, target: tuple[int, bool], ) -> None: - _ = transactions_spending dat_query = spending.data_query( - session, DEFAULT_CURRENCY, account.uri if include_account else None, period, @@ -75,20 +70,18 @@ def test_data_query( is_income=is_income, ) assert dat_query.any_filters == target[1] - assert query_count(dat_query.final_query) == target[0] + assert sql.count(dat_query.final_query) == target[0] def test_ctx_options( today: datetime.date, - session: orm.Session, account: Account, transactions: list[Transaction], categories: dict[str, int], labels: dict[str, int], ) -> None: - _ = transactions dat_query = spending.DataQuery( - session.query(TransactionSplit), + TransactionSplit.query(), {}, any_filters=False, ) @@ -96,9 +89,9 @@ def test_ctx_options( ctx = spending.ctx_options( dat_query, today, - Account.map_name(session), - base.tranaction_category_groups(session), - Label.map_name(session), + Account.map_name(), + base.tranaction_category_groups(), + Label.map_name(), ) assert ctx["options_account"] == [base.NamePair(account.uri, account.name)] @@ -130,13 +123,12 @@ def test_ctx_options( def test_ctx_options_selected( today: datetime.date, - session: orm.Session, account: Account, categories: dict[str, int], labels: dict[str, int], ) -> None: dat_query = spending.DataQuery( - session.query(TransactionSplit), + TransactionSplit.query(), {}, any_filters=False, ) @@ -144,9 +136,9 @@ def test_ctx_options_selected( ctx = spending.ctx_options( dat_query, today, - Account.map_name(session), - base.tranaction_category_groups(session), - Label.map_name(session), + Account.map_name(), + base.tranaction_category_groups(), + Label.map_name(), account.uri, TransactionCategory.id_to_uri(categories["other income"]), Label.id_to_uri(labels["engineer"]), @@ -173,11 +165,8 @@ def test_ctx_options_selected( def test_ctx_chart_empty( today: datetime.date, account: Account, - session: orm.Session, ) -> None: - _ = account ctx, title = spending.ctx_chart( - session, today, selected_account=None, selected_category=None, @@ -205,12 +194,9 @@ def test_ctx_chart_empty( def test_ctx_chart( today: datetime.date, - session: orm.Session, transactions_spending: list[Transaction], ) -> None: - _ = transactions_spending ctx, title = spending.ctx_chart( - session, today, None, None, diff --git a/tests/controllers/spending/test_endpoints.py b/tests/controllers/spending/test_endpoints.py index a57af1eb..a3edd185 100644 --- a/tests/controllers/spending/test_endpoints.py +++ b/tests/controllers/spending/test_endpoints.py @@ -13,7 +13,6 @@ def test_page( account: Account, transactions_spending: list[Transaction], ) -> None: - _ = transactions_spending result, _ = web_client.GET("spending.page") assert "Spending" in result assert account.name in result @@ -27,7 +26,6 @@ def test_chart( account: Account, transactions_spending: list[Transaction], ) -> None: - _ = transactions_spending result, _ = web_client.GET("spending.chart") assert "Spending" in result assert account.name in result @@ -41,7 +39,6 @@ def test_dashboard( account: Account, transactions_spending: list[Transaction], ) -> None: - _ = transactions_spending result, _ = web_client.GET("spending.dashboard") assert "Spending" in result assert account.name in result diff --git a/tests/controllers/test_base.py b/tests/controllers/test_base.py index 3258410c..9bf4d906 100644 --- a/tests/controllers/test_base.py +++ b/tests/controllers/test_base.py @@ -24,25 +24,24 @@ from collections.abc import Callable import werkzeug.test - from sqlalchemy import orm from nummus.models.asset import Asset from tests.conftest import RandomStringGenerator from tests.controllers.conftest import HTMLValidator, WebClient -def test_find(session: orm.Session, account: Account) -> None: - assert base.find(session, Account, account.uri) == account +def test_find(account: Account) -> None: + assert base.find(Account, account.uri) == account -def test_find_404(session: orm.Session) -> None: +def test_find_404() -> None: with pytest.raises(exc.http.NotFound): - base.find(session, Account, Account.id_to_uri(0)) + base.find(Account, Account.id_to_uri(0)) -def test_find_400(session: orm.Session) -> None: +def test_find_400() -> None: with pytest.raises(exc.http.BadRequest): - base.find(session, Account, "fake") + base.find(Account, "fake") @pytest.mark.parametrize( @@ -115,7 +114,7 @@ class Fake: ], ids=conftest.id_func, ) -def test_validate_required(func: Callable) -> None: +def test_validate_required(func: Callable[..., object]) -> None: assert func("", is_required=True) == "Required" @@ -132,24 +131,19 @@ def test_validate_string_short() -> None: assert base.validate_string("a", check_length=True) == "2 characters required" -def test_validate_string_no_session() -> None: - with pytest.raises(TypeError): - base.validate_string("abc", no_duplicates=Account.name) - - -def test_validate_string_duplicate(session: orm.Session, account: Account) -> None: +def test_validate_string_duplicate(account: Account) -> None: err = base.validate_string( account.name, - session=session, + cls=Account, no_duplicates=Account.name, ) assert err == "Must be unique" -def test_validate_string_duplicate_self(session: orm.Session, account: Account) -> None: +def test_validate_string_duplicate_self(account: Account) -> None: err = base.validate_string( account.name, - session=session, + cls=Account, no_duplicates=Account.name, no_duplicate_wheres=[Account.id_ != account.id_], ) @@ -170,7 +164,7 @@ def test_validate_date(today: datetime.date, s: str, max_future: int | None) -> ], ids=conftest.id_func, ) -def test_validate_unable_to_parse(func: Callable) -> None: +def test_validate_unable_to_parse(func: Callable[..., object]) -> None: assert func("a") == "Unable to parse" @@ -222,13 +216,12 @@ def test_parse_date( def test_validate_date_duplicate( today: datetime.date, - session: orm.Session, asset_valuation: AssetValuation, ) -> None: err = base.validate_date( asset_valuation.date.isoformat(), today, - session=session, + cls=AssetValuation, no_duplicates=AssetValuation.date_ord, ) assert err == "Must be unique" @@ -327,12 +320,10 @@ def test_error_str( def test_error_empty_field( - session: orm.Session, valid_html: HTMLValidator, ) -> None: - session.add(Account()) try: - session.commit() + Account.create() except exc.IntegrityError as e: html = base.error(e) assert valid_html(html) @@ -342,21 +333,18 @@ def test_error_empty_field( def test_error_unique( - session: orm.Session, account: Account, valid_html: HTMLValidator, ) -> None: - new_account = Account( - name=account.name, - institution=account.institution, - category=account.category, - closed=False, - budgeted=False, - currency=DEFAULT_CURRENCY, - ) - session.add(new_account) try: - session.commit() + Account.create( + name=account.name, + institution=account.institution, + category=account.category, + closed=False, + budgeted=False, + currency=DEFAULT_CURRENCY, + ) except exc.IntegrityError as e: html = base.error(e) assert valid_html(html) @@ -366,13 +354,11 @@ def test_error_unique( def test_error_check( - session: orm.Session, account: Account, valid_html: HTMLValidator, ) -> None: - _ = account try: - session.query(Account).update({"name": "a"}) + Account.query().update({"name": "a"}) except exc.IntegrityError as e: html = base.error(e) assert valid_html(html) @@ -496,10 +482,9 @@ def test_change_redirect_no_htmx(web_client: WebClient) -> None: def test_tranaction_category_groups( - session: orm.Session, categories: dict[str, int], ) -> None: - groups = base.tranaction_category_groups(session) + groups = base.tranaction_category_groups() assert len(groups) == len(TransactionCategoryGroup) assert sum(len(group) for group in groups.values()) == len(categories) diff --git a/tests/controllers/test_import_file.py b/tests/controllers/test_import_file.py index 5ffe4082..e42af863 100644 --- a/tests/controllers/test_import_file.py +++ b/tests/controllers/test_import_file.py @@ -48,7 +48,7 @@ def test_no_file(web_client: WebClient) -> None: ], ) def test_error( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], web_client: WebClient, data_path: Path, file: str, @@ -56,7 +56,6 @@ def test_error( traceback: bool, account: Account, ) -> None: - _ = account path = data_path / file result, _ = web_client.POST( "import_file.import_file", @@ -78,8 +77,6 @@ def test_import_file( account: Account, account_investments: Account, ) -> None: - _ = account - _ = account_investments path = data_path / "transactions_required.csv" result, headers = web_client.POST( "import_file.import_file", @@ -96,8 +93,6 @@ def test_duplicate( account: Account, account_investments: Account, ) -> None: - _ = account - _ = account_investments path = data_path / "transactions_required.csv" web_client.POST( "import_file.import_file", @@ -120,8 +115,6 @@ def test_duplicate_force( account: Account, account_investments: Account, ) -> None: - _ = account - _ = account_investments path = data_path / "transactions_required.csv" web_client.POST( "import_file.import_file", diff --git a/tests/controllers/transaction_categories/test_contexts.py b/tests/controllers/transaction_categories/test_contexts.py index 7157e198..eaea7a74 100644 --- a/tests/controllers/transaction_categories/test_contexts.py +++ b/tests/controllers/transaction_categories/test_contexts.py @@ -1,26 +1,21 @@ from __future__ import annotations -from typing import TYPE_CHECKING - +from nummus import sql from nummus.controllers import base, transaction_categories -from nummus.models.base import YIELD_PER from nummus.models.transaction_category import ( TransactionCategory, TransactionCategoryGroup, ) -if TYPE_CHECKING: - from sqlalchemy import orm - -def test_ctx(session: orm.Session) -> None: - groups = transaction_categories.ctx_categories(session) +def test_ctx() -> None: + groups = transaction_categories.ctx_categories() exclude = {"securities traded"} for g in TransactionCategoryGroup: query = ( - session.query(TransactionCategory) + TransactionCategory.query() .where( TransactionCategory.group == g, TransactionCategory.name.not_in(exclude), @@ -28,7 +23,6 @@ def test_ctx(session: orm.Session) -> None: .order_by(TransactionCategory.name) ) target: list[base.NamePair] = [ - base.NamePair(t_cat.uri, t_cat.emoji_name) - for t_cat in query.yield_per(YIELD_PER) + base.NamePair(t_cat.uri, t_cat.emoji_name) for t_cat in sql.yield_(query) ] assert groups[g] == target diff --git a/tests/controllers/transaction_categories/test_endpoints.py b/tests/controllers/transaction_categories/test_endpoints.py index 05baa14b..b7c52b13 100644 --- a/tests/controllers/transaction_categories/test_endpoints.py +++ b/tests/controllers/transaction_categories/test_endpoints.py @@ -4,6 +4,7 @@ import pytest +from nummus import sql from nummus.controllers import base from nummus.models.transaction import TransactionSplit from nummus.models.transaction_category import ( @@ -12,7 +13,6 @@ ) if TYPE_CHECKING: - from sqlalchemy import orm from nummus.models.transaction import Transaction from tests.controllers.conftest import WebClient @@ -62,7 +62,6 @@ def test_new_get(web_client: WebClient) -> None: def test_new( web_client: WebClient, rand_str: str, - session: orm.Session, ) -> None: form = { "name": rand_str, @@ -75,11 +74,10 @@ def test_new( assert f"Created category {rand_str}" in result assert "category" in headers["HX-Trigger"] - t_cat = ( - session.query(TransactionCategory) - .where(TransactionCategory.emoji_name == rand_str) - .one() + query = TransactionCategory.query().where( + TransactionCategory.emoji_name == rand_str, ) + t_cat = sql.one(query) assert t_cat.group == TransactionCategoryGroup.EXPENSE assert t_cat.is_profit_loss assert t_cat.essential_spending @@ -133,12 +131,10 @@ def test_category_delete_locked( def test_category_delete_unlocked( web_client: WebClient, categories: dict[str, int], - session: orm.Session, transactions_spending: list[Transaction], ) -> None: - _ = transactions_spending t_split = ( - session.query(TransactionSplit) + TransactionSplit.query() .where(TransactionSplit.category_id == categories["groceries"]) .first() ) @@ -153,20 +149,19 @@ def test_category_delete_unlocked( assert "category" in headers["HX-Trigger"] t_cat = ( - session.query(TransactionCategory) + TransactionCategory.query() .where(TransactionCategory.name == "groceries") .one_or_none() ) assert t_cat is None - session.refresh(t_split) + t_split.refresh() assert t_split.category_id == categories["uncategorized"] def test_category_edit_unlocked( web_client: WebClient, categories: dict[str, int], - session: orm.Session, ) -> None: uri = TransactionCategory.id_to_uri(categories["groceries"]) @@ -178,11 +173,8 @@ def test_category_edit_unlocked( assert "All changes saved" in result assert "category" in headers["HX-Trigger"] - t_cat = ( - session.query(TransactionCategory) - .where(TransactionCategory.name == "food") - .one() - ) + query = TransactionCategory.query().where(TransactionCategory.name == "food") + t_cat = sql.one(query) assert t_cat.emoji_name == "Food" assert t_cat.group == TransactionCategoryGroup.EXPENSE assert not t_cat.is_profit_loss @@ -192,7 +184,6 @@ def test_category_edit_unlocked( def test_category_edit_locked( web_client: WebClient, categories: dict[str, int], - session: orm.Session, ) -> None: uri = TransactionCategory.id_to_uri(categories["uncategorized"]) @@ -204,11 +195,10 @@ def test_category_edit_locked( assert "All changes saved" in result assert "category" in headers["HX-Trigger"] - t_cat = ( - session.query(TransactionCategory) - .where(TransactionCategory.name == "uncategorized") - .one() + query = TransactionCategory.query().where( + TransactionCategory.name == "uncategorized", ) + t_cat = sql.one(query) assert t_cat.emoji_name == "Uncategorized 🤷" assert t_cat.group == TransactionCategoryGroup.OTHER assert not t_cat.is_profit_loss diff --git a/tests/controllers/transactions/test_contexts.py b/tests/controllers/transactions/test_contexts.py index f3279152..25fcdcb8 100644 --- a/tests/controllers/transactions/test_contexts.py +++ b/tests/controllers/transactions/test_contexts.py @@ -6,12 +6,11 @@ import pytest -from nummus import utils +from nummus import sql, utils from nummus.controllers import base from nummus.controllers import transactions as txn_controller from nummus.models.account import Account from nummus.models.asset import Asset -from nummus.models.base import YIELD_PER from nummus.models.currency import CURRENCY_FORMATS, DEFAULT_CURRENCY from nummus.models.label import Label, LabelLink from nummus.models.transaction import TransactionSplit @@ -19,7 +18,6 @@ TransactionCategory, TransactionCategoryGroup, ) -from nummus.models.utils import query_count if TYPE_CHECKING: from sqlalchemy import orm @@ -43,7 +41,6 @@ ], ) def test_table_query( - session: orm.Session, account: Account, transactions: list[Transaction], categories: dict[str, int], @@ -55,9 +52,7 @@ def test_table_query( uncleared: bool, target: tuple[int, bool], ) -> None: - _ = transactions tbl_query = txn_controller.table_query( - session, None, account.uri if include_account else None, period, @@ -67,7 +62,7 @@ def test_table_query( uncleared=uncleared, ) assert tbl_query.any_filters == target[1] - assert query_count(tbl_query.final_query) == target[0] + assert sql.count(tbl_query.final_query) == target[0] def test_ctx_txn( @@ -91,14 +86,11 @@ def test_ctx_txn( def test_ctx_split( - session: orm.Session, transactions: list[Transaction], labels: dict[str, int], ) -> None: - query = session.query(Asset).with_entities(Asset.id_, Asset.name, Asset.ticker) - assets: dict[int, tuple[str, str | None]] = { - r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER) - } + query = Asset.query(Asset.id_, Asset.name, Asset.ticker) + assets = sql.to_dict_tuple(query) txn = transactions[0] t_split = txn.splits[0] @@ -123,15 +115,12 @@ def test_ctx_split( def test_ctx_split_asset( - session: orm.Session, asset: Asset, transactions: list[Transaction], labels: dict[str, int], ) -> None: - query = session.query(Asset).with_entities(Asset.id_, Asset.name, Asset.ticker) - assets: dict[int, tuple[str, str | None]] = { - r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER) - } + query = Asset.query(Asset.id_, Asset.name, Asset.ticker) + assets = sql.to_dict_tuple(query) txn = transactions[1] t_split = txn.splits[0] @@ -156,23 +145,20 @@ def test_ctx_split_asset( def test_ctx_row( - session: orm.Session, account: Account, transactions: list[Transaction], labels: dict[str, int], ) -> None: - query = session.query(Asset).with_entities(Asset.id_, Asset.name, Asset.ticker) - assets: dict[int, tuple[str, str | None]] = { - r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER) - } + query = Asset.query(Asset.id_, Asset.name, Asset.ticker) + assets = sql.to_dict_tuple(query) txn = transactions[0] t_split = txn.splits[0] ctx = txn_controller.ctx_row( t_split, assets, - Account.map_name(session), - TransactionCategory.map_name_emoji(session), + Account.map_name(), + TransactionCategory.map_name_emoji(), {labels["engineer"]: "engineer"}, set(), CURRENCY_FORMATS[DEFAULT_CURRENCY], @@ -199,14 +185,12 @@ def test_ctx_row( def test_ctx_options( today: datetime.date, - session: orm.Session, account: Account, transactions: list[Transaction], categories: dict[str, int], ) -> None: - _ = transactions tbl_query = txn_controller.TableQuery( - session.query(TransactionSplit), + TransactionSplit.query(), {}, any_filters=False, ) @@ -214,8 +198,8 @@ def test_ctx_options( ctx = txn_controller.ctx_options( tbl_query, today, - Account.map_name(session), - base.tranaction_category_groups(session), + Account.map_name(), + base.tranaction_category_groups(), None, None, ) @@ -251,7 +235,7 @@ def test_ctx_options_selected( categories: dict[str, int], ) -> None: tbl_query = txn_controller.TableQuery( - session.query(TransactionSplit), + TransactionSplit.query(), {}, any_filters=False, ) @@ -259,8 +243,8 @@ def test_ctx_options_selected( ctx = txn_controller.ctx_options( tbl_query, today, - Account.map_name(session), - base.tranaction_category_groups(session), + Account.map_name(), + base.tranaction_category_groups(), account.uri, TransactionCategory.id_to_uri(categories["other income"]), ) @@ -332,21 +316,17 @@ def test_table_title( assert title == target -def test_table_results_empty( - session: orm.Session, -) -> None: - query = session.query(Asset).with_entities(Asset.id_, Asset.name, Asset.ticker) - assets: dict[int, tuple[str, str | None]] = { - r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER) - } +def test_table_results_empty() -> None: + query = Asset.query(Asset.id_, Asset.name, Asset.ticker) + assets = sql.to_dict_tuple(query) - accounts = Account.map_name(session) + accounts = Account.map_name() result = txn_controller._table_results( - session.query(TransactionSplit), + TransactionSplit.query(), assets, accounts, - TransactionCategory.map_name_emoji(session), - Label.map_name(session), + TransactionCategory.map_name_emoji(), + Label.map_name(), {}, dict.fromkeys(accounts, CURRENCY_FORMATS[DEFAULT_CURRENCY]), ) @@ -357,16 +337,14 @@ def test_table_results( session: orm.Session, transactions: list[Transaction], ) -> None: - query = session.query(Asset).with_entities(Asset.id_, Asset.name, Asset.ticker) - assets: dict[int, tuple[str, str | None]] = { - r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER) - } - accounts = Account.map_name(session) - labels = Label.map_name(session) - categories = TransactionCategory.map_name_emoji(session) + query = Asset.query(Asset.id_, Asset.name, Asset.ticker) + assets = sql.to_dict_tuple(query) + accounts = Account.map_name() + labels = Label.map_name() + categories = TransactionCategory.map_name_emoji() result = txn_controller._table_results( - session.query(TransactionSplit).order_by(TransactionSplit.date_ord), + TransactionSplit.query().order_by(TransactionSplit.date_ord), assets, accounts, categories, @@ -385,7 +363,7 @@ def test_table_results( categories, { label_id: labels[label_id] - for label_id, in session.query(LabelLink.label_id).where( + for label_id, in LabelLink.query(LabelLink.label_id).where( LabelLink.t_split_id == txn.splits[0].id_, ) }, @@ -399,9 +377,8 @@ def test_table_results( assert result == target -def test_ctx_table_empty(today: datetime.date, session: orm.Session) -> None: +def test_ctx_table_empty(today: datetime.date) -> None: ctx, title = txn_controller.ctx_table( - session, today, None, None, @@ -431,11 +408,9 @@ def test_ctx_table_empty(today: datetime.date, session: orm.Session) -> None: def test_ctx_table( today: datetime.date, - session: orm.Session, transactions: list[Transaction], ) -> None: ctx, title = txn_controller.ctx_table( - session, today, None, None, @@ -466,12 +441,10 @@ def test_ctx_table( def test_ctx_table_paging( today: datetime.date, monkeypatch: pytest.MonkeyPatch, - session: orm.Session, transactions: list[Transaction], ) -> None: monkeypatch.setattr(txn_controller, "PAGE_LEN", 2) ctx, _ = txn_controller.ctx_table( - session, today, None, None, @@ -489,12 +462,9 @@ def test_ctx_table_paging( def test_ctx_table_search( today: datetime.date, - session: orm.Session, transactions: list[Transaction], ) -> None: - _ = transactions ctx, _ = txn_controller.ctx_table( - session, today, "rent", None, @@ -512,12 +482,9 @@ def test_ctx_table_search( def test_ctx_table_search_paging( today: datetime.date, - session: orm.Session, transactions: list[Transaction], ) -> None: - _ = transactions ctx, _ = txn_controller.ctx_table( - session, today, "rent", None, diff --git a/tests/controllers/transactions/test_endpoints.py b/tests/controllers/transactions/test_endpoints.py index 6ed51b22..ecd7db78 100644 --- a/tests/controllers/transactions/test_endpoints.py +++ b/tests/controllers/transactions/test_endpoints.py @@ -6,9 +6,9 @@ import flask import pytest +from nummus import sql from nummus.controllers import base from nummus.models.account import Account -from nummus.models.base import YIELD_PER from nummus.models.label import Label, LabelLink from nummus.models.transaction import Transaction, TransactionSplit from nummus.models.transaction_category import TransactionCategory @@ -50,7 +50,6 @@ def test_table_options( account: Account, transactions: list[Transaction], ) -> None: - _ = transactions result, _ = web_client.GET("transactions.table_options") assert 'name="period"' in result assert 'name="category"' in result @@ -151,7 +150,6 @@ def test_new_put_bad_date( def test_new( today: datetime.date, - session: orm.Session, web_client: WebClient, account: Account, categories: dict[str, int], @@ -173,7 +171,7 @@ def test_new( assert "Transaction created" in result assert "account" in headers["HX-Trigger"] - txn = session.query(Transaction).one() + txn = Transaction.one() assert txn.account_id == account.id_ assert txn.date == today assert txn.amount == round(rand_real, 2) @@ -186,13 +184,12 @@ def test_new( assert t_split.category_id == categories["other income"] assert t_split.memo is None - labels = Label.map_name(session) + labels = Label.map_name() assert not labels def test_new_split( today: datetime.date, - session: orm.Session, web_client: WebClient, account: Account, categories: dict[str, int], @@ -221,7 +218,7 @@ def test_new_split( assert "Transaction created" in result assert "account" in headers["HX-Trigger"] - txn = session.query(Transaction).one() + txn = Transaction.one() assert txn.account_id == account.id_ assert txn.date == today assert txn.amount == round(rand_real, 2) @@ -230,23 +227,27 @@ def test_new_split( splits = txn.splits assert len(splits) == 2 - labels = Label.map_name(session) + labels = Label.map_name() assert len(labels) == 2 t_split = splits[0] assert t_split.amount == Decimal(10) assert t_split.category_id == categories["other income"] assert t_split.memo is None - query = session.query(LabelLink.label_id).where(LabelLink.t_split_id == t_split.id_) - split_labels = {labels[label_id] for label_id, in query.yield_per(YIELD_PER)} + query = LabelLink.query(LabelLink.label_id).where( + LabelLink.t_split_id == t_split.id_, + ) + split_labels = {labels[label_id] for label_id in sql.col0(query)} assert split_labels == {"Engineer", "Salary"} t_split = splits[1] assert t_split.amount == round(rand_real - 10, 2) assert t_split.category_id == categories["groceries"] assert t_split.memo == "bananas" - query = session.query(LabelLink.label_id).where(LabelLink.t_split_id == t_split.id_) - split_labels = {labels[label_id] for label_id, in query.yield_per(YIELD_PER)} + query = LabelLink.query(LabelLink.label_id).where( + LabelLink.t_split_id == t_split.id_, + ) + split_labels = {labels[label_id] for label_id in sql.col0(query)} assert not split_labels @@ -345,8 +346,8 @@ def test_transaction_get_uncleared( transactions: list[Transaction], ) -> None: txn = transactions[0] - txn.cleared = False - session.commit() + with session.begin_nested(): + txn.cleared = False result, _ = web_client.GET(("transactions.transaction", {"uri": txn.uri})) assert "Edit transaction" in result @@ -375,8 +376,8 @@ def test_transaction_clear( transactions: list[Transaction], ) -> None: txn = transactions[0] - txn.cleared = False - session.commit() + with session.begin_nested(): + txn.cleared = False result, headers = web_client.PATCH(("transactions.transaction", {"uri": txn.uri})) assert "snackbar.show" in result @@ -386,12 +387,9 @@ def test_transaction_clear( session.refresh(txn) assert txn.cleared - t = ( - session.query(TransactionSplit) - .where(TransactionSplit.parent_id == txn.id_) - .one() - ) - assert t.cleared + query = TransactionSplit.query().where(TransactionSplit.parent_id == txn.id_) + t_split = sql.one(query) + assert t_split.cleared def test_transaction_delete_uncleared( @@ -400,23 +398,23 @@ def test_transaction_delete_uncleared( transactions: list[Transaction], ) -> None: txn = transactions[0] - txn.cleared = False - session.commit() + with session.begin_nested(): + txn.cleared = False result, headers = web_client.DELETE(("transactions.transaction", {"uri": txn.uri})) assert "snackbar.show" in result assert f"Transaction on {txn.date} deleted" in result assert "account" in headers["HX-Trigger"] - t = session.query(Transaction).where(Transaction.id_ == txn.id_).one_or_none() + t = Transaction.query().where(Transaction.id_ == txn.id_).one_or_none() assert t is None - t = ( - session.query(TransactionSplit) + t_split = ( + TransactionSplit.query() .where(TransactionSplit.parent_id == txn.id_) .one_or_none() ) - assert t is None + assert t_split is None def test_transaction_delete( @@ -601,21 +599,23 @@ def test_validation( @pytest.mark.parametrize( - ("split_amount", "split", "target"), + ("split_amount", "split", "include_account", "target"), [ # Just amount with a single split is okay - ([], False, ""), - (["10"], False, ""), - (["11"], False, "Remove $1.00 from splits"), - (["9"], False, "Assign $1.00 to splits"), - (["9"], True, "Assign $1.00 to splits"), + ([], False, False, ""), + (["10"], False, False, ""), + (["11"], False, False, "Remove $1.00 from splits"), + (["9"], False, False, "Assign $1.00 to splits"), + (["9"], True, True, "Assign $1.00 to splits"), ], ) def test_validation_amounts( flask_app: flask.Flask, web_client: WebClient, + account: Account, split_amount: list[str], split: bool, + include_account: bool, target: str, ) -> None: result, _ = web_client.GET( @@ -625,6 +625,7 @@ def test_validation_amounts( "amount": "10", "split-amount": split_amount, "split": split, + "account": account.uri if include_account else "", }, ), ) diff --git a/tests/encryption/test_aes.py b/tests/encryption/test_aes.py index 422ada09..9eb912ad 100644 --- a/tests/encryption/test_aes.py +++ b/tests/encryption/test_aes.py @@ -15,10 +15,12 @@ try: from nummus.encryption.aes import EncryptionAES as Encryption except ImportError: - NO_ENCRYPTION = True + no_encryption = True from nummus.encryption.base import NoEncryption as Encryption else: - NO_ENCRYPTION = False + no_encryption = False + +NO_ENCRYPTION = no_encryption @pytest.fixture diff --git a/tests/health_checks/test_base.py b/tests/health_checks/test_base.py index a857b642..8dddc7c6 100644 --- a/tests/health_checks/test_base.py +++ b/tests/health_checks/test_base.py @@ -4,13 +4,12 @@ import pytest +from nummus import sql from nummus.health_checks.base import HealthCheck from nummus.health_checks.top import HEALTH_CHECKS from nummus.models.health_checks import HealthCheckIssue -from nummus.models.utils import query_count if TYPE_CHECKING: - from sqlalchemy import orm from tests.conftest import RandomStringGenerator @@ -20,23 +19,22 @@ class MockCheck(HealthCheck): _SEVERE = True @override - def test(self, s: orm.Session) -> None: - self._commit_issues(s, {}) + def test(self) -> None: + self._commit_issues({}) @pytest.fixture def issues( - session: orm.Session, rand_str_generator: RandomStringGenerator, ) -> list[tuple[str, int]]: value_0 = rand_str_generator() value_1 = rand_str_generator() c = MockCheck() d = {value_0: "msg 0", value_1: "msg 1"} - c._commit_issues(session, d) - c.ignore(session, [value_0]) + c._commit_issues(d) + c.ignore([value_0]) - return [(i.value, i.id_) for i in session.query(HealthCheckIssue).all()] + return [(i.value, i.id_) for i in HealthCheckIssue.all()] def test_init_properties() -> None: @@ -57,7 +55,6 @@ def test_any_issues(rand_str: str) -> None: @pytest.mark.parametrize("no_ignores", [False, True]) def test_commit_issues( - session: orm.Session, rand_str_generator: RandomStringGenerator, no_ignores: bool, ) -> None: @@ -65,17 +62,19 @@ def test_commit_issues( value_1 = rand_str_generator() c = MockCheck(no_ignores=no_ignores) d = {value_0: "msg 0", value_1: "msg 1"} - c._commit_issues(session, d) - c.ignore(session, [value_0]) + c._commit_issues(d) + c.ignore([value_0]) # Refresh c.issues - c._commit_issues(session, d) + c._commit_issues(d) - i_0 = session.query(HealthCheckIssue).where(HealthCheckIssue.value == value_0).one() + query = HealthCheckIssue.query().where(HealthCheckIssue.value == value_0) + i_0 = sql.one(query) assert i_0.check == MockCheck.name() assert i_0.msg == "msg 0" assert i_0.ignore - i_1 = session.query(HealthCheckIssue).where(HealthCheckIssue.value == value_1).one() + query = HealthCheckIssue.query().where(HealthCheckIssue.value == value_1) + i_1 = sql.one(query) assert i_1.check == MockCheck.name() assert i_1.msg == "msg 1" assert not i_1.ignore @@ -88,31 +87,24 @@ def test_commit_issues( assert c.issues == target -def test_ignore_empty(session: orm.Session, rand_str: str) -> None: - MockCheck.ignore(session, {rand_str}) - assert query_count(session.query(HealthCheckIssue)) == 0 +def test_ignore_empty(rand_str: str) -> None: + MockCheck.ignore({rand_str}) + assert not sql.any_(HealthCheckIssue.query()) def test_ignore( - session: orm.Session, issues: list[tuple[str, int]], ) -> None: - MockCheck.ignore(session, [issues[0][0]]) - i = ( - session.query(HealthCheckIssue) - .where(HealthCheckIssue.id_ == issues[0][1]) - .one() - ) + MockCheck.ignore([issues[0][0]]) + query = HealthCheckIssue.query().where(HealthCheckIssue.id_ == issues[0][1]) + i = sql.one(query) assert i.check == MockCheck.name() assert i.value == issues[0][0] assert i.msg == "msg 0" assert i.ignore - i = ( - session.query(HealthCheckIssue) - .where(HealthCheckIssue.id_ == issues[1][1]) - .one() - ) + query = HealthCheckIssue.query().where(HealthCheckIssue.id_ == issues[1][1]) + i = sql.one(query) assert i.check == MockCheck.name() assert i.value == issues[1][0] assert i.msg == "msg 1" diff --git a/tests/health_checks/test_category_direction.py b/tests/health_checks/test_category_direction.py index 6f95d2bf..af612fc0 100644 --- a/tests/health_checks/test_category_direction.py +++ b/tests/health_checks/test_category_direction.py @@ -5,11 +5,11 @@ import pytest +from nummus import sql from nummus.health_checks.category_direction import CategoryDirection from nummus.models.currency import CURRENCY_FORMATS, DEFAULT_CURRENCY from nummus.models.health_checks import HealthCheckIssue from nummus.models.transaction import Transaction, TransactionSplit -from nummus.models.utils import query_count if TYPE_CHECKING: import datetime @@ -19,9 +19,9 @@ from nummus.models.account import Account -def test_empty(session: orm.Session) -> None: +def test_empty() -> None: c = CategoryDirection() - c.test(session) + c.test() assert c.issues == {} @@ -42,27 +42,26 @@ def test_check( amount: Decimal, rand_str: str, ) -> None: - txn = Transaction( - account_id=account.id_, - date=today, - amount=amount, - statement=rand_str, - ) - t_split = TransactionSplit( - parent=txn, - amount=txn.amount, - category_id=categories[category], - ) - session.add_all((txn, t_split)) - session.commit() + with session.begin_nested(): + txn = Transaction.create( + account_id=account.id_, + date=today, + amount=amount, + statement=rand_str, + ) + t_split = TransactionSplit.create( + parent=txn, + amount=txn.amount, + category_id=categories[category], + ) t_uri = t_split.uri c = CategoryDirection() - c.test(session) + c.test() - assert query_count(session.query(HealthCheckIssue)) == 1 + assert sql.count(HealthCheckIssue.query()) == 1 - i = session.query(HealthCheckIssue).one() + i = HealthCheckIssue.one() assert i.check == c.name() assert i.value == t_uri uri = i.uri diff --git a/tests/health_checks/test_database_integrity.py b/tests/health_checks/test_database_integrity.py index a96f7501..e950eda8 100644 --- a/tests/health_checks/test_database_integrity.py +++ b/tests/health_checks/test_database_integrity.py @@ -1,27 +1,27 @@ from __future__ import annotations +import textwrap from typing import TYPE_CHECKING import sqlalchemy -from pandas.core.indexes.api import textwrap +from nummus import sql from nummus.health_checks.database_integrity import DatabaseIntegrity from nummus.models.config import Config from nummus.models.health_checks import HealthCheckIssue -from nummus.models.utils import query_count if TYPE_CHECKING: from sqlalchemy import orm -def test_empty(session: orm.Session) -> None: +def test_empty() -> None: c = DatabaseIntegrity() - c.test(session) + c.test() assert c.issues == {} def test_corrupt(session: orm.Session) -> None: - n = session.query(Config).update({"value": "abc"}) + n = Config.query().update({"value": "abc"}) # Simulating a corrupt database is difficult # Instead update the Account schema to have unique constraint # PRAGMA integrity_check should catch duplicates @@ -45,18 +45,17 @@ def test_corrupt(session: orm.Session) -> None: session.commit() c = DatabaseIntegrity() - c.test(session) + c.test() - assert query_count(session.query(HealthCheckIssue)) == n - 1 + assert sql.count(HealthCheckIssue.query()) == n - 1 - i = session.query(HealthCheckIssue).first() + i = HealthCheckIssue.first() assert i is not None assert i.check == c.name() assert i.value == "0" # The balanced $100 transfer also on this day will not show up target = { - i.uri: f"non-unique entry in index {index_name}" - for i in session.query(HealthCheckIssue).all() + i.uri: f"non-unique entry in index {index_name}" for i in HealthCheckIssue.all() } assert c.issues == target diff --git a/tests/health_checks/test_duplicate_transactions.py b/tests/health_checks/test_duplicate_transactions.py index 1077d3e8..9de6b4c5 100644 --- a/tests/health_checks/test_duplicate_transactions.py +++ b/tests/health_checks/test_duplicate_transactions.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING +from nummus import sql from nummus.health_checks.duplicate_transactions import DuplicateTransactions from nummus.models.currency import ( CURRENCY_FORMATS, @@ -9,57 +10,51 @@ ) from nummus.models.health_checks import HealthCheckIssue from nummus.models.transaction import Transaction, TransactionSplit -from nummus.models.utils import query_count if TYPE_CHECKING: from sqlalchemy import orm -def test_empty(session: orm.Session) -> None: +def test_empty() -> None: c = DuplicateTransactions() - c.test(session) + c.test() assert c.issues == {} def test_no_issues( - session: orm.Session, transactions: list[Transaction], ) -> None: - _ = transactions c = DuplicateTransactions() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 0 + c.test() + assert not sql.any_(HealthCheckIssue.query()) def test_duplicate( session: orm.Session, transactions: list[Transaction], ) -> None: - _ = transactions - txn_to_copy = transactions[0] # Fund account on 3 days before today - txn = Transaction( - account_id=txn_to_copy.account_id, - date=txn_to_copy.date, - amount=txn_to_copy.amount, - statement=txn_to_copy.statement, - ) - t_split = TransactionSplit( - parent=txn, - amount=txn.amount, - category_id=txn_to_copy.splits[0].category_id, - ) - session.add_all((txn, t_split)) - session.commit() + with session.begin_nested(): + txn = Transaction.create( + account_id=txn_to_copy.account_id, + date=txn_to_copy.date, + amount=txn_to_copy.amount, + statement=txn_to_copy.statement, + ) + TransactionSplit.create( + parent=txn, + amount=txn.amount, + category_id=txn_to_copy.splits[0].category_id, + ) c = DuplicateTransactions() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 1 + c.test() + assert HealthCheckIssue.count() == 1 - i = session.query(HealthCheckIssue).one() + i = HealthCheckIssue.one() assert i.check == c.name() amount_raw = Transaction.amount.type.process_bind_param(txn.amount, None) assert i.value == f"{txn.account_id}.{txn.date_ord}.{amount_raw}" diff --git a/tests/health_checks/test_empty_fields.py b/tests/health_checks/test_empty_fields.py index a6132252..537832ec 100644 --- a/tests/health_checks/test_empty_fields.py +++ b/tests/health_checks/test_empty_fields.py @@ -2,10 +2,10 @@ from typing import TYPE_CHECKING +from nummus import sql from nummus.health_checks.empty_fields import EmptyFields from nummus.models.health_checks import HealthCheckIssue from nummus.models.transaction_category import TransactionCategory -from nummus.models.utils import query_count if TYPE_CHECKING: from sqlalchemy import orm @@ -15,20 +15,18 @@ from nummus.models.transaction import Transaction -def test_empty(session: orm.Session) -> None: +def test_empty() -> None: c = EmptyFields() - c.test(session) + c.test() assert c.issues == {} def test_no_issues( - session: orm.Session, transactions: list[Transaction], ) -> None: - _ = transactions c = EmptyFields() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 0 + c.test() + assert not sql.any_(HealthCheckIssue.query()) def test_no_account_number( @@ -36,14 +34,13 @@ def test_no_account_number( account: Account, transactions: list[Transaction], ) -> None: - account.number = None - session.commit() - _ = transactions + with session.begin_nested(): + account.number = None c = EmptyFields() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 1 + c.test() + assert HealthCheckIssue.count() == 1 - i = session.query(HealthCheckIssue).one() + i = HealthCheckIssue.one() assert i.check == c.name() assert i.value == f"{account.uri}.number" uri = i.uri @@ -57,14 +54,13 @@ def test_no_asset_description( asset: Asset, transactions: list[Transaction], ) -> None: - asset.description = None - session.commit() - _ = transactions + with session.begin_nested(): + asset.description = None c = EmptyFields() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 1 + c.test() + assert HealthCheckIssue.count() == 1 - i = session.query(HealthCheckIssue).one() + i = HealthCheckIssue.one() assert i.check == c.name() assert i.value == f"{asset.uri}.description" uri = i.uri @@ -78,14 +74,14 @@ def test_no_txn_payee( account: Account, transactions: list[Transaction], ) -> None: - txn = transactions[0] - txn.payee = None - session.commit() + with session.begin_nested(): + txn = transactions[0] + txn.payee = None c = EmptyFields() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 1 + c.test() + assert HealthCheckIssue.count() == 1 - i = session.query(HealthCheckIssue).one() + i = HealthCheckIssue.one() assert i.check == c.name() assert i.value == f"{txn.uri}.payee" uri = i.uri @@ -99,14 +95,14 @@ def test_uncategorized( account: Account, transactions: list[Transaction], ) -> None: - t_split = transactions[0].splits[0] - t_split.category_id = TransactionCategory.uncategorized(session)[0] - session.commit() + with session.begin_nested(): + t_split = transactions[0].splits[0] + t_split.category_id = TransactionCategory.uncategorized()[0] c = EmptyFields() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 1 + c.test() + assert HealthCheckIssue.count() == 1 - i = session.query(HealthCheckIssue).one() + i = HealthCheckIssue.one() assert i.check == c.name() assert i.value == f"{t_split.uri}.category" uri = i.uri diff --git a/tests/health_checks/test_missing_asset_link.py b/tests/health_checks/test_missing_asset_link.py index e9bb4a53..42a3c332 100644 --- a/tests/health_checks/test_missing_asset_link.py +++ b/tests/health_checks/test_missing_asset_link.py @@ -6,7 +6,6 @@ from nummus.health_checks.missing_asset_link import MissingAssetLink from nummus.models.currency import CURRENCY_FORMATS, DEFAULT_CURRENCY from nummus.models.health_checks import HealthCheckIssue -from nummus.models.utils import query_count if TYPE_CHECKING: from sqlalchemy import orm @@ -16,20 +15,18 @@ from nummus.models.transaction import Transaction -def test_empty(session: orm.Session) -> None: +def test_empty() -> None: c = MissingAssetLink() - c.test(session) + c.test() assert c.issues == {} def test_no_issues( - session: orm.Session, transactions: list[Transaction], ) -> None: - _ = transactions c = MissingAssetLink() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 0 + c.test() + assert HealthCheckIssue.count() == 0 def test_missing_link( @@ -37,16 +34,15 @@ def test_missing_link( account: Account, transactions: list[Transaction], ) -> None: - t_split = transactions[-1].splits[0] - t_split.asset_id = None - t_split.asset_quantity_unadjusted = None - session.commit() - _ = transactions + with session.begin_nested(): + t_split = transactions[-1].splits[0] + t_split.asset_id = None + t_split.asset_quantity_unadjusted = None c = MissingAssetLink() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 1 + c.test() + assert HealthCheckIssue.count() == 1 - i = session.query(HealthCheckIssue).one() + i = HealthCheckIssue.one() assert i.check == c.name() assert i.value == t_split.uri uri = i.uri @@ -65,16 +61,15 @@ def test_extra_link( asset: Asset, transactions: list[Transaction], ) -> None: - t_split = transactions[0].splits[0] - t_split.asset_id = asset.id_ - t_split.asset_quantity_unadjusted = Decimal() - session.commit() - _ = transactions + with session.begin_nested(): + t_split = transactions[0].splits[0] + t_split.asset_id = asset.id_ + t_split.asset_quantity_unadjusted = Decimal() c = MissingAssetLink() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 1 + c.test() + assert HealthCheckIssue.count() == 1 - i = session.query(HealthCheckIssue).one() + i = HealthCheckIssue.one() assert i.check == c.name() assert i.value == t_split.uri uri = i.uri diff --git a/tests/health_checks/test_missing_valuations.py b/tests/health_checks/test_missing_valuations.py index 58240ad2..065f1207 100644 --- a/tests/health_checks/test_missing_valuations.py +++ b/tests/health_checks/test_missing_valuations.py @@ -2,11 +2,9 @@ from typing import TYPE_CHECKING +from nummus import sql from nummus.health_checks.missing_valuations import MissingAssetValuations from nummus.models.health_checks import HealthCheckIssue -from nummus.models.utils import ( - query_count, -) if TYPE_CHECKING: from sqlalchemy import orm @@ -18,9 +16,9 @@ from nummus.models.transaction import Transaction -def test_empty(session: orm.Session) -> None: +def test_empty() -> None: c = MissingAssetValuations() - c.test(session) + c.test() assert c.issues == {} @@ -29,25 +27,23 @@ def test_no_issues( transactions: list[Transaction], asset_valuation: AssetValuation, ) -> None: - txn = transactions[1] - asset_valuation.date_ord = txn.date_ord - session.commit() + with session.begin_nested(): + txn = transactions[1] + asset_valuation.date_ord = txn.date_ord c = MissingAssetValuations() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 0 + c.test() + assert not sql.any_(HealthCheckIssue.query()) def test_no_valuations( - session: orm.Session, asset: Asset, transactions: list[Transaction], ) -> None: - _ = transactions c = MissingAssetValuations() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 1 + c.test() + assert HealthCheckIssue.count() == 1 - i = session.query(HealthCheckIssue).one() + i = HealthCheckIssue.one() assert i.check == c.name() assert i.value == asset.uri uri = i.uri @@ -57,17 +53,16 @@ def test_no_valuations( def test_no_valuations_before_txn( - session: orm.Session, asset: Asset, transactions: list[Transaction], asset_valuation: AssetValuation, ) -> None: txn = transactions[1] c = MissingAssetValuations() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 1 + c.test() + assert HealthCheckIssue.count() == 1 - i = session.query(HealthCheckIssue).one() + i = HealthCheckIssue.one() assert i.check == c.name() assert i.value == asset.uri uri = i.uri diff --git a/tests/health_checks/test_outlier_asset_price.py b/tests/health_checks/test_outlier_asset_price.py index a29e0d4b..a6757d7f 100644 --- a/tests/health_checks/test_outlier_asset_price.py +++ b/tests/health_checks/test_outlier_asset_price.py @@ -5,10 +5,10 @@ import pytest +from nummus import sql from nummus.health_checks.outlier_asset_price import OutlierAssetPrice from nummus.models.currency import CURRENCY_FORMATS, DEFAULT_CURRENCY from nummus.models.health_checks import HealthCheckIssue -from nummus.models.utils import query_count if TYPE_CHECKING: from sqlalchemy import orm @@ -20,9 +20,9 @@ from nummus.models.transaction import Transaction -def test_empty(session: orm.Session) -> None: +def test_empty() -> None: c = OutlierAssetPrice() - c.test(session) + c.test() assert c.issues == {} @@ -31,15 +31,15 @@ def test_zero_quantity( transactions: list[Transaction], asset_valuation: AssetValuation, ) -> None: - t_split = transactions[1].splits[0] - t_split.asset_quantity_unadjusted = Decimal() - asset_valuation.date_ord = t_split.date_ord - asset_valuation.value = Decimal(10) - session.commit() + with session.begin_nested(): + t_split = transactions[1].splits[0] + t_split.asset_quantity_unadjusted = Decimal() + asset_valuation.date_ord = t_split.date_ord + asset_valuation.value = Decimal(10) c = OutlierAssetPrice() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 0 + c.test() + assert not sql.any_(HealthCheckIssue.query()) @pytest.mark.parametrize( @@ -58,21 +58,21 @@ def test_check( amount: Decimal, target_word: str | None, ) -> None: - t_split = transactions[1].splits[0] - asset_valuation.date_ord = t_split.date_ord - asset_valuation.value = Decimal(10) - t_split.amount = amount - session.commit() + with session.begin_nested(): + t_split = transactions[1].splits[0] + asset_valuation.date_ord = t_split.date_ord + asset_valuation.value = Decimal(10) + t_split.amount = amount c = OutlierAssetPrice() - c.test(session) + c.test() if target_word is None: - assert query_count(session.query(HealthCheckIssue)) == 0 + assert not sql.any_(HealthCheckIssue.query()) return - assert query_count(session.query(HealthCheckIssue)) == 1 + assert HealthCheckIssue.count() == 1 - i = session.query(HealthCheckIssue).one() + i = HealthCheckIssue.one() assert i.check == c.name() assert i.value == t_split.uri uri = i.uri diff --git a/tests/health_checks/test_overdrawn_accounts.py b/tests/health_checks/test_overdrawn_accounts.py index 71ad9c95..a5c4fa4d 100644 --- a/tests/health_checks/test_overdrawn_accounts.py +++ b/tests/health_checks/test_overdrawn_accounts.py @@ -3,10 +3,10 @@ from decimal import Decimal from typing import TYPE_CHECKING +from nummus import sql from nummus.health_checks.overdrawn_accounts import OverdrawnAccounts from nummus.models.currency import CURRENCY_FORMATS, DEFAULT_CURRENCY from nummus.models.health_checks import HealthCheckIssue -from nummus.models.utils import query_count if TYPE_CHECKING: from sqlalchemy import orm @@ -15,21 +15,18 @@ from nummus.models.transaction import Transaction -def test_empty(session: orm.Session) -> None: +def test_empty() -> None: c = OverdrawnAccounts() - c.test(session) + c.test() assert c.issues == {} def test_no_issues( - session: orm.Session, transactions: list[Transaction], ) -> None: - _ = transactions - session.commit() c = OverdrawnAccounts() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 0 + c.test() + assert not sql.any_(HealthCheckIssue.query()) def test_check( @@ -37,13 +34,14 @@ def test_check( account: Account, transactions: list[Transaction], ) -> None: - t_split = transactions[0].splits[0] - t_split.amount = Decimal(-1) + with session.begin_nested(): + t_split = transactions[0].splits[0] + t_split.amount = Decimal(-1) c = OverdrawnAccounts() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 1 + c.test() + assert HealthCheckIssue.count() == 1 - i = session.query(HealthCheckIssue).one() + i = HealthCheckIssue.one() assert i.check == c.name() assert i.value == f"{account.id_}.{t_split.date_ord}" uri = i.uri diff --git a/tests/health_checks/test_typos.py b/tests/health_checks/test_typos.py index 62a7c2e4..8095eee9 100644 --- a/tests/health_checks/test_typos.py +++ b/tests/health_checks/test_typos.py @@ -4,9 +4,9 @@ import pytest +from nummus import sql from nummus.health_checks.typos import Typos from nummus.models.health_checks import HealthCheckIssue -from nummus.models.utils import query_count if TYPE_CHECKING: from sqlalchemy import orm @@ -16,20 +16,18 @@ from nummus.models.transaction import Transaction -def test_empty(session: orm.Session) -> None: +def test_empty() -> None: c = Typos() - c.test(session) + c.test() assert c.issues == {} def test_no_issues( - session: orm.Session, transactions: list[Transaction], ) -> None: - _ = transactions c = Typos() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 0 + c.test() + assert not sql.any_(HealthCheckIssue.query()) def test_mispelled_proper_noun( @@ -37,14 +35,14 @@ def test_mispelled_proper_noun( account: Account, account_savings: Account, ) -> None: - # institution is proper noun so make a almost the same - account.institution = account_savings.institution + "a" - session.commit() + with session.begin_nested(): + # institution is proper noun so make a almost the same + account.institution = account_savings.institution + "a" c = Typos() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 1 + c.test() + assert HealthCheckIssue.count() == 1 - i = session.query(HealthCheckIssue).one() + i = HealthCheckIssue.one() assert i.check == c.name() assert i.value == account.institution uri = i.uri @@ -59,14 +57,14 @@ def test_mispelled( asset: Asset, no_description_typos: bool, ) -> None: - # asset description is checked for dictionary spelling - asset.description = "Banana mispel & 1234 bananas" - session.commit() + with session.begin_nested(): + # asset description is checked for dictionary spelling + asset.description = "Banana mispel & 1234 bananas" c = Typos(no_description_typos=no_description_typos) - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 1 + c.test() + assert HealthCheckIssue.count() == 1 - i = session.query(HealthCheckIssue).one() + i = HealthCheckIssue.one() assert i.check == c.name() assert i.value == "mispel" uri = i.uri diff --git a/tests/health_checks/test_unbalanced_transfers.py b/tests/health_checks/test_unbalanced_transfers.py index 35d996f9..6cbd39cc 100644 --- a/tests/health_checks/test_unbalanced_transfers.py +++ b/tests/health_checks/test_unbalanced_transfers.py @@ -4,10 +4,10 @@ from decimal import Decimal from typing import TYPE_CHECKING +from nummus import sql from nummus.health_checks.unbalanced_transfers import UnbalancedTransfers from nummus.models.currency import CURRENCY_FORMATS, DEFAULT_CURRENCY from nummus.models.health_checks import HealthCheckIssue -from nummus.models.utils import query_count if TYPE_CHECKING: from sqlalchemy import orm @@ -16,20 +16,18 @@ from nummus.models.transaction import Transaction -def test_empty(session: orm.Session) -> None: +def test_empty() -> None: c = UnbalancedTransfers() - c.test(session) + c.test() assert c.issues == {} def test_no_transfers( - session: orm.Session, transactions: list[Transaction], ) -> None: - _ = transactions c = UnbalancedTransfers() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 0 + c.test() + assert not sql.any_(HealthCheckIssue.query()) def test_no_issues( @@ -38,25 +36,25 @@ def test_no_issues( transactions_spending: list[Transaction], categories: dict[str, int], ) -> None: - amount = Decimal(100) - spec = [ - (0, amount), - (0, -amount), - (1, amount), - (1, -amount), - ] - for i, (dt, a) in enumerate(spec): - txn = transactions_spending[i] - txn.date = today + datetime.timedelta(days=dt) - t_split = txn.splits[0] - t_split.category_id = categories["transfers"] - t_split.amount = a - t_split.parent = txn - session.commit() + with session.begin_nested(): + amount = Decimal(100) + spec = [ + (0, amount), + (0, -amount), + (1, amount), + (1, -amount), + ] + for i, (dt, a) in enumerate(spec): + txn = transactions_spending[i] + txn.date = today + datetime.timedelta(days=dt) + t_split = txn.splits[0] + t_split.category_id = categories["transfers"] + t_split.amount = a + t_split.parent = txn c = UnbalancedTransfers() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 0 + c.test() + assert not sql.any_(HealthCheckIssue.query()) def test_wrong_amount( @@ -66,19 +64,19 @@ def test_wrong_amount( transactions_spending: list[Transaction], categories: dict[str, int], ) -> None: - amount = Decimal(100) - spec = [amount, -amount * 2] - for i, a in enumerate(spec): - t_split = transactions_spending[i].splits[0] - t_split.category_id = categories["transfers"] - t_split.amount = a - session.commit() + with session.begin_nested(): + amount = Decimal(100) + spec = [amount, -amount * 2] + for i, a in enumerate(spec): + t_split = transactions_spending[i].splits[0] + t_split.category_id = categories["transfers"] + t_split.amount = a c = UnbalancedTransfers() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 1 + c.test() + assert HealthCheckIssue.count() == 1 - i = session.query(HealthCheckIssue).one() + i = HealthCheckIssue.one() assert i.check == c.name() assert i.value == today.isoformat() uri = i.uri @@ -99,19 +97,19 @@ def test_one_pair( transactions_spending: list[Transaction], categories: dict[str, int], ) -> None: - amount = Decimal(100) - spec = [amount, -amount, -amount] - for i, a in enumerate(spec): - t_split = transactions_spending[i].splits[0] - t_split.category_id = categories["transfers"] - t_split.amount = a - session.commit() + with session.begin_nested(): + amount = Decimal(100) + spec = [amount, -amount, -amount] + for i, a in enumerate(spec): + t_split = transactions_spending[i].splits[0] + t_split.category_id = categories["transfers"] + t_split.amount = a c = UnbalancedTransfers() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 1 + c.test() + assert HealthCheckIssue.count() == 1 - i = session.query(HealthCheckIssue).one() + i = HealthCheckIssue.one() assert i.check == c.name() assert i.value == today.isoformat() uri = i.uri @@ -131,33 +129,30 @@ def test_wrong_date( transactions_spending: list[Transaction], categories: dict[str, int], ) -> None: - amount = Decimal(100) - spec = [ - (0, amount), - (0, -amount), - (0, amount), - (1, -amount), - ] - for i, (dt, a) in enumerate(spec): - txn = transactions_spending[i] - txn.date = today + datetime.timedelta(days=dt) - t_split = txn.splits[0] - t_split.category_id = categories["transfers"] - t_split.amount = a - t_split.parent = txn - amount = Decimal(100) - session.commit() + with session.begin_nested(): + amount = Decimal(100) + spec = [ + (0, amount), + (0, -amount), + (0, amount), + (1, -amount), + ] + for i, (dt, a) in enumerate(spec): + txn = transactions_spending[i] + txn.date = today + datetime.timedelta(days=dt) + t_split = txn.splits[0] + t_split.category_id = categories["transfers"] + t_split.amount = a + t_split.parent = txn + amount = Decimal(100) tomorrow = today + datetime.timedelta(days=1) c = UnbalancedTransfers() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 2 + c.test() + assert HealthCheckIssue.count() == 2 - i = ( - session.query(HealthCheckIssue) - .where(HealthCheckIssue.value == today.isoformat()) - .one() - ) + query = HealthCheckIssue.query().where(HealthCheckIssue.value == today.isoformat()) + i = sql.one(query) assert i.check == c.name() cf = CURRENCY_FORMATS[DEFAULT_CURRENCY] lines = ( @@ -166,11 +161,10 @@ def test_wrong_date( ) assert i.msg == "\n".join(lines) - i = ( - session.query(HealthCheckIssue) - .where(HealthCheckIssue.value == tomorrow.isoformat()) - .one() + query = HealthCheckIssue.query().where( + HealthCheckIssue.value == tomorrow.isoformat(), ) + i = sql.one(query) assert i.check == c.name() assert i.value == tomorrow.isoformat() lines = ( diff --git a/tests/health_checks/test_uncleared_transactions.py b/tests/health_checks/test_uncleared_transactions.py index 67f755cc..8dfcfd79 100644 --- a/tests/health_checks/test_uncleared_transactions.py +++ b/tests/health_checks/test_uncleared_transactions.py @@ -2,10 +2,10 @@ from typing import TYPE_CHECKING +from nummus import sql from nummus.health_checks.uncleared_transactions import UnclearedTransactions from nummus.models.currency import CURRENCY_FORMATS, DEFAULT_CURRENCY from nummus.models.health_checks import HealthCheckIssue -from nummus.models.utils import query_count if TYPE_CHECKING: from sqlalchemy import orm @@ -14,20 +14,18 @@ from nummus.models.transaction import Transaction -def test_empty(session: orm.Session) -> None: +def test_empty() -> None: c = UnclearedTransactions() - c.test(session) + c.test() assert c.issues == {} def test_no_issues( - session: orm.Session, transactions: list[Transaction], ) -> None: - _ = transactions c = UnclearedTransactions() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 0 + c.test() + assert not sql.any_(HealthCheckIssue.query()) def test_check( @@ -35,17 +33,17 @@ def test_check( account: Account, transactions: list[Transaction], ) -> None: - txn = transactions[0] - txn.cleared = False - t_split = txn.splits[0] - t_split.parent = txn - session.commit() + with session.begin_nested(): + txn = transactions[0] + txn.cleared = False + t_split = txn.splits[0] + t_split.parent = txn c = UnclearedTransactions() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 1 + c.test() + assert HealthCheckIssue.count() == 1 - i = session.query(HealthCheckIssue).one() + i = HealthCheckIssue.one() assert i.check == c.name() assert i.value == t_split.uri uri = i.uri diff --git a/tests/health_checks/test_unnecessary_splits.py b/tests/health_checks/test_unnecessary_splits.py index 90e2d1a8..5b7ef548 100644 --- a/tests/health_checks/test_unnecessary_splits.py +++ b/tests/health_checks/test_unnecessary_splits.py @@ -2,10 +2,10 @@ from typing import TYPE_CHECKING +from nummus import sql from nummus.health_checks.unnecessary_slits import UnnecessarySplits from nummus.models.health_checks import HealthCheckIssue from nummus.models.transaction import TransactionSplit -from nummus.models.utils import query_count if TYPE_CHECKING: from sqlalchemy import orm @@ -14,20 +14,18 @@ from nummus.models.transaction import Transaction -def test_empty(session: orm.Session) -> None: +def test_empty() -> None: c = UnnecessarySplits() - c.test(session) + c.test() assert c.issues == {} def test_no_issues( - session: orm.Session, transactions: list[Transaction], ) -> None: - _ = transactions c = UnnecessarySplits() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 0 + c.test() + assert not sql.any_(HealthCheckIssue.query()) def test_check( @@ -35,21 +33,20 @@ def test_check( account: Account, transactions: list[Transaction], ) -> None: - txn = transactions[0] - t_split = txn.splits[0] - t_split = TransactionSplit( - parent=txn, - amount=t_split.amount, - category_id=t_split.category_id, - ) - session.add(t_split) - session.commit() + with session.begin_nested(): + txn = transactions[0] + t_split = txn.splits[0] + t_split = TransactionSplit.create( + parent=txn, + amount=t_split.amount, + category_id=t_split.category_id, + ) c = UnnecessarySplits() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 1 + c.test() + assert HealthCheckIssue.count() == 1 - i = session.query(HealthCheckIssue).one() + i = HealthCheckIssue.one() assert i.check == c.name() assert i.value == f"{txn.id_}.{t_split.payee}.{t_split.category_id}" uri = i.uri diff --git a/tests/health_checks/test_unused_categories.py b/tests/health_checks/test_unused_categories.py index 1992a9cf..e2a64525 100644 --- a/tests/health_checks/test_unused_categories.py +++ b/tests/health_checks/test_unused_categories.py @@ -5,7 +5,6 @@ from nummus.health_checks.unused_categories import UnusedCategories from nummus.models.health_checks import HealthCheckIssue from nummus.models.transaction_category import TransactionCategory -from nummus.models.utils import query_count if TYPE_CHECKING: from sqlalchemy import orm @@ -13,13 +12,12 @@ from nummus.models.transaction import Transaction -def test_empty(session: orm.Session) -> None: +def test_empty() -> None: # Mark all locked since those are excluded - session.query(TransactionCategory).update({"locked": True}) - session.commit() + TransactionCategory.query().update({"locked": True}) c = UnusedCategories() - c.test(session) + c.test() assert c.issues == {} @@ -28,18 +26,16 @@ def test_one( transactions: list[Transaction], categories: dict[str, int], ) -> None: - _ = transactions # Mark all but groceries and other income locked since those are excluded - session.query(TransactionCategory).where( + TransactionCategory.query().where( TransactionCategory.name.not_in({"groceries", "other income"}), ).update({"locked": True}) - session.commit() c = UnusedCategories() - c.test(session) - assert query_count(session.query(HealthCheckIssue)) == 1 + c.test() + assert HealthCheckIssue.count() == 1 - i = session.query(HealthCheckIssue).one() + i = HealthCheckIssue.one() assert i.check == c.name() assert i.value == TransactionCategory.id_to_uri(categories["groceries"]) uri = i.uri diff --git a/tests/migrations/test_base.py b/tests/migrations/test_base.py index 195dd76f..58227ed5 100644 --- a/tests/migrations/test_base.py +++ b/tests/migrations/test_base.py @@ -27,7 +27,6 @@ class MockMigrator(Migrator): @override def migrate(self, p: Portfolio) -> list[str]: - _ = p return ["Comments"] @@ -38,85 +37,85 @@ def test_version() -> None: def test_drop_column(session: orm.Session) -> None: m = MockMigrator() - m.drop_column(session, Asset, "category") - session.commit() + with session.begin_nested(): + m.drop_column(Asset, "category") assert m.pending_schema_updates == set() - result = "\n".join(dump_table_configs(session, Asset)) + result = "\n".join(dump_table_configs(Asset)) assert "category" not in result def test_drop_column_with_constraints(session: orm.Session) -> None: m = MockMigrator() - m.drop_column(session, AssetValuation, "value") - session.commit() + with session.begin_nested(): + m.drop_column(AssetValuation, "value") assert m.pending_schema_updates == {AssetValuation} - result = "\n".join(dump_table_configs(session, AssetValuation)) + result = "\n".join(dump_table_configs(AssetValuation)) assert "value" not in result def test_add_column_no_value_set(session: orm.Session, asset: Asset) -> None: m = MockMigrator() - m.drop_column(session, Asset, "category") - session.commit() + with session.begin_nested(): + m.drop_column(Asset, "category") m.pending_schema_updates.clear() - m.add_column(session, Asset, Asset.category) - session.commit() + with session.begin_nested(): + m.add_column(Asset, Asset.category) assert m.pending_schema_updates == {Asset} - result = "\n".join(dump_table_configs(session, Asset)) + result = "\n".join(dump_table_configs(Asset)) assert "category" in result + asset.refresh() assert asset.category is None def test_add_column_value_set(session: orm.Session, asset: Asset) -> None: m = MockMigrator() - m.drop_column(session, Asset, "category") - session.commit() + with session.begin_nested(): + m.drop_column(Asset, "category") m.pending_schema_updates.clear() - m.add_column(session, Asset, Asset.category, AssetCategory.STOCKS) - session.commit() + with session.begin_nested(): + m.add_column(Asset, Asset.category, AssetCategory.STOCKS) assert m.pending_schema_updates == {Asset} - result = "\n".join(dump_table_configs(session, Asset)) + result = "\n".join(dump_table_configs(Asset)) assert "category" in result + asset.refresh() assert asset.category == AssetCategory.STOCKS def test_rename_column(session: orm.Session) -> None: m = MockMigrator() - m.rename_column(session, Asset, "category", "class") - session.commit() + with session.begin_nested(): + m.rename_column(Asset, "category", "class") assert m.pending_schema_updates == {Asset} - result = "\n".join(dump_table_configs(session, Asset)) + result = "\n".join(dump_table_configs(Asset)) assert "category" not in result assert "class" in result def test_migrate_schemas_no_value_set(empty_portfolio: Portfolio, asset: Asset) -> None: - _ = asset m = SchemaMigrator(set()) - with empty_portfolio.begin_session() as s: - m.drop_column(s, Asset, "category") - with empty_portfolio.begin_session() as s: - m.add_column(s, Asset, Asset.category) + with empty_portfolio.begin_session(): + m.drop_column(Asset, "category") + with empty_portfolio.begin_session(): + m.add_column(Asset, Asset.category) with pytest.raises(exc.IntegrityError): m.migrate(empty_portfolio) def test_migrate_schemas_value_set(empty_portfolio: Portfolio, asset: Asset) -> None: - _ = asset m = SchemaMigrator(set()) - with empty_portfolio.begin_session() as s: - m.drop_column(s, Asset, "category") - with empty_portfolio.begin_session() as s: - m.add_column(s, Asset, Asset.category, AssetCategory.STOCKS) + with empty_portfolio.begin_session(): + m.drop_column(Asset, "category") + with empty_portfolio.begin_session(): + m.add_column(Asset, Asset.category, AssetCategory.STOCKS) assert m.migrate(empty_portfolio) == [] diff --git a/tests/migrations/test_v0_10.py b/tests/migrations/test_v0_10.py index 2af46874..163cb3ee 100644 --- a/tests/migrations/test_v0_10.py +++ b/tests/migrations/test_v0_10.py @@ -23,7 +23,7 @@ def test_migrate(tmp_path: Path, data_path: Path) -> None: target = [] assert result == target - with p.begin_session() as s: - result = "\n".join(dump_table_configs(s, TransactionCategory)) + with p.begin_session(): + result = "\n".join(dump_table_configs(TransactionCategory)) assert "essential_spending" in result assert TransactionCategory in m.pending_schema_updates diff --git a/tests/migrations/test_v0_13.py b/tests/migrations/test_v0_13.py index bca8b5c1..0025d6be 100644 --- a/tests/migrations/test_v0_13.py +++ b/tests/migrations/test_v0_13.py @@ -3,13 +3,11 @@ import shutil from typing import TYPE_CHECKING +from nummus import sql from nummus.migrations.v0_13 import MigratorV0_13 from nummus.models.label import Label, LabelLink from nummus.models.transaction import TransactionSplit -from nummus.models.utils import ( - dump_table_configs, - query_count, -) +from nummus.models.utils import dump_table_configs from nummus.portfolio import Portfolio if TYPE_CHECKING: @@ -27,17 +25,15 @@ def test_migrate(tmp_path: Path, data_path: Path) -> None: target = [] assert result == target - with p.begin_session() as s: - result = "\n".join(dump_table_configs(s, TransactionSplit)) + with p.begin_session(): + result = "\n".join(dump_table_configs(TransactionSplit)) assert "label" not in result - result = "\n".join(dump_table_configs(s, Label)) + result = "\n".join(dump_table_configs(Label)) assert "name" in result - n = query_count(s.query(Label)) - assert n == 1 + assert sql.count(Label.query()) == 1 - result = "\n".join(dump_table_configs(s, LabelLink)) + result = "\n".join(dump_table_configs(LabelLink)) assert "label_id" in result - n = query_count(s.query(LabelLink)) - assert n == 100 + assert sql.count(LabelLink.query()) == 100 diff --git a/tests/migrations/test_v0_15.py b/tests/migrations/test_v0_15.py index d7c28d69..39c9b6ea 100644 --- a/tests/migrations/test_v0_15.py +++ b/tests/migrations/test_v0_15.py @@ -3,12 +3,10 @@ import shutil from typing import TYPE_CHECKING +from nummus import sql from nummus.migrations.v0_15 import MigratorV0_15 from nummus.models.label import Label, LabelLink -from nummus.models.utils import ( - dump_table_configs, - query_count, -) +from nummus.models.utils import dump_table_configs from nummus.portfolio import Portfolio if TYPE_CHECKING: @@ -26,14 +24,12 @@ def test_migrate(tmp_path: Path, data_path: Path) -> None: target = [] assert result == target - with p.begin_session() as s: - result = "\n".join(dump_table_configs(s, Label)) + with p.begin_session(): + result = "\n".join(dump_table_configs(Label)) assert "name" in result - n = query_count(s.query(Label)) - assert n == 1 + assert sql.count(Label.query()) == 1 - result = "\n".join(dump_table_configs(s, LabelLink)) + result = "\n".join(dump_table_configs(LabelLink)) assert "label_id" in result - n = query_count(s.query(LabelLink)) - assert n == 100 + assert sql.count(LabelLink.query()) == 100 diff --git a/tests/migrations/test_v0_16.py b/tests/migrations/test_v0_16.py index 41fe0f42..0798889d 100644 --- a/tests/migrations/test_v0_16.py +++ b/tests/migrations/test_v0_16.py @@ -28,11 +28,11 @@ def test_migrate(tmp_path: Path, data_path: Path) -> None: ] assert result == target - with p.begin_session() as s: - result = "\n".join(dump_table_configs(s, Account)) + with p.begin_session(): + result = "\n".join(dump_table_configs(Account)) assert "currency" in result - result = "\n".join(dump_table_configs(s, Asset)) + result = "\n".join(dump_table_configs(Asset)) assert "currency" in result - assert Config.base_currency(s) == Currency.USD + assert Config.base_currency() == Currency.USD diff --git a/tests/migrations/test_v0_2.py b/tests/migrations/test_v0_2.py index 1eeb038e..5d1e6367 100644 --- a/tests/migrations/test_v0_2.py +++ b/tests/migrations/test_v0_2.py @@ -28,14 +28,14 @@ def test_migrate(tmp_path: Path, data_path: Path) -> None: ] assert result == target - with p.begin_session() as s: - result = "\n".join(dump_table_configs(s, Transaction)) + with p.begin_session(): + result = "\n".join(dump_table_configs(Transaction)) assert "linked" not in result assert "locked" not in result assert "cleared" in result assert "payee" in result - result = "\n".join(dump_table_configs(s, TransactionSplit)) + result = "\n".join(dump_table_configs(TransactionSplit)) assert "linked" not in result assert "locked" not in result assert "cleared" in result diff --git a/tests/mock_yfinance.py b/tests/mock_yfinance.py index 28f10069..046ec150 100644 --- a/tests/mock_yfinance.py +++ b/tests/mock_yfinance.py @@ -84,6 +84,6 @@ def history( dt += datetime.timedelta(days=1) return pd.DataFrame( - index=pd.to_datetime(dates), + index=pd.to_datetime(dates), # type: ignore[attr-defined] data={"Close": close, "Stock Splits": split}, ) diff --git a/tests/models/asset/test_asset.py b/tests/models/asset/test_asset.py index 3b4f9d2a..8adf82d5 100644 --- a/tests/models/asset/test_asset.py +++ b/tests/models/asset/test_asset.py @@ -7,6 +7,7 @@ import pytest from nummus import exceptions as exc +from nummus import sql from nummus.models.asset import ( Asset, AssetCategory, @@ -16,27 +17,18 @@ ) from nummus.models.currency import Currency, DEFAULT_CURRENCY from nummus.models.label import LabelLink -from nummus.models.utils import ( - query_count, - query_to_dict, - update_rows, -) +from nummus.models.utils import update_rows from tests import conftest if TYPE_CHECKING: - from sqlalchemy import orm - from nummus.models.account import Account - from nummus.models.asset import ( - AssetSplit, - ) + from nummus.models.asset import AssetSplit from nummus.models.transaction import Transaction from tests.conftest import RandomStringGenerator @pytest.fixture def valuations( - session: orm.Session, today_ord: int, asset: Asset, ) -> list[AssetValuation]: @@ -47,15 +39,12 @@ def valuations( today_ord + 3: {"value": Decimal(10), "asset_id": a_id}, } - query = session.query(AssetValuation) - update_rows(session, AssetValuation, query, "date_ord", updates) - session.commit() - return query.all() + update_rows(AssetValuation, AssetValuation.query(), "date_ord", updates) + return AssetValuation.all() @pytest.fixture def valuations_five( - session: orm.Session, today_ord: int, asset: Asset, ) -> list[AssetValuation]: @@ -68,14 +57,11 @@ def valuations_five( today_ord + 7: {"value": Decimal(10), "asset_id": a_id}, } - query = session.query(AssetValuation) - update_rows(session, AssetValuation, query, "date_ord", updates) - session.commit() - return query.all() + update_rows(AssetValuation, AssetValuation.query(), "date_ord", updates) + return AssetValuation.all() def test_init_properties( - session: orm.Session, rand_str_generator: RandomStringGenerator, ) -> None: d = { @@ -86,9 +72,7 @@ def test_init_properties( "currency": DEFAULT_CURRENCY, } - a = Asset(**d) - session.add(a) - session.commit() + a = Asset.create(**d) assert a.name == d["name"] assert a.description == d["description"] @@ -103,7 +87,6 @@ def test_short() -> None: def test_get_value_empty( today_ord: int, - session: orm.Session, asset: Asset, ) -> None: start_ord = today_ord - 3 @@ -111,7 +94,7 @@ def test_get_value_empty( result = asset.get_value(start_ord, end_ord) assert result == [Decimal(0)] * 7 - result = Asset.get_value_all(session, start_ord, end_ord) + result = Asset.get_value_all(start_ord, end_ord) assert result == {} @@ -120,7 +103,6 @@ def test_get_value( asset: Asset, valuations: list[AssetValuation], ) -> None: - _ = valuations start_ord = today_ord - 3 end_ord = today_ord + 3 result = asset.get_value(start_ord, end_ord) @@ -142,7 +124,6 @@ def test_get_value_interpolate( valuations: list[AssetValuation], ) -> None: asset.interpolate = True - _ = valuations start_ord = today_ord - 3 end_ord = today_ord + 3 result = asset.get_value(start_ord, end_ord) @@ -163,7 +144,6 @@ def test_get_value_today( asset: Asset, valuations: list[AssetValuation], ) -> None: - _ = valuations result = asset.get_value(today_ord, today_ord) assert result == [Decimal(100)] @@ -174,7 +154,6 @@ def test_get_value_tomorrow_interpolate( valuations: list[AssetValuation], ) -> None: asset.interpolate = True - _ = valuations result = asset.get_value(today_ord + 1, today_ord + 1) assert result == [Decimal(70)] @@ -185,7 +164,6 @@ def test_update_splits_empty( asset: Asset, transactions: list[Transaction], ) -> None: - _ = transactions asset.update_splits() assets = account.get_asset_qty(today_ord, today_ord) assert assets == {asset.id_: [Decimal(10)]} @@ -198,8 +176,6 @@ def test_update_splits( asset_split: AssetSplit, transactions: list[Transaction], ) -> None: - _ = transactions - _ = asset_split asset.update_splits() assets = account.get_asset_qty(today_ord, today_ord) assert assets == {asset.id_: [Decimal(100)]} @@ -216,8 +192,6 @@ def test_prune_valuations_none( valuations: list[AssetValuation], transactions: list[Transaction], ) -> None: - _ = valuations - _ = transactions assert asset.prune_valuations() == 0 @@ -236,7 +210,6 @@ def test_prune_valuations_none( ids=conftest.id_func, ) def test_prune_valuations_first_txn( - session: orm.Session, asset: Asset, valuations_five: list[AssetValuation], transactions: list[Transaction], @@ -246,31 +219,28 @@ def test_prune_valuations_first_txn( for i in to_delete: txn = transactions[i] for t_split in txn.splits: - session.query(LabelLink).where(LabelLink.t_split_id == t_split.id_).delete() - session.delete(t_split) - session.delete(txn) - _ = valuations_five + LabelLink.query().where(LabelLink.t_split_id == t_split.id_).delete() + t_split.delete() + txn.delete() assert asset.prune_valuations() == target def test_prune_valuations_index(asset: Asset, valuations: list[AssetValuation]) -> None: - _ = valuations asset.category = AssetCategory.INDEX assert asset.prune_valuations() == 0 -def test_update_valuations_none(session: orm.Session, asset: Asset) -> None: +def test_update_valuations_none(asset: Asset) -> None: asset.ticker = None - session.commit() with pytest.raises(exc.NoAssetWebSourceError): asset.update_valuations(through_today=True) -def test_update_valuations_empty(session: orm.Session, asset: Asset) -> None: +def test_update_valuations_empty(asset: Asset) -> None: start, end = asset.update_valuations(through_today=True) assert start is None assert end is None - assert query_count(session.query(AssetValuation)) == 0 + assert not sql.any_(AssetValuation.query()) @pytest.mark.parametrize( @@ -283,7 +253,6 @@ def test_update_valuations_empty(session: orm.Session, asset: Asset) -> None: ) def test_update_valuations( today: datetime.date, - session: orm.Session, transactions: list[Transaction], category: AssetCategory, asset: Asset, @@ -304,22 +273,20 @@ def test_update_valuations( while start <= end: n += 0 if start.weekday() in {5, 6} else 1 start += datetime.timedelta(days=1) - assert query_count(session.query(AssetValuation)) == n + assert sql.count(AssetValuation.query()) == n def test_update_valuations_delisted( asset: Asset, transactions: list[Transaction], ) -> None: - _ = transactions asset.ticker = "APPLE" with pytest.raises(exc.AssetWebError): asset.update_valuations(through_today=True) -def test_update_sectors_none(session: orm.Session, asset: Asset) -> None: +def test_update_sectors_none(asset: Asset) -> None: asset.ticker = None - session.commit() with pytest.raises(exc.NoAssetWebSourceError): asset.update_sectors() @@ -347,43 +314,39 @@ def test_update_sectors_none(session: orm.Session, asset: Asset) -> None: ], ) def test_update_sectors( - session: orm.Session, asset: Asset, ticker: str, target: dict[USSector, Decimal], ) -> None: asset.ticker = ticker asset.update_sectors() - session.commit() - query = ( - session.query(AssetSector) - .with_entities(AssetSector.sector, AssetSector.weight) - .where(AssetSector.asset_id == asset.id_) + query = AssetSector.query(AssetSector.sector, AssetSector.weight).where( + AssetSector.asset_id == asset.id_, ) - sectors: dict[USSector, Decimal] = query_to_dict(query) + sectors: dict[USSector, Decimal] = sql.to_dict(query) assert sectors == target -def test_index_twrr_none(today_ord: int, session: orm.Session) -> None: +def test_index_twrr_none(today_ord: int) -> None: with pytest.raises(exc.ProtectedObjectNotFoundError): - Asset.index_twrr(session, "Fake Index", today_ord, today_ord) + Asset.index_twrr("Fake Index", today_ord, today_ord) -def test_index_twrr(today_ord: int, session: orm.Session, asset: Asset) -> None: +def test_index_twrr(today_ord: int, asset: Asset) -> None: asset.category = AssetCategory.INDEX - result = Asset.index_twrr(session, asset.name, today_ord - 3, today_ord + 3) + result = Asset.index_twrr(asset.name, today_ord - 3, today_ord + 3) # utils.twrr and Asset.get_value already tested, just check they connect well assert result == [Decimal(0)] * 7 -def test_index_twrr_today(today_ord: int, session: orm.Session, asset: Asset) -> None: +def test_index_twrr_today(today_ord: int, asset: Asset) -> None: asset.category = AssetCategory.INDEX - result = Asset.index_twrr(session, asset.name, today_ord, today_ord) + result = Asset.index_twrr(asset.name, today_ord, today_ord) assert result == [Decimal(0)] -def test_add_indices(session: orm.Session) -> None: - for asset in session.query(Asset).all(): +def test_add_indices() -> None: + for asset in Asset.all(): assert asset.name is not None assert asset.description is not None assert not asset.interpolate @@ -399,7 +362,6 @@ def test_autodetect_interpolate_sparse( asset: Asset, valuations: list[AssetValuation], ) -> None: - _ = valuations asset.autodetect_interpolate() assert asset.interpolate @@ -410,36 +372,34 @@ def test_autodetect_interpolate_daily( ) -> None: for i, v in enumerate(valuations): v.date_ord = valuations[0].date_ord + i - _ = valuations asset.autodetect_interpolate() assert not asset.interpolate -def test_create_forex(session: orm.Session, asset: Asset) -> None: +def test_create_forex(asset: Asset) -> None: asset.ticker = "EURUSD=X" asset.category = AssetCategory.FOREX - Asset.create_forex(session, Currency.USD, {*Currency}) + Asset.create_forex(Currency.USD, {*Currency}) - query = session.query(Asset).where(Asset.category == AssetCategory.FOREX) + query = Asset.query().where(Asset.category == AssetCategory.FOREX) # -1 since don't need USDUSD=x - assert query_count(query) == len(Currency) - 1 + assert sql.count(query) == len(Currency) - 1 -def test_create_forex_none(session: orm.Session) -> None: - Asset.create_forex(session, Currency.USD, set()) +def test_create_forex_none() -> None: + Asset.create_forex(Currency.USD, set()) - query = session.query(Asset).where(Asset.category == AssetCategory.FOREX) - assert query_count(query) == 0 + query = Asset.query().where(Asset.category == AssetCategory.FOREX) + assert not sql.any_(query) -def test_get_forex_empty(session: orm.Session, today_ord: int) -> None: - result = Asset.get_forex(session, today_ord, today_ord, DEFAULT_CURRENCY) +def test_get_forex_empty(today_ord: int) -> None: + result = Asset.get_forex(today_ord, today_ord, DEFAULT_CURRENCY) assert result[DEFAULT_CURRENCY] == [Decimal(1)] def test_get_forex( - session: orm.Session, today_ord: int, asset: Asset, asset_valuation: AssetValuation, @@ -449,7 +409,6 @@ def test_get_forex( asset.currency = Currency.USD result = Asset.get_forex( - session, today_ord, today_ord, Currency.USD, diff --git a/tests/models/asset/test_asset_sector.py b/tests/models/asset/test_asset_sector.py index 54eec082..35695e14 100644 --- a/tests/models/asset/test_asset_sector.py +++ b/tests/models/asset/test_asset_sector.py @@ -8,14 +8,11 @@ from nummus.models.asset import AssetSector, USSector if TYPE_CHECKING: - from sqlalchemy import orm - from nummus.models.asset import Asset from tests.conftest import RandomRealGenerator def test_init_properties( - session: orm.Session, asset: Asset, rand_real_generator: RandomRealGenerator, ) -> None: @@ -25,45 +22,35 @@ def test_init_properties( "weight": rand_real_generator(1, 10), } - v = AssetSector(**d) - session.add(v) - session.commit() + v = AssetSector.create(**d) assert v.asset_id == d["asset_id"] assert v.sector == d["sector"] assert v.weight == d["weight"] -def test_weight_negative(session: orm.Session, asset: Asset) -> None: - v = AssetSector(asset_id=asset.id_, sector=USSector.REAL_ESTATE, weight=-1) - session.add(v) +def test_weight_negative(asset: Asset) -> None: with pytest.raises(exc.IntegrityError): - session.commit() + AssetSector.create(asset_id=asset.id_, sector=USSector.REAL_ESTATE, weight=-1) -def test_weight_zero(session: orm.Session, asset: Asset) -> None: - v = AssetSector(asset_id=asset.id_, sector=USSector.REAL_ESTATE, weight=0) - session.add(v) +def test_weight_zero(asset: Asset) -> None: with pytest.raises(exc.IntegrityError): - session.commit() + AssetSector.create(asset_id=asset.id_, sector=USSector.REAL_ESTATE, weight=0) def test_duplicate_sectors( - session: orm.Session, asset: Asset, rand_real_generator: RandomRealGenerator, ) -> None: - v = AssetSector( - asset_id=asset.id_, - sector=USSector.REAL_ESTATE, - weight=rand_real_generator(1, 10), - ) - session.add(v) - v = AssetSector( + AssetSector.create( asset_id=asset.id_, sector=USSector.REAL_ESTATE, weight=rand_real_generator(1, 10), ) - session.add(v) with pytest.raises(exc.IntegrityError): - session.commit() + AssetSector.create( + asset_id=asset.id_, + sector=USSector.REAL_ESTATE, + weight=rand_real_generator(1, 10), + ) diff --git a/tests/models/asset/test_asset_split.py b/tests/models/asset/test_asset_split.py index 5b177099..85e3df39 100644 --- a/tests/models/asset/test_asset_split.py +++ b/tests/models/asset/test_asset_split.py @@ -10,8 +10,6 @@ if TYPE_CHECKING: import datetime - from sqlalchemy import orm - from nummus.models.asset import Asset from tests.conftest import RandomRealGenerator @@ -19,7 +17,6 @@ def test_init_properties( today: datetime.date, today_ord: int, - session: orm.Session, asset: Asset, rand_real_generator: RandomRealGenerator, ) -> None: @@ -29,9 +26,7 @@ def test_init_properties( "date_ord": today_ord, } - v = AssetSplit(**d) - session.add(v) - session.commit() + v = AssetSplit.create(**d) assert v.asset_id == d["asset_id"] assert v.multiplier == d["multiplier"] @@ -41,39 +36,30 @@ def test_init_properties( def test_multiplier_negative( today_ord: int, - session: orm.Session, asset: Asset, ) -> None: - v = AssetSplit(asset_id=asset.id_, date_ord=today_ord, multiplier=-1) - session.add(v) with pytest.raises(exc.IntegrityError): - session.commit() + AssetSplit.create(asset_id=asset.id_, date_ord=today_ord, multiplier=-1) -def test_multiplier_zero(today_ord: int, session: orm.Session, asset: Asset) -> None: - v = AssetSplit(asset_id=asset.id_, date_ord=today_ord, multiplier=0) - session.add(v) +def test_multiplier_zero(today_ord: int, asset: Asset) -> None: with pytest.raises(exc.IntegrityError): - session.commit() + AssetSplit.create(asset_id=asset.id_, date_ord=today_ord, multiplier=0) def test_duplicate_dates( today_ord: int, - session: orm.Session, asset: Asset, rand_real_generator: RandomRealGenerator, ) -> None: - v = AssetSplit( - asset_id=asset.id_, - date_ord=today_ord, - multiplier=rand_real_generator(1, 10), - ) - session.add(v) - v = AssetSplit( + AssetSplit.create( asset_id=asset.id_, date_ord=today_ord, multiplier=rand_real_generator(1, 10), ) - session.add(v) with pytest.raises(exc.IntegrityError): - session.commit() + AssetSplit.create( + asset_id=asset.id_, + date_ord=today_ord, + multiplier=rand_real_generator(1, 10), + ) diff --git a/tests/models/asset/test_asset_valuation.py b/tests/models/asset/test_asset_valuation.py index 94bfc57a..15fe81cc 100644 --- a/tests/models/asset/test_asset_valuation.py +++ b/tests/models/asset/test_asset_valuation.py @@ -11,8 +11,6 @@ import datetime from decimal import Decimal - from sqlalchemy import orm - from nummus.models.asset import Asset from tests.conftest import RandomRealGenerator @@ -20,7 +18,6 @@ def test_init_properties( today: datetime.date, today_ord: int, - session: orm.Session, asset: Asset, rand_real: Decimal, ) -> None: @@ -30,9 +27,7 @@ def test_init_properties( "value": rand_real, } - v = AssetValuation(**d) - session.add(v) - session.commit() + v = AssetValuation.create(**d) assert v.asset_id == d["asset_id"] assert v.value == d["value"] @@ -42,32 +37,25 @@ def test_init_properties( def test_multiplier_negative( today_ord: int, - session: orm.Session, asset: Asset, ) -> None: - v = AssetValuation(asset_id=asset.id_, date_ord=today_ord, value=-1) - session.add(v) with pytest.raises(exc.IntegrityError): - session.commit() + AssetValuation.create(asset_id=asset.id_, date_ord=today_ord, value=-1) def test_duplicate_dates( today_ord: int, - session: orm.Session, asset: Asset, rand_real_generator: RandomRealGenerator, ) -> None: - v = AssetValuation( - asset_id=asset.id_, - date_ord=today_ord, - value=rand_real_generator(), - ) - session.add(v) - v = AssetValuation( + AssetValuation.create( asset_id=asset.id_, date_ord=today_ord, value=rand_real_generator(), ) - session.add(v) with pytest.raises(exc.IntegrityError): - session.commit() + AssetValuation.create( + asset_id=asset.id_, + date_ord=today_ord, + value=rand_real_generator(), + ) diff --git a/tests/models/base/test_orm_base.py b/tests/models/base/test_orm_base.py index 177da751..0aea5fea 100644 --- a/tests/models/base/test_orm_base.py +++ b/tests/models/base/test_orm_base.py @@ -1,19 +1,36 @@ from __future__ import annotations +import re from decimal import Decimal +from pathlib import Path from typing import TYPE_CHECKING import pytest from sqlalchemy import ForeignKey, orm +import nummus +import tests from nummus import exceptions as exc from nummus import sql -from nummus.models import base +from nummus.models.base import ( + Base, + BaseEnum, + Decimal6, + ORMInt, + ORMIntOpt, + ORMRealOpt, + ORMStrOpt, + SQLEnum, + string_column_args, +) +from tests import conftest if TYPE_CHECKING: - from collections.abc import Mapping - from pathlib import Path + from collections.abc import Callable, Generator, Mapping + from nummus.models.base import ( + NamePair, + ) from tests.conftest import RandomStringGenerator @@ -28,7 +45,7 @@ def __hash__(self) -> int: return hash(self._data) -class Derived(base.BaseEnum): +class Derived(BaseEnum): RED = 1 BLUE = 2 SEAFOAM_GREEN = 3 @@ -38,15 +55,17 @@ def lut(cls) -> Mapping[str, Derived]: return {"r": cls.RED, "b": cls.BLUE} -class Parent(base.Base, skip_register=True): +class Parent(Base, skip_register=True): __tablename__ = "parent" __table_id__ = 0xF0000000 - generic_column: base.ORMIntOpt - name: base.ORMStrOpt + generic_column: ORMIntOpt + name: ORMStrOpt children: orm.Mapped[list[Child]] = orm.relationship(back_populates="parent") - __table_args__ = (*base.string_column_args("name"),) + __table_args__ = (*string_column_args("name"),) + + _SEARCH_PROPERTIES = ("name",) @orm.validates("name") def validate_strings(self, key: str, field: str | None) -> str | None: @@ -63,46 +82,47 @@ def uri_bytes(self) -> Bytes: return Bytes(self.uri) -class Child(base.Base, skip_register=True): +class Child(Base, skip_register=True): __tablename__ = "child" __table_id__ = 0xE0000000 - parent_id: base.ORMInt = orm.mapped_column(ForeignKey("parent.id_")) + parent_id: ORMInt = orm.mapped_column(ForeignKey("parent.id_")) parent: orm.Mapped[Parent] = orm.relationship(back_populates="children") - height: base.ORMRealOpt = orm.mapped_column(base.Decimal6) + height: ORMRealOpt = orm.mapped_column(Decimal6) - color: orm.Mapped[Derived | None] = orm.mapped_column(base.SQLEnum(Derived)) + color: orm.Mapped[Derived | None] = orm.mapped_column(SQLEnum(Derived)) @orm.validates("height") def validate_decimals(self, key: str, field: Decimal | None) -> Decimal | None: return self.clean_decimals(key, field) -class NoURI(base.Base, skip_register=True): +class NoURI(Base, skip_register=True): __tablename__ = "no_uri" __table_id__ = None @pytest.fixture -def session(tmp_path: Path) -> orm.Session: +def session(tmp_path: Path) -> Generator[orm.Session]: """Create SQL session. Args: tmp_path: Temp path to create DB in - Returns: + Yields: Session generator """ path = tmp_path / "sql.db" s = orm.Session(sql.get_engine(path, None)) - base.Base.metadata.create_all( - s.get_bind(), - tables=[Parent.sql_table(), Child.sql_table()], - ) - s.commit() - return s + with s.begin_nested(): + Base.metadata.create_all( + s.get_bind(), + tables=[Parent.sql_table(), Child.sql_table()], + ) + with Base.set_session(s): + yield s @pytest.fixture @@ -113,10 +133,8 @@ def parent(session: orm.Session) -> Parent: Parent """ - p = Parent() - session.add(p) - session.commit() - return p + with session.begin_nested(): + return Parent.create() @pytest.fixture @@ -127,17 +145,8 @@ def child(session: orm.Session, parent: Parent) -> Child: Child """ - c = Child(parent=parent) - session.add(c) - session.commit() - return c - - -def test_detached() -> None: - parent = Parent() - assert parent.id_ is None - with pytest.raises(exc.NoIDError): - _ = parent.uri + with session.begin_nested(): + return Child.create(parent=parent) def test_init_properties(parent: Parent) -> None: @@ -147,13 +156,6 @@ def test_init_properties(parent: Parent) -> None: assert hash(parent) == parent.id_ -def test_detached_child() -> None: - child = Child() - assert child.id_ is None - assert child.parent is None - assert child.parent_id is None - - def test_link_child(parent: Parent, child: Child) -> None: assert child.id_ is not None assert child.parent == parent @@ -165,23 +167,20 @@ def test_wrong_uri_type(parent: Parent) -> None: Child.uri_to_id(parent.uri) -def test_set_decimal_none(session: orm.Session, child: Child) -> None: +def test_set_decimal_none(child: Child) -> None: child.height = None - session.commit() assert child.height is None -def test_set_decimal_value(session: orm.Session, child: Child) -> None: +def test_set_decimal_value(child: Child) -> None: height = Decimal("1.2") child.height = height - session.commit() assert isinstance(child.height, Decimal) assert child.height == height -def test_set_enum(session: orm.Session, child: Child) -> None: +def test_set_enum(child: Child) -> None: child.color = Derived.RED - session.commit() assert isinstance(child.color, Derived) assert child.color == Derived.RED @@ -192,11 +191,9 @@ def test_no_uri() -> None: _ = no_uri.uri -def test_comparators_same_session(session: orm.Session) -> None: - parent_a = Parent() - parent_b = Parent() - session.add_all([parent_a, parent_b]) - session.commit() +def test_comparators_same_session() -> None: + parent_a = Parent.create() + parent_b = Parent.create() assert parent_a == parent_a # noqa: PLR0124 assert parent_a != parent_b @@ -213,25 +210,20 @@ def test_comparators_different_session(session: orm.Session, parent: Parent) -> assert parent == parent_a_queried -def test_map_name_none(session: orm.Session) -> None: +def test_map_name_none() -> None: with pytest.raises(KeyError, match="Base does not have name column"): - base.Base.map_name(session) + Base.map_name() -def test_map_name_parent( - session: orm.Session, - rand_str_generator: RandomStringGenerator, -) -> None: - parent_a = Parent(name=rand_str_generator()) - parent_b = Parent(name=rand_str_generator()) - session.add_all([parent_a, parent_b]) - session.commit() +def test_map_name_parent(rand_str_generator: RandomStringGenerator) -> None: + parent_a = Parent.create(name=rand_str_generator()) + parent_b = Parent.create(name=rand_str_generator()) target = { parent_a.id_: parent_a.name, parent_b.id_: parent_b.name, } - assert Parent.map_name(session) == target + assert Parent.map_name() == target def test_clean_strings_none(parent: Parent) -> None: @@ -258,28 +250,28 @@ def test_clean_strings_short(parent: Parent) -> None: parent.name = "a" -def test_string_check_none(session: orm.Session, parent: Parent) -> None: +def test_string_check_none(parent: Parent) -> None: with pytest.raises(exc.IntegrityError): - session.query(Parent).where(Parent.id_ == parent.id_).update({Parent.name: ""}) + Parent.query().where(Parent.id_ == parent.id_).update({Parent.name: ""}) -def test_string_check_leading(session: orm.Session, parent: Parent) -> None: +def test_string_check_leading(parent: Parent) -> None: with pytest.raises(exc.IntegrityError): - session.query(Parent).where(Parent.id_ == parent.id_).update( + Parent.query().where(Parent.id_ == parent.id_).update( {Parent.name: " leading"}, ) -def test_string_check_trailing(session: orm.Session, parent: Parent) -> None: +def test_string_check_trailing(parent: Parent) -> None: with pytest.raises(exc.IntegrityError): - session.query(Parent).where(Parent.id_ == parent.id_).update( + Parent.query().where(Parent.id_ == parent.id_).update( {Parent.name: "trailing "}, ) -def test_string_check_short(session: orm.Session, parent: Parent) -> None: +def test_string_check_short(parent: Parent) -> None: with pytest.raises(exc.IntegrityError): - session.query(Parent).where(Parent.id_ == parent.id_).update({Parent.name: "a"}) + Parent.query().where(Parent.id_ == parent.id_).update({Parent.name: "a"}) def test_clean_decimals() -> None: @@ -293,9 +285,160 @@ def test_clean_decimals() -> None: def test_clean_emoji_name(rand_str: str) -> None: text = rand_str.lower() - assert base.Base.clean_emoji_name(text + " 😀 ") == text + assert Base.clean_emoji_name(text + " 😀 ") == text def test_clean_emoji_name_upper(rand_str: str) -> None: text = rand_str.lower() - assert base.Base.clean_emoji_name(text.upper() + " 😀 ") == text + assert Base.clean_emoji_name(text.upper() + " 😀 ") == text + + +def test_query_kwargs() -> None: + with pytest.raises(exc.NoKeywordArgumentsError): + # Intentional bad argument + Parent.query(kw=None) # type: ignore[attr-defined] + + +def test_unbound_error() -> None: + s = Base._sessions.pop() + with pytest.raises(exc.UnboundExecutionError): + Base.session() + Base._sessions.append(s) + + +def noop[T](x: T) -> T: + return x + + +def lower(s: str) -> str: + return s.lower() + + +def upper(s: str) -> str: + return s.upper() + + +@pytest.mark.parametrize( + ("prop", "value_adjuster"), + [ + ("uri", noop), + ("name", noop), + ("name", lower), + ("name", upper), + ], +) +def test_find( + parent: Parent, + prop: str, + value_adjuster: Callable[[str], str], +) -> None: + parent.name = "Fake" + query = value_adjuster(getattr(parent, prop)) + + cache: dict[str, NamePair] = {} + + result = Parent.find(query, cache) + assert result.id_ == parent.id_ + assert result.name == parent.name + + assert cache == {query: result} + + +def test_find_missing(parent: Parent) -> None: + query = Parent.id_to_uri(parent.id_ + 1) + + cache: dict[str, NamePair] = {} + with pytest.raises(exc.NoResultFound): + Parent.find(query, cache) + + assert not cache + + +def check_no_session_add(line: str) -> str: + if re.match(r"^ *(s|session)\.add\(\w\)", line): + return "Use of session.add found, use Model.create()" + return "" + + +def check_no_session_query(line: str) -> str: + if re.search(r"[( ](s|session)\.query\(", line): + return "Use of session.query found, use Model.query()" + return "" + + +def check_no_query_with_entities(line: str) -> str: + if ".with_entities" in line: # nummus: ignore + return "Use of with_entities found, use Model.query(col, ...)" + return "" + + +def check_no_query_scalar(line: str) -> str: + if (m := re.search(r"(\w*)\.scalar\(", line)) and m.group(1) != "sql": + return "Use of query.scalar found, use sql.scalar()" + return "" + + +def check_no_query_one(line: str) -> str: + if not (m := re.search(r"(\w*)\.one\(", line)): + return "" + g = m.group(1) + if (g and g[0] == g[0].upper()) or g == "sql": + # use of Model.one() + return "" + return "Use of query.one found, use sql.one()" + + +def check_no_query_all(line: str) -> str: + if not (m := re.search(r"(\w*)\.all\(", line)): + return "" + g = m.group(1) + if (g and g[0] == g[0].upper()) or g == "sql": + # use of Model.all() + return "" + return "Use of query.all found, use sql.yield_()" + + +def check_no_query_col0(line: str) -> str: + if re.search(r"for \w+,? in query", line): + return "Use of first column iterator found, use sql.col0()" + return "" + + +@pytest.mark.parametrize( + "path", + sorted( + [ + *Path(nummus.__file__).parent.glob("**/*.py"), + *Path(tests.__file__).parent.glob("**/*.py"), + ], + ), + ids=conftest.id_func, +) +def test_use_of_mixins(path: Path) -> None: + lines = path.read_text("utf-8").splitlines() + + ignore = "# nummus: ignore" + + errors: list[str] = [] + + for i, line in enumerate(lines): + checks = [ + check_no_session_add(line), + check_no_session_query(line), + check_no_query_with_entities(line), + check_no_query_scalar(line), + check_no_query_one(line), + check_no_query_all(line), + check_no_query_col0(line), + ] + checks = [f"{path:}:{i + 1}: {c}" for c in checks if c] + if checks: + if not line.endswith(ignore): + errors.extend(checks) + elif line.endswith(ignore): + errors.append( + f"{path}:{i + 1}: Use of unnecessary 'nummus: ignore'", + ) + + print("\n".join(errors)) + assert not errors diff --git a/tests/models/budget/test_budget_assignment.py b/tests/models/budget/test_budget_assignment.py index f662e27c..00e02ca2 100644 --- a/tests/models/budget/test_budget_assignment.py +++ b/tests/models/budget/test_budget_assignment.py @@ -16,14 +16,11 @@ from nummus.models.transaction_category import TransactionCategory if TYPE_CHECKING: - from sqlalchemy import orm - from nummus.models.account import Account def test_init_properties( month_ord: int, - session: orm.Session, categories: dict[str, int], rand_real: Decimal, ) -> None: @@ -33,9 +30,7 @@ def test_init_properties( "category_id": categories["uncategorized"], } - b = BudgetAssignment(**d) - session.add(b) - session.commit() + b = BudgetAssignment.create(**d) assert b.month_ord == d["month_ord"] assert b.amount == d["amount"] @@ -44,33 +39,27 @@ def test_init_properties( def test_duplicate_months( month_ord: int, - session: orm.Session, categories: dict[str, int], rand_real: Decimal, ) -> None: - b = BudgetAssignment( - month_ord=month_ord, - amount=rand_real, - category_id=categories["uncategorized"], - ) - session.add(b) - b = BudgetAssignment( + BudgetAssignment.create( month_ord=month_ord, amount=rand_real, category_id=categories["uncategorized"], ) - session.add(b) with pytest.raises(exc.IntegrityError): - session.commit() + BudgetAssignment.create( + month_ord=month_ord, + amount=rand_real, + category_id=categories["uncategorized"], + ) def test_get_monthly_available_empty( month: datetime.date, - session: orm.Session, categories: dict[str, int], ) -> None: availables, assignable, future_assigned = BudgetAssignment.get_monthly_available( - session, month, ) assert set(availables.keys()) == set(categories.values()) @@ -82,15 +71,11 @@ def test_get_monthly_available_empty( def test_get_monthly_available( month: datetime.date, - session: orm.Session, categories: dict[str, int], transactions_spending: list[Transaction], budget_assignments: list[BudgetAssignment], ) -> None: - _ = transactions_spending - _ = budget_assignments availables, assignable, future_assigned = BudgetAssignment.get_monthly_available( - session, month, ) availables.pop(categories["other income"]) @@ -131,15 +116,11 @@ def test_get_monthly_available( def test_get_monthly_available_next_month( month: datetime.date, - session: orm.Session, categories: dict[str, int], transactions_spending: list[Transaction], budget_assignments: list[BudgetAssignment], ) -> None: - _ = transactions_spending - _ = budget_assignments availables, assignable, future_assigned = BudgetAssignment.get_monthly_available( - session, utils.date_add_months(month, 1), ) availables.pop(categories["other income"]) @@ -173,7 +154,6 @@ def test_get_monthly_available_next_month( def test_get_emergency_fund_empty( today_ord: int, - session: orm.Session, ) -> None: start_ord = today_ord - 3 end_ord = today_ord + 3 @@ -181,7 +161,6 @@ def test_get_emergency_fund_empty( n_lower = 20 n_upper = 40 result = BudgetAssignment.get_emergency_fund( - session, start_ord, end_ord, n_lower, @@ -197,33 +176,28 @@ def test_get_emergency_fund_empty( def test_get_emergency_fund( today: datetime.date, today_ord: int, - session: orm.Session, account: Account, categories: dict[str, int], transactions_spending: list[Transaction], budget_assignments: list[BudgetAssignment], rand_str: str, ) -> None: - session.query(TransactionCategory).where( + TransactionCategory.query().where( TransactionCategory.name == "groceries", ).update({"essential_spending": True}) # Add a transaction 30 days ago - txn = Transaction( + txn = Transaction.create( account_id=account.id_, date=today - datetime.timedelta(days=30), amount=-50, statement=rand_str, ) - t_split = TransactionSplit( + TransactionSplit.create( parent=txn, amount=txn.amount, category_id=categories["groceries"], ) - session.add_all((txn, t_split)) - session.commit() - _ = transactions_spending - _ = budget_assignments start_ord = today_ord - 3 end_ord = today_ord + 3 @@ -231,7 +205,6 @@ def test_get_emergency_fund( n_lower = 20 n_upper = 40 result = BudgetAssignment.get_emergency_fund( - session, start_ord, end_ord, n_lower, @@ -272,17 +245,14 @@ def test_get_emergency_fund( def test_get_emergency_fund_balance( month_ord: int, - session: orm.Session, budget_assignments: list[BudgetAssignment], ) -> None: - _ = budget_assignments start_ord = month_ord - 3 end_ord = month_ord + 3 n_lower = 20 n_upper = 40 result = BudgetAssignment.get_emergency_fund( - session, start_ord, end_ord, n_lower, @@ -297,15 +267,13 @@ def test_get_emergency_fund_balance( def test_move_from_income( month_ord: int, - session: orm.Session, categories: dict[str, int], ) -> None: src_cat_id = None dest_cat_id = categories["groceries"] - BudgetAssignment.move(session, month_ord, src_cat_id, dest_cat_id, Decimal(100)) - session.commit() + BudgetAssignment.move(month_ord, src_cat_id, dest_cat_id, Decimal(100)) - a = session.query(BudgetAssignment).one() + a = BudgetAssignment.one() assert a.category_id == dest_cat_id assert a.month_ord == month_ord assert a.amount == 100 @@ -324,7 +292,6 @@ def test_move_from_income( ) def test_move_to_income_partial( month_ord: int, - session: orm.Session, categories: dict[str, int], budget_assignments: list[BudgetAssignment], src: str | None, @@ -333,14 +300,12 @@ def test_move_to_income_partial( target_src: Decimal | None, target_dest: Decimal | None, ) -> None: - _ = budget_assignments src_cat_id = None if src is None else categories[src] dest_cat_id = None if dest is None else categories[dest] - BudgetAssignment.move(session, month_ord, src_cat_id, dest_cat_id, to_move) - session.commit() + BudgetAssignment.move(month_ord, src_cat_id, dest_cat_id, to_move) a = ( - session.query(BudgetAssignment) + BudgetAssignment.query() .where( BudgetAssignment.category_id == src_cat_id, BudgetAssignment.month_ord == month_ord, @@ -355,7 +320,7 @@ def test_move_to_income_partial( assert a.amount == target_src a = ( - session.query(BudgetAssignment) + BudgetAssignment.query() .where( BudgetAssignment.category_id == dest_cat_id, BudgetAssignment.month_ord == month_ord, diff --git a/tests/models/budget/test_budget_group.py b/tests/models/budget/test_budget_group.py index 32f81737..36bef2e1 100644 --- a/tests/models/budget/test_budget_group.py +++ b/tests/models/budget/test_budget_group.py @@ -8,55 +8,41 @@ from nummus.models.budget import BudgetGroup if TYPE_CHECKING: - from sqlalchemy import orm - from tests.conftest import RandomStringGenerator -def test_init_properties(session: orm.Session, rand_str: str) -> None: +def test_init_properties(rand_str: str) -> None: d = { "name": rand_str, "position": 0, } - g = BudgetGroup(**d) - session.add(g) - session.commit() + g = BudgetGroup.create(**d) assert g.name == d["name"] assert g.position == d["position"] def test_duplicate_names( - session: orm.Session, rand_str: str, ) -> None: - g = BudgetGroup(name=rand_str, position=0) - session.add(g) - g = BudgetGroup(name=rand_str, position=1) - session.add(g) + BudgetGroup.create(name=rand_str, position=0) with pytest.raises(exc.IntegrityError): - session.commit() + BudgetGroup.create(name=rand_str, position=0) def test_duplicate_position( - session: orm.Session, rand_str_generator: RandomStringGenerator, ) -> None: - g = BudgetGroup(name=rand_str_generator(), position=0) - session.add(g) - g = BudgetGroup(name=rand_str_generator(), position=0) - session.add(g) + BudgetGroup.create(name=rand_str_generator(), position=0) with pytest.raises(exc.IntegrityError): - session.commit() + BudgetGroup.create(name=rand_str_generator(), position=0) -def test_empty(session: orm.Session) -> None: - g = BudgetGroup(name="", position=0) - session.add(g) +def test_empty() -> None: with pytest.raises(exc.IntegrityError): - session.commit() + BudgetGroup.create(name="", position=0) def test_short() -> None: with pytest.raises(exc.InvalidORMValueError): - BudgetGroup(name="a", position=0) + BudgetGroup.create(name="a", position=0) diff --git a/tests/models/budget/test_target.py b/tests/models/budget/test_target.py index 439a1023..897a6d64 100644 --- a/tests/models/budget/test_target.py +++ b/tests/models/budget/test_target.py @@ -12,13 +12,10 @@ import datetime from decimal import Decimal - from sqlalchemy import orm - def test_init_properties( today: datetime.date, today_ord: int, - session: orm.Session, rand_real: Decimal, categories: dict[str, int], ) -> None: @@ -31,9 +28,7 @@ def test_init_properties( "repeat_every": 0, } - t = Target(**d) - session.add(t) - session.commit() + t = Target.create(**d) assert t.category_id == d["category_id"] assert t.amount == d["amount"] @@ -69,7 +64,6 @@ def test_init_properties( ) def test_check_constraints( today_ord: int, - session: orm.Session, rand_real: Decimal, categories: dict[str, int], period: TargetPeriod, @@ -77,7 +71,7 @@ def test_check_constraints( kwargs: dict[str, object], success: bool, ) -> None: - d = { + d: dict[str, object] = { "category_id": categories["uncategorized"], "amount": rand_real, "type_": type_, @@ -86,18 +80,15 @@ def test_check_constraints( "repeat_every": 0, } d.update(kwargs) - t = Target(**d) - session.add(t) if success: - session.commit() + Target.create(**d) else: with pytest.raises(exc.IntegrityError): - session.commit() + Target.create(**d) def test_duplicates( today_ord: int, - session: orm.Session, rand_real: Decimal, categories: dict[str, int], ) -> None: @@ -110,16 +101,12 @@ def test_duplicates( "repeat_every": 0, } - t = Target(**d) - session.add(t) - t = Target(**d) - session.add(t) + Target.create(**d) with pytest.raises(exc.IntegrityError): - session.commit() + Target.create(**d) def test_date_none( - session: orm.Session, rand_real: Decimal, categories: dict[str, int], ) -> None: @@ -132,7 +119,6 @@ def test_date_none( "repeat_every": 0, } - t = Target(**d) - session.add(t) + t = Target.create(**d) assert t.due_date is None diff --git a/tests/models/label/test_label.py b/tests/models/label/test_label.py index 99f7c5f2..ff7ab0e9 100644 --- a/tests/models/label/test_label.py +++ b/tests/models/label/test_label.py @@ -1,23 +1,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING - import pytest from nummus import exceptions as exc from nummus.models.label import Label -if TYPE_CHECKING: - - from sqlalchemy import orm - -def test_init_properties(session: orm.Session, rand_str: str) -> None: +def test_init_properties(rand_str: str) -> None: d = {"name": rand_str} - label = Label(**d) - session.add(label) - session.commit() + label = Label.create(**d) assert label.name == d["name"] diff --git a/tests/models/label/test_label_link.py b/tests/models/label/test_label_link.py index 16022d89..d87effdc 100644 --- a/tests/models/label/test_label_link.py +++ b/tests/models/label/test_label_link.py @@ -2,18 +2,14 @@ from typing import TYPE_CHECKING +from nummus import sql from nummus.models.label import Label, LabelLink -from nummus.models.utils import query_count if TYPE_CHECKING: - - from sqlalchemy import orm - from nummus.models.transaction import Transaction def test_init_properties( - session: orm.Session, labels: dict[str, int], transactions: list[Transaction], ) -> None: @@ -22,47 +18,38 @@ def test_init_properties( "t_split_id": transactions[-1].splits[0].id_, } - link = LabelLink(**d) - session.add(link) - session.commit() + link = LabelLink.create(**d) assert link.label_id == d["label_id"] assert link.t_split_id == d["t_split_id"] def test_add_links_delete( - session: orm.Session, transactions: list[Transaction], labels: dict[str, int], ) -> None: new_labels: dict[int, set[str]] = {txn.splits[0].id_: set() for txn in transactions} - LabelLink.add_links(session, new_labels) + LabelLink.add_links(new_labels) - n = query_count(session.query(LabelLink)) - assert n == 0 - - n = query_count(session.query(Label)) - assert n == len(labels) + assert not sql.any_(LabelLink.query()) + assert sql.count(Label.query()) == len(labels) def test_add_links( - session: orm.Session, transactions: list[Transaction], rand_str: str, labels: dict[str, int], ) -> None: new_labels: dict[int, set[str]] = { - txn.splits[0].id_: {rand_str} for txn in transactions + txn.splits[0].id_: {rand_str, "engineer"} for txn in transactions } - LabelLink.add_links(session, new_labels) - - n = query_count(session.query(LabelLink)) - assert n == len(transactions) + LabelLink.add_links(new_labels) - n = query_count(session.query(Label)) - assert n == len(labels) + 1 + assert sql.count(LabelLink.query()) == len(transactions) * 2 + assert sql.count(Label.query()) == len(labels) + 1 - label = session.query(Label).where(Label.id_.not_in(labels.values())).one() + query = Label.query().where(Label.id_.not_in(labels.values())) + label = sql.one(query) assert label.name == rand_str diff --git a/tests/models/test_account.py b/tests/models/test_account.py index b14f9052..4f1df5d5 100644 --- a/tests/models/test_account.py +++ b/tests/models/test_account.py @@ -7,11 +7,9 @@ from nummus import exceptions as exc from nummus.models.account import Account, AccountCategory -from nummus.models.currency import DEFAULT_CURRENCY +from nummus.models.currency import Currency, DEFAULT_CURRENCY if TYPE_CHECKING: - from sqlalchemy import orm - from nummus.models.asset import Asset, AssetValuation from nummus.models.transaction import Transaction from tests.conftest import RandomStringGenerator @@ -19,7 +17,6 @@ def test_init_properties( rand_str_generator: RandomStringGenerator, - session: orm.Session, ) -> None: d = { "name": rand_str_generator(), @@ -29,10 +26,7 @@ def test_init_properties( "budgeted": False, "currency": DEFAULT_CURRENCY, } - acct = Account(**d) - - session.add(acct) - session.commit() + acct = Account.create(**d) assert acct.name == d["name"] assert acct.institution == d["institution"] @@ -47,14 +41,13 @@ def test_short(account: Account) -> None: account.name = "a" -def test_ids(session: orm.Session, account: Account) -> None: - ids = Account.ids(session, AccountCategory.CASH) +def test_ids(account: Account) -> None: + ids = Account.ids(AccountCategory.CASH) assert ids == {account.id_} -def test_ids_none(session: orm.Session, account: Account) -> None: - _ = account - ids = Account.ids(session, AccountCategory.CREDIT) +def test_ids_none(account: Account) -> None: + ids = Account.ids(AccountCategory.CREDIT) assert ids == set() @@ -63,14 +56,12 @@ def test_date_properties( account: Account, transactions: list[Transaction], ) -> None: - _ = transactions assert account.opened_on_ord == today_ord - 3 assert account.updated_on_ord == today_ord + 7 def test_get_asset_qty_empty( today_ord: int, - session: orm.Session, account: Account, ) -> None: start_ord = today_ord - 3 @@ -80,22 +71,19 @@ def test_get_asset_qty_empty( # defaultdict is correct length assert result[0] == [Decimal()] * 7 - result = Account.get_asset_qty_all(session, start_ord, end_ord) + result = Account.get_asset_qty_all(start_ord, end_ord) assert result == {} def test_get_asset_qty_none( today_ord: int, - session: orm.Session, account: Account, transactions: list[Transaction], ) -> None: - _ = account - _ = transactions start_ord = today_ord - 3 end_ord = today_ord + 3 - result = Account.get_asset_qty_all(session, start_ord, end_ord, set()) + result = Account.get_asset_qty_all(start_ord, end_ord, set()) # defaultdict is correct length assert result[0][0] == [Decimal()] * 7 @@ -106,7 +94,6 @@ def test_get_asset_qty( asset: Asset, transactions: list[Transaction], ) -> None: - _ = transactions start_ord = today_ord - 3 end_ord = today_ord + 3 result_qty = account.get_asset_qty(start_ord, end_ord) @@ -130,14 +117,12 @@ def test_get_asset_qty_today( asset: Asset, transactions: list[Transaction], ) -> None: - _ = transactions result_qty = account.get_asset_qty(today_ord, today_ord) assert result_qty == {asset.id_: [Decimal(10)]} def test_get_value_empty( today_ord: int, - session: orm.Session, account: Account, ) -> None: start_ord = today_ord - 3 @@ -149,7 +134,7 @@ def test_get_value_empty( # defaultdict is correct length assert assets[0] == [Decimal()] * 7 - values, profits, assets = Account.get_value_all(session, start_ord, end_ord) + values, profits, assets = Account.get_value_all(start_ord, end_ord) assert values == {} assert profits == {} assert assets == {} @@ -157,16 +142,13 @@ def test_get_value_empty( def test_get_value_none( today_ord: int, - session: orm.Session, account: Account, transactions: list[Transaction], ) -> None: - _ = account - _ = transactions start_ord = today_ord - 3 end_ord = today_ord + 3 - values, profits, assets = Account.get_value_all(session, start_ord, end_ord, set()) + values, profits, assets = Account.get_value_all(start_ord, end_ord, set()) assert values == {} assert profits == {} assert assets == {} @@ -182,8 +164,6 @@ def test_get_value( asset_valuation: AssetValuation, transactions: list[Transaction], ) -> None: - _ = transactions - _ = asset_valuation start_ord = today_ord - 4 end_ord = today_ord + 3 values, profits, assets = account.get_value(start_ord, end_ord) @@ -224,6 +204,55 @@ def test_get_value( assert assets == target +def test_get_value_forex( + today_ord: int, + account: Account, + asset: Asset, + asset_valuation: AssetValuation, + transactions: list[Transaction], +) -> None: + start_ord = today_ord - 4 + end_ord = today_ord + 3 + f = 2 + forex: dict[Currency, list[Decimal]] = {Currency.USD: [Decimal(f)] * 8} + values, profits, assets = account.get_value(start_ord, end_ord, forex=forex) + target = [ + Decimal(), + Decimal(100) * f, + Decimal(90) * f, + Decimal(90) * f, + Decimal(110) * f, + Decimal(150) * f, + Decimal(150) * f, + Decimal(150) * f, + ] + assert values == target + target = [ + Decimal(), + Decimal(), + Decimal(-10) * f, + Decimal(-10) * f, + Decimal(10) * f, + Decimal(50) * f, + Decimal(50) * f, + Decimal(50) * f, + ] + assert profits == target + target = { + asset.id_: [ + Decimal(), + Decimal(), + Decimal(), + Decimal(), + Decimal(20) * f, + Decimal(10) * f, + Decimal(10) * f, + Decimal(10) * f, + ], + } + assert assets == target + + def test_get_value_today( today_ord: int, account: Account, @@ -231,8 +260,6 @@ def test_get_value_today( asset_valuation: AssetValuation, transactions: list[Transaction], ) -> None: - _ = transactions - _ = asset_valuation values, profits, assets = account.get_value(today_ord, today_ord) assert values == [Decimal(110)] assert profits == [Decimal()] @@ -245,7 +272,6 @@ def test_get_value_buy_day( asset: Asset, transactions: list[Transaction], ) -> None: - _ = transactions values, profits, assets = account.get_value(today_ord - 2, today_ord - 2) assert values == [Decimal(90)] assert profits == [Decimal(-10)] @@ -257,7 +283,6 @@ def test_get_value_fund_day( account: Account, transactions: list[Transaction], ) -> None: - _ = transactions values, profits, assets = account.get_value(today_ord - 3, today_ord - 3) assert values == [Decimal(100)] assert profits == [Decimal()] @@ -266,7 +291,6 @@ def test_get_value_fund_day( def test_get_cash_flow_empty( today_ord: int, - session: orm.Session, account: Account, ) -> None: start_ord = today_ord - 3 @@ -276,7 +300,7 @@ def test_get_cash_flow_empty( # defaultdict is correct length assert result[0] == [Decimal()] * 7 - result = Account.get_cash_flow_all(session, start_ord, end_ord) + result = Account.get_cash_flow_all(start_ord, end_ord) assert result == {} @@ -286,7 +310,6 @@ def test_get_cash_flow( transactions: list[Transaction], categories: dict[str, int], ) -> None: - _ = transactions start_ord = today_ord - 3 end_ord = today_ord + 3 result = account.get_cash_flow(start_ord, end_ord) @@ -318,14 +341,12 @@ def test_get_cash_flow_today( account: Account, transactions: list[Transaction], ) -> None: - _ = transactions result = account.get_cash_flow(today_ord, today_ord) assert result == {} def test_get_profit_by_asset_empty( today_ord: int, - session: orm.Session, account: Account, ) -> None: start_ord = today_ord - 3 @@ -334,7 +355,7 @@ def test_get_profit_by_asset_empty( assert result == {} assert result[0] == Decimal() - result = Account.get_profit_by_asset_all(session, start_ord, end_ord) + result = Account.get_profit_by_asset_all(start_ord, end_ord) assert result == {} @@ -345,8 +366,6 @@ def test_get_profit_by_asset( transactions: list[Transaction], asset_valuation: AssetValuation, ) -> None: - _ = transactions - _ = asset_valuation start_ord = today_ord - 3 end_ord = today_ord + 3 result = account.get_profit_by_asset(start_ord, end_ord) @@ -362,7 +381,6 @@ def test_get_profit_by_asset_today( asset: Asset, transactions: list[Transaction], ) -> None: - _ = transactions result = account.get_profit_by_asset(today_ord, today_ord) assert result == {asset.id_: Decimal()} @@ -387,6 +405,14 @@ def test_do_include_closed( account: Account, transactions: list[Transaction], ) -> None: - _ = transactions account.closed = True assert account.do_include(today_ord) + + +def test_find( + account: Account, +) -> None: + account.number = "1234567890" + id_, name = Account.find("7890", {}) + assert id_ == account.id_ + assert name == account.name diff --git a/tests/models/test_base_uri.py b/tests/models/test_base_uri.py index 548476d7..18e6b180 100644 --- a/tests/models/test_base_uri.py +++ b/tests/models/test_base_uri.py @@ -86,7 +86,7 @@ def test_empty_uri() -> None: def test_symmetrical_unique() -> None: - uris = set() + uris: set[str] = set() n = 10000 for i in range(n): @@ -147,7 +147,8 @@ def test_table_ids_no_duplicates() -> None: table_ids: set[int] = set() for m in MODELS_URI: - t_id: int = m.__table_id__ + t_id = m.__table_id__ + assert t_id is not None assert t_id not in table_ids table_ids.add(t_id) diff --git a/tests/models/test_config.py b/tests/models/test_config.py index 074d194b..81fa7f57 100644 --- a/tests/models/test_config.py +++ b/tests/models/test_config.py @@ -6,84 +6,77 @@ from packaging.version import Version from nummus import exceptions as exc +from nummus import sql from nummus.migrations.top import MIGRATORS from nummus.models.config import Config, ConfigKey from nummus.models.currency import DEFAULT_CURRENCY from nummus.version import __version__ if TYPE_CHECKING: - from sqlalchemy import orm - from tests.conftest import RandomStringGenerator -def test_init_properties(session: orm.Session, rand_str: str) -> None: +def test_init_properties(rand_str: str) -> None: d = { "key": ConfigKey.WEB_KEY, "value": rand_str, } - c = Config(**d) - session.add(c) - session.commit() + c = Config.create(**d) assert c.key == d["key"] assert c.value == d["value"] def test_duplicate_keys( - session: orm.Session, rand_str_generator: RandomStringGenerator, ) -> None: - c = Config(key=ConfigKey.WEB_KEY, value=rand_str_generator()) - session.add(c) - c = Config(key=ConfigKey.WEB_KEY, value=rand_str_generator()) - session.add(c) + Config.create(key=ConfigKey.WEB_KEY, value=rand_str_generator()) with pytest.raises(exc.IntegrityError): - session.commit() + Config.create(key=ConfigKey.WEB_KEY, value=rand_str_generator()) -def test_empty(session: orm.Session) -> None: - c = Config(key=ConfigKey.WEB_KEY, value="") - session.add(c) +def test_empty() -> None: with pytest.raises(exc.IntegrityError): - session.commit() + Config.create(key=ConfigKey.WEB_KEY, value="") def test_short() -> None: with pytest.raises(exc.InvalidORMValueError): - Config(key=ConfigKey.WEB_KEY, value="a") + Config.create(key=ConfigKey.WEB_KEY, value="a") + +def test_set(rand_str: str) -> None: + Config.set_(ConfigKey.VERSION, rand_str) + assert Config.fetch(ConfigKey.VERSION) == rand_str -def test_set(session: orm.Session, rand_str: str) -> None: - Config.set_(session, ConfigKey.WEB_KEY, rand_str) - session.commit() - v = session.query(Config.value).where(Config.key == ConfigKey.WEB_KEY).scalar() - assert v == rand_str +def test_set_new(rand_str: str) -> None: + Config.set_(ConfigKey.WEB_KEY, rand_str) + assert Config.fetch(ConfigKey.WEB_KEY) == rand_str -def test_fetch(session: orm.Session) -> None: - target = session.query(Config.value).where(Config.key == ConfigKey.VERSION).scalar() - assert Config.fetch(session, ConfigKey.VERSION) == target +def test_fetch() -> None: + v = sql.scalar(Config.query(Config.value).where(Config.key == ConfigKey.VERSION)) + assert Config.fetch(ConfigKey.VERSION) == v -def test_fetch_missing(session: orm.Session) -> None: +def test_fetch_missing() -> None: with pytest.raises(exc.ProtectedObjectNotFoundError): - Config.fetch(session, ConfigKey.WEB_KEY) + Config.fetch(ConfigKey.WEB_KEY) -def test_fetch_missing_ok(session: orm.Session) -> None: - assert Config.fetch(session, ConfigKey.WEB_KEY, no_raise=True) is None +def test_fetch_missing_ok() -> None: + assert Config.fetch(ConfigKey.WEB_KEY, no_raise=True) is None -def test_db_version(session: orm.Session) -> None: +def test_db_version() -> None: target = max( Version(__version__), *[m.min_version() for m in MIGRATORS], ) - assert Config.db_version(session) == target + assert Config.db_version() == target -def test_base_currency(session: orm.Session) -> None: - assert Config.base_currency(session) == DEFAULT_CURRENCY +def test_base_currency() -> None: + assert Config.base_currency() == DEFAULT_CURRENCY diff --git a/tests/models/test_health_checks.py b/tests/models/test_health_checks.py index b1356f3c..7b13457a 100644 --- a/tests/models/test_health_checks.py +++ b/tests/models/test_health_checks.py @@ -8,13 +8,10 @@ from nummus.models.health_checks import HealthCheckIssue if TYPE_CHECKING: - from sqlalchemy import orm - from tests.conftest import RandomStringGenerator def test_init_properties( - session: orm.Session, rand_str_generator: RandomStringGenerator, ) -> None: d = { @@ -24,9 +21,7 @@ def test_init_properties( "ignore": False, } - i = HealthCheckIssue(**d) - session.add(i) - session.commit() + i = HealthCheckIssue.create(**d) assert i.check == d["check"] assert i.value == d["value"] @@ -35,7 +30,6 @@ def test_init_properties( def test_duplicate_keys( - session: orm.Session, rand_str_generator: RandomStringGenerator, ) -> None: d = { @@ -44,12 +38,9 @@ def test_duplicate_keys( "msg": rand_str_generator(), "ignore": False, } - i = HealthCheckIssue(**d) - session.add(i) - i = HealthCheckIssue(**d) - session.add(i) + HealthCheckIssue.create(**d) with pytest.raises(exc.IntegrityError): - session.commit() + HealthCheckIssue.create(**d) def test_short_check(rand_str_generator: RandomStringGenerator) -> None: @@ -64,7 +55,6 @@ def test_short_check(rand_str_generator: RandomStringGenerator) -> None: def test_short_value( - session: orm.Session, rand_str_generator: RandomStringGenerator, ) -> None: d = { @@ -73,6 +63,4 @@ def test_short_value( "msg": rand_str_generator(), "ignore": False, } - i = HealthCheckIssue(**d) - session.add(i) - session.commit() + HealthCheckIssue.create(**d) diff --git a/tests/models/test_imported_file.py b/tests/models/test_imported_file.py index f240c650..54d74a8a 100644 --- a/tests/models/test_imported_file.py +++ b/tests/models/test_imported_file.py @@ -1,34 +1,23 @@ from __future__ import annotations -from typing import TYPE_CHECKING - import pytest from nummus import exceptions as exc -from nummus.models import imported_file - -if TYPE_CHECKING: - from sqlalchemy import orm +from nummus.models.imported_file import ImportedFile def test_init_properties( - session: orm.Session, rand_str: str, today_ord: int, ) -> None: - f = imported_file.ImportedFile(hash_=rand_str) - session.add(f) - session.commit() + f = ImportedFile.create(hash_=rand_str) # Default date is today assert f.date_ord == today_ord assert f.hash_ == rand_str -def test_duplicates(session: orm.Session, rand_str: str) -> None: - f = imported_file.ImportedFile(hash_=rand_str) - session.add(f) - f = imported_file.ImportedFile(hash_=rand_str) - session.add(f) +def test_duplicates(rand_str: str) -> None: + ImportedFile.create(hash_=rand_str) with pytest.raises(exc.IntegrityError): - session.commit() + ImportedFile.create(hash_=rand_str) diff --git a/tests/models/test_transaction_category.py b/tests/models/test_transaction_category.py index 677adb0b..ac3a559b 100644 --- a/tests/models/test_transaction_category.py +++ b/tests/models/test_transaction_category.py @@ -11,13 +11,10 @@ ) if TYPE_CHECKING: - from sqlalchemy import orm - from nummus.models.budget import BudgetGroup def test_init_properties( - session: orm.Session, rand_str: str, budget_group: BudgetGroup, ) -> None: @@ -32,9 +29,7 @@ def test_init_properties( "budget_position": 0, } - t_cat = TransactionCategory(**d) - session.add(t_cat) - session.commit() + t_cat = TransactionCategory.create(**d) assert t_cat.name == rand_str.lower() assert t_cat.emoji_name == d["emoji_name"] @@ -62,16 +57,16 @@ def test_name_direct() -> None: TransactionCategory(name="a") -def test_name_no_position(session: orm.Session, budget_group: BudgetGroup) -> None: +def test_name_no_position(budget_group: BudgetGroup) -> None: with pytest.raises(exc.IntegrityError): - session.query(TransactionCategory).where( + TransactionCategory.query().where( TransactionCategory.name == "transfers", ).update({TransactionCategory.budget_group_id: budget_group.id_}) -def test_name_no_group(session: orm.Session) -> None: +def test_name_no_group() -> None: with pytest.raises(exc.IntegrityError): - session.query(TransactionCategory).where( + TransactionCategory.query().where( TransactionCategory.name == "transfers", ).update({TransactionCategory.budget_position: 0}) @@ -84,18 +79,17 @@ def test_essential_income() -> None: ) -def test_essential_income_update(session: orm.Session) -> None: +def test_essential_income_update() -> None: with pytest.raises(exc.IntegrityError): - session.query(TransactionCategory).where( + TransactionCategory.query().where( TransactionCategory.name == "other income", ).update({TransactionCategory.essential_spending: True}) -def test_essential_expense(session: orm.Session) -> None: - session.query(TransactionCategory).where( +def test_essential_expense() -> None: + TransactionCategory.query().where( TransactionCategory.name == "groceries", ).update({TransactionCategory.essential_spending: True}) - session.commit() def test_essential_none() -> None: @@ -103,71 +97,67 @@ def test_essential_none() -> None: TransactionCategory(essential_spending=None) -def test_emergency_fund_missing(session: orm.Session) -> None: - session.query(TransactionCategory).delete() +def test_emergency_fund_missing() -> None: + TransactionCategory.query().delete() with pytest.raises(exc.ProtectedObjectNotFoundError): - TransactionCategory.emergency_fund(session) + TransactionCategory.emergency_fund() -def test_emergency_fund(session: orm.Session, categories: dict[str, int]) -> None: - result = TransactionCategory.emergency_fund(session) +def test_emergency_fund(categories: dict[str, int]) -> None: + result = TransactionCategory.emergency_fund() t_cat_id = categories["emergency fund"] assert result == (t_cat_id, TransactionCategory.id_to_uri(t_cat_id)) -def test_uncategorized(session: orm.Session, categories: dict[str, int]) -> None: - result = TransactionCategory.uncategorized(session) +def test_uncategorized(categories: dict[str, int]) -> None: + result = TransactionCategory.uncategorized() t_cat_id = categories["uncategorized"] assert result == (t_cat_id, TransactionCategory.id_to_uri(t_cat_id)) -def test_securities_traded(session: orm.Session, categories: dict[str, int]) -> None: - result = TransactionCategory.securities_traded(session) +def test_securities_traded(categories: dict[str, int]) -> None: + result = TransactionCategory.securities_traded() t_cat_id = categories["securities traded"] assert result == (t_cat_id, TransactionCategory.id_to_uri(t_cat_id)) def test_map_name( - session: orm.Session, categories: dict[str, int], ) -> None: - result = TransactionCategory.map_name(session) + result = TransactionCategory.map_name() assert result[categories["uncategorized"]] == "uncategorized" assert result[categories["securities traded"]] == "securities traded" def test_map_name_no_asset_linked( - session: orm.Session, categories: dict[str, int], ) -> None: - result = TransactionCategory.map_name(session, no_asset_linked=True) + result = TransactionCategory.map_name(no_asset_linked=True) assert result[categories["uncategorized"]] == "uncategorized" assert categories["securities traded"] not in result def test_map_name_emoji( - session: orm.Session, categories: dict[str, int], ) -> None: - session.query(TransactionCategory).where( + TransactionCategory.query().where( TransactionCategory.name == "uncategorized", ).update( {TransactionCategory.emoji_name: "🤷 Uncategorized 🤷"}, ) - result = TransactionCategory.map_name_emoji(session) + result = TransactionCategory.map_name_emoji() assert result[categories["uncategorized"]] == "🤷 Uncategorized 🤷" assert result[categories["securities traded"]] == "Securities Traded" def test_map_name_emoji_no_asset_linked( - session: orm.Session, categories: dict[str, int], ) -> None: - session.query(TransactionCategory).where( + TransactionCategory.query().where( TransactionCategory.name == "uncategorized", ).update( {TransactionCategory.emoji_name: "🤷 Uncategorized 🤷"}, ) - result = TransactionCategory.map_name_emoji(session, no_asset_linked=True) + result = TransactionCategory.map_name_emoji(no_asset_linked=True) assert result[categories["uncategorized"]] == "🤷 Uncategorized 🤷" assert categories["securities traded"] not in result diff --git a/tests/models/test_utils.py b/tests/models/test_utils.py index 75850d56..227ed446 100644 --- a/tests/models/test_utils.py +++ b/tests/models/test_utils.py @@ -6,12 +6,10 @@ import pytest from sqlalchemy import CheckConstraint, ForeignKeyConstraint, UniqueConstraint -from nummus import exceptions as exc +from nummus import sql from nummus.models import utils from nummus.models.account import Account from nummus.models.asset import ( - Asset, - AssetCategory, AssetSplit, AssetValuation, ) @@ -20,40 +18,36 @@ if TYPE_CHECKING: import datetime - import sqlalchemy - from sqlalchemy import orm - + from nummus.models.asset import ( + Asset, + ) from tests.conftest import RandomStringGenerator @pytest.fixture def transactions( - session: orm.Session, today: datetime.date, account: Account, categories: dict[str, int], rand_str_generator: RandomStringGenerator, ) -> list[Transaction]: for _ in range(10): - txn = Transaction( + txn = Transaction.create( account_id=account.id_, date=today, amount=100, statement=rand_str_generator(), ) - t_split = TransactionSplit( + TransactionSplit.create( amount=100, parent=txn, category_id=categories["uncategorized"], ) - session.add_all((txn, t_split)) - session.commit() - return session.query(Transaction).all() + return Transaction.all() @pytest.fixture def valuations( - session: orm.Session, today_ord: int, asset: Asset, ) -> list[AssetValuation]: @@ -63,14 +57,12 @@ def valuations( today_ord: {"value": Decimal(100), "asset_id": a_id}, } - query = session.query(AssetValuation) - utils.update_rows(session, AssetValuation, query, "date_ord", updates) - session.commit() - return query.all() + utils.update_rows(AssetValuation, AssetValuation.query(), "date_ord", updates) + return AssetValuation.all() -def test_paginate_all(session: orm.Session, transactions: list[Transaction]) -> None: - page, count, next_offset = utils.paginate(session.query(Transaction), 50, 0) +def test_paginate_all(transactions: list[Transaction]) -> None: + page, count, next_offset = utils.paginate(Transaction.query(), 50, 0) assert page == transactions assert count == len(transactions) assert next_offset is None @@ -78,11 +70,10 @@ def test_paginate_all(session: orm.Session, transactions: list[Transaction]) -> @pytest.mark.parametrize("offset", range(10)) def test_paginate_three( - session: orm.Session, transactions: list[Transaction], offset: int, ) -> None: - page, count, next_offset = utils.paginate(session.query(Transaction), 3, offset) + page, count, next_offset = utils.paginate(Transaction.query(), 3, offset) assert page == transactions[offset : offset + 3] assert count == len(transactions) if offset >= (len(transactions) - 3): @@ -91,102 +82,80 @@ def test_paginate_three( assert next_offset == offset + 3 -def test_paginate_three_page_1000( - session: orm.Session, - transactions: list[Transaction], -) -> None: - page, count, next_offset = utils.paginate(session.query(Transaction), 3, 1000) +def test_paginate_three_page_1000(transactions: list[Transaction]) -> None: + page, count, next_offset = utils.paginate(Transaction.query(), 3, 1000) assert page == [] assert count == len(transactions) assert next_offset is None -def test_paginate_three_page_n1000( - session: orm.Session, - transactions: list[Transaction], -) -> None: - page, count, next_offset = utils.paginate(session.query(Transaction), 3, -1000) +def test_paginate_three_page_n1000(transactions: list[Transaction]) -> None: + page, count, next_offset = utils.paginate(Transaction.query(), 3, -1000) assert page == transactions[0:3] assert count == len(transactions) assert next_offset == 3 -def test_dump_table_configs(session: orm.Session) -> None: - result = utils.dump_table_configs(session, Account) +def test_dump_table_configs() -> None: + result = utils.dump_table_configs(Account) assert result[0] == "CREATE TABLE account (" assert result[-1] == ")" assert "\t" not in "\n".join(result) -def test_get_constraints(session: orm.Session) -> None: +def test_get_constraints() -> None: target = [ (UniqueConstraint, "asset_id, date_ord"), (CheckConstraint, "multiplier > 0"), (ForeignKeyConstraint, "asset_id"), ] - assert utils.get_constraints(session, AssetSplit) == target - - -def test_obj_session(session: orm.Session, account: Account) -> None: - result = utils.obj_session(account) - assert result == session - - -def test_obj_session_detached() -> None: - acct = Account() - with pytest.raises(exc.UnboundExecutionError): - utils.obj_session(acct) + assert utils.get_constraints(AssetSplit) == target def test_update_rows_new( - session: orm.Session, today_ord: int, valuations: list[AssetValuation], ) -> None: - query = session.query(AssetValuation) - assert utils.query_count(query) == len(valuations) + assert sql.count(AssetValuation.query()) == len(valuations) - v = query.where(AssetValuation.date_ord == today_ord).one() + v = sql.one(AssetValuation.query().where(AssetValuation.date_ord == today_ord)) assert v.value == Decimal(100) - v = query.where(AssetValuation.date_ord == (today_ord - 1)).one() + v = sql.one( + AssetValuation.query().where(AssetValuation.date_ord == (today_ord - 1)), + ) assert v.value == Decimal(10) def test_update_rows_edit( - session: orm.Session, today_ord: int, asset: Asset, valuations: list[AssetValuation], ) -> None: - query = session.query(AssetValuation) + query = AssetValuation.query() updates: dict[object, dict[str, object]] = { today_ord - 2: {"value": Decimal(5), "asset_id": asset.id_}, today_ord: {"value": Decimal(50), "asset_id": asset.id_}, } - utils.update_rows(session, AssetValuation, query, "date_ord", updates) - session.commit() - assert utils.query_count(query) == len(valuations) + utils.update_rows(AssetValuation, query, "date_ord", updates) + assert sql.count(query) == len(valuations) - v = query.where(AssetValuation.date_ord == today_ord).one() + v = sql.one(AssetValuation.query().where(AssetValuation.date_ord == today_ord)) assert v.value == Decimal(50) - v = query.where(AssetValuation.date_ord == (today_ord - 2)).one() + v = sql.one( + AssetValuation.query().where(AssetValuation.date_ord == (today_ord - 2)), + ) assert v.value == Decimal(5) -def test_update_rows_delete( - session: orm.Session, - valuations: list[AssetValuation], -) -> None: - _ = valuations - query = session.query(AssetValuation) - utils.update_rows(session, AssetValuation, query, "date_ord", {}) - assert utils.query_count(query) == 0 +def test_update_rows_delete(valuations: list[AssetValuation]) -> None: + query = AssetValuation.query() + utils.update_rows(AssetValuation, query, "date_ord", {}) + assert not sql.any_(query) def test_update_rows_list_edit( - session: orm.Session, transactions: list[Transaction], categories: dict[str, int], rand_str_generator: RandomStringGenerator, @@ -211,61 +180,31 @@ def test_update_rows_list_edit( }, ] utils.update_rows_list( - session, TransactionSplit, - session.query(TransactionSplit).where(TransactionSplit.parent_id == txn.id_), + TransactionSplit.query().where(TransactionSplit.parent_id == txn.id_), updates, ) - session.commit() assert t_split_0.parent_id == txn.id_ assert t_split_0.memo == memo_0 assert t_split_0.amount == txn.amount - new_split_amount - t_split_1 = ( - session.query(TransactionSplit) - .where( - TransactionSplit.parent_id == txn.id_, - TransactionSplit.id_ != t_split_0.id_, - ) - .one() + query = TransactionSplit.query().where( + TransactionSplit.parent_id == txn.id_, + TransactionSplit.id_ != t_split_0.id_, ) + t_split_1 = sql.one(query) assert t_split_1.parent_id == txn.id_ assert t_split_1.memo == memo_1 assert t_split_1.amount == new_split_amount def test_update_rows_list_delete( - session: orm.Session, transactions: list[Transaction], ) -> None: txn = transactions[0] utils.update_rows_list( - session, TransactionSplit, - session.query(TransactionSplit).where(TransactionSplit.parent_id == txn.id_), + TransactionSplit.query().where(TransactionSplit.parent_id == txn.id_), [], ) - session.commit() assert len(txn.splits) == 0 - - -@pytest.mark.parametrize( - ("where", "expect_asset"), - [ - ([], False), - ([Asset.category == AssetCategory.STOCKS], True), - ([Asset.category == AssetCategory.BONDS], False), - ], -) -def test_one_or_none( - session: orm.Session, - asset: Asset, - where: list[sqlalchemy.ColumnClause], - expect_asset: bool, -) -> None: - _ = asset - query = session.query(Asset).where(*where) - if expect_asset: - assert utils.one_or_none(query) == asset - else: - assert utils.one_or_none(query) is None diff --git a/tests/models/transaction/test_transaction.py b/tests/models/transaction/test_transaction.py index 900b1a4a..9d187e2d 100644 --- a/tests/models/transaction/test_transaction.py +++ b/tests/models/transaction/test_transaction.py @@ -10,15 +10,12 @@ import datetime from decimal import Decimal - from sqlalchemy import orm - from nummus.models.account import Account from tests.conftest import RandomStringGenerator def test_init_properties( today: datetime.date, - session: orm.Session, account: Account, rand_real: Decimal, rand_str_generator: RandomStringGenerator, @@ -31,9 +28,7 @@ def test_init_properties( "payee": rand_str_generator(), } - txn = Transaction(**d) - session.add(txn) - session.commit() + txn = Transaction.create(**d) assert txn.account_id == account.id_ assert txn.date_ord == today.toordinal() diff --git a/tests/models/transaction/test_transaction_split.py b/tests/models/transaction/test_transaction_split.py index 6458e152..40fa1829 100644 --- a/tests/models/transaction/test_transaction_split.py +++ b/tests/models/transaction/test_transaction_split.py @@ -21,7 +21,6 @@ def test_init_properties( today: datetime.date, - session: orm.Session, account: Account, asset: Asset, categories: dict[str, int], @@ -36,9 +35,7 @@ def test_init_properties( "payee": rand_str_generator(), } - txn = Transaction(**d) - session.add(txn) - session.commit() + txn = Transaction.create(**d) d = { "amount": d["amount"], @@ -49,10 +46,8 @@ def test_init_properties( "memo": rand_str_generator(), } - t_split_0 = TransactionSplit(**d) + t_split_0 = TransactionSplit.create(**d) - session.add(t_split_0) - session.commit() assert t_split_0.parent == txn assert t_split_0.parent_id == txn.id_ assert t_split_0.category_id == d["category_id"] @@ -71,9 +66,8 @@ def test_init_properties( def test_zero_amount(session: orm.Session, transactions: list[Transaction]) -> None: t_split = transactions[1].splits[0] - t_split.amount = Decimal() - with pytest.raises(exc.IntegrityError): - session.commit() + with pytest.raises(exc.IntegrityError), session.begin_nested(): + t_split.amount = Decimal() def test_short() -> None: @@ -104,18 +98,15 @@ def test_unset_asset_quantity( transactions: list[Transaction], ) -> None: t_split = transactions[1].splits[0] - t_split._asset_qty_unadjusted = None - with pytest.raises(exc.IntegrityError): - session.commit() + with pytest.raises(exc.IntegrityError), session.begin_nested(): + t_split._asset_qty_unadjusted = None def test_clear_asset_quantity( - session: orm.Session, transactions: list[Transaction], ) -> None: t_split = transactions[1].splits[0] t_split.asset_quantity_unadjusted = None - session.commit() assert t_split.asset_quantity is None @@ -159,9 +150,8 @@ def test_parent(transactions: list[Transaction]) -> None: assert t_split.parent == txn -def test_search_none(session: orm.Session, transactions: list[Transaction]) -> None: - _ = transactions - query = session.query(TransactionSplit) +def test_search_none(transactions: list[Transaction]) -> None: + query = TransactionSplit.query() with pytest.raises(exc.EmptySearchError): TransactionSplit.search(query, "") @@ -190,11 +180,10 @@ def test_search_none(session: orm.Session, transactions: list[Transaction]) -> N ids=conftest.id_func, ) def test_search( - session: orm.Session, transactions: list[Transaction], search_str: str, target: list[int], ) -> None: - query = session.query(TransactionSplit) + query = TransactionSplit.query() result = TransactionSplit.search(query, search_str) assert result == [transactions[i].splits[0].id_ for i in target] diff --git a/tests/portfolio/test_backup_restore.py b/tests/portfolio/test_backup_restore.py index 048bdb52..6f6bd4fa 100644 --- a/tests/portfolio/test_backup_restore.py +++ b/tests/portfolio/test_backup_restore.py @@ -156,8 +156,8 @@ def test_restore_path_traversal(tmp_path: Path) -> None: def test_restore(empty_portfolio: Portfolio) -> None: # Delete ENCRYPTION_TEST so reload fails - with empty_portfolio.begin_session() as s: - s.query(Config).where(Config.key == ConfigKey.ENCRYPTION_TEST).delete() + with empty_portfolio.begin_session(): + Config.query().where(Config.key == ConfigKey.ENCRYPTION_TEST).delete() empty_portfolio.backup() empty_portfolio.path.unlink() @@ -169,8 +169,8 @@ def test_restore(empty_portfolio: Portfolio) -> None: def test_restore_path(empty_portfolio: Portfolio) -> None: # Delete ENCRYPTION_TEST so reload fails - with empty_portfolio.begin_session() as s: - s.query(Config).where(Config.key == ConfigKey.ENCRYPTION_TEST).delete() + with empty_portfolio.begin_session(): + Config.query().where(Config.key == ConfigKey.ENCRYPTION_TEST).delete() empty_portfolio.backup() empty_portfolio.path.unlink() diff --git a/tests/portfolio/test_change_key.py b/tests/portfolio/test_change_key.py index 7cb5fd3c..ff8352ae 100644 --- a/tests/portfolio/test_change_key.py +++ b/tests/portfolio/test_change_key.py @@ -11,14 +11,14 @@ @pytest.mark.skipif(not ENCRYPTION_AVAILABLE, reason="No encryption available") @pytest.mark.encryption def test_change_db_key( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio_encrypted: tuple[Portfolio, str], rand_str: str, ) -> None: new_key = rand_str p, old_key = empty_portfolio_encrypted - with p.begin_session() as s: - web_key_enc = Config.fetch(s, ConfigKey.WEB_KEY) + with p.begin_session(): + web_key_enc = Config.fetch(ConfigKey.WEB_KEY) web_key = p.decrypt_s(web_key_enc) p.change_key(new_key) @@ -28,8 +28,8 @@ def test_change_db_key( # tqdm in here assert captured.err - with p.begin_session() as s: - web_key_enc = Config.fetch(s, ConfigKey.WEB_KEY) + with p.begin_session(): + web_key_enc = Config.fetch(ConfigKey.WEB_KEY) new_web_key = p.decrypt_s(web_key_enc) assert new_web_key == web_key assert new_web_key != new_key @@ -57,8 +57,8 @@ def test_change_web_key( p, db_key = empty_portfolio_encrypted p.change_web_key(new_key) - with p.begin_session() as s: - web_key_enc = Config.fetch(s, ConfigKey.WEB_KEY) + with p.begin_session(): + web_key_enc = Config.fetch(ConfigKey.WEB_KEY) web_key = p.decrypt_s(web_key_enc) assert web_key == new_key assert web_key != db_key diff --git a/tests/portfolio/test_import.py b/tests/portfolio/test_import.py index 46e33be9..6ede35eb 100644 --- a/tests/portfolio/test_import.py +++ b/tests/portfolio/test_import.py @@ -1,24 +1,21 @@ from __future__ import annotations import datetime -import operator from typing import TYPE_CHECKING import pytest from nummus import exceptions as exc -from nummus.models.account import Account -from nummus.models.asset import Asset +from nummus import sql from nummus.models.transaction import Transaction, TransactionSplit -from nummus.models.utils import query_count -from nummus.portfolio import Portfolio from tests.importers.test_raw_csv import TRANSACTIONS_REQUIRED if TYPE_CHECKING: - from collections.abc import Callable from pathlib import Path - from nummus.models.base import Base + from nummus.models.account import Account + from nummus.models.asset import Asset + from nummus.portfolio import Portfolio def test_import_file( @@ -29,33 +26,30 @@ def test_import_file( asset: Asset, categories: dict[str, int], ) -> None: - _ = account_investments - _ = asset path = data_path / "transactions_required.csv" path_debug = empty_portfolio.path.with_suffix(".importer-debug") # Create first txn to be cleared - with empty_portfolio.begin_session() as s: + with empty_portfolio.begin_session(): d = TRANSACTIONS_REQUIRED[0] - txn = Transaction( + txn = Transaction.create( account_id=account.id_, date=d["date"], amount=d["amount"], statement="Manually imported", ) - t_split = TransactionSplit( + TransactionSplit.create( parent=txn, amount=txn.amount, category_id=categories["uncategorized"], ) - s.add_all((txn, t_split)) empty_portfolio.import_file(path, path_debug) assert not path_debug.exists() - with empty_portfolio.begin_session() as s: - assert query_count(s.query(Transaction)) == len(TRANSACTIONS_REQUIRED) + with empty_portfolio.begin_session(): + assert sql.count(Transaction.query()) == len(TRANSACTIONS_REQUIRED) def test_import_file_duplicate( @@ -65,9 +59,6 @@ def test_import_file_duplicate( account_investments: Account, asset: Asset, ) -> None: - _ = account - _ = account_investments - _ = asset path = data_path / "transactions_required.csv" path_debug = empty_portfolio.path.with_suffix(".importer-debug") @@ -86,9 +77,6 @@ def test_import_file_force( account_investments: Account, asset: Asset, ) -> None: - _ = account - _ = account_investments - _ = asset path = data_path / "transactions_required.csv" path_debug = empty_portfolio.path.with_suffix(".importer-debug") @@ -97,8 +85,8 @@ def test_import_file_force( assert not path_debug.exists() - with empty_portfolio.begin_session() as s: - assert query_count(s.query(Transaction)) == len(TRANSACTIONS_REQUIRED) * 2 + with empty_portfolio.begin_session(): + assert sql.count(Transaction.query()) == len(TRANSACTIONS_REQUIRED) * 2 @pytest.mark.parametrize( @@ -131,9 +119,6 @@ def test_import_file_error( target: type[Exception], debug_exists: bool, ) -> None: - _ = account - _ = account_investments - _ = asset path_debug = empty_portfolio.path.with_suffix(".importer-debug") for f in files[:-1]: path = data_path / f @@ -153,8 +138,6 @@ def test_import_file_investments( account_investments: Account, asset: Asset, ) -> None: - _ = account - _ = account_investments path = data_path / "transactions_investments.csv" path_debug = empty_portfolio.path.with_suffix(".importer-debug") @@ -162,14 +145,13 @@ def test_import_file_investments( assert not path_debug.exists() - with empty_portfolio.begin_session() as s: - assert query_count(s.query(Transaction)) == 4 + with empty_portfolio.begin_session(): + assert sql.count(Transaction.query()) == 4 - txn = ( - s.query(Transaction) - .where(Transaction.date_ord == datetime.date(2023, 1, 3).toordinal()) - .one() + query = Transaction.query().where( + Transaction.date_ord == datetime.date(2023, 1, 3).toordinal(), ) + txn = sql.one(query) assert txn.statement == f"Asset Transaction {asset.name}" @@ -179,7 +161,6 @@ def test_import_file_bad_category( account: Account, categories: dict[str, int], ) -> None: - _ = account path = data_path / "transactions_bad_category.csv" path_debug = empty_portfolio.path.with_suffix(".importer-debug") @@ -187,57 +168,6 @@ def test_import_file_bad_category( assert not path_debug.exists() - with empty_portfolio.begin_session() as s: - t_split = s.query(TransactionSplit).one() + with empty_portfolio.begin_session(): + t_split = TransactionSplit.one() assert t_split.category_id == categories["uncategorized"] - - -@pytest.mark.parametrize( - ("type_", "prop", "value_adjuster"), - [ - (Account, "uri", lambda s: s), - (Account, "number", lambda s: s), - (Account, "number", operator.itemgetter(slice(-4, None))), - (Account, "institution", lambda s: s), - (Account, "name", lambda s: s), - (Account, "name", lambda s: s.lower()), - (Account, "name", lambda s: s.upper()), - (Asset, "uri", lambda s: s), - (Asset, "ticker", lambda s: s), - (Asset, "name", lambda s: s), - ], -) -def test_find( - empty_portfolio: Portfolio, - account: Account, - asset: Asset, - type_: type[Base], - prop: str, - value_adjuster: Callable[[str], str], -) -> None: - obj = { - Account: account, - Asset: asset, - }[type_] - query = value_adjuster(getattr(obj, prop)) - - cache: dict[str, tuple[int, str | None]] = {} - with empty_portfolio.begin_session() as s: - a_id, a_name = Portfolio.find(s, type_, query, cache) - assert a_id == obj.id_ - assert a_name == obj.name - - assert cache == {query: (a_id, a_name)} - - -def test_find_missing( - empty_portfolio: Portfolio, - account: Account, -) -> None: - query = Account.id_to_uri(account.id_ + 1) - - cache: dict[str, tuple[int, str | None]] = {} - with empty_portfolio.begin_session() as s, pytest.raises(exc.NoResultFound): - Portfolio.find(s, Account, query, cache) - - assert not cache diff --git a/tests/portfolio/test_open.py b/tests/portfolio/test_open.py index 562dd890..5e68d412 100644 --- a/tests/portfolio/test_open.py +++ b/tests/portfolio/test_open.py @@ -6,11 +6,11 @@ import pytest from nummus import exceptions as exc +from nummus import sql from nummus.encryption.top import ENCRYPTION_AVAILABLE from nummus.models.asset import Asset from nummus.models.config import Config, ConfigKey from nummus.models.transaction_category import TransactionCategory -from nummus.models.utils import query_count from nummus.portfolio import Portfolio if TYPE_CHECKING: @@ -56,10 +56,10 @@ def test_unencrypted(tmp_path: Path) -> None: assert not p.is_encrypted assert not Portfolio.is_encrypted_path(path) - with p.begin_session() as s: - assert query_count(s.query(Config)) == 5 - assert query_count(s.query(TransactionCategory)) > 0 - assert query_count(s.query(Asset)) > 0 + with p.begin_session(): + assert sql.count(Config.query()) == 5 + assert sql.any_(TransactionCategory.query()) + assert sql.any_(Asset.query()) with pytest.raises(exc.NotEncryptedError): p.encrypt("") @@ -92,16 +92,16 @@ def test_no_encryption_test( empty_portfolio: Portfolio, key: ConfigKey, ) -> None: - with empty_portfolio.begin_session() as s: - s.query(Config).where(Config.key == key).delete() + with empty_portfolio.begin_session(): + Config.query().where(Config.key == key).delete() with pytest.raises(exc.ProtectedObjectNotFoundError): Portfolio(empty_portfolio.path, None) def test_bad_encryption_test(empty_portfolio: Portfolio) -> None: - with empty_portfolio.begin_session() as s: - Config.set_(s, ConfigKey.ENCRYPTION_TEST, "fake") + with empty_portfolio.begin_session(): + Config.set_(ConfigKey.ENCRYPTION_TEST, "fake") with pytest.raises(exc.UnlockingError): Portfolio(empty_portfolio.path, None) @@ -124,10 +124,10 @@ def test_encrypted(tmp_path: Path, rand_str: str) -> None: assert p.is_encrypted assert Portfolio.is_encrypted_path(path) - with p.begin_session() as s: - assert query_count(s.query(Config)) == 6 - assert query_count(s.query(TransactionCategory)) > 0 - assert query_count(s.query(Asset)) > 0 + with p.begin_session(): + assert sql.count(Config.query()) == 6 + assert sql.any_(TransactionCategory.query()) + assert sql.any_(Asset.query()) @pytest.mark.skipif(not ENCRYPTION_AVAILABLE, reason="No encryption available") @@ -149,8 +149,8 @@ def test_encrypted_bad_enc_test( empty_portfolio_encrypted: tuple[Portfolio, str], ) -> None: p, key = empty_portfolio_encrypted - with p.begin_session() as s: - Config.set_(s, ConfigKey.ENCRYPTION_TEST, "fake") + with p.begin_session(): + Config.set_(ConfigKey.ENCRYPTION_TEST, "fake") with pytest.raises(exc.UnlockingError): Portfolio(p.path, key) diff --git a/tests/portfolio/test_update_assets.py b/tests/portfolio/test_update_assets.py index 3b9a9832..29169cfe 100644 --- a/tests/portfolio/test_update_assets.py +++ b/tests/portfolio/test_update_assets.py @@ -17,7 +17,7 @@ from nummus.portfolio import Portfolio -def test_empty(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> None: +def test_empty(capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio) -> None: assert empty_portfolio.update_assets(no_bars=True) == [] captured = capsys.readouterr() @@ -26,7 +26,6 @@ def test_empty(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> Non def test_no_txns(empty_portfolio: Portfolio, asset: Asset) -> None: - _ = asset assert empty_portfolio.update_assets(no_bars=True) == [] @@ -38,11 +37,9 @@ def test_update_assets( asset_etf: Asset, transactions: list[Transaction], ) -> None: - _ = asset_etf - asset.interpolate = True - session.query(Asset).where(Asset.category == AssetCategory.INDEX).delete() - session.commit() - _ = transactions + with session.begin_nested(): + asset.interpolate = True + Asset.query().where(Asset.category == AssetCategory.INDEX).delete() target: list[AssetUpdate] = [ AssetUpdate( asset.name, @@ -65,10 +62,9 @@ def test_error( asset: Asset, transactions: list[Transaction], ) -> None: - asset.ticker = "FAKE" - session.query(Asset).where(Asset.category == AssetCategory.INDEX).delete() - session.commit() - _ = transactions + with session.begin_nested(): + asset.ticker = "FAKE" + Asset.query().where(Asset.category == AssetCategory.INDEX).delete() target: list[AssetUpdate] = [ AssetUpdate( asset.name, diff --git a/tests/test_main.py b/tests/test_main.py index 0fcf6400..79dea1b4 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -42,7 +42,7 @@ def test_unlock_non_existant(empty_portfolio: Portfolio) -> None: def test_unlock_successful( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio, ) -> None: args = ["--portfolio", str(empty_portfolio.path), "unlock"] diff --git a/tests/test_sql.py b/tests/test_sql.py index 6e857acc..aed92a35 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -7,6 +7,7 @@ from nummus import sql from nummus.encryption.top import Encryption, ENCRYPTION_AVAILABLE +from nummus.models.config import Config, ConfigKey if TYPE_CHECKING: from pathlib import Path @@ -60,3 +61,116 @@ def test_escape_not_reserved() -> None: def test_escape_reserved() -> None: assert sql.escape("where") == "`where`" + + +def test_to_dict() -> None: + query = Config.query(Config.key, Config.value) + result = sql.to_dict(query) + assert isinstance(result, dict) + assert all(isinstance(k, ConfigKey) for k in result) + assert all(isinstance(v, str) for v in result.values()) + + +def test_to_dict_tuple() -> None: + query = Config.query(Config.id_, Config.key, Config.value) + result = sql.to_dict_tuple(query) + assert isinstance(result, dict) + assert all(isinstance(k, int) for k in result) + assert all(isinstance(v, tuple) for v in result.values()) + assert all(len(v) == 2 for v in result.values()) + assert all(isinstance(v[0], ConfigKey) for v in result.values()) + assert all(isinstance(v[1], str) for v in result.values()) + + +def test_count() -> None: + query = Config.query() + assert sql.count(query) == query.count() + + +def test_any() -> None: + assert sql.any_(Config.query()) + + +def test_any_none() -> None: + Config.query().delete() + assert not sql.any_(Config.query()) + + +def test_one() -> None: + query = Config.query().where( + Config.key == ConfigKey.VERSION, + ) + result = sql.one(query) + assert isinstance(result, Config) + + +def test_one_value() -> None: + query = Config.query(Config.key).where( + Config.key == ConfigKey.VERSION, + ) + result = sql.one(query) + assert isinstance(result, ConfigKey) + + +def test_one_tuple() -> None: + query = Config.query(Config.key, Config.value).where( + Config.key == ConfigKey.VERSION, + ) + result = sql.one(query) + assert isinstance(result, tuple) + assert len(result) == 2 + assert isinstance(result[0], ConfigKey) + assert isinstance(result[1], str) + + +def test_scalar() -> None: + query = Config.query().where( + Config.key == ConfigKey.VERSION, + ) + result = sql.scalar(query) + assert isinstance(result, Config) + + +def test_scalar_value() -> None: + query = Config.query(Config.key).where( + Config.key == ConfigKey.VERSION, + ) + result = sql.scalar(query) + assert isinstance(result, ConfigKey) + + +def test_scalar_tuple() -> None: + query = Config.query(Config.key, Config.value).where( + Config.key == ConfigKey.VERSION, + ) + result = sql.scalar(query) + assert isinstance(result, ConfigKey) + + +def test_yield() -> None: + query = Config.query().where() + for r in sql.yield_(query): + assert isinstance(r, Config) + + +def test_yield_value() -> None: + query = Config.query(Config.key) + for r in sql.yield_(query): + assert isinstance(r, tuple) + assert len(r) == 1 + assert isinstance(r[0], ConfigKey) + + +def test_yield_tuple() -> None: + query = Config.query(Config.key, Config.value) + for r in sql.yield_(query): + assert isinstance(r, tuple) + assert len(r) == 2 + assert isinstance(r[0], ConfigKey) + assert isinstance(r[1], str) + + +def test_col0() -> None: + query = Config.query(Config.key) + for r in sql.col0(query): + assert isinstance(r, ConfigKey) diff --git a/tests/test_utils.py b/tests/test_utils.py index 3a4bc31c..c5940480 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -32,7 +32,7 @@ def test_camel_to_snake(s: str, c: str) -> None: def test_get_input_insecure( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, rand_str_generator: RandomStringGenerator, ) -> None: @@ -49,7 +49,7 @@ def mock_input(to_print: str) -> str | None: def test_get_input_insecure_abort( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, rand_str_generator: RandomStringGenerator, ) -> None: @@ -66,7 +66,7 @@ def mock_input(to_print: str) -> str | None: def test_get_input_secure( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, rand_str_generator: RandomStringGenerator, ) -> None: @@ -83,7 +83,7 @@ def mock_get_pass(to_print: str) -> str | None: def test_get_input_secure_abort( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, rand_str: str, ) -> None: @@ -97,7 +97,7 @@ def mock_get_pass(to_print: str) -> str | None: def test_get_input_secure_with_icon( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, rand_str_generator: RandomStringGenerator, ) -> None: @@ -149,7 +149,7 @@ def mock_input(to_print: str, *, secure: bool) -> str | None: ], ) def test_confirm( - capsys: pytest.CaptureFixture, + capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, rand_str: str, queue: list[str | None], @@ -651,8 +651,20 @@ def test_pretty_table_no_header() -> None: utils.pretty_table([None]) -def test_pretty_table_only_header(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr("shutil.get_terminal_size", lambda **_: (80, 24)) +@pytest.fixture +def fixed_terminal( + monkeypatch: pytest.MonkeyPatch, + width: int, + height: int, +) -> None: + def mock_terminal_size(**_: object) -> tuple[int, int]: + return width, height + + monkeypatch.setattr("shutil.get_terminal_size", mock_terminal_size) + + +@pytest.mark.parametrize(("width", "height"), [(80, 24)]) +def test_pretty_table_only_header(fixed_terminal: None) -> None: table: list[list[str] | None] = [ ["H1", ">H2", " None: ) assert "\n".join(utils.pretty_table(table)) == target - # Reset terminal width before verbose info is printed - monkeypatch.undo() - -def test_pretty_table_only_separator(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr("shutil.get_terminal_size", lambda **_: (80, 24)) +@pytest.mark.parametrize(("width", "height"), [(80, 24)]) +def test_pretty_table_only_separator(fixed_terminal: None) -> None: table: list[list[str] | None] = [ ["H1", ">H2", " None: ) assert "\n".join(utils.pretty_table(table)) == target - # Reset terminal width before verbose info is printed - monkeypatch.undo() - @pytest.fixture def table() -> list[list[str] | None]: @@ -698,11 +704,11 @@ def table() -> list[list[str] | None]: ] +@pytest.mark.parametrize(("width", "height"), [(80, 24)]) def test_pretty_table_width_80( - monkeypatch: pytest.MonkeyPatch, + fixed_terminal: None, table: list[list[str] | None], ) -> None: - monkeypatch.setattr("shutil.get_terminal_size", lambda **_: (80, 24)) target = textwrap.dedent( """\ ╭───────────┬───────────┬───────────┬───────────┬───────────┬───────────╮ @@ -715,16 +721,12 @@ def test_pretty_table_width_80( ) assert "\n".join(utils.pretty_table(table)) == target - # Reset terminal width before verbose info is printed - monkeypatch.undo() - +@pytest.mark.parametrize(("width", "height"), [(70, 24)]) def test_pretty_table_width_70( - monkeypatch: pytest.MonkeyPatch, + fixed_terminal: None, table: list[list[str] | None], ) -> None: - # Make terminal smaller, extra space goes first - monkeypatch.setattr("shutil.get_terminal_size", lambda **_: (70, 24)) target = textwrap.dedent( """\ ╭───────────┬───────────┬───────────┬───────────┬─────────┬─────────╮ @@ -737,16 +739,12 @@ def test_pretty_table_width_70( ) assert "\n".join(utils.pretty_table(table)) == target - # Reset terminal width before verbose info is printed - monkeypatch.undo() - +@pytest.mark.parametrize(("width", "height"), [(60, 24)]) def test_pretty_table_width_60( - monkeypatch: pytest.MonkeyPatch, + fixed_terminal: None, table: list[list[str] | None], ) -> None: - # Make terminal smaller, truncate column goes next - monkeypatch.setattr("shutil.get_terminal_size", lambda **_: (60, 24)) target = textwrap.dedent( """\ ╭─────────┬─────────┬─────────┬─────────┬───────┬─────────╮ @@ -759,16 +757,12 @@ def test_pretty_table_width_60( ) assert "\n".join(utils.pretty_table(table)) == target - # Reset terminal width before verbose info is printed - monkeypatch.undo() - +@pytest.mark.parametrize(("width", "height"), [(50, 24)]) def test_pretty_table_width_50( - monkeypatch: pytest.MonkeyPatch, + fixed_terminal: None, table: list[list[str] | None], ) -> None: - # Make terminal smaller, other columns go next - monkeypatch.setattr("shutil.get_terminal_size", lambda **_: (50, 24)) target = textwrap.dedent( """\ ╭───────┬───────┬───────┬────────┬────┬─────────╮ @@ -781,16 +775,12 @@ def test_pretty_table_width_50( ) assert "\n".join(utils.pretty_table(table)) == target - # Reset terminal width before verbose info is printed - monkeypatch.undo() - +@pytest.mark.parametrize(("width", "height"), [(10, 24)]) def test_pretty_table_width_10( - monkeypatch: pytest.MonkeyPatch, + fixed_terminal: None, table: list[list[str] | None], ) -> None: - # Make terminal tiny, other columns go next, never last - monkeypatch.setattr("shutil.get_terminal_size", lambda **_: (10, 24)) target = textwrap.dedent( """\ ╭────┬────┬────┬────┬────┬─────────╮ @@ -803,9 +793,6 @@ def test_pretty_table_width_10( ) assert "\n".join(utils.pretty_table(table)) == target - # Reset terminal width before verbose info is printed - monkeypatch.undo() - @pytest.mark.parametrize( ("items", "target"), diff --git a/tests/test_web.py b/tests/test_web.py index 23566b36..32a9ac72 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -19,8 +19,8 @@ def test_create_app(empty_portfolio: Portfolio, flask_app: flask.Flask) -> None: - with empty_portfolio.begin_session() as s: - secret_key = Config.fetch(s, ConfigKey.SECRET_KEY) + with empty_portfolio.begin_session(): + secret_key = Config.fetch(ConfigKey.SECRET_KEY) assert flask_app.secret_key == secret_key assert len(flask_app.before_request_funcs[None]) == 1 @@ -32,8 +32,8 @@ def test_no_secret_key( empty_portfolio: Portfolio, ) -> None: monkeypatch.setenv("NUMMUS_PORTFOLIO", str(empty_portfolio.path)) - with empty_portfolio.begin_session() as s: - s.query(Config).where(Config.key == ConfigKey.SECRET_KEY).delete() + with empty_portfolio.begin_session(): + Config.query().where(Config.key == ConfigKey.SECRET_KEY).delete() with pytest.raises(exc.ProtectedObjectNotFoundError): web.create_app() diff --git a/typings/Cryptodome/Cipher/AES.pyi b/typings/Cryptodome/Cipher/AES.pyi new file mode 100644 index 00000000..14c13447 --- /dev/null +++ b/typings/Cryptodome/Cipher/AES.pyi @@ -0,0 +1,111 @@ +# Initially generated by Pyright + +from typing import overload + +from Cryptodome.Cipher._mode_cbc import CbcMode +from typing_extensions import Literal + +Buffer = bytes | bytearray | memoryview +MODE_ECB: Literal[1] +MODE_CBC: Literal[2] +MODE_CFB: Literal[3] +MODE_OFB: Literal[5] +MODE_CTR: Literal[6] +MODE_OPENPGP: Literal[7] +MODE_CCM: Literal[8] +MODE_EAX: Literal[9] +MODE_SIV: Literal[10] +MODE_GCM: Literal[11] +MODE_OCB: Literal[12] + +@overload +def new(key: Buffer, mode: Literal[1], use_aesni: bool = ...) -> None: ... +@overload +def new( + key: Buffer, mode: Literal[2], iv: Buffer | None = ..., use_aesni: bool = ... +) -> CbcMode: ... +@overload +def new( + key: Buffer, mode: Literal[2], IV: Buffer | None = ..., use_aesni: bool = ... +) -> CbcMode: ... +@overload +def new( + key: Buffer, + mode: Literal[3], + iv: Buffer | None = ..., + segment_size: int = ..., + use_aesni: bool = ..., +) -> None: ... +@overload +def new( + key: Buffer, + mode: Literal[3], + IV: Buffer | None = ..., + segment_size: int = ..., + use_aesni: bool = ..., +) -> None: ... +@overload +def new( + key: Buffer, mode: Literal[5], iv: Buffer | None = ..., use_aesni: bool = ... +) -> None: ... +@overload +def new( + key: Buffer, mode: Literal[5], IV: Buffer | None = ..., use_aesni: bool = ... +) -> None: ... +@overload +def new( + key: Buffer, + mode: Literal[6], + nonce: Buffer | None = ..., + initial_value: Buffer | int = ..., + counter: dict[object, object] = ..., + use_aesni: bool = ..., +) -> None: ... +@overload +def new( + key: Buffer, mode: Literal[7], iv: Buffer | None = ..., use_aesni: bool = ... +) -> None: ... +@overload +def new( + key: Buffer, mode: Literal[7], IV: Buffer | None = ..., use_aesni: bool = ... +) -> None: ... +@overload +def new( + key: Buffer, + mode: Literal[8], + nonce: Buffer | None = ..., + mac_len: int = ..., + assoc_len: int = ..., + use_aesni: bool = ..., +) -> None: ... +@overload +def new( + key: Buffer, + mode: Literal[9], + nonce: Buffer | None = ..., + mac_len: int = ..., + use_aesni: bool = ..., +) -> None: ... +@overload +def new( + key: Buffer, mode: Literal[10], nonce: Buffer | None = ..., use_aesni: bool = ... +) -> None: ... +@overload +def new( + key: Buffer, + mode: Literal[11], + nonce: Buffer | None = ..., + mac_len: int = ..., + use_aesni: bool = ..., +) -> None: ... +@overload +def new( + key: Buffer, + mode: Literal[12], + nonce: Buffer | None = ..., + mac_len: int = ..., + use_aesni: bool = ..., +) -> None: ... + +block_size: int +key_size: tuple[int, int, int] diff --git a/typings/Cryptodome/Cipher/_mode_cbc.pyi b/typings/Cryptodome/Cipher/_mode_cbc.pyi new file mode 100644 index 00000000..3c379cb1 --- /dev/null +++ b/typings/Cryptodome/Cipher/_mode_cbc.pyi @@ -0,0 +1,30 @@ +# Initially generated by Pyright + +from typing import overload + +from Cryptodome.Util._raw_api import SmartPointer + +Buffer = bytes | bytearray | memoryview +__all__ = ["CbcMode"] + +class CbcMode: + block_size: int + iv: Buffer + IV: Buffer + def __init__(self, block_cipher: SmartPointer, iv: Buffer) -> None: ... + @overload + def encrypt(self, plaintext: Buffer) -> bytes: ... + @overload + def encrypt( + self, + plaintext: Buffer, + output: bytearray | memoryview, + ) -> None: ... + @overload + def decrypt(self, plaintext: Buffer) -> bytes: ... + @overload + def decrypt( + self, + plaintext: Buffer, + output: bytearray | memoryview, + ) -> None: ... diff --git a/typings/flask_assets/__init__.pyi b/typings/flask_assets/__init__.pyi new file mode 100644 index 00000000..b8c678c5 --- /dev/null +++ b/typings/flask_assets/__init__.pyi @@ -0,0 +1,32 @@ +# Initially generated by Pyright + +import io + +import flask +from webassets.bundle import Bundle as BaseBundle +from webassets.env import BaseEnvironment +from webassets.merge import FileHunk + +class Environment(BaseEnvironment): + def __init__(self, app: flask.Flask = ...) -> None: ... + def set_directory(self, directory: str) -> None: ... + def get_directory(self) -> str | None: ... + def set_url(self, url: str) -> None: ... + def get_url(self) -> str | None: ... + def init_app(self, app: flask.Flask) -> None: ... + def from_yaml(self, path: str) -> None: ... + def from_module(self, path: str) -> None: ... + def register( + self, + name: str, + *args: str | Bundle, + **kwargs: object, + ) -> Bundle | None: ... + +class Bundle(BaseBundle): + def build( + self, + force: bool | None = ..., + output: io.FileIO | None = ..., + disable_cache: bool | None = ..., + ) -> list[FileHunk]: ... diff --git a/typings/flask_login/__init__.pyi b/typings/flask_login/__init__.pyi new file mode 100644 index 00000000..8f7ab102 --- /dev/null +++ b/typings/flask_login/__init__.pyi @@ -0,0 +1,38 @@ +# Initially generated by Pyright + +from .login_manager import LoginManager +from .mixins import AnonymousUserMixin, UserMixin +from .utils import ( + confirm_login, + current_user, + decode_cookie, + encode_cookie, + fresh_login_required, + login_fresh, + login_remembered, + login_required, + login_url, + login_user, + logout_user, + make_next_param, + set_login_view, +) + +__all__ = [ + "LoginManager", + "AnonymousUserMixin", + "UserMixin", + "confirm_login", + "current_user", + "decode_cookie", + "encode_cookie", + "fresh_login_required", + "login_fresh", + "login_remembered", + "login_required", + "login_url", + "login_user", + "logout_user", + "make_next_param", + "set_login_view", +] diff --git a/typings/flask_login/login_manager.pyi b/typings/flask_login/login_manager.pyi new file mode 100644 index 00000000..d80fb8ee --- /dev/null +++ b/typings/flask_login/login_manager.pyi @@ -0,0 +1,47 @@ +# Initially generated by Pyright + +from collections.abc import Callable + +import flask + +from .mixins import AnonymousUserMixin, UserMixin + +class LoginManager: + login_view: str | None + def __init__( + self, + app: flask.Flask = ..., + add_context_processor: bool = ..., + ) -> None: ... + def setup_app( + self, + app: flask.Flask, + add_context_processor: bool = ..., + ) -> None: ... + def init_app(self, app: flask.Flask, add_context_processor: bool = ...) -> None: ... + def unauthorized(self) -> flask.Response: ... + def user_loader( + self, + callback: Callable[[str], UserMixin | AnonymousUserMixin | None], + ) -> None: ... + @property + def user_callback(self) -> None: ... + def request_loader( + self, + callback: Callable[[flask.Request], UserMixin | AnonymousUserMixin | None], + ) -> None: ... + @property + def request_callback(self) -> None: ... + def unauthorized_handler[T: Callable[[], flask.Response]]( + self, + callback: T, + ) -> T: ... + def needs_refresh_handler[T: Callable[[], flask.Response]]( + self, + callback: T, + ) -> T: ... + def needs_refresh(self) -> flask.Response: ... + def header_loader[T: Callable[[str], UserMixin | AnonymousUserMixin | None]]( + self, + callback: T, + ) -> T: ... diff --git a/typings/flask_login/mixins.pyi b/typings/flask_login/mixins.pyi new file mode 100644 index 00000000..165962ba --- /dev/null +++ b/typings/flask_login/mixins.pyi @@ -0,0 +1,23 @@ +# Initially generated by Pyright + +from typing import Literal + +class UserMixin: + @property + def is_active(self) -> Literal[True]: ... + @property + def is_authenticated(self) -> Literal[True]: ... + @property + def is_anonymous(self) -> Literal[False]: ... + def get_id(self) -> str: ... + def __eq__(self, other: UserMixin | object) -> bool: ... + def __ne__(self, other: UserMixin | object) -> bool: ... + +class AnonymousUserMixin: + @property + def is_authenticated(self) -> Literal[False]: ... + @property + def is_active(self) -> Literal[False]: ... + @property + def is_anonymous(self) -> Literal[True]: ... + def get_id(self) -> None: ... diff --git a/typings/flask_login/utils.pyi b/typings/flask_login/utils.pyi new file mode 100644 index 00000000..bc05e870 --- /dev/null +++ b/typings/flask_login/utils.pyi @@ -0,0 +1,31 @@ +# Initially generated by Pyright + +import datetime +from collections.abc import Callable +from typing import Literal + +import flask + +from .mixins import AnonymousUserMixin, UserMixin + +current_user: UserMixin | AnonymousUserMixin = ... + +def encode_cookie(payload: str, key: str = ...) -> str: ... +def decode_cookie(cookie: str, key: str = ...) -> str | None: ... +def make_next_param(login_url: str, current_url: str) -> str: ... +def expand_login_view(login_view: str) -> str: ... +def login_url(login_view: str, next_url: str = ..., next_field: str = ...) -> str: ... +def login_fresh() -> bool: ... +def login_remembered() -> bool: ... +def login_user( + user: UserMixin, + remember: bool = ..., + duration: datetime.timedelta = ..., + force: bool = ..., + fresh: bool = ..., +) -> bool: ... +def logout_user() -> Literal[True]: ... +def confirm_login() -> None: ... +def login_required[T: Callable[..., object]](func: T) -> T: ... +def fresh_login_required[T: Callable[..., object]](func: T) -> T: ... +def set_login_view(login_view: str, blueprint: flask.Blueprint = ...) -> None: ... diff --git a/typings/jsmin/__init__.pyi b/typings/jsmin/__init__.pyi new file mode 100644 index 00000000..658e62ad --- /dev/null +++ b/typings/jsmin/__init__.pyi @@ -0,0 +1,16 @@ +# Initially generated by Pyright + +import io + +def jsmin(js: str, **kwargs: object) -> str: ... + +class JavascriptMinify: + def __init__( + self, + instream: io.StringIO = ..., + outstream: io.StringIO = ..., + quote_chars: str = ..., + ) -> None: ... + def minify( + self, instream: io.StringIO = ..., outstream: io.StringIO = ... + ) -> None: ... diff --git a/typings/numpy_financial/__init__.pyi b/typings/numpy_financial/__init__.pyi new file mode 100644 index 00000000..fd47cd48 --- /dev/null +++ b/typings/numpy_financial/__init__.pyi @@ -0,0 +1,5 @@ +# Initially generated by Pyright + +from ._financial import * + +__version__ = ... diff --git a/typings/numpy_financial/_financial.pyi b/typings/numpy_financial/_financial.pyi new file mode 100644 index 00000000..be1a6f59 --- /dev/null +++ b/typings/numpy_financial/_financial.pyi @@ -0,0 +1,5 @@ +# Initially generated by Pyright + +from decimal import Decimal + +def irr(values: list[Decimal] | list[float]) -> float: ... diff --git a/typings/prometheus_flask_exporter/__init__.pyi b/typings/prometheus_flask_exporter/__init__.pyi new file mode 100644 index 00000000..7d13172e --- /dev/null +++ b/typings/prometheus_flask_exporter/__init__.pyi @@ -0,0 +1,107 @@ +# Initially generated by Pyright + +from collections.abc import Callable +from typing import Literal, Self + +import flask +import prometheus_client +from flask.typing import ResponseReturnValue +from prometheus_client import Gauge +from werkzeug.exceptions import HTTPException + +GroupBy = Literal["path", "endpoint", "url_rule"] + +Responder = Callable[..., ResponseReturnValue | object | HTTPException | flask.Response] + +class PrometheusMetrics: + def __init__[T: Callable[..., object]]( + self, + app: flask.Flask, + path: str = ..., + export_defaults: bool = ..., + defaults_prefix: str = ..., + group_by: GroupBy = ..., + buckets: tuple[float, ...] | None = ..., + default_latency_as_histogram: bool = ..., + default_labels: dict[str, object] | None = ..., + response_converter: Responder = ..., + excluded_paths: list[str] | str | None = ..., + exclude_user_defaults: bool = ..., + metrics_decorator: Callable[[T], T] = ..., + registry: prometheus_client.CollectorRegistry | None = ..., + **kwargs: object, + ) -> None: ... + @classmethod + def for_app_factory(cls, **kwargs: object) -> Self: ... + def init_app(self, app: flask.Flask) -> None: ... + def register_endpoint(self, path: str, app: flask.Flask = ...) -> None: ... + def generate_metrics( + self, + accept_header: str | None = ..., + names: list[str] | None = ..., + ) -> tuple[str, str]: ... + def start_http_server( + self, + port: int, + host: str = ..., + endpoint: str = ..., + ssl: dict[str, str] | None = ..., + ) -> None: ... + def export_defaults( + self, + buckets: tuple[float, ...] | None = ..., + group_by: GroupBy = ..., + latency_as_histogram: bool = ..., + prefix: str = ..., + app: flask.Flask = ..., + **kwargs: object, + ) -> None: ... + def register_default[T: Callable[..., object]]( + self, + *metric_wrappers: Callable[[T], T], + **kwargs: object, + ) -> None: ... + def histogram( + self, + name: str, + description: str, + labels: dict[str, object] | None = ..., + initial_value_when_only_static_labels: bool = ..., + **kwargs: str, + ) -> Callable[..., Responder]: ... + def summary( + self, + name: str, + description: str, + labels: dict[str, object] | None = ..., + initial_value_when_only_static_labels: bool = ..., + **kwargs: str, + ) -> Callable[..., Responder]: ... + def gauge( + self, + name: str, + description: str, + labels: dict[str, object] | None = ..., + initial_value_when_only_static_labels: bool = ..., + **kwargs: str, + ) -> Callable[..., Responder]: ... + def counter( + self, + name: str, + description: str, + labels: dict[str, object] | None = ..., + initial_value_when_only_static_labels: bool = ..., + **kwargs: str, + ) -> Callable[..., Responder]: ... + @staticmethod + def do_not_track[T: Callable[..., object]]() -> Callable[[T], T]: ... + @staticmethod + def exclude_all_metrics[T: Callable[..., object]]() -> Callable[[T], T]: ... + def info( + self, + name: str, + description: str, + labelnames: tuple[str, ...] | None = ..., + labelvalues: tuple[object, ...] | None = ..., + **labels: object, + ) -> Gauge: ... diff --git a/typings/prometheus_flask_exporter/multiprocess.pyi b/typings/prometheus_flask_exporter/multiprocess.pyi new file mode 100644 index 00000000..3957653d --- /dev/null +++ b/typings/prometheus_flask_exporter/multiprocess.pyi @@ -0,0 +1,19 @@ +# Initially generated by Pyright + +from abc import ABCMeta, abstractmethod +from typing import Literal + +from . import PrometheusMetrics + +class MultiprocessPrometheusMetrics(PrometheusMetrics): + __metaclass__ = ABCMeta + + @abstractmethod + def should_start_http_server(self) -> bool: ... + +class GunicornPrometheusMetrics(MultiprocessPrometheusMetrics): + def should_start_http_server(self) -> Literal[True]: ... + @classmethod + def start_http_server_when_ready(cls, port: int, host: str = ...) -> None: ... + @classmethod + def mark_process_dead_on_child_exit(cls, pid: int) -> None: ... diff --git a/typings/pytailwindcss/__init__.pyi b/typings/pytailwindcss/__init__.pyi new file mode 100644 index 00000000..180d3bdd --- /dev/null +++ b/typings/pytailwindcss/__init__.pyi @@ -0,0 +1,13 @@ +# Initially generated by Pyright +# +import pathlib + +def run( + tailwindcss_cli_args: list[str] | str = ..., + cwd: pathlib.Path | str = ..., + bin_path: pathlib.Path | str = ..., + env: dict[str, str] = ..., + live_output: bool = ..., + auto_install: bool = ..., + version: str = ..., +) -> str: ... diff --git a/typings/pytest/__init__.pyi b/typings/pytest/__init__.pyi new file mode 100644 index 00000000..07eb6452 --- /dev/null +++ b/typings/pytest/__init__.pyi @@ -0,0 +1,181 @@ +# Initially generated by Pyright + +from decimal import Decimal +from typing import overload + +from _pytest import __version__, version_tuple +from _pytest._code import ExceptionInfo +from _pytest.assertion import register_assert_rewrite +from _pytest.cacheprovider import Cache +from _pytest.capture import CaptureFixture +from _pytest.config import ( + cmdline, + Config, + console_main, + ExitCode, + hookimpl, + hookspec, + main, + PytestPluginManager, + UsageError, +) +from _pytest.config.argparsing import OptionGroup, Parser +from _pytest.doctest import DoctestItem +from _pytest.fixtures import ( + fixture, + FixtureDef, + FixtureLookupError, + FixtureRequest, +) +from _pytest.freeze_support import freeze_includes +from _pytest.legacypath import TempdirFactory, Testdir +from _pytest.logging import LogCaptureFixture +from _pytest.main import Dir, Session +from _pytest.mark import HIDDEN_PARAM, Mark +from _pytest.mark import MARK_GEN as mark +from _pytest.mark import MarkDecorator, MarkGenerator, param +from _pytest.monkeypatch import MonkeyPatch +from _pytest.nodes import Collector, Directory, File, Item +from _pytest.outcomes import exit, fail, importorskip, skip, xfail +from _pytest.pytester import ( + HookRecorder, + LineMatcher, + Pytester, + RecordedHookCall, + RunResult, +) +from _pytest.python import Class, Function, Metafunc, Module, Package +from _pytest.python_api import ApproxBase, ApproxDecimal, ApproxScalar +from _pytest.raises import raises, RaisesExc, RaisesGroup +from _pytest.recwarn import deprecated_call, WarningsRecorder, warns +from _pytest.reports import CollectReport, TestReport +from _pytest.runner import CallInfo +from _pytest.stash import Stash, StashKey +from _pytest.subtests import SubtestReport, Subtests +from _pytest.terminal import TerminalReporter, TestShortLogReport +from _pytest.tmpdir import TempPathFactory +from _pytest.warning_types import ( + PytestAssertRewriteWarning, + PytestCacheWarning, + PytestCollectionWarning, + PytestConfigWarning, + PytestDeprecationWarning, + PytestExperimentalApiWarning, + PytestFDWarning, + PytestRemovedIn9Warning, + PytestRemovedIn10Warning, + PytestReturnNotNoneWarning, + PytestUnhandledThreadExceptionWarning, + PytestUnknownMarkWarning, + PytestUnraisableExceptionWarning, + PytestWarning, +) + +@overload +def approx[T: Decimal]( + expected: T, + rel: T | None = None, + abs: T | None = None, + nan_ok: bool = False, +) -> ApproxDecimal: ... +@overload +def approx[T]( + expected: T, + rel: T | None = None, + abs: T | None = None, + nan_ok: bool = False, +) -> ApproxScalar: ... +def approx[T]( + expected: T, + rel: T | None = None, + abs: T | None = None, + nan_ok: bool = False, +) -> ApproxBase: ... + +__all__ = [ + "HIDDEN_PARAM", + "Cache", + "CallInfo", + "CaptureFixture", + "Class", + "CollectReport", + "Collector", + "Config", + "Dir", + "Directory", + "DoctestItem", + "ExceptionInfo", + "ExitCode", + "File", + "FixtureDef", + "FixtureLookupError", + "FixtureRequest", + "Function", + "HookRecorder", + "Item", + "LineMatcher", + "LogCaptureFixture", + "Mark", + "MarkDecorator", + "MarkGenerator", + "Metafunc", + "Module", + "MonkeyPatch", + "OptionGroup", + "Package", + "Parser", + "PytestAssertRewriteWarning", + "PytestCacheWarning", + "PytestCollectionWarning", + "PytestConfigWarning", + "PytestDeprecationWarning", + "PytestExperimentalApiWarning", + "PytestFDWarning", + "PytestPluginManager", + "PytestRemovedIn9Warning", + "PytestRemovedIn10Warning", + "PytestReturnNotNoneWarning", + "PytestUnhandledThreadExceptionWarning", + "PytestUnknownMarkWarning", + "PytestUnraisableExceptionWarning", + "PytestWarning", + "Pytester", + "RaisesExc", + "RaisesGroup", + "RecordedHookCall", + "RunResult", + "Session", + "Stash", + "StashKey", + "SubtestReport", + "Subtests", + "TempPathFactory", + "TempdirFactory", + "TerminalReporter", + "TestReport", + "TestShortLogReport", + "Testdir", + "UsageError", + "WarningsRecorder", + "__version__", + "approx", + "cmdline", + "console_main", + "deprecated_call", + "exit", + "fail", + "fixture", + "freeze_includes", + "hookimpl", + "hookspec", + "importorskip", + "main", + "mark", + "param", + "raises", + "register_assert_rewrite", + "skip", + "version_tuple", + "warns", + "xfail", +] diff --git a/typings/sqlcipher3/__init__.pyi b/typings/sqlcipher3/__init__.pyi new file mode 100644 index 00000000..51667194 --- /dev/null +++ b/typings/sqlcipher3/__init__.pyi @@ -0,0 +1,2 @@ +# Generated by Pyright +# Don't actually need any typing diff --git a/typings/yfinance/__init__.pyi b/typings/yfinance/__init__.pyi new file mode 100644 index 00000000..3d095e1c --- /dev/null +++ b/typings/yfinance/__init__.pyi @@ -0,0 +1,5 @@ +# Initially generated by Pyright + +from .ticker import Ticker + +__all__ = ["Ticker"] diff --git a/typings/yfinance/base.pyi b/typings/yfinance/base.pyi new file mode 100644 index 00000000..748d987c --- /dev/null +++ b/typings/yfinance/base.pyi @@ -0,0 +1,29 @@ +# Initially generated by Pyright + +import datetime + +import pandas as pd +import requests + +class TickerBase: + def __init__( + self, + ticker: str | tuple[str, str], + session: requests.Session | None = ..., + ) -> None: ... + def history( + self, + period: str | None = ..., + interval: str = ..., + start: datetime.date | str | None = ..., + end: datetime.date | str | None = ..., + prepost: bool = ..., + actions: bool = ..., + auto_adjust: bool = ..., + back_adjust: bool = ..., + repair: bool = ..., + keepna: bool = ..., + rounding: bool = ..., + timeout: float | None = ..., + raise_errors: bool = ..., + ) -> pd.DataFrame: ... diff --git a/typings/yfinance/exceptions.pyi b/typings/yfinance/exceptions.pyi new file mode 100644 index 00000000..2eec6a0f --- /dev/null +++ b/typings/yfinance/exceptions.pyi @@ -0,0 +1,4 @@ +# Initially generated by Pyright + +class YFException(Exception): ... +class YFDataException(YFException): ... diff --git a/typings/yfinance/scrapers/__init__.pyi b/typings/yfinance/scrapers/__init__.pyi new file mode 100644 index 00000000..e69de29b diff --git a/typings/yfinance/scrapers/funds.pyi b/typings/yfinance/scrapers/funds.pyi new file mode 100644 index 00000000..14c1a8a7 --- /dev/null +++ b/typings/yfinance/scrapers/funds.pyi @@ -0,0 +1,5 @@ +# Initially generated by Pyright + +class FundsData: + @property + def sector_weightings(self) -> dict[str, float]: ... diff --git a/typings/yfinance/ticker.pyi b/typings/yfinance/ticker.pyi new file mode 100644 index 00000000..bcea0081 --- /dev/null +++ b/typings/yfinance/ticker.pyi @@ -0,0 +1,24 @@ +# Initially generated by Pyright + +from typing import NotRequired, TypedDict + +import requests + +from .base import TickerBase +from .scrapers.funds import FundsData + +class Info(TypedDict): + + currency: str + sector: NotRequired[str | None] + +class Ticker(TickerBase): + def __init__( + self, + ticker: str | tuple[str, str], + session: requests.Session | None = ..., + ) -> None: ... + @property + def info(self) -> Info: ... + @property + def funds_data(self) -> FundsData: ...