diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 7450f785..eed612eb 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -5,7 +5,7 @@ repos:
- id: ruff-check
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/RobertCraigie/pyright-python
- rev: v1.1.402
+ rev: v1.1.408
hooks:
- id: pyright
- repo: https://github.com/pycqa/isort
diff --git a/nummus/commands/export.py b/nummus/commands/export.py
index 0fa6162e..2fe0251a 100644
--- a/nummus/commands/export.py
+++ b/nummus/commands/export.py
@@ -15,11 +15,6 @@
import argparse
import io
- from sqlalchemy import orm
-
- from nummus.models.currency import Currency
- from nummus.models.transaction import TransactionSplit
-
class Export(Command):
"""Export transactions."""
@@ -85,35 +80,20 @@ def setup_args(cls, parser: argparse.ArgumentParser) -> None:
@override
def run(self) -> int:
- # Defer for faster time to main
- from nummus.models.transaction import TransactionSplit
-
- with self._p.begin_session() as s:
- query = (
- s.query(TransactionSplit)
- .where(
- TransactionSplit.asset_id.is_(None),
- )
- .with_entities(TransactionSplit.amount)
- )
- if self._start is not None:
- query = query.where(
- TransactionSplit.date_ord >= self._start.toordinal(),
- )
- if self._end is not None:
- query = query.where(
- TransactionSplit.date_ord <= self._end.toordinal(),
- )
-
- with self._csv_path.open("w", encoding="utf-8") as file:
- n = write_csv(file, query, no_bars=self._no_bars)
+
+ with (
+ self._p.begin_session(),
+ self._csv_path.open("w", encoding="utf-8") as file,
+ ):
+ n = write_csv(file, self._start, self._end, no_bars=self._no_bars)
print(f"{Fore.GREEN}{n} transactions exported to {self._csv_path}")
return 0
def write_csv(
file: io.TextIOBase,
- transactions_query: orm.Query[TransactionSplit],
+ start: datetime.date | None,
+ end: datetime.date | None,
*,
no_bars: bool,
) -> int:
@@ -121,7 +101,8 @@ def write_csv(
Args:
file: Destination file to write to
- transactions_query: ORM query to obtain TransactionSplits
+ start: Start date to filter transactions
+ end: End date to filter transactions
no_bars: True will disable progress bars
Returns:
@@ -131,34 +112,43 @@ def write_csv(
# Defer for faster time to main
import tqdm
+ from nummus import sql
from nummus.models.account import Account
- from nummus.models.base import YIELD_PER
from nummus.models.currency import CURRENCY_FORMATS
from nummus.models.transaction import TransactionCategory, TransactionSplit
- from nummus.models.utils import query_count
- s = transactions_query.session
-
- query = s.query(Account).with_entities(
+ query = Account.query(
Account.id_,
Account.name,
Account.currency,
)
- accounts: dict[int, tuple[str, Currency]] = {
- r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER)
- }
-
- categories = TransactionCategory.map_name_emoji(s)
-
- query = transactions_query.with_entities(
- TransactionSplit.date_ord,
- TransactionSplit.account_id,
- TransactionSplit.payee,
- TransactionSplit.memo,
- TransactionSplit.category_id,
- TransactionSplit.amount,
- ).order_by(TransactionSplit.date_ord)
- n = query_count(query)
+ accounts = sql.to_dict_tuple(query)
+
+ categories = TransactionCategory.map_name_emoji()
+
+ query = (
+ TransactionSplit.query(
+ TransactionSplit.date_ord,
+ TransactionSplit.account_id,
+ TransactionSplit.payee,
+ TransactionSplit.memo,
+ TransactionSplit.category_id,
+ TransactionSplit.amount,
+ )
+ .order_by(TransactionSplit.date_ord)
+ .where(
+ TransactionSplit.asset_id.is_(None),
+ )
+ )
+ if start is not None:
+ query = query.where(
+ TransactionSplit.date_ord >= start.toordinal(),
+ )
+ if end is not None:
+ query = query.where(
+ TransactionSplit.date_ord <= end.toordinal(),
+ )
+ n = sql.count(query)
header = [
"Date",
@@ -177,7 +167,7 @@ def write_csv(
t_cat_id,
amount,
) in tqdm.tqdm(
- query.yield_per(YIELD_PER),
+ sql.yield_(query),
total=n,
desc="Exporting",
disable=no_bars,
@@ -189,8 +179,8 @@ def write_csv(
[
datetime.date.fromordinal(date).isoformat(),
acct_name,
- payee,
- memo,
+ payee or "",
+ memo or "",
categories[t_cat_id],
cf(amount),
],
diff --git a/nummus/commands/health.py b/nummus/commands/health.py
index daa38e97..1912f81e 100644
--- a/nummus/commands/health.py
+++ b/nummus/commands/health.py
@@ -113,13 +113,15 @@ def run(self) -> int:
p = self._p
- with p.begin_session() as s:
+ with p.begin_session():
if self._clear_ignores:
- s.query(HealthCheckIssue).delete()
+ HealthCheckIssue.query().delete()
elif self._ignores:
# Set ignore for all specified issues
+ print(self._ignores)
+ print(HealthCheckIssue.uri_to_id(self._ignores[0]))
ids = {HealthCheckIssue.uri_to_id(uri) for uri in self._ignores}
- s.query(HealthCheckIssue).where(HealthCheckIssue.id_.in_(ids)).update(
+ HealthCheckIssue.query().where(HealthCheckIssue.id_.in_(ids)).update(
{HealthCheckIssue.ignore: True},
)
@@ -141,8 +143,8 @@ def run(self) -> int:
# Update LAST_HEALTH_CHECK_TS
utc_now = datetime.datetime.now(datetime.UTC)
- with p.begin_session() as s:
- Config.set_(s, ConfigKey.LAST_HEALTH_CHECK_TS, utc_now.isoformat())
+ with p.begin_session():
+ Config.set_(ConfigKey.LAST_HEALTH_CHECK_TS, utc_now.isoformat())
if any_severe_issues:
return -2
if any_issues:
@@ -164,8 +166,8 @@ def _test_check(self, check_type: type[HealthCheck]) -> str | None:
no_ignores=self._no_ignores,
no_description_typos=self._no_description_typos,
)
- with self._p.begin_session() as s:
- c.test(s)
+ with self._p.begin_session():
+ c.test()
n_issues = len(c.issues)
if n_issues == 0:
print(f"{Fore.GREEN}Check '{c.name()}' has no issues")
diff --git a/nummus/commands/migrate.py b/nummus/commands/migrate.py
index f18fcec3..2821c94d 100644
--- a/nummus/commands/migrate.py
+++ b/nummus/commands/migrate.py
@@ -54,8 +54,8 @@ def run(self) -> int:
# Back up Portfolio
_, tar_ver = p.backup()
- with p.begin_session() as s:
- v_db = Config.db_version(s)
+ with p.begin_session():
+ v_db = Config.db_version()
any_migrated = False
try:
@@ -78,13 +78,13 @@ def run(self) -> int:
m.migrate(p) # no comments
print(f"{Fore.GREEN}Portfolio model schemas updated")
- with p.begin_session() as s:
+ with p.begin_session():
v = max(
Version(__version__),
*[m.min_version() for m in MIGRATORS],
)
- Config.set_(s, ConfigKey.VERSION, str(v))
+ Config.set_(ConfigKey.VERSION, str(v))
except Exception: # pragma: no cover
# No immediate exception thrown, can't easily test
portfolio.Portfolio.restore(p, tar_ver=tar_ver)
diff --git a/nummus/commands/summarize.py b/nummus/commands/summarize.py
index 999e049a..7bd49fc3 100644
--- a/nummus/commands/summarize.py
+++ b/nummus/commands/summarize.py
@@ -108,41 +108,33 @@ def _get_summary(
# Defer for faster time to main
from sqlalchemy import func
- from nummus import utils
+ from nummus import sql, utils
from nummus.models.account import Account
from nummus.models.asset import Asset, AssetCategory, AssetValuation
from nummus.models.config import Config
from nummus.models.currency import CURRENCY_FORMATS
from nummus.models.transaction import TransactionSplit
- from nummus.models.utils import query_count
today = datetime.datetime.now().astimezone().date()
today_ord = today.toordinal()
- with self._p.begin_session() as s:
- accts = {acct.id_: acct for acct in s.query(Account).all()}
- assets = {
- a.id_: a
- for a in (
- s.query(Asset).where(Asset.category != AssetCategory.INDEX).all()
- )
- }
+ with self._p.begin_session():
+ accts = {acct.id_: acct for acct in Account.all()}
+ query = Asset.query().where(Asset.category != AssetCategory.INDEX)
+ assets = {a.id_: a for a in sql.yield_(query)}
# Get the inception date
- start_date_ord: int = (
- s.query(
- func.min(TransactionSplit.date_ord),
- ).scalar()
- or datetime.date(1970, 1, 1).toordinal()
+ query = TransactionSplit.query(
+ func.min(TransactionSplit.date_ord),
)
+ start_date_ord = sql.scalar(query) or datetime.date(1970, 1, 1).toordinal()
n_accounts = len(accts)
- n_transactions = query_count(s.query(TransactionSplit))
+ n_transactions = TransactionSplit.count()
n_assets = len(assets)
- n_valuations = query_count(s.query(AssetValuation))
+ n_valuations = AssetValuation.count()
value_accts, profit_accts, value_assets = Account.get_value_all(
- s,
start_date_ord,
today_ord,
)
@@ -180,7 +172,6 @@ def _get_summary(
)
profit_assets = Account.get_profit_by_asset_all(
- s,
start_date_ord,
today_ord,
)
@@ -224,7 +215,7 @@ def _get_summary(
"total_asset_value": total_asset_value,
"assets": summary_assets,
"db_size": self._p.path.stat().st_size,
- "cf": CURRENCY_FORMATS[Config.base_currency(s)],
+ "cf": CURRENCY_FORMATS[Config.base_currency()],
}
@classmethod
diff --git a/nummus/controllers/accounts.py b/nummus/controllers/accounts.py
index d7eca261..b0cbd786 100644
--- a/nummus/controllers/accounts.py
+++ b/nummus/controllers/accounts.py
@@ -11,11 +11,10 @@
from sqlalchemy import func
from nummus import exceptions as exc
-from nummus import utils, web
+from nummus import sql, utils, web
from nummus.controllers import base, transactions
from nummus.models.account import Account, AccountCategory
from nummus.models.asset import Asset, AssetCategory
-from nummus.models.base import YIELD_PER
from nummus.models.config import Config
from nummus.models.currency import (
Currency,
@@ -24,13 +23,11 @@
)
from nummus.models.transaction import Transaction, TransactionSplit
from nummus.models.transaction_category import TransactionCategory
-from nummus.models.utils import query_to_dict
if TYPE_CHECKING:
import datetime
import werkzeug
- from sqlalchemy import orm
from nummus.models.currency import CurrencyFormat
@@ -123,12 +120,12 @@ def page_all() -> flask.Response:
"""
p = web.portfolio
- with p.begin_session() as s:
+ with p.begin_session():
include_closed = "include-closed" in flask.request.args
return base.page(
"accounts/page-all.jinja",
"Accounts",
- ctx=ctx_accounts(s, base.today_client(), include_closed=include_closed),
+ ctx=ctx_accounts(base.today_client(), include_closed=include_closed),
)
@@ -144,11 +141,10 @@ def page(uri: str) -> flask.Response:
"""
p = web.portfolio
today = base.today_client()
- with p.begin_session() as s:
- acct = base.find(s, Account, uri)
+ with p.begin_session():
+ acct = base.find(Account, uri)
args = flask.request.args
txn_table, title = transactions.ctx_table(
- s,
today,
args.get("search"),
args.get("account"),
@@ -163,16 +159,15 @@ def page(uri: str) -> flask.Response:
title = title.removeprefix("Transactions").strip()
title = f"{acct.name}, {title}" if title else f"{acct.name}"
- ctx = ctx_account(s, acct, today)
+ ctx = ctx_account(acct, today)
if acct.category == AccountCategory.INVESTMENT:
ctx["performance"] = ctx_performance(
- s,
acct,
today,
args.get("chart-period"),
CURRENCY_FORMATS[acct.currency],
)
- ctx["assets"] = ctx_assets(s, acct, today)
+ ctx["assets"] = ctx_assets(acct, today)
return base.page(
"accounts/page.jinja",
title=title,
@@ -192,7 +187,7 @@ def new() -> str | flask.Response:
"""
p = web.portfolio
with p.begin_session() as s:
- base_currency = Config.base_currency(s)
+ base_currency = Config.base_currency()
if flask.request.method == "GET":
ctx: AccountContext = {
"uri": None,
@@ -236,7 +231,7 @@ def new() -> str | flask.Response:
try:
with s.begin_nested():
- acct = Account(
+ Account.create(
institution=institution,
name=name,
number=number,
@@ -245,7 +240,6 @@ def new() -> str | flask.Response:
budgeted=budgeted,
currency=currency,
)
- s.add(acct)
except (exc.IntegrityError, exc.InvalidORMValueError) as e:
return base.error(e)
@@ -267,18 +261,18 @@ def account(uri: str) -> str | werkzeug.Response:
today_ord = today.toordinal()
with p.begin_session() as s:
- base_currency = Config.base_currency(s)
+ base_currency = Config.base_currency()
- acct = base.find(s, Account, uri)
+ acct = base.find(Account, uri)
if flask.request.method == "GET":
return flask.render_template(
"accounts/edit.jinja",
- acct=ctx_account(s, acct, today),
+ acct=ctx_account(acct, today),
)
if flask.request.method == "DELETE":
with s.begin_nested():
- s.delete(acct)
+ acct.delete()
return flask.redirect(flask.url_for("accounts.page_all"))
values, _, _ = acct.get_value(today_ord, today_ord)
@@ -326,14 +320,13 @@ def performance(uri: str) -> flask.Response:
"""
p = web.portfolio
args = flask.request.args
- with p.begin_session() as s:
- acct = base.find(s, Account, uri)
+ with p.begin_session():
+ acct = base.find(Account, uri)
html = flask.render_template(
"accounts/performance.jinja",
acct={
"uri": uri,
"performance": ctx_performance(
- s,
acct,
base.today_client(),
args.get("chart-period"),
@@ -365,13 +358,13 @@ def validation() -> str:
p = web.portfolio
# dict{key: (required, prop if unique required)}
- properties: dict[str, tuple[bool, orm.QueryableAttribute | None]] = {
+ properties: dict[str, tuple[bool, sql.Column | None]] = {
"name": (True, Account.name),
"institution": (True, None),
"number": (False, Account.number),
}
- with p.begin_session() as s:
+ with p.begin_session():
args = flask.request.args
uri = args.get("uri")
for key, (required, prop) in properties.items():
@@ -380,7 +373,7 @@ def validation() -> str:
return base.validate_string(
args[key],
is_required=required,
- session=s,
+ cls=Account,
no_duplicates=prop,
no_duplicate_wheres=(
None if uri is None else [Account.id_ != Account.uri_to_id(uri)]
@@ -391,7 +384,6 @@ def validation() -> str:
def ctx_account(
- s: orm.Session,
acct: Account,
today: datetime.date,
*,
@@ -400,7 +392,6 @@ def ctx_account(
"""Get the context to build the account details.
Args:
- s: SQL session to use
acct: Account to generate context for
today: Today's date
skip_today: True will skip fetching today's value
@@ -423,32 +414,24 @@ def ctx_account(
None if updated_on_ord is None else today_ord - updated_on_ord
)
- query = (
- s.query(Transaction)
- .with_entities(
- func.count(Transaction.id_),
- func.sum(Transaction.amount),
- )
- .where(
- Transaction.date_ord == today_ord,
- Transaction.account_id == acct.id_,
- )
+ query = Transaction.query(
+ func.count(Transaction.id_),
+ func.sum(Transaction.amount),
+ ).where(
+ Transaction.date_ord == today_ord,
+ Transaction.account_id == acct.id_,
)
- n_today, change_today = query.one()
+ n_today, change_today = sql.one(query)
change_today: Decimal = change_today or Decimal()
- query = (
- s.query(Transaction)
- .with_entities(
- func.count(Transaction.id_),
- func.sum(Transaction.amount),
- )
- .where(
- Transaction.date_ord > today_ord,
- Transaction.account_id == acct.id_,
- )
+ query = Transaction.query(
+ func.count(Transaction.id_),
+ func.sum(Transaction.amount),
+ ).where(
+ Transaction.date_ord > today_ord,
+ Transaction.account_id == acct.id_,
)
- n_future, change_future = query.one()
+ n_future, change_future = sql.one(query)
values, _, _ = acct.get_value(today_ord, today_ord)
current_value = values[0]
@@ -478,7 +461,6 @@ def ctx_account(
def ctx_performance(
- s: orm.Session,
acct: Account,
today: datetime.date,
period: str | None,
@@ -487,7 +469,6 @@ def ctx_performance(
"""Get the context to build the account performance details.
Args:
- s: SQL session to use
acct: Account to generate context for
today: Today's date
period: Period string to get data for
@@ -502,25 +483,30 @@ def ctx_performance(
end_ord = end.toordinal()
start_ord = acct.opened_on_ord or end_ord if start is None else start.toordinal()
- query = s.query(TransactionCategory.id_, TransactionCategory.name).where(
+ query = TransactionCategory.query(
+ TransactionCategory.id_,
+ TransactionCategory.name,
+ ).where(
TransactionCategory.is_profit_loss.is_(True),
)
- pnl_categories: dict[int, str] = query_to_dict(query)
+ pnl_categories: dict[int, str] = sql.to_dict(query)
# Calculate total cost basis
total_cost_basis = Decimal()
dividends = Decimal()
fees = Decimal()
query = (
- s.query(TransactionSplit)
- .with_entities(TransactionSplit.category_id, func.sum(TransactionSplit.amount))
+ TransactionSplit.query(
+ TransactionSplit.category_id,
+ func.sum(TransactionSplit.amount),
+ )
.where(
TransactionSplit.date_ord <= end_ord,
TransactionSplit.account_id == acct.id_,
)
.group_by(TransactionSplit.category_id)
)
- for cat_id, value in query.yield_per(YIELD_PER):
+ for cat_id, value in sql.yield_(query):
name = pnl_categories.get(cat_id)
if name is None:
total_cost_basis += value
@@ -566,14 +552,12 @@ def ctx_performance(
def ctx_assets(
- s: orm.Session,
acct: Account,
today: datetime.date,
) -> list[AssetContext] | None:
"""Get the context to build the account assets.
Args:
- s: SQL session to use
acct: Account to generate context for
today: Today's date
@@ -591,32 +575,32 @@ def ctx_assets(
return None # Not an investment account
# Include all assets every held
- query = s.query(TransactionSplit.asset_id).where(
- TransactionSplit.account_id == acct.id_,
- TransactionSplit.asset_id.is_not(None),
+ query = (
+ TransactionSplit.query(TransactionSplit.asset_id)
+ .where(
+ TransactionSplit.account_id == acct.id_,
+ TransactionSplit.asset_id.is_not(None),
+ )
+ .distinct()
)
- a_ids = {a_id for a_id, in query.distinct()}
+ a_ids = {a_id for a_id in sql.col0(query) if a_id}
- end_prices = Asset.get_value_all(s, today_ord, today_ord, ids=a_ids)
+ end_prices = Asset.get_value_all(today_ord, today_ord, ids=a_ids)
asset_profits = acct.get_profit_by_asset(start_ord, today_ord)
# Sum of profits should match final profit value, add any mismatch to cash
- query = (
- s.query(Asset)
- .with_entities(
- Asset.id_,
- Asset.name,
- Asset.ticker,
- Asset.category,
- )
- .where(Asset.id_.in_(a_ids))
- )
+ query = Asset.query(
+ Asset.id_,
+ Asset.name,
+ Asset.ticker,
+ Asset.category,
+ ).where(Asset.id_.in_(a_ids))
assets: list[AssetContext] = []
total_value = Decimal()
total_profit = Decimal()
- for a_id, name, ticker, category in query.yield_per(YIELD_PER):
+ for a_id, name, ticker, category in sql.yield_(query):
end_qty = asset_qtys[a_id]
end_price = end_prices[a_id][0]
end_value = end_qty * end_price
@@ -639,12 +623,11 @@ def ctx_assets(
assets.append(ctx_asset)
# Add in cash too
- cash: Decimal = (
- s.query(func.sum(TransactionSplit.amount))
- .where(TransactionSplit.account_id == acct.id_)
- .where(TransactionSplit.date_ord <= today_ord)
- .one()[0]
+ query = TransactionSplit.query(func.sum(TransactionSplit.amount)).where(
+ TransactionSplit.account_id == acct.id_,
+ TransactionSplit.date_ord <= today_ord,
)
+ cash: Decimal = sql.one(query)
total_value += cash
ctx_asset = {
"uri": None,
@@ -676,7 +659,6 @@ def ctx_assets(
def ctx_accounts(
- s: orm.Session,
today: datetime.date,
*,
include_closed: bool = False,
@@ -684,7 +666,6 @@ def ctx_accounts(
"""Get the context to build the accounts table.
Args:
- s: SQL session to use
today: Today's date
include_closed: True will include Accounts marked closed, False will exclude
@@ -705,71 +686,59 @@ def ctx_accounts(
# Get basic info
accounts: dict[int, AccountContext] = {}
currencies: dict[int, Currency] = {}
- query = s.query(Account).order_by(Account.category)
+ query = Account.query().order_by(Account.category)
if not include_closed:
query = query.where(Account.closed.is_(False))
- for acct in query.all():
- accounts[acct.id_] = ctx_account(s, acct, today, skip_today=True)
+ for acct in sql.yield_(query):
+ accounts[acct.id_] = ctx_account(acct, today, skip_today=True)
currencies[acct.id_] = acct.currency
if acct.closed:
n_closed += 1
# Get updated_on
query = (
- s.query(Transaction)
- .with_entities(
+ Transaction.query(
Transaction.account_id,
func.max(Transaction.date_ord),
)
.group_by(Transaction.account_id)
.where(Transaction.account_id.in_(accounts))
)
- for acct_id, updated_on_ord in query.all():
+ for acct_id, updated_on_ord in sql.yield_(query):
acct_id: int
updated_on_ord: int
accounts[acct_id]["updated_days_ago"] = today_ord - updated_on_ord
# Get n_today
query = (
- s.query(Transaction)
- .with_entities(
+ Transaction.query(
Transaction.account_id,
func.count(Transaction.id_),
func.sum(Transaction.amount),
)
- .where(Transaction.date_ord == today_ord)
+ .where(Transaction.date_ord == today_ord, Transaction.account_id.in_(accounts))
.group_by(Transaction.account_id)
- .where(Transaction.account_id.in_(accounts))
)
- for acct_id, n_today, change_today in query.all():
- acct_id: int
- n_today: int
- change_today: Decimal | None
+ for acct_id, n_today, change_today in sql.yield_(query):
accounts[acct_id]["n_today"] = n_today
accounts[acct_id]["change_today"] = change_today or Decimal()
# Get n_future
query = (
- s.query(Transaction)
- .with_entities(
+ Transaction.query(
Transaction.account_id,
func.count(Transaction.id_),
func.sum(Transaction.amount),
)
- .where(Transaction.date_ord > today_ord)
+ .where(Transaction.date_ord > today_ord, Transaction.account_id.in_(accounts))
.group_by(Transaction.account_id)
- .where(Transaction.account_id.in_(accounts))
)
- for acct_id, n_future, change_future in query.all():
- acct_id: int
- n_future: int
- change_future: Decimal
+ for acct_id, n_future, change_future in sql.yield_(query):
accounts[acct_id]["n_future"] = n_future
accounts[acct_id]["change_future"] = change_future
- base_currency = Config.base_currency(s)
+ base_currency = Config.base_currency()
forex = Asset.get_forex(
- s,
today_ord,
today_ord,
base_currency,
@@ -777,7 +746,7 @@ def ctx_accounts(
)
# Get all Account values
- acct_values, _, _ = Account.get_value_all(s, today_ord, today_ord, ids=accounts)
+ acct_values, _, _ = Account.get_value_all(today_ord, today_ord, ids=accounts)
for acct_id, ctx in accounts.items():
v = acct_values[acct_id][0]
ctx["value"] = v
@@ -822,7 +791,7 @@ def ctx_accounts(
},
"include_closed": include_closed,
"n_closed": n_closed,
- "currency_format": CURRENCY_FORMATS[Config.base_currency(s)],
+ "currency_format": CURRENCY_FORMATS[Config.base_currency()],
}
@@ -840,9 +809,8 @@ def txns(uri: str) -> str | flask.Response:
args = flask.request.args
first_page = "page" not in args
- with p.begin_session() as s:
+ with p.begin_session():
txn_table, title = transactions.ctx_table(
- s,
base.today_client(),
args.get("search"),
args.get("account"),
@@ -855,7 +823,7 @@ def txns(uri: str) -> str | flask.Response:
acct_uri=uri,
)
title = title.removeprefix("Transactions").strip()
- acct = base.find(s, Account, uri)
+ acct = base.find(Account, uri)
title = f"{acct.name}, {title}" if title else f"{acct.name}"
html_title = f"
{title} - nummus\n"
html = html_title + flask.render_template(
@@ -893,8 +861,8 @@ def txns_options(uri: str) -> str:
"""
p = web.portfolio
- with p.begin_session() as s:
- accounts = Account.map_name(s)
+ with p.begin_session():
+ accounts = Account.map_name()
args = flask.request.args
uncleared = "uncleared" in args
@@ -905,7 +873,6 @@ def txns_options(uri: str) -> str:
selected_end = args.get("end")
tbl_query = transactions.table_query(
- s,
None,
selected_account,
selected_period,
@@ -918,7 +885,7 @@ def txns_options(uri: str) -> str:
tbl_query,
base.today_client(),
accounts,
- base.tranaction_category_groups(s),
+ base.tranaction_category_groups(),
selected_account,
selected_category,
)
diff --git a/nummus/controllers/allocation.py b/nummus/controllers/allocation.py
index a1e2d161..2f68e4f9 100644
--- a/nummus/controllers/allocation.py
+++ b/nummus/controllers/allocation.py
@@ -7,11 +7,10 @@
from decimal import Decimal
from typing import TYPE_CHECKING, TypedDict
-from nummus import web
+from nummus import sql, web
from nummus.controllers import base
from nummus.models.account import Account
from nummus.models.asset import Asset, AssetSector
-from nummus.models.base import YIELD_PER
from nummus.models.config import Config
from nummus.models.currency import CURRENCY_FORMATS
@@ -19,7 +18,6 @@
import datetime
import flask
- from sqlalchemy import orm
from nummus.models.asset import AssetCategory, USSector
from nummus.models.currency import CurrencyFormat
@@ -76,19 +74,18 @@ def page() -> flask.Response:
"""
p = web.portfolio
- with p.begin_session() as s:
+ with p.begin_session():
return base.page(
"allocation/page.jinja",
title="Asset allocation",
- allocation=ctx_allocation(s, base.today_client()),
+ allocation=ctx_allocation(base.today_client()),
)
-def ctx_allocation(s: orm.Session, today: datetime.date) -> AllocationContext:
+def ctx_allocation(today: datetime.date) -> AllocationContext:
"""Get the context to build the allocation chart.
Args:
- s: SQL session to use
today: Today's date
Returns:
@@ -98,7 +95,7 @@ def ctx_allocation(s: orm.Session, today: datetime.date) -> AllocationContext:
today_ord = today.toordinal()
asset_qtys: dict[int, Decimal] = defaultdict(Decimal)
- acct_qtys = Account.get_asset_qty_all(s, today_ord, today_ord)
+ acct_qtys = Account.get_asset_qty_all(today_ord, today_ord)
for acct_qty in acct_qtys.values():
for a_id, values in acct_qty.items():
asset_qtys[a_id] += values[0]
@@ -107,7 +104,6 @@ def ctx_allocation(s: orm.Session, today: datetime.date) -> AllocationContext:
asset_prices = {
a_id: values[0]
for a_id, values in Asset.get_value_all(
- s,
today_ord,
today_ord,
ids=set(asset_qtys),
@@ -117,13 +113,13 @@ def ctx_allocation(s: orm.Session, today: datetime.date) -> AllocationContext:
asset_values = {a_id: qty * asset_prices[a_id] for a_id, qty in asset_qtys.items()}
asset_sectors: dict[int, dict[USSector, Decimal]] = defaultdict(dict)
- for a_sector in s.query(AssetSector).yield_per(YIELD_PER):
+ for a_sector in sql.yield_(AssetSector.query()):
asset_sectors[a_sector.asset_id][a_sector.sector] = a_sector.weight
assets_by_category: dict[AssetCategory, list[AssetContext]] = defaultdict(list)
assets_by_sector: dict[USSector, list[AssetContext]] = defaultdict(list)
- query = s.query(Asset).where(Asset.id_.in_(asset_qtys)).order_by(Asset.name)
- for asset in query.yield_per(YIELD_PER):
+ query = Asset.query().where(Asset.id_.in_(asset_qtys)).order_by(Asset.name)
+ for asset in sql.yield_(query):
qty = asset_qtys[asset.id_]
value = asset_values[asset.id_]
@@ -174,7 +170,7 @@ def chart_assets(assets: list[AssetContext]) -> list[ChartAssetContext]:
for a in assets
]
- cf = CURRENCY_FORMATS[Config.base_currency(s)]
+ cf = CURRENCY_FORMATS[Config.base_currency()]
return {
"categories": sorted(categories, key=operator.itemgetter("name")),
diff --git a/nummus/controllers/assets.py b/nummus/controllers/assets.py
index 975344a2..560b5ba6 100644
--- a/nummus/controllers/assets.py
+++ b/nummus/controllers/assets.py
@@ -12,7 +12,7 @@
from sqlalchemy import func
from nummus import exceptions as exc
-from nummus import utils, web
+from nummus import sql, utils, web
from nummus.controllers import base
from nummus.models.account import Account
from nummus.models.asset import (
@@ -22,19 +22,15 @@
AssetSplit,
AssetValuation,
)
-from nummus.models.base import YIELD_PER
from nummus.models.config import Config
from nummus.models.currency import (
Currency,
CURRENCY_FORMATS,
)
from nummus.models.transaction import TransactionSplit
-from nummus.models.utils import query_count
if TYPE_CHECKING:
- import sqlalchemy
import werkzeug
- from sqlalchemy import orm
from nummus.models.currency import CurrencyFormat
@@ -132,8 +128,8 @@ def page_all() -> flask.Response:
p = web.portfolio
include_unheld = "include-unheld" in flask.request.args
- with p.begin_session() as s:
- categories = ctx_rows(s, base.today_client(), include_unheld=include_unheld)
+ with p.begin_session():
+ categories = ctx_rows(base.today_client(), include_unheld=include_unheld)
return base.page(
"assets/page-all.jinja",
@@ -160,10 +156,9 @@ def page(uri: str) -> flask.Response:
"""
p = web.portfolio
args = flask.request.args
- with p.begin_session() as s:
- a = base.find(s, Asset, uri)
+ with p.begin_session():
+ a = base.find(Asset, uri)
ctx = ctx_asset(
- s,
a,
base.today_client(),
args.get("period"),
@@ -190,7 +185,7 @@ def new() -> str | flask.Response:
with p.begin_session() as s:
if flask.request.method == "GET":
- currency = Config.base_currency(s)
+ currency = Config.base_currency()
ctx: AssetContext = {
"uri": None,
"name": "",
@@ -222,14 +217,13 @@ def new() -> str | flask.Response:
try:
with s.begin_nested():
- a = Asset(
+ Asset.create(
name=name,
description=description,
category=category,
ticker=ticker,
currency=currency,
)
- s.add(a)
except (exc.IntegrityError, exc.InvalidORMValueError) as e:
return base.error(e)
@@ -248,14 +242,13 @@ def asset(uri: str) -> str | werkzeug.Response:
"""
p = web.portfolio
with p.begin_session() as s:
- a = base.find(s, Asset, uri)
+ a = base.find(Asset, uri)
if flask.request.method == "GET":
args = flask.request.args
return flask.render_template(
"assets/edit.jinja",
asset=ctx_asset(
- s,
a,
base.today_client(),
args.get("period"),
@@ -267,10 +260,10 @@ def asset(uri: str) -> str | werkzeug.Response:
)
if flask.request.method == "DELETE":
with s.begin_nested():
- s.query(AssetSector).where(AssetSector.asset_id == a.id_).delete()
- s.query(AssetSplit).where(AssetSplit.asset_id == a.id_).delete()
- s.query(AssetValuation).where(AssetValuation.asset_id == a.id_).delete()
- s.delete(a)
+ AssetSector.query().where(AssetSector.asset_id == a.id_).delete()
+ AssetSplit.query().where(AssetSplit.asset_id == a.id_).delete()
+ AssetValuation.query().where(AssetValuation.asset_id == a.id_).delete()
+ a.delete()
return flask.redirect(flask.url_for("assets.page_all"))
form = flask.request.form
@@ -302,14 +295,14 @@ def performance(uri: str) -> flask.Response:
"""
p = web.portfolio
- with p.begin_session() as s:
- a = base.find(s, Asset, uri)
+ with p.begin_session():
+ a = base.find(Asset, uri)
period = flask.request.args.get("chart-period")
html = flask.render_template(
"assets/performance.jinja",
asset={
"uri": uri,
- "performance": ctx_performance(s, a, base.today_client(), period),
+ "performance": ctx_performance(a, base.today_client(), period),
},
)
response = flask.make_response(html)
@@ -337,11 +330,10 @@ def table(uri: str) -> str | flask.Response:
"""
p = web.portfolio
args = flask.request.args
- with p.begin_session() as s:
- a = base.find(s, Asset, uri)
+ with p.begin_session():
+ a = base.find(Asset, uri)
cf = CURRENCY_FORMATS[a.currency]
val_table = ctx_table(
- s,
a,
base.today_client(),
args.get("period"),
@@ -382,13 +374,13 @@ def validation() -> str:
"""
p = web.portfolio
# dict{key: (required, prop if unique required)}
- properties: dict[str, tuple[bool, orm.QueryableAttribute | None]] = {
+ properties: dict[str, tuple[bool, sql.Column | None]] = {
"name": (True, Asset.name),
"description": (False, None),
"ticker": (False, Asset.ticker),
}
- with p.begin_session() as s:
+ with p.begin_session():
args = flask.request.args
uri = args.get("uri")
for key, (required, prop) in properties.items():
@@ -398,7 +390,7 @@ def validation() -> str:
args[key],
is_required=required,
check_length=key != "ticker",
- session=s,
+ cls=Asset,
no_duplicates=prop,
no_duplicate_wheres=(
None if uri is None else [Asset.id_ != Asset.uri_to_id(uri)]
@@ -406,7 +398,7 @@ def validation() -> str:
)
if "date" in args:
- wheres: list[sqlalchemy.ColumnExpressionArgument] = []
+ wheres: list[sql.ColumnClause] = []
if uri:
wheres.append(
AssetValuation.asset_id == Asset.uri_to_id(uri),
@@ -420,7 +412,7 @@ def validation() -> str:
args["date"],
base.today_client(),
is_required=True,
- session=s,
+ cls=AssetValuation,
no_duplicates=AssetValuation.date_ord,
no_duplicate_wheres=wheres,
)
@@ -470,14 +462,13 @@ def new_valuation(uri: str) -> str | flask.Response:
try:
p = web.portfolio
- with p.begin_session() as s:
- a = base.find(s, Asset, uri)
- v = AssetValuation(
+ with p.begin_session():
+ a = base.find(Asset, uri)
+ AssetValuation.create(
asset_id=a.id_,
date_ord=date.toordinal(),
value=value,
)
- s.add(v)
except exc.IntegrityError as e:
# Get the line that starts with (...IntegrityError)
orig = str(e.orig)
@@ -505,7 +496,7 @@ def valuation(uri: str) -> str | flask.Response:
today = base.today_client()
with p.begin_session() as s:
- v = base.find(s, AssetValuation, uri)
+ v = base.find(AssetValuation, uri)
date_max = today + datetime.timedelta(days=utils.DAYS_IN_WEEK)
if flask.request.method == "GET":
@@ -521,7 +512,7 @@ def valuation(uri: str) -> str | flask.Response:
)
if flask.request.method == "DELETE":
date = v.date
- s.delete(v)
+ v.delete()
return base.dialog_swap(
event="valuation",
snackbar=f"{date} valuation deleted",
@@ -564,8 +555,8 @@ def update() -> str | flask.Response:
"""
p = web.portfolio
- with p.begin_session() as s:
- n = query_count(s.query(Asset).where(Asset.ticker.is_not(None)))
+ with p.begin_session():
+ n = sql.count(Asset.query().where(Asset.ticker.is_not(None)))
if flask.request.method == "GET":
return flask.render_template(
"assets/update.jinja",
@@ -600,7 +591,6 @@ def update() -> str | flask.Response:
def ctx_rows(
- s: orm.Session,
today: datetime.date,
*,
include_unheld: bool,
@@ -608,7 +598,6 @@ def ctx_rows(
"""Get the context to build the page all rows.
Args:
- s: SQL session to use
today: Today's date
include_unheld: True will include assets with zero current quantity
@@ -620,7 +609,7 @@ def ctx_rows(
today_ord = today.toordinal()
- accounts = Account.get_asset_qty_all(s, today_ord, today_ord)
+ accounts = Account.get_asset_qty_all(today_ord, today_ord)
qtys: dict[int, Decimal] = defaultdict(Decimal)
for acct_qtys in accounts.values():
for a_id, values in acct_qtys.items():
@@ -628,14 +617,14 @@ def ctx_rows(
held_ids = {a_id for a_id, qty in qtys.items() if qty}
query = (
- s.query(Asset)
+ Asset.query()
.where(Asset.category != AssetCategory.INDEX)
.order_by(Asset.category)
)
if not include_unheld:
query = query.where(Asset.id_.in_(held_ids))
- prices = Asset.get_value_all(s, today_ord, today_ord, held_ids)
- for asset in query.yield_per(YIELD_PER):
+ prices = Asset.get_value_all(today_ord, today_ord, held_ids)
+ for asset in sql.yield_(query):
qty = qtys[asset.id_]
price = prices[asset.id_][0]
value = qty * price
@@ -655,7 +644,6 @@ def ctx_rows(
def ctx_asset(
- s: orm.Session,
a: Asset,
today: datetime.date,
period: str | None,
@@ -667,7 +655,6 @@ def ctx_asset(
"""Get the context to build the asset details.
Args:
- s: SQL session to use
a: Asset to generate context for
today: Today's date
period: Period to get table for
@@ -681,7 +668,7 @@ def ctx_asset(
"""
valuation = (
- s.query(AssetValuation)
+ AssetValuation.query()
.where(AssetValuation.asset_id == a.id_)
.order_by(AssetValuation.date_ord.desc())
.first()
@@ -692,18 +679,14 @@ def ctx_asset(
else:
current_value = valuation.value
current_date = valuation.date
- deletable = (
- s.query(TransactionSplit.id_)
- .where(TransactionSplit.asset_id == a.id_)
- .limit(1)
- .scalar()
- is None
+ query = TransactionSplit.query(TransactionSplit.id_).where(
+ TransactionSplit.asset_id == a.id_,
)
+ deletable = not sql.any_(query)
- accounts = Account.map_name(s)
+ accounts = Account.map_name()
query = (
- s.query(TransactionSplit)
- .with_entities(
+ TransactionSplit.query(
TransactionSplit.account_id,
func.sum(TransactionSplit.asset_quantity),
)
@@ -722,7 +705,7 @@ def ctx_asset(
qty,
qty * current_value,
)
- for acct_id, qty in query.yield_per(YIELD_PER)
+ for acct_id, qty in sql.yield_(query)
if qty
]
@@ -738,15 +721,14 @@ def ctx_asset(
"currency_format": CURRENCY_FORMATS[a.currency],
"value": current_value,
"value_date": current_date,
- "performance": ctx_performance(s, a, today, period_chart),
- "table": ctx_table(s, a, today, period, start, end, page),
+ "performance": ctx_performance(a, today, period_chart),
+ "table": ctx_table(a, today, period, start, end, page),
"holdings": sorted(holdings, key=operator.itemgetter(2), reverse=True),
"deletable": deletable,
}
def ctx_performance(
- s: orm.Session,
a: Asset,
today: datetime.date,
period: str | None,
@@ -754,7 +736,6 @@ def ctx_performance(
"""Get the context to build the asset performance details.
Args:
- s: SQL session to use
a: Asset to generate context for
today: Today's date
period: Chart-period to fetch performance for
@@ -767,12 +748,10 @@ def ctx_performance(
start, end = base.parse_period(period, today)
end_ord = end.toordinal()
if start is None:
- start_ord = (
- s.query(func.min(AssetValuation.date_ord))
- .where(AssetValuation.asset_id == a.id_)
- .scalar()
- or end_ord
+ query = AssetValuation.query(func.min(AssetValuation.date_ord)).where(
+ AssetValuation.asset_id == a.id_,
)
+ start_ord = sql.one(query) or end_ord
else:
start_ord = start.toordinal()
@@ -782,12 +761,11 @@ def ctx_performance(
**base.chart_data(start_ord, end_ord, values),
"period": period,
"period_options": base.PERIOD_OPTIONS,
- "currency_format": CURRENCY_FORMATS[Config.base_currency(s)]._asdict(),
+ "currency_format": CURRENCY_FORMATS[Config.base_currency()]._asdict(),
}
def ctx_table(
- s: orm.Session,
a: Asset,
today: datetime.date,
period: str | None,
@@ -798,7 +776,6 @@ def ctx_table(
"""Get the context to build the valuations table.
Args:
- s: SQL session to use
a: Asset to get valuations for
today: Today's date
period: Period to get table for
@@ -813,7 +790,7 @@ def ctx_table(
page_start = None if page is None else datetime.date.fromisoformat(page).toordinal()
query = (
- s.query(AssetValuation)
+ AssetValuation.query()
.where(AssetValuation.asset_id == a.id_)
.order_by(AssetValuation.date_ord.desc())
)
@@ -854,7 +831,7 @@ def ctx_table(
"date_max": None,
"value": v.value,
}
- for v in query.limit(PAGE_LEN).yield_per(YIELD_PER)
+ for v in sql.yield_(query.limit(PAGE_LEN))
]
next_page = (
diff --git a/nummus/controllers/auth.py b/nummus/controllers/auth.py
index 6b33e8c9..792b0af4 100644
--- a/nummus/controllers/auth.py
+++ b/nummus/controllers/auth.py
@@ -124,18 +124,18 @@ def login() -> str | werkzeug.Response:
if not password:
return base.error("Password must not be blank")
- with p.begin_session() as s:
- expected_encoded = Config.fetch(s, ConfigKey.WEB_KEY)
+ with p.begin_session():
+ expected_encoded = Config.fetch(ConfigKey.WEB_KEY)
- expected = p.decrypt(expected_encoded)
- if password.encode() != expected:
- return base.error("Bad password")
+ expected = p.decrypt(expected_encoded)
+ if password.encode() != expected:
+ return base.error("Bad password")
- web_user = WebUser()
- flask_login.login_user(web_user, remember=True)
+ web_user = WebUser()
+ flask_login.login_user(web_user, remember=True)
- next_url = form.get("next")
- return flask.redirect(next_url or flask.url_for("common.page_dashboard"))
+ next_url = form.get("next")
+ return flask.redirect(next_url or flask.url_for("common.page_dashboard"))
def logout() -> str | werkzeug.Response:
diff --git a/nummus/controllers/base.py b/nummus/controllers/base.py
index 8bd2d913..9f367fd1 100644
--- a/nummus/controllers/base.py
+++ b/nummus/controllers/base.py
@@ -8,29 +8,23 @@
import textwrap
from decimal import Decimal
from pathlib import Path
-from typing import NamedTuple, overload, TYPE_CHECKING, TypedDict
+from typing import NamedTuple, overload, TypedDict
import flask
import flask.typing
from nummus import exceptions as exc
-from nummus import utils, web
+from nummus import sql, utils, web
from nummus.models.base import (
Base,
BaseEnum,
- YIELD_PER,
)
from nummus.models.transaction_category import (
TransactionCategory,
TransactionCategoryGroup,
)
-from nummus.models.utils import query_count
from nummus.version import __version__
-if TYPE_CHECKING:
- import sqlalchemy
- from sqlalchemy import orm
-
type Routes = dict[str, tuple[flask.typing.RouteCallable, list[str]]]
@@ -445,11 +439,10 @@ def change_redirect_to_htmx(response: flask.Response) -> flask.Response:
return response
-def find[T: Base](s: orm.Session, cls: type[T], uri: str) -> T:
+def find[T: Base](cls: type[T], uri: str) -> T:
"""Find the matching object by URI.
Args:
- s: SQL session to search
cls: Type of object to find
uri: URI to find
@@ -466,7 +459,7 @@ def find[T: Base](s: orm.Session, cls: type[T], uri: str) -> T:
except (exc.InvalidURIError, exc.WrongURITypeError) as e:
raise exc.http.BadRequest(str(e)) from e
try:
- obj = s.query(cls).where(cls.id_ == id_).one()
+ obj = sql.one(cls.query().where(cls.id_ == id_))
except exc.NoResultFound as e:
msg = f"{cls.__name__} {uri} not found in Portfolio"
raise exc.http.NotFound(msg) from e
@@ -558,9 +551,9 @@ def validate_string(
*,
is_required: bool = False,
check_length: bool = True,
- session: orm.Session | None = None,
- no_duplicates: orm.QueryableAttribute | None = None,
- no_duplicate_wheres: list[sqlalchemy.ColumnExpressionArgument] | None = None,
+ cls: type[Base] | None = None,
+ no_duplicates: sql.Column | None = None,
+ no_duplicate_wheres: list[sql.ColumnClause] | None = None,
) -> str:
"""Validate a string matches requirements.
@@ -568,7 +561,7 @@ def validate_string(
value: String to test
is_required: True will require the value be non-empty
check_length: True will require value to be MIN_STR_LEN long
- session: SQL session to use for no_duplicates
+ cls: Model class to test for duplicate values
no_duplicates: Property to test for duplicate values
no_duplicate_wheres: Additional where clauses to add to no_duplicates
@@ -582,11 +575,11 @@ def validate_string(
if check_length and len(value) < utils.MIN_STR_LEN:
# Ticker can be short
return f"{utils.MIN_STR_LEN} characters required"
- if no_duplicates is None:
+ if no_duplicates is None or cls is None:
return ""
return _test_duplicates(
value,
- session,
+ cls,
no_duplicates,
no_duplicate_wheres,
)
@@ -594,19 +587,15 @@ def validate_string(
def _test_duplicates(
value: object,
- session: orm.Session | None,
- no_duplicates: orm.QueryableAttribute,
- no_duplicate_wheres: list[sqlalchemy.ColumnExpressionArgument] | None,
+ cls: type[Base],
+ no_duplicates: sql.Column,
+ no_duplicate_wheres: list[sql.ColumnClause] | None,
) -> str:
- if session is None:
- msg = "Cannot test no_duplicates without a session"
- raise TypeError(msg)
- query = session.query(no_duplicates.parent).where(
+ query = cls.query().where(
no_duplicates == value,
*(no_duplicate_wheres or []),
)
- n = query_count(query)
- if n != 0:
+ if sql.any_(query):
return "Must be unique"
return ""
@@ -617,9 +606,9 @@ def validate_date(
*,
is_required: bool = False,
max_future: int | None = utils.DAYS_IN_WEEK,
- session: orm.Session | None = None,
- no_duplicates: orm.QueryableAttribute | None = None,
- no_duplicate_wheres: list[sqlalchemy.ColumnExpressionArgument] | None = None,
+ cls: type[Base] | None = None,
+ no_duplicates: sql.Column | None = None,
+ no_duplicate_wheres: list[sql.ColumnClause] | None = None,
) -> str:
"""Validate a date string matches requirements.
@@ -628,7 +617,7 @@ def validate_date(
today: Today's date
is_required: True will require the value be non-empty
max_future: Maximum number of days date is allowed in the future
- session: SQL session to use for no_duplicates
+ cls: Model class to test for duplicate values
no_duplicates: Property to test for duplicate values
no_duplicate_wheres: Additional where clauses to add to no_duplicates
@@ -652,11 +641,11 @@ def validate_date(
):
return f"Only up to {utils.format_days(max_future)} in advance"
- if no_duplicates is None:
+ if no_duplicates is None or cls is None:
return ""
return _test_duplicates(
date.toordinal(),
- session,
+ cls,
no_duplicates,
no_duplicate_wheres,
)
@@ -760,34 +749,27 @@ def parse_date(
return date
-def tranaction_category_groups(s: orm.Session) -> CategoryGroups:
+def tranaction_category_groups() -> CategoryGroups:
"""Get TransactionCategory by groups.
- Args:
- s: SQL session to use
-
Returns:
dict{group: list[CategoryContext]}
"""
- query = (
- s.query(TransactionCategory)
- .with_entities(
- TransactionCategory.id_,
- TransactionCategory.name,
- TransactionCategory.emoji_name,
- TransactionCategory.asset_linked,
- TransactionCategory.group,
- )
- .order_by(TransactionCategory.name)
- )
+ query = TransactionCategory.query(
+ TransactionCategory.id_,
+ TransactionCategory.name,
+ TransactionCategory.emoji_name,
+ TransactionCategory.asset_linked,
+ TransactionCategory.group,
+ ).order_by(TransactionCategory.name)
category_groups: CategoryGroups = {
TransactionCategoryGroup.INCOME: [],
TransactionCategoryGroup.EXPENSE: [],
TransactionCategoryGroup.TRANSFER: [],
TransactionCategoryGroup.OTHER: [],
}
- for t_cat_id, name, emoji_name, asset_linked, group in query.yield_per(YIELD_PER):
+ for t_cat_id, name, emoji_name, asset_linked, group in sql.yield_(query):
category_groups[group].append(
CategoryContext(
TransactionCategory.id_to_uri(t_cat_id),
diff --git a/nummus/controllers/budgeting.py b/nummus/controllers/budgeting.py
index e2eda4cc..d77607ef 100644
--- a/nummus/controllers/budgeting.py
+++ b/nummus/controllers/budgeting.py
@@ -10,12 +10,10 @@
from typing import NamedTuple, NotRequired, TYPE_CHECKING, TypedDict
import flask
-from sqlalchemy import sql
from nummus import exceptions as exc
-from nummus import utils, web
+from nummus import sql, utils, web
from nummus.controllers import base
-from nummus.models.base import YIELD_PER
from nummus.models.budget import (
BudgetAssignment,
BudgetGroup,
@@ -29,11 +27,9 @@
TransactionCategory,
TransactionCategoryGroup,
)
-from nummus.models.utils import query_count
if TYPE_CHECKING:
import werkzeug.datastructures
- from sqlalchemy import orm
from nummus.models.budget import BudgetAvailableCategory
from nummus.models.currency import CurrencyFormat
@@ -158,10 +154,9 @@ def page() -> flask.Response:
)
sidebar_uri = args.get("sidebar") or None
- with p.begin_session() as s:
- data = BudgetAssignment.get_monthly_available(s, month)
+ with p.begin_session():
+ data = BudgetAssignment.get_monthly_available(month)
budget, title = ctx_budget(
- s,
today,
month,
data.categories,
@@ -169,7 +164,6 @@ def page() -> flask.Response:
flask.session.get("groups_open", []),
)
sidebar = ctx_sidebar(
- s,
today,
month,
data.categories,
@@ -253,21 +247,21 @@ def assign(uri: str) -> str:
form = flask.request.form
amount = utils.evaluate_real_statement(form["amount"]) or Decimal()
- with p.begin_session() as s:
- cat = base.find(s, TransactionCategory, uri)
+ with p.begin_session():
+ cat = base.find(TransactionCategory, uri)
group_uri = (
None
if cat.budget_group_id is None
else BudgetGroup.id_to_uri(cat.budget_group_id)
)
if amount == 0:
- s.query(BudgetAssignment).where(
+ BudgetAssignment.query().where(
BudgetAssignment.month_ord == month_ord,
BudgetAssignment.category_id == cat.id_,
).delete()
else:
a = (
- s.query(BudgetAssignment)
+ BudgetAssignment.query()
.where(
BudgetAssignment.month_ord == month_ord,
BudgetAssignment.category_id == cat.id_,
@@ -275,18 +269,16 @@ def assign(uri: str) -> str:
.one_or_none()
)
if a is None:
- a = BudgetAssignment(
+ a = BudgetAssignment.create(
month_ord=month_ord,
category_id=cat.id_,
amount=amount,
)
- s.add(a)
else:
a.amount = amount
- data = BudgetAssignment.get_monthly_available(s, month)
+ data = BudgetAssignment.get_monthly_available(month)
budget, _ = ctx_budget(
- s,
today,
month,
data.categories,
@@ -295,7 +287,6 @@ def assign(uri: str) -> str:
)
sidebar_uri = form.get("sidebar") or None
sidebar = ctx_sidebar(
- s,
today,
month,
data.categories,
@@ -328,15 +319,15 @@ def move(uri: str) -> str | flask.Response:
month = datetime.date.fromisoformat(month_str + "-01")
month_ord = month.toordinal()
- with p.begin_session() as s:
- cf = CURRENCY_FORMATS[Config.base_currency(s)]
- data = BudgetAssignment.get_monthly_available(s, month)
+ with p.begin_session():
+ cf = CURRENCY_FORMATS[Config.base_currency()]
+ data = BudgetAssignment.get_monthly_available(month)
if uri == "income":
src_cat = None
src_cat_id = None
src_available = data.assignable
else:
- src_cat = base.find(s, TransactionCategory, uri)
+ src_cat = base.find(TransactionCategory, uri)
src_cat_id = src_cat.id_
src_available = data.categories[src_cat_id].available
@@ -358,7 +349,7 @@ def move(uri: str) -> str | flask.Response:
# Max of the negative number is min of the positive/abs
to_move = max(src_available, -dest_available)
- BudgetAssignment.move(s, month_ord, src_cat_id, dest_cat_id, to_move)
+ BudgetAssignment.move(month_ord, src_cat_id, dest_cat_id, to_move)
return base.dialog_swap(
event="budget",
@@ -373,8 +364,7 @@ def move(uri: str) -> str | flask.Response:
destination = args.get("destination")
query = (
- s.query(TransactionCategory)
- .with_entities(
+ TransactionCategory.query(
TransactionCategory.id_,
TransactionCategory.emoji_name,
TransactionCategory.group,
@@ -386,11 +376,7 @@ def move(uri: str) -> str | flask.Response:
)
.order_by(TransactionCategory.group, TransactionCategory.name)
)
- for t_cat_id, name, group in query.yield_per(YIELD_PER):
- t_cat_id: int
- name: str
- group: TransactionCategoryGroup
-
+ for t_cat_id, name, group in sql.yield_(query):
t_cat_uri = TransactionCategory.id_to_uri(t_cat_id)
available = data.categories[t_cat_id].available
if destination or src_available > 0 or available > 0:
@@ -425,7 +411,7 @@ def reorder() -> str:
t_cat_uris = form.getlist("category-uri")
groups = form.getlist("group")
- with p.begin_session() as s:
+ with p.begin_session():
g_positions = {
BudgetGroup.uri_to_id(g_uri): i for i, g_uri in enumerate(group_uris)
}
@@ -452,7 +438,7 @@ def reorder() -> str:
last_group = g_uri
# Set all to None first so swapping can occur without unique violations
- s.query(TransactionCategory).update(
+ TransactionCategory.query().update(
{
TransactionCategory.budget_group_id: None,
TransactionCategory.budget_position: None,
@@ -460,11 +446,11 @@ def reorder() -> str:
)
# Delete any groups
- s.query(BudgetGroup).where(BudgetGroup.id_.not_in(g_positions)).delete()
+ BudgetGroup.query().where(BudgetGroup.id_.not_in(g_positions)).delete()
if g_positions:
# Set all to -index first so swapping can occur without unique violations
- s.query(BudgetGroup).update(
+ BudgetGroup.query().update(
{
BudgetGroup.position: sql.case(
{g_id: -i - 1 for i, g_id in enumerate(g_positions)},
@@ -474,7 +460,7 @@ def reorder() -> str:
)
# Set new group positions
- s.query(BudgetGroup).update(
+ BudgetGroup.query().update(
{
BudgetGroup.position: sql.case(
g_positions,
@@ -485,7 +471,7 @@ def reorder() -> str:
if t_cat_positions:
# Set new category positions
- s.query(TransactionCategory).update(
+ TransactionCategory.query().update(
{
TransactionCategory.budget_group_id: sql.case(
t_cat_groups,
@@ -525,8 +511,8 @@ def group(uri: str) -> str:
flask.session["groups_open"] = groups_open
elif uri != "ungrouped":
try:
- with p.begin_session() as s:
- g = base.find(s, BudgetGroup, uri)
+ with p.begin_session():
+ g = base.find(BudgetGroup, uri)
g.name = name
except (exc.IntegrityError, exc.InvalidORMValueError) as e:
return base.error(e)
@@ -547,27 +533,26 @@ def new_group() -> str:
"""
p = web.portfolio
name = "New group"
- with p.begin_session() as s:
- cf = CURRENCY_FORMATS[Config.base_currency(s)]
+ with p.begin_session():
+ cf = CURRENCY_FORMATS[Config.base_currency()]
# Ensure the name isn't a duplicate
i = 1
- n = query_count(s.query(BudgetGroup).where(BudgetGroup.name == name))
- while n != 0:
+
+ query = BudgetGroup.query().where(BudgetGroup.name == name)
+ while sql.any_(query):
i += 1
name = f"New group {i}"
- n = query_count(s.query(BudgetGroup).where(BudgetGroup.name == name))
+ query = BudgetGroup.query().where(BudgetGroup.name == name)
# Move existing groups down one
- n = query_count(s.query(BudgetGroup))
+ n = sql.count(BudgetGroup.query())
for i in range(n, -1, -1):
# Do one at a time in reverse order to prevent duplicate value
- s.query(BudgetGroup).where(BudgetGroup.position == i).update(
+ BudgetGroup.query().where(BudgetGroup.position == i).update(
{BudgetGroup.position: i + 1},
)
- g = BudgetGroup(name=name, position=0)
- s.add(g)
- s.flush()
+ g = BudgetGroup.create(name=name, position=0)
g_uri = g.uri
ctx: GroupContext = {
"position": 0,
@@ -607,17 +592,16 @@ def target(uri: str) -> str | flask.Response:
with p.begin_session() as s:
try:
- tar = base.find(s, Target, uri)
+ tar = base.find(Target, uri)
t_cat_id = tar.category_id
except exc.http.BadRequest:
t_cat_id = TransactionCategory.uri_to_id(uri)
- tar = s.query(Target).where(Target.category_id == t_cat_id).one_or_none()
+ tar = Target.query().where(Target.category_id == t_cat_id).one_or_none()
- emoji_name = (
- s.query(TransactionCategory.emoji_name)
- .where(TransactionCategory.id_ == t_cat_id)
- .one()[0]
+ query = TransactionCategory.query(TransactionCategory.emoji_name).where(
+ TransactionCategory.id_ == t_cat_id,
)
+ emoji_name = sql.one(query)
new_target = tar is None
if tar is None:
@@ -631,7 +615,7 @@ def target(uri: str) -> str | flask.Response:
repeat_every=1,
)
elif flask.request.method == "DELETE":
- s.delete(tar)
+ tar.delete()
return base.dialog_swap(
event="budget",
snackbar=f"{emoji_name} target deleted",
@@ -679,7 +663,7 @@ def target(uri: str) -> str | flask.Response:
"from_amount": (
flask.request.headers.get("HX-Trigger") == "budgeting-amount"
),
- "currency_format": CURRENCY_FORMATS[Config.base_currency(s)],
+ "currency_format": CURRENCY_FORMATS[Config.base_currency()],
}
# Don't make the changes
s.rollback()
@@ -775,13 +759,9 @@ def sidebar() -> flask.Response:
)
uri = args.get("uri")
- with p.begin_session() as s:
- data = BudgetAssignment.get_monthly_available(
- s,
- month,
- )
+ with p.begin_session():
+ data = BudgetAssignment.get_monthly_available(month)
sidebar = ctx_sidebar(
- s,
today,
month,
data.categories,
@@ -792,7 +772,7 @@ def sidebar() -> flask.Response:
"budgeting/sidebar.jinja",
ctx={
"month": month_str,
- "currency_format": CURRENCY_FORMATS[Config.base_currency(s)],
+ "currency_format": CURRENCY_FORMATS[Config.base_currency()],
},
budget_sidebar=sidebar,
)
@@ -810,7 +790,6 @@ def sidebar() -> flask.Response:
def ctx_sidebar(
- s: orm.Session,
today: datetime.date,
month: datetime.date,
categories: dict[int, BudgetAvailableCategory],
@@ -820,7 +799,6 @@ def ctx_sidebar(
"""Get the context to build the budgeting sidebar.
Args:
- s: SQL session to use
today: Today's date
month: Month of table
categories: Dict of categories from Budget.get_monthly_available
@@ -839,13 +817,13 @@ def ctx_sidebar(
total_assigned = Decimal()
total_activity = Decimal()
- query = s.query(TransactionCategory.id_).where(
+ query = TransactionCategory.query(TransactionCategory.id_).where(
TransactionCategory.group == TransactionCategoryGroup.INCOME,
)
- income_ids = {row[0] for row in query.all()}
+ income_ids = set(sql.col0(query))
targets: dict[int, Target] = {
- t.category_id: t for t in s.query(Target).yield_per(YIELD_PER)
+ t.category_id: t for t in sql.yield_(Target.query())
}
no_target: set[int] = set()
@@ -873,8 +851,7 @@ def ctx_sidebar(
total_to_go += target_ctx["to_go"]
query = (
- s.query(TransactionCategory)
- .with_entities(
+ TransactionCategory.query(
TransactionCategory.id_,
TransactionCategory.emoji_name,
)
@@ -883,7 +860,7 @@ def ctx_sidebar(
)
no_target_names: dict[str, str] = {
TransactionCategory.id_to_uri(t_cat_id): name
- for t_cat_id, name in query.all()
+ for t_cat_id, name in sql.yield_(query)
}
return {
@@ -901,11 +878,11 @@ def ctx_sidebar(
"no_target": no_target_names,
"target": None,
}
- t_cat = base.find(s, TransactionCategory, uri)
+ t_cat = base.find(TransactionCategory, uri)
t_cat_id = t_cat.id_
assigned, activity, available, leftover = categories[t_cat_id]
- tar = s.query(Target).where(Target.category_id == t_cat_id).one_or_none()
+ tar = Target.query().where(Target.category_id == t_cat_id).one_or_none()
if tar is None:
return {
"uri": uri,
@@ -1082,7 +1059,6 @@ def ctx_target(
def ctx_budget(
- s: orm.Session,
today: datetime.date,
month: datetime.date,
categories: dict[int, BudgetAvailableCategory],
@@ -1092,7 +1068,6 @@ def ctx_budget(
"""Get the context to build the budgeting table.
Args:
- s: SQL session to use
today: Today's date
month: Month of table
categories: Dict of categories from Budget.get_monthly_available
@@ -1105,13 +1080,10 @@ def ctx_budget(
"""
n_overspent = 0
- targets: dict[int, Target] = {
- t.category_id: t for t in s.query(Target).yield_per(YIELD_PER)
- }
+ targets: dict[int, Target] = {t.category_id: t for t in sql.yield_(Target.query())}
groups: dict[int | None, GroupContext] = {}
- query = s.query(BudgetGroup)
- for g in query.all():
+ for g in sql.yield_(BudgetGroup.query()):
groups[g.id_] = {
"position": g.position,
"name": g.name,
@@ -1135,8 +1107,7 @@ def ctx_budget(
"has_error": False,
}
- query = s.query(TransactionCategory)
- for t_cat in query.yield_per(YIELD_PER):
+ for t_cat in sql.yield_(TransactionCategory.query()):
assigned, activity, available, leftover = categories[t_cat.id_]
tar = targets.get(t_cat.id_)
# Skip category if all numbers are 0 and not grouped
@@ -1209,7 +1180,7 @@ def ctx_budget(
"assignable": assignable,
"groups": groups_list,
"n_overspent": n_overspent,
- "currency_format": CURRENCY_FORMATS[Config.base_currency(s)],
+ "currency_format": CURRENCY_FORMATS[Config.base_currency()],
},
title,
)
diff --git a/nummus/controllers/emergency_fund.py b/nummus/controllers/emergency_fund.py
index ce28204e..8a6942a0 100644
--- a/nummus/controllers/emergency_fund.py
+++ b/nummus/controllers/emergency_fund.py
@@ -16,8 +16,6 @@
if TYPE_CHECKING:
from decimal import Decimal
- from sqlalchemy import orm
-
from nummus.models.currency import CurrencyFormat
@@ -62,11 +60,11 @@ def page() -> flask.Response:
"""
p = web.portfolio
- with p.begin_session() as s:
+ with p.begin_session():
return base.page(
"emergency-fund/page.jinja",
"Emergency fund",
- ctx=ctx_page(s, base.today_client()),
+ ctx=ctx_page(base.today_client()),
)
@@ -78,18 +76,17 @@ def dashboard() -> str:
"""
p = web.portfolio
- with p.begin_session() as s:
+ with p.begin_session():
return flask.render_template(
"emergency-fund/dashboard.jinja",
- ctx=ctx_page(s, base.today_client()),
+ ctx=ctx_page(base.today_client()),
)
-def ctx_page(s: orm.Session, today: datetime.date) -> EFundContext:
+def ctx_page(today: datetime.date) -> EFundContext:
"""Get the context to build the emergency fund page.
Args:
- s: SQL session to use
today: Today's date
Returns:
@@ -103,7 +100,6 @@ def ctx_page(s: orm.Session, today: datetime.date) -> EFundContext:
t_lowers, t_uppers, balances, categories, categories_total = (
BudgetAssignment.get_emergency_fund(
- s,
start_ord,
today_ord,
utils.DAYS_IN_QUARTER,
@@ -144,7 +140,7 @@ def ctx_page(s: orm.Session, today: datetime.date) -> EFundContext:
key=lambda item: (-round(item["monthly"], 2), item["name"]),
)
- cf = CURRENCY_FORMATS[Config.base_currency(s)]
+ cf = CURRENCY_FORMATS[Config.base_currency()]
return {
"chart": {
"labels": [d.isoformat() for d in dates],
diff --git a/nummus/controllers/health.py b/nummus/controllers/health.py
index 107a15c6..de25a212 100644
--- a/nummus/controllers/health.py
+++ b/nummus/controllers/health.py
@@ -5,20 +5,16 @@
import datetime
import operator
from collections import defaultdict
-from typing import TYPE_CHECKING, TypedDict
+from typing import TypedDict
import flask
-from nummus import web
+from nummus import sql, web
from nummus.controllers import base
from nummus.health_checks.top import HEALTH_CHECKS
-from nummus.models.base import YIELD_PER
from nummus.models.config import Config, ConfigKey
from nummus.models.health_checks import HealthCheckIssue
-if TYPE_CHECKING:
- from sqlalchemy import orm
-
class HealthContext(TypedDict):
"""Type definition for health page context."""
@@ -45,11 +41,11 @@ def page() -> flask.Response:
"""
p = web.portfolio
- with p.begin_session() as s:
+ with p.begin_session():
return base.page(
"health/page.jinja",
title="Health",
- ctx=ctx_checks(s, run=False),
+ ctx=ctx_checks(run=False),
)
@@ -61,10 +57,10 @@ def refresh() -> str:
"""
p = web.portfolio
- with p.begin_session() as s:
+ with p.begin_session():
return flask.render_template(
"health/checks.jinja",
- ctx=ctx_checks(s, run=True),
+ ctx=ctx_checks(run=True),
include_oob=True,
)
@@ -80,12 +76,12 @@ def ignore(uri: str) -> str:
"""
p = web.portfolio
- with p.begin_session() as s:
- c = base.find(s, HealthCheckIssue, uri)
+ with p.begin_session():
+ c = base.find(HealthCheckIssue, uri)
c.ignore = True
name = c.check
- checks = ctx_checks(s, run=False)["checks"]
+ checks = ctx_checks(run=False)["checks"]
return flask.render_template(
"health/check-row.jinja",
@@ -94,11 +90,10 @@ def ignore(uri: str) -> str:
)
-def ctx_checks(s: orm.Session, *, run: bool) -> HealthContext:
+def ctx_checks(*, run: bool) -> HealthContext:
"""Get the context to build the health checks.
Args:
- s: SQL session to use
run: True will rerun health checks
Returns:
@@ -109,17 +104,17 @@ def ctx_checks(s: orm.Session, *, run: bool) -> HealthContext:
issues: dict[str, dict[str, str]] = defaultdict(dict)
if run:
- Config.set_(s, ConfigKey.LAST_HEALTH_CHECK_TS, utc_now.isoformat())
+ Config.set_(ConfigKey.LAST_HEALTH_CHECK_TS, utc_now.isoformat())
last_update = utc_now
else:
- last_update_str = Config.fetch(s, ConfigKey.LAST_HEALTH_CHECK_TS, no_raise=True)
+ last_update_str = Config.fetch(ConfigKey.LAST_HEALTH_CHECK_TS, no_raise=True)
last_update = (
None
if last_update_str is None
else datetime.datetime.fromisoformat(last_update_str)
)
- query = s.query(HealthCheckIssue).where(HealthCheckIssue.ignore.is_(False))
- for i in query.yield_per(YIELD_PER):
+ query = HealthCheckIssue.query().where(HealthCheckIssue.ignore.is_(False))
+ for i in sql.yield_(query):
issues[i.check][i.uri] = i.msg
checks: list[HealthCheckContext] = []
@@ -128,7 +123,7 @@ def ctx_checks(s: orm.Session, *, run: bool) -> HealthContext:
if run:
c = check_type()
- c.test(s)
+ c.test()
c_issues = c.issues
else:
c_issues = issues[name]
diff --git a/nummus/controllers/income.py b/nummus/controllers/income.py
index 38076bb7..ba42a2d1 100644
--- a/nummus/controllers/income.py
+++ b/nummus/controllers/income.py
@@ -17,10 +17,9 @@ def page() -> flask.Response:
"""
args = flask.request.args
p = web.portfolio
- with p.begin_session() as s:
+ with p.begin_session():
today = base.today_client()
ctx, title = spending.ctx_chart(
- s,
today,
args.get("account"),
args.get("category"),
@@ -48,10 +47,9 @@ def chart() -> flask.Response:
"""
args = flask.request.args
p = web.portfolio
- with p.begin_session() as s:
+ with p.begin_session():
today = base.today_client()
ctx, title = spending.ctx_chart(
- s,
today,
args.get("account"),
args.get("category"),
@@ -89,10 +87,9 @@ def dashboard() -> str:
"""
p = web.portfolio
- with p.begin_session() as s:
+ with p.begin_session():
today = base.today_client()
ctx, _ = spending.ctx_chart(
- s,
today,
None,
None,
diff --git a/nummus/controllers/labels.py b/nummus/controllers/labels.py
index 477b9835..110caeb2 100644
--- a/nummus/controllers/labels.py
+++ b/nummus/controllers/labels.py
@@ -2,19 +2,13 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
-
import flask
from nummus import exceptions as exc
-from nummus import web
+from nummus import sql, web
from nummus.controllers import base
-from nummus.models.base import YIELD_PER
from nummus.models.label import Label, LabelLink
-if TYPE_CHECKING:
- from sqlalchemy import orm
-
def page() -> flask.Response:
"""GET /labels.
@@ -25,11 +19,11 @@ def page() -> flask.Response:
"""
p = web.portfolio
- with p.begin_session() as s:
+ with p.begin_session():
return base.page(
"labels/page.jinja",
"Labels",
- labels=ctx_labels(s),
+ labels=ctx_labels(),
)
@@ -45,7 +39,7 @@ def label(uri: str) -> str | flask.Response:
"""
p = web.portfolio
with p.begin_session() as s:
- label = base.find(s, Label, uri)
+ label = base.find(Label, uri)
if flask.request.method == "GET":
ctx: dict[str, object] = {
@@ -59,11 +53,10 @@ def label(uri: str) -> str | flask.Response:
)
if flask.request.method == "DELETE":
-
- s.query(LabelLink).where(
+ LabelLink.query().where(
LabelLink.label_id == label.id_,
).delete()
- s.delete(label)
+ label.delete()
return base.dialog_swap(
event="label",
@@ -96,11 +89,11 @@ def validation() -> str:
args = flask.request.args
uri = args["uri"]
if "name" in args:
- with p.begin_session() as s:
+ with p.begin_session():
return base.validate_string(
args["name"],
is_required=True,
- session=s,
+ cls=Label,
no_duplicates=Label.name,
no_duplicate_wheres=([Label.id_ != Label.uri_to_id(uri)]),
)
@@ -108,20 +101,15 @@ def validation() -> str:
raise NotImplementedError
-def ctx_labels(s: orm.Session) -> list[base.NamePair]:
+def ctx_labels() -> list[base.NamePair]:
"""Get the context required to build the labels table.
- Args:
- s: SQL session to use
-
Returns:
List of HTML context
"""
- query = s.query(Label).order_by(Label.name)
- return [
- base.NamePair(label.uri, label.name) for label in query.yield_per(YIELD_PER)
- ]
+ query = Label.query().order_by(Label.name)
+ return [base.NamePair(label.uri, label.name) for label in sql.yield_(query)]
ROUTES: base.Routes = {
diff --git a/nummus/controllers/net_worth.py b/nummus/controllers/net_worth.py
index fe0f42a0..901d2e3e 100644
--- a/nummus/controllers/net_worth.py
+++ b/nummus/controllers/net_worth.py
@@ -9,7 +9,7 @@
import flask
from sqlalchemy import func
-from nummus import utils, web
+from nummus import sql, utils, web
from nummus.controllers import base
from nummus.models.account import Account
from nummus.models.asset import Asset
@@ -21,7 +21,6 @@
from nummus.models.transaction import TransactionSplit
if TYPE_CHECKING:
- from sqlalchemy import orm
from nummus.controllers.base import Routes
from nummus.models.currency import Currency, CurrencyFormat
@@ -63,9 +62,8 @@ def page() -> flask.Response:
"""
args = flask.request.args
p = web.portfolio
- with p.begin_session() as s:
+ with p.begin_session():
ctx = ctx_chart(
- s,
base.today_client(),
args.get("period", base.DEFAULT_PERIOD),
)
@@ -86,8 +84,8 @@ def chart() -> flask.Response:
args = flask.request.args
period = args.get("period", base.DEFAULT_PERIOD)
p = web.portfolio
- with p.begin_session() as s:
- ctx = ctx_chart(s, base.today_client(), period)
+ with p.begin_session():
+ ctx = ctx_chart(base.today_client(), period)
html = flask.render_template(
"net-worth/chart-data.jinja",
ctx=ctx,
@@ -113,9 +111,8 @@ def dashboard() -> str:
"""
p = web.portfolio
- with p.begin_session() as s:
+ with p.begin_session():
ctx = ctx_chart(
- s,
base.today_client(),
base.DEFAULT_PERIOD,
)
@@ -126,14 +123,12 @@ def dashboard() -> str:
def ctx_chart(
- s: orm.Session,
today: datetime.date,
period: str,
) -> Context:
"""Get the context to build the net worth chart.
Args:
- s: SQL session to use
today: Today's date
period: Selected chart period
@@ -144,22 +139,22 @@ def ctx_chart(
start, end = base.parse_period(period, today)
if start is None:
- query = s.query(func.min(TransactionSplit.date_ord)).where(
+ query = TransactionSplit.query(func.min(TransactionSplit.date_ord)).where(
TransactionSplit.asset_id.is_(None),
)
- start_ord = query.scalar()
+ start_ord = sql.scalar(query)
start = datetime.date.fromordinal(start_ord) if start_ord else end
start_ord = start.toordinal()
end_ord = end.toordinal()
- query = s.query(Account)
account_currencies: dict[int, Currency] = {
- acct.id_: acct.currency for acct in query.all() if acct.do_include(start_ord)
+ acct.id_: acct.currency
+ for acct in sql.yield_(Account.query())
+ if acct.do_include(start_ord)
}
- base_currency = Config.base_currency(s)
+ base_currency = Config.base_currency()
forex = Asset.get_forex(
- s,
start_ord,
end_ord,
base_currency,
@@ -167,7 +162,6 @@ def ctx_chart(
)
acct_values, _, _ = Account.get_value_all(
- s,
start_ord,
end_ord,
ids=account_currencies.keys(),
@@ -182,7 +176,7 @@ def ctx_chart(
] or [Decimal()] * (end_ord - start_ord + 1)
data_tuple = base.chart_data(start_ord, end_ord, (total, *acct_values.values()))
- mapping = Account.map_name(s)
+ mapping = Account.map_name()
ctx_accounts: list[AccountContext] = [
{
diff --git a/nummus/controllers/performance.py b/nummus/controllers/performance.py
index 2d743e3a..e18a8c2d 100644
--- a/nummus/controllers/performance.py
+++ b/nummus/controllers/performance.py
@@ -11,7 +11,7 @@
import flask
from sqlalchemy import func
-from nummus import utils, web
+from nummus import sql, utils, web
from nummus.controllers import base
from nummus.models.account import Account, AccountCategory
from nummus.models.asset import (
@@ -24,9 +24,9 @@
CURRENCY_FORMATS,
)
from nummus.models.transaction import TransactionSplit
+from nummus.sql import yield_
if TYPE_CHECKING:
- from sqlalchemy import orm
from nummus.models.currency import Currency, CurrencyFormat
@@ -92,9 +92,8 @@ def page() -> flask.Response:
"""
args = flask.request.args
p = web.portfolio
- with p.begin_session() as s:
+ with p.begin_session():
ctx = ctx_chart(
- s,
base.today_client(),
args.get("period", base.DEFAULT_PERIOD),
args.get("index", _DEFAULT_INDEX),
@@ -119,9 +118,8 @@ def chart() -> flask.Response:
index = args.get("index", _DEFAULT_INDEX)
excluded_accounts = args.getlist("exclude")
p = web.portfolio
- with p.begin_session() as s:
+ with p.begin_session():
ctx = ctx_chart(
- s,
base.today_client(),
period,
index,
@@ -154,8 +152,8 @@ def dashboard() -> str:
"""
p = web.portfolio
- with p.begin_session() as s:
- acct_ids = Account.ids(s, AccountCategory.INVESTMENT)
+ with p.begin_session():
+ acct_ids = Account.ids(AccountCategory.INVESTMENT)
end = base.today_client()
start = end - datetime.timedelta(days=90)
start_ord = start.toordinal()
@@ -164,16 +162,15 @@ def dashboard() -> str:
indices: dict[str, Decimal] = {}
query = (
- s.query(Asset.name)
+ Asset.query(Asset.name)
.where(Asset.category == AssetCategory.INDEX)
.order_by(Asset.name)
)
- for (name,) in query.all():
- twrr = Asset.index_twrr(s, name, start_ord, end_ord)
+ for (name,) in yield_(query):
+ twrr = Asset.index_twrr(name, start_ord, end_ord)
indices[name] = twrr[-1]
acct_values, acct_profits, _ = Account.get_value_all(
- s,
start_ord,
end_ord,
ids=acct_ids,
@@ -191,7 +188,7 @@ def dashboard() -> str:
"pnl": total_profit[-1],
"twrr": twrr[-1],
"indices": indices,
- "currency_format": CURRENCY_FORMATS[Config.base_currency(s)],
+ "currency_format": CURRENCY_FORMATS[Config.base_currency()],
}
return flask.render_template(
"performance/dashboard.jinja",
@@ -200,7 +197,6 @@ def dashboard() -> str:
def ctx_chart(
- s: orm.Session,
today: datetime.date,
period: str,
index: str,
@@ -209,7 +205,6 @@ def ctx_chart(
"""Get the context to build the performance chart.
Args:
- s: SQL session to use
today: Today's date
period: Selected chart period
index: Selected index to compare against
@@ -223,34 +218,34 @@ def ctx_chart(
ctx_accounts: list[AccountContext] = []
- acct_ids = Account.ids(s, AccountCategory.INVESTMENT)
+ acct_ids = Account.ids(AccountCategory.INVESTMENT)
if start is None:
- query = s.query(func.min(TransactionSplit.date_ord)).where(
+ query = TransactionSplit.query(func.min(TransactionSplit.date_ord)).where(
TransactionSplit.asset_id.is_(None),
TransactionSplit.account_id.in_(acct_ids),
)
- start_ord = query.scalar()
+ start_ord = sql.scalar(query)
start = datetime.date.fromordinal(start_ord) if start_ord else end
start_ord = start.toordinal()
end_ord = end.toordinal()
n = end_ord - start_ord + 1
query = (
- s.query(Asset.name)
+ Asset.query(Asset.name)
.where(Asset.category == AssetCategory.INDEX)
.order_by(Asset.name)
)
- indices: list[str] = [name for (name,) in query.all()]
- index_description: str | None = (
- s.query(Asset.description).where(Asset.name == index).scalar()
+ indices: list[str] = list(sql.col0(query))
+ index_description: str | None = sql.scalar(
+ Asset.query(Asset.description).where(Asset.name == index),
)
- query = s.query(Account).where(Account.id_.in_(acct_ids))
+ query = Account.query().where(Account.id_.in_(acct_ids))
mapping: dict[int, str] = {}
currencies: dict[int, Currency] = {}
acct_ids.clear()
account_options: list[base.NamePairState] = []
- for acct in query.all():
+ for acct in sql.yield_(query):
if acct.do_include(start_ord):
excluded = acct.id_ in excluded_accounts
account_options.append(
@@ -261,9 +256,8 @@ def ctx_chart(
currencies[acct.id_] = acct.currency
acct_ids.add(acct.id_)
- base_currency = Config.base_currency(s)
+ base_currency = Config.base_currency()
forex = Asset.get_forex(
- s,
start_ord,
end_ord,
base_currency,
@@ -271,7 +265,6 @@ def ctx_chart(
)
acct_values, acct_profits, _ = Account.get_value_all(
- s,
start_ord,
end_ord,
ids=acct_ids,
@@ -287,7 +280,7 @@ def ctx_chart(
twrr = utils.twrr(total, total_profit)
mwrr = utils.mwrr(total, total_profit)
- index_twrr = Asset.index_twrr(s, index, start_ord, end_ord)
+ index_twrr = Asset.index_twrr(index, start_ord, end_ord)
sum_cash_flow = Decimal(0)
diff --git a/nummus/controllers/settings.py b/nummus/controllers/settings.py
index c97d6b15..df08276b 100644
--- a/nummus/controllers/settings.py
+++ b/nummus/controllers/settings.py
@@ -2,7 +2,7 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, TypedDict
+from typing import TypedDict
import flask
@@ -11,9 +11,6 @@
from nummus.models.config import Config, ConfigKey
from nummus.models.currency import Currency
-if TYPE_CHECKING:
- from sqlalchemy import orm
-
class SettingsContext(TypedDict):
"""Type definition for settings context."""
@@ -30,11 +27,11 @@ def page() -> flask.Response:
"""
p = web.portfolio
- with p.begin_session() as s:
+ with p.begin_session():
return base.page(
"settings/page.jinja",
"Settings",
- ctx=ctx_settings(s),
+ ctx=ctx_settings(),
)
@@ -48,26 +45,23 @@ def edit() -> flask.Response:
p = web.portfolio
currency = flask.request.form.get("currency", type=Currency)
if currency:
- with p.begin_session() as s:
- Config.set_(s, ConfigKey.BASE_CURRENCY, str(currency.value))
+ with p.begin_session():
+ Config.set_(ConfigKey.BASE_CURRENCY, str(currency.value))
else:
raise NotImplementedError
return base.dialog_swap(event="config", snackbar="All changes saved")
-def ctx_settings(s: orm.Session) -> SettingsContext:
+def ctx_settings() -> SettingsContext:
"""Get the context to build the settings page.
- Args:
- s: SQL session to use
-
Returns:
SettingsContext
"""
return {
- "currency": Config.base_currency(s),
+ "currency": Config.base_currency(),
"currency_type": Currency,
}
diff --git a/nummus/controllers/spending.py b/nummus/controllers/spending.py
index b314da5b..bdd7372f 100644
--- a/nummus/controllers/spending.py
+++ b/nummus/controllers/spending.py
@@ -9,10 +9,9 @@
import flask
from sqlalchemy import func
-from nummus import utils, web
+from nummus import sql, utils, web
from nummus.controllers import base
from nummus.models.account import Account
-from nummus.models.base import YIELD_PER
from nummus.models.config import Config
from nummus.models.currency import (
Currency,
@@ -24,12 +23,10 @@
TransactionCategory,
TransactionCategoryGroup,
)
-from nummus.models.utils import query_count
if TYPE_CHECKING:
from decimal import Decimal
- import sqlalchemy
from sqlalchemy import orm
from nummus.models.currency import Currency, CurrencyFormat
@@ -55,7 +52,7 @@ class Context(OptionsContext):
start: str | None
end: str | None
by_account: list[tuple[str, Decimal]]
- by_payee: list[tuple[str, Decimal]]
+ by_payee: list[tuple[str | None, Decimal]]
by_category: list[tuple[str, Decimal]]
by_label: list[tuple[str | None, Decimal]]
currency_format: CurrencyFormat
@@ -65,7 +62,7 @@ class DataQuery(NamedTuple):
"""Type definition for result of data_query()."""
query: orm.Query[TransactionSplit]
- clauses: dict[str, sqlalchemy.ColumnElement]
+ clauses: dict[str, sql.ColumnClause]
any_filters: bool
@property
@@ -83,10 +80,9 @@ def page() -> flask.Response:
"""
args = flask.request.args
p = web.portfolio
- with p.begin_session() as s:
+ with p.begin_session():
today = base.today_client()
ctx, title = ctx_chart(
- s,
today,
args.get("account"),
args.get("category"),
@@ -114,10 +110,9 @@ def chart() -> flask.Response:
"""
args = flask.request.args
p = web.portfolio
- with p.begin_session() as s:
+ with p.begin_session():
today = base.today_client()
ctx, title = ctx_chart(
- s,
today,
args.get("account"),
args.get("category"),
@@ -155,10 +150,9 @@ def dashboard() -> str:
"""
p = web.portfolio
- with p.begin_session() as s:
+ with p.begin_session():
today = base.today_client()
ctx, _ = ctx_chart(
- s,
today,
None,
None,
@@ -177,7 +171,6 @@ def dashboard() -> str:
def data_query(
- s: orm.Session,
selected_currency: Currency,
selected_account: str | None = None,
selected_period: str | None = None,
@@ -191,7 +184,6 @@ def data_query(
"""Create transactions data query.
Args:
- s: SQL session to use
selected_currency: Currency to filter by
selected_account: URI of account from args
selected_period: Name of period from args
@@ -214,18 +206,18 @@ def data_query(
),
TransactionCategoryGroup.TRANSFER,
}
- query = s.query(TransactionCategory.id_).where(
+ query = TransactionCategory.query(TransactionCategory.id_).where(
(TransactionCategory.name == "securities traded")
| TransactionCategory.group.in_(skip_groups),
)
- skip_ids = {r[0] for r in query.yield_per(YIELD_PER)}
- query = s.query(Account.id_).where(Account.currency != selected_currency)
- skip_acct_ids = {r[0] for r in query.yield_per(YIELD_PER)}
- query = s.query(TransactionSplit).where(
+ skip_ids = set(sql.col0(query))
+ query = Account.query(Account.id_).where(Account.currency != selected_currency)
+ skip_acct_ids = set(sql.col0(query))
+ query = TransactionSplit.query().where(
TransactionSplit.category_id.not_in(skip_ids),
TransactionSplit.account_id.not_in(skip_acct_ids),
)
- clauses: dict[str, sqlalchemy.ColumnElement] = {}
+ clauses: dict[str, sql.ColumnClause] = {}
any_filters = False
@@ -264,11 +256,11 @@ def data_query(
label_id = Label.uri_to_id(selected_label)
label_query = (
query.join(LabelLink)
- .with_entities(TransactionSplit.id_)
+ .with_entities(TransactionSplit.id_) # nummus: ignore
.where(LabelLink.label_id == label_id)
.distinct()
)
- t_split_ids: set[int] = {r[0] for r in label_query.yield_per(YIELD_PER)}
+ t_split_ids: set[int] = {r[0] for r in sql.yield_(label_query)}
clauses["label"] = TransactionSplit.id_.in_(t_split_ids)
return DataQuery(query, clauses, any_filters)
@@ -315,14 +307,14 @@ def ctx_options(
clauses = dat_query.clauses.copy()
clauses.pop("account", None)
query_options = (
- query.with_entities(TransactionSplit.account_id)
+ query.with_entities(TransactionSplit.account_id) # nummus: ignore
.where(*clauses.values())
.distinct()
)
options_account = sorted(
[
base.NamePair(Account.id_to_uri(acct_id), accounts[acct_id])
- for acct_id, in query_options.yield_per(YIELD_PER)
+ for acct_id, in sql.yield_(query_options)
],
key=operator.itemgetter(0),
)
@@ -333,12 +325,12 @@ def ctx_options(
clauses = dat_query.clauses.copy()
clauses.pop("category", None)
query_options = (
- query.with_entities(TransactionSplit.category_id)
+ query.with_entities(TransactionSplit.category_id) # nummus: ignore
.where(*clauses.values())
.distinct()
)
options_uris = {
- TransactionCategory.id_to_uri(r[0]) for r in query_options.yield_per(YIELD_PER)
+ TransactionCategory.id_to_uri(r[0]) for r in sql.yield_(query_options)
}
if selected_category:
options_uris.add(selected_category)
@@ -356,14 +348,14 @@ def ctx_options(
clauses.pop("label", None)
query_options = (
query.join(LabelLink)
- .with_entities(LabelLink.label_id)
+ .with_entities(LabelLink.label_id) # nummus: ignore
.where(*clauses.values())
.distinct()
)
options_label = sorted(
[
base.NamePair(Label.id_to_uri(label_id), labels[label_id])
- for label_id, in query_options.yield_per(YIELD_PER)
+ for label_id, in sql.yield_(query_options)
],
key=operator.itemgetter(1),
)
@@ -380,7 +372,6 @@ def ctx_options(
def ctx_chart(
- s: orm.Session,
today: datetime.date,
selected_account: str | None,
selected_category: str | None,
@@ -394,7 +385,6 @@ def ctx_chart(
"""Get the context to build the chart data.
Args:
- s: SQL session to use
today: Today's date
selected_account: Selected account for filtering
selected_category: Selected category for filtering
@@ -409,13 +399,12 @@ def ctx_chart(
tuple(Context, title)
"""
- accounts = Account.map_name(s)
- base_currency = Config.base_currency(s)
- categories_emoji = TransactionCategory.map_name_emoji(s)
- labels = Label.map_name(s)
+ accounts = Account.map_name()
+ base_currency = Config.base_currency()
+ categories_emoji = TransactionCategory.map_name_emoji()
+ labels = Label.map_name()
dat_query = data_query(
- s,
base_currency,
selected_account,
selected_period,
@@ -429,7 +418,7 @@ def ctx_chart(
dat_query,
today,
accounts,
- base.tranaction_category_groups(s),
+ base.tranaction_category_groups(),
labels,
selected_account,
selected_category,
@@ -437,51 +426,51 @@ def ctx_chart(
)
final_query = dat_query.final_query
- n_matches = query_count(final_query)
+ n_matches = sql.count(final_query)
if not n_matches:
# If no matches, reset period to all
selected_period = None
dat_query.clauses.pop("start", None)
dat_query.clauses.pop("end", None)
final_query = dat_query.final_query
- n_matches = query_count(final_query)
+ n_matches = sql.count(final_query)
- query = final_query.with_entities(
+ query = final_query.with_entities( # nummus: ignore
TransactionSplit.account_id,
func.sum(TransactionSplit.amount),
).group_by(TransactionSplit.account_id)
by_account: list[tuple[str, Decimal]] = [
(accounts[account_id], amount if is_income else -amount)
- for account_id, amount in query.yield_per(YIELD_PER)
+ for account_id, amount in sql.yield_(query)
if amount
]
by_account = sorted(by_account, key=operator.itemgetter(1), reverse=True)
- query = final_query.with_entities(
+ query = final_query.with_entities( # nummus: ignore
TransactionSplit.payee,
func.sum(TransactionSplit.amount),
).group_by(TransactionSplit.payee)
- by_payee: list[tuple[str, Decimal]] = [
+ by_payee: list[tuple[str | None, Decimal]] = [
(payee, amount if is_income else -amount)
- for payee, amount in query.yield_per(YIELD_PER)
+ for payee, amount in sql.yield_(query)
if amount
]
by_payee = sorted(by_payee, key=operator.itemgetter(1), reverse=True)
- query = final_query.with_entities(
+ query = final_query.with_entities( # nummus: ignore
TransactionSplit.category_id,
func.sum(TransactionSplit.amount),
).group_by(TransactionSplit.category_id)
by_category: list[tuple[str, Decimal]] = [
(categories_emoji[cat_id], amount if is_income else -amount)
- for cat_id, amount in query.yield_per(YIELD_PER)
+ for cat_id, amount in sql.yield_(query)
if amount
]
by_category = sorted(by_category, key=operator.itemgetter(1), reverse=True)
query = (
final_query.join(LabelLink, full=True)
- .with_entities(
+ .with_entities( # nummus: ignore
LabelLink.label_id,
func.sum(TransactionSplit.amount),
)
@@ -490,11 +479,11 @@ def ctx_chart(
selected_label_id = selected_label and Label.uri_to_id(selected_label)
by_label: list[tuple[str | None, Decimal, bool]] = [
(
- label_id and labels[label_id],
+ labels.get(label_id),
amount if is_income else -amount,
- label_id and label_id == selected_label_id,
+ label_id == selected_label_id if label_id else False,
)
- for label_id, amount in query.yield_per(YIELD_PER)
+ for label_id, amount in sql.yield_(query)
if amount
]
by_label = sorted(by_label, key=operator.itemgetter(1), reverse=True)
diff --git a/nummus/controllers/transaction_categories.py b/nummus/controllers/transaction_categories.py
index 2816dee6..af1ec7c3 100644
--- a/nummus/controllers/transaction_categories.py
+++ b/nummus/controllers/transaction_categories.py
@@ -2,23 +2,16 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
-
import flask
from nummus import exceptions as exc
-from nummus import utils, web
+from nummus import sql, utils, web
from nummus.controllers import base
-from nummus.models.base import YIELD_PER
from nummus.models.transaction import TransactionSplit
from nummus.models.transaction_category import (
TransactionCategory,
TransactionCategoryGroup,
)
-from nummus.models.utils import query_count
-
-if TYPE_CHECKING:
- from sqlalchemy import orm
def page() -> flask.Response:
@@ -30,11 +23,11 @@ def page() -> flask.Response:
"""
p = web.portfolio
- with p.begin_session() as s:
+ with p.begin_session():
return base.page(
"transaction-categories/page.jinja",
"Transaction categories",
- groups=ctx_categories(s),
+ groups=ctx_categories(),
)
@@ -65,8 +58,8 @@ def new() -> str | flask.Response:
try:
p = web.portfolio
- with p.begin_session() as s:
- cat = TransactionCategory(
+ with p.begin_session():
+ TransactionCategory.create(
emoji_name=name,
group=group,
locked=False,
@@ -74,7 +67,6 @@ def new() -> str | flask.Response:
asset_linked=False,
essential_spending=essential_spending,
)
- s.add(cat)
except (exc.IntegrityError, exc.InvalidORMValueError) as e:
return base.error(e)
@@ -99,7 +91,7 @@ def category(uri: str) -> str | flask.Response:
"""
p = web.portfolio
with p.begin_session() as s:
- cat = base.find(s, TransactionCategory, uri)
+ cat = base.find(TransactionCategory, uri)
if flask.request.method == "GET":
ctx: dict[str, object] = {
@@ -122,12 +114,12 @@ def category(uri: str) -> str | flask.Response:
msg = f"Locked category {cat.name} cannot be modified"
raise exc.http.Forbidden(msg)
# Move all transactions to uncategorized
- uncategorized_id, _ = TransactionCategory.uncategorized(s)
+ uncategorized_id, _ = TransactionCategory.uncategorized()
- s.query(TransactionSplit).where(
+ TransactionSplit.query().where(
TransactionSplit.category_id == cat.id_,
).update({"category_id": uncategorized_id})
- s.delete(cat)
+ cat.delete()
return base.dialog_swap(
event="category",
@@ -177,39 +169,30 @@ def validation() -> str:
return "Required"
if len(name) < utils.MIN_STR_LEN:
return f"{utils.MIN_STR_LEN} characters required"
- with p.begin_session() as s:
+ with p.begin_session():
# Only get original name if locked
- locked_name = (
- s.query(TransactionCategory.name)
- .where(
+ locked_name = sql.scalar(
+ TransactionCategory.query(TransactionCategory.name).where(
TransactionCategory.id_ == category_id,
TransactionCategory.locked,
- )
- .scalar()
+ ),
)
if locked_name and locked_name != name:
return "May only add/remove emojis"
- n = query_count(
- s.query(TransactionCategory).where(
- TransactionCategory.name == name,
- TransactionCategory.id_ != category_id,
- ),
+ query = TransactionCategory.query().where(
+ TransactionCategory.name == name,
+ TransactionCategory.id_ != category_id,
)
- if n != 0:
+ if sql.any_(query):
return "Must be unique"
return ""
raise NotImplementedError
-def ctx_categories(
- s: orm.Session,
-) -> dict[TransactionCategoryGroup, list[base.NamePair]]:
+def ctx_categories() -> dict[TransactionCategoryGroup, list[base.NamePair]]:
"""Get the context required to build the categories table.
- Args:
- s: SQL session to use
-
Returns:
List of HTML context
@@ -220,8 +203,8 @@ def ctx_categories(
TransactionCategoryGroup.TRANSFER: [],
TransactionCategoryGroup.OTHER: [],
}
- query = s.query(TransactionCategory).order_by(TransactionCategory.name)
- for cat in query.yield_per(YIELD_PER):
+ query = TransactionCategory.query().order_by(TransactionCategory.name)
+ for cat in sql.yield_(query):
cat_d = base.NamePair(cat.uri, cat.emoji_name)
if cat.group != TransactionCategoryGroup.OTHER or cat.name == "uncategorized":
groups[cat.group].append(cat_d)
diff --git a/nummus/controllers/transactions.py b/nummus/controllers/transactions.py
index 62bb4f62..42e45c47 100644
--- a/nummus/controllers/transactions.py
+++ b/nummus/controllers/transactions.py
@@ -6,17 +6,17 @@
import operator
from collections import defaultdict
from decimal import Decimal
+from itertools import starmap
from typing import NamedTuple, NotRequired, TYPE_CHECKING, TypedDict
import flask
from sqlalchemy import func
from nummus import exceptions as exc
-from nummus import utils, web
+from nummus import sql, utils, web
from nummus.controllers import base
from nummus.models.account import Account
from nummus.models.asset import Asset
-from nummus.models.base import YIELD_PER
from nummus.models.config import Config
from nummus.models.currency import (
Currency,
@@ -25,15 +25,9 @@
from nummus.models.label import Label, LabelLink
from nummus.models.transaction import Transaction, TransactionSplit
from nummus.models.transaction_category import TransactionCategory
-from nummus.models.utils import (
- obj_session,
- query_count,
- query_to_dict,
- update_rows_list,
-)
+from nummus.models.utils import update_rows_list
if TYPE_CHECKING:
- import sqlalchemy
from sqlalchemy import orm
from nummus.models.currency import Currency, CurrencyFormat
@@ -122,7 +116,7 @@ class TableQuery(NamedTuple):
"""Type definition for result of table_query()."""
query: orm.Query[TransactionSplit]
- clauses: dict[str, sqlalchemy.ColumnElement]
+ clauses: dict[str, sql.ColumnClause]
any_filters: bool
@property
@@ -130,7 +124,7 @@ def final_query(self) -> orm.Query[TransactionSplit]:
"""Build the final query with clauses."""
return self.query.where(*self.clauses.values())
- def where(self, **clauses: sqlalchemy.ColumnElement) -> TableQuery:
+ def where(self, **clauses: sql.ColumnClause) -> TableQuery:
"""Add clauses to query.
Args:
@@ -155,9 +149,8 @@ def page_all() -> flask.Response:
args = flask.request.args
p = web.portfolio
- with p.begin_session() as s:
+ with p.begin_session():
txn_table, title = ctx_table(
- s,
base.today_client(),
args.get("search"),
args.get("account"),
@@ -186,9 +179,8 @@ def table() -> str | flask.Response:
args = flask.request.args
first_page = "page" not in args
p = web.portfolio
- with p.begin_session() as s:
+ with p.begin_session():
txn_table, title = ctx_table(
- s,
base.today_client(),
args.get("search"),
args.get("account"),
@@ -229,8 +221,8 @@ def table_options() -> str:
"""
p = web.portfolio
- with p.begin_session() as s:
- accounts = Account.map_name(s)
+ with p.begin_session():
+ accounts = Account.map_name()
args = flask.request.args
uncleared = "uncleared" in args
@@ -241,7 +233,6 @@ def table_options() -> str:
selected_end = args.get("end")
tbl_query = table_query(
- s,
None,
selected_account,
selected_period,
@@ -254,7 +245,7 @@ def table_options() -> str:
tbl_query,
base.today_client(),
accounts,
- base.tranaction_category_groups(s),
+ base.tranaction_category_groups(),
selected_account,
selected_category,
)
@@ -287,25 +278,27 @@ def new() -> str | flask.Response:
with p.begin_session() as s:
query = (
- s.query(Account)
- .with_entities(Account.id_, Account.name, Account.currency)
+ Account.query(Account.id_, Account.name, Account.currency)
.where(Account.closed.is_(False))
.order_by(Account.name)
)
accounts: dict[int, tuple[str, Currency]] = {
- r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER)
+ r[0]: (r[1], r[2]) for r in sql.yield_(query)
}
- uncategorized_id, uncategorized_uri = TransactionCategory.uncategorized(s)
+ uncategorized_id, uncategorized_uri = TransactionCategory.uncategorized()
- query = s.query(Transaction.payee)
+ query = Transaction.query(Transaction.payee).distinct()
payees = sorted(
- filter(None, (item for item, in query.distinct())),
+ filter(None, (item for item, in sql.yield_(query))),
key=lambda item: item.lower(),
)
- query = s.query(Label.name)
- labels = sorted(item for item, in query.distinct())
+ query = Label.query(Label.name).distinct()
+ labels = sorted(
+ filter(None, (item for item, in sql.yield_(query))),
+ key=lambda item: item.lower(),
+ )
acct_uri = (
flask.request.form.get("account") or flask.request.args.get("account") or ""
@@ -313,7 +306,7 @@ def new() -> str | flask.Response:
if acct_uri:
cf = CURRENCY_FORMATS[accounts[Account.uri_to_id(acct_uri)][1]]
else:
- cf = CURRENCY_FORMATS[Config.base_currency(s)]
+ cf = CURRENCY_FORMATS[Config.base_currency()]
empty_split: SplitContext = {
"parent_uri": "",
@@ -340,7 +333,7 @@ def new() -> str | flask.Response:
"statement": "Manually created",
"payee": None,
"splits": [empty_split],
- "category_groups": base.tranaction_category_groups(s),
+ "category_groups": base.tranaction_category_groups(),
"payees": payees,
"labels": labels,
"similar_uri": None,
@@ -427,7 +420,7 @@ def new() -> str | flask.Response:
return base.error(err)
s.add(txn)
s.flush()
- if err := _transaction_split_edit(s, txn):
+ if err := _transaction_split_edit(txn):
return base.error(err)
except (exc.IntegrityError, exc.InvalidORMValueError) as e:
return base.error(e)
@@ -452,7 +445,7 @@ def transaction(uri: str) -> str | flask.Response:
p = web.portfolio
today = base.today_client()
with p.begin_session() as s:
- txn = base.find(s, Transaction, uri)
+ txn = base.find(Transaction, uri)
if flask.request.method == "GET":
return flask.render_template(
@@ -461,7 +454,7 @@ def transaction(uri: str) -> str | flask.Response:
)
if flask.request.method == "PATCH":
txn.cleared = True
- s.query(TransactionSplit).where(
+ TransactionSplit.query().where(
TransactionSplit.parent_id == txn.id_,
).update({"cleared": True})
return base.dialog_swap(
@@ -472,15 +465,15 @@ def transaction(uri: str) -> str | flask.Response:
if txn.cleared:
return base.error("Cannot delete cleared transaction")
date = txn.date
- query = s.query(TransactionSplit.id_).where(
+ query = TransactionSplit.query(TransactionSplit.id_).where(
TransactionSplit.parent_id == txn.id_,
)
- t_split_ids = {r[0] for r in query.yield_per(YIELD_PER)}
- s.query(LabelLink).where(LabelLink.t_split_id.in_(t_split_ids)).delete()
- s.query(TransactionSplit).where(
+ t_split_ids = set(sql.col0(query))
+ LabelLink.query().where(LabelLink.t_split_id.in_(t_split_ids)).delete()
+ TransactionSplit.query().where(
TransactionSplit.id_.in_(t_split_ids),
).delete()
- s.delete(txn)
+ txn.delete()
return base.dialog_swap(
# update-account since transaction was deleted
event="account",
@@ -493,7 +486,7 @@ def transaction(uri: str) -> str | flask.Response:
if err := _transaction_edit(txn, today):
return base.error(err)
s.flush()
- if err := _transaction_split_edit(s, txn):
+ if err := _transaction_split_edit(txn):
return base.error(err)
except (exc.IntegrityError, exc.InvalidORMValueError) as e:
return base.error(e)
@@ -536,11 +529,10 @@ def _transaction_edit(txn: Transaction, today: datetime.date) -> str:
return ""
-def _transaction_split_edit(s: orm.Session, txn: Transaction) -> str:
+def _transaction_split_edit(txn: Transaction) -> str:
"""Edit transaction from form.
Args:
- s: SQL session to use
txn: Transaction to edit
Returns:
@@ -564,16 +556,15 @@ def _transaction_split_edit(s: orm.Session, txn: Transaction) -> str:
remaining = txn.amount - sum(filter(None, split_amounts))
if remaining != 0:
- currency = (
- s.query(Account.currency).where(Account.id_ == txn.account_id).one()[0]
- )
+ query = Account.query(Account.currency).where(Account.id_ == txn.account_id)
+ currency = sql.one(query)
cf = CURRENCY_FORMATS[currency]
if remaining < 0:
return f"Remove {cf(-remaining)} from splits"
return f"Assign {cf(remaining)} to splits"
- splits = [
+ splits: list[dict[str, object]] = [
{
"parent": txn,
"category_id": cat_id,
@@ -589,19 +580,16 @@ def _transaction_split_edit(s: orm.Session, txn: Transaction) -> str:
if amount
]
query = (
- s.query(TransactionSplit)
+ TransactionSplit.query()
.where(TransactionSplit.parent_id == txn.id_)
.order_by(TransactionSplit.id_)
)
t_split_ids = update_rows_list(
- s,
TransactionSplit,
query,
splits,
)
- s.flush()
LabelLink.add_links(
- s,
{
t_split_id: set(form.getlist(f"label-{i}"))
for i, t_split_id in enumerate(t_split_ids)
@@ -624,12 +612,13 @@ def split(uri: str) -> str:
p = web.portfolio
form = flask.request.form
- with p.begin_session() as s:
- txn = base.find(s, Transaction, uri)
+ with p.begin_session():
+ txn = base.find(Transaction, uri)
parent_amount = utils.parse_real(form["amount"]) or Decimal()
account_id = Account.uri_to_id(form["account"])
- currency = s.query(Account.currency).where(Account.id_ == account_id).one()[0]
+ query = Account.query(Account.currency).where(Account.id_ == account_id)
+ currency = sql.one(query)
payee = form["payee"]
date = utils.parse_date(form["date"])
@@ -650,7 +639,7 @@ def split(uri: str) -> str:
split_labels.append(set())
split_amounts.append(None)
- _, uncategorized_uri = TransactionCategory.uncategorized(s)
+ _, uncategorized_uri = TransactionCategory.uncategorized()
cf = CURRENCY_FORMATS[currency]
@@ -781,14 +770,14 @@ def _validate_splits() -> str:
else:
uri = args.get("account")
p = web.portfolio
- with p.begin_session() as s:
- currency = (
- s.query(Account.currency)
- .where(Account.id_ == Account.uri_to_id(uri))
- .one()[0]
- if uri
- else Config.base_currency(s)
- )
+ with p.begin_session():
+ if uri:
+ query = Account.query(Account.currency).where(
+ Account.id_ == Account.uri_to_id(uri),
+ )
+ currency = sql.one(query)
+ else:
+ currency = Config.base_currency()
cf = CURRENCY_FORMATS[currency]
msg = (
f"Assign {cf(remaining)} to splits"
@@ -805,7 +794,6 @@ def _validate_splits() -> str:
def table_query(
- s: orm.Session,
acct_uri: str | None = None,
selected_account: str | None = None,
selected_period: str | None = None,
@@ -818,7 +806,6 @@ def table_query(
"""Create transactions table query.
Args:
- s: SQL session to use
acct_uri: Account URI to filter to
selected_account: URI of account from args
selected_period: Name of period from args
@@ -832,14 +819,14 @@ def table_query(
"""
selected_account = acct_uri or selected_account
- query = s.query(TransactionSplit).order_by(
+ query = TransactionSplit.query().order_by(
TransactionSplit.date_ord.desc(),
TransactionSplit.account_id,
TransactionSplit.payee,
TransactionSplit.category_id,
TransactionSplit.memo,
)
- clauses: dict[str, sqlalchemy.ColumnElement] = {}
+ clauses: dict[str, sql.ColumnClause] = {}
any_filters = False
@@ -906,38 +893,26 @@ def ctx_txn(
Dictionary HTML context
"""
- s = obj_session(txn)
-
account_id = txn.account_id if account_id is None else account_id
- query = (
- s.query(Account)
- .with_entities(
- Account.id_,
- Account.name,
- Account.closed,
- Account.currency,
- )
- .order_by(Account.name)
- )
- accounts: dict[int, tuple[str, bool, Currency]] = {
- r[0]: (r[1], r[2], r[3]) for r in query.yield_per(YIELD_PER)
- }
- query = s.query(Asset).with_entities(Asset.id_, Asset.name, Asset.ticker)
- assets: dict[int, tuple[str, str | None]] = {
- r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER)
- }
- query = s.query(Label.id_, Label.name)
- labels: dict[int, str] = query_to_dict(query)
+ query = Account.query(
+ Account.id_,
+ Account.name,
+ Account.closed,
+ Account.currency,
+ ).order_by(Account.name)
+ accounts: dict[int, tuple[str, bool, Currency]] = sql.to_dict_tuple(query)
+ query = Asset.query(Asset.id_, Asset.name, Asset.ticker)
+ assets = sql.to_dict_tuple(query)
+ query = Label.query(Label.id_, Label.name)
+ labels: dict[int, str] = sql.to_dict(query)
cf = CURRENCY_FORMATS[accounts[account_id][2]]
- query = (
- s.query(LabelLink)
- .with_entities(LabelLink.t_split_id, LabelLink.label_id)
- .where(LabelLink.t_split_id.in_(t_split.id_ for t_split in txn.splits))
+ query = LabelLink.query(LabelLink.t_split_id, LabelLink.label_id).where(
+ LabelLink.t_split_id.in_(t_split.id_ for t_split in txn.splits),
)
label_links: dict[int, set[int]] = defaultdict(set)
- for t_split_id, label_id in query.yield_per(YIELD_PER):
+ for t_split_id, label_id in sql.yield_(query):
label_links[t_split_id].add(label_id)
ctx_splits: list[SplitContext] = (
@@ -955,9 +930,9 @@ def ctx_txn(
)
any_asset_splits = any(split.get("asset_name") for split in ctx_splits)
- query = s.query(Transaction.payee)
+ query = Transaction.query(Transaction.payee).distinct()
payees = sorted(
- filter(None, (item for item, in query.distinct())),
+ filter(None, sql.col0(query)),
key=lambda item: item.lower(),
)
@@ -980,7 +955,7 @@ def ctx_txn(
"statement": txn.statement,
"payee": txn.payee if payee is None else payee,
"splits": ctx_splits,
- "category_groups": base.tranaction_category_groups(s),
+ "category_groups": base.tranaction_category_groups(),
"payees": payees,
"labels": sorted(labels.values()),
"similar_uri": similar_uri,
@@ -1105,14 +1080,14 @@ def ctx_options(
clauses = tbl_query.clauses.copy()
clauses.pop("account", None)
query_options = (
- query.with_entities(TransactionSplit.account_id)
+ query.with_entities(TransactionSplit.account_id) # nummus: ignore
.where(*clauses.values())
.distinct()
)
options_account = sorted(
[
base.NamePair(Account.id_to_uri(acct_id), accounts[acct_id])
- for acct_id, in query_options.yield_per(YIELD_PER)
+ for acct_id, in sql.yield_(query_options)
],
key=operator.itemgetter(0),
)
@@ -1123,13 +1098,13 @@ def ctx_options(
clauses = tbl_query.clauses.copy()
clauses.pop("category", None)
query_options = (
- query.with_entities(TransactionSplit.category_id)
+ query.with_entities(TransactionSplit.category_id) # nummus: ignore
.where(*clauses.values())
.distinct()
)
- options_uris = {
- TransactionCategory.id_to_uri(r[0]) for r in query_options.yield_per(YIELD_PER)
- }
+ options_uris = set(
+ starmap(TransactionCategory.id_to_uri, sql.yield_(query_options)),
+ )
if selected_category:
options_uris.add(selected_category)
options_category = {
@@ -1150,7 +1125,6 @@ def ctx_options(
def ctx_table(
- s: orm.Session,
today: datetime.date,
search_str: str | None,
selected_account: str | None,
@@ -1166,7 +1140,6 @@ def ctx_table(
"""Get the context to build the transaction table.
Args:
- s: SQL session to use
today: Today's date
search_str: String to search for
selected_account: Selected account for filtering
@@ -1182,24 +1155,22 @@ def ctx_table(
tuple(TableContext, title)
"""
- query = s.query(Account).with_entities(Account.id_, Account.name, Account.currency)
+ query = Account.query(Account.id_, Account.name, Account.currency)
accounts: dict[int, str] = {}
currency_formats: dict[int, CurrencyFormat] = {}
- for acct_id, name, currency in query.yield_per(YIELD_PER):
+ for acct_id, name, currency in sql.yield_(query):
accounts[acct_id] = name
currency_formats[acct_id] = CURRENCY_FORMATS[currency]
- categories_emoji = TransactionCategory.map_name_emoji(s)
+ categories_emoji = TransactionCategory.map_name_emoji()
categories = {
cat_id: TransactionCategory.clean_emoji_name(name)
for cat_id, name in categories_emoji.items()
}
- query = s.query(Asset).with_entities(Asset.id_, Asset.name, Asset.ticker)
- assets: dict[int, tuple[str, str | None]] = {
- r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER)
- }
- labels = Label.map_name(s)
+ query = Asset.query(Asset.id_, Asset.name, Asset.ticker)
+ assets = sql.to_dict_tuple(query)
+ labels = Label.map_name()
if page_start is None:
page_start_int = None
@@ -1210,7 +1181,6 @@ def ctx_table(
page_start_int = datetime.date.fromisoformat(page_start).toordinal()
tbl_query = table_query(
- s,
acct_uri,
selected_account,
selected_period,
@@ -1223,7 +1193,7 @@ def ctx_table(
tbl_query,
today,
accounts,
- base.tranaction_category_groups(s),
+ base.tranaction_category_groups(),
selected_account,
selected_category,
)
@@ -1245,7 +1215,9 @@ def ctx_table(
t_split_order = {}
final_query = tbl_query.final_query
- query_total = final_query.with_entities(func.sum(TransactionSplit.amount))
+ query_total = final_query.with_entities( # nummus: ignore
+ func.sum(TransactionSplit.amount),
+ )
if matches is not None:
i_start = page_start_int or 0
@@ -1256,7 +1228,7 @@ def ctx_table(
# Find the fewest dates to include that will make page at least
# PAGE_LEN long
included_date_ords: set[int] = set()
- query_page_count = final_query.with_entities(
+ query_page_count = final_query.with_entities( # nummus: ignore
TransactionSplit.date_ord,
func.count(),
).group_by(TransactionSplit.date_ord)
@@ -1266,9 +1238,7 @@ def ctx_table(
)
page_count = 0
# Limit to PAGE_LEN since at most there is one txn per day
- for date_ord, count in query_page_count.limit(PAGE_LEN).yield_per(
- YIELD_PER,
- ):
+ for date_ord, count in sql.yield_(query_page_count.limit(PAGE_LEN)):
included_date_ords.add(date_ord)
page_count += count
if page_count >= PAGE_LEN:
@@ -1284,7 +1254,7 @@ def ctx_table(
else datetime.date.fromordinal(min(included_date_ords) - 1)
)
- n_matches = query_count(final_query)
+ n_matches = sql.count(final_query)
groups = _table_results(
final_query,
assets,
@@ -1306,7 +1276,7 @@ def ctx_table(
return {
"uri": acct_uri,
"transactions": groups,
- "query_total": query_total.scalar() or Decimal(),
+ "query_total": sql.scalar(query_total) or Decimal(),
"no_matches": n_matches == 0 and page_start_int is None,
"next_page": None if n_matches < PAGE_LEN else str(next_page),
"any_filters": tbl_query.any_filters,
@@ -1318,7 +1288,7 @@ def ctx_table(
"uncleared": uncleared,
"start": selected_start,
"end": selected_end,
- "currency_format": CURRENCY_FORMATS[Config.base_currency(s)],
+ "currency_format": CURRENCY_FORMATS[Config.base_currency()],
}, title
@@ -1350,19 +1320,17 @@ def _table_results(
)]
"""
- s = query.session
-
# Iterate first to get required second query
t_splits: list[TransactionSplit] = []
parent_ids: set[int] = set()
- for t_split in query.yield_per(YIELD_PER):
+ for t_split in sql.yield_(query):
t_splits.append(t_split)
parent_ids.add(t_split.parent_id)
# There are no more if there wasn't enough for a full page
query_has_splits = (
- s.query(Transaction.id_)
+ Transaction.query(Transaction.id_)
.join(TransactionSplit)
.where(
Transaction.id_.in_(parent_ids),
@@ -1370,15 +1338,13 @@ def _table_results(
.group_by(Transaction.id_)
.having(func.count() > 1)
)
- has_splits = {r[0] for r in query_has_splits.yield_per(YIELD_PER)}
+ has_splits = set(sql.col0(query_has_splits))
- query_labels = (
- s.query(LabelLink)
- .with_entities(LabelLink.t_split_id, LabelLink.label_id)
- .where(LabelLink.t_split_id.in_(t_split.id_ for t_split in t_splits))
+ query_labels = LabelLink.query(LabelLink.t_split_id, LabelLink.label_id).where(
+ LabelLink.t_split_id.in_(t_split.id_ for t_split in t_splits),
)
label_links: dict[int, set[int]] = defaultdict(set)
- for t_split_id, label_id in query_labels.yield_per(YIELD_PER):
+ for t_split_id, label_id in sql.yield_(query_labels):
label_links[t_split_id].add(label_id)
t_splits_flat: list[tuple[RowContext, int]] = []
diff --git a/nummus/encryption/top.py b/nummus/encryption/top.py
index 8c05f092..4a745629 100644
--- a/nummus/encryption/top.py
+++ b/nummus/encryption/top.py
@@ -14,7 +14,8 @@
logger.warning("Install libsqlcipher: apt install libsqlcipher-dev")
logger.warning("Install encrypt extra: pip install nummus[encrypt]")
Encryption = NoEncryption
- ENCRYPTION_AVAILABLE = False
+ encryption_available = False
else:
Encryption = EncryptionAES
- ENCRYPTION_AVAILABLE = True
+ encryption_available = True
+ENCRYPTION_AVAILABLE = encryption_available
diff --git a/nummus/exceptions.py b/nummus/exceptions.py
index b720632e..592db581 100644
--- a/nummus/exceptions.py
+++ b/nummus/exceptions.py
@@ -39,8 +39,8 @@
"MissingAssetError",
"MultipleResultsFound",
"NoAssetWebSourceError",
- "NoIDError",
"NoImporterBufferError",
+ "NoKeywordArgumentsError",
"NoResultFound",
"NoURIError",
"NonAssetTransactionError",
@@ -175,10 +175,6 @@ class ProtectedObjectNotFoundError(Exception):
"""Error when a protected object (non-deletable) could not be found."""
-class NoIDError(Exception):
- """Error when model does not have id_ yet, likely a flush is needed."""
-
-
class NoURIError(Exception):
"""Error when a URI is requested for a model without one."""
@@ -260,3 +256,7 @@ class InvalidAssetTransactionCategoryError(Exception):
class InvalidKeyError(Exception):
"""Error when a key does not meet minimum requirements."""
+
+
+class NoKeywordArgumentsError(Exception):
+ """Error when function is given kwargs when not expected."""
diff --git a/nummus/health_checks/base.py b/nummus/health_checks/base.py
index 972b5fc0..f92d08e2 100644
--- a/nummus/health_checks/base.py
+++ b/nummus/health_checks/base.py
@@ -3,16 +3,12 @@
from __future__ import annotations
from abc import ABC, abstractmethod
-from typing import ClassVar, TYPE_CHECKING
+from typing import ClassVar
-from nummus import utils
-from nummus.models.base import YIELD_PER
+from nummus import sql, utils
from nummus.models.health_checks import HealthCheckIssue
from nummus.models.utils import update_rows
-if TYPE_CHECKING:
- from sqlalchemy import orm
-
class HealthCheck(ABC):
"""Base health check class."""
@@ -24,7 +20,7 @@ def __init__(
self,
*,
no_ignores: bool = False,
- **_,
+ **_: object,
) -> None:
"""Initialize Base health check.
@@ -79,26 +75,20 @@ def is_severe(cls) -> bool:
return cls._SEVERE
@abstractmethod
- def test(self, s: orm.Session) -> None:
- """Run the health check on a portfolio.
-
- Args:
- s: SQL session to use
-
- """
+ def test(self) -> None:
+ """Run the health check on a portfolio."""
raise NotImplementedError
@classmethod
- def ignore(cls, s: orm.Session, values: list[str] | set[str]) -> None:
+ def ignore(cls, values: list[str] | set[str]) -> None:
"""Ignore false positive issues.
Args:
- s: SQL session to use
values: List of issues to ignore
"""
(
- s.query(HealthCheckIssue)
+ HealthCheckIssue.query()
.where(
HealthCheckIssue.check == cls.name(),
HealthCheckIssue.value.in_(values),
@@ -106,40 +96,36 @@ def ignore(cls, s: orm.Session, values: list[str] | set[str]) -> None:
.update({"ignore": True})
)
- def _commit_issues(self, s: orm.Session, issues: dict[str, str]) -> None:
+ def _commit_issues(self, issues: dict[str, str]) -> None:
"""Commit issues to Portfolio.
Args:
- s: SQL session to use
issues: dict{value: message}
"""
- query = s.query(HealthCheckIssue.value).where(
+ query = HealthCheckIssue.query(HealthCheckIssue.value).where(
HealthCheckIssue.check == self.name(),
HealthCheckIssue.ignore.is_(True),
)
- ignored = {r[0] for r in query.yield_per(YIELD_PER)}
+ ignored = set(sql.col0(query))
updates: dict[object, dict[str, object]] = {
value: {"check": self.name(), "ignore": value in ignored, "msg": msg}
for value, msg in issues.items()
}
- query = s.query(HealthCheckIssue).where(
+ query = HealthCheckIssue.query().where(
HealthCheckIssue.check == self.name(),
)
- update_rows(s, HealthCheckIssue, query, "value", updates)
- s.flush()
+ update_rows(HealthCheckIssue, query, "value", updates)
- query = (
- s.query(HealthCheckIssue)
- .with_entities(HealthCheckIssue.id_, HealthCheckIssue.msg)
- .where(
- HealthCheckIssue.check == self.name(),
- )
+ query = HealthCheckIssue.query(
+ HealthCheckIssue.id_,
+ HealthCheckIssue.msg,
+ ).where(
+ HealthCheckIssue.check == self.name(),
)
if not self._no_ignores:
query = query.where(HealthCheckIssue.ignore.is_(False))
self._issues = {
- HealthCheckIssue.id_to_uri(id_): msg
- for id_, msg in query.yield_per(YIELD_PER)
+ HealthCheckIssue.id_to_uri(id_): msg for id_, msg in sql.yield_(query)
}
diff --git a/nummus/health_checks/category_direction.py b/nummus/health_checks/category_direction.py
index c56f095f..d62c734f 100644
--- a/nummus/health_checks/category_direction.py
+++ b/nummus/health_checks/category_direction.py
@@ -4,25 +4,17 @@
import datetime
import textwrap
-from typing import override, TYPE_CHECKING
+from typing import override
+from nummus import sql
from nummus.health_checks.base import HealthCheck
from nummus.models.account import Account
-from nummus.models.base import YIELD_PER
from nummus.models.currency import CURRENCY_FORMATS
from nummus.models.transaction import TransactionSplit
from nummus.models.transaction_category import (
TransactionCategory,
TransactionCategoryGroup,
)
-from nummus.models.utils import query_to_dict
-
-if TYPE_CHECKING:
- from decimal import Decimal
-
- from sqlalchemy import orm
-
- from nummus.models.currency import Currency
class CategoryDirection(HealthCheck):
@@ -36,39 +28,36 @@ class CategoryDirection(HealthCheck):
_SEVERE = True
@override
- def test(self, s: orm.Session) -> None:
- query = s.query(Account).with_entities(
+ def test(self) -> None:
+ query = Account.query(
Account.id_,
Account.name,
Account.currency,
)
- accounts: dict[int, tuple[str, Currency]] = {
- r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER)
- }
+ accounts = sql.to_dict_tuple(query)
if len(accounts) == 0:
- self._commit_issues(s, {})
+ self._commit_issues({})
return
acct_len = max(len(acct[0]) for acct in accounts.values())
issues: dict[str, str] = {}
- query = s.query(
+ query = TransactionCategory.query(
TransactionCategory.id_,
TransactionCategory.emoji_name,
).where(
TransactionCategory.group == TransactionCategoryGroup.INCOME,
)
- cat_income_ids: dict[int, str] = query_to_dict(query)
- query = s.query(
+ cat_income_ids: dict[int, str] = sql.to_dict(query)
+ query = TransactionCategory.query(
TransactionCategory.id_,
TransactionCategory.emoji_name,
).where(
TransactionCategory.group == TransactionCategoryGroup.EXPENSE,
)
- cat_expense_ids: dict[int, str] = query_to_dict(query)
+ cat_expense_ids: dict[int, str] = sql.to_dict(query)
query = (
- s.query(TransactionSplit)
- .with_entities(
+ TransactionSplit.query(
TransactionSplit.id_,
TransactionSplit.account_id,
TransactionSplit.date_ord,
@@ -82,13 +71,7 @@ def test(self, s: orm.Session) -> None:
)
.order_by(TransactionSplit.date_ord)
)
- for t_id, acct_id, date_ord, payee, amount, t_cat_id in query.yield_per(
- YIELD_PER,
- ):
- acct_id: int
- date_ord: int
- payee: str
- amount: Decimal
+ for t_id, acct_id, date_ord, payee, amount, t_cat_id in sql.yield_(query):
uri = TransactionSplit.id_to_uri(t_id)
acct_name, currency = accounts[acct_id]
@@ -104,8 +87,7 @@ def test(self, s: orm.Session) -> None:
issues[uri] = msg
query = (
- s.query(TransactionSplit)
- .with_entities(
+ TransactionSplit.query(
TransactionSplit.id_,
TransactionSplit.account_id,
TransactionSplit.date_ord,
@@ -119,13 +101,7 @@ def test(self, s: orm.Session) -> None:
)
.order_by(TransactionSplit.date_ord)
)
- for t_id, acct_id, date_ord, payee, amount, t_cat_id in query.yield_per(
- YIELD_PER,
- ):
- acct_id: int
- date_ord: int
- payee: str
- amount: Decimal
+ for t_id, acct_id, date_ord, payee, amount, t_cat_id in sql.yield_(query):
uri = TransactionSplit.id_to_uri(t_id)
acct_name, currency = accounts[acct_id]
@@ -140,4 +116,4 @@ def test(self, s: orm.Session) -> None:
)
issues[uri] = msg
- self._commit_issues(s, issues)
+ self._commit_issues(issues)
diff --git a/nummus/health_checks/database_integrity.py b/nummus/health_checks/database_integrity.py
index 9cf9ba10..f147064b 100644
--- a/nummus/health_checks/database_integrity.py
+++ b/nummus/health_checks/database_integrity.py
@@ -6,7 +6,9 @@
import sqlalchemy
+from nummus import sql
from nummus.health_checks.base import HealthCheck
+from nummus.models.base import Base
if TYPE_CHECKING:
from sqlalchemy import orm
@@ -19,11 +21,13 @@ class DatabaseIntegrity(HealthCheck):
_SEVERE = True
@override
- def test(self, s: orm.Session) -> None:
- result = s.execute(sqlalchemy.text("PRAGMA integrity_check"))
- rows = [row for row, in result.all()]
+ def test(self) -> None:
+ query: orm.query.RowReturningQuery[tuple[str]] = Base.session().execute( # type: ignore[attr-defined]
+ sqlalchemy.text("PRAGMA integrity_check"),
+ )
+ rows = list(sql.col0(query))
if len(rows) != 1 or rows[0] != "ok":
issues = {str(i): row for i, row in enumerate(rows)}
else:
issues = {}
- self._commit_issues(s, issues)
+ self._commit_issues(issues)
diff --git a/nummus/health_checks/duplicate_transactions.py b/nummus/health_checks/duplicate_transactions.py
index 353c494c..cbc50666 100644
--- a/nummus/health_checks/duplicate_transactions.py
+++ b/nummus/health_checks/duplicate_transactions.py
@@ -3,23 +3,16 @@
from __future__ import annotations
import datetime
-from typing import override, TYPE_CHECKING
+from typing import override
from sqlalchemy import func
+from nummus import sql
from nummus.health_checks.base import HealthCheck
from nummus.models.account import Account
-from nummus.models.base import YIELD_PER
from nummus.models.currency import CURRENCY_FORMATS
from nummus.models.transaction import Transaction
-if TYPE_CHECKING:
- from decimal import Decimal
-
- from sqlalchemy import orm
-
- from nummus.models.currency import Currency
-
class DuplicateTransactions(HealthCheck):
"""Checks for transactions with same amount, date, and statement."""
@@ -28,21 +21,18 @@ class DuplicateTransactions(HealthCheck):
_SEVERE = True
@override
- def test(self, s: orm.Session) -> None:
- query = s.query(Account).with_entities(
+ def test(self) -> None:
+ query = Account.query(
Account.id_,
Account.name,
Account.currency,
)
- accounts: dict[int, tuple[str, Currency]] = {
- r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER)
- }
+ accounts = sql.to_dict_tuple(query)
issues: list[tuple[str, str, str]] = []
query = (
- s.query(Transaction)
- .with_entities(
+ Transaction.query(
Transaction.date_ord,
Transaction.account_id,
Transaction.amount,
@@ -58,10 +48,7 @@ def test(self, s: orm.Session) -> None:
.order_by(Transaction.date_ord)
.having(func.count() > 1)
)
- for date_ord, acct_id, amount in query.yield_per(YIELD_PER):
- date_ord: int
- acct_id: int
- amount: Decimal
+ for date_ord, acct_id, amount in sql.yield_(query):
amount_raw = Transaction.amount.type.process_bind_param(amount, None)
# Create a robust uri for this duplicate
uri = f"{acct_id}.{date_ord}.{amount_raw}"
@@ -81,7 +68,6 @@ def test(self, s: orm.Session) -> None:
amount_len = 0
self._commit_issues(
- s,
{
uri: f"{source:{source_len}}: {amount_str:>{amount_len}}"
for uri, source, amount_str in issues
diff --git a/nummus/health_checks/empty_fields.py b/nummus/health_checks/empty_fields.py
index 1a04b082..8fe72c9c 100644
--- a/nummus/health_checks/empty_fields.py
+++ b/nummus/health_checks/empty_fields.py
@@ -3,18 +3,15 @@
from __future__ import annotations
import datetime
-from typing import override, TYPE_CHECKING
+from typing import override
+from nummus import sql
from nummus.health_checks.base import HealthCheck
from nummus.models.account import Account
from nummus.models.asset import Asset
-from nummus.models.base import YIELD_PER
from nummus.models.transaction import Transaction, TransactionSplit
from nummus.models.transaction_category import TransactionCategory
-if TYPE_CHECKING:
- from sqlalchemy import orm
-
class EmptyFields(HealthCheck):
"""Checks for empty fields that are better when populated."""
@@ -23,53 +20,36 @@ class EmptyFields(HealthCheck):
_SEVERE = False
@override
- def test(self, s: orm.Session) -> None:
- accounts = Account.map_name(s)
+ def test(self) -> None:
+ accounts = Account.map_name()
# List of (uri, source, field)
issues: list[tuple[str, str, str]] = []
- query = (
- s.query(Account)
- .with_entities(Account.id_, Account.name)
- .where(Account.number.is_(None))
- )
- for acct_id, name in query.yield_per(YIELD_PER):
- acct_id: int
- name: str
+ query = Account.query(Account.id_, Account.name).where(Account.number.is_(None))
+ for acct_id, name in sql.yield_(query):
uri = Account.id_to_uri(acct_id)
issues.append(
(f"{uri}.number", f"Account {name}", "has an empty number"),
)
- query = (
- s.query(Asset)
- .with_entities(Asset.id_, Asset.name)
- .where(Asset.description.is_(None))
+ query = Asset.query(Asset.id_, Asset.name).where(
+ Asset.description.is_(None),
)
- for a_id, name in query.yield_per(YIELD_PER):
- a_id: int
- name: str
+ for a_id, name in sql.yield_(query):
uri = Asset.id_to_uri(a_id)
issues.append(
(f"{uri}.description", f"Asset {name}", "has an empty description"),
)
- query = (
- s.query(Transaction)
- .with_entities(
- Transaction.id_,
- Transaction.date_ord,
- Transaction.account_id,
- )
- .where(
- Transaction.payee.is_(None),
- )
+ query = Transaction.query(
+ Transaction.id_,
+ Transaction.date_ord,
+ Transaction.account_id,
+ ).where(
+ Transaction.payee.is_(None),
)
- for t_id, date_ord, acct_id in query.yield_per(YIELD_PER):
- t_id: int
- date_ord: int
- acct_id: int
+ for t_id, date_ord, acct_id in sql.yield_(query):
uri = Transaction.id_to_uri(t_id)
date = datetime.date.fromordinal(date_ord)
@@ -78,20 +58,13 @@ def test(self, s: orm.Session) -> None:
(f"{uri}.payee", source, "has an empty payee"),
)
- uncategorized_id, _ = TransactionCategory.uncategorized(s)
- query = (
- s.query(TransactionSplit)
- .with_entities(
- TransactionSplit.id_,
- TransactionSplit.date_ord,
- TransactionSplit.account_id,
- )
- .where(TransactionSplit.category_id == uncategorized_id)
- )
- for t_id, date_ord, acct_id in query.yield_per(YIELD_PER):
- t_id: int
- date_ord: int
- acct_id: int
+ uncategorized_id, _ = TransactionCategory.uncategorized()
+ query = TransactionSplit.query(
+ TransactionSplit.id_,
+ TransactionSplit.date_ord,
+ TransactionSplit.account_id,
+ ).where(TransactionSplit.category_id == uncategorized_id)
+ for t_id, date_ord, acct_id in sql.yield_(query):
uri = TransactionSplit.id_to_uri(t_id)
date = datetime.date.fromordinal(date_ord)
@@ -101,6 +74,5 @@ def test(self, s: orm.Session) -> None:
source_len = max(len(item[1]) for item in issues) if issues else 0
self._commit_issues(
- s,
{uri: f"{source:{source_len}} {field}" for uri, source, field in issues},
)
diff --git a/nummus/health_checks/missing_asset_link.py b/nummus/health_checks/missing_asset_link.py
index f58330e7..2a5f48d6 100644
--- a/nummus/health_checks/missing_asset_link.py
+++ b/nummus/health_checks/missing_asset_link.py
@@ -5,9 +5,9 @@
import datetime
from typing import override, TYPE_CHECKING
+from nummus import sql
from nummus.health_checks.base import HealthCheck
from nummus.models.account import Account
-from nummus.models.base import YIELD_PER
from nummus.models.currency import CURRENCY_FORMATS
from nummus.models.transaction import TransactionSplit
from nummus.models.transaction_category import TransactionCategory
@@ -15,10 +15,6 @@
if TYPE_CHECKING:
from decimal import Decimal
- from sqlalchemy import orm
-
- from nummus.models.currency import Currency
-
class MissingAssetLink(HealthCheck):
"""Checks for transactions that should be linked to an asset that aren't."""
@@ -27,50 +23,39 @@ class MissingAssetLink(HealthCheck):
_SEVERE = False
@override
- def test(self, s: orm.Session) -> None:
- query = s.query(Account).with_entities(
+ def test(self) -> None:
+ query = Account.query(
Account.id_,
Account.name,
Account.currency,
)
- accounts: dict[int, tuple[str, Currency]] = {
- r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER)
- }
+ accounts = sql.to_dict_tuple(query)
if len(accounts) == 0:
- self._commit_issues(s, {})
+ self._commit_issues({})
return
acct_len = max(len(acct[0]) for acct in accounts.values())
issues: dict[str, str] = {}
- categories = TransactionCategory.map_name_emoji(s)
+ categories = TransactionCategory.map_name_emoji()
# These categories should be linked to an asset
- query = s.query(TransactionCategory.id_).where(
+ query = TransactionCategory.query(TransactionCategory.id_).where(
TransactionCategory.asset_linked.is_(True),
)
- categories_assets_id = {r for r, in query.all()}
+ categories_assets_id = set(sql.col0(query))
# Get transactions in these categories that do not have an asset
- query = (
- s.query(TransactionSplit)
- .with_entities(
- TransactionSplit.id_,
- TransactionSplit.date_ord,
- TransactionSplit.account_id,
- TransactionSplit.category_id,
- TransactionSplit.amount,
- )
- .where(
- TransactionSplit.category_id.in_(categories_assets_id),
- TransactionSplit.asset_id.is_(None),
- )
+ query = TransactionSplit.query(
+ TransactionSplit.id_,
+ TransactionSplit.date_ord,
+ TransactionSplit.account_id,
+ TransactionSplit.category_id,
+ TransactionSplit.amount,
+ ).where(
+ TransactionSplit.category_id.in_(categories_assets_id),
+ TransactionSplit.asset_id.is_(None),
)
- for t_id, date_ord, acct_id, cat_id, amount in query.yield_per(YIELD_PER):
- t_id: int
- date_ord: int
- acct_id: int
- cat_id: int
- amount: Decimal
+ for t_id, date_ord, acct_id, cat_id, amount in sql.yield_(query):
uri = TransactionSplit.id_to_uri(t_id)
acct_name, currency = accounts[acct_id]
@@ -85,21 +70,17 @@ def test(self, s: orm.Session) -> None:
issues[uri] = msg
# Get transactions not in these categories that do have an asset
- query = (
- s.query(TransactionSplit)
- .with_entities(
- TransactionSplit.id_,
- TransactionSplit.date_ord,
- TransactionSplit.account_id,
- TransactionSplit.category_id,
- TransactionSplit.amount,
- )
- .where(
- TransactionSplit.category_id.not_in(categories_assets_id),
- TransactionSplit.asset_id.is_not(None),
- )
+ query = TransactionSplit.query(
+ TransactionSplit.id_,
+ TransactionSplit.date_ord,
+ TransactionSplit.account_id,
+ TransactionSplit.category_id,
+ TransactionSplit.amount,
+ ).where(
+ TransactionSplit.category_id.not_in(categories_assets_id),
+ TransactionSplit.asset_id.is_not(None),
)
- for t_id, date_ord, acct_id, cat_id, amount in query.yield_per(YIELD_PER):
+ for t_id, date_ord, acct_id, cat_id, amount in sql.yield_(query):
t_id: int
date_ord: int
acct_id: int
@@ -118,4 +99,4 @@ def test(self, s: orm.Session) -> None:
)
issues[uri] = msg
- self._commit_issues(s, issues)
+ self._commit_issues(issues)
diff --git a/nummus/health_checks/missing_valuations.py b/nummus/health_checks/missing_valuations.py
index 2131ec49..46933d4a 100644
--- a/nummus/health_checks/missing_valuations.py
+++ b/nummus/health_checks/missing_valuations.py
@@ -3,20 +3,17 @@
from __future__ import annotations
import datetime
-from typing import override, TYPE_CHECKING
+from typing import override
from sqlalchemy import func
+from nummus import sql
from nummus.health_checks.base import HealthCheck
from nummus.models.asset import (
Asset,
AssetValuation,
)
from nummus.models.transaction import TransactionSplit
-from nummus.models.utils import query_to_dict
-
-if TYPE_CHECKING:
- from sqlalchemy import orm
class MissingAssetValuations(HealthCheck):
@@ -26,30 +23,25 @@ class MissingAssetValuations(HealthCheck):
_SEVERE = True
@override
- def test(self, s: orm.Session) -> None:
- assets = Asset.map_name(s)
+ def test(self) -> None:
+ assets = Asset.map_name()
issues: dict[str, str] = {}
query = (
- s.query(TransactionSplit)
- .with_entities(
+ TransactionSplit.query(
TransactionSplit.asset_id,
func.min(TransactionSplit.date_ord),
)
.where(TransactionSplit.asset_id.isnot(None))
.group_by(TransactionSplit.asset_id)
)
- first_date_ords: dict[int | None, int] = query_to_dict(query)
+ first_date_ords: dict[int | None, int] = sql.to_dict(query)
- query = (
- s.query(AssetValuation)
- .with_entities(
- AssetValuation.asset_id,
- func.min(AssetValuation.date_ord),
- )
- .group_by(AssetValuation.asset_id)
- )
- first_valuations: dict[int, int] = query_to_dict(query)
+ query = AssetValuation.query(
+ AssetValuation.asset_id,
+ func.min(AssetValuation.date_ord),
+ ).group_by(AssetValuation.asset_id)
+ first_valuations: dict[int, int] = sql.to_dict(query)
for a_id, date_ord in first_date_ords.items():
if a_id is None:
@@ -68,4 +60,4 @@ def test(self, s: orm.Session) -> None:
)
issues[uri] = msg
- self._commit_issues(s, issues)
+ self._commit_issues(issues)
diff --git a/nummus/health_checks/outlier_asset_price.py b/nummus/health_checks/outlier_asset_price.py
index eecb7e57..3f9a2363 100644
--- a/nummus/health_checks/outlier_asset_price.py
+++ b/nummus/health_checks/outlier_asset_price.py
@@ -9,17 +9,14 @@
from sqlalchemy import func
-from nummus import utils
+from nummus import sql, utils
from nummus.health_checks.base import HealthCheck
from nummus.models.account import Account
from nummus.models.asset import Asset
-from nummus.models.base import YIELD_PER
from nummus.models.currency import CURRENCY_FORMATS
from nummus.models.transaction import TransactionSplit
-from nummus.models.utils import query_to_dict
if TYPE_CHECKING:
- from sqlalchemy import orm
from nummus.models.currency import Currency
@@ -38,39 +35,37 @@ class OutlierAssetPrice(HealthCheck):
_RANGE = Decimal("0.4")
@override
- def test(self, s: orm.Session) -> None:
+ def test(self) -> None:
today = datetime.datetime.now(datetime.UTC).date()
today_ord = today.toordinal()
- start_ord = (
- s.query(func.min(TransactionSplit.date_ord))
- .where(TransactionSplit.asset_id.isnot(None))
- .scalar()
+ start_ord = sql.scalar(
+ TransactionSplit.query(func.min(TransactionSplit.date_ord)).where(
+ TransactionSplit.asset_id.isnot(None),
+ ),
)
if start_ord is None:
# No asset transactions at all
- self._commit_issues(s, {})
+ self._commit_issues({})
return
# List of (uri, source, field)
issues: list[tuple[str, str, str]] = []
- assets = Asset.map_name(s)
+ assets = Asset.map_name()
asset_valuations = Asset.get_value_all(
- s,
start_ord,
today_ord + utils.DAYS_IN_WEEK,
)
- query = s.query(Account).with_entities(
+ query = Account.query(
Account.id_,
Account.currency,
)
- accounts: dict[int, Currency] = query_to_dict(query)
+ accounts: dict[int, Currency] = sql.to_dict(query)
query = (
- s.query(TransactionSplit)
- .with_entities(
+ TransactionSplit.query(
TransactionSplit.id_,
TransactionSplit.account_id,
TransactionSplit.date_ord,
@@ -84,19 +79,14 @@ def test(self, s: orm.Session) -> None:
)
.where(TransactionSplit.asset_id.isnot(None))
)
- for t_id, acct_id, date_ord, a_id, amount, qty in query.yield_per(
- YIELD_PER,
- ):
- t_id: int
- acct_id: int
- date_ord: int
- a_id: int
- amount: Decimal
- qty: Decimal
+ for t_id, acct_id, date_ord, a_id, amount, qty in sql.yield_(query):
uri = TransactionSplit.id_to_uri(t_id)
- if qty == 0:
+ if not qty:
continue
+ if TYPE_CHECKING:
+ # Enforced by query and SQL constraints
+ assert a_id is not None
# Transaction asset price
t_price = -amount / qty
@@ -131,6 +121,5 @@ def test(self, s: orm.Session) -> None:
source_len = max(len(item[1]) for item in issues) if issues else 0
self._commit_issues(
- s,
{uri: f"{source:{source_len}} {field}" for uri, source, field in issues},
)
diff --git a/nummus/health_checks/overdrawn_accounts.py b/nummus/health_checks/overdrawn_accounts.py
index e6533bc1..86dc7881 100644
--- a/nummus/health_checks/overdrawn_accounts.py
+++ b/nummus/health_checks/overdrawn_accounts.py
@@ -4,21 +4,16 @@
import datetime
from decimal import Decimal
-from typing import override, TYPE_CHECKING
+from typing import override
from sqlalchemy import func
+from nummus import sql
from nummus.health_checks.base import HealthCheck
from nummus.models.account import Account, AccountCategory
-from nummus.models.base import YIELD_PER
from nummus.models.currency import CURRENCY_FORMATS
from nummus.models.transaction import TransactionSplit
-if TYPE_CHECKING:
- from sqlalchemy import orm
-
- from nummus.models.currency import Currency
-
class OverdrawnAccounts(HealthCheck):
"""Checks for accounts that had a negative cash balance when they shouldn't."""
@@ -27,38 +22,30 @@ class OverdrawnAccounts(HealthCheck):
_SEVERE = True
@override
- def test(self, s: orm.Session) -> None:
+ def test(self) -> None:
# Get a list of accounts subject to overdrawn so not credit and loans
categories_exclude = [
AccountCategory.CREDIT,
AccountCategory.LOAN,
AccountCategory.MORTGAGE,
]
- query = (
- s.query(Account)
- .with_entities(Account.id_, Account.name, Account.currency)
- .where(Account.category.not_in(categories_exclude))
+ query = Account.query(Account.id_, Account.name, Account.currency).where(
+ Account.category.not_in(categories_exclude),
)
- accounts: dict[int, tuple[str, Currency]] = {
- r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER)
- }
+ accounts = sql.to_dict_tuple(query)
acct_ids = set(accounts)
issues: list[tuple[str, str, str]] = []
- start_ord, end_ord = (
- s.query(
+ start_ord, end_ord = sql.one(
+ TransactionSplit.query(
func.min(TransactionSplit.date_ord),
func.max(TransactionSplit.date_ord),
- )
- .where(TransactionSplit.account_id.in_(acct_ids))
- .one()
+ ).where(TransactionSplit.account_id.in_(acct_ids)),
)
- start_ord: int | None
- end_ord: int | None
- if start_ord is None or end_ord is None:
- # No asset transactions at all
- self._commit_issues(s, {})
+ if start_ord is None or end_ord is None: # type: ignore[attr-defined]
+ # No transactions at all
+ self._commit_issues({})
return
n = end_ord - start_ord + 1
@@ -66,20 +53,13 @@ def test(self, s: orm.Session) -> None:
cf = CURRENCY_FORMATS[currency]
# Get cash holdings across all time
cash_flow: list[Decimal | None] = [None] * n
- query = (
- s.query(TransactionSplit)
- .with_entities(
- TransactionSplit.date_ord,
- TransactionSplit.amount,
- )
- .where(
- TransactionSplit.account_id == acct_id,
- )
+ query = TransactionSplit.query(
+ TransactionSplit.date_ord,
+ TransactionSplit.amount,
+ ).where(
+ TransactionSplit.account_id == acct_id,
)
- for date_ord, amount in query.yield_per(YIELD_PER):
- date_ord: int
- amount: Decimal
-
+ for date_ord, amount in sql.yield_(query):
i = date_ord - start_ord
v = cash_flow[i]
@@ -111,7 +91,6 @@ def test(self, s: orm.Session) -> None:
amount_len = 0
self._commit_issues(
- s,
{
uri: f"{source:{source_len}}: {amount_str:>{amount_len}}"
for uri, source, amount_str in issues
diff --git a/nummus/health_checks/typos.py b/nummus/health_checks/typos.py
index 30083751..e180fc19 100644
--- a/nummus/health_checks/typos.py
+++ b/nummus/health_checks/typos.py
@@ -9,23 +9,18 @@
from typing import override, TYPE_CHECKING
import spellchecker
-from sqlalchemy import func, orm
+from sqlalchemy import func
-from nummus import utils
+from nummus import sql, utils
from nummus.health_checks.base import HealthCheck
from nummus.models.account import Account
from nummus.models.asset import (
Asset,
AssetCategory,
)
-from nummus.models.base import YIELD_PER
from nummus.models.label import Label
from nummus.models.transaction import TransactionSplit
-if TYPE_CHECKING:
- from sqlalchemy import orm
-
-
_LIMIT_FREQUENCY = 10
@@ -60,18 +55,18 @@ def __init__(
self._proper_nouns: set[str] = set()
@override
- def test(self, s: orm.Session) -> None:
+ def test(self) -> None:
spell = spellchecker.SpellChecker()
- accounts = Account.map_name(s)
- assets = Asset.map_name(s)
+ accounts = Account.map_name()
+ assets = Asset.map_name()
issues: dict[str, tuple[str, str, str]] = {}
self._proper_nouns.update(accounts.values())
self._proper_nouns.update(assets.values())
- issues.update(self._test_accounts(s, accounts))
- issues.update(self._test_labels(s))
- issues.update(self._test_transaction_nouns(s, accounts))
+ issues.update(self._test_accounts(accounts))
+ issues.update(self._test_labels())
+ issues.update(self._test_transaction_nouns(accounts))
# Escape words and sort to replace longest words first
# So long words aren't partially replaced if they contain a short word
@@ -82,8 +77,8 @@ def test(self, s: orm.Session) -> None:
# Remove proper nouns indicated by word boundary or space at end
re_cleaner = re.compile(rf"\b(?:{'|'.join(proper_nouns_re)})(?:\b|(?= |$))")
- issues.update(self._test_transaction_texts(s, accounts, re_cleaner, spell))
- issues.update(self._test_assets(s, assets, re_cleaner, spell))
+ issues.update(self._test_transaction_texts(accounts, re_cleaner, spell))
+ issues.update(self._test_assets(assets, re_cleaner, spell))
source_len = 0
field_len = 0
@@ -95,7 +90,6 @@ def test(self, s: orm.Session) -> None:
# Getting a suggested correction is slow and error prone,
# Just say if a word is outside of the dictionary
self._commit_issues(
- s,
{
uri: f"{source:{source_len}} {field:{field_len}}: {word}"
for uri, (word, source, field) in issues.items()
@@ -138,29 +132,22 @@ def _create_issues(self) -> dict[str, tuple[str, str, str]]:
def _test_accounts(
self,
- s: orm.Session,
accounts: dict[int, str],
) -> dict[str, tuple[str, str, str]]:
- query = s.query(Account).with_entities(
+ query = Account.query(
Account.id_,
Account.institution,
)
- for acct_id, institution in query.yield_per(YIELD_PER):
- acct_id: int
- institution: str
+ for acct_id, institution in sql.yield_(query):
name = accounts[acct_id]
source = f"Account {name}"
self._add(institution, source, "institution", 1)
self._proper_nouns.add(institution)
return self._create_issues()
- def _test_labels(
- self,
- s: orm.Session,
- ) -> dict[str, tuple[str, str, str]]:
- query = s.query(Label.name)
- for (name,) in query.yield_per(YIELD_PER):
- name: str
+ def _test_labels(self) -> dict[str, tuple[str, str, str]]:
+ query = Label.query(Label.name)
+ for name in sql.col0(query):
source = f"Label {name}"
self._add(name, source, "name", 1)
self._proper_nouns.add(name)
@@ -168,7 +155,6 @@ def _test_labels(
def _test_transaction_nouns(
self,
- s: orm.Session,
accounts: dict[int, str],
) -> dict[str, tuple[str, str, str]]:
issues: dict[str, tuple[str, str, str]] = {}
@@ -177,8 +163,7 @@ def _test_transaction_nouns(
]
for field in txn_fields:
query = (
- s.query(TransactionSplit)
- .with_entities(
+ TransactionSplit.query(
TransactionSplit.date_ord,
TransactionSplit.account_id,
field,
@@ -187,10 +172,10 @@ def _test_transaction_nouns(
.where(field.is_not(None))
.group_by(field)
)
- for date_ord, acct_id, value, count in query.yield_per(YIELD_PER):
- date_ord: int
- acct_id: int
- value: str
+ for date_ord, acct_id, value, count in sql.yield_(query):
+ if TYPE_CHECKING:
+ # Enforced by query and SQL constraints
+ assert value is not None
date = datetime.date.fromordinal(date_ord)
source = f"{date} - {accounts[acct_id]}"
self._add(value, source, field.key, count)
@@ -200,27 +185,24 @@ def _test_transaction_nouns(
def _test_transaction_texts(
self,
- s: orm.Session,
accounts: dict[int, str],
- re_cleaner: re.Pattern,
+ re_cleaner: re.Pattern[str],
spell: spellchecker.SpellChecker,
) -> dict[str, tuple[str, str, str]]:
query = (
- s.query(TransactionSplit)
- .with_entities(
+ TransactionSplit.query(
TransactionSplit.date_ord,
TransactionSplit.account_id,
TransactionSplit.memo,
func.count(),
)
.group_by(TransactionSplit.memo)
+ .where(TransactionSplit.memo.is_not(None))
)
- for date_ord, acct_id, value, count in query.yield_per(YIELD_PER):
- date_ord: int
- acct_id: int
- value: str | None
- if value is None:
- continue
+ for date_ord, acct_id, value, count in sql.yield_(query):
+ if TYPE_CHECKING:
+ # Enforced by query and SQL constraints
+ assert value is not None
date = datetime.date.fromordinal(date_ord)
source = f"{date} - {accounts[acct_id]}"
cleaned = re_cleaner.sub("", value).lower()
@@ -237,25 +219,21 @@ def _test_transaction_texts(
def _test_assets(
self,
- s: orm.Session,
assets: dict[int, str],
- re_cleaner: re.Pattern,
+ re_cleaner: re.Pattern[str],
spell: spellchecker.SpellChecker,
) -> dict[str, tuple[str, str, str]]:
- query = (
- s.query(Asset)
- .with_entities(
- Asset.id_,
- Asset.description,
- )
- .where(
- Asset.category != AssetCategory.INDEX,
- Asset.description.is_not(None),
- )
+ query = Asset.query(
+ Asset.id_,
+ Asset.description,
+ ).where(
+ Asset.category != AssetCategory.INDEX,
+ Asset.description.is_not(None),
)
- for a_id, value in query.yield_per(YIELD_PER):
- a_id: int
- value: str
+ for a_id, value in sql.yield_(query):
+ if TYPE_CHECKING:
+ # Enforced by query and SQL constraints
+ assert value is not None
source = f"Asset {assets[a_id]}"
cleaned = re_cleaner.sub("", value).lower()
for word in self._RE_WORDS.split(cleaned):
diff --git a/nummus/health_checks/unbalanced_transfers.py b/nummus/health_checks/unbalanced_transfers.py
index b0f60e78..136cec9e 100644
--- a/nummus/health_checks/unbalanced_transfers.py
+++ b/nummus/health_checks/unbalanced_transfers.py
@@ -9,19 +9,17 @@
from decimal import Decimal
from typing import override, TYPE_CHECKING
+from nummus import sql
from nummus.health_checks.base import HealthCheck
from nummus.models.account import Account
-from nummus.models.base import YIELD_PER
from nummus.models.currency import CURRENCY_FORMATS
from nummus.models.transaction import TransactionSplit
from nummus.models.transaction_category import (
TransactionCategory,
TransactionCategoryGroup,
)
-from nummus.models.utils import query_to_dict
if TYPE_CHECKING:
- from sqlalchemy import orm
from nummus.models.currency import Currency
@@ -37,28 +35,25 @@ class UnbalancedTransfers(HealthCheck):
_SEVERE = True
@override
- def test(self, s: orm.Session) -> None:
+ def test(self) -> None:
issues: dict[str, str] = {}
- query = s.query(
+ query = TransactionCategory.query(
TransactionCategory.id_,
TransactionCategory.emoji_name,
).where(
TransactionCategory.group == TransactionCategoryGroup.TRANSFER,
)
- cat_transfers_ids: dict[int, str] = query_to_dict(query)
+ cat_transfers_ids = sql.to_dict(query)
- query = s.query(Account).with_entities(
+ query = Account.query(
Account.id_,
Account.name,
Account.currency,
)
- accounts: dict[int, tuple[str, Currency]] = {
- r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER)
- }
+ accounts = sql.to_dict_tuple(query)
query = (
- s.query(TransactionSplit)
- .with_entities(
+ TransactionSplit.query(
TransactionSplit.account_id,
TransactionSplit.date_ord,
TransactionSplit.amount,
@@ -68,12 +63,9 @@ def test(self, s: orm.Session) -> None:
.order_by(TransactionSplit.date_ord, TransactionSplit.amount)
)
current_date_ord: int | None = None
- total = defaultdict(Decimal)
+ total: dict[int, Decimal] = defaultdict(Decimal)
current_splits: dict[int, list[tuple[int, Decimal]]] = defaultdict(list)
- for acct_id, date_ord, amount, t_cat_id in query.yield_per(YIELD_PER):
- acct_id: int
- date_ord: int
- amount: Decimal
+ for acct_id, date_ord, amount, t_cat_id in sql.yield_(query):
if current_date_ord is None:
current_date_ord = date_ord
if date_ord != current_date_ord:
@@ -101,7 +93,7 @@ def test(self, s: orm.Session) -> None:
)
issues[uri] = msg
- self._commit_issues(s, issues)
+ self._commit_issues(issues)
@classmethod
def _create_issue(
diff --git a/nummus/health_checks/uncleared_transactions.py b/nummus/health_checks/uncleared_transactions.py
index 45c39462..cc955f28 100644
--- a/nummus/health_checks/uncleared_transactions.py
+++ b/nummus/health_checks/uncleared_transactions.py
@@ -4,21 +4,14 @@
import datetime
import textwrap
-from typing import override, TYPE_CHECKING
+from typing import override
+from nummus import sql
from nummus.health_checks.base import HealthCheck
from nummus.models.account import Account
-from nummus.models.base import YIELD_PER
from nummus.models.currency import CURRENCY_FORMATS
from nummus.models.transaction import TransactionSplit
-if TYPE_CHECKING:
- from decimal import Decimal
-
- from sqlalchemy import orm
-
- from nummus.models.currency import Currency
-
class UnclearedTransactions(HealthCheck):
"""Checks for uncleared transactions."""
@@ -31,38 +24,27 @@ class UnclearedTransactions(HealthCheck):
_SEVERE = False
@override
- def test(self, s: orm.Session) -> None:
- query = s.query(Account).with_entities(
+ def test(self) -> None:
+ query = Account.query(
Account.id_,
Account.name,
Account.currency,
)
- accounts: dict[int, tuple[str, Currency]] = {
- r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER)
- }
+ accounts = sql.to_dict_tuple(query)
if len(accounts) == 0:
- self._commit_issues(s, {})
+ self._commit_issues({})
return
acct_len = max(len(acct[0]) for acct in accounts.values())
issues: dict[str, str] = {}
- query = (
- s.query(TransactionSplit)
- .with_entities(
- TransactionSplit.id_,
- TransactionSplit.date_ord,
- TransactionSplit.account_id,
- TransactionSplit.payee,
- TransactionSplit.amount,
- )
- .where(TransactionSplit.cleared.is_(False))
- )
- for t_id, date_ord, acct_id, payee, amount in query.yield_per(YIELD_PER):
- t_id: int
- date_ord: int
- acct_id: int
- payee: str
- amount: Decimal
+ query = TransactionSplit.query(
+ TransactionSplit.id_,
+ TransactionSplit.date_ord,
+ TransactionSplit.account_id,
+ TransactionSplit.payee,
+ TransactionSplit.amount,
+ ).where(TransactionSplit.cleared.is_(False))
+ for t_id, date_ord, acct_id, payee, amount in sql.yield_(query):
uri = TransactionSplit.id_to_uri(t_id)
acct_name, currency = accounts[acct_id]
@@ -76,4 +58,4 @@ def test(self, s: orm.Session) -> None:
)
issues[uri] = msg
- self._commit_issues(s, issues)
+ self._commit_issues(issues)
diff --git a/nummus/health_checks/unnecessary_slits.py b/nummus/health_checks/unnecessary_slits.py
index 55f73725..2ae9f634 100644
--- a/nummus/health_checks/unnecessary_slits.py
+++ b/nummus/health_checks/unnecessary_slits.py
@@ -3,19 +3,16 @@
from __future__ import annotations
import datetime
-from typing import NamedTuple, override, TYPE_CHECKING
+from typing import NamedTuple, override
from sqlalchemy import func
+from nummus import sql
from nummus.health_checks.base import HealthCheck
from nummus.models.account import Account
-from nummus.models.base import YIELD_PER
from nummus.models.transaction import TransactionSplit
from nummus.models.transaction_category import TransactionCategory
-if TYPE_CHECKING:
- from sqlalchemy import orm
-
class RawIssue(NamedTuple):
"""Type definition for a raw issue."""
@@ -33,15 +30,14 @@ class UnnecessarySplits(HealthCheck):
_SEVERE = False
@override
- def test(self, s: orm.Session) -> None:
- accounts = Account.map_name(s)
- categories = TransactionCategory.map_name_emoji(s)
+ def test(self) -> None:
+ accounts = Account.map_name()
+ categories = TransactionCategory.map_name_emoji()
issues: list[RawIssue] = []
query = (
- s.query(TransactionSplit)
- .with_entities(
+ TransactionSplit.query(
TransactionSplit.date_ord,
TransactionSplit.account_id,
TransactionSplit.parent_id,
@@ -55,14 +51,7 @@ def test(self, s: orm.Session) -> None:
.order_by(TransactionSplit.date_ord)
.having(func.count() > 1)
)
- for date_ord, acct_id, t_id, payee, t_cat_id in query.yield_per(
- YIELD_PER,
- ):
- date_ord: int
- acct_id: int
- t_id: int
- payee: str | None
- t_cat_id: int
+ for date_ord, acct_id, t_id, payee, t_cat_id in sql.yield_(query):
# Create a robust uri for this duplicate
uri = f"{t_id}.{payee}.{t_cat_id}"
@@ -82,7 +71,6 @@ def test(self, s: orm.Session) -> None:
t_cat_len = 0
self._commit_issues(
- s,
{
issue.uri: (
f"{issue.source:{source_len}}: "
diff --git a/nummus/health_checks/unused_categories.py b/nummus/health_checks/unused_categories.py
index 4d45017d..92e83eb5 100644
--- a/nummus/health_checks/unused_categories.py
+++ b/nummus/health_checks/unused_categories.py
@@ -2,16 +2,13 @@
from __future__ import annotations
-from typing import override, TYPE_CHECKING
+from typing import override
+from nummus import sql
from nummus.health_checks.base import HealthCheck
from nummus.models.budget import BudgetAssignment
from nummus.models.transaction import TransactionSplit
from nummus.models.transaction_category import TransactionCategory
-from nummus.models.utils import query_to_dict
-
-if TYPE_CHECKING:
- from sqlalchemy import orm
class UnusedCategories(HealthCheck):
@@ -21,23 +18,22 @@ class UnusedCategories(HealthCheck):
_SEVERE = False
@override
- def test(self, s: orm.Session) -> None:
+ def test(self) -> None:
# Only check unlocked categories
- query = (
- s.query(TransactionCategory)
- .with_entities(TransactionCategory.id_, TransactionCategory.emoji_name)
- .where(TransactionCategory.locked.is_(False))
- )
- categories: dict[int, str] = query_to_dict(query)
+ query = TransactionCategory.query(
+ TransactionCategory.id_,
+ TransactionCategory.emoji_name,
+ ).where(TransactionCategory.locked.is_(False))
+ categories: dict[int, str] = sql.to_dict(query)
if len(categories) == 0:
- self._commit_issues(s, {})
+ self._commit_issues({})
return
- query = s.query(TransactionSplit.category_id)
- used_categories = {r[0] for r in query.distinct()}
+ query = TransactionSplit.query(TransactionSplit.category_id)
+ used_categories = set(sql.col0(query))
- query = s.query(BudgetAssignment.category_id)
- used_categories.update(r[0] for r in query.distinct())
+ query = BudgetAssignment.query(BudgetAssignment.category_id)
+ used_categories.update(sql.col0(query))
categories = {
t_cat_id: name
@@ -49,7 +45,6 @@ def test(self, s: orm.Session) -> None:
)
self._commit_issues(
- s,
{
TransactionCategory.id_to_uri(t_cat_id): (
f"{name:{category_len}} has no "
diff --git a/nummus/migrations/base.py b/nummus/migrations/base.py
index 1992d53f..ced1eecf 100644
--- a/nummus/migrations/base.py
+++ b/nummus/migrations/base.py
@@ -12,13 +12,11 @@
from sqlalchemy.schema import CreateTable
from nummus import sql
+from nummus.models.base import Base
from nummus.models.utils import dump_table_configs, get_constraints
if TYPE_CHECKING:
- from sqlalchemy import orm
-
from nummus import portfolio
- from nummus.models.base import Base
class Migrator(ABC):
@@ -55,20 +53,19 @@ def min_version(cls) -> Version:
def add_column(
self,
- s: orm.Session,
model: type[Base],
- column: orm.QueryableAttribute,
+ column: sql.Column,
initial_value: object | None = None,
) -> None:
"""Add a column to a table.
Args:
- s: SQL session to use
model: Table to modify
column: Column to add
initial_value: Value to set all rows to
"""
+ s = model.session()
engine = s.get_bind().engine
col_name = sql.escape(column.name)
@@ -77,13 +74,12 @@ def add_column(
s.execute(sqlalchemy.text(stmt))
if initial_value is not None:
- s.query(model).update({column: initial_value})
+ model.query().update({column: initial_value})
self.pending_schema_updates.add(model)
def rename_column(
self,
- s: orm.Session,
model: type[Base],
old_name: str,
new_name: str,
@@ -91,7 +87,6 @@ def rename_column(
"""Rename a column in a table.
Args:
- s: SQL session to use
model: Table to modify
old_name: Current name of column
new_name: New name of column
@@ -100,7 +95,7 @@ def rename_column(
old_name = sql.escape(old_name)
new_name = sql.escape(new_name)
stmt = f'ALTER TABLE "{model.__tablename__}" RENAME {old_name} TO {new_name}'
- s.execute(sqlalchemy.text(stmt))
+ model.session().execute(sqlalchemy.text(stmt))
# RENAME modifies column references but not constraint names
# Need to update schema to update those
@@ -108,32 +103,29 @@ def rename_column(
def drop_column(
self,
- s: orm.Session,
model: type[Base],
col_name: str,
) -> None:
"""Rename a column in a table.
Args:
- s: SQL session to use
model: Table to modify
col_name: Name of column to drop
"""
- constraints = get_constraints(s, model)
+ constraints = get_constraints(model)
if any(col_name in sql_text for _, sql_text in constraints):
- self.recreate_table(s, model, drop={col_name})
+ self.recreate_table(model, drop={col_name})
else:
# Able to drop directly
col_name = sql.escape(col_name)
stmt = f'ALTER TABLE "{model.__tablename__}" DROP {col_name}'
- s.execute(sqlalchemy.text(stmt))
+ model.session().execute(sqlalchemy.text(stmt))
# DROP does not need updated schema
def recreate_table(
self,
- s: orm.Session,
model: type[Base],
*,
drop: set[str] | None = None,
@@ -142,13 +134,13 @@ def recreate_table(
"""Rebuild table, optionally dropping columns.
Args:
- s: SQL session to use
model: Table to modify
drop: Set of column names to drop
create_stmt: Statement to execute to create new table,
None will modify existing config
"""
+ s = model.session()
drop = drop or set()
# In SQLite we can do the hacky way or recreate the table
# Opt for recreate
@@ -159,7 +151,7 @@ def recreate_table(
new_config = create_stmt.splitlines()
else:
# Edit table config, dropping any columns
- old_config = dump_table_configs(s, model)
+ old_config = dump_table_configs(model)
new_config: list[str] = []
re_column = re.compile(r" +([a-z_]+) [A-Z ]+")
re_constraint = re.compile(r' +[A-Z ]+(?:"[^\"]+" [A-Z ]+)?\(([^\)]+)\)')
@@ -191,7 +183,7 @@ def recreate_table(
s.execute(sqlalchemy.text(stmt))
# Drop old table
- self.drop_table(s, name)
+ self.drop_table(name)
# Rename new into old
stmt = f'ALTER TABLE "migration_temp" RENAME TO "{name}"'
@@ -204,16 +196,15 @@ def recreate_table(
self.pending_schema_updates.add(model)
@staticmethod
- def drop_table(s: orm.Session, table_name: str) -> None:
+ def drop_table(table_name: str) -> None:
"""Drop a table.
Args:
- s: SQL session to use
table_name: Name of table to drop
"""
stmt = f'DROP TABLE "{table_name}"'
- s.execute(sqlalchemy.text(stmt))
+ Base.session().execute(sqlalchemy.text(stmt))
class SchemaMigrator(Migrator):
@@ -235,5 +226,5 @@ def migrate(self, p: portfolio.Portfolio) -> list[str]:
with p.begin_session() as s:
table: sqlalchemy.Table = model.sql_table()
create_stmt = CreateTable(table).compile(s.get_bind()).string.strip()
- self.recreate_table(s, model, create_stmt=create_stmt)
+ self.recreate_table(model, create_stmt=create_stmt)
return []
diff --git a/nummus/migrations/v0_10.py b/nummus/migrations/v0_10.py
index 4e7bc61b..c0734c76 100644
--- a/nummus/migrations/v0_10.py
+++ b/nummus/migrations/v0_10.py
@@ -5,7 +5,6 @@
from typing import override, TYPE_CHECKING
from nummus.migrations.base import Migrator
-from nummus.models.base import YIELD_PER
from nummus.models.health_checks import HealthCheckIssue
from nummus.models.transaction_category import TransactionCategory
@@ -23,16 +22,14 @@ def migrate(self, p: portfolio.Portfolio) -> list[str]:
comments: list[str] = []
- with p.begin_session() as s:
+ with p.begin_session():
self.rename_column(
- s,
TransactionCategory,
"essential",
"essential_spending",
)
- query = s.query(HealthCheckIssue)
- for issue in query.yield_per(YIELD_PER):
+ for issue in HealthCheckIssue.all():
issue.check = issue.check.capitalize()
return comments
diff --git a/nummus/migrations/v0_13.py b/nummus/migrations/v0_13.py
index 3c7c8f18..3bb0a348 100644
--- a/nummus/migrations/v0_13.py
+++ b/nummus/migrations/v0_13.py
@@ -40,15 +40,13 @@ def migrate(self, p: portfolio.Portfolio) -> list[str]:
for t_split_id, name in s.execute(sqlalchemy.text(stmt)):
tag_mapping[name].add(t_split_id)
- labels = [Label(name=name) for name in tag_mapping]
- s.add_all(labels)
- s.flush()
+ labels = [Label.create(name=name) for name in tag_mapping]
for label in labels:
for t_split_id in tag_mapping[label.name]:
- s.add(LabelLink(label_id=label.id_, t_split_id=t_split_id))
+ LabelLink.create(label_id=label.id_, t_split_id=t_split_id)
- with p.begin_session() as s:
- self.drop_column(s, TransactionSplit, "tag")
+ with p.begin_session():
+ self.drop_column(TransactionSplit, "tag")
return comments
diff --git a/nummus/migrations/v0_15.py b/nummus/migrations/v0_15.py
index 616eabb4..a719e8b8 100644
--- a/nummus/migrations/v0_15.py
+++ b/nummus/migrations/v0_15.py
@@ -7,12 +7,15 @@
import sqlalchemy
from nummus import exceptions as exc
+from nummus import sql
from nummus.migrations.base import Migrator
from nummus.models.base import Base
from nummus.models.label import Label, LabelLink
from nummus.models.utils import dump_table_configs
if TYPE_CHECKING:
+ from sqlalchemy import orm
+
from nummus import portfolio
@@ -30,7 +33,7 @@ def migrate(self, p: portfolio.Portfolio) -> list[str]:
with p.begin_session() as s:
# Already have Label from updated v0.13 migrator, skip this one
try:
- dump_table_configs(s, Label)
+ dump_table_configs(Label)
except exc.NoResultFound:
pass
else:
@@ -43,19 +46,20 @@ def migrate(self, p: portfolio.Portfolio) -> list[str]:
# Move existing tags to labels
with p.begin_session() as s:
stmt = "SELECT id_, name FROM tag"
- # Hand crafted SQL statement can't use query_to_dict
- tags: dict[int, str] = dict(s.execute(sqlalchemy.text(stmt)).all()) # type: ignore[attr-defined]
+ query: orm.query.RowReturningQuery[tuple[int, str]] = s.execute( # type: ignore[attr-defined]
+ sqlalchemy.text(stmt),
+ )
+ tags: dict[int, str] = sql.to_dict(query)
- labels = [Label(id_=tag_id, name=name) for tag_id, name in tags.items()]
- s.add_all(labels)
- s.flush()
+ for tag_id, name in tags.items():
+ Label.create(id_=tag_id, name=name)
stmt = "SELECT tag_id, t_split_id FROM tag_link"
for tag_id, t_split_id in s.execute(sqlalchemy.text(stmt)):
- s.add(LabelLink(label_id=tag_id, t_split_id=t_split_id))
+ LabelLink.create(label_id=tag_id, t_split_id=t_split_id)
- with p.begin_session() as s:
- self.drop_table(s, "tag_link")
- self.drop_table(s, "tag")
+ with p.begin_session():
+ self.drop_table("tag_link")
+ self.drop_table("tag")
return comments
diff --git a/nummus/migrations/v0_16.py b/nummus/migrations/v0_16.py
index 09b76d25..76cc6077 100644
--- a/nummus/migrations/v0_16.py
+++ b/nummus/migrations/v0_16.py
@@ -27,12 +27,13 @@ def migrate(self, p: portfolio.Portfolio) -> list[str]:
f"Portfolio currency set to {DEFAULT_CURRENCY.pretty}, use web to edit",
]
- with p.begin_session() as s:
- s.add(
- Config(key=ConfigKey.BASE_CURRENCY, value=str(DEFAULT_CURRENCY.value)),
+ with p.begin_session():
+ Config.create(
+ key=ConfigKey.BASE_CURRENCY,
+ value=str(DEFAULT_CURRENCY.value),
)
- self.add_column(s, Account, Account.currency, DEFAULT_CURRENCY)
- self.add_column(s, Asset, Asset.currency, DEFAULT_CURRENCY)
+ self.add_column(Account, Account.currency, DEFAULT_CURRENCY)
+ self.add_column(Asset, Asset.currency, DEFAULT_CURRENCY)
return comments
diff --git a/nummus/migrations/v0_2.py b/nummus/migrations/v0_2.py
index 8bb11716..284046bb 100644
--- a/nummus/migrations/v0_2.py
+++ b/nummus/migrations/v0_2.py
@@ -4,12 +4,12 @@
from typing import override, TYPE_CHECKING
-from sqlalchemy import func, sql
+from sqlalchemy import func
+from nummus import sql
from nummus.migrations.base import Migrator
from nummus.models.account import Account
from nummus.models.asset import Asset
-from nummus.models.base import YIELD_PER
from nummus.models.budget import BudgetGroup
from nummus.models.config import Config
from nummus.models.health_checks import HealthCheckIssue
@@ -32,29 +32,29 @@ def migrate(self, p: portfolio.Portfolio) -> list[str]:
comments: list[str] = []
- with p.begin_session() as s:
+ with p.begin_session():
# Update TransactionSplit to add text_fields
- self.add_column(s, TransactionSplit, TransactionSplit.text_fields)
- self.rename_column(s, TransactionSplit, "description", "memo")
- self.rename_column(s, TransactionSplit, "linked", "cleared")
- self.drop_column(s, TransactionSplit, "locked")
+ self.add_column(TransactionSplit, TransactionSplit.text_fields)
+ self.rename_column(TransactionSplit, "description", "memo")
+ self.rename_column(TransactionSplit, "linked", "cleared")
+ self.drop_column(TransactionSplit, "locked")
- with p.begin_session() as s:
+ with p.begin_session():
# Update Transaction to add payee
- self.add_column(s, Transaction, Transaction.payee)
- self.rename_column(s, Transaction, "linked", "cleared")
- self.drop_column(s, Transaction, "locked")
+ self.add_column(Transaction, Transaction.payee)
+ self.rename_column(Transaction, "linked", "cleared")
+ self.drop_column(Transaction, "locked")
- with p.begin_session() as s:
+ with p.begin_session():
# Check which ones have more than one payee
- accounts = Account.map_name(s)
+ accounts = Account.map_name()
query = (
- s.query(TransactionSplit)
+ TransactionSplit.query()
.group_by(TransactionSplit.parent_id)
.having(func.count(TransactionSplit.payee.distinct()) > 1)
.order_by(TransactionSplit.date_ord)
)
- for t_split in query.yield_per(YIELD_PER):
+ for t_split in sql.yield_(query):
msg = (
"This transaction had multiple payees, only one allowed: "
f"{t_split.date} {accounts[t_split.account_id]}, please validate"
@@ -62,21 +62,20 @@ def migrate(self, p: portfolio.Portfolio) -> list[str]:
comments.append(msg)
sub_query = (
- s.query(TransactionSplit.payee)
+ TransactionSplit.query(TransactionSplit.payee)
.where(
TransactionSplit.parent_id == Transaction.id_,
)
.scalar_subquery()
)
- s.query(Transaction).update(
+ Transaction.query().update(
{Transaction.payee: sub_query},
)
- with p.begin_session() as s:
+ with p.begin_session():
n_batch = 100
# Update text_fields after payee is set
- query = s.query(TransactionSplit)
- for t_split in query.yield_per(YIELD_PER):
+ for t_split in TransactionSplit.all():
t_split.parent = t_split.parent
t_split.memo = t_split.memo
@@ -85,8 +84,7 @@ def migrate(self, p: portfolio.Portfolio) -> list[str]:
offset = 0
while has_more:
query = (
- s.query(TransactionCategory)
- .with_entities(
+ TransactionCategory.query(
TransactionCategory.id_,
TransactionCategory.emoji_name,
)
@@ -96,9 +94,9 @@ def migrate(self, p: portfolio.Portfolio) -> list[str]:
)
values = {
id_: TransactionCategory.clean_emoji_name(v)
- for id_, v in query.yield_per(YIELD_PER)
+ for id_, v in sql.yield_(query)
}
- s.query(TransactionCategory).where(
+ TransactionCategory.query().where(
TransactionCategory.id_.in_(values),
).update(
{
diff --git a/nummus/models/account.py b/nummus/models/account.py
index c62c1eea..af167511 100644
--- a/nummus/models/account.py
+++ b/nummus/models/account.py
@@ -8,7 +8,7 @@
from sqlalchemy import func, orm, UniqueConstraint
-from nummus import utils
+from nummus import sql, utils
from nummus.models.asset import Asset
from nummus.models.base import (
Base,
@@ -18,12 +18,10 @@
ORMStrOpt,
SQLEnum,
string_column_args,
- YIELD_PER,
)
from nummus.models.currency import Currency
from nummus.models.transaction import Transaction, TransactionSplit
from nummus.models.transaction_category import TransactionCategory
-from nummus.models.utils import obj_session
if TYPE_CHECKING:
from collections.abc import Iterable
@@ -92,6 +90,8 @@ class Account(Base):
*string_column_args("institution"),
)
+ _SEARCH_PROPERTIES = ("number", "institution", "name")
+
@orm.validates("name", "number", "institution")
def validate_strings(self, key: str, field: str | None) -> str | None:
"""Validate string fields satisfy constraints.
@@ -109,25 +109,22 @@ def validate_strings(self, key: str, field: str | None) -> str | None:
@property
def opened_on_ord(self) -> int | None:
"""Date ordinal of first Transaction."""
- s = obj_session(self)
- query = s.query(func.min(Transaction.date_ord)).where(
+ query = Transaction.query(func.min(Transaction.date_ord)).where(
Transaction.account_id == self.id_,
)
- return query.scalar()
+ return sql.scalar(query)
@property
def updated_on_ord(self) -> int | None:
"""Date ordinal of latest Transaction."""
- s = obj_session(self)
- query = s.query(func.max(Transaction.date_ord)).where(
+ query = Transaction.query(func.max(Transaction.date_ord)).where(
Transaction.account_id == self.id_,
)
- return query.scalar()
+ return sql.scalar(query)
@classmethod
def get_value_all(
cls,
- s: orm.Session,
start_ord: int,
end_ord: int,
ids: Iterable[int] | None = None,
@@ -136,7 +133,6 @@ def get_value_all(
"""Get the value of all Accounts from start to end date.
Args:
- s: SQL session to use
start_ord: First date ordinal to evaluate
end_ord: Last date ordinal to evaluate (inclusive)
ids: Limit results to specific Accounts by ID
@@ -151,10 +147,11 @@ def get_value_all(
n = end_ord - start_ord + 1
if not ids and ids is not None:
- acct_values = defaultdict(lambda: [Decimal()] * n)
- acct_profit = defaultdict(lambda: [Decimal()] * n)
- asset_values = defaultdict(lambda: [Decimal()] * n)
- return ValueResultAll(acct_values, acct_profit, asset_values)
+ return ValueResultAll(
+ defaultdict(lambda: [Decimal()] * n),
+ defaultdict(lambda: [Decimal()] * n),
+ defaultdict(lambda: [Decimal()] * n),
+ )
cash_flow_accounts: dict[int, list[Decimal | None]] = defaultdict(
lambda: [None] * n,
@@ -162,21 +159,20 @@ def get_value_all(
cost_basis_accounts: dict[int, list[Decimal | None]] = defaultdict(
lambda: [None] * n,
)
- ids = ids or {r[0] for r in s.query(Account.id_).all()}
+ ids = ids or set(sql.col0(Account.query(Account.id_)))
# Profit = Interest + dividends + rewards + change in asset value - fees
# Dividends, fees, and change in value can be assigned to an asset
# Change in value = current value - basis
# Get list of transaction categories not included in cost basis
- query = s.query(TransactionCategory.id_).where(
+ query = TransactionCategory.query(TransactionCategory.id_).where(
TransactionCategory.is_profit_loss.is_(True),
)
- cost_basis_skip_ids = {t_cat_id for t_cat_id, in query.all()}
+ cost_basis_skip_ids = set(sql.col0(query))
# Get Account cash value on start date
query = (
- s.query(TransactionSplit)
- .with_entities(
+ TransactionSplit.query(
TransactionSplit.account_id,
func.sum(TransactionSplit.amount),
)
@@ -186,15 +182,12 @@ def get_value_all(
)
.group_by(TransactionSplit.account_id)
)
- for acct_id, iv in query.all():
- acct_id: int
- iv: Decimal
+ for acct_id, iv in sql.yield_(query):
cash_flow_accounts[acct_id][0] = iv
# Calculate cost basis on first day
query = (
- s.query(TransactionSplit)
- .with_entities(
+ TransactionSplit.query(
TransactionSplit.account_id,
func.sum(TransactionSplit.amount),
)
@@ -205,36 +198,24 @@ def get_value_all(
)
.group_by(TransactionSplit.account_id)
)
- for acct_id, iv in query.all():
- acct_id: int
- iv: Decimal
+ for acct_id, iv in sql.yield_(query):
cost_basis_accounts[acct_id][0] = -iv
if start_ord != end_ord:
# Get cash_flow on each day between start and end
# Not Account.get_cash_flow because being categorized doesn't matter and
# slows it down
- query = (
- s.query(TransactionSplit)
- .with_entities(
- TransactionSplit.account_id,
- TransactionSplit.date_ord,
- TransactionSplit.amount,
- TransactionSplit.category_id,
- )
- .where(
- TransactionSplit.date_ord <= end_ord,
- TransactionSplit.date_ord > start_ord,
- TransactionSplit.account_id.in_(ids),
- )
+ query = TransactionSplit.query(
+ TransactionSplit.account_id,
+ TransactionSplit.date_ord,
+ TransactionSplit.amount,
+ TransactionSplit.category_id,
+ ).where(
+ TransactionSplit.date_ord <= end_ord,
+ TransactionSplit.date_ord > start_ord,
+ TransactionSplit.account_id.in_(ids),
)
-
- for acct_id, date_ord, amount, t_cat_id in query.yield_per(YIELD_PER):
- acct_id: int
- date_ord: int
- amount: Decimal
- t_cat_id: int
-
+ for acct_id, date_ord, amount, t_cat_id in sql.yield_(query):
i = date_ord - start_ord
v = cash_flow_accounts[acct_id][i]
@@ -248,33 +229,29 @@ def get_value_all(
# Get assets for all Accounts
assets_accounts = cls.get_asset_qty_all(
- s,
start_ord,
end_ord,
list(cash_flow_accounts.keys()),
)
# Get day one asset transactions to add to profit & loss
- query = (
- s.query(TransactionSplit)
- .with_entities(
- TransactionSplit.account_id,
- TransactionSplit.asset_id,
- TransactionSplit.asset_quantity,
- )
- .where(
- TransactionSplit.asset_id.isnot(None),
- TransactionSplit.date_ord == start_ord,
- TransactionSplit.account_id.in_(ids),
- )
+ query = TransactionSplit.query(
+ TransactionSplit.account_id,
+ TransactionSplit.asset_id,
+ TransactionSplit.asset_quantity,
+ ).where(
+ TransactionSplit.asset_id.isnot(None),
+ TransactionSplit.date_ord == start_ord,
+ TransactionSplit.account_id.in_(ids),
)
assets_day_zero: dict[int, dict[int, Decimal]] = defaultdict(
lambda: defaultdict(Decimal),
)
- for acct_id, a_id, qty in query.yield_per(YIELD_PER):
- acct_id: int
- a_id: int
- qty: Decimal
+ for acct_id, a_id, qty in sql.yield_(query):
+ if TYPE_CHECKING:
+ # Enforced by query and SQL constraints
+ assert a_id is not None
+ assert qty is not None
assets_day_zero[acct_id][a_id] += qty
# Remove zeros
@@ -295,17 +272,15 @@ def get_value_all(
a_ids: set[int] = utils.set_sub_keys(assets_accounts)
a_ids.update(utils.set_sub_keys(assets_day_zero))
- asset_prices = Asset.get_value_all(s, start_ord, end_ord, a_ids)
+ asset_prices = Asset.get_value_all(start_ord, end_ord, a_ids)
forex_by_account: dict[int, list[Decimal]] | None = None
if forex is not None:
- query = (
- s.query(Account)
- .with_entities(Account.id_, Account.currency)
- .where(Account.id_.in_(ids))
+ query = Account.query(Account.id_, Account.currency).where(
+ Account.id_.in_(ids),
)
forex_by_account = {
- acct_id: forex[currency] for acct_id, currency in query.all()
+ acct_id: forex[currency] for acct_id, currency in sql.yield_(query)
}
return cls._merge_value_data(
@@ -384,23 +359,23 @@ def get_value(
self,
start_ord: int,
end_ord: int,
+ forex: dict[Currency, list[Decimal]] | None = None,
) -> ValueResult:
"""Get the value of Account from start to end date.
Args:
start_ord: First date ordinal to evaluate
end_ord: Last date ordinal to evaluate (inclusive)
+ forex: Currency exchange rates, None will not normalize
Returns:
ValueResult
"""
- s = obj_session(self)
-
# Not reusing get_value_all is faster by ~2ms,
# not worth maintaining two almost identical implementations
- r = self.get_value_all(s, start_ord, end_ord, [self.id_])
+ r = self.get_value_all(start_ord, end_ord, [self.id_], forex=forex)
return ValueResult(
r.values_by_account[self.id_],
r.profits[self.id_],
@@ -410,7 +385,6 @@ def get_value(
@classmethod
def get_cash_flow_all(
cls,
- s: orm.Session,
start_ord: int,
end_ord: int,
ids: Iterable[int] | None = None,
@@ -420,7 +394,6 @@ def get_cash_flow_all(
Does not separate results by account.
Args:
- s: SQL session to use
start_ord: First date ordinal to evaluate
end_ord: Last date ordinal to evaluate (inclusive)
ids: Limit results to specific Accounts by ID
@@ -435,26 +408,18 @@ def get_cash_flow_all(
categories: dict[int, list[Decimal]] = defaultdict(lambda: [Decimal()] * n)
# Transactions between start and end
- query = (
- s.query(TransactionSplit)
- .with_entities(
- TransactionSplit.date_ord,
- TransactionSplit.amount,
- TransactionSplit.category_id,
- )
- .where(
- TransactionSplit.date_ord <= end_ord,
- TransactionSplit.date_ord >= start_ord,
- )
+ query = TransactionSplit.query(
+ TransactionSplit.date_ord,
+ TransactionSplit.amount,
+ TransactionSplit.category_id,
+ ).where(
+ TransactionSplit.date_ord <= end_ord,
+ TransactionSplit.date_ord >= start_ord,
)
if ids is not None:
query = query.where(TransactionSplit.account_id.in_(ids))
- for t_date_ord, amount, category_id in query.yield_per(YIELD_PER):
- t_date_ord: int
- amount: Decimal
- category_id: int
-
+ for t_date_ord, amount, category_id in sql.yield_(query):
categories[category_id][t_date_ord - start_ord] += amount
return categories
@@ -478,13 +443,11 @@ def get_cash_flow(
Includes None in categories
"""
- s = obj_session(self)
- return self.get_cash_flow_all(s, start_ord, end_ord, [self.id_])
+ return self.get_cash_flow_all(start_ord, end_ord, [self.id_])
@classmethod
def get_asset_qty_all(
cls,
- s: orm.Session,
start_ord: int,
end_ord: int,
ids: Iterable[int] | None = None,
@@ -492,7 +455,6 @@ def get_asset_qty_all(
"""Get the quantity of Assets held from start to end date.
Args:
- s: SQL session to use
start_ord: First date ordinal to evaluate
end_ord: Last date ordinal to evaluate (inclusive)
ids: Limit results to specific Accounts by ID
@@ -509,13 +471,15 @@ def get_asset_qty_all(
lambda: defaultdict(lambda: [Decimal()] * n),
)
- iv_accounts: dict[int, dict[int, Decimal]] = defaultdict(dict)
- ids = ids or {r[0] for r in s.query(Account.id_).all()}
+ # Daily delta in qty
+ deltas_accounts: dict[int, dict[int, list[Decimal | None]]] = defaultdict(
+ lambda: defaultdict(lambda: [None] * n),
+ )
+ ids = ids or set(sql.col0(Account.query(Account.id_)))
# Get Asset quantities on start date
query = (
- s.query(TransactionSplit)
- .with_entities(
+ TransactionSplit.query(
TransactionSplit.account_id,
TransactionSplit.asset_id,
func.sum(TransactionSplit.asset_quantity),
@@ -530,27 +494,17 @@ def get_asset_qty_all(
TransactionSplit.asset_id,
)
)
-
- for acct_id, a_id, qty in query.yield_per(YIELD_PER):
- acct_id: int
- a_id: int
- qty: Decimal
- iv_accounts[acct_id][a_id] = qty
-
- # Daily delta in qty
- deltas_accounts: dict[int, dict[int, list[Decimal | None]]] = defaultdict(
- lambda: defaultdict(lambda: [None] * n),
- )
- for acct_id, iv in iv_accounts.items():
- deltas = deltas_accounts[acct_id]
- for a_id, v in iv.items():
- deltas[a_id][0] = v
+ for acct_id, a_id, qty in sql.yield_(query):
+ if TYPE_CHECKING:
+ # Enforced by query and SQL constraints
+ assert a_id is not None
+ assert qty is not None
+ deltas_accounts[acct_id][a_id][0] = qty
if start_ord != end_ord:
# Transactions between start and end
query = (
- s.query(TransactionSplit)
- .with_entities(
+ TransactionSplit.query(
TransactionSplit.date_ord,
TransactionSplit.account_id,
TransactionSplit.asset_id,
@@ -565,15 +519,14 @@ def get_asset_qty_all(
.order_by(TransactionSplit.account_id)
)
- current_acct_id = None
+ current_acct_id: int | None = None
deltas = {}
- for date_ord, acct_id, a_id, qty in query.yield_per(YIELD_PER):
- date_ord: int
- acct_id: int
- a_id: int
- qty: Decimal
-
+ for date_ord, acct_id, a_id, qty in sql.yield_(query):
+ if TYPE_CHECKING:
+ # Enforced by query and SQL constraints
+ assert a_id is not None
+ assert qty is not None
i = date_ord - start_ord
if acct_id != current_acct_id:
@@ -608,13 +561,11 @@ def get_asset_qty(
dict{Asset.id_: list[values]}
"""
- s = obj_session(self)
- return self.get_asset_qty_all(s, start_ord, end_ord, [self.id_])[self.id_]
+ return self.get_asset_qty_all(start_ord, end_ord, [self.id_])[self.id_]
@classmethod
def get_profit_by_asset_all(
cls,
- s: orm.Session,
start_ord: int,
end_ord: int,
ids: Iterable[int] | None = None,
@@ -622,7 +573,6 @@ def get_profit_by_asset_all(
"""Get the profit of Assets on end_date since start_ord.
Args:
- s: SQL session to use
start_ord: First date ordinal to evaluate
end_ord: Last date ordinal to evaluate (inclusive)
ids: Limit results to specific Accounts by ID
@@ -633,9 +583,9 @@ def get_profit_by_asset_all(
"""
# Get Asset quantities on start date
+ initial_qty: dict[int, Decimal] = defaultdict(Decimal)
query = (
- s.query(TransactionSplit)
- .with_entities(
+ TransactionSplit.query(
TransactionSplit.asset_id,
func.sum(TransactionSplit.asset_quantity),
)
@@ -647,40 +597,38 @@ def get_profit_by_asset_all(
)
if ids is not None:
query = query.where(TransactionSplit.account_id.in_(ids))
-
- initial_qty: dict[int, Decimal] = defaultdict(
- Decimal,
- {a_id: qty for a_id, qty in query.yield_per(YIELD_PER) if qty != 0},
- )
-
- query = (
- s.query(TransactionSplit)
- .with_entities(
- TransactionSplit.asset_id,
- TransactionSplit.asset_quantity,
- TransactionSplit.amount,
- )
- .where(
- TransactionSplit.asset_id.is_not(None),
- TransactionSplit.date_ord >= start_ord,
- TransactionSplit.date_ord <= end_ord,
- )
+ for a_id, qty in sql.yield_(query):
+ if TYPE_CHECKING:
+ # Enforced by query and SQL constraints
+ assert a_id is not None
+ assert qty is not None
+ initial_qty[a_id] = qty
+
+ query = TransactionSplit.query(
+ TransactionSplit.asset_id,
+ TransactionSplit.asset_quantity,
+ TransactionSplit.amount,
+ ).where(
+ TransactionSplit.asset_id.is_not(None),
+ TransactionSplit.date_ord >= start_ord,
+ TransactionSplit.date_ord <= end_ord,
)
if ids is not None:
query = query.where(TransactionSplit.account_id.in_(ids))
cost_basis: dict[int, Decimal] = defaultdict(Decimal)
end_qty: dict[int, Decimal] = initial_qty.copy()
- for a_id, qty, amount in query.yield_per(YIELD_PER):
- a_id: int
- qty: Decimal
- amount: Decimal
+ for a_id, qty, amount in sql.yield_(query):
+ if TYPE_CHECKING:
+ # Enforced by query and SQL constraints
+ assert a_id is not None
+ assert qty is not None
end_qty[a_id] += qty
cost_basis[a_id] += amount
a_ids = set(end_qty)
- initial_price = Asset.get_value_all(s, start_ord, start_ord, ids=a_ids)
- end_price = Asset.get_value_all(s, end_ord, end_ord, ids=a_ids)
+ initial_price = Asset.get_value_all(start_ord, start_ord, ids=a_ids)
+ end_price = Asset.get_value_all(end_ord, end_ord, ids=a_ids)
profits: dict[int, Decimal] = defaultdict(Decimal)
for a_id in a_ids:
@@ -707,23 +655,21 @@ def get_profit_by_asset(
dict{Asset.id_: profit}
"""
- s = obj_session(self)
- return self.get_profit_by_asset_all(s, start_ord, end_ord, [self.id_])
+ return self.get_profit_by_asset_all(start_ord, end_ord, [self.id_])
@classmethod
- def ids(cls, s: orm.Session, category: AccountCategory) -> set[int]:
+ def ids(cls, category: AccountCategory) -> set[int]:
"""Get Account ids for a specific category.
Args:
- s: SQL session to use
category: AccountCategory to filter
Returns:
set{Account.id_}
"""
- query = s.query(Account.id_).where(Account.category == category)
- return {acct_id for acct_id, in query.all()}
+ query = Account.query(Account.id_).where(Account.category == category)
+ return {acct_id for acct_id, in sql.yield_(query)}
def do_include(self, date_ord: int) -> bool:
"""Test if account should be included for data.
diff --git a/nummus/models/asset.py b/nummus/models/asset.py
index 170d47ae..8ba6ea6e 100644
--- a/nummus/models/asset.py
+++ b/nummus/models/asset.py
@@ -8,13 +8,12 @@
from decimal import Decimal
from typing import override, TYPE_CHECKING
-import pandas as pd
import yfinance
import yfinance.exceptions
from sqlalchemy import CheckConstraint, ForeignKey, func, Index, orm, UniqueConstraint
from nummus import exceptions as exc
-from nummus import utils
+from nummus import sql, utils
from nummus.models.base import (
Base,
BaseEnum,
@@ -26,11 +25,10 @@
ORMStrOpt,
SQLEnum,
string_column_args,
- YIELD_PER,
)
from nummus.models.currency import Currency, DEFAULT_CURRENCY
from nummus.models.transaction import TransactionSplit
-from nummus.models.utils import obj_session, query_count, query_to_dict, update_rows
+from nummus.models.utils import update_rows
if TYPE_CHECKING:
from collections.abc import Iterable, Mapping
@@ -230,6 +228,8 @@ class Asset(Base):
*string_column_args("ticker", short_check=False),
)
+ _SEARCH_PROPERTIES = ("ticker", "name")
+
@orm.validates("name", "description", "ticker")
def validate_strings(self, key: str, field: str | None) -> str | None:
"""Validate string fields satisfy constraints.
@@ -247,7 +247,6 @@ def validate_strings(self, key: str, field: str | None) -> str | None:
@classmethod
def get_value_all(
cls,
- s: orm.Session,
start_ord: int,
end_ord: int,
ids: Iterable[int] | None = None,
@@ -255,7 +254,6 @@ def get_value_all(
"""Get the value of all Assets from start to end date.
Args:
- s: SQL session to use
start_ord: First date ordinal to evaluate
end_ord: Last date ordinal to evaluate (inclusive)
ids: Limit results to specific Assets by ID
@@ -269,15 +267,14 @@ def get_value_all(
# Get a list of valuations (date offset, value) for each Asset
valuations_assets: dict[int, list[tuple[int, Decimal]]] = defaultdict(list)
- query = s.query(Asset.id_).where(Asset.interpolate)
+ query = Asset.query(Asset.id_).where(Asset.interpolate)
if ids is not None:
query = query.where(Asset.id_.in_(ids))
- interpolated_assets: set[int] = {r[0] for r in query.all()}
+ interpolated_assets: set[int] = {r[0] for r in sql.yield_(query)}
# Get latest Valuation before or including start date
query = (
- s.query(AssetValuation)
- .with_entities(
+ AssetValuation.query(
AssetValuation.asset_id,
func.max(AssetValuation.date_ord),
AssetValuation.value,
@@ -287,41 +284,30 @@ def get_value_all(
)
if ids is not None:
query = query.where(AssetValuation.asset_id.in_(ids))
- for a_id, date_ord, v in query.all():
- a_id: int
- date_ord: int
- v: Decimal
+ for a_id, date_ord, v in sql.yield_(query):
i = date_ord - start_ord
valuations_assets[a_id] = [(i, v)]
if start_ord != end_ord:
# Transactions between start and end
- query = (
- s.query(AssetValuation)
- .with_entities(
- AssetValuation.asset_id,
- AssetValuation.date_ord,
- AssetValuation.value,
- )
- .where(
- AssetValuation.date_ord <= end_ord,
- AssetValuation.date_ord > start_ord,
- )
+ query = AssetValuation.query(
+ AssetValuation.asset_id,
+ AssetValuation.date_ord,
+ AssetValuation.value,
+ ).where(
+ AssetValuation.date_ord <= end_ord,
+ AssetValuation.date_ord > start_ord,
)
if ids is not None:
query = query.where(AssetValuation.asset_id.in_(ids))
- for a_id, date_ord, v in query.yield_per(YIELD_PER):
- a_id: int
- date_ord: int
- v: Decimal
+ for a_id, date_ord, v in sql.yield_(query):
i = date_ord - start_ord
valuations_assets[a_id].append((i, v))
# Get interpolation point for assets with interpolation
query = (
- s.query(AssetValuation)
- .with_entities(
+ AssetValuation.query(
AssetValuation.asset_id,
func.min(AssetValuation.date_ord),
AssetValuation.value,
@@ -332,10 +318,7 @@ def get_value_all(
)
.group_by(AssetValuation.asset_id)
)
- for a_id, date_ord, v in query.all():
- a_id: int
- date_ord: int
- v: Decimal
+ for a_id, date_ord, v in sql.yield_(query):
i = date_ord - start_ord
valuations_assets[a_id].append((i, v))
@@ -360,12 +343,9 @@ def get_value(self, start_ord: int, end_ord: int) -> list[Decimal]:
list[values]
"""
- s = obj_session(self)
-
# Not reusing get_value_all is faster by ~2ms,
# not worth maintaining two almost identical implementations
-
- return self.get_value_all(s, start_ord, end_ord, [self.id_])[self.id_]
+ return self.get_value_all(start_ord, end_ord, [self.id_])[self.id_]
def update_splits(self) -> None:
"""Recalculate adjusted TransactionSplit.asset_quantity based on all splits.
@@ -375,21 +355,16 @@ def update_splits(self) -> None:
# This function is best here but need to avoid circular imports
from nummus.models.transaction import TransactionSplit # noqa: PLC0415
- s = obj_session(self)
-
multiplier = Decimal(1)
splits: list[tuple[int, Decimal]] = []
query = (
- s.query(AssetSplit)
- .with_entities(AssetSplit.date_ord, AssetSplit.multiplier)
+ AssetSplit.query(AssetSplit.date_ord, AssetSplit.multiplier)
.where(AssetSplit.asset_id == self.id_)
.order_by(AssetSplit.date_ord.desc())
)
- for s_date_ord, s_multiplier in query.yield_per(YIELD_PER):
- s_date_ord: int
- s_multiplier: Decimal
+ for s_date_ord, s_multiplier in sql.yield_(query):
# Compound splits as we go
multiplier *= s_multiplier
splits.append((s_date_ord, multiplier))
@@ -400,7 +375,7 @@ def update_splits(self) -> None:
multiplier = Decimal(1)
query = (
- s.query(TransactionSplit)
+ TransactionSplit.query()
.where(
TransactionSplit.asset_id == self.id_,
TransactionSplit._asset_qty_unadjusted.isnot(None), # noqa: SLF001
@@ -410,9 +385,7 @@ def update_splits(self) -> None:
sum_unadjusted = Decimal()
sum_adjusted = Decimal()
- for t_split in query.yield_per(YIELD_PER):
- # Query whole object okay, need to set things
- t_split: TransactionSplit
+ for t_split in sql.yield_(query):
# If txn is on/after the split, update the multiplier
while len(splits) >= 1 and t_split.date_ord >= splits[0][0]:
splits.pop(0)
@@ -441,7 +414,6 @@ def prune_valuations(self) -> int:
if self.category == AssetCategory.INDEX:
# If asset is an INDEX, do not prune
return 0
- s = obj_session(self)
# Date when quantity is zero
date_ord_zero: int | None = None
@@ -451,26 +423,25 @@ def prune_valuations(self) -> int:
periods_zero: list[tuple[int | None, int | None]] = []
query = (
- s.query(TransactionSplit)
- .with_entities(
+ TransactionSplit.query(
TransactionSplit.date_ord,
TransactionSplit.asset_quantity,
)
.where(TransactionSplit.asset_id == self.id_)
.order_by(TransactionSplit.date_ord)
)
- if query_count(query) == 0:
+ if not sql.any_(query):
# No transactions, prune all
return (
- s.query(AssetValuation)
+ AssetValuation.query()
.where(AssetValuation.asset_id == self.id_)
.delete()
)
- for date_ord, qty in query.yield_per(YIELD_PER):
- date_ord: int
- qty: Decimal
-
+ for date_ord, qty in sql.yield_(query):
+ if TYPE_CHECKING:
+ # Ensured by query and constraints
+ assert qty is not None
if current_qty == 0:
# Bought some, record the period when zero
date_ord_non_zero = date_ord
@@ -500,32 +471,31 @@ def _delete_valuations(
Number of valuations deleted
"""
- s = obj_session(self)
n_deleted = 0
for date_ord_sell, date_ord_buy in periods_zero:
trim_start: int | None = None
trim_end: int | None = None
if date_ord_sell is not None:
# Get date of oldest valuation after or on the sell
- query = s.query(func.min(AssetValuation.date_ord)).where(
+ query = AssetValuation.query(func.min(AssetValuation.date_ord)).where(
AssetValuation.asset_id == self.id_,
AssetValuation.date_ord >= date_ord_sell,
)
- trim_start = query.scalar()
+ trim_start = sql.scalar(query)
if date_ord_buy is not None:
# Get date of most recent valuation or on before the buy
- query = s.query(func.max(AssetValuation.date_ord)).where(
+ query = AssetValuation.query(func.max(AssetValuation.date_ord)).where(
AssetValuation.asset_id == self.id_,
AssetValuation.date_ord <= date_ord_buy,
)
- trim_end = query.scalar()
+ trim_end = sql.scalar(query)
if trim_start is None and trim_end is None:
# Can happen if no valuations exist before/after a transaction
continue
- query = s.query(AssetValuation).where(AssetValuation.asset_id == self.id_)
+ query = AssetValuation.query().where(AssetValuation.asset_id == self.id_)
if trim_start:
query = query.where(AssetValuation.date_ord > trim_start)
if trim_end:
@@ -552,18 +522,15 @@ def update_valuations(
Raises:
NoAssetWebSourceError: If Asset has no ticker
AssetWebError: If failed to download data
- TypeError: If returned DataFrame indices aren't Timestamps
"""
if self.ticker is None:
raise exc.NoAssetWebSourceError
- s = obj_session(self)
-
today = datetime.datetime.now(datetime.UTC).date()
today_ord = today.toordinal()
- query = s.query(TransactionSplit).with_entities(
+ query = TransactionSplit.query(
func.min(TransactionSplit.date_ord),
func.max(TransactionSplit.date_ord),
)
@@ -575,10 +542,8 @@ def update_valuations(
through_today = True
else:
query = query.where(TransactionSplit.asset_id == self.id_)
- start_ord, end_ord = query.one()
- start_ord: int | None
- end_ord: int | None
- if start_ord is None or end_ord is None:
+ start_ord, end_ord = sql.one(query)
+ if not start_ord or not end_ord:
return None, None
start_ord -= utils.DAYS_IN_WEEK
@@ -601,15 +566,12 @@ def update_valuations(
# yfinance raises Exception if no data found
raise exc.AssetWebError(e) from e
- valuations: dict[int, float] = {}
- for k, v in raw["Close"].items():
- if not isinstance(k, pd.Timestamp):
- raise TypeError
- valuations[k.to_pydatetime().date().toordinal()] = float(v)
+ valuations: dict[int, float] = utils.pd_series_to_dict(
+ raw["Close"], # type: ignore[attr-defined]
+ )
- query = s.query(AssetValuation).where(AssetValuation.asset_id == self.id_)
+ query = AssetValuation.query().where(AssetValuation.asset_id == self.id_)
update_rows(
- s,
AssetValuation,
query,
"date_ord",
@@ -620,13 +582,12 @@ def update_valuations(
},
)
- raw_splits = raw.loc[raw["Stock Splits"] != 0]["Stock Splits"]
- splits: dict[int, float] = {
- k.to_pydatetime().date().toordinal(): v for k, v in raw_splits.items()
- }
- query = s.query(AssetSplit).where(AssetSplit.asset_id == self.id_)
+ splits: dict[int, float] = utils.pd_series_to_dict(
+ raw.loc[raw["Stock Splits"] != 0]["Stock Splits"], # type: ignore[attr-defined]
+ )
+
+ query = AssetSplit.query().where(AssetSplit.asset_id == self.id_)
update_rows(
- s,
AssetSplit,
query,
"date_ord",
@@ -651,8 +612,6 @@ def update_sectors(self) -> None:
if self.ticker is None:
raise exc.NoAssetWebSourceError
- s = obj_session(self)
-
yf_ticker = yfinance.Ticker(self.ticker)
funds = yf_ticker.funds_data
try:
@@ -665,13 +624,12 @@ def update_sectors(self) -> None:
# Not a fund
sector = yf_ticker.info.get("sector")
if sector is None:
- s.query(AssetSector).where(AssetSector.asset_id == self.id_).delete()
+ AssetSector.query().where(AssetSector.asset_id == self.id_).delete()
return
weights = {USSector(sector): Decimal(1)}
- query = s.query(AssetSector).where(AssetSector.asset_id == self.id_)
+ query = AssetSector.query().where(AssetSector.asset_id == self.id_)
update_rows(
- s,
AssetSector,
query,
"sector",
@@ -684,7 +642,6 @@ def update_sectors(self) -> None:
@classmethod
def index_twrr(
cls,
- s: orm.Session,
name: str,
start_ord: int,
end_ord: int,
@@ -692,7 +649,6 @@ def index_twrr(
"""Get the TWRR for an index from start to end date.
Args:
- s: SQL session to use
name: Name of index
start_ord: First date ordinal to evaluate
end_ord: Last date ordinal to evaluate (inclusive)
@@ -705,22 +661,17 @@ def index_twrr(
"""
try:
- a_id = s.query(Asset.id_).where(Asset.name == name).one()[0]
+ a_id = sql.one(Asset.query(Asset.id_).where(Asset.name == name))
except exc.NoResultFound as e:
msg = f"Could not find asset index {name}"
raise exc.ProtectedObjectNotFoundError(msg) from e
- values = cls.get_value_all(s, start_ord, end_ord, ids=[a_id])[a_id]
+ values = cls.get_value_all(start_ord, end_ord, ids=[a_id])[a_id]
cost_basis = values[0]
return utils.twrr(values, [v - cost_basis for v in values])
@classmethod
- def add_indices(cls, s: orm.Session) -> None:
- """Add Asset indices used for performance comparison.
-
- Args:
- s: SQL session to use
-
- """
+ def add_indices(cls) -> None:
+ """Add Asset indices used for performance comparison."""
indices: dict[str, dict[str, str]] = {
"^GSPC": {
"name": "S&P 500",
@@ -766,7 +717,7 @@ def add_indices(cls, s: orm.Session) -> None:
},
}
for ticker, item in indices.items():
- a = Asset(
+ cls.create(
name=item["name"],
description=item["description"],
category=AssetCategory.INDEX,
@@ -774,21 +725,18 @@ def add_indices(cls, s: orm.Session) -> None:
ticker=ticker,
currency=DEFAULT_CURRENCY,
)
- s.add(a)
def autodetect_interpolate(self) -> None:
"""Autodetect if Asset needs interpolation.
Does not commit changes, call s.commit() afterwards.
"""
- s = obj_session(self)
-
query = (
- s.query(AssetValuation.date_ord)
+ AssetValuation.query(AssetValuation.date_ord)
.where(AssetValuation.asset_id == self.id_)
.order_by(AssetValuation.date_ord)
)
- date_ords = [r[0] for r in query.yield_per(YIELD_PER)]
+ date_ords = [r[0] for r in sql.yield_(query)]
has_dailys = any(
(date_ords[i] - date_ords[i - 1]) == 1 for i in range(1, len(date_ords))
)
@@ -798,14 +746,12 @@ def autodetect_interpolate(self) -> None:
@classmethod
def create_forex(
cls,
- s: orm.Session,
base: Currency,
others: set[Currency],
) -> None:
"""Create foreign exchange rate assets.
Args:
- s: SQL session to use
base: Base currency to get FOREX referenced to
others: Other currencys to get
@@ -813,8 +759,16 @@ def create_forex(
if base in others:
others.discard(base)
- query = s.query(Asset.ticker).where(Asset.category == AssetCategory.FOREX)
- existing: set[str] = {r[0] for r in query.all()}
+ query = Asset.query(Asset.ticker).where(
+ Asset.category == AssetCategory.FOREX,
+ Asset.ticker.is_not(None),
+ )
+ existing: set[str] = set()
+ for (ticker,) in sql.yield_(query):
+ if TYPE_CHECKING:
+ # Ensured by query
+ assert ticker is not None
+ existing.add(ticker)
for other in others:
ticker = f"{other.name}{base.name}=X"
@@ -822,31 +776,28 @@ def create_forex(
existing.discard(ticker)
continue
- asset = Asset(
+ cls.create(
name=f"{other.name} to {base.name}",
description=f"Exchange rate from {other.pretty} to {base.pretty}",
category=AssetCategory.FOREX,
ticker=ticker,
currency=base,
)
- s.add(asset)
existing.discard(ticker)
# existing has unused FOREX assets
# TODO (WattsUp): #463 Handle when accounts hold FOREX
- to_delete = {
- r[0] for r in s.query(Asset.id_).where(Asset.ticker.in_(existing)).all()
- }
- s.query(AssetValuation).where(AssetValuation.asset_id.in_(to_delete)).delete()
- s.query(AssetSplit).where(AssetSplit.asset_id.in_(to_delete)).delete()
- s.query(AssetSector).where(AssetSector.asset_id.in_(to_delete)).delete()
- s.query(Asset).where(Asset.id_.in_(to_delete)).delete()
+ query = Asset.query(Asset.id_).where(Asset.ticker.in_(existing))
+ to_delete = {r[0] for r in sql.yield_(query)}
+ AssetValuation.query().where(AssetValuation.asset_id.in_(to_delete)).delete()
+ AssetSplit.query().where(AssetSplit.asset_id.in_(to_delete)).delete()
+ AssetSector.query().where(AssetSector.asset_id.in_(to_delete)).delete()
+ Asset.query().where(Asset.id_.in_(to_delete)).delete()
@classmethod
def get_forex(
cls,
- s: orm.Session,
start_ord: int,
end_ord: int,
base: Currency,
@@ -855,7 +806,6 @@ def get_forex(
"""Get foreign exchange rate over time.
Args:
- s: SQL session to use
start_ord: First date ordinal to evaluate
end_ord: Last date ordinal to evaluate (inclusive)
base: Base currency to exchange to
@@ -875,14 +825,14 @@ def get_forex(
f"{other.name}{base.name}=X": other for other in currencies
}
# null ticker filtered out by query
- query = s.query(Asset.id_, Asset.ticker).where(
+ query = Asset.query(Asset.id_, Asset.ticker).where(
Asset.category == AssetCategory.FOREX,
Asset.currency == base,
)
query = query.where(Asset.ticker.in_(currencies_by_ticker))
- assets = query_to_dict(query)
+ assets = sql.to_dict(query)
- values = cls.get_value_all(s, start_ord, end_ord, assets.keys())
+ values = cls.get_value_all(start_ord, end_ord, assets.keys())
forex: dict[Currency, list[Decimal]] = defaultdict(
lambda: [Decimal(1)] * (end_ord - start_ord + 1),
diff --git a/nummus/models/base.py b/nummus/models/base.py
index 3805b72b..e3f0f485 100644
--- a/nummus/models/base.py
+++ b/nummus/models/base.py
@@ -2,23 +2,23 @@
from __future__ import annotations
+import contextlib
import enum
from decimal import Decimal
-from typing import ClassVar, override, TYPE_CHECKING
+from typing import ClassVar, NamedTuple, overload, override, Self, TYPE_CHECKING
import sqlalchemy
-from sqlalchemy import CheckConstraint, orm, sql, types
+from sqlalchemy import CheckConstraint, orm, types
from nummus import exceptions as exc
-from nummus import utils
+from nummus import sql, utils
from nummus.models import base_uri
if TYPE_CHECKING:
- from collections.abc import Iterable, Mapping
+ from collections.abc import Generator, Iterable, Mapping
+ import sqlalchemy.sql.roles as sql_roles
-# Yield per instead of fetch all is faster
-YIELD_PER = 100
ORMBool = orm.Mapped[bool]
ORMBoolOpt = orm.Mapped[bool | None]
@@ -30,7 +30,202 @@
ORMRealOpt = orm.Mapped[Decimal | None]
-class Base(orm.DeclarativeBase):
+class NamePair(NamedTuple):
+ """ID & name pair."""
+
+ id_: int
+ name: str | None
+
+
+class SessionMixIn:
+ """Mix-in that provides a session reference to the type."""
+
+ _sessions: ClassVar[list[orm.Session]] = []
+
+ @classmethod
+ @contextlib.contextmanager
+ def set_session(cls, s: orm.Session) -> Generator[None]:
+ """Set session used by active record.
+
+ Yields:
+ SQL session
+
+ """
+ cls._sessions.append(s)
+ try:
+ yield
+ finally:
+ cls._sessions.pop()
+
+ @classmethod
+ def session(cls) -> orm.Session:
+ """Get scoped session.
+
+ Returns:
+ SQL session
+
+ Raises:
+ UnboundExecutionError: set_session has not been called yet
+
+ """
+ if not cls._sessions:
+ raise exc.UnboundExecutionError
+ return cls._sessions[-1]
+
+
+class QueryMixIn(SessionMixIn):
+ """Mix-in that provides a query interface to the type."""
+
+ @classmethod
+ def create(cls, **kwargs: object) -> Self:
+ """Create a new instance.
+
+ Args:
+ kwargs: Passed to init
+
+ Returns:
+ New instance
+
+ """
+ i = cls(**kwargs)
+ s = cls.session()
+ s.add(i) # nummus: ignore
+ s.flush()
+ return i
+
+ def delete(self) -> None:
+ """Delete an instance."""
+ self.session().delete(self)
+
+ def refresh(self) -> None:
+ """Refresh an instance."""
+ self.session().refresh(self)
+
+ @overload
+ @classmethod
+ def query(cls) -> orm.Query[Self]: ...
+
+ @overload
+ @classmethod
+ def query[T0](
+ cls,
+ c0: sql_roles.TypedColumnsClauseRole[T0],
+ ) -> orm.query.RowReturningQuery[tuple[T0]]: ...
+
+ @overload
+ @classmethod
+ def query[T0, T1](
+ cls,
+ c0: sql_roles.TypedColumnsClauseRole[T0],
+ c1: sql_roles.TypedColumnsClauseRole[T1],
+ ) -> orm.query.RowReturningQuery[tuple[T0, T1]]: ...
+
+ @overload
+ @classmethod
+ def query[T0, T1, T2](
+ cls,
+ c0: sql_roles.TypedColumnsClauseRole[T0],
+ c1: sql_roles.TypedColumnsClauseRole[T1],
+ c2: sql_roles.TypedColumnsClauseRole[T2],
+ ) -> orm.query.RowReturningQuery[tuple[T0, T1, T2]]: ...
+
+ @overload
+ @classmethod
+ def query[T0, T1, T2, T3](
+ cls,
+ c0: sql_roles.TypedColumnsClauseRole[T0],
+ c1: sql_roles.TypedColumnsClauseRole[T1],
+ c2: sql_roles.TypedColumnsClauseRole[T2],
+ c3: sql_roles.TypedColumnsClauseRole[T3],
+ ) -> orm.query.RowReturningQuery[tuple[T0, T1, T2, T3]]: ...
+
+ @overload
+ @classmethod
+ def query[T0, T1, T2, T3, T4](
+ cls,
+ c0: sql_roles.TypedColumnsClauseRole[T0],
+ c1: sql_roles.TypedColumnsClauseRole[T1],
+ c2: sql_roles.TypedColumnsClauseRole[T2],
+ c3: sql_roles.TypedColumnsClauseRole[T3],
+ c4: sql_roles.TypedColumnsClauseRole[T4],
+ ) -> orm.query.RowReturningQuery[tuple[T0, T1, T2, T3, T4]]: ...
+
+ @overload
+ @classmethod
+ def query[T0, T1, T2, T3, T4, T5](
+ cls,
+ c0: sql_roles.TypedColumnsClauseRole[T0],
+ c1: sql_roles.TypedColumnsClauseRole[T1],
+ c2: sql_roles.TypedColumnsClauseRole[T2],
+ c3: sql_roles.TypedColumnsClauseRole[T3],
+ c4: sql_roles.TypedColumnsClauseRole[T4],
+ c5: sql_roles.TypedColumnsClauseRole[T5],
+ ) -> orm.query.RowReturningQuery[tuple[T0, T1, T2, T3, T4, T5]]: ...
+
+ @classmethod
+ def query[T](
+ cls,
+ *columns: sql_roles.TypedColumnsClauseRole[object],
+ **kwargs: T,
+ ) -> orm.Query[Self] | orm.Query[T]:
+ """Create a new query.
+
+ Returns:
+ Query of table
+
+ Raises:
+ NoKeywordArgumentsError: if kwargs are provided
+
+ """
+ if kwargs:
+ raise exc.NoKeywordArgumentsError
+ query: orm.Query[Self] = cls.session().query(cls)
+ if columns:
+ return query.with_entities(*columns) # nummus: ignore
+ return query
+
+ @classmethod
+ def all(cls) -> list[Self]:
+ """Fetch all rows.
+
+ Returns:
+ List of each row object
+
+ """
+ return list(sql.yield_(cls.query()))
+
+ @classmethod
+ def one(cls) -> Self:
+ """Fetch one rows.
+
+ Returns:
+ Only row
+
+ """
+ return sql.one(cls.query())
+
+ @classmethod
+ def first(cls) -> Self | None:
+ """Fetch first rows.
+
+ Returns:
+ First row
+
+ """
+ return cls.query().first()
+
+ @classmethod
+ def count(cls) -> int:
+ """Count number of rows.
+
+ Returns:
+ Number of rows
+
+ """
+ return sql.count(cls.query())
+
+
+class Base(orm.DeclarativeBase, QueryMixIn):
"""Base ORM model.
Attributes:
@@ -43,6 +238,10 @@ class Base(orm.DeclarativeBase):
__table_id__: int | None
+ id_: ORMInt = orm.mapped_column(primary_key=True, autoincrement=True)
+
+ _SEARCH_PROPERTIES: tuple[str, ...] = ()
+
@override
def __init_subclass__(cls, *, skip_register: bool = False, **kw: object) -> None:
super().__init_subclass__(**kw)
@@ -61,15 +260,13 @@ def __init_subclass__(cls, *, skip_register: bool = False, **kw: object) -> None
i += 1
@classmethod
- def metadata_create_all(cls, s: orm.Session) -> None:
+ def metadata_create_all(cls) -> None:
"""Create all tables for nummus models.
Creates tables then commits
- Args:
- s: Session to create tables for
-
"""
+ s = cls.session()
cls.metadata.create_all(s.get_bind(), [m.sql_table() for m in cls._MODELS])
s.commit()
@@ -88,8 +285,6 @@ def sql_table(cls) -> sqlalchemy.Table:
return cls.__table__
raise TypeError
- id_: ORMInt = orm.mapped_column(primary_key=True, autoincrement=True)
-
@classmethod
def id_to_uri(cls, id_: int) -> str:
"""Uniform Resource Identifier derived from id_ and __table_id__.
@@ -132,15 +327,7 @@ def uri_to_id(cls, uri: str) -> int:
@property
def uri(self) -> str:
- """Uniform Resource Identifier derived from id_ and __table_id__.
-
- Raises:
- NoIDError: If object does not have id_
-
- """
- if self.id_ is None:
- msg = f"{self.__class__.__name__} does not have an id_, maybe flush"
- raise exc.NoIDError(msg)
+ """Uniform Resource Identifier derived from id_ and __table_id__."""
return self.id_to_uri(self.id_)
@override
@@ -163,12 +350,9 @@ def __ne__(self, other: Base | object) -> bool:
return not isinstance(other, Base) or self.uri != other.uri
@classmethod
- def map_name(cls, s: orm.Session) -> dict[int, str]:
+ def map_name(cls) -> dict[int, str]:
"""Get mapping between id and names.
- Args:
- s: SQL session to use
-
Returns:
Dictionary {id: name}
@@ -176,13 +360,12 @@ def map_name(cls, s: orm.Session) -> dict[int, str]:
KeyError: if model does not have name property
"""
- attr = getattr(cls, "name", None)
+ attr: orm.QueryableAttribute[str] | None = getattr(cls, "name", None)
if not attr:
msg = f"{cls.__name__} does not have name column"
raise KeyError(msg)
- query = s.query(cls).with_entities(cls.id_, attr)
- return dict(query.all())
+ return sql.to_dict(cls.query(cls.id_, attr))
@classmethod
def clean_strings(
@@ -246,6 +429,66 @@ def clean_emoji_name(cls, s: str) -> str:
"""
return utils.strip_emojis(s).strip().lower()
+ @classmethod
+ def find(
+ cls,
+ search: str,
+ cache: dict[str, NamePair],
+ ) -> NamePair:
+ """Find a matching object by uri, or field value.
+
+ Args:
+ search: Search query
+ cache: Cache results to speed up look ups
+
+ Returns:
+ tuple(id_, name)
+
+ Raises:
+ NoResultFound: if object not found
+
+ """
+ pair = cache.get(search)
+ if pair is not None:
+ return pair
+
+ def cache_and_return(m: Self) -> NamePair:
+ id_ = m.id_
+ name: str | None = getattr(m, "name", None)
+ pair = NamePair(id_, name)
+ cache[search] = pair
+ return pair
+
+ try:
+ # See if query is an URI
+ id_ = cls.uri_to_id(search)
+ except (exc.InvalidURIError, exc.WrongURITypeError):
+ pass
+ else:
+ query = cls.query().where(cls.id_ == id_)
+ if m := sql.scalar(query):
+ return cache_and_return(m)
+
+ for name in cls._SEARCH_PROPERTIES:
+ prop: sql.Column = getattr(cls, name)
+ # Exact?
+ query = cls.query().where(prop == search)
+ if m := sql.scalar(query):
+ return cache_and_return(m)
+
+ # Exact lower case?
+ query = cls.query().where(prop.ilike(search))
+ if m := sql.scalar(query):
+ return cache_and_return(m)
+
+ # For account number, see if there is one ending in the search
+ query = cls.query().where(prop.ilike(f"%{search}"))
+ if name == "number" and (m := sql.scalar(query)):
+ return cache_and_return(m)
+
+ msg = f"{cls.__name__} matching '{search}' could not be found"
+ raise exc.NoResultFound(msg)
+
class BaseEnum(enum.IntEnum):
"""Enum class with a parser."""
@@ -295,7 +538,7 @@ def pretty(self) -> str:
return self.name.replace("_", " ").title()
-class SQLEnum(types.TypeDecorator):
+class SQLEnum(types.TypeDecorator[BaseEnum]):
"""SQL type for enumeration, stores as integer."""
impl = types.Integer
@@ -361,7 +604,7 @@ def process_result_value(
return self._enum_type(value)
-class Decimal6(types.TypeDecorator):
+class Decimal6(types.TypeDecorator[Decimal]):
"""SQL type for fixed point numbers, stores as micro-integer."""
impl = types.BigInteger
@@ -453,7 +696,7 @@ def string_column_args(
Tuple of constraints
"""
- name_col = f"`{name}`" if name in sql.compiler.RESERVED_WORDS else name
+ name_col = sql.escape(name)
checks = [
(
CheckConstraint(
diff --git a/nummus/models/base_uri.py b/nummus/models/base_uri.py
index 48abe232..07453d4f 100644
--- a/nummus/models/base_uri.py
+++ b/nummus/models/base_uri.py
@@ -22,7 +22,7 @@
_ROUNDS = 3
-_CIPHER: Cipher
+_cipher: Cipher
class Cipher:
@@ -203,7 +203,7 @@ def from_bytes(buf: bytes) -> Cipher:
msg = f"Buf is {len(buf)}B long, expected {n}B"
raise ValueError(msg)
- keys = []
+ keys: list[int] = []
for _ in range(_ROUNDS):
buf_next = buf[:ID_BYTES]
buf = buf[ID_BYTES:]
@@ -216,14 +216,14 @@ def from_bytes(buf: bytes) -> Cipher:
def load_cipher(buf: bytes) -> None:
- """Load a Cipher from bytes into _CIPHER.
+ """Load a Cipher from bytes into _cipher.
Args:
buf: Bytes to load
"""
- global _CIPHER
- _CIPHER = Cipher.from_bytes(buf)
+ global _cipher
+ _cipher = Cipher.from_bytes(buf)
def id_to_uri(id_: int) -> str:
@@ -236,7 +236,7 @@ def id_to_uri(id_: int) -> str:
URI, hex encoded, 1:1 mapping
"""
- return _CIPHER.encode(id_).to_bytes(ID_BYTES, _ORDER).hex()
+ return _cipher.encode(id_).to_bytes(ID_BYTES, _ORDER).hex()
def uri_to_id(uri: str) -> int:
@@ -261,4 +261,4 @@ def uri_to_id(uri: str) -> int:
msg = f"URI is not a hex number: {uri}"
raise exc.InvalidURIError(msg) from e
else:
- return _CIPHER.decode(uri_int)
+ return _cipher.decode(uri_int)
diff --git a/nummus/models/budget.py b/nummus/models/budget.py
index 2627e474..28583fa8 100644
--- a/nummus/models/budget.py
+++ b/nummus/models/budget.py
@@ -9,7 +9,7 @@
from sqlalchemy import CheckConstraint, ForeignKey, func, Index, orm, UniqueConstraint
-from nummus import utils
+from nummus import sql, utils
from nummus.models.account import Account
from nummus.models.base import (
Base,
@@ -21,14 +21,12 @@
ORMStr,
SQLEnum,
string_column_args,
- YIELD_PER,
)
from nummus.models.transaction import TransactionSplit
from nummus.models.transaction_category import (
TransactionCategory,
TransactionCategoryGroup,
)
-from nummus.models.utils import query_to_dict
class BudgetAvailableCategory(NamedTuple):
@@ -129,13 +127,11 @@ def validate_decimals(self, key: str, field: Decimal | None) -> Decimal | None:
@classmethod
def get_monthly_available(
cls,
- s: orm.Session,
month: datetime.date,
) -> BudgetAvailable:
"""Get available budget for a month.
Args:
- s: SQL session to use
month: Month to compute budget during
Returns:
@@ -147,44 +143,37 @@ def get_monthly_available(
"""
month_ord = month.toordinal()
- query = s.query(Account).where(Account.budgeted)
+ query = Account.query().where(Account.budgeted)
accounts = {
- acct.id_: acct.name for acct in query.all() if acct.do_include(month_ord)
+ acct.id_: acct.name
+ for acct in sql.yield_(query)
+ if acct.do_include(month_ord)
}
# Starting balance
- query = (
- s.query(TransactionSplit)
- .with_entities(
- func.sum(TransactionSplit.amount),
- )
- .where(
- TransactionSplit.account_id.in_(accounts),
- TransactionSplit.date_ord < month_ord,
- )
+ query = TransactionSplit.query(
+ func.sum(TransactionSplit.amount),
+ ).where(
+ TransactionSplit.account_id.in_(accounts),
+ TransactionSplit.date_ord < month_ord,
)
- starting_balance = query.scalar() or Decimal()
+ starting_balance = sql.scalar(query) or Decimal()
ending_balance = starting_balance
total_available = Decimal()
# Check all categories not INCOME
- budget_categories = {
- t_cat_id
- for t_cat_id, in (
- s.query(TransactionCategory.id_)
- .where(TransactionCategory.group != TransactionCategoryGroup.INCOME)
- .all()
- )
- }
+ query = TransactionCategory.query(TransactionCategory.id_).where(
+ TransactionCategory.group != TransactionCategoryGroup.INCOME,
+ )
+ budget_categories = {t_cat_id for t_cat_id, in sql.yield_(query)}
# Current month's assignment
- query = (
- s.query(BudgetAssignment)
- .with_entities(BudgetAssignment.category_id, BudgetAssignment.amount)
- .where(BudgetAssignment.month_ord == month_ord)
- )
- categories_assigned: dict[int, Decimal] = query_to_dict(query)
+ query = BudgetAssignment.query(
+ BudgetAssignment.category_id,
+ BudgetAssignment.amount,
+ ).where(BudgetAssignment.month_ord == month_ord)
+ categories_assigned: dict[int, Decimal] = sql.to_dict(query)
# Prior months' assignment
min_month_ord = month_ord
@@ -192,8 +181,7 @@ def get_monthly_available(
t_cat_id: {} for t_cat_id in budget_categories
}
query = (
- s.query(BudgetAssignment)
- .with_entities(
+ BudgetAssignment.query(
BudgetAssignment.category_id,
BudgetAssignment.amount,
BudgetAssignment.month_ord,
@@ -201,7 +189,7 @@ def get_monthly_available(
.where(BudgetAssignment.month_ord < month_ord)
.order_by(BudgetAssignment.month_ord)
)
- for cat_id, amount, m_ord in query.yield_per(YIELD_PER):
+ for cat_id, amount, m_ord in sql.yield_(query):
prior_assigned[cat_id][m_ord] = amount
min_month_ord = min(min_month_ord, m_ord)
@@ -210,8 +198,7 @@ def get_monthly_available(
t_cat_id: {} for t_cat_id in budget_categories
}
query = (
- s.query(TransactionSplit)
- .with_entities(
+ TransactionSplit.query(
TransactionSplit.category_id,
func.sum(TransactionSplit.amount),
TransactionSplit.month_ord,
@@ -227,7 +214,7 @@ def get_monthly_available(
TransactionSplit.month_ord,
)
)
- for cat_id, amount, m_ord in query.yield_per(YIELD_PER):
+ for cat_id, amount, m_ord in sql.yield_(query):
prior_activity[cat_id][m_ord] = amount
# Carry over leftover to next months to get current month's leftover amounts
@@ -246,17 +233,14 @@ def get_monthly_available(
date = utils.date_add_months(date, 1)
# Future months' assignment
- query = (
- s.query(BudgetAssignment)
- .with_entities(func.sum(BudgetAssignment.amount))
- .where(BudgetAssignment.month_ord > month_ord)
+ query = BudgetAssignment.query(func.sum(BudgetAssignment.amount)).where(
+ BudgetAssignment.month_ord > month_ord,
)
- future_assigned = query.scalar() or Decimal()
+ future_assigned = sql.scalar(query) or Decimal()
# Current month's activity
query = (
- s.query(TransactionSplit)
- .with_entities(
+ TransactionSplit.query(
TransactionSplit.category_id,
func.sum(TransactionSplit.amount),
)
@@ -266,14 +250,14 @@ def get_monthly_available(
)
.group_by(TransactionSplit.category_id)
)
- categories_activity: dict[int, Decimal] = query_to_dict(query)
+ categories_activity: dict[int, Decimal] = sql.to_dict(query)
categories: dict[int, BudgetAvailableCategory] = {}
- query = s.query(TransactionCategory).with_entities(
+ query = TransactionCategory.query(
TransactionCategory.id_,
TransactionCategory.group,
)
- for t_cat_id, group in query.yield_per(YIELD_PER):
+ for t_cat_id, group in sql.yield_(query):
activity = categories_activity.get(t_cat_id, Decimal())
assigned = categories_assigned.get(t_cat_id, Decimal())
leftover = categories_leftover.get(t_cat_id, Decimal())
@@ -299,7 +283,6 @@ def get_monthly_available(
@classmethod
def get_emergency_fund(
cls,
- s: orm.Session,
start_ord: int,
end_ord: int,
n_lower: int,
@@ -308,7 +291,6 @@ def get_emergency_fund(
"""Get the emergency fund target range and assigned balance.
Args:
- s: SQL session to use
start_ord: First day of calculated range
end_ord: Last day of calculated range
n_lower: Number of days in sliding lower period
@@ -321,30 +303,21 @@ def get_emergency_fund(
n = end_ord - start_ord + 1
n_smoothing = 15
- query = (
- s.query(Account)
- .with_entities(Account.id_, Account.name)
- .where(Account.budgeted)
- )
- accounts: dict[int, str] = query_to_dict(query)
+ query = Account.query(Account.id_, Account.name).where(Account.budgeted)
+ accounts: dict[int, str] = sql.to_dict(query)
- t_cat_id, _ = TransactionCategory.emergency_fund(s)
+ t_cat_id, _ = TransactionCategory.emergency_fund()
- balance = (
- s.query(func.sum(BudgetAssignment.amount))
- .where(
- BudgetAssignment.category_id == t_cat_id,
- BudgetAssignment.month_ord <= start_ord,
- )
- .scalar()
- or Decimal()
+ query = BudgetAssignment.query(func.sum(BudgetAssignment.amount)).where(
+ BudgetAssignment.category_id == t_cat_id,
+ BudgetAssignment.month_ord <= start_ord,
)
+ balance = sql.scalar(query) or Decimal()
balances: list[Decimal] = []
query = (
- s.query(BudgetAssignment)
- .with_entities(BudgetAssignment.month_ord, BudgetAssignment.amount)
+ BudgetAssignment.query(BudgetAssignment.month_ord, BudgetAssignment.amount)
.where(
BudgetAssignment.category_id == t_cat_id,
BudgetAssignment.month_ord > start_ord,
@@ -353,7 +326,7 @@ def get_emergency_fund(
.order_by(BudgetAssignment.month_ord)
)
date_ord = start_ord
- for b_ord, amount in query.all():
+ for b_ord, amount in sql.yield_(query):
while date_ord < b_ord:
balances.append(balance)
date_ord += 1
@@ -368,23 +341,18 @@ def get_emergency_fund(
daily = Decimal()
dailys: list[Decimal] = []
- query = (
- s.query(TransactionCategory)
- .with_entities(
- TransactionCategory.id_,
- TransactionCategory.name,
- TransactionCategory.emoji_name,
- )
- .where(TransactionCategory.essential_spending)
- )
- for t_cat_id, name, emoji_name in query.all():
+ query = TransactionCategory.query(
+ TransactionCategory.id_,
+ TransactionCategory.name,
+ TransactionCategory.emoji_name,
+ ).where(TransactionCategory.essential_spending)
+ for t_cat_id, name, emoji_name in sql.yield_(query):
categories[t_cat_id] = name, emoji_name
categories_total[t_cat_id] = Decimal()
start_ord_dailys = start_ord - n_upper - n_smoothing
query = (
- s.query(TransactionSplit)
- .with_entities(
+ TransactionSplit.query(
TransactionSplit.date_ord,
TransactionSplit.category_id,
func.sum(TransactionSplit.amount),
@@ -397,7 +365,7 @@ def get_emergency_fund(
.group_by(TransactionSplit.date_ord, TransactionSplit.category_id)
)
date_ord = start_ord_dailys
- for t_ord, t_cat_id, amount in query.yield_per(YIELD_PER):
+ for t_ord, t_cat_id, amount in sql.yield_(query):
while date_ord < t_ord:
dailys.append(daily)
date_ord += 1
@@ -438,7 +406,6 @@ def get_emergency_fund(
@classmethod
def move(
cls,
- s: orm.Session,
month_ord: int,
src_cat_id: int | None,
dest_cat_id: int | None,
@@ -447,7 +414,6 @@ def move(
"""Move funds between budget assignments.
Args:
- s: SQL session to use
month_ord: Month of BudgetAssignment
src_cat_id: Source category ID, or None
dest_cat_id: Destination category ID, or None
@@ -457,7 +423,7 @@ def move(
if src_cat_id is not None:
# Remove to_move from src_cat_id
a = (
- s.query(BudgetAssignment)
+ BudgetAssignment.query()
.where(
BudgetAssignment.category_id == src_cat_id,
BudgetAssignment.month_ord == month_ord,
@@ -465,20 +431,19 @@ def move(
.one_or_none()
)
if a is None:
- a = BudgetAssignment(
+ BudgetAssignment.create(
month_ord=month_ord,
amount=-to_move,
category_id=src_cat_id,
)
- s.add(a)
elif a.amount == to_move:
- s.delete(a)
+ a.delete()
else:
a.amount -= to_move
if dest_cat_id is not None:
a = (
- s.query(BudgetAssignment)
+ BudgetAssignment.query()
.where(
BudgetAssignment.category_id == dest_cat_id,
BudgetAssignment.month_ord == month_ord,
@@ -486,12 +451,11 @@ def move(
.one_or_none()
)
if a is None:
- a = BudgetAssignment(
+ BudgetAssignment.create(
month_ord=month_ord,
amount=to_move,
category_id=dest_cat_id,
)
- s.add(a)
else:
a.amount += to_move
diff --git a/nummus/models/config.py b/nummus/models/config.py
index 8fe926cd..4d92983c 100644
--- a/nummus/models/config.py
+++ b/nummus/models/config.py
@@ -8,6 +8,7 @@
from sqlalchemy import orm
from nummus import exceptions as exc
+from nummus import sql
from nummus.models.base import Base, BaseEnum, ORMStr, SQLEnum, string_column_args
from nummus.models.currency import Currency
@@ -56,24 +57,22 @@ def validate_strings(self, key: str, field: str | None) -> str | None:
return self.clean_strings(key, field)
@classmethod
- def set_(cls, s: orm.Session, key: ConfigKey, value: str) -> None:
+ def set_(cls, key: ConfigKey, value: str) -> None:
"""Set a Configuration value.
Args:
- s: SQL session to use
key: ConfigKey to query
value: Value to set
"""
- if s.query(Config).where(Config.key == key).update({"value": value}):
+ if Config.query().where(Config.key == key).update({"value": value}):
return
- s.add(Config(key=key, value=value))
+ Config.create(key=key, value=value)
@overload
@classmethod
def fetch(
cls,
- s: orm.Session,
key: ConfigKey,
*,
no_raise: Literal[False] = False,
@@ -83,7 +82,6 @@ def fetch(
@classmethod
def fetch(
cls,
- s: orm.Session,
key: ConfigKey,
*,
no_raise: Literal[True],
@@ -92,7 +90,6 @@ def fetch(
@classmethod
def fetch(
cls,
- s: orm.Session,
key: ConfigKey,
*,
no_raise: bool = False,
@@ -100,7 +97,6 @@ def fetch(
"""Fetch a Configuration value.
Args:
- s: SQL session to use
key: ConfigKey to query
no_raise: True will return None if missing
@@ -112,7 +108,7 @@ def fetch(
"""
try:
- return s.query(Config.value).where(Config.key == key).one()[0]
+ return sql.one(cls.query(cls.value).where(cls.key == key))
except exc.NoResultFound as e:
if no_raise:
return None
@@ -120,27 +116,21 @@ def fetch(
raise exc.ProtectedObjectNotFoundError(msg) from e
@classmethod
- def db_version(cls, s: orm.Session) -> Version:
+ def db_version(cls) -> Version:
"""Query the database version.
- Args:
- s: SQL session to use
-
Returns:
Version of database
"""
- return Version(Config.fetch(s, ConfigKey.VERSION))
+ return Version(Config.fetch(ConfigKey.VERSION))
@classmethod
- def base_currency(cls, s: orm.Session) -> Currency:
+ def base_currency(cls) -> Currency:
"""Query the basse currency.
- Args:
- s: SQL session to use
-
Returns:
Base currency all accounts are converted into
"""
- return Currency(int(Config.fetch(s, ConfigKey.BASE_CURRENCY)))
+ return Currency(int(Config.fetch(ConfigKey.BASE_CURRENCY)))
diff --git a/nummus/models/label.py b/nummus/models/label.py
index b2131708..39b10a93 100644
--- a/nummus/models/label.py
+++ b/nummus/models/label.py
@@ -4,13 +4,14 @@
from sqlalchemy import ForeignKey, Index, orm, UniqueConstraint
+from nummus import sql
from nummus.models.base import (
Base,
ORMInt,
ORMStr,
string_column_args,
)
-from nummus.models.utils import query_to_dict, update_rows
+from nummus.models.utils import update_rows
class LabelLink(Base):
@@ -34,12 +35,11 @@ class LabelLink(Base):
Index("label_link_t_split_id", "t_split_id"),
)
- @staticmethod
- def add_links(s: orm.Session, split_labels: dict[int, set[str]]) -> None:
+ @classmethod
+ def add_links(cls, split_labels: dict[int, set[str]]) -> None:
"""Add links between TransactionSplits and Labels.
Args:
- s: SQL session to use
split_labels: dict {TransactionSplit: {label names to link}
"""
@@ -51,23 +51,16 @@ def add_links(s: orm.Session, split_labels: dict[int, set[str]]) -> None:
for labels in split_labels.values():
label_names.update(labels)
- query = (
- s.query(Label)
- .with_entities(Label.name, Label.id_)
- .where(Label.name.in_(label_names))
- )
- mapping: dict[str, int] = query_to_dict(query)
+ query = Label.query(Label.name, Label.id_).where(Label.name.in_(label_names))
+ mapping: dict[str, int] = sql.to_dict(query)
- to_add = [Label(name=name) for name in label_names if name not in mapping]
- if to_add:
- s.add_all(to_add)
- s.flush()
- mapping.update({label.name: label.id_ for label in to_add})
+ for name in label_names:
+ if name not in mapping:
+ mapping[name] = Label.create(name=name).id_
for t_split_id, labels in split_labels.items():
- query = s.query(LabelLink).where(LabelLink.t_split_id == t_split_id)
+ query = LabelLink.query().where(LabelLink.t_split_id == t_split_id)
update_rows(
- s,
LabelLink,
query,
"label_id",
diff --git a/nummus/models/transaction.py b/nummus/models/transaction.py
index 74b417b4..b856f5f5 100644
--- a/nummus/models/transaction.py
+++ b/nummus/models/transaction.py
@@ -11,10 +11,10 @@
import sqlalchemy
from rapidfuzz import process
-from sqlalchemy import CheckConstraint, event, ForeignKey, Index, orm
+from sqlalchemy import CheckConstraint, ForeignKey, Index, orm
from nummus import exceptions as exc
-from nummus import utils
+from nummus import sql, utils
from nummus.models.base import (
Base,
Decimal6,
@@ -27,11 +27,9 @@
ORMStr,
ORMStrOpt,
string_column_args,
- YIELD_PER,
)
from nummus.models.label import Label, LabelLink
from nummus.models.transaction_category import TransactionCategory
-from nummus.models.utils import obj_session
if TYPE_CHECKING:
from decimal import Decimal
@@ -92,6 +90,10 @@ class TransactionSplit(Base):
"(asset_quantity IS NOT NULL) == (_asset_qty_unadjusted IS NOT NULL)",
name="asset_quantity and unadjusted must be same null state",
),
+ CheckConstraint(
+ "(asset_id IS NOT NULL) == (_asset_qty_unadjusted IS NOT NULL)",
+ name="asset_id and asset quantity must be same null state",
+ ),
CheckConstraint(
"amount != 0",
"transaction_split.amount must be non-zero",
@@ -225,15 +227,10 @@ def adjust_asset_quantity_residual(self, residual: Decimal) -> None:
@property
def parent(self) -> Transaction:
"""Parent Transaction."""
- s = obj_session(self)
- query = s.query(Transaction).where(Transaction.id_ == self.parent_id)
- return query.one()
+ return sql.one(Transaction.query().where(Transaction.id_ == self.parent_id))
@parent.setter
def parent(self, parent: Transaction) -> None:
- if parent.id_ is None:
- self.parent_tmp = parent
- return
super().__setattr__("parent_id", parent.id_)
super().__setattr__("date_ord", parent.date_ord)
super().__setattr__("month_ord", parent.month_ord)
@@ -270,9 +267,9 @@ def search(
query = query.join(LabelLink, full=True)
tokens_must, tokens_can, tokens_not = utils.tokenize_search_str(search_str)
- category_names = category_names or TransactionCategory.map_name(query.session)
+ category_names = category_names or TransactionCategory.map_name()
category_names_rev = {v: k for k, v in category_names.items()}
- label_names = label_names or Label.map_name(query.session)
+ label_names = label_names or Label.map_name()
label_names_rev = {v.lower(): k for k, v in label_names.items()}
query = cls._search_must(
@@ -283,17 +280,18 @@ def search(
)
query = cls._search_not(query, tokens_not, category_names_rev, label_names_rev)
- sub_query = query.with_entities(TransactionSplit.id_).scalar_subquery()
- query_modified = (
- query.session.query(LabelLink)
- .with_entities(LabelLink.t_split_id, LabelLink.label_id)
- .where(LabelLink.t_split_id.in_(sub_query))
- )
+ sub_query = query.with_entities( # nummus: ignore
+ TransactionSplit.id_,
+ ).scalar_subquery()
+ query_modified = LabelLink.query(
+ LabelLink.t_split_id,
+ LabelLink.label_id,
+ ).where(LabelLink.t_split_id.in_(sub_query))
split_labels: dict[int, set[int]] = defaultdict(set)
- for t_split_id, label_id in query_modified.yield_per(YIELD_PER):
+ for t_split_id, label_id in sql.yield_(query_modified):
split_labels[t_split_id].add(label_id)
- query_modified = query.with_entities(
+ query_modified = query.with_entities( # nummus: ignore
TransactionSplit.id_,
TransactionSplit.date_ord,
TransactionSplit.category_id,
@@ -306,12 +304,7 @@ def search(
date_ord,
cat_id,
text_fields,
- ) in query_modified.yield_per(YIELD_PER):
- t_id: int
- date_ord: int
- cat_id: int
- text_fields: str | None
-
+ ) in sql.yield_(query_modified):
full_text = f"{category_names[cat_id]} {text_fields or ''} " + " ".join(
label_names[label_id] for label_id in split_labels[t_id]
)
@@ -356,7 +349,7 @@ def _search_must(
continue
- clauses_or: list[sqlalchemy.ColumnExpressionArgument] = []
+ clauses_or: list[sql.ColumnClause] = []
categories = {
cat_id
for cat_name, cat_id in category_names.items()
@@ -418,24 +411,6 @@ def _search_not(
return query
-@event.listens_for(TransactionSplit, "before_insert")
-def before_insert_transaction_split(
- _: orm.Mapper,
- __: sqlalchemy.Connection,
- target: TransactionSplit,
-) -> None:
- """Handle event before insert of TransactionSplit.
-
- Args:
- target: TransactionSplit being inserted
-
- """
- # If TransactionSplit has parent_tmp set, move it to real parent
- if hasattr(target, "parent_tmp"):
- target.parent = target.parent_tmp
- delattr(target, "parent_tmp")
-
-
class Transaction(Base):
"""Transaction model for storing an exchange of cash for an asset (or none).
@@ -534,8 +509,6 @@ def find_similar(
Most similar Transaction.id_
"""
- s = obj_session(self)
-
if cache_ok and self.similar_txn_id is not None:
return self.similar_txn_id
@@ -553,26 +526,21 @@ def set_match(matching_row: int | Row[tuple[int]]) -> int:
id_ = matching_row if isinstance(matching_row, int) else matching_row[0]
if set_property:
self.similar_txn_id = id_
- s.flush()
return id_
# Convert txn.amount to the raw SQL value to make a raw query
amount_raw = Transaction.amount.type.process_bind_param(self.amount, None)
sort_closest_amount = sqlalchemy.text(f"abs({amount_raw} - amount)")
- cat_asset_linked = {
- t_cat_id
- for t_cat_id, in (
- s.query(TransactionCategory.id_)
- .where(TransactionCategory.asset_linked.is_(True))
- .all()
- )
- }
+ query = TransactionCategory.query(TransactionCategory.id_).where(
+ TransactionCategory.asset_linked.is_(True),
+ )
+ cat_asset_linked = {t_cat_id for t_cat_id, in sql.yield_(query)}
# Check within Account first, exact matches
# If this matches, great, no post filtering needed
query = (
- s.query(Transaction.id_)
+ Transaction.query(Transaction.id_)
.where(
Transaction.account_id == self.account_id,
Transaction.id_ != self.id_,
@@ -588,7 +556,7 @@ def set_match(matching_row: int | Row[tuple[int]]) -> int:
# Maybe exact statement but different account
query = (
- s.query(Transaction.id_)
+ Transaction.query(Transaction.id_)
.where(
Transaction.id_ != self.id_,
Transaction.amount >= amount_min,
@@ -603,7 +571,7 @@ def set_match(matching_row: int | Row[tuple[int]]) -> int:
# Maybe exact statement but different amount
query = (
- s.query(Transaction.id_)
+ Transaction.query(Transaction.id_)
.where(
Transaction.id_ != self.id_,
Transaction.statement == self.statement,
@@ -616,8 +584,7 @@ def set_match(matching_row: int | Row[tuple[int]]) -> int:
# No statements match, choose highest fuzzy matching statement
query = (
- s.query(Transaction)
- .with_entities(
+ Transaction.query(
Transaction.id_,
Transaction.statement,
)
@@ -630,22 +597,20 @@ def set_match(matching_row: int | Row[tuple[int]]) -> int:
)
statements: dict[int, str] = {
t_id: re.sub(r"[0-9]+", "", statement).lower()
- for t_id, statement in query.yield_per(YIELD_PER)
+ for t_id, statement in sql.yield_(query)
}
if len(statements) == 0:
return None
# Don't match a Transaction if it has a Securities Traded split
- has_asset_linked = {
- id_
- for id_, in (
- s.query(TransactionSplit.parent_id)
- .where(
- TransactionSplit.parent_id.in_(statements),
- TransactionSplit.category_id.in_(cat_asset_linked),
- )
- .distinct()
+ query = (
+ Transaction.query(TransactionSplit.parent_id)
+ .where(
+ TransactionSplit.parent_id.in_(statements),
+ TransactionSplit.category_id.in_(cat_asset_linked),
)
- }
+ .distinct()
+ )
+ has_asset_linked = {id_ for id_, in sql.yield_(query)}
statements = {
t_id: statement
for t_id, statement in statements.items()
@@ -669,17 +634,13 @@ def set_match(matching_row: int | Row[tuple[int]]) -> int:
)
# Add a bonuse points for closeness in price and same account
- query = (
- s.query(Transaction)
- .with_entities(
- Transaction.id_,
- Transaction.account_id,
- Transaction.amount,
- )
- .where(Transaction.id_.in_(matches))
- )
+ query = Transaction.query(
+ Transaction.id_,
+ Transaction.account_id,
+ Transaction.amount,
+ ).where(Transaction.id_.in_(matches))
matches_bonus: dict[int, float] = {}
- for t_id, acct_id, amount in query.yield_per(YIELD_PER):
+ for t_id, acct_id, amount in sql.yield_(query):
# 5% off will reduce score by 5%
amount_diff_percent = abs(amount - self.amount) / self.amount
# Extra 10 points for same account
diff --git a/nummus/models/transaction_category.py b/nummus/models/transaction_category.py
index 03892808..a63e7621 100644
--- a/nummus/models/transaction_category.py
+++ b/nummus/models/transaction_category.py
@@ -7,6 +7,7 @@
from sqlalchemy import CheckConstraint, ForeignKey, orm, UniqueConstraint
from nummus import exceptions as exc
+from nummus import sql
from nummus.models.base import (
Base,
BaseEnum,
@@ -16,7 +17,6 @@
SQLEnum,
string_column_args,
)
-from nummus.models.utils import query_to_dict
class TransactionCategoryGroup(BaseEnum):
@@ -127,12 +127,9 @@ def validate_essential_spending(self, _: str, field: object) -> bool:
return field
@staticmethod
- def add_default(s: orm.Session) -> dict[str, TransactionCategory]:
+ def add_default() -> dict[str, TransactionCategory]:
"""Create default transaction categories.
- Args:
- s: SQL session to use
-
Returns:
Dictionary {name: category}
@@ -232,7 +229,7 @@ class Spec(NamedTuple):
for group, categories in groups.items():
for name, spec in categories.items():
- cat = TransactionCategory(
+ cat = TransactionCategory.create(
emoji_name=name,
group=group,
locked=spec.locked,
@@ -240,7 +237,6 @@ class Spec(NamedTuple):
asset_linked=spec.asset_linked,
essential_spending=spec.essential_spending,
)
- s.add(cat)
d[cat.name] = cat
return d
@@ -248,14 +244,12 @@ class Spec(NamedTuple):
@classmethod
def map_name(
cls,
- s: orm.Session,
*,
no_asset_linked: bool = False,
) -> dict[int, str]:
"""Get mapping between id and names.
Args:
- s: SQL session to use
no_asset_linked: True will not include asset_linked categories
Returns:
@@ -265,29 +259,20 @@ def map_name(
KeyError if model does not have name property
"""
- query = (
- s.query(TransactionCategory)
- .with_entities(
- TransactionCategory.id_,
- TransactionCategory.name,
- )
- .order_by(TransactionCategory.name)
- )
+ query = cls.query(cls.id_, cls.name).order_by(cls.name)
if no_asset_linked:
- query = query.where(TransactionCategory.asset_linked.is_(False))
- return query_to_dict(query)
+ query = query.where(cls.asset_linked.is_(False))
+ return sql.to_dict(query)
@classmethod
def map_name_emoji(
cls,
- s: orm.Session,
*,
no_asset_linked: bool = False,
) -> dict[int, str]:
"""Get mapping between id and names with emojis.
Args:
- s: SQL session to use
no_asset_linked: True will not include asset_linked categories
Returns:
@@ -297,24 +282,16 @@ def map_name_emoji(
KeyError if model does not have name property
"""
- query = (
- s.query(TransactionCategory)
- .with_entities(
- TransactionCategory.id_,
- TransactionCategory.emoji_name,
- )
- .order_by(TransactionCategory.name)
- )
+ query = cls.query(cls.id_, cls.emoji_name).order_by(cls.name)
if no_asset_linked:
- query = query.where(TransactionCategory.asset_linked.is_(False))
- return query_to_dict(query)
+ query = query.where(cls.asset_linked.is_(False))
+ return sql.to_dict(query)
@classmethod
- def _get_protected_id(cls, s: orm.Session, name: str) -> tuple[int, str]:
+ def _get_protected_id(cls, name: str) -> tuple[int, str]:
"""Get the ID and URI of a protected category.
Args:
- s: SQL session to use
name: Name of protected category to fetch
Returns:
@@ -325,23 +302,16 @@ def _get_protected_id(cls, s: orm.Session, name: str) -> tuple[int, str]:
"""
try:
- id_ = (
- s.query(TransactionCategory.id_)
- .where(TransactionCategory.name == name)
- .one()[0]
- )
+ id_ = sql.one(cls.query(cls.id_).where(TransactionCategory.name == name))
except exc.NoResultFound as e:
msg = f"Category {name} not found"
raise exc.ProtectedObjectNotFoundError(msg) from e
return id_, cls.id_to_uri(id_)
@classmethod
- def uncategorized(cls, s: orm.Session) -> tuple[int, str]:
+ def uncategorized(cls) -> tuple[int, str]:
"""Get the ID and URI of the uncategorized category.
- Args:
- s: SQL session to use
-
Returns:
tuple(id_, URI)
@@ -349,15 +319,12 @@ def uncategorized(cls, s: orm.Session) -> tuple[int, str]:
ProtectedObjectNotFound if not found
"""
- return cls._get_protected_id(s, "uncategorized")
+ return cls._get_protected_id("uncategorized")
@classmethod
- def emergency_fund(cls, s: orm.Session) -> tuple[int, str]:
+ def emergency_fund(cls) -> tuple[int, str]:
"""Get the ID and URI of the emergency fund category.
- Args:
- s: SQL session to use
-
Returns:
tuple(id_, URI)
@@ -365,15 +332,12 @@ def emergency_fund(cls, s: orm.Session) -> tuple[int, str]:
ProtectedObjectNotFound if not found
"""
- return cls._get_protected_id(s, "emergency fund")
+ return cls._get_protected_id("emergency fund")
@classmethod
- def securities_traded(cls, s: orm.Session) -> tuple[int, str]:
+ def securities_traded(cls) -> tuple[int, str]:
"""Get the ID and URI of the securities traded category.
- Args:
- s: SQL session to use
-
Returns:
tuple(id_, URI)
@@ -381,4 +345,4 @@ def securities_traded(cls, s: orm.Session) -> tuple[int, str]:
ProtectedObjectNotFound if not found
"""
- return cls._get_protected_id(s, "securities traded")
+ return cls._get_protected_id("securities traded")
diff --git a/nummus/models/utils.py b/nummus/models/utils.py
index 952c6388..88704c9b 100644
--- a/nummus/models/utils.py
+++ b/nummus/models/utils.py
@@ -9,69 +9,25 @@
from sqlalchemy import (
CheckConstraint,
ForeignKeyConstraint,
- func,
- orm,
UniqueConstraint,
)
-from nummus import exceptions as exc
-from nummus.models.base import YIELD_PER
+from nummus import sql
if TYPE_CHECKING:
from sqlalchemy import (
Constraint,
+ orm,
)
from nummus.models.base import Base
-def query_to_dict[K, V](query: orm.query.RowReturningQuery[tuple[K, V]]) -> dict[K, V]:
- """Fetch results from query and return a dict.
-
- Args:
- query: Query that returns 2 columns
-
- Returns:
- dict{first column: second column}
-
- """
- # pyright is happier with comprehension
- # ruff is happier with dict()
- return dict(query.yield_per(YIELD_PER)) # type: ignore[attr-defined]
-
-
-def query_count(query: orm.Query) -> int:
- """Count the number of result a query will return.
-
- Args:
- query: Session query to execute
-
- Returns:
- Number of instances query will return upon execution
-
- Raises:
- TypeError: if query.statement is not a Select
-
- """
- # From here:
- # https://datawookie.dev/blog/2021/01/sqlalchemy-efficient-counting/
- col_one = sqlalchemy.literal_column("1")
- stmt = query.statement
- if not isinstance(stmt, sqlalchemy.Select):
- raise TypeError
- counter = stmt.with_only_columns(
- func.count(col_one),
- maintain_column_froms=True,
- )
- counter = counter.order_by(None)
- return query.session.execute(counter).scalar() or 0
-
-
-def paginate(
- query: orm.Query[Base],
+def paginate[T: Base](
+ query: orm.Query[T],
limit: int,
offset: int,
-) -> tuple[list[Base], int, int | None]:
+) -> tuple[list[T], int, int | None]:
"""Paginate query response for smaller results.
Args:
@@ -87,12 +43,12 @@ def paginate(
offset = max(0, offset)
# Get amount number from filters
- count = query_count(query)
+ count = sql.count(query)
# Apply limiting, and offset
query = query.limit(limit).offset(offset)
- results = query.all()
+ results = list(sql.yield_(query))
# Compute next_offset
n_current = len(results)
@@ -102,14 +58,10 @@ def paginate(
return results, count, next_offset
-def dump_table_configs(
- s: orm.Session,
- model: type[Base],
-) -> list[str]:
+def dump_table_configs(model: type[Base]) -> list[str]:
"""Get the table configs (columns and constraints) and print.
Args:
- s: SQL session to use
model: Filter to specific table
Returns:
@@ -123,26 +75,26 @@ def dump_table_configs(
type='table'
AND name='{model.__tablename__}'
""".strip() # noqa: S608
- result = s.execute(sqlalchemy.text(stmt)).one()[0]
- result: str
+ query: orm.query.RowReturningQuery[tuple[str]] = model.session().execute( # type: ignore[attr-defined]
+ sqlalchemy.text(stmt),
+ )
+ result: str = sql.one(query)
return [s.replace("\t", " ") for s in result.splitlines()]
def get_constraints(
- s: orm.Session,
model: type[Base],
) -> list[tuple[type[Constraint], str]]:
"""Get constraints of a table.
Args:
- s: SQL session to use
model: Filter to specific table
Returns:
list[(Constraint type, construction text)]
"""
- config = "\n".join(dump_table_configs(s, model))
+ config = "\n".join(dump_table_configs(model))
constraints: list[tuple[type[Constraint], str]] = []
re_unique = re.compile(r"UNIQUE \(([^\)]+)\)")
@@ -163,36 +115,15 @@ def get_constraints(
return constraints
-def obj_session(m: Base) -> orm.Session:
- """Get the SQL session for an object.
-
- Args:
- m: Model to get from
-
- Returns:
- Session
-
- Raises:
- UnboundExecutionError: if model is unbound
-
- """
- s = orm.object_session(m)
- if s is None:
- raise exc.UnboundExecutionError
- return s
-
-
-def update_rows(
- s: orm.Session,
- cls: type[Base],
- query: orm.Query,
+def update_rows[T: Base](
+ cls: type[T],
+ query: orm.Query[T],
id_key: str,
updates: dict[object, dict[str, object]],
) -> None:
"""Update many rows, reusing leftovers when possible.
Args:
- s: SQL session to use
cls: Type of model to update
query: Query to fetch all applicable models
id_key: Name of property used for identification
@@ -202,7 +133,7 @@ def update_rows(
updates = updates.copy()
leftovers: list[Base] = []
- for m in query.yield_per(YIELD_PER):
+ for m in sql.yield_(query):
update = updates.pop(getattr(m, id_key), None)
if update is None:
# No longer needed
@@ -211,6 +142,8 @@ def update_rows(
for k, v in update.items():
setattr(m, k, v)
+ s = cls.session()
+
# Add any missing ones
for id_, update in updates.items():
if leftovers:
@@ -219,24 +152,21 @@ def update_rows(
for k, v in update.items():
setattr(m, k, v)
else:
- m = cls(**{id_key: id_, **update})
- s.add(m)
+ cls.create(**{id_key: id_, **update})
# Delete any leftovers
for m in leftovers:
s.delete(m)
-def update_rows_list(
- s: orm.Session,
- cls: type[Base],
- query: orm.Query,
+def update_rows_list[T: Base](
+ cls: type[T],
+ query: orm.Query[T],
updates: list[dict[str, object]],
) -> list[int]:
"""Update many rows, reusing leftovers when possible.
Args:
- s: SQL session to use
cls: Type of model to update
query: Query to fetch all applicable models
updates: list[{parameter: value}]
@@ -250,7 +180,7 @@ def update_rows_list(
updates = updates.copy()
leftovers: list[Base] = []
- for m in query.yield_per(YIELD_PER):
+ for m in sql.yield_(query):
if len(updates) == 0:
# No longer needed
leftovers.append(m)
@@ -260,6 +190,8 @@ def update_rows_list(
setattr(m, k, v)
ids.append(m.id_)
+ s = cls.session()
+
to_add = [cls(**update) for update in updates]
s.add_all(to_add)
@@ -271,17 +203,3 @@ def update_rows_list(
ids.extend(m.id_ for m in to_add)
return ids
-
-
-def one_or_none[T](query: orm.Query[T]) -> T | None:
- """Return one result.
-
- Returns:
- One result
- If no results or multiple, return None
-
- """
- try:
- return query.one_or_none()
- except (exc.NoResultFound, exc.MultipleResultsFound):
- return None
diff --git a/nummus/portfolio.py b/nummus/portfolio.py
index cb43e63c..18c40a90 100644
--- a/nummus/portfolio.py
+++ b/nummus/portfolio.py
@@ -3,6 +3,7 @@
from __future__ import annotations
import base64
+import contextlib
import datetime
import hashlib
import io
@@ -27,27 +28,20 @@
from nummus.migrations.top import MIGRATORS
from nummus.models.account import Account
from nummus.models.asset import Asset
-from nummus.models.base import Base, YIELD_PER
+from nummus.models.base import Base
from nummus.models.base_uri import Cipher, load_cipher
from nummus.models.config import Config, ConfigKey
-from nummus.models.currency import (
- Currency,
- DEFAULT_CURRENCY,
-)
+from nummus.models.currency import DEFAULT_CURRENCY
from nummus.models.imported_file import ImportedFile
from nummus.models.transaction import Transaction, TransactionSplit
from nummus.models.transaction_category import TransactionCategory
-from nummus.models.utils import (
- one_or_none,
- query_to_dict,
-)
from nummus.version import __version__
if TYPE_CHECKING:
- import contextlib
+ from collections.abc import Iterator
from nummus.importers.base import TxnDict
- from nummus.models.currency import Currency
+ from nummus.models.base import NamePair
class AssetUpdate(NamedTuple):
@@ -200,9 +194,9 @@ def create(cls, path: str | Path, key: str | None = None) -> Portfolio:
test_value = enc.encrypt(Portfolio._ENCRYPTION_TEST_VALUE)
engine = sql.get_engine(path_db, enc)
- with orm.Session(engine) as s:
+ with orm.Session(engine) as s, Base.set_session(s):
with s.begin():
- Base.metadata_create_all(s)
+ Base.metadata_create_all()
with s.begin():
# If developing a migration, current version will be less
@@ -212,20 +206,20 @@ def create(cls, path: str | Path, key: str | None = None) -> Portfolio:
*[m.min_version() for m in MIGRATORS],
)
- Config.set_(s, ConfigKey.VERSION, str(v))
- Config.set_(s, ConfigKey.ENCRYPTION_TEST, test_value)
- Config.set_(s, ConfigKey.CIPHER, cipher_b64)
- Config.set_(s, ConfigKey.SECRET_KEY, secrets.token_hex())
- Config.set_(s, ConfigKey.BASE_CURRENCY, str(DEFAULT_CURRENCY.value))
+ Config.set_(ConfigKey.VERSION, str(v))
+ Config.set_(ConfigKey.ENCRYPTION_TEST, test_value)
+ Config.set_(ConfigKey.CIPHER, cipher_b64)
+ Config.set_(ConfigKey.SECRET_KEY, secrets.token_hex())
+ Config.set_(ConfigKey.BASE_CURRENCY, str(DEFAULT_CURRENCY.value))
if enc is not None and key is not None:
- Config.set_(s, ConfigKey.WEB_KEY, enc.encrypt(key))
+ Config.set_(ConfigKey.WEB_KEY, enc.encrypt(key))
path_db.chmod(0o600) # Only owner can read/write
p = Portfolio(path_db, key)
- with p.begin_session() as s:
- TransactionCategory.add_default(s)
- Asset.add_indices(s)
+ with p.begin_session():
+ TransactionCategory.add_default()
+ Asset.add_indices()
return p
def _unlock(self) -> dict[ConfigKey, str]:
@@ -240,9 +234,9 @@ def _unlock(self) -> dict[ConfigKey, str]:
"""
try:
- with self.begin_session() as s:
- query = s.query(Config).with_entities(Config.key, Config.value)
- configs: dict[ConfigKey, str] = query_to_dict(query)
+ with self.begin_session():
+ query = Config.query(Config.key, Config.value)
+ configs: dict[ConfigKey, str] = sql.to_dict(query)
except exc.DatabaseError as e:
msg = f"Failed to open database {self._path_db}"
raise exc.UnlockingError(msg) from e
@@ -280,14 +274,17 @@ def get_engine(self) -> sqlalchemy.Engine:
"""
return sql.get_engine(self._path_db, self._enc)
- def begin_session(self) -> contextlib.AbstractContextManager[orm.Session]:
+ @contextlib.contextmanager
+ def begin_session(self) -> Iterator[orm.Session]:
"""Get SQL Session to the database.
- Returns:
+ Yields:
Open Session
"""
- return self._session_maker.begin()
+ s = self._session_maker()
+ with s, s.begin(), Base.set_session(s):
+ yield s
def encrypt(self, secret: bytes | str) -> str:
"""Encrypt a secret using the key.
@@ -346,8 +343,8 @@ def migration_required(self, version_str: str | None) -> Version | None:
"""
if version_str is None:
- with self.begin_session() as s:
- v_db = Config.db_version(s)
+ with self.begin_session():
+ v_db = Config.db_version()
else:
v_db = Version(version_str)
for m in MIGRATORS[::-1]:
@@ -375,11 +372,13 @@ def import_file(self, path: Path, path_debug: Path, *, force: bool = False) -> N
sha = hashlib.sha256()
sha.update(path.read_bytes())
h = sha.hexdigest()
- with self.begin_session() as s:
+ with self.begin_session():
if force:
- s.query(ImportedFile).where(ImportedFile.hash_ == h).delete()
- existing_date_ord: int | None = (
- s.query(ImportedFile.date_ord).where(ImportedFile.hash_ == h).scalar()
+ ImportedFile.query().where(ImportedFile.hash_ == h).delete()
+ existing_date_ord: int | None = sql.scalar(
+ ImportedFile.query(ImportedFile.date_ord).where(
+ ImportedFile.hash_ == h,
+ ),
)
if existing_date_ord is not None:
date = datetime.date.fromordinal(existing_date_ord)
@@ -388,12 +387,12 @@ def import_file(self, path: Path, path_debug: Path, *, force: bool = False) -> N
i = get_importer(path, path_debug, self._importers)
today = datetime.datetime.now(datetime.UTC).date()
- categories = TransactionCategory.map_name(s)
+ categories = TransactionCategory.map_name()
# Reverse categories for LUT
categories = {v: k for k, v in categories.items()}
# Cache a mapping from account/asset name to the ID
- acct_mapping: dict[str, tuple[int, str | None]] = {}
- asset_mapping: dict[str, tuple[int, str | None]] = {}
+ acct_mapping: dict[str, NamePair] = {}
+ asset_mapping: dict[str, NamePair] = {}
try:
txns_raw = i.run()
except Exception as e:
@@ -406,25 +405,25 @@ def import_file(self, path: Path, path_debug: Path, *, force: bool = False) -> N
# Create a single split for each transaction
acct_raw = d["account"]
- acct_id, _ = self.find(s, Account, acct_raw, acct_mapping)
+ acct_id, _ = Account.find(acct_raw, acct_mapping)
asset_raw = d["asset"]
asset_id: int | None = None
if asset_raw:
- asset_id, asset_name = self.find(s, Asset, asset_raw, asset_mapping)
+ asset_id, asset_name = Asset.find(asset_raw, asset_mapping)
if not d["statement"]:
d["statement"] = f"Asset Transaction {asset_name}"
- self._import_transaction(s, d, acct_id, asset_id, categories)
+ self._import_transaction(d, acct_id, asset_id, categories)
# Add file hash to prevent importing again
- s.add(ImportedFile(hash_=h))
+ ImportedFile.create(hash_=h)
# Update splits on each touched
- query = s.query(Asset).where(
+ query = Asset.query().where(
Asset.id_.in_(a_id for a_id, _ in asset_mapping.values()),
)
- for asset in query.all():
+ for asset in sql.yield_(query):
asset.update_splits()
# If successful, delete the temp file
@@ -432,42 +431,36 @@ def import_file(self, path: Path, path_debug: Path, *, force: bool = False) -> N
def _import_transaction(
self,
- s: orm.Session,
d: TxnDict,
acct_id: int,
asset_id: int | None,
categories: dict[str, int],
) -> None:
if asset_id is not None:
- self._import_asset_transaction(s, d, acct_id, asset_id, categories)
+ self._import_asset_transaction(d, acct_id, asset_id, categories)
return
# See if anything matches
date_ord = d["date"].toordinal()
- matches = list(
- s.query(Transaction)
- .with_entities(Transaction.id_, Transaction.date_ord)
- .where(
- Transaction.account_id == acct_id,
- Transaction.amount == d["amount"],
- Transaction.date_ord >= date_ord - 5,
- Transaction.date_ord <= date_ord + 5,
- Transaction.cleared.is_(False),
- )
- .all(),
+ query = Transaction.query(Transaction.id_, Transaction.date_ord).where(
+ Transaction.account_id == acct_id,
+ Transaction.amount == d["amount"],
+ Transaction.date_ord >= date_ord - 5,
+ Transaction.date_ord <= date_ord + 5,
+ Transaction.cleared.is_(False),
)
- matches = sorted(matches, key=lambda x: abs(x[1] - date_ord))
+ matches = sorted(sql.yield_(query), key=lambda x: abs(x[1] - date_ord))
# If only one match on closest day, link transaction
if len(matches) == 1 or (len(matches) > 1 and matches[0][1] != matches[1][1]):
match_id = matches[0][0]
statement_clean = Transaction.clean_strings("statement", d["statement"])
- s.query(Transaction).where(Transaction.id_ == match_id).update(
+ Transaction.query().where(Transaction.id_ == match_id).update(
{
"cleared": True,
"statement": statement_clean,
},
)
- s.query(TransactionSplit).where(
+ TransactionSplit.query().where(
TransactionSplit.parent_id == match_id,
).update({"cleared": True})
return
@@ -478,7 +471,7 @@ def _import_transaction(
except KeyError:
category_id = categories["uncategorized"]
- txn = Transaction(
+ txn = Transaction.create(
account_id=acct_id,
amount=d["amount"],
date=d["date"],
@@ -486,21 +479,19 @@ def _import_transaction(
payee=d["payee"],
cleared=True,
)
- t_split = TransactionSplit(
+ TransactionSplit.create(
+ parent=txn,
amount=d["amount"],
memo=d["memo"],
category_id=category_id,
)
- t_split.parent = txn
- s.add_all((txn, t_split))
@classmethod
def _import_asset_transaction(
cls,
- s: orm.Session,
d: TxnDict,
acct_id: int,
- asset_id: int,
+ asset_id: int | None,
categories: dict[str, int],
) -> None:
category_name = (d["category"] or "uncategorized").lower()
@@ -513,7 +504,7 @@ def _import_asset_transaction(
raise exc.MissingAssetError(msg)
qty = abs(qty)
- txn = Transaction(
+ txn = Transaction.create(
account_id=acct_id,
amount=0,
date=d["date"],
@@ -521,7 +512,7 @@ def _import_asset_transaction(
payee=d["payee"],
cleared=True,
)
- t_split_0 = TransactionSplit(
+ TransactionSplit.create(
parent=txn,
amount=amount,
memo=d["memo"],
@@ -529,7 +520,7 @@ def _import_asset_transaction(
asset_id=asset_id,
asset_quantity_unadjusted=-qty,
)
- t_split_1 = TransactionSplit(
+ TransactionSplit.create(
parent=txn,
amount=-amount,
memo=d["memo"],
@@ -537,7 +528,6 @@ def _import_asset_transaction(
asset_id=asset_id,
asset_quantity_unadjusted=0,
)
- s.add_all((txn, t_split_0, t_split_1))
return
if category_name == "dividends received":
# Associate dividends with asset
@@ -548,7 +538,7 @@ def _import_asset_transaction(
raise exc.MissingAssetError(msg)
qty = abs(qty)
- txn = Transaction(
+ txn = Transaction.create(
account_id=acct_id,
amount=0,
date=d["date"],
@@ -556,7 +546,7 @@ def _import_asset_transaction(
payee=d["payee"],
cleared=True,
)
- t_split_0 = TransactionSplit(
+ TransactionSplit.create(
parent=txn,
amount=amount,
memo=d["memo"],
@@ -564,21 +554,19 @@ def _import_asset_transaction(
asset_id=asset_id,
asset_quantity_unadjusted=0,
)
- t_split_1 = TransactionSplit(
- parent=txn,
- amount=-amount,
- memo=d["memo"],
- category_id=categories["securities traded"],
- asset_id=asset_id,
- asset_quantity_unadjusted=qty,
- )
- s.add_all((txn, t_split_0))
if qty != 0:
# Zero quantity means cash dividends, not reinvested
- s.add(t_split_1)
+ TransactionSplit.create(
+ parent=txn,
+ amount=-amount,
+ memo=d["memo"],
+ category_id=categories["securities traded"],
+ asset_id=asset_id,
+ asset_quantity_unadjusted=qty,
+ )
return
if category_name == "securities traded":
- txn = Transaction(
+ txn = Transaction.create(
account_id=acct_id,
amount=d["amount"],
date=d["date"],
@@ -586,86 +574,19 @@ def _import_asset_transaction(
payee=d["payee"],
cleared=True,
)
- t_split = TransactionSplit(
+ TransactionSplit(
+ parent=txn,
amount=d["amount"],
memo=d["memo"],
category_id=categories[category_name],
asset_id=asset_id,
asset_quantity_unadjusted=d["asset_quantity"],
)
- t_split.parent = txn
- s.add_all((txn, t_split))
return
msg = f"'{category_name}' is not a valid category for asset transaction"
raise exc.InvalidAssetTransactionCategoryError(msg)
- @classmethod
- def find(
- cls,
- s: orm.Session,
- model: type[Base],
- search: str,
- cache: dict[str, tuple[int, str | None]],
- ) -> tuple[int, str | None]:
- """Find a matching object by uri, or field value.
-
- Args:
- s: Session to use
- model: Type of model to search for
- search: Search query
- cache: Cache results to speed up look ups
-
- Returns:
- tuple(id_, name)
-
- Raises:
- NoResultFound: if object not found
-
- """
- id_, name = cache.get(search, (None, None))
- if id_ is not None:
- return id_, name
-
- def cache_and_return(m: Base) -> tuple[int, str | None]:
- id_ = m.id_
- name: str | None = getattr(m, "name", None)
- cache[search] = id_, name
- return id_, name
-
- try:
- # See if query is an URI
- id_ = model.uri_to_id(search)
- except (exc.InvalidURIError, exc.WrongURITypeError):
- pass
- else:
- query = s.query(model).where(model.id_ == id_)
- if m := one_or_none(query):
- return cache_and_return(m)
-
- properties: list[orm.QueryableAttribute] = {
- Account: [Account.number, Account.institution, Account.name],
- Asset: [Asset.ticker, Asset.name],
- }[model]
- for prop in properties:
- # Exact?
- query = s.query(model).where(prop == search)
- if m := one_or_none(query):
- return cache_and_return(m)
-
- # Exact lower case?
- query = s.query(model).where(prop.ilike(search))
- if m := one_or_none(query):
- return cache_and_return(m)
-
- # For account number, see if there is one ending in the search
- query = s.query(model).where(prop.ilike(f"%{search}"))
- if prop is Account.number and (m := one_or_none(query)):
- return cache_and_return(m)
-
- msg = f"{model.__name__} matching '{search}' could not be found"
- raise exc.NoResultFound(msg)
-
def backup(self) -> tuple[Path, int]:
"""Back up database, duplicates files.
@@ -765,7 +686,7 @@ def clean(self) -> tuple[int, int]:
# Prune unused AssetValuations
with self.begin_session() as s:
- for asset in s.query(Asset).yield_per(YIELD_PER):
+ for asset in Asset.all():
asset.prune_valuations()
asset.autodetect_interpolate()
@@ -924,17 +845,18 @@ def update_assets(self, *, no_bars: bool = False) -> list[AssetUpdate]:
today_ord = today.toordinal()
updated: list[AssetUpdate] = []
- with self.begin_session() as s:
+ with self.begin_session():
# Get FOREXes, add if need be
- currencies: set[Currency] = {r[0] for r in s.query(Account.currency).all()}
- base_currency = Config.base_currency(s)
- Asset.create_forex(s, base_currency, currencies)
+ currencies = set(sql.col0(Account.query(Account.currency)))
+ base_currency = Config.base_currency()
+ Asset.create_forex(base_currency, currencies)
- assets = s.query(Asset).where(Asset.ticker.isnot(None)).all()
+ query = Asset.query().where(Asset.ticker.isnot(None))
+ assets = list(sql.yield_(query))
ids = [asset.id_ for asset in assets]
# Get currently held assets
- asset_qty = Account.get_asset_qty_all(s, today_ord, today_ord)
+ asset_qty = Account.get_asset_qty_all(today_ord, today_ord)
currently_held_assets: set[int] = set()
for acct_assets in asset_qty.values():
for a_id in ids:
@@ -958,7 +880,7 @@ def update_assets(self, *, no_bars: bool = False) -> list[AssetUpdate]:
# start & end are None if there are no transactions for the Asset
# Auto update if asset needs interpolation
- for asset in s.query(Asset).all():
+ for asset in Asset.all():
asset.autodetect_interpolate()
return updated
@@ -1030,8 +952,8 @@ def filter_(tables: list[sqlalchemy.Table]) -> list[sqlalchemy.Table]:
conn_dst.commit()
# Use new encryption key
- with self.begin_session() as s:
- value_encrypted = Config.fetch(s, ConfigKey.WEB_KEY)
+ with self.begin_session():
+ value_encrypted = Config.fetch(ConfigKey.WEB_KEY, no_raise=True)
value = key if value_encrypted is None else self.decrypt_s(value_encrypted)
dst.change_web_key(value)
@@ -1062,5 +984,5 @@ def change_web_key(self, key: str) -> None:
raise exc.InvalidKeyError(msg)
key_encrypted = self.encrypt(key)
- with self.begin_session() as s:
- Config.set_(s, ConfigKey.WEB_KEY, key_encrypted)
+ with self.begin_session():
+ Config.set_(ConfigKey.WEB_KEY, key_encrypted)
diff --git a/nummus/sql.py b/nummus/sql.py
index 15672ca5..f1cb8a34 100644
--- a/nummus/sql.py
+++ b/nummus/sql.py
@@ -4,13 +4,17 @@
import base64
import sys
-from typing import TYPE_CHECKING
+from collections.abc import Sequence
+from typing import overload, TYPE_CHECKING
import sqlalchemy
import sqlalchemy.event
+from sqlalchemy import func, orm
+from sqlalchemy.sql import case
if TYPE_CHECKING:
import sqlite3
+ from collections.abc import Generator, Iterable
from pathlib import Path
from nummus.encryption.base import EncryptionInterface
@@ -23,6 +27,16 @@
_ENGINE_ARGS: dict[str, object] = {}
+Column = (
+ orm.InstrumentedAttribute[str]
+ | orm.InstrumentedAttribute[str | None]
+ | orm.InstrumentedAttribute[int]
+ | orm.InstrumentedAttribute[int | None]
+)
+ColumnClause = sqlalchemy.ColumnElement[bool]
+
+__all__ = ["case"]
+
@sqlalchemy.event.listens_for(sqlalchemy.engine.Engine, "connect")
def set_sqlite_pragma(db_connection: sqlite3.Connection, *_) -> None:
@@ -80,3 +94,201 @@ def escape(s: str) -> str:
"""
return f"`{s}`" if s in sqlalchemy.sql.compiler.RESERVED_WORDS else s
+
+
+@overload
+def to_dict_tuple[K, T0, T1](
+ query: orm.query.RowReturningQuery[tuple[K, T0, T1]],
+) -> dict[K, tuple[T0, T1]]: ...
+
+
+@overload
+def to_dict_tuple[K, T0, T1, T2](
+ query: orm.query.RowReturningQuery[tuple[K, T0, T1, T2]],
+) -> dict[K, tuple[T0, T1, T2]]: ...
+
+
+@overload
+def to_dict_tuple[K, T0, T1, T2, T3](
+ query: orm.query.RowReturningQuery[tuple[K, T0, T1, T2, T3]],
+) -> dict[K, tuple[T0, T1, T2, T3]]: ...
+
+
+def to_dict_tuple[T: tuple[object, ...]]( # type: ignore[attr-defined]
+ query: orm.query.RowReturningQuery[T],
+) -> dict[object, tuple[object, ...]]:
+ """Fetch results from query and return a dict.
+
+ Args:
+ query: Query that returns 2 columns
+
+ Returns:
+ dict{first column: second column}
+ or
+ dict{first column: tuple(other columns)}
+
+ """
+ return {r[0]: r[1:] for r in yield_(query)}
+
+
+def to_dict[K, V](
+ query: orm.query.RowReturningQuery[tuple[K, V]],
+) -> dict[K, V]:
+ """Fetch results from query and return a dict.
+
+ Args:
+ query: Query that returns 2 columns
+
+ Returns:
+ dict{first column: second column}
+
+ """
+ return {r[0]: r[1] for r in yield_(query)}
+
+
+def count[T](query: orm.Query[T]) -> int:
+ """Count the number of result a query will return.
+
+ Args:
+ query: Session query to execute
+
+ Returns:
+ Number of instances query will return upon execution
+
+ Raises:
+ TypeError: if query.statement is not a Select
+
+ """
+ # From here:
+ # https://datawookie.dev/blog/2021/01/sqlalchemy-efficient-counting/
+ col_one: sqlalchemy.ColumnClause[object] = sqlalchemy.literal_column("1")
+ stmt = query.statement
+ if not isinstance(stmt, sqlalchemy.Select):
+ raise TypeError
+ counter = stmt.with_only_columns(
+ func.count(col_one),
+ maintain_column_froms=True,
+ )
+ counter = counter.order_by(None)
+ return query.session.execute(counter).scalar() or 0 # nummus: ignore
+
+
+def any_[T](query: orm.Query[T]) -> bool:
+ """Check if any rows exists in query.
+
+ Args:
+ query: Session query to execute
+
+ Returns:
+ True if any results
+
+ """
+ return count(query.limit(1)) != 0
+
+
+@overload
+def one[T0](
+ query: orm.query.RowReturningQuery[tuple[T0]],
+) -> T0: ...
+
+
+@overload
+def one[T0, T1](
+ query: orm.query.RowReturningQuery[tuple[T0, T1]],
+) -> tuple[T0, T1]: ...
+
+
+@overload
+def one[T](query: orm.Query[T]) -> T: ...
+
+
+def one[T](query: orm.Query[T]) -> object:
+ """Check if any rows exists in query.
+
+ Args:
+ query: Session query to execute
+
+ Returns:
+ One result
+
+ """
+ ret: T | Sequence[T] = query.one() # nummus: ignore
+ if not isinstance(ret, Sequence):
+ return ret
+ if len(ret) == 1: # type: ignore[attr-defined]
+ return ret[0] # type: ignore[attr-defined]
+ return ret[0:] # type: ignore[attr-defined]
+
+
+@overload
+def scalar[T0](
+ query: orm.query.RowReturningQuery[tuple[T0]],
+) -> T0 | None: ...
+
+
+@overload
+def scalar[T0, T1](
+ query: orm.query.RowReturningQuery[tuple[T0, T1]],
+) -> T0 | None: ...
+
+
+@overload
+def scalar[T0, T1, T2](
+ query: orm.query.RowReturningQuery[tuple[T0, T1, T2]],
+) -> T0 | None: ...
+
+
+@overload
+def scalar[T](query: orm.Query[T]) -> T | None: ...
+
+
+def scalar[T](query: orm.Query[T]) -> object | None:
+ """Check if any rows exists in query.
+
+ Args:
+ query: Session query to execute
+
+ Returns:
+ One result
+
+ """
+ return query.scalar() # nummus: ignore
+
+
+@overload
+def yield_[T: tuple[object, ...]](
+ query: orm.query.RowReturningQuery[T],
+) -> Iterable[T]: ...
+
+
+@overload
+def yield_[T](query: orm.Query[T]) -> Iterable[T]: ...
+
+
+def yield_[T](query: orm.Query[T]) -> Iterable[object]:
+ """Yield a query.
+
+ Args:
+ query: Query to yield
+
+ Yields:
+ Rows
+
+ """
+ # Yield per instead of fetch all is faster
+ for r in query.yield_per(100): # nummus: ignore
+ yield r[0:] if isinstance(r, Sequence) else r
+
+
+def col0[T](query: orm.query.RowReturningQuery[tuple[T]]) -> Generator[T]:
+ """Yield a query into a list.
+
+ Args:
+ query: Query to yield
+
+ Yields:
+ first column
+
+ """
+ for (r,) in yield_(query):
+ yield r
diff --git a/nummus/utils.py b/nummus/utils.py
index b586bc43..8c139da0 100644
--- a/nummus/utils.py
+++ b/nummus/utils.py
@@ -16,6 +16,7 @@
from typing import NamedTuple, overload, TYPE_CHECKING
import emoji as emoji_mod
+import pandas as pd
from colorama import Fore
from rapidfuzz import process
from scipy import optimize
@@ -822,7 +823,7 @@ def pretty_table(table: list[list[str] | None]) -> list[str]:
col_widths[i] = n
margin += n_trim
extra = False
- excess.append(0 if cell[-1] == "/" else col_widths[i] - n_label)
+ excess.append(0 if cell[-1] == "/" else max(0, col_widths[i] - n_label))
# Distribute excess
while margin < 0 and any(excess):
@@ -832,7 +833,7 @@ def pretty_table(table: list[list[str] | None]) -> list[str]:
excess[i] -= 1
margin += 1
- formats = []
+ formats: list[str] = []
for cell, n in zip(header_raw, col_widths, strict=True):
align = cell[0]
align = align if align in "<>^" else ""
@@ -1073,3 +1074,26 @@ def set_sub_keys[_, T, V](dicts: dict[_, dict[T, V]]) -> set[T]:
for d in dicts.values():
keys.update(d.keys())
return keys
+
+
+def pd_series_to_dict(s: pd.Series[float]) -> dict[int, float]:
+ """Convert pandas series to dict.
+
+ Args:
+ s: pandas series
+
+ Returns:
+ dict{date ordinal: value}
+
+ Raises:
+ TypeError: if columns are not date, floatable
+
+ """
+ d: dict[int, float] = {}
+ for k, v in s.items():
+ if not isinstance(k, pd.Timestamp):
+ raise TypeError
+ if not isinstance(v, float):
+ raise TypeError
+ d[k.to_pydatetime().date().toordinal()] = float(v)
+ return d
diff --git a/nummus/web.py b/nummus/web.py
index cc81e8fb..f2b19e33 100644
--- a/nummus/web.py
+++ b/nummus/web.py
@@ -3,7 +3,9 @@
from __future__ import annotations
import datetime
+import functools
import os
+from decimal import Decimal
from pathlib import Path
from typing import TYPE_CHECKING
@@ -68,28 +70,30 @@ def init_app(self, app: flask.Flask) -> None:
self._init_metrics(app)
# Inject common variables into templates
- app.context_processor(
- lambda: {
- "url_args": {},
- },
- )
+ args: dict[str, dict[str, object]] = {
+ "url_args": {},
+ }
+ app.context_processor(lambda: args)
@classmethod
- def _open_portfolio(cls, config: flask.Config) -> Portfolio:
- path = (
- Path(
- config.get("PORTFOLIO", "~/.nummus/portfolio.db"),
- )
- .expanduser()
- .absolute()
- )
+ def _open_portfolio(cls, config: dict[str, object]) -> Portfolio:
+ s = config.get("PORTFOLIO", "~/.nummus/portfolio.db")
+ if not isinstance(s, str):
+ raise TypeError
+ path = Path(s).expanduser().absolute()
key = config.get("KEY")
if key is None:
path_key = config.get("KEY_PATH")
- path_key = Path(path_key).expanduser().absolute() if path_key else None
+ path_key = (
+ Path(path_key).expanduser().absolute()
+ if isinstance(path_key, str)
+ else None
+ )
if path_key and path_key.exists():
key = path_key.read_text("utf-8").strip()
+ elif not isinstance(key, str):
+ raise TypeError
return Portfolio(path, key)
@@ -129,27 +133,25 @@ def _add_routes(cls, app: flask.Flask) -> None:
@classmethod
def _init_auth(cls, app: flask.Flask, p: Portfolio) -> None:
- with p.begin_session() as s:
- secret_key = Config.fetch(s, ConfigKey.SECRET_KEY)
+ with p.begin_session():
+ secret_key = Config.fetch(ConfigKey.SECRET_KEY)
app.secret_key = secret_key
- app.config.update(
- SESSION_COOKIE_SECURE=True,
- SESSION_COOKIE_HTTPONLY=True,
- SESSION_COOKIE_SAMESITE="Lax",
- REMEMBER_COOKIE_SECURE=True,
- REMEMBER_COOKIE_HTTPONLY=True,
- REMEMBER_COOKIE_SAMESITE="Lax",
- REMEMBER_COOKIE_DURATION=datetime.timedelta(days=28),
- )
+ config: dict[str, object] = app.config
+ config["SESSION_COOKIE_SECURE"] = True
+ config["SESSION_COOKIE_HTTPONLY"] = True
+ config["SESSION_COOKIE_SAMESITE"] = "Lax"
+ config["REMEMBER_COOKIE_SECURE"] = True
+ config["REMEMBER_COOKIE_HTTPONLY"] = True
+ config["REMEMBER_COOKIE_SAMESITE"] = "Lax"
+ config["REMEMBER_COOKIE_DURATION"] = datetime.timedelta(days=28)
app.after_request(base.update_client_timezone)
app.after_request(base.change_redirect_to_htmx)
login_manager = flask_login.LoginManager()
login_manager.init_app(app)
login_manager.user_loader(auth.get_user)
- # LoginManager.login_view not typed to str | None
- login_manager.login_view = "auth.page_login" # type: ignore[attr-defined]
+ login_manager.login_view = "auth.page_login"
if p.is_encrypted:
# Only can have authentiation with encrypted portfolio
@@ -159,25 +161,41 @@ def _init_auth(cls, app: flask.Flask, p: Portfolio) -> None:
def _init_jinja_env(cls, env: jinja2.Environment) -> None:
env.filters["seconds"] = utils.format_seconds
env.filters["days"] = utils.format_days
- env.filters["days_abv"] = lambda x: utils.format_days(
- x,
- ["days", "wks", "mos", "yrs"],
+ env.filters["days_abv"] = functools.partial(
+ utils.format_days,
+ labels=["days", "wks", "mos", "yrs"],
)
env.filters["comma"] = lambda x: f"{x:,.2f}"
env.filters["qty"] = lambda x: f"{x:,.6f}"
- env.filters["percent"] = lambda x: f"{x * 100:5.2f}%"
- env.filters["pnl_color"] = lambda x: (
- "" if x is None or x == 0 else ("text-primary" if x > 0 else "text-error")
- )
- env.filters["pnl_arrow"] = lambda x: (
- ""
- if x is None or x == 0
- else ("arrow_upward" if x > 0 else "arrow_downward")
- )
env.filters["no_emojis"] = utils.strip_emojis
env.filters["tojson"] = base.ctx_to_json
env.filters["input_value"] = lambda x: str(x or "").rstrip("0").rstrip(".")
+ def percent(x: Decimal | float | object) -> str:
+ if not isinstance(x, Decimal | float):
+ raise TypeError
+ return f"{x * 100:5.2f}%"
+
+ env.filters["percent"] = percent
+
+ def pnl_color(x: Decimal | float | object) -> str:
+ if not x:
+ return ""
+ if not isinstance(x, Decimal | float):
+ raise TypeError
+ return "text-primary" if x > 0 else "text-error"
+
+ env.filters["pnl_color"] = pnl_color
+
+ def pnl_arrow(x: Decimal | float | object) -> str:
+ if not x:
+ return ""
+ if not isinstance(x, Decimal | float):
+ raise TypeError
+ return "arrow_upward" if x > 0 else "arrow_downward"
+
+ env.filters["pnl_arrow"] = pnl_arrow
+
@classmethod
def _init_metrics(cls, app: flask.Flask) -> None:
multiproc = "PROMETHEUS_MULTIPROC_DIR" in os.environ
diff --git a/nummus/web_assets.py b/nummus/web_assets.py
index e9ca943b..2536681f 100644
--- a/nummus/web_assets.py
+++ b/nummus/web_assets.py
@@ -29,7 +29,7 @@ class TailwindCSSFilter(webassets.filter.Filter):
DEBUG = False
@override
- def output(self, _in: io.StringIO, out: io.StringIO, **_) -> None:
+ def output(self, _in: io.StringIO, out: io.StringIO, **_: object) -> None:
if pytailwindcss is None:
raise NotImplementedError
path_root = Path(__file__).parent.resolve()
@@ -54,7 +54,7 @@ class JSMinFilter(webassets.filter.Filter):
"""webassets Filter for running jsmin over."""
@override
- def output(self, _in: io.StringIO, out: io.StringIO, **_) -> None:
+ def output(self, _in: io.StringIO, out: io.StringIO, **_: object) -> None:
if jsmin is None:
raise NotImplementedError
# Add back tick to quote_chars for template strings
diff --git a/pyproject.toml b/pyproject.toml
index 60e0be11..c76360d4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -76,8 +76,10 @@ dev = [
"isort",
"pre-commit",
"djlint>=1.36.4",
- "pyright>=1.1.402",
+ "pyright>=1.1.408",
"taplo",
+ "scipy-stubs",
+ "pandas-stubs",
]
build = ["build"]
build-docker = ["nummus-financial[build,encrypt]", "setuptools-scm>=8"]
@@ -148,9 +150,20 @@ force_alphabetical_sort_within_sections = true
[tool.pyright]
include = ["nummus", "tests", "tools"]
-exclude = ["**/__pycache__"]
+exclude = ["**/__pycache__", "typing"]
venvPath = "."
venv = ".venv"
+typeCheckingMode = "strict"
+
+reportUnknownParameterType = "warning"
+reportUnknownArgumentType = "warning"
+reportUnknownLambdaType = "warning"
+reportUnknownVariableType = "warning"
+reportUnknownMemberType = "warning"
+reportUnnecessaryTypeIgnoreComment = "warning"
+
+# Checked by ruff
+reportPrivateUsage = "none"
[tool.pytest.ini_options]
markers = ["encryption: tests that require encryption"]
@@ -158,6 +171,7 @@ addopts = ["--durations=10", "--import-mode=importlib"]
[tool.ruff]
target-version = "py312"
+exclude = ["typings"]
[tool.ruff.lint]
select = ["ALL"]
@@ -212,6 +226,7 @@ ignore = [
"FBT001", # Boolean positional arguments
"T201", # Allow printing for extra context
"SLF001", # Allow access to privates
+ "ARG001", # Allow unused arguments for fixtures
]
"tests/controllers/*.py" = [
"N802", # Allow CAPS methods
diff --git a/tests/commands/test_backup.py b/tests/commands/test_backup.py
index d2853064..e6108f44 100644
--- a/tests/commands/test_backup.py
+++ b/tests/commands/test_backup.py
@@ -14,7 +14,7 @@
from nummus.portfolio import Portfolio
-def test_backup(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> None:
+def test_backup(capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio) -> None:
c = Backup(empty_portfolio.path, None)
assert c.run() == 0
@@ -30,7 +30,10 @@ def test_backup(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> No
assert not captured.err
-def test_restore(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> None:
+def test_restore(
+ capsys: pytest.CaptureFixture[str],
+ empty_portfolio: Portfolio,
+) -> None:
empty_portfolio.backup()
c = Restore(empty_portfolio.path, None, tar_ver=None, list_ver=False)
assert c.run() == 0
@@ -45,7 +48,7 @@ def test_restore(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> N
def test_restore_missing(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio: Portfolio,
) -> None:
c = Restore(empty_portfolio.path, None, tar_ver=None, list_ver=False)
@@ -58,7 +61,7 @@ def test_restore_missing(
def test_restore_list_empty(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio: Portfolio,
) -> None:
c = Restore(empty_portfolio.path, None, tar_ver=None, list_ver=True)
@@ -71,7 +74,7 @@ def test_restore_list_empty(
def test_restore_list(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio: Portfolio,
utc_frozen: datetime.datetime,
) -> None:
diff --git a/tests/commands/test_base.py b/tests/commands/test_base.py
index a7cd28ae..050a0b27 100644
--- a/tests/commands/test_base.py
+++ b/tests/commands/test_base.py
@@ -32,7 +32,7 @@ class MockCommand(Command):
@classmethod
def setup_args(cls, parser: argparse.ArgumentParser) -> None:
- _ = parser
+ pass
@override
def run(self) -> int:
@@ -43,7 +43,7 @@ def test_no_unlock(tmp_path: Path) -> None:
MockCommand(tmp_path / "fake.db", None, do_unlock=False)
-def test_no_file(capsys: pytest.CaptureFixture, tmp_path: Path) -> None:
+def test_no_file(capsys: pytest.CaptureFixture[str], tmp_path: Path) -> None:
path = tmp_path / "fake.db"
with pytest.raises(SystemExit):
MockCommand(path, None)
@@ -54,7 +54,7 @@ def test_no_file(capsys: pytest.CaptureFixture, tmp_path: Path) -> None:
assert captured.err == target
-def test_unlock(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> None:
+def test_unlock(capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio) -> None:
MockCommand(empty_portfolio.path, None)
captured = capsys.readouterr()
@@ -63,7 +63,10 @@ def test_unlock(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> No
assert not captured.err
-def test_migration_required(capsys: pytest.CaptureFixture, data_path: Path) -> None:
+def test_migration_required(
+ capsys: pytest.CaptureFixture[str],
+ data_path: Path,
+) -> None:
with pytest.raises(SystemExit):
MockCommand(data_path / "old_versions" / "v0.1.16.db", None)
@@ -80,7 +83,7 @@ def test_migration_required(capsys: pytest.CaptureFixture, data_path: Path) -> N
@pytest.mark.skipif(not ENCRYPTION_AVAILABLE, reason="Encryption is not installed")
@pytest.mark.encryption
def test_unlock_encrypted_path(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio_encrypted: tuple[Portfolio, str],
tmp_path: Path,
) -> None:
@@ -99,7 +102,7 @@ def test_unlock_encrypted_path(
@pytest.mark.skipif(not ENCRYPTION_AVAILABLE, reason="Encryption is not installed")
@pytest.mark.encryption
def test_unlock_encrypted_path_bad_key(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio_encrypted: tuple[Portfolio, str],
tmp_path: Path,
) -> None:
@@ -119,7 +122,7 @@ def test_unlock_encrypted_path_bad_key(
@pytest.mark.skipif(not ENCRYPTION_AVAILABLE, reason="Encryption is not installed")
@pytest.mark.encryption
def test_unlock_encrypted(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
empty_portfolio_encrypted: tuple[Portfolio, str],
) -> None:
@@ -147,7 +150,7 @@ def mock_get_pass(to_print: str) -> str | None:
@pytest.mark.skipif(not ENCRYPTION_AVAILABLE, reason="Encryption is not installed")
@pytest.mark.encryption
def test_unlock_encrypted_cancel(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
empty_portfolio_encrypted: tuple[Portfolio, str],
) -> None:
@@ -171,7 +174,7 @@ def mock_get_pass(to_print: str) -> str | None:
@pytest.mark.skipif(not ENCRYPTION_AVAILABLE, reason="Encryption is not installed")
@pytest.mark.encryption
def test_unlock_encrypted_failed(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
empty_portfolio_encrypted: tuple[Portfolio, str],
) -> None:
diff --git a/tests/commands/test_change_password.py b/tests/commands/test_change_password.py
index 7e26ec8d..2ac0b1b3 100644
--- a/tests/commands/test_change_password.py
+++ b/tests/commands/test_change_password.py
@@ -26,7 +26,7 @@ def change_web_key(self, key: str) -> None:
def test_no_change_unencrypted(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio: Portfolio,
tmp_path: Path,
) -> None:
@@ -49,7 +49,7 @@ def test_no_change_unencrypted(
],
)
def test_change(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
empty_portfolio: Portfolio,
tmp_path: Path,
diff --git a/tests/commands/test_clean.py b/tests/commands/test_clean.py
index 563ee5dc..8d6cd964 100644
--- a/tests/commands/test_clean.py
+++ b/tests/commands/test_clean.py
@@ -12,7 +12,7 @@
from nummus.portfolio import Portfolio
-def test_clean(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> None:
+def test_clean(capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio) -> None:
c = Clean(empty_portfolio.path, None)
assert c.run() == 0
diff --git a/tests/commands/test_create.py b/tests/commands/test_create.py
index 0ca4dd38..49a0c2d4 100644
--- a/tests/commands/test_create.py
+++ b/tests/commands/test_create.py
@@ -28,7 +28,7 @@ def create(cls, path: str | Path, key: str | None = None) -> Portfolio:
def test_create_existing(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
tmp_path: Path,
) -> None:
path = tmp_path / "new.db"
@@ -44,7 +44,7 @@ def test_create_existing(
def test_create_unencrypted_forced(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
@@ -63,7 +63,7 @@ def test_create_unencrypted_forced(
def test_create_unencrypted(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
@@ -81,7 +81,7 @@ def test_create_unencrypted(
def test_create_encrypted(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
rand_str: str,
@@ -108,7 +108,7 @@ def mock_get_pass(_: str) -> str | None:
def test_create_encrypted_pass_file(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
rand_str: str,
@@ -130,7 +130,7 @@ def test_create_encrypted_pass_file(
def test_create_encrypted_cancelled(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
diff --git a/tests/commands/test_export.py b/tests/commands/test_export.py
index 0b44833b..d84e46bc 100644
--- a/tests/commands/test_export.py
+++ b/tests/commands/test_export.py
@@ -18,7 +18,7 @@
def test_export_empty(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio: Portfolio,
tmp_path: Path,
) -> None:
@@ -43,7 +43,7 @@ def test_export_empty(
def test_export(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio: Portfolio,
account: Account,
transactions: list[Transaction],
diff --git a/tests/commands/test_health.py b/tests/commands/test_health.py
index 5f34c9ce..5a33d96a 100644
--- a/tests/commands/test_health.py
+++ b/tests/commands/test_health.py
@@ -9,7 +9,6 @@
from nummus.health_checks.top import HEALTH_CHECKS
from nummus.models.config import Config, ConfigKey
from nummus.models.health_checks import HealthCheckIssue
-from nummus.models.utils import query_count
from nummus.portfolio import Portfolio
if TYPE_CHECKING:
@@ -20,7 +19,7 @@
def test_issues(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio: Portfolio,
utc_frozen: datetime.datetime,
) -> None:
@@ -47,13 +46,13 @@ def test_issues(
assert f"{Fore.MAGENTA}Use web interface to fix issues" in captured.out
assert not captured.err
- with empty_portfolio.begin_session() as s:
- v = Config.fetch(s, ConfigKey.LAST_HEALTH_CHECK_TS)
+ with empty_portfolio.begin_session():
+ v = Config.fetch(ConfigKey.LAST_HEALTH_CHECK_TS)
assert v == utc_frozen.isoformat()
def test_no_limit_severe(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
empty_portfolio: Portfolio,
utc_frozen: datetime.datetime,
@@ -85,25 +84,24 @@ def test_no_limit_severe(
assert f"{Fore.MAGENTA}Use web interface to fix issues" in captured.out
assert not captured.err
- with empty_portfolio.begin_session() as s:
- v = Config.fetch(s, ConfigKey.LAST_HEALTH_CHECK_TS)
+ with empty_portfolio.begin_session():
+ v = Config.fetch(ConfigKey.LAST_HEALTH_CHECK_TS)
assert v == utc_frozen.isoformat()
def test_ignore_all(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio: Portfolio,
utc_frozen: datetime.datetime,
) -> None:
ignores: list[str] = []
for check_type in HEALTH_CHECKS:
c = check_type()
- with empty_portfolio.begin_session() as s:
- c.test(s)
+ with empty_portfolio.begin_session():
+ c.test()
ignores.extend(c.issues.keys())
- with empty_portfolio.begin_session() as s:
+ with empty_portfolio.begin_session():
Config.set_(
- s,
ConfigKey.LAST_HEALTH_CHECK_TS,
(utc_frozen - datetime.timedelta(days=1)).isoformat(),
)
@@ -131,22 +129,22 @@ def test_ignore_all(
assert f"{Fore.MAGENTA}Use web interface to fix issues" not in captured.out
assert not captured.err
- with empty_portfolio.begin_session() as s:
- v = Config.fetch(s, ConfigKey.LAST_HEALTH_CHECK_TS)
+ with empty_portfolio.begin_session():
+ v = Config.fetch(ConfigKey.LAST_HEALTH_CHECK_TS)
assert v == utc_frozen.isoformat()
- assert query_count(s.query(HealthCheckIssue)) == len(ignores)
+ assert HealthCheckIssue.count() == len(ignores)
def test_clear_ignores(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio: Portfolio,
) -> None:
ignores: list[str] = []
for check_type in HEALTH_CHECKS:
c = check_type()
- with empty_portfolio.begin_session() as s:
- c.test(s)
+ with empty_portfolio.begin_session():
+ c.test()
ignores.extend(c.issues.keys())
c = Health(
@@ -168,5 +166,5 @@ def test_clear_ignores(
assert "has no transactions nor budget assignments" in captured.out
assert not captured.err
- with empty_portfolio.begin_session() as s:
- assert query_count(s.query(HealthCheckIssue)) == len(ignores)
+ with empty_portfolio.begin_session():
+ assert HealthCheckIssue.count() == len(ignores)
diff --git a/tests/commands/test_import_files.py b/tests/commands/test_import_files.py
index 63463484..c7b4e8ac 100644
--- a/tests/commands/test_import_files.py
+++ b/tests/commands/test_import_files.py
@@ -17,7 +17,7 @@
from nummus.portfolio import Portfolio
-def test_empty(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> None:
+def test_empty(capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio) -> None:
path_debug = empty_portfolio.path.with_suffix(".importer_debug")
c = Import(empty_portfolio.path, None, [], force=False)
@@ -32,7 +32,7 @@ def test_empty(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> Non
def test_non_existant(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio: Portfolio,
tmp_path: Path,
) -> None:
@@ -55,7 +55,7 @@ def test_non_existant(
def test_data_dir(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio: Portfolio,
account: Account,
account_investments: Account,
@@ -63,9 +63,6 @@ def test_data_dir(
tmp_path: Path,
data_path: Path,
) -> None:
- _ = account
- _ = account_investments
- _ = asset
files = [
"transactions_required.csv",
"transactions_extras.csv",
@@ -87,7 +84,7 @@ def test_data_dir(
def test_unknown_importer(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio: Portfolio,
account: Account,
account_investments: Account,
@@ -95,9 +92,6 @@ def test_unknown_importer(
tmp_path: Path,
data_path: Path,
) -> None:
- _ = account
- _ = account_investments
- _ = asset
file = "transactions_lacking.csv"
path = tmp_path / file
shutil.copyfile(data_path / file, path)
@@ -120,7 +114,7 @@ def test_unknown_importer(
def test_duplicate(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
today: datetime.date,
empty_portfolio: Portfolio,
account: Account,
@@ -129,9 +123,6 @@ def test_duplicate(
tmp_path: Path,
data_path: Path,
) -> None:
- _ = account
- _ = account_investments
- _ = asset
file = "transactions_required.csv"
path = tmp_path / file
shutil.copyfile(data_path / file, path)
@@ -155,7 +146,7 @@ def test_duplicate(
@pytest.mark.xfail
def test_data_dir_no_account(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio: Portfolio,
data_path: Path,
) -> None:
diff --git a/tests/commands/test_migrate.py b/tests/commands/test_migrate.py
index b9259419..b4322ff3 100644
--- a/tests/commands/test_migrate.py
+++ b/tests/commands/test_migrate.py
@@ -16,7 +16,7 @@
def test_not_required(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio: Portfolio,
) -> None:
@@ -33,7 +33,7 @@ def test_not_required(
def test_v0_1_migration(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
tmp_path: Path,
data_path: Path,
) -> None:
diff --git a/tests/commands/test_summarize.py b/tests/commands/test_summarize.py
index 1d142895..b191d765 100644
--- a/tests/commands/test_summarize.py
+++ b/tests/commands/test_summarize.py
@@ -52,9 +52,6 @@ def test_non_empty_summary(
transactions: list[Transaction],
asset_valuation: AssetValuation,
) -> None:
- _ = transactions
- _ = asset_valuation
-
c = Summarize(empty_portfolio.path, None, include_all=False)
utc = utc.replace(tzinfo=zoneinfo.ZoneInfo("UTC"))
@@ -102,10 +99,8 @@ def test_exclude_empty(
account: Account,
asset: Asset,
) -> None:
- account.closed = True
- session.commit()
- _ = asset
-
+ with session.begin_nested():
+ account.closed = True
c = Summarize(empty_portfolio.path, None, include_all=False)
result = c._get_summary()
@@ -125,10 +120,9 @@ def test_exclude_empty(
def test_empty_print(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio: Portfolio,
) -> None:
-
c = Summarize(empty_portfolio.path, None, include_all=False)
assert c.run() == 0
@@ -142,15 +136,12 @@ def test_empty_print(
def test_non_empty_print(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
utc: datetime.datetime,
empty_portfolio: Portfolio,
transactions: list[Transaction],
asset_valuation: AssetValuation,
) -> None:
- _ = transactions
- _ = asset_valuation
-
c = Summarize(empty_portfolio.path, None, include_all=False)
utc = utc.replace(tzinfo=zoneinfo.ZoneInfo("UTC"))
diff --git a/tests/commands/test_unlock.py b/tests/commands/test_unlock.py
index 8f17c199..03844c52 100644
--- a/tests/commands/test_unlock.py
+++ b/tests/commands/test_unlock.py
@@ -14,7 +14,7 @@
def test_empty(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio: Portfolio,
) -> None:
diff --git a/tests/commands/test_update_assets.py b/tests/commands/test_update_assets.py
index e396bd2a..e5afeef5 100644
--- a/tests/commands/test_update_assets.py
+++ b/tests/commands/test_update_assets.py
@@ -20,10 +20,9 @@
def test_empty(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio: Portfolio,
) -> None:
-
c = UpdateAssets(empty_portfolio.path, None, no_bars=True)
assert c.run() == 0
@@ -38,15 +37,14 @@ def test_empty(
def test_one(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
today: datetime.date,
empty_portfolio: Portfolio,
transactions: list[Transaction],
asset: Asset,
) -> None:
- _ = transactions
- with empty_portfolio.begin_session() as s:
- s.query(Asset).where(Asset.category == AssetCategory.INDEX).delete()
+ with empty_portfolio.begin_session():
+ Asset.query().where(Asset.category == AssetCategory.INDEX).delete()
c = UpdateAssets(empty_portfolio.path, None, no_bars=True)
assert c.run() == 0
@@ -62,15 +60,15 @@ def test_one(
def test_failed(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio: Portfolio,
transactions: list[Transaction],
asset: Asset,
) -> None:
- _ = transactions
- with empty_portfolio.begin_session() as s:
- s.query(Asset).where(Asset.category == AssetCategory.INDEX).delete()
- s.query(Asset).update({"ticker": "FAKE"})
+ with empty_portfolio.begin_session():
+ Asset.query().where(Asset.category == AssetCategory.INDEX).delete()
+ Asset.query().update({"ticker": "FAKE"})
+ asset.refresh()
c = UpdateAssets(empty_portfolio.path, None, no_bars=True)
assert c.run() != 0
diff --git a/tests/conftest.py b/tests/conftest.py
index cad6d7db..4c324fe0 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -26,6 +26,7 @@
AssetValuation,
USSector,
)
+from nummus.models.base import Base
from nummus.models.budget import (
BudgetAssignment,
BudgetGroup,
@@ -41,6 +42,8 @@
from tests.mock_yfinance import MockTicker
if TYPE_CHECKING:
+ from collections.abc import Generator
+
import time_machine
@@ -48,7 +51,7 @@ def id_func(val: object) -> str | None:
if isinstance(val, datetime.date):
return val.isoformat()
if isinstance(val, Iterable | Decimal | Path):
- return str(val)
+ return str(val) # type: ignore[attr-defined]
if callable(val):
return val.__name__
return None
@@ -211,21 +214,23 @@ def empty_portfolio_encrypted(
return p, key
-@pytest.fixture
-def session(empty_portfolio: Portfolio) -> orm.Session:
+@pytest.fixture(autouse=True)
+def session(empty_portfolio: Portfolio) -> Generator[orm.Session]:
"""Create SQL session.
- Returns:
- Session generator
+ Yields:
+ Session
"""
- return orm.Session(sql.get_engine(empty_portfolio.path, None))
+ s = orm.Session(sql.get_engine(empty_portfolio.path, None))
+ with Base.set_session(s):
+ yield s
-@pytest.fixture(autouse=True)
+@pytest.fixture(autouse=True, scope="session")
def uri_cipher() -> None:
"""Generate a URI cipher."""
- base_uri._CIPHER = base_uri.Cipher.generate()
+ base_uri._cipher = base_uri.Cipher.generate()
@pytest.fixture(autouse=True)
@@ -308,18 +313,16 @@ def account(session: orm.Session, rand_str_generator: RandomStringGenerator) ->
Checking Account, not closed, budgeted
"""
- acct = Account(
- name="Monkey bank checking",
- institution="Monkey bank",
- category=AccountCategory.CASH,
- closed=False,
- budgeted=True,
- currency=DEFAULT_CURRENCY,
- number=rand_str_generator(),
- )
- session.add(acct)
- session.commit()
- return acct
+ with session.begin_nested():
+ return Account.create(
+ name="Monkey bank checking",
+ institution="Monkey bank",
+ category=AccountCategory.CASH,
+ closed=False,
+ budgeted=True,
+ currency=DEFAULT_CURRENCY,
+ number=rand_str_generator(),
+ )
@pytest.fixture
@@ -330,19 +333,17 @@ def account_savings(session: orm.Session) -> Account:
Savings Account, not closed, not budgeted
"""
- acct = Account(
- # capital case for HTML header check
- name="Monkey bank savings",
- institution="Monkey bank",
- category=AccountCategory.CASH,
- closed=False,
- budgeted=False,
- currency=DEFAULT_CURRENCY,
- number="1234",
- )
- session.add(acct)
- session.commit()
- return acct
+ with session.begin_nested():
+ return Account.create(
+ # capital case for HTML header check
+ name="Monkey bank savings",
+ institution="Monkey bank",
+ category=AccountCategory.CASH,
+ closed=False,
+ budgeted=False,
+ currency=DEFAULT_CURRENCY,
+ number="1234",
+ )
@pytest.fixture
@@ -353,18 +354,16 @@ def account_investments(session: orm.Session) -> Account:
Investments Account, not closed, not budgeted
"""
- acct = Account(
- name="Monkey bank investments",
- institution="Monkey bank",
- category=AccountCategory.INVESTMENT,
- closed=False,
- budgeted=False,
- currency=DEFAULT_CURRENCY,
- number="1235",
- )
- session.add(acct)
- session.commit()
- return acct
+ with session.begin_nested():
+ return Account.create(
+ name="Monkey bank investments",
+ institution="Monkey bank",
+ category=AccountCategory.INVESTMENT,
+ closed=False,
+ budgeted=False,
+ currency=DEFAULT_CURRENCY,
+ number="1235",
+ )
@pytest.fixture
@@ -375,7 +374,7 @@ def categories(session: orm.Session) -> dict[str, int]:
dict{name: category id}
"""
- return {name: id_ for id_, name in TransactionCategory.map_name(session).items()}
+ return {name: id_ for id_, name in TransactionCategory.map_name().items()}
@pytest.fixture
@@ -386,10 +385,11 @@ def labels(session: orm.Session) -> dict[str, int]:
dict{name: label id}
"""
- labels = {"engineer", "fruit", "apartments 4 U"}
- session.add_all(Label(name=name) for name in labels)
- session.commit()
- return {name: id_ for id_, name in Label.map_name(session).items()}
+ with session.begin_nested():
+ labels = {"engineer", "fruit", "apartments 4 U"}
+ for name in labels:
+ Label.create(name=name)
+ return {name: id_ for id_, name in Label.map_name().items()}
@pytest.fixture
@@ -400,16 +400,14 @@ def asset(session: orm.Session) -> Asset:
Banana Incorporated, STOCKS
"""
- asset = Asset(
- name="Banana incorporated",
- category=AssetCategory.STOCKS,
- ticker="BANANA",
- description="Banana Incorporated makes bananas",
- currency=DEFAULT_CURRENCY,
- )
- session.add(asset)
- session.commit()
- return asset
+ with session.begin_nested():
+ return Asset.create(
+ name="Banana incorporated",
+ category=AssetCategory.STOCKS,
+ ticker="BANANA",
+ description="Banana Incorporated makes bananas",
+ currency=DEFAULT_CURRENCY,
+ )
@pytest.fixture
@@ -420,16 +418,14 @@ def asset_etf(session: orm.Session) -> Asset:
Banana ETF, STOCKS
"""
- asset = Asset(
- name="Banana ETF",
- category=AssetCategory.STOCKS,
- ticker="BANANA_ETF",
- description="Banana ETF",
- currency=DEFAULT_CURRENCY,
- )
- session.add(asset)
- session.commit()
- return asset
+ with session.begin_nested():
+ return Asset.create(
+ name="Banana ETF",
+ category=AssetCategory.STOCKS,
+ ticker="BANANA_ETF",
+ description="Banana ETF",
+ currency=DEFAULT_CURRENCY,
+ )
@pytest.fixture
@@ -444,10 +440,8 @@ def asset_valuation(
AssetValuation on today of $10
"""
- v = AssetValuation(asset_id=asset.id_, date_ord=today_ord, value=2)
- session.add(v)
- session.commit()
- return v
+ with session.begin_nested():
+ return AssetValuation.create(asset_id=asset.id_, date_ord=today_ord, value=2)
@pytest.fixture
@@ -462,10 +456,8 @@ def asset_split(
AssetSplit on today of 10:1
"""
- v = AssetSplit(asset_id=asset.id_, date_ord=today_ord, multiplier=10)
- session.add(v)
- session.commit()
- return v
+ with session.begin_nested():
+ return AssetSplit.create(asset_id=asset.id_, date_ord=today_ord, multiplier=10)
@pytest.fixture
@@ -479,19 +471,18 @@ def asset_sectors(
20% BASIC_MATERIALS, 80% TECHNOLOGY
"""
- s0 = AssetSector(
- asset_id=asset.id_,
- sector=USSector.BASIC_MATERIALS,
- weight=Decimal("0.2"),
- )
- s1 = AssetSector(
- asset_id=asset.id_,
- sector=USSector.TECHNOLOGY,
- weight=Decimal("0.8"),
- )
- session.add_all((s0, s1))
- session.commit()
- return s0, s1
+ with session.begin_nested():
+ s0 = AssetSector.create(
+ asset_id=asset.id_,
+ sector=USSector.BASIC_MATERIALS,
+ weight=Decimal("0.2"),
+ )
+ s1 = AssetSector.create(
+ asset_id=asset.id_,
+ sector=USSector.TECHNOLOGY,
+ weight=Decimal("0.8"),
+ )
+ return s0, s1
@pytest.fixture
@@ -505,10 +496,8 @@ def budget_group(
BudgetGroup with position 0
"""
- g = BudgetGroup(name=rand_str_generator(), position=0)
- session.add(g)
- session.commit()
- return g
+ with session.begin_nested():
+ return BudgetGroup.create(name=rand_str_generator(), position=0)
@pytest.fixture
@@ -521,84 +510,80 @@ def transactions(
categories: dict[str, int],
labels: dict[str, int],
) -> list[Transaction]:
- # Fund account on 3 days before today
- txn = Transaction(
- account_id=account.id_,
- date=today - datetime.timedelta(days=3),
- amount=100,
- statement=rand_str_generator(),
- payee="Monkey Bank",
- cleared=True,
- )
- t_split_0 = TransactionSplit(
- parent=txn,
- amount=txn.amount,
- category_id=categories["other income"],
- )
- session.add_all((txn, t_split_0))
-
- # Buy asset on 2 days before today
- txn = Transaction(
- account_id=account.id_,
- date=today - datetime.timedelta(days=2),
- amount=-10,
- statement=rand_str_generator(),
- payee="Monkey Bank",
- cleared=True,
- )
- t_split_1 = TransactionSplit(
- parent=txn,
- amount=txn.amount,
- asset_id=asset.id_,
- asset_quantity_unadjusted=10,
- category_id=categories["securities traded"],
- )
- session.add_all((txn, t_split_1))
-
- # Sell asset tomorrow
- txn = Transaction(
- account_id=account.id_,
- date=today + datetime.timedelta(days=1),
- amount=50,
- statement=rand_str_generator(),
- payee="Monkey Bank",
- cleared=True,
- )
- t_split = TransactionSplit(
- parent=txn,
- amount=txn.amount,
- asset_id=asset.id_,
- asset_quantity_unadjusted=-5,
- category_id=categories["securities traded"],
- memo="for rent",
- )
- session.add_all((txn, t_split))
-
- # Sell remaining next week
- txn = Transaction(
- account_id=account.id_,
- date=today + datetime.timedelta(days=7),
- amount=50,
- statement=rand_str_generator(),
- payee="Monkey Bank",
- cleared=True,
- )
- t_split = TransactionSplit(
- parent=txn,
- amount=txn.amount,
- asset_id=asset.id_,
- asset_quantity_unadjusted=-5,
- category_id=categories["securities traded"],
- memo="rent transfer",
- )
- session.add_all((txn, t_split))
+ with session.begin_nested():
+ # Fund account on 3 days before today
+ txn = Transaction.create(
+ account_id=account.id_,
+ date=today - datetime.timedelta(days=3),
+ amount=100,
+ statement=rand_str_generator(),
+ payee="Monkey Bank",
+ cleared=True,
+ )
+ t_split_0 = TransactionSplit.create(
+ parent=txn,
+ amount=txn.amount,
+ category_id=categories["other income"],
+ )
- session.commit()
+ # Buy asset on 2 days before today
+ txn = Transaction.create(
+ account_id=account.id_,
+ date=today - datetime.timedelta(days=2),
+ amount=-10,
+ statement=rand_str_generator(),
+ payee="Monkey Bank",
+ cleared=True,
+ )
+ t_split_1 = TransactionSplit.create(
+ parent=txn,
+ amount=txn.amount,
+ asset_id=asset.id_,
+ asset_quantity_unadjusted=10,
+ category_id=categories["securities traded"],
+ )
+
+ # Sell asset tomorrow
+ txn = Transaction.create(
+ account_id=account.id_,
+ date=today + datetime.timedelta(days=1),
+ amount=50,
+ statement=rand_str_generator(),
+ payee="Monkey Bank",
+ cleared=True,
+ )
+ TransactionSplit.create(
+ parent=txn,
+ amount=txn.amount,
+ asset_id=asset.id_,
+ asset_quantity_unadjusted=-5,
+ category_id=categories["securities traded"],
+ memo="for rent",
+ )
- session.add(LabelLink(label_id=labels["engineer"], t_split_id=t_split_0.id_))
- session.add(LabelLink(label_id=labels["engineer"], t_split_id=t_split_1.id_))
- session.commit()
- return session.query(Transaction).order_by(Transaction.date_ord).all()
+ # Sell remaining next week
+ txn = Transaction.create(
+ account_id=account.id_,
+ date=today + datetime.timedelta(days=7),
+ amount=50,
+ statement=rand_str_generator(),
+ payee="Monkey Bank",
+ cleared=True,
+ )
+ TransactionSplit.create(
+ parent=txn,
+ amount=txn.amount,
+ asset_id=asset.id_,
+ asset_quantity_unadjusted=-5,
+ category_id=categories["securities traded"],
+ memo="rent transfer",
+ )
+
+ LabelLink.create(label_id=labels["engineer"], t_split_id=t_split_0.id_)
+ LabelLink.create(label_id=labels["engineer"], t_split_id=t_split_1.id_)
+ return (
+ Transaction.query().order_by(Transaction.date_ord).all() # nummus: ignore
+ )
@pytest.fixture
@@ -612,59 +597,113 @@ def transactions_spending(
categories: dict[str, int],
labels: dict[str, int],
) -> list[Transaction]:
- statement_income = rand_str_generator()
- statement_groceries = rand_str_generator()
- statement_rent = rand_str_generator()
- specs = [
- (account, Decimal(100), statement_income, "other income"),
- (account, Decimal(100), statement_income, "other income"),
- (account, Decimal(120), statement_income, "other income"),
- (account, Decimal(-10), statement_groceries, "groceries"),
- (account, Decimal(-10), statement_groceries + " other word", "groceries"),
- (account, Decimal(-50), statement_rent, "rent"),
- (account, Decimal(1000), rand_str_generator(), "other income"),
- (account_savings, Decimal(100), statement_income, "other income"),
- ]
- for acct, amount, statement, category in specs:
- txn = Transaction(
- account_id=acct.id_,
+ with session.begin_nested():
+ statement_income = rand_str_generator()
+ statement_groceries = rand_str_generator()
+ statement_rent = rand_str_generator()
+ specs = [
+ (account, Decimal(100), statement_income, "other income"),
+ (account, Decimal(100), statement_income, "other income"),
+ (account, Decimal(120), statement_income, "other income"),
+ (account, Decimal(-10), statement_groceries, "groceries"),
+ (account, Decimal(-10), statement_groceries + " other word", "groceries"),
+ (account, Decimal(-50), statement_rent, "rent"),
+ (account, Decimal(1000), rand_str_generator(), "other income"),
+ (account_savings, Decimal(100), statement_income, "other income"),
+ ]
+ for acct, amount, statement, category in specs:
+ txn = Transaction.create(
+ account_id=acct.id_,
+ date=today,
+ amount=amount,
+ statement=statement,
+ )
+ TransactionSplit.create(
+ parent=txn,
+ amount=txn.amount,
+ category_id=categories[category],
+ )
+
+ txn = Transaction.create(
+ account_id=account.id_,
date=today,
- amount=amount,
- statement=statement,
+ amount=-50,
+ statement=statement_rent + " other word",
)
- t_split = TransactionSplit(
+ TransactionSplit.create(
parent=txn,
amount=txn.amount,
- category_id=categories[category],
+ asset_id=asset.id_,
+ asset_quantity_unadjusted=10,
+ category_id=categories["securities traded"],
)
- session.add_all((txn, t_split))
- txn = Transaction(
- account_id=account.id_,
- date=today,
- amount=-50,
- statement=statement_rent + " other word",
- )
- t_split = TransactionSplit(
- parent=txn,
- amount=txn.amount,
- asset_id=asset.id_,
- asset_quantity_unadjusted=10,
- category_id=categories["securities traded"],
- )
- session.add_all((txn, t_split))
+ query = TransactionSplit.query(TransactionSplit.id_).where(
+ TransactionSplit.category_id == categories["rent"],
+ )
+ t_split_id = sql.one(query)
+ LabelLink.create(label_id=labels["apartments 4 U"], t_split_id=t_split_id)
- session.commit()
+ return (
+ Transaction.query().order_by(Transaction.date_ord).all() # nummus: ignore
+ )
- t_split_id = (
- session.query(TransactionSplit.id_)
- .where(TransactionSplit.category_id == categories["rent"])
- .one()[0]
- )
- session.add(LabelLink(label_id=labels["apartments 4 U"], t_split_id=t_split_id))
- session.commit()
- return session.query(Transaction).order_by(Transaction.date_ord).all()
+@pytest.fixture
+def budget_assignments(
+ month: datetime.date,
+ month_ord: int,
+ session: orm.Session,
+ categories: dict[str, int],
+) -> list[BudgetAssignment]:
+ """Create BudgetAssignments.
+
+ Returns:
+ [
+ BudgetAssignment this month for $50 of groceries,
+ BudgetAssignment this month for $100 of emergency fund,
+ BudgetAssignment next month for $2000 of rent,
+ ]
+
+ """
+ with session.begin_nested():
+ BudgetAssignment.create(
+ month_ord=month_ord,
+ amount=Decimal(50),
+ category_id=categories["groceries"],
+ )
+ BudgetAssignment.create(
+ month_ord=month_ord,
+ amount=Decimal(100),
+ category_id=categories["emergency fund"],
+ )
+ BudgetAssignment.create(
+ month_ord=utils.date_add_months(month, 1).toordinal(),
+ amount=Decimal(2000),
+ category_id=categories["rent"],
+ )
+ return BudgetAssignment.all()
+
+
+@pytest.fixture
+def budget_target(
+ session: orm.Session,
+ categories: dict[str, int],
+) -> Target:
+ """Create a budget target.
+
+ Returns:
+ Target for Emergency Fund, $1000, no due date
+
+ """
+ with session.begin_nested():
+ return Target.create(
+ category_id=categories["emergency fund"],
+ amount=Decimal(1000),
+ type_=TargetType.BALANCE,
+ period=TargetPeriod.ONCE,
+ repeat_every=0,
+ )
@pytest.fixture(autouse=True)
@@ -719,8 +758,7 @@ def __init__(
class MockExtension(web.FlaskExtension):
@override
@classmethod
- def _open_portfolio(cls, config: flask.Config) -> Portfolio:
- _ = config
+ def _open_portfolio(cls, config: dict[str, object]) -> Portfolio:
return generator()[0]
self._ext = MockExtension()
@@ -797,65 +835,3 @@ def flask_app_encrypted(
"""
return flask_app_encrypted_generator(empty_portfolio_encrypted[0])
-
-
-@pytest.fixture
-def budget_assignments(
- month: datetime.date,
- month_ord: int,
- session: orm.Session,
- categories: dict[str, int],
-) -> list[BudgetAssignment]:
- """Create BudgetAssignments.
-
- Returns:
- [
- BudgetAssignment this month for $50 of groceries,
- BudgetAssignment this month for $100 of emergency fund,
- BudgetAssignment next month for $2000 of rent,
- ]
-
- """
- b = BudgetAssignment(
- month_ord=month_ord,
- amount=Decimal(50),
- category_id=categories["groceries"],
- )
- session.add(b)
- b = BudgetAssignment(
- month_ord=month_ord,
- amount=Decimal(100),
- category_id=categories["emergency fund"],
- )
- session.add(b)
- b = BudgetAssignment(
- month_ord=utils.date_add_months(month, 1).toordinal(),
- amount=Decimal(2000),
- category_id=categories["rent"],
- )
- session.add(b)
- session.commit()
- return list(session.query(BudgetAssignment).all())
-
-
-@pytest.fixture
-def budget_target(
- session: orm.Session,
- categories: dict[str, int],
-) -> Target:
- """Create a budget target.
-
- Returns:
- Target for Emergency Fund, $1000, no due date
-
- """
- target = Target(
- category_id=categories["emergency fund"],
- amount=Decimal(1000),
- type_=TargetType.BALANCE,
- period=TargetPeriod.ONCE,
- repeat_every=0,
- )
- session.add(target)
- session.commit()
- return target
diff --git a/tests/controllers/accounts/conftest.py b/tests/controllers/accounts/conftest.py
index dc1ef3ce..35efc12f 100644
--- a/tests/controllers/accounts/conftest.py
+++ b/tests/controllers/accounts/conftest.py
@@ -25,56 +25,55 @@ def transactions(
categories: dict[str, int],
transactions: list[Transaction],
) -> list[Transaction]:
- _ = transactions
- # Add dividends yesterday
- txn = Transaction(
- account_id=account.id_,
- date=today - datetime.timedelta(days=1),
- amount=0,
- statement=rand_str_generator(),
- payee="Monkey Bank",
- cleared=True,
- )
- t_split_0 = TransactionSplit(
- parent=txn,
- amount=-1,
- asset_id=asset.id_,
- asset_quantity_unadjusted=1,
- category_id=categories["securities traded"],
- )
- t_split_1 = TransactionSplit(
- parent=txn,
- amount=1,
- asset_id=asset.id_,
- asset_quantity_unadjusted=0,
- category_id=categories["dividends received"],
- )
- session.add_all((txn, t_split_0, t_split_1))
+ with session.begin_nested():
+ # Add dividends yesterday
+ txn = Transaction.create(
+ account_id=account.id_,
+ date=today - datetime.timedelta(days=1),
+ amount=0,
+ statement=rand_str_generator(),
+ payee="Monkey Bank",
+ cleared=True,
+ )
+ TransactionSplit.create(
+ parent=txn,
+ amount=-1,
+ asset_id=asset.id_,
+ asset_quantity_unadjusted=1,
+ category_id=categories["securities traded"],
+ )
+ TransactionSplit.create(
+ parent=txn,
+ amount=1,
+ asset_id=asset.id_,
+ asset_quantity_unadjusted=0,
+ category_id=categories["dividends received"],
+ )
- # Add fee today
- txn = Transaction(
- account_id=account.id_,
- date=today,
- amount=0,
- statement=rand_str_generator(),
- payee="Monkey Bank",
- cleared=True,
- )
- t_split_0 = TransactionSplit(
- parent=txn,
- amount=2,
- asset_id=asset.id_,
- asset_quantity_unadjusted=-2,
- category_id=categories["securities traded"],
- )
- t_split_1 = TransactionSplit(
- parent=txn,
- amount=-2,
- asset_id=asset.id_,
- asset_quantity_unadjusted=0,
- category_id=categories["investment fees"],
- )
- session.add_all((txn, t_split_0, t_split_1))
+ # Add fee today
+ txn = Transaction.create(
+ account_id=account.id_,
+ date=today,
+ amount=0,
+ statement=rand_str_generator(),
+ payee="Monkey Bank",
+ cleared=True,
+ )
+ TransactionSplit.create(
+ parent=txn,
+ amount=2,
+ asset_id=asset.id_,
+ asset_quantity_unadjusted=-2,
+ category_id=categories["securities traded"],
+ )
+ TransactionSplit.create(
+ parent=txn,
+ amount=-2,
+ asset_id=asset.id_,
+ asset_quantity_unadjusted=0,
+ category_id=categories["investment fees"],
+ )
- session.commit()
- return session.query(Transaction).order_by(Transaction.date_ord).all()
+ return (
+ Transaction.query().order_by(Transaction.date_ord).all() # nummus: ignore
+ )
diff --git a/tests/controllers/accounts/test_contexts.py b/tests/controllers/accounts/test_contexts.py
index 4be565f6..21c1f3d0 100644
--- a/tests/controllers/accounts/test_contexts.py
+++ b/tests/controllers/accounts/test_contexts.py
@@ -27,11 +27,10 @@
@pytest.mark.parametrize("skip_today", [False, True])
def test_ctx_account_empty(
today: datetime.date,
- session: orm.Session,
account: Account,
skip_today: bool,
) -> None:
- ctx = accounts.ctx_account(session, account, today, skip_today=skip_today)
+ ctx = accounts.ctx_account(account, today, skip_today=skip_today)
target: accounts.AccountContext = {
"uri": account.uri,
@@ -60,11 +59,10 @@ def test_ctx_account_empty(
def test_ctx_account(
today: datetime.date,
- session: orm.Session,
account: Account,
transactions: list[Transaction],
) -> None:
- ctx = accounts.ctx_account(session, account, today)
+ ctx = accounts.ctx_account(account, today)
target: accounts.AccountContext = {
"uri": account.uri,
@@ -93,14 +91,12 @@ def test_ctx_account(
def test_ctx_performance_empty(
today: datetime.date,
- session: orm.Session,
account: Account,
) -> None:
start = utils.date_add_months(today, -12)
labels, mode = base.date_labels(start.toordinal(), today.toordinal())
ctx = accounts.ctx_performance(
- session,
account,
today,
"1yr",
@@ -134,12 +130,11 @@ def test_ctx_performance(
asset_valuation: AssetValuation,
transactions: list[Transaction],
) -> None:
- asset_valuation.date_ord -= 7
- session.commit()
+ with session.begin_nested():
+ asset_valuation.date_ord -= 7
labels, mode = base.date_labels(transactions[0].date_ord, today.toordinal())
ctx = accounts.ctx_performance(
- session,
account,
today,
"max",
@@ -172,10 +167,9 @@ def test_ctx_performance(
def test_ctx_assets_empty(
today: datetime.date,
- session: orm.Session,
account: Account,
) -> None:
- assert accounts.ctx_assets(session, account, today) is None
+ assert accounts.ctx_assets(account, today) is None
def test_ctx_assets(
@@ -186,11 +180,10 @@ def test_ctx_assets(
asset_valuation: AssetValuation,
transactions: list[Transaction],
) -> None:
- _ = transactions
- asset_valuation.date_ord -= 7
- session.commit()
+ with session.begin_nested():
+ asset_valuation.date_ord -= 7
- ctx = accounts.ctx_assets(session, account, today)
+ ctx = accounts.ctx_assets(account, today)
target: list[accounts.AssetContext] = [
{
@@ -219,8 +212,8 @@ def test_ctx_assets(
assert ctx == target
-def test_ctx_accounts_empty(today: datetime.date, session: orm.Session) -> None:
- ctx = accounts.ctx_accounts(session, today)
+def test_ctx_accounts_empty(today: datetime.date) -> None:
+ ctx = accounts.ctx_accounts(today)
target: accounts.AllAccountsContext = {
"net_worth": Decimal(),
@@ -244,12 +237,10 @@ def test_ctx_accounts(
transactions: list[Transaction],
asset_valuation: AssetValuation,
) -> None:
- _ = transactions
- _ = asset_valuation
- account_investments.closed = True
- session.commit()
+ with session.begin_nested():
+ account_investments.closed = True
- ctx = accounts.ctx_accounts(session, today, include_closed=True)
+ ctx = accounts.ctx_accounts(today, include_closed=True)
target: accounts.AllAccountsContext = {
"net_worth": Decimal(108),
@@ -260,11 +251,11 @@ def test_ctx_accounts(
"categories": {
account.category: (
Decimal(108),
- [accounts.ctx_account(session, account, today)],
+ [accounts.ctx_account(account, today)],
),
account_investments.category: (
Decimal(),
- [accounts.ctx_account(session, account_investments, today)],
+ [accounts.ctx_account(account_investments, today)],
),
},
"include_closed": True,
diff --git a/tests/controllers/accounts/test_endpoints.py b/tests/controllers/accounts/test_endpoints.py
index fbcdd8a2..b189c0e6 100644
--- a/tests/controllers/accounts/test_endpoints.py
+++ b/tests/controllers/accounts/test_endpoints.py
@@ -45,7 +45,6 @@ def test_txns_options(
account: Account,
transactions: list[Transaction],
) -> None:
- _ = transactions
result, _ = web_client.GET(("accounts.txns_options", {"uri": account.uri}))
assert 'name="period"' in result
assert 'name="category"' in result
@@ -78,7 +77,6 @@ def test_page(
account: Account,
transactions: list[Transaction],
) -> None:
- _ = transactions
result, _ = web_client.GET(("accounts.page", {"uri": account.uri}))
assert "Transactions" in result
assert "Balance" in result
@@ -94,9 +92,8 @@ def test_page_performance(
account: Account,
transactions: list[Transaction],
) -> None:
- _ = transactions
- account.category = AccountCategory.INVESTMENT
- session.commit()
+ with session.begin_nested():
+ account.category = AccountCategory.INVESTMENT
result, _ = web_client.GET(("accounts.page", {"uri": account.uri}))
assert "Transactions" in result
@@ -133,7 +130,7 @@ def test_new(
assert "All changes saved" in result
assert "account" in headers["HX-Trigger"]
- account = session.query(Account).one()
+ account = Account.one()
assert account.name == "New name"
assert account.category == AccountCategory.INVESTMENT
assert account.currency == Currency.USD
@@ -186,7 +183,6 @@ def test_account_get(
account: Account,
transactions: list[Transaction],
) -> None:
- _ = transactions
result, _ = web_client.GET(("accounts.account", {"uri": account.uri}))
assert account.name in result
assert account.institution in result
@@ -242,7 +238,6 @@ def test_account_edit_error(
target: str,
transactions: list[Transaction],
) -> None:
- _ = transactions
form = {
"name": name,
"category": "INVESTMENT",
@@ -268,7 +263,6 @@ def test_performance(
account: Account,
transactions: list[Transaction],
) -> None:
- _ = transactions
result, headers = web_client.GET(("accounts.performance", {"uri": account.uri}))
assert headers["HX-Push-URL"] == web_client.url_for(
"accounts.page",
@@ -300,7 +294,6 @@ def test_validation(
value: str,
target: str,
) -> None:
- _ = account_investments
result, _ = web_client.GET(
(
"accounts.validation",
diff --git a/tests/controllers/allocation/test_contexts.py b/tests/controllers/allocation/test_contexts.py
index c537a592..a455dda1 100644
--- a/tests/controllers/allocation/test_contexts.py
+++ b/tests/controllers/allocation/test_contexts.py
@@ -9,8 +9,6 @@
if TYPE_CHECKING:
import datetime
- from sqlalchemy import orm
-
from nummus.models.asset import (
Asset,
AssetSector,
@@ -19,8 +17,8 @@
from nummus.models.transaction import Transaction
-def test_ctx_empty(today: datetime.date, session: orm.Session) -> None:
- ctx = allocation.ctx_allocation(session, today)
+def test_ctx_empty(today: datetime.date) -> None:
+ ctx = allocation.ctx_allocation(today)
target: allocation.AllocationContext = {
"chart": {
@@ -37,17 +35,12 @@ def test_ctx_empty(today: datetime.date, session: orm.Session) -> None:
def test_ctx(
today: datetime.date,
- session: orm.Session,
asset: Asset,
transactions: list[Transaction],
asset_valuation: AssetValuation,
asset_sectors: tuple[AssetSector, AssetSector],
) -> None:
- _ = transactions
- _ = asset_valuation
- _ = asset_sectors
-
- ctx = allocation.ctx_allocation(session, today)
+ ctx = allocation.ctx_allocation(today)
target: allocation.AllocationContext = {
"chart": {
diff --git a/tests/controllers/assets/test_contexts.py b/tests/controllers/assets/test_contexts.py
index 8c7ba68d..968aa74d 100644
--- a/tests/controllers/assets/test_contexts.py
+++ b/tests/controllers/assets/test_contexts.py
@@ -12,8 +12,6 @@
if TYPE_CHECKING:
import datetime
- from sqlalchemy import orm
-
from nummus.models.account import Account
from nummus.models.asset import (
Asset,
@@ -25,11 +23,10 @@
def test_ctx_performance_empty(
today: datetime.date,
- session: orm.Session,
asset: Asset,
) -> None:
start = utils.date_add_months(today, -12)
- ctx = assets.ctx_performance(session, asset, today, "1yr")
+ ctx = assets.ctx_performance(asset, today, "1yr")
labels, mode = base.date_labels(start.toordinal(), today.toordinal())
target: assets.PerformanceContext = {
"mode": mode,
@@ -46,11 +43,10 @@ def test_ctx_performance_empty(
def test_ctx_performance(
today: datetime.date,
- session: orm.Session,
asset: Asset,
asset_valuation: AssetValuation,
) -> None:
- ctx = assets.ctx_performance(session, asset, today, "max")
+ ctx = assets.ctx_performance(asset, today, "max")
labels, mode = base.date_labels(asset_valuation.date_ord, today.toordinal())
target: assets.PerformanceContext = {
"mode": mode,
@@ -68,10 +64,9 @@ def test_ctx_performance(
def test_ctx_table_empty(
today: datetime.date,
month: datetime.date,
- session: orm.Session,
asset: Asset,
) -> None:
- ctx = assets.ctx_table(session, asset, today, None, None, None, None)
+ ctx = assets.ctx_table(asset, today, None, None, None, None)
last_months = [utils.date_add_months(month, i) for i in range(0, -3, -1)]
options_period = [
@@ -111,7 +106,6 @@ def test_ctx_table_empty(
)
def test_ctx_table(
today: datetime.date,
- session: orm.Session,
asset: Asset,
asset_valuation: AssetValuation,
period: str | None,
@@ -121,7 +115,7 @@ def test_ctx_table(
any_filters: bool,
has_valuation: bool,
) -> None:
- ctx = assets.ctx_table(session, asset, today, period, start, end, page)
+ ctx = assets.ctx_table(asset, today, period, start, end, page)
if page is None:
assert ctx["first_page"]
@@ -151,10 +145,9 @@ def test_ctx_table(
def test_ctx_asset_empty(
today: datetime.date,
- session: orm.Session,
asset: Asset,
) -> None:
- ctx = assets.ctx_asset(session, asset, today, None, None, None, None, None)
+ ctx = assets.ctx_asset(asset, today, None, None, None, None, None)
assert ctx["uri"] == asset.uri
assert ctx["name"] == asset.name
assert ctx["category"] == asset.category
@@ -166,14 +159,12 @@ def test_ctx_asset_empty(
def test_ctx_asset(
today: datetime.date,
- session: orm.Session,
account: Account,
asset: Asset,
asset_valuation: AssetValuation,
transactions: list[Transaction],
) -> None:
- _ = transactions
- ctx = assets.ctx_asset(session, asset, today, None, None, None, None, None)
+ ctx = assets.ctx_asset(asset, today, None, None, None, None, None)
assert ctx["uri"] == asset.uri
assert ctx["name"] == asset.name
assert ctx["category"] == asset.category
@@ -185,17 +176,16 @@ def test_ctx_asset(
]
-def test_ctx_rows_empty(today: datetime.date, session: orm.Session) -> None:
- ctx = assets.ctx_rows(session, today, include_unheld=True)
+def test_ctx_rows_empty(today: datetime.date) -> None:
+ ctx = assets.ctx_rows(today, include_unheld=True)
assert ctx == {}
def test_ctx_rows_unheld(
today: datetime.date,
- session: orm.Session,
asset: Asset,
) -> None:
- ctx = assets.ctx_rows(session, today, include_unheld=True)
+ ctx = assets.ctx_rows(today, include_unheld=True)
target: dict[AssetCategory, list[assets.RowContext]] = {
asset.category: [
{
@@ -214,13 +204,11 @@ def test_ctx_rows_unheld(
def test_ctx_rows(
today: datetime.date,
- session: orm.Session,
asset: Asset,
asset_valuation: AssetValuation,
transactions: list[Transaction],
) -> None:
- _ = transactions
- ctx = assets.ctx_rows(session, today, include_unheld=False)
+ ctx = assets.ctx_rows(today, include_unheld=False)
target: dict[AssetCategory, list[assets.RowContext]] = {
asset.category: [
{
diff --git a/tests/controllers/assets/test_endpoints.py b/tests/controllers/assets/test_endpoints.py
index a36c785f..ae412313 100644
--- a/tests/controllers/assets/test_endpoints.py
+++ b/tests/controllers/assets/test_endpoints.py
@@ -4,6 +4,7 @@
import pytest
+from nummus import sql
from nummus.controllers import base
from nummus.models.asset import (
Asset,
@@ -11,7 +12,6 @@
AssetValuation,
)
from nummus.models.currency import Currency, CURRENCY_FORMATS, DEFAULT_CURRENCY
-from nummus.models.utils import query_count
if TYPE_CHECKING:
import datetime
@@ -56,8 +56,8 @@ def test_new(
web_client: WebClient,
session: orm.Session,
) -> None:
- session.query(Asset).delete()
- session.commit()
+ with session.begin_nested():
+ Asset.query().delete()
result, headers = web_client.POST(
"assets.new",
@@ -73,7 +73,7 @@ def test_new(
assert "All changes saved" in result
assert "asset" in headers["HX-Trigger"]
- a = session.query(Asset).one()
+ a = Asset.one()
assert a.name == "New name"
assert a.category == AssetCategory.STOCKS
assert a.currency == Currency.USD
@@ -112,7 +112,6 @@ def test_asset_get(
asset: Asset,
transactions: list[Transaction],
) -> None:
- _ = transactions
result, _ = web_client.GET(("assets.asset", {"uri": asset.uri}))
assert asset.name in result
assert asset.ticker is not None
@@ -124,7 +123,7 @@ def test_asset_get(
assert "Delete" not in result
-def test_asset_edit(web_client: WebClient, session: orm.Session, asset: Asset) -> None:
+def test_asset_edit(web_client: WebClient, asset: Asset) -> None:
result, headers = web_client.PUT(
("assets.asset", {"uri": asset.uri}),
data={
@@ -139,7 +138,7 @@ def test_asset_edit(web_client: WebClient, session: orm.Session, asset: Asset) -
assert "All changes saved" in result
assert "asset" in headers["HX-Trigger"]
- session.refresh(asset)
+ asset.refresh()
assert asset.name == "New name"
assert asset.category == AssetCategory.BONDS
assert asset.currency == Currency.EUR
@@ -166,7 +165,6 @@ def test_account_delete(
asset: Asset,
asset_valuation: AssetValuation,
) -> None:
- _ = asset_valuation
result, headers = web_client.DELETE(("assets.asset", {"uri": asset.uri}))
assert not result
assert headers["HX-Redirect"] == web_client.url_for("assets.page_all")
@@ -225,7 +223,6 @@ def test_validation(
value: str,
target: str,
) -> None:
- _ = asset_etf
args = {prop: value}
if include_asset:
args["uri"] = asset.uri
@@ -247,7 +244,6 @@ def test_new_valuation_get(
def test_new_valuation(
today: datetime.date,
- session: orm.Session,
web_client: WebClient,
asset: Asset,
rand_real: Decimal,
@@ -263,7 +259,7 @@ def test_new_valuation(
assert "All changes saved" in result
assert "valuation" in headers["HX-Trigger"]
- v = session.query(AssetValuation).one()
+ v = AssetValuation.one()
assert v.asset_id == asset.id_
assert v.date == today
assert v.value == rand_real
@@ -323,7 +319,6 @@ def test_valuation_get(
def test_valuation_delete(
- session: orm.Session,
web_client: WebClient,
asset_valuation: AssetValuation,
) -> None:
@@ -334,13 +329,12 @@ def test_valuation_delete(
assert f"{asset_valuation.date} valuation deleted" in result
assert "valuation" in headers["HX-Trigger"]
- v = session.query(AssetValuation).one_or_none()
+ v = AssetValuation.query().one_or_none()
assert v is None
def test_valuation_edit(
tomorrow: datetime.date,
- session: orm.Session,
web_client: WebClient,
asset_valuation: AssetValuation,
rand_real: Decimal,
@@ -356,7 +350,7 @@ def test_valuation_edit(
assert "All changes saved" in result
assert "valuation" in headers["HX-Trigger"]
- session.refresh(asset_valuation)
+ asset_valuation.refresh()
assert asset_valuation.date == tomorrow
assert asset_valuation.value == rand_real
@@ -396,13 +390,12 @@ def test_valuation_duplicate(
asset: Asset,
asset_valuation: AssetValuation,
) -> None:
- v = AssetValuation(
- asset_id=asset.id_,
- date_ord=tomorrow.toordinal(),
- value=asset_valuation.value,
- )
- session.add(v)
- session.commit()
+ with session.begin_nested():
+ AssetValuation.create(
+ asset_id=asset.id_,
+ date_ord=tomorrow.toordinal(),
+ value=asset_valuation.value,
+ )
result, _ = web_client.PUT(
("assets.valuation", {"uri": asset_valuation.uri}),
@@ -415,8 +408,8 @@ def test_valuation_duplicate(
def test_update_get_empty(session: orm.Session, web_client: WebClient) -> None:
- session.query(Asset).delete()
- session.commit()
+ with session.begin_nested():
+ Asset.query().delete()
result, _ = web_client.GET("assets.update")
assert "Update assets" in result
@@ -428,17 +421,16 @@ def test_update_get_one(
web_client: WebClient,
asset: Asset,
) -> None:
- _ = asset
- session.query(Asset).where(Asset.category == AssetCategory.INDEX).delete()
- session.commit()
+ with session.begin_nested():
+ Asset.query().where(Asset.category == AssetCategory.INDEX).delete()
result, _ = web_client.GET("assets.update")
assert "Update assets" in result
assert "There is one asset with ticker to update" in result
-def test_update_get(session: orm.Session, web_client: WebClient) -> None:
- n = query_count(session.query(Asset))
+def test_update_get(web_client: WebClient) -> None:
+ n = sql.count(Asset.query())
result, _ = web_client.GET("assets.update")
assert "Update assets" in result
@@ -456,10 +448,8 @@ def test_update(
asset: Asset,
transactions: list[Transaction],
) -> None:
- _ = asset
- _ = transactions
- session.query(Asset).where(Asset.category == AssetCategory.INDEX).delete()
- session.commit()
+ with session.begin_nested():
+ Asset.query().where(Asset.category == AssetCategory.INDEX).delete()
result, headers = web_client.POST("assets.update")
assert "snackbar.show" in result
@@ -471,6 +461,5 @@ def test_update_error(
web_client: WebClient,
transactions: list[Transaction],
) -> None:
- _ = transactions
result, _ = web_client.POST("assets.update")
assert "No timezone found, symbol may be delisted" in result
diff --git a/tests/controllers/budgeting/test_contexts.py b/tests/controllers/budgeting/test_contexts.py
index 93bc5d02..bd7fbfe8 100644
--- a/tests/controllers/budgeting/test_contexts.py
+++ b/tests/controllers/budgeting/test_contexts.py
@@ -6,7 +6,7 @@
import pytest
-from nummus import utils
+from nummus import sql, utils
from nummus.controllers import budgeting
from nummus.models.budget import (
BudgetAssignment,
@@ -33,18 +33,13 @@
def test_ctx_sidebar_global(
today: datetime.date,
month: datetime.date,
- session: orm.Session,
transactions_spending: list[Transaction],
budget_assignments: list[BudgetAssignment],
budget_target: Target,
) -> None:
- _ = transactions_spending
- _ = budget_assignments
- _ = budget_target
- data = BudgetAssignment.get_monthly_available(session, month)
+ data = BudgetAssignment.get_monthly_available(month)
ctx = budgeting.ctx_sidebar(
- session,
today,
month,
data.categories,
@@ -52,11 +47,11 @@ def test_ctx_sidebar_global(
None,
)
- query = session.query(TransactionCategory).where(
+ query = TransactionCategory.query().where(
TransactionCategory.name.not_in({"emergency fund"}),
TransactionCategory.group != TransactionCategoryGroup.INCOME,
)
- categories = {t_cat.uri: t_cat.emoji_name for t_cat in query.all()}
+ categories = {t_cat.uri: t_cat.emoji_name for t_cat in sql.yield_(query)}
target: budgeting.SidebarContext = {
"uri": None,
@@ -77,18 +72,14 @@ def test_ctx_sidebar_global(
def test_ctx_sidebar_no_target(
today: datetime.date,
month: datetime.date,
- session: orm.Session,
transactions_spending: list[Transaction],
budget_assignments: list[BudgetAssignment],
categories: dict[str, int],
) -> None:
- _ = transactions_spending
- _ = budget_assignments
- data = BudgetAssignment.get_monthly_available(session, month)
+ data = BudgetAssignment.get_monthly_available(month)
uri = TransactionCategory.id_to_uri(categories["emergency fund"])
ctx = budgeting.ctx_sidebar(
- session,
today,
month,
data.categories,
@@ -113,19 +104,15 @@ def test_ctx_sidebar_no_target(
def test_ctx_sidebar(
today: datetime.date,
month: datetime.date,
- session: orm.Session,
transactions_spending: list[Transaction],
budget_assignments: list[BudgetAssignment],
budget_target: Target,
) -> None:
- _ = transactions_spending
- _ = budget_assignments
- data = BudgetAssignment.get_monthly_available(session, month)
+ data = BudgetAssignment.get_monthly_available(month)
data_cat = data.categories[budget_target.category_id]
uri = TransactionCategory.id_to_uri(budget_target.category_id)
ctx = budgeting.ctx_sidebar(
- session,
today,
month,
data.categories,
@@ -191,8 +178,8 @@ def test_ctx_target_once(
budget_target: Target,
) -> None:
due_date = utils.date_add_months(month, 8)
- budget_target.due_date_ord = due_date.toordinal()
- session.commit()
+ with session.begin_nested():
+ budget_target.due_date_ord = due_date.toordinal()
ctx = budgeting.ctx_target(
budget_target,
@@ -226,11 +213,11 @@ def test_ctx_target_weekly_refil(
budget_target: Target,
) -> None:
month = utils.date_add_months(month, 1)
- budget_target.period = TargetPeriod.WEEK
- budget_target.type_ = TargetType.REFILL
- budget_target.due_date_ord = today.toordinal()
- budget_target.repeat_every = 1
- session.commit()
+ with session.begin_nested():
+ budget_target.period = TargetPeriod.WEEK
+ budget_target.type_ = TargetType.REFILL
+ budget_target.due_date_ord = today.toordinal()
+ budget_target.repeat_every = 1
n_weekdays = utils.weekdays_in_month(today.weekday(), month)
ctx = budgeting.ctx_target(
@@ -264,11 +251,11 @@ def test_ctx_target_weekly_accumulate(
session: orm.Session,
budget_target: Target,
) -> None:
- budget_target.period = TargetPeriod.WEEK
- budget_target.type_ = TargetType.ACCUMULATE
- budget_target.due_date_ord = today.toordinal()
- budget_target.repeat_every = 1
- session.commit()
+ with session.begin_nested():
+ budget_target.period = TargetPeriod.WEEK
+ budget_target.type_ = TargetType.ACCUMULATE
+ budget_target.due_date_ord = today.toordinal()
+ budget_target.repeat_every = 1
n_weekdays = utils.weekdays_in_month(today.weekday(), month)
ctx = budgeting.ctx_target(
@@ -303,11 +290,11 @@ def test_ctx_target_monthly_refil(
budget_target: Target,
) -> None:
month = utils.date_add_months(month, 1)
- budget_target.period = TargetPeriod.MONTH
- budget_target.type_ = TargetType.REFILL
- budget_target.due_date_ord = today.toordinal()
- budget_target.repeat_every = 2
- session.commit()
+ with session.begin_nested():
+ budget_target.period = TargetPeriod.MONTH
+ budget_target.type_ = TargetType.REFILL
+ budget_target.due_date_ord = today.toordinal()
+ budget_target.repeat_every = 2
ctx = budgeting.ctx_target(
budget_target,
@@ -340,11 +327,11 @@ def test_ctx_target_monthly_accumulate(
session: orm.Session,
budget_target: Target,
) -> None:
- budget_target.period = TargetPeriod.MONTH
- budget_target.type_ = TargetType.ACCUMULATE
- budget_target.due_date_ord = today.toordinal()
- budget_target.repeat_every = 1
- session.commit()
+ with session.begin_nested():
+ budget_target.period = TargetPeriod.MONTH
+ budget_target.type_ = TargetType.ACCUMULATE
+ budget_target.due_date_ord = today.toordinal()
+ budget_target.repeat_every = 1
ctx = budgeting.ctx_target(
budget_target,
@@ -374,11 +361,9 @@ def test_ctx_target_monthly_accumulate(
def test_ctx_budget_empty(
today: datetime.date,
month: datetime.date,
- session: orm.Session,
) -> None:
- data = BudgetAssignment.get_monthly_available(session, month)
+ data = BudgetAssignment.get_monthly_available(month)
ctx, title = budgeting.ctx_budget(
- session,
today,
month,
data.categories,
@@ -419,21 +404,18 @@ def test_ctx_budget(
budget_target: Target,
categories: dict[str, int],
) -> None:
- session.query(TransactionCategory).where(
- TransactionCategory.name == "groceries",
- ).update(
- {
- TransactionCategory.budget_group_id: budget_group.id_,
- TransactionCategory.budget_position: 0,
- },
- )
- session.commit()
- _ = transactions_spending
- _ = budget_assignments
- data = BudgetAssignment.get_monthly_available(session, month)
+ with session.begin_nested():
+ TransactionCategory.query().where(
+ TransactionCategory.name == "groceries",
+ ).update(
+ {
+ TransactionCategory.budget_group_id: budget_group.id_,
+ TransactionCategory.budget_position: 0,
+ },
+ )
+ data = BudgetAssignment.get_monthly_available(month)
ctx, title = budgeting.ctx_budget(
- session,
today,
month,
data.categories,
diff --git a/tests/controllers/budgeting/test_endpoints.py b/tests/controllers/budgeting/test_endpoints.py
index 1c2e568c..0aedd8da 100644
--- a/tests/controllers/budgeting/test_endpoints.py
+++ b/tests/controllers/budgeting/test_endpoints.py
@@ -7,6 +7,7 @@
import pytest
import werkzeug.datastructures
+from nummus import sql
from nummus.controllers import base, budgeting
from nummus.models.budget import (
BudgetAssignment,
@@ -16,7 +17,6 @@
TargetType,
)
from nummus.models.transaction_category import TransactionCategory
-from nummus.models.utils import query_count
if TYPE_CHECKING:
import flask
@@ -53,7 +53,6 @@ def test_validation(
def test_assign_new(
month: datetime.date,
- session: orm.Session,
web_client: WebClient,
categories: dict[str, int],
rand_real: Decimal,
@@ -68,7 +67,7 @@ def test_assign_new(
assert "Ungrouped" in result
- a = session.query(BudgetAssignment).one()
+ a = BudgetAssignment.one()
assert a.category_id == t_cat_id
assert a.month_ord == month.toordinal()
assert a.amount == round(rand_real, 2)
@@ -76,7 +75,6 @@ def test_assign_new(
def test_assign_edit(
month: datetime.date,
- session: orm.Session,
web_client: WebClient,
budget_assignments: list[BudgetAssignment],
) -> None:
@@ -90,14 +88,13 @@ def test_assign_edit(
assert "Ungrouped" in result
- session.refresh(a)
+ a.refresh()
assert a.month_ord == month.toordinal()
assert a.amount == Decimal(10)
def test_assign_remove(
month: datetime.date,
- session: orm.Session,
web_client: WebClient,
budget_assignments: list[BudgetAssignment],
) -> None:
@@ -111,11 +108,7 @@ def test_assign_remove(
assert "Ungrouped" in result
- a = (
- session.query(BudgetAssignment)
- .where(BudgetAssignment.id_ == a.id_)
- .one_or_none()
- )
+ a = BudgetAssignment.query().where(BudgetAssignment.id_ == a.id_).one_or_none()
assert a is None
@@ -125,8 +118,6 @@ def test_move_get_income(
transactions_spending: list[Transaction],
budget_assignments: list[BudgetAssignment],
) -> None:
- _ = transactions_spending
- _ = budget_assignments
result, _ = web_client.GET(
("budgeting.move", {"uri": "income", "month": month.isoformat()[:7]}),
)
@@ -143,8 +134,6 @@ def test_move_get(
categories: dict[str, int],
budget_assignments: list[BudgetAssignment],
) -> None:
- _ = transactions_spending
- _ = budget_assignments
t_cat_uri = TransactionCategory.id_to_uri(categories["groceries"])
result, _ = web_client.GET(
@@ -163,8 +152,6 @@ def test_move_get_destination(
categories: dict[str, int],
budget_assignments: list[BudgetAssignment],
) -> None:
- _ = transactions_spending
- _ = budget_assignments
t_cat_uri = TransactionCategory.id_to_uri(categories["groceries"])
result, _ = web_client.GET(
@@ -185,7 +172,6 @@ def test_move_get_overspending(
transactions_spending: list[Transaction],
categories: dict[str, int],
) -> None:
- _ = transactions_spending
t_cat_uri = TransactionCategory.id_to_uri(categories["groceries"])
result, _ = web_client.GET(
@@ -199,13 +185,11 @@ def test_move_get_overspending(
def test_move_overspending(
month: datetime.date,
- session: orm.Session,
web_client: WebClient,
transactions_spending: list[Transaction],
budget_assignments: list[BudgetAssignment],
categories: dict[str, int],
) -> None:
- _ = transactions_spending
a = budget_assignments[0]
uri = TransactionCategory.id_to_uri(categories["securities traded"])
dest_uri = TransactionCategory.id_to_uri(a.category_id)
@@ -218,19 +202,17 @@ def test_move_overspending(
assert "$30.00 reallocated" in result
assert "budget" in headers["HX-Trigger"]
- session.refresh(a)
+ a.refresh()
assert a.month_ord == month.toordinal()
assert a.amount == Decimal(20)
def test_move_to_income(
month: datetime.date,
- session: orm.Session,
web_client: WebClient,
transactions_spending: list[Transaction],
budget_assignments: list[BudgetAssignment],
) -> None:
- _ = transactions_spending
a = budget_assignments[0]
t_cat_uri = TransactionCategory.id_to_uri(a.category_id)
@@ -242,7 +224,7 @@ def test_move_to_income(
assert "$10.00 reallocated" in result
assert "budget" in headers["HX-Trigger"]
- session.refresh(a)
+ a.refresh()
assert a.month_ord == month.toordinal()
assert a.amount == Decimal(40)
@@ -263,11 +245,9 @@ def test_move_error(
def test_reorder_empty(
- session: orm.Session,
web_client: WebClient,
budget_group: BudgetGroup,
) -> None:
- _ = budget_group
result, _ = web_client.PUT(
"budgeting.reorder",
data={
@@ -278,37 +258,29 @@ def test_reorder_empty(
)
assert not result
- query = session.query(BudgetGroup)
- assert query_count(query) == 0
- query = session.query(TransactionCategory).where(
+ assert not sql.any_(BudgetGroup.query())
+ query = TransactionCategory.query().where(
TransactionCategory.budget_group_id.is_not(None),
)
- assert query_count(query) == 0
- query = session.query(TransactionCategory).where(
+ assert not sql.any_(query)
+ query = TransactionCategory.query().where(
TransactionCategory.budget_position.is_not(None),
)
- assert query_count(query) == 0
+ assert not sql.any_(query)
def test_reorder(
- session: orm.Session,
web_client: WebClient,
budget_group: BudgetGroup,
) -> None:
- t_cat_0 = (
- session.query(TransactionCategory)
- .where(TransactionCategory.name == "groceries")
- .one()
+ t_cat_0 = sql.one(
+ TransactionCategory.query().where(TransactionCategory.name == "groceries"),
)
- t_cat_1 = (
- session.query(TransactionCategory)
- .where(TransactionCategory.name == "rent")
- .one()
+ t_cat_1 = sql.one(
+ TransactionCategory.query().where(TransactionCategory.name == "rent"),
)
- t_cat_2 = (
- session.query(TransactionCategory)
- .where(TransactionCategory.name == "transfers")
- .one()
+ t_cat_2 = sql.one(
+ TransactionCategory.query().where(TransactionCategory.name == "transfers"),
)
result, _ = web_client.PUT(
@@ -321,15 +293,15 @@ def test_reorder(
)
assert not result
- session.refresh(t_cat_0)
+ t_cat_0.refresh()
assert t_cat_0.budget_group_id == budget_group.id_
assert t_cat_0.budget_position == 0
- session.refresh(t_cat_1)
+ t_cat_1.refresh()
assert t_cat_1.budget_group_id == budget_group.id_
assert t_cat_1.budget_position == 1
- session.refresh(t_cat_2)
+ t_cat_2.refresh()
assert t_cat_2.budget_group_id is None
assert t_cat_2.budget_position is None
@@ -500,10 +472,10 @@ def test_target_get_once(
web_client: WebClient,
budget_target: Target,
) -> None:
- budget_target.type_ = TargetType.BALANCE
- budget_target.period = TargetPeriod.ONCE
- budget_target.due_date_ord = today_ord
- session.commit()
+ with session.begin_nested():
+ budget_target.type_ = TargetType.BALANCE
+ budget_target.period = TargetPeriod.ONCE
+ budget_target.due_date_ord = today_ord
result, _ = web_client.GET(("budgeting.target", {"uri": budget_target.uri}))
assert "Edit target" in result
@@ -511,7 +483,6 @@ def test_target_get_once(
def test_target_new(
today_ord: int,
- session: orm.Session,
web_client: WebClient,
categories: dict[str, int],
) -> None:
@@ -527,7 +498,7 @@ def test_target_new(
assert "Groceries target created" in result
assert "budget" in headers["HX-Trigger"]
- tar = session.query(Target).one()
+ tar = Target.one()
assert tar.category_id == t_cat_id
assert tar.amount == Decimal(10)
assert tar.type_ == TargetType.ACCUMULATE
@@ -554,7 +525,6 @@ def test_target_new_error(
def test_target_put(
- session: orm.Session,
web_client: WebClient,
budget_target: Target,
) -> None:
@@ -566,12 +536,11 @@ def test_target_put(
assert "All changes saved" in result
assert "budget" in headers["HX-Trigger"]
- session.refresh(budget_target)
+ budget_target.refresh()
assert budget_target.amount == Decimal(10)
def test_target_delete(
- session: orm.Session,
web_client: WebClient,
budget_target: Target,
) -> None:
@@ -582,7 +551,7 @@ def test_target_delete(
assert "Emergency Fund target deleted" in result
assert "budget" in headers["HX-Trigger"]
- tar = session.query(Target).one_or_none()
+ tar = Target.query().one_or_none()
assert tar is None
diff --git a/tests/controllers/conftest.py b/tests/controllers/conftest.py
index 7558cd91..d918e5d0 100644
--- a/tests/controllers/conftest.py
+++ b/tests/controllers/conftest.py
@@ -254,7 +254,7 @@ def has_valid_classes(self, inner_html: str) -> bool:
ResultType = dict[str, object] | str | bytes
-Queries = dict[str, str] | dict[str, str | bool | list[str | bool]]
+Queries = dict[str, str] | dict[str, str | bool | list[str] | list[str | bool]]
class HTMLValidator:
diff --git a/tests/controllers/emergency_fund/test_contexts.py b/tests/controllers/emergency_fund/test_contexts.py
index e31280aa..95d1f921 100644
--- a/tests/controllers/emergency_fund/test_contexts.py
+++ b/tests/controllers/emergency_fund/test_contexts.py
@@ -19,12 +19,12 @@
from nummus.models.budget import BudgetAssignment
-def test_empty(today: datetime.date, session: orm.Session) -> None:
+def test_empty(today: datetime.date) -> None:
start = today - datetime.timedelta(days=utils.DAYS_IN_QUARTER * 2)
dates = utils.range_date(start.toordinal(), today.toordinal())
n = len(dates)
- ctx = emergency_fund.ctx_page(session, today)
+ ctx = emergency_fund.ctx_page(today)
target: emergency_fund.EFundContext = {
"chart": {
@@ -56,26 +56,23 @@ def test_ctx_underfunded(
budget_assignments: list[BudgetAssignment],
rand_str: str,
) -> None:
- _ = transactions_spending
- _ = budget_assignments
- session.query(TransactionCategory).where(
- TransactionCategory.name == "groceries",
- ).update({"essential_spending": True})
- txn = Transaction(
- account_id=account.id_,
- date=today - datetime.timedelta(days=100),
- amount=-1000,
- statement=rand_str,
- )
- t_split = TransactionSplit(
- parent=txn,
- amount=txn.amount,
- category_id=categories["groceries"],
- )
- session.add_all((txn, t_split))
- session.commit()
-
- ctx = emergency_fund.ctx_page(session, today)
+ with session.begin_nested():
+ TransactionCategory.query().where(
+ TransactionCategory.name == "groceries",
+ ).update({"essential_spending": True})
+ txn = Transaction.create(
+ account_id=account.id_,
+ date=today - datetime.timedelta(days=100),
+ amount=-1000,
+ statement=rand_str,
+ )
+ TransactionSplit.create(
+ parent=txn,
+ amount=txn.amount,
+ category_id=categories["groceries"],
+ )
+
+ ctx = emergency_fund.ctx_page(today)
assert ctx["current"] == Decimal(100)
assert ctx["days"] == pytest.approx(Decimal(34), abs=Decimal(1))
@@ -98,26 +95,23 @@ def test_ctx_overfunded(
budget_assignments: list[BudgetAssignment],
rand_str: str,
) -> None:
- _ = transactions_spending
- _ = budget_assignments
- session.query(TransactionCategory).where(
- TransactionCategory.name == "groceries",
- ).update({"essential_spending": True})
- txn = Transaction(
- account_id=account.id_,
- date=today - datetime.timedelta(days=100),
- amount=-50,
- statement=rand_str,
- )
- t_split = TransactionSplit(
- parent=txn,
- amount=txn.amount,
- category_id=categories["groceries"],
- )
- session.add_all((txn, t_split))
- session.commit()
-
- ctx = emergency_fund.ctx_page(session, today)
+ with session.begin_nested():
+ TransactionCategory.query().where(
+ TransactionCategory.name == "groceries",
+ ).update({"essential_spending": True})
+ txn = Transaction.create(
+ account_id=account.id_,
+ date=today - datetime.timedelta(days=100),
+ amount=-50,
+ statement=rand_str,
+ )
+ TransactionSplit.create(
+ parent=txn,
+ amount=txn.amount,
+ category_id=categories["groceries"],
+ )
+
+ ctx = emergency_fund.ctx_page(today)
assert ctx["current"] == Decimal(100)
assert ctx["days"] == pytest.approx(Decimal(347), abs=Decimal(1))
@@ -140,26 +134,23 @@ def test_ctx(
budget_assignments: list[BudgetAssignment],
rand_str: str,
) -> None:
- _ = transactions_spending
- _ = budget_assignments
- session.query(TransactionCategory).where(
- TransactionCategory.name == "groceries",
- ).update({"essential_spending": True})
- txn = Transaction(
- account_id=account.id_,
- date=today - datetime.timedelta(days=100),
- amount=-200,
- statement=rand_str,
- )
- t_split = TransactionSplit(
- parent=txn,
- amount=txn.amount,
- category_id=categories["groceries"],
- )
- session.add_all((txn, t_split))
- session.commit()
-
- ctx = emergency_fund.ctx_page(session, today)
+ with session.begin_nested():
+ TransactionCategory.query().where(
+ TransactionCategory.name == "groceries",
+ ).update({"essential_spending": True})
+ txn = Transaction.create(
+ account_id=account.id_,
+ date=today - datetime.timedelta(days=100),
+ amount=-200,
+ statement=rand_str,
+ )
+ TransactionSplit.create(
+ parent=txn,
+ amount=txn.amount,
+ category_id=categories["groceries"],
+ )
+
+ ctx = emergency_fund.ctx_page(today)
assert ctx["current"] == Decimal(100)
assert ctx["days"] == pytest.approx(Decimal(119), abs=Decimal(1))
diff --git a/tests/controllers/health/test_contexts.py b/tests/controllers/health/test_contexts.py
index 9ee0f0ab..46857d02 100644
--- a/tests/controllers/health/test_contexts.py
+++ b/tests/controllers/health/test_contexts.py
@@ -1,18 +1,13 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
-
+from nummus import sql
from nummus.controllers import health
from nummus.health_checks.top import HEALTH_CHECKS
from nummus.models.transaction_category import TransactionCategory
-from nummus.models.utils import query_count
-
-if TYPE_CHECKING:
- from sqlalchemy import orm
-def test_ctx_empty(session: orm.Session) -> None:
- ctx = health.ctx_checks(session, run=False)
+def test_ctx_empty() -> None:
+ ctx = health.ctx_checks(run=False)
assert ctx["last_update_ago"] is None
checks = ctx["checks"]
@@ -21,8 +16,8 @@ def test_ctx_empty(session: orm.Session) -> None:
assert not has_issues
-def test_ctx_empty_run(session: orm.Session) -> None:
- ctx = health.ctx_checks(session, run=True)
+def test_ctx_empty_run() -> None:
+ ctx = health.ctx_checks(run=True)
assert ctx["last_update_ago"] == 0
checks = ctx["checks"]
@@ -33,7 +28,7 @@ def test_ctx_empty_run(session: orm.Session) -> None:
assert c["name"] == "Unused categories"
# All unused
- query = session.query(TransactionCategory).where(
+ query = TransactionCategory.query().where(
TransactionCategory.locked.is_(False),
)
- assert len(c["issues"]) == query_count(query)
+ assert len(c["issues"]) == sql.count(query)
diff --git a/tests/controllers/health/test_endpoints.py b/tests/controllers/health/test_endpoints.py
index b36571f1..24844fae 100644
--- a/tests/controllers/health/test_endpoints.py
+++ b/tests/controllers/health/test_endpoints.py
@@ -35,9 +35,9 @@ def test_refresh(web_client: WebClient, n_runs: int) -> None:
def test_ignore(web_client: WebClient, session: orm.Session) -> None:
- c = UnusedCategories()
- c.test(session)
- session.commit()
+ with session.begin_nested():
+ c = UnusedCategories()
+ c.test()
uri = next(iter(c.issues.keys()))
diff --git a/tests/controllers/income/test_endpoints.py b/tests/controllers/income/test_endpoints.py
index 09285e0a..47f59bc1 100644
--- a/tests/controllers/income/test_endpoints.py
+++ b/tests/controllers/income/test_endpoints.py
@@ -17,7 +17,6 @@ def test_page(
account: Account,
transactions: list[Transaction],
) -> None:
- _ = transactions
result, _ = web_client.GET(("income.page", {"period": "all"}))
assert "Income" in result
assert account.name in result
@@ -30,7 +29,6 @@ def test_chart(
account: Account,
transactions: list[Transaction],
) -> None:
- _ = transactions
result, _ = web_client.GET(("income.chart", {"period": "all"}))
assert "Income" in result
assert account.name in result
@@ -45,9 +43,9 @@ def test_dashboard(
account: Account,
transactions: list[Transaction],
) -> None:
- transactions[0].date = today
- transactions[0].splits[0].parent = transactions[0]
- session.commit()
+ with session.begin_nested():
+ transactions[0].date = today
+ transactions[0].splits[0].parent = transactions[0]
result, _ = web_client.GET("income.dashboard")
assert "Income" in result
diff --git a/tests/controllers/labels/test_contexts.py b/tests/controllers/labels/test_contexts.py
index 545e19c4..d6967247 100644
--- a/tests/controllers/labels/test_contexts.py
+++ b/tests/controllers/labels/test_contexts.py
@@ -1,17 +1,12 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
-
from nummus.controllers import base
from nummus.controllers import labels as label_controller
from nummus.models.label import Label
-if TYPE_CHECKING:
- from sqlalchemy import orm
-
-def test_ctx(session: orm.Session, labels: dict[str, int]) -> None:
- ctx = label_controller.ctx_labels(session)
+def test_ctx(labels: dict[str, int]) -> None:
+ ctx = label_controller.ctx_labels()
target: list[base.NamePair] = [
base.NamePair(Label.id_to_uri(label_id), name)
diff --git a/tests/controllers/labels/test_endpoints.py b/tests/controllers/labels/test_endpoints.py
index 69d9e40f..b8413511 100644
--- a/tests/controllers/labels/test_endpoints.py
+++ b/tests/controllers/labels/test_endpoints.py
@@ -4,12 +4,11 @@
import pytest
+from nummus import sql
from nummus.controllers import base
from nummus.models.label import Label, LabelLink
-from nummus.models.utils import query_count
if TYPE_CHECKING:
- from sqlalchemy import orm
from nummus.models.transaction import Transaction
from tests.controllers.conftest import WebClient
@@ -53,10 +52,8 @@ def test_label_get(web_client: WebClient, labels: dict[str, int]) -> None:
def test_label_delete(
web_client: WebClient,
labels: dict[str, int],
- session: orm.Session,
transactions: list[Transaction],
) -> None:
- _ = transactions
uri = Label.id_to_uri(labels["engineer"])
result, headers = web_client.DELETE(
@@ -66,14 +63,12 @@ def test_label_delete(
assert "Deleted label engineer" in result
assert "label" in headers["HX-Trigger"]
- n = query_count(session.query(LabelLink))
- assert n == 0
+ assert not sql.any_(LabelLink.query())
def test_label_edit(
web_client: WebClient,
labels: dict[str, int],
- session: orm.Session,
) -> None:
uri = Label.id_to_uri(labels["engineer"])
@@ -85,7 +80,7 @@ def test_label_edit(
assert "All changes saved" in result
assert "label" in headers["HX-Trigger"]
- label = session.query(Label).where(Label.name == "new label").one()
+ label = sql.one(Label.query().where(Label.name == "new label"))
assert label.id_ == labels["engineer"]
diff --git a/tests/controllers/net_worth/test_contexts.py b/tests/controllers/net_worth/test_contexts.py
index 5b06b2c6..f7ffdfa8 100644
--- a/tests/controllers/net_worth/test_contexts.py
+++ b/tests/controllers/net_worth/test_contexts.py
@@ -19,10 +19,8 @@
def test_ctx_chart_empty(
today: datetime.date,
account: Account,
- session: orm.Session,
) -> None:
- _ = account
- ctx = net_worth.ctx_chart(session, today, "max")
+ ctx = net_worth.ctx_chart(today, "max")
chart: base.ChartData = {
"labels": [today.isoformat()],
@@ -51,9 +49,8 @@ def test_ctx_chart_empty(
def test_ctx_chart_this_year(
today: datetime.date,
- session: orm.Session,
) -> None:
- ctx = net_worth.ctx_chart(session, today, "ytd")
+ ctx = net_worth.ctx_chart(today, "ytd")
assert ctx["start"] == today.replace(month=1, day=1)
assert ctx["end"] == today
@@ -69,28 +66,25 @@ def test_ctx_chart(
session: orm.Session,
categories: dict[str, int],
) -> None:
- _ = asset_valuation
- _ = transactions
- # Make account_investments negative
- txn = Transaction(
- account_id=account_investments.id_,
- date=today,
- amount=-100,
- statement=rand_str_generator(),
- payee="Monkey Bank",
- cleared=True,
- )
- t_split = TransactionSplit(
- parent=txn,
- amount=txn.amount,
- category_id=categories["groceries"],
- )
- session.add_all((txn, t_split))
- session.commit()
+ with session.begin_nested():
+ # Make account_investments negative
+ txn = Transaction.create(
+ account_id=account_investments.id_,
+ date=today,
+ amount=-100,
+ statement=rand_str_generator(),
+ payee="Monkey Bank",
+ cleared=True,
+ )
+ TransactionSplit.create(
+ parent=txn,
+ amount=txn.amount,
+ category_id=categories["groceries"],
+ )
start = today - datetime.timedelta(days=3)
end = today + datetime.timedelta(days=3)
- ctx = net_worth.ctx_chart(session, end, "max")
+ ctx = net_worth.ctx_chart(end, "max")
chart: base.ChartData = {
"labels": base.date_labels(start.toordinal(), end.toordinal())[0],
diff --git a/tests/controllers/net_worth/test_endpoints.py b/tests/controllers/net_worth/test_endpoints.py
index 55d99ae5..f1668cf9 100644
--- a/tests/controllers/net_worth/test_endpoints.py
+++ b/tests/controllers/net_worth/test_endpoints.py
@@ -23,9 +23,6 @@ def test_page(
asset_valuation: AssetValuation,
transactions: list[Transaction],
) -> None:
- _ = asset_valuation
- _ = transactions
-
result, _ = web_client.GET("net_worth.page")
assert "Net worth" in result
assert "Assets" in result
@@ -39,10 +36,6 @@ def test_chart(
asset_valuation: AssetValuation,
transactions: list[Transaction],
) -> None:
- _ = account
- _ = asset_valuation
- _ = transactions
-
result, headers = web_client.GET("net_worth.chart")
assert headers["HX-Push-URL"] == web_client.url_for(
"net_worth.page",
@@ -57,10 +50,6 @@ def test_dashboard(
asset_valuation: AssetValuation,
transactions: list[Transaction],
) -> None:
- _ = account
- _ = asset_valuation
- _ = transactions
-
result, _ = web_client.GET("net_worth.dashboard")
assert "Net worth" in result
assert "JSON.parse" in result
diff --git a/tests/controllers/performance/test_contexts.py b/tests/controllers/performance/test_contexts.py
index eac08cf2..195f3035 100644
--- a/tests/controllers/performance/test_contexts.py
+++ b/tests/controllers/performance/test_contexts.py
@@ -4,13 +4,13 @@
from decimal import Decimal
from typing import TYPE_CHECKING
+from nummus import sql
from nummus.controllers import base, performance
from nummus.models.account import AccountCategory
from nummus.models.asset import (
Asset,
AssetCategory,
)
-from nummus.models.base import YIELD_PER
from nummus.models.currency import CURRENCY_FORMATS, DEFAULT_CURRENCY
if TYPE_CHECKING:
@@ -25,10 +25,8 @@
def test_ctx_chart_empty(
today: datetime.date,
account: Account,
- session: orm.Session,
) -> None:
- _ = account
- ctx = performance.ctx_chart(session, today, "max", "S&P 500", set())
+ ctx = performance.ctx_chart(today, "max", "S&P 500", set())
chart: performance.ChartData = {
"labels": [today.isoformat()],
@@ -55,10 +53,10 @@ def test_ctx_chart_empty(
"currency_format": CURRENCY_FORMATS[DEFAULT_CURRENCY],
}
- query = session.query(Asset.name).order_by(Asset.name)
- indices: list[str] = [r[0] for r in query.yield_per(YIELD_PER)]
+ query = Asset.query(Asset.name).order_by(Asset.name)
+ indices: list[str] = list(sql.col0(query))
- desc = session.query(Asset.description).where(Asset.name == "S&P 500").one()[0]
+ desc = sql.one(Asset.query(Asset.description).where(Asset.name == "S&P 500"))
target: performance.Context = {
"start": today,
@@ -79,11 +77,11 @@ def test_ctx_chart_this_year(
session: orm.Session,
account: Account,
) -> None:
- account.category = AccountCategory.INVESTMENT
- account.closed = True
- session.commit()
+ with session.begin_nested():
+ account.category = AccountCategory.INVESTMENT
+ account.closed = True
- ctx = performance.ctx_chart(session, today, "ytd", "S&P 500", set())
+ ctx = performance.ctx_chart(today, "ytd", "S&P 500", set())
assert ctx["start"] == today.replace(month=1, day=1)
assert ctx["end"] == today
@@ -97,14 +95,12 @@ def test_ctx_chart(
transactions: list[Transaction],
session: orm.Session,
) -> None:
- account.category = AccountCategory.INVESTMENT
- session.commit()
- _ = asset_valuation
- _ = transactions
+ with session.begin_nested():
+ account.category = AccountCategory.INVESTMENT
start = today - datetime.timedelta(days=3)
end = today + datetime.timedelta(days=3)
- ctx = performance.ctx_chart(session, end, "max", "S&P 500", set())
+ ctx = performance.ctx_chart(end, "max", "S&P 500", set())
chart: performance.ChartData = {
"labels": base.date_labels(start.toordinal(), end.toordinal())[0],
@@ -152,13 +148,13 @@ def test_ctx_chart(
}
query = (
- session.query(Asset.name)
+ Asset.query(Asset.name)
.where(Asset.category == AssetCategory.INDEX)
.order_by(Asset.name)
)
- indices: list[str] = [r[0] for r in query.yield_per(YIELD_PER)]
+ indices: list[str] = list(sql.col0(query))
- desc = session.query(Asset.description).where(Asset.name == "S&P 500").one()[0]
+ desc = sql.one(Asset.query(Asset.description).where(Asset.name == "S&P 500"))
target: performance.Context = {
"start": start,
@@ -181,14 +177,12 @@ def test_ctx_chart_exclude(
transactions: list[Transaction],
session: orm.Session,
) -> None:
- account.category = AccountCategory.INVESTMENT
- session.commit()
- _ = asset_valuation
- _ = transactions
+ with session.begin_nested():
+ account.category = AccountCategory.INVESTMENT
start = today - datetime.timedelta(days=3)
end = today + datetime.timedelta(days=3)
- ctx = performance.ctx_chart(session, end, "max", "S&P 500", {account.id_})
+ ctx = performance.ctx_chart(end, "max", "S&P 500", {account.id_})
chart: performance.ChartData = {
"labels": base.date_labels(start.toordinal(), end.toordinal())[0],
@@ -218,13 +212,14 @@ def test_ctx_chart_exclude(
}
query = (
- session.query(Asset.name)
+ Asset.query(Asset.name)
.where(Asset.category == AssetCategory.INDEX)
.order_by(Asset.name)
)
- indices: list[str] = [r[0] for r in query.yield_per(YIELD_PER)]
- desc = session.query(Asset.description).where(Asset.name == "S&P 500").one()[0]
+ indices: list[str] = list(sql.col0(query))
+
+ desc = sql.one(Asset.query(Asset.description).where(Asset.name == "S&P 500"))
target: performance.Context = {
"start": start,
diff --git a/tests/controllers/performance/test_endpoints.py b/tests/controllers/performance/test_endpoints.py
index c6f1efe8..21e49572 100644
--- a/tests/controllers/performance/test_endpoints.py
+++ b/tests/controllers/performance/test_endpoints.py
@@ -28,10 +28,8 @@ def test_page(
asset_valuation: AssetValuation,
transactions: list[Transaction],
) -> None:
- account.category = AccountCategory.INVESTMENT
- session.commit()
- _ = asset_valuation
- _ = transactions
+ with session.begin_nested():
+ account.category = AccountCategory.INVESTMENT
result, _ = web_client.GET(
("performance.page", {"index": "Dow Jones Industrial Average"}),
@@ -49,10 +47,8 @@ def test_chart(
asset_valuation: AssetValuation,
transactions: list[Transaction],
) -> None:
- account.category = AccountCategory.INVESTMENT
- session.commit()
- _ = asset_valuation
- _ = transactions
+ with session.begin_nested():
+ account.category = AccountCategory.INVESTMENT
result, headers = web_client.GET("performance.chart")
assert headers["HX-Push-URL"] == web_client.url_for(
@@ -70,10 +66,8 @@ def test_dashboard(
asset_valuation: AssetValuation,
transactions: list[Transaction],
) -> None:
- account.category = AccountCategory.INVESTMENT
- session.commit()
- _ = asset_valuation
- _ = transactions
+ with session.begin_nested():
+ account.category = AccountCategory.INVESTMENT
result, _ = web_client.GET("performance.dashboard")
assert "Investing performance" in result
diff --git a/tests/controllers/settings/test_contexts.py b/tests/controllers/settings/test_contexts.py
index cdaf90d5..42f727a1 100644
--- a/tests/controllers/settings/test_contexts.py
+++ b/tests/controllers/settings/test_contexts.py
@@ -1,16 +1,11 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
-
from nummus.controllers import settings
from nummus.models.currency import Currency, DEFAULT_CURRENCY
-if TYPE_CHECKING:
- from sqlalchemy import orm
-
-def test_ctx(session: orm.Session) -> None:
- ctx = settings.ctx_settings(session)
+def test_ctx() -> None:
+ ctx = settings.ctx_settings()
target: settings.SettingsContext = {
"currency": DEFAULT_CURRENCY,
diff --git a/tests/controllers/settings/test_endpoints.py b/tests/controllers/settings/test_endpoints.py
index f09ca4f5..b22e0a9c 100644
--- a/tests/controllers/settings/test_endpoints.py
+++ b/tests/controllers/settings/test_endpoints.py
@@ -6,8 +6,6 @@
from nummus.models.currency import Currency
if TYPE_CHECKING:
- from sqlalchemy import orm
-
from tests.controllers.conftest import WebClient
@@ -16,10 +14,10 @@ def test_page(web_client: WebClient) -> None:
assert "Base currency" in result
-def test_edit_currency(web_client: WebClient, session: orm.Session) -> None:
+def test_edit_currency(web_client: WebClient) -> None:
result, headers = web_client.PATCH("settings.edit", data={"currency": "CHF"})
assert "snackbar.show" in result
assert "All changes saved" in result
assert "config" in headers["HX-Trigger"]
- assert Config.base_currency(session) == Currency.CHF
+ assert Config.base_currency() == Currency.CHF
diff --git a/tests/controllers/spending/test_contexts.py b/tests/controllers/spending/test_contexts.py
index d56727bf..b2f69f13 100644
--- a/tests/controllers/spending/test_contexts.py
+++ b/tests/controllers/spending/test_contexts.py
@@ -4,6 +4,7 @@
import pytest
+from nummus import sql
from nummus.controllers import base, spending
from nummus.models.account import Account
from nummus.models.currency import DEFAULT_CURRENCY
@@ -13,13 +14,10 @@
TransactionCategory,
TransactionCategoryGroup,
)
-from nummus.models.utils import query_count
if TYPE_CHECKING:
import datetime
- from sqlalchemy import orm
-
from nummus.models.transaction import Transaction
@@ -48,7 +46,6 @@
],
)
def test_data_query(
- session: orm.Session,
account: Account,
transactions_spending: list[Transaction],
categories: dict[str, int],
@@ -62,9 +59,7 @@ def test_data_query(
is_income: bool,
target: tuple[int, bool],
) -> None:
- _ = transactions_spending
dat_query = spending.data_query(
- session,
DEFAULT_CURRENCY,
account.uri if include_account else None,
period,
@@ -75,20 +70,18 @@ def test_data_query(
is_income=is_income,
)
assert dat_query.any_filters == target[1]
- assert query_count(dat_query.final_query) == target[0]
+ assert sql.count(dat_query.final_query) == target[0]
def test_ctx_options(
today: datetime.date,
- session: orm.Session,
account: Account,
transactions: list[Transaction],
categories: dict[str, int],
labels: dict[str, int],
) -> None:
- _ = transactions
dat_query = spending.DataQuery(
- session.query(TransactionSplit),
+ TransactionSplit.query(),
{},
any_filters=False,
)
@@ -96,9 +89,9 @@ def test_ctx_options(
ctx = spending.ctx_options(
dat_query,
today,
- Account.map_name(session),
- base.tranaction_category_groups(session),
- Label.map_name(session),
+ Account.map_name(),
+ base.tranaction_category_groups(),
+ Label.map_name(),
)
assert ctx["options_account"] == [base.NamePair(account.uri, account.name)]
@@ -130,13 +123,12 @@ def test_ctx_options(
def test_ctx_options_selected(
today: datetime.date,
- session: orm.Session,
account: Account,
categories: dict[str, int],
labels: dict[str, int],
) -> None:
dat_query = spending.DataQuery(
- session.query(TransactionSplit),
+ TransactionSplit.query(),
{},
any_filters=False,
)
@@ -144,9 +136,9 @@ def test_ctx_options_selected(
ctx = spending.ctx_options(
dat_query,
today,
- Account.map_name(session),
- base.tranaction_category_groups(session),
- Label.map_name(session),
+ Account.map_name(),
+ base.tranaction_category_groups(),
+ Label.map_name(),
account.uri,
TransactionCategory.id_to_uri(categories["other income"]),
Label.id_to_uri(labels["engineer"]),
@@ -173,11 +165,8 @@ def test_ctx_options_selected(
def test_ctx_chart_empty(
today: datetime.date,
account: Account,
- session: orm.Session,
) -> None:
- _ = account
ctx, title = spending.ctx_chart(
- session,
today,
selected_account=None,
selected_category=None,
@@ -205,12 +194,9 @@ def test_ctx_chart_empty(
def test_ctx_chart(
today: datetime.date,
- session: orm.Session,
transactions_spending: list[Transaction],
) -> None:
- _ = transactions_spending
ctx, title = spending.ctx_chart(
- session,
today,
None,
None,
diff --git a/tests/controllers/spending/test_endpoints.py b/tests/controllers/spending/test_endpoints.py
index a57af1eb..a3edd185 100644
--- a/tests/controllers/spending/test_endpoints.py
+++ b/tests/controllers/spending/test_endpoints.py
@@ -13,7 +13,6 @@ def test_page(
account: Account,
transactions_spending: list[Transaction],
) -> None:
- _ = transactions_spending
result, _ = web_client.GET("spending.page")
assert "Spending" in result
assert account.name in result
@@ -27,7 +26,6 @@ def test_chart(
account: Account,
transactions_spending: list[Transaction],
) -> None:
- _ = transactions_spending
result, _ = web_client.GET("spending.chart")
assert "Spending" in result
assert account.name in result
@@ -41,7 +39,6 @@ def test_dashboard(
account: Account,
transactions_spending: list[Transaction],
) -> None:
- _ = transactions_spending
result, _ = web_client.GET("spending.dashboard")
assert "Spending" in result
assert account.name in result
diff --git a/tests/controllers/test_base.py b/tests/controllers/test_base.py
index 3258410c..9bf4d906 100644
--- a/tests/controllers/test_base.py
+++ b/tests/controllers/test_base.py
@@ -24,25 +24,24 @@
from collections.abc import Callable
import werkzeug.test
- from sqlalchemy import orm
from nummus.models.asset import Asset
from tests.conftest import RandomStringGenerator
from tests.controllers.conftest import HTMLValidator, WebClient
-def test_find(session: orm.Session, account: Account) -> None:
- assert base.find(session, Account, account.uri) == account
+def test_find(account: Account) -> None:
+ assert base.find(Account, account.uri) == account
-def test_find_404(session: orm.Session) -> None:
+def test_find_404() -> None:
with pytest.raises(exc.http.NotFound):
- base.find(session, Account, Account.id_to_uri(0))
+ base.find(Account, Account.id_to_uri(0))
-def test_find_400(session: orm.Session) -> None:
+def test_find_400() -> None:
with pytest.raises(exc.http.BadRequest):
- base.find(session, Account, "fake")
+ base.find(Account, "fake")
@pytest.mark.parametrize(
@@ -115,7 +114,7 @@ class Fake:
],
ids=conftest.id_func,
)
-def test_validate_required(func: Callable) -> None:
+def test_validate_required(func: Callable[..., object]) -> None:
assert func("", is_required=True) == "Required"
@@ -132,24 +131,19 @@ def test_validate_string_short() -> None:
assert base.validate_string("a", check_length=True) == "2 characters required"
-def test_validate_string_no_session() -> None:
- with pytest.raises(TypeError):
- base.validate_string("abc", no_duplicates=Account.name)
-
-
-def test_validate_string_duplicate(session: orm.Session, account: Account) -> None:
+def test_validate_string_duplicate(account: Account) -> None:
err = base.validate_string(
account.name,
- session=session,
+ cls=Account,
no_duplicates=Account.name,
)
assert err == "Must be unique"
-def test_validate_string_duplicate_self(session: orm.Session, account: Account) -> None:
+def test_validate_string_duplicate_self(account: Account) -> None:
err = base.validate_string(
account.name,
- session=session,
+ cls=Account,
no_duplicates=Account.name,
no_duplicate_wheres=[Account.id_ != account.id_],
)
@@ -170,7 +164,7 @@ def test_validate_date(today: datetime.date, s: str, max_future: int | None) ->
],
ids=conftest.id_func,
)
-def test_validate_unable_to_parse(func: Callable) -> None:
+def test_validate_unable_to_parse(func: Callable[..., object]) -> None:
assert func("a") == "Unable to parse"
@@ -222,13 +216,12 @@ def test_parse_date(
def test_validate_date_duplicate(
today: datetime.date,
- session: orm.Session,
asset_valuation: AssetValuation,
) -> None:
err = base.validate_date(
asset_valuation.date.isoformat(),
today,
- session=session,
+ cls=AssetValuation,
no_duplicates=AssetValuation.date_ord,
)
assert err == "Must be unique"
@@ -327,12 +320,10 @@ def test_error_str(
def test_error_empty_field(
- session: orm.Session,
valid_html: HTMLValidator,
) -> None:
- session.add(Account())
try:
- session.commit()
+ Account.create()
except exc.IntegrityError as e:
html = base.error(e)
assert valid_html(html)
@@ -342,21 +333,18 @@ def test_error_empty_field(
def test_error_unique(
- session: orm.Session,
account: Account,
valid_html: HTMLValidator,
) -> None:
- new_account = Account(
- name=account.name,
- institution=account.institution,
- category=account.category,
- closed=False,
- budgeted=False,
- currency=DEFAULT_CURRENCY,
- )
- session.add(new_account)
try:
- session.commit()
+ Account.create(
+ name=account.name,
+ institution=account.institution,
+ category=account.category,
+ closed=False,
+ budgeted=False,
+ currency=DEFAULT_CURRENCY,
+ )
except exc.IntegrityError as e:
html = base.error(e)
assert valid_html(html)
@@ -366,13 +354,11 @@ def test_error_unique(
def test_error_check(
- session: orm.Session,
account: Account,
valid_html: HTMLValidator,
) -> None:
- _ = account
try:
- session.query(Account).update({"name": "a"})
+ Account.query().update({"name": "a"})
except exc.IntegrityError as e:
html = base.error(e)
assert valid_html(html)
@@ -496,10 +482,9 @@ def test_change_redirect_no_htmx(web_client: WebClient) -> None:
def test_tranaction_category_groups(
- session: orm.Session,
categories: dict[str, int],
) -> None:
- groups = base.tranaction_category_groups(session)
+ groups = base.tranaction_category_groups()
assert len(groups) == len(TransactionCategoryGroup)
assert sum(len(group) for group in groups.values()) == len(categories)
diff --git a/tests/controllers/test_import_file.py b/tests/controllers/test_import_file.py
index 5ffe4082..e42af863 100644
--- a/tests/controllers/test_import_file.py
+++ b/tests/controllers/test_import_file.py
@@ -48,7 +48,7 @@ def test_no_file(web_client: WebClient) -> None:
],
)
def test_error(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
web_client: WebClient,
data_path: Path,
file: str,
@@ -56,7 +56,6 @@ def test_error(
traceback: bool,
account: Account,
) -> None:
- _ = account
path = data_path / file
result, _ = web_client.POST(
"import_file.import_file",
@@ -78,8 +77,6 @@ def test_import_file(
account: Account,
account_investments: Account,
) -> None:
- _ = account
- _ = account_investments
path = data_path / "transactions_required.csv"
result, headers = web_client.POST(
"import_file.import_file",
@@ -96,8 +93,6 @@ def test_duplicate(
account: Account,
account_investments: Account,
) -> None:
- _ = account
- _ = account_investments
path = data_path / "transactions_required.csv"
web_client.POST(
"import_file.import_file",
@@ -120,8 +115,6 @@ def test_duplicate_force(
account: Account,
account_investments: Account,
) -> None:
- _ = account
- _ = account_investments
path = data_path / "transactions_required.csv"
web_client.POST(
"import_file.import_file",
diff --git a/tests/controllers/transaction_categories/test_contexts.py b/tests/controllers/transaction_categories/test_contexts.py
index 7157e198..eaea7a74 100644
--- a/tests/controllers/transaction_categories/test_contexts.py
+++ b/tests/controllers/transaction_categories/test_contexts.py
@@ -1,26 +1,21 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
-
+from nummus import sql
from nummus.controllers import base, transaction_categories
-from nummus.models.base import YIELD_PER
from nummus.models.transaction_category import (
TransactionCategory,
TransactionCategoryGroup,
)
-if TYPE_CHECKING:
- from sqlalchemy import orm
-
-def test_ctx(session: orm.Session) -> None:
- groups = transaction_categories.ctx_categories(session)
+def test_ctx() -> None:
+ groups = transaction_categories.ctx_categories()
exclude = {"securities traded"}
for g in TransactionCategoryGroup:
query = (
- session.query(TransactionCategory)
+ TransactionCategory.query()
.where(
TransactionCategory.group == g,
TransactionCategory.name.not_in(exclude),
@@ -28,7 +23,6 @@ def test_ctx(session: orm.Session) -> None:
.order_by(TransactionCategory.name)
)
target: list[base.NamePair] = [
- base.NamePair(t_cat.uri, t_cat.emoji_name)
- for t_cat in query.yield_per(YIELD_PER)
+ base.NamePair(t_cat.uri, t_cat.emoji_name) for t_cat in sql.yield_(query)
]
assert groups[g] == target
diff --git a/tests/controllers/transaction_categories/test_endpoints.py b/tests/controllers/transaction_categories/test_endpoints.py
index 05baa14b..b7c52b13 100644
--- a/tests/controllers/transaction_categories/test_endpoints.py
+++ b/tests/controllers/transaction_categories/test_endpoints.py
@@ -4,6 +4,7 @@
import pytest
+from nummus import sql
from nummus.controllers import base
from nummus.models.transaction import TransactionSplit
from nummus.models.transaction_category import (
@@ -12,7 +13,6 @@
)
if TYPE_CHECKING:
- from sqlalchemy import orm
from nummus.models.transaction import Transaction
from tests.controllers.conftest import WebClient
@@ -62,7 +62,6 @@ def test_new_get(web_client: WebClient) -> None:
def test_new(
web_client: WebClient,
rand_str: str,
- session: orm.Session,
) -> None:
form = {
"name": rand_str,
@@ -75,11 +74,10 @@ def test_new(
assert f"Created category {rand_str}" in result
assert "category" in headers["HX-Trigger"]
- t_cat = (
- session.query(TransactionCategory)
- .where(TransactionCategory.emoji_name == rand_str)
- .one()
+ query = TransactionCategory.query().where(
+ TransactionCategory.emoji_name == rand_str,
)
+ t_cat = sql.one(query)
assert t_cat.group == TransactionCategoryGroup.EXPENSE
assert t_cat.is_profit_loss
assert t_cat.essential_spending
@@ -133,12 +131,10 @@ def test_category_delete_locked(
def test_category_delete_unlocked(
web_client: WebClient,
categories: dict[str, int],
- session: orm.Session,
transactions_spending: list[Transaction],
) -> None:
- _ = transactions_spending
t_split = (
- session.query(TransactionSplit)
+ TransactionSplit.query()
.where(TransactionSplit.category_id == categories["groceries"])
.first()
)
@@ -153,20 +149,19 @@ def test_category_delete_unlocked(
assert "category" in headers["HX-Trigger"]
t_cat = (
- session.query(TransactionCategory)
+ TransactionCategory.query()
.where(TransactionCategory.name == "groceries")
.one_or_none()
)
assert t_cat is None
- session.refresh(t_split)
+ t_split.refresh()
assert t_split.category_id == categories["uncategorized"]
def test_category_edit_unlocked(
web_client: WebClient,
categories: dict[str, int],
- session: orm.Session,
) -> None:
uri = TransactionCategory.id_to_uri(categories["groceries"])
@@ -178,11 +173,8 @@ def test_category_edit_unlocked(
assert "All changes saved" in result
assert "category" in headers["HX-Trigger"]
- t_cat = (
- session.query(TransactionCategory)
- .where(TransactionCategory.name == "food")
- .one()
- )
+ query = TransactionCategory.query().where(TransactionCategory.name == "food")
+ t_cat = sql.one(query)
assert t_cat.emoji_name == "Food"
assert t_cat.group == TransactionCategoryGroup.EXPENSE
assert not t_cat.is_profit_loss
@@ -192,7 +184,6 @@ def test_category_edit_unlocked(
def test_category_edit_locked(
web_client: WebClient,
categories: dict[str, int],
- session: orm.Session,
) -> None:
uri = TransactionCategory.id_to_uri(categories["uncategorized"])
@@ -204,11 +195,10 @@ def test_category_edit_locked(
assert "All changes saved" in result
assert "category" in headers["HX-Trigger"]
- t_cat = (
- session.query(TransactionCategory)
- .where(TransactionCategory.name == "uncategorized")
- .one()
+ query = TransactionCategory.query().where(
+ TransactionCategory.name == "uncategorized",
)
+ t_cat = sql.one(query)
assert t_cat.emoji_name == "Uncategorized 🤷"
assert t_cat.group == TransactionCategoryGroup.OTHER
assert not t_cat.is_profit_loss
diff --git a/tests/controllers/transactions/test_contexts.py b/tests/controllers/transactions/test_contexts.py
index f3279152..25fcdcb8 100644
--- a/tests/controllers/transactions/test_contexts.py
+++ b/tests/controllers/transactions/test_contexts.py
@@ -6,12 +6,11 @@
import pytest
-from nummus import utils
+from nummus import sql, utils
from nummus.controllers import base
from nummus.controllers import transactions as txn_controller
from nummus.models.account import Account
from nummus.models.asset import Asset
-from nummus.models.base import YIELD_PER
from nummus.models.currency import CURRENCY_FORMATS, DEFAULT_CURRENCY
from nummus.models.label import Label, LabelLink
from nummus.models.transaction import TransactionSplit
@@ -19,7 +18,6 @@
TransactionCategory,
TransactionCategoryGroup,
)
-from nummus.models.utils import query_count
if TYPE_CHECKING:
from sqlalchemy import orm
@@ -43,7 +41,6 @@
],
)
def test_table_query(
- session: orm.Session,
account: Account,
transactions: list[Transaction],
categories: dict[str, int],
@@ -55,9 +52,7 @@ def test_table_query(
uncleared: bool,
target: tuple[int, bool],
) -> None:
- _ = transactions
tbl_query = txn_controller.table_query(
- session,
None,
account.uri if include_account else None,
period,
@@ -67,7 +62,7 @@ def test_table_query(
uncleared=uncleared,
)
assert tbl_query.any_filters == target[1]
- assert query_count(tbl_query.final_query) == target[0]
+ assert sql.count(tbl_query.final_query) == target[0]
def test_ctx_txn(
@@ -91,14 +86,11 @@ def test_ctx_txn(
def test_ctx_split(
- session: orm.Session,
transactions: list[Transaction],
labels: dict[str, int],
) -> None:
- query = session.query(Asset).with_entities(Asset.id_, Asset.name, Asset.ticker)
- assets: dict[int, tuple[str, str | None]] = {
- r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER)
- }
+ query = Asset.query(Asset.id_, Asset.name, Asset.ticker)
+ assets = sql.to_dict_tuple(query)
txn = transactions[0]
t_split = txn.splits[0]
@@ -123,15 +115,12 @@ def test_ctx_split(
def test_ctx_split_asset(
- session: orm.Session,
asset: Asset,
transactions: list[Transaction],
labels: dict[str, int],
) -> None:
- query = session.query(Asset).with_entities(Asset.id_, Asset.name, Asset.ticker)
- assets: dict[int, tuple[str, str | None]] = {
- r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER)
- }
+ query = Asset.query(Asset.id_, Asset.name, Asset.ticker)
+ assets = sql.to_dict_tuple(query)
txn = transactions[1]
t_split = txn.splits[0]
@@ -156,23 +145,20 @@ def test_ctx_split_asset(
def test_ctx_row(
- session: orm.Session,
account: Account,
transactions: list[Transaction],
labels: dict[str, int],
) -> None:
- query = session.query(Asset).with_entities(Asset.id_, Asset.name, Asset.ticker)
- assets: dict[int, tuple[str, str | None]] = {
- r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER)
- }
+ query = Asset.query(Asset.id_, Asset.name, Asset.ticker)
+ assets = sql.to_dict_tuple(query)
txn = transactions[0]
t_split = txn.splits[0]
ctx = txn_controller.ctx_row(
t_split,
assets,
- Account.map_name(session),
- TransactionCategory.map_name_emoji(session),
+ Account.map_name(),
+ TransactionCategory.map_name_emoji(),
{labels["engineer"]: "engineer"},
set(),
CURRENCY_FORMATS[DEFAULT_CURRENCY],
@@ -199,14 +185,12 @@ def test_ctx_row(
def test_ctx_options(
today: datetime.date,
- session: orm.Session,
account: Account,
transactions: list[Transaction],
categories: dict[str, int],
) -> None:
- _ = transactions
tbl_query = txn_controller.TableQuery(
- session.query(TransactionSplit),
+ TransactionSplit.query(),
{},
any_filters=False,
)
@@ -214,8 +198,8 @@ def test_ctx_options(
ctx = txn_controller.ctx_options(
tbl_query,
today,
- Account.map_name(session),
- base.tranaction_category_groups(session),
+ Account.map_name(),
+ base.tranaction_category_groups(),
None,
None,
)
@@ -251,7 +235,7 @@ def test_ctx_options_selected(
categories: dict[str, int],
) -> None:
tbl_query = txn_controller.TableQuery(
- session.query(TransactionSplit),
+ TransactionSplit.query(),
{},
any_filters=False,
)
@@ -259,8 +243,8 @@ def test_ctx_options_selected(
ctx = txn_controller.ctx_options(
tbl_query,
today,
- Account.map_name(session),
- base.tranaction_category_groups(session),
+ Account.map_name(),
+ base.tranaction_category_groups(),
account.uri,
TransactionCategory.id_to_uri(categories["other income"]),
)
@@ -332,21 +316,17 @@ def test_table_title(
assert title == target
-def test_table_results_empty(
- session: orm.Session,
-) -> None:
- query = session.query(Asset).with_entities(Asset.id_, Asset.name, Asset.ticker)
- assets: dict[int, tuple[str, str | None]] = {
- r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER)
- }
+def test_table_results_empty() -> None:
+ query = Asset.query(Asset.id_, Asset.name, Asset.ticker)
+ assets = sql.to_dict_tuple(query)
- accounts = Account.map_name(session)
+ accounts = Account.map_name()
result = txn_controller._table_results(
- session.query(TransactionSplit),
+ TransactionSplit.query(),
assets,
accounts,
- TransactionCategory.map_name_emoji(session),
- Label.map_name(session),
+ TransactionCategory.map_name_emoji(),
+ Label.map_name(),
{},
dict.fromkeys(accounts, CURRENCY_FORMATS[DEFAULT_CURRENCY]),
)
@@ -357,16 +337,14 @@ def test_table_results(
session: orm.Session,
transactions: list[Transaction],
) -> None:
- query = session.query(Asset).with_entities(Asset.id_, Asset.name, Asset.ticker)
- assets: dict[int, tuple[str, str | None]] = {
- r[0]: (r[1], r[2]) for r in query.yield_per(YIELD_PER)
- }
- accounts = Account.map_name(session)
- labels = Label.map_name(session)
- categories = TransactionCategory.map_name_emoji(session)
+ query = Asset.query(Asset.id_, Asset.name, Asset.ticker)
+ assets = sql.to_dict_tuple(query)
+ accounts = Account.map_name()
+ labels = Label.map_name()
+ categories = TransactionCategory.map_name_emoji()
result = txn_controller._table_results(
- session.query(TransactionSplit).order_by(TransactionSplit.date_ord),
+ TransactionSplit.query().order_by(TransactionSplit.date_ord),
assets,
accounts,
categories,
@@ -385,7 +363,7 @@ def test_table_results(
categories,
{
label_id: labels[label_id]
- for label_id, in session.query(LabelLink.label_id).where(
+ for label_id, in LabelLink.query(LabelLink.label_id).where(
LabelLink.t_split_id == txn.splits[0].id_,
)
},
@@ -399,9 +377,8 @@ def test_table_results(
assert result == target
-def test_ctx_table_empty(today: datetime.date, session: orm.Session) -> None:
+def test_ctx_table_empty(today: datetime.date) -> None:
ctx, title = txn_controller.ctx_table(
- session,
today,
None,
None,
@@ -431,11 +408,9 @@ def test_ctx_table_empty(today: datetime.date, session: orm.Session) -> None:
def test_ctx_table(
today: datetime.date,
- session: orm.Session,
transactions: list[Transaction],
) -> None:
ctx, title = txn_controller.ctx_table(
- session,
today,
None,
None,
@@ -466,12 +441,10 @@ def test_ctx_table(
def test_ctx_table_paging(
today: datetime.date,
monkeypatch: pytest.MonkeyPatch,
- session: orm.Session,
transactions: list[Transaction],
) -> None:
monkeypatch.setattr(txn_controller, "PAGE_LEN", 2)
ctx, _ = txn_controller.ctx_table(
- session,
today,
None,
None,
@@ -489,12 +462,9 @@ def test_ctx_table_paging(
def test_ctx_table_search(
today: datetime.date,
- session: orm.Session,
transactions: list[Transaction],
) -> None:
- _ = transactions
ctx, _ = txn_controller.ctx_table(
- session,
today,
"rent",
None,
@@ -512,12 +482,9 @@ def test_ctx_table_search(
def test_ctx_table_search_paging(
today: datetime.date,
- session: orm.Session,
transactions: list[Transaction],
) -> None:
- _ = transactions
ctx, _ = txn_controller.ctx_table(
- session,
today,
"rent",
None,
diff --git a/tests/controllers/transactions/test_endpoints.py b/tests/controllers/transactions/test_endpoints.py
index 6ed51b22..ecd7db78 100644
--- a/tests/controllers/transactions/test_endpoints.py
+++ b/tests/controllers/transactions/test_endpoints.py
@@ -6,9 +6,9 @@
import flask
import pytest
+from nummus import sql
from nummus.controllers import base
from nummus.models.account import Account
-from nummus.models.base import YIELD_PER
from nummus.models.label import Label, LabelLink
from nummus.models.transaction import Transaction, TransactionSplit
from nummus.models.transaction_category import TransactionCategory
@@ -50,7 +50,6 @@ def test_table_options(
account: Account,
transactions: list[Transaction],
) -> None:
- _ = transactions
result, _ = web_client.GET("transactions.table_options")
assert 'name="period"' in result
assert 'name="category"' in result
@@ -151,7 +150,6 @@ def test_new_put_bad_date(
def test_new(
today: datetime.date,
- session: orm.Session,
web_client: WebClient,
account: Account,
categories: dict[str, int],
@@ -173,7 +171,7 @@ def test_new(
assert "Transaction created" in result
assert "account" in headers["HX-Trigger"]
- txn = session.query(Transaction).one()
+ txn = Transaction.one()
assert txn.account_id == account.id_
assert txn.date == today
assert txn.amount == round(rand_real, 2)
@@ -186,13 +184,12 @@ def test_new(
assert t_split.category_id == categories["other income"]
assert t_split.memo is None
- labels = Label.map_name(session)
+ labels = Label.map_name()
assert not labels
def test_new_split(
today: datetime.date,
- session: orm.Session,
web_client: WebClient,
account: Account,
categories: dict[str, int],
@@ -221,7 +218,7 @@ def test_new_split(
assert "Transaction created" in result
assert "account" in headers["HX-Trigger"]
- txn = session.query(Transaction).one()
+ txn = Transaction.one()
assert txn.account_id == account.id_
assert txn.date == today
assert txn.amount == round(rand_real, 2)
@@ -230,23 +227,27 @@ def test_new_split(
splits = txn.splits
assert len(splits) == 2
- labels = Label.map_name(session)
+ labels = Label.map_name()
assert len(labels) == 2
t_split = splits[0]
assert t_split.amount == Decimal(10)
assert t_split.category_id == categories["other income"]
assert t_split.memo is None
- query = session.query(LabelLink.label_id).where(LabelLink.t_split_id == t_split.id_)
- split_labels = {labels[label_id] for label_id, in query.yield_per(YIELD_PER)}
+ query = LabelLink.query(LabelLink.label_id).where(
+ LabelLink.t_split_id == t_split.id_,
+ )
+ split_labels = {labels[label_id] for label_id in sql.col0(query)}
assert split_labels == {"Engineer", "Salary"}
t_split = splits[1]
assert t_split.amount == round(rand_real - 10, 2)
assert t_split.category_id == categories["groceries"]
assert t_split.memo == "bananas"
- query = session.query(LabelLink.label_id).where(LabelLink.t_split_id == t_split.id_)
- split_labels = {labels[label_id] for label_id, in query.yield_per(YIELD_PER)}
+ query = LabelLink.query(LabelLink.label_id).where(
+ LabelLink.t_split_id == t_split.id_,
+ )
+ split_labels = {labels[label_id] for label_id in sql.col0(query)}
assert not split_labels
@@ -345,8 +346,8 @@ def test_transaction_get_uncleared(
transactions: list[Transaction],
) -> None:
txn = transactions[0]
- txn.cleared = False
- session.commit()
+ with session.begin_nested():
+ txn.cleared = False
result, _ = web_client.GET(("transactions.transaction", {"uri": txn.uri}))
assert "Edit transaction" in result
@@ -375,8 +376,8 @@ def test_transaction_clear(
transactions: list[Transaction],
) -> None:
txn = transactions[0]
- txn.cleared = False
- session.commit()
+ with session.begin_nested():
+ txn.cleared = False
result, headers = web_client.PATCH(("transactions.transaction", {"uri": txn.uri}))
assert "snackbar.show" in result
@@ -386,12 +387,9 @@ def test_transaction_clear(
session.refresh(txn)
assert txn.cleared
- t = (
- session.query(TransactionSplit)
- .where(TransactionSplit.parent_id == txn.id_)
- .one()
- )
- assert t.cleared
+ query = TransactionSplit.query().where(TransactionSplit.parent_id == txn.id_)
+ t_split = sql.one(query)
+ assert t_split.cleared
def test_transaction_delete_uncleared(
@@ -400,23 +398,23 @@ def test_transaction_delete_uncleared(
transactions: list[Transaction],
) -> None:
txn = transactions[0]
- txn.cleared = False
- session.commit()
+ with session.begin_nested():
+ txn.cleared = False
result, headers = web_client.DELETE(("transactions.transaction", {"uri": txn.uri}))
assert "snackbar.show" in result
assert f"Transaction on {txn.date} deleted" in result
assert "account" in headers["HX-Trigger"]
- t = session.query(Transaction).where(Transaction.id_ == txn.id_).one_or_none()
+ t = Transaction.query().where(Transaction.id_ == txn.id_).one_or_none()
assert t is None
- t = (
- session.query(TransactionSplit)
+ t_split = (
+ TransactionSplit.query()
.where(TransactionSplit.parent_id == txn.id_)
.one_or_none()
)
- assert t is None
+ assert t_split is None
def test_transaction_delete(
@@ -601,21 +599,23 @@ def test_validation(
@pytest.mark.parametrize(
- ("split_amount", "split", "target"),
+ ("split_amount", "split", "include_account", "target"),
[
# Just amount with a single split is okay
- ([], False, ""),
- (["10"], False, ""),
- (["11"], False, "Remove $1.00 from splits"),
- (["9"], False, "Assign $1.00 to splits"),
- (["9"], True, "Assign $1.00 to splits"),
+ ([], False, False, ""),
+ (["10"], False, False, ""),
+ (["11"], False, False, "Remove $1.00 from splits"),
+ (["9"], False, False, "Assign $1.00 to splits"),
+ (["9"], True, True, "Assign $1.00 to splits"),
],
)
def test_validation_amounts(
flask_app: flask.Flask,
web_client: WebClient,
+ account: Account,
split_amount: list[str],
split: bool,
+ include_account: bool,
target: str,
) -> None:
result, _ = web_client.GET(
@@ -625,6 +625,7 @@ def test_validation_amounts(
"amount": "10",
"split-amount": split_amount,
"split": split,
+ "account": account.uri if include_account else "",
},
),
)
diff --git a/tests/encryption/test_aes.py b/tests/encryption/test_aes.py
index 422ada09..9eb912ad 100644
--- a/tests/encryption/test_aes.py
+++ b/tests/encryption/test_aes.py
@@ -15,10 +15,12 @@
try:
from nummus.encryption.aes import EncryptionAES as Encryption
except ImportError:
- NO_ENCRYPTION = True
+ no_encryption = True
from nummus.encryption.base import NoEncryption as Encryption
else:
- NO_ENCRYPTION = False
+ no_encryption = False
+
+NO_ENCRYPTION = no_encryption
@pytest.fixture
diff --git a/tests/health_checks/test_base.py b/tests/health_checks/test_base.py
index a857b642..8dddc7c6 100644
--- a/tests/health_checks/test_base.py
+++ b/tests/health_checks/test_base.py
@@ -4,13 +4,12 @@
import pytest
+from nummus import sql
from nummus.health_checks.base import HealthCheck
from nummus.health_checks.top import HEALTH_CHECKS
from nummus.models.health_checks import HealthCheckIssue
-from nummus.models.utils import query_count
if TYPE_CHECKING:
- from sqlalchemy import orm
from tests.conftest import RandomStringGenerator
@@ -20,23 +19,22 @@ class MockCheck(HealthCheck):
_SEVERE = True
@override
- def test(self, s: orm.Session) -> None:
- self._commit_issues(s, {})
+ def test(self) -> None:
+ self._commit_issues({})
@pytest.fixture
def issues(
- session: orm.Session,
rand_str_generator: RandomStringGenerator,
) -> list[tuple[str, int]]:
value_0 = rand_str_generator()
value_1 = rand_str_generator()
c = MockCheck()
d = {value_0: "msg 0", value_1: "msg 1"}
- c._commit_issues(session, d)
- c.ignore(session, [value_0])
+ c._commit_issues(d)
+ c.ignore([value_0])
- return [(i.value, i.id_) for i in session.query(HealthCheckIssue).all()]
+ return [(i.value, i.id_) for i in HealthCheckIssue.all()]
def test_init_properties() -> None:
@@ -57,7 +55,6 @@ def test_any_issues(rand_str: str) -> None:
@pytest.mark.parametrize("no_ignores", [False, True])
def test_commit_issues(
- session: orm.Session,
rand_str_generator: RandomStringGenerator,
no_ignores: bool,
) -> None:
@@ -65,17 +62,19 @@ def test_commit_issues(
value_1 = rand_str_generator()
c = MockCheck(no_ignores=no_ignores)
d = {value_0: "msg 0", value_1: "msg 1"}
- c._commit_issues(session, d)
- c.ignore(session, [value_0])
+ c._commit_issues(d)
+ c.ignore([value_0])
# Refresh c.issues
- c._commit_issues(session, d)
+ c._commit_issues(d)
- i_0 = session.query(HealthCheckIssue).where(HealthCheckIssue.value == value_0).one()
+ query = HealthCheckIssue.query().where(HealthCheckIssue.value == value_0)
+ i_0 = sql.one(query)
assert i_0.check == MockCheck.name()
assert i_0.msg == "msg 0"
assert i_0.ignore
- i_1 = session.query(HealthCheckIssue).where(HealthCheckIssue.value == value_1).one()
+ query = HealthCheckIssue.query().where(HealthCheckIssue.value == value_1)
+ i_1 = sql.one(query)
assert i_1.check == MockCheck.name()
assert i_1.msg == "msg 1"
assert not i_1.ignore
@@ -88,31 +87,24 @@ def test_commit_issues(
assert c.issues == target
-def test_ignore_empty(session: orm.Session, rand_str: str) -> None:
- MockCheck.ignore(session, {rand_str})
- assert query_count(session.query(HealthCheckIssue)) == 0
+def test_ignore_empty(rand_str: str) -> None:
+ MockCheck.ignore({rand_str})
+ assert not sql.any_(HealthCheckIssue.query())
def test_ignore(
- session: orm.Session,
issues: list[tuple[str, int]],
) -> None:
- MockCheck.ignore(session, [issues[0][0]])
- i = (
- session.query(HealthCheckIssue)
- .where(HealthCheckIssue.id_ == issues[0][1])
- .one()
- )
+ MockCheck.ignore([issues[0][0]])
+ query = HealthCheckIssue.query().where(HealthCheckIssue.id_ == issues[0][1])
+ i = sql.one(query)
assert i.check == MockCheck.name()
assert i.value == issues[0][0]
assert i.msg == "msg 0"
assert i.ignore
- i = (
- session.query(HealthCheckIssue)
- .where(HealthCheckIssue.id_ == issues[1][1])
- .one()
- )
+ query = HealthCheckIssue.query().where(HealthCheckIssue.id_ == issues[1][1])
+ i = sql.one(query)
assert i.check == MockCheck.name()
assert i.value == issues[1][0]
assert i.msg == "msg 1"
diff --git a/tests/health_checks/test_category_direction.py b/tests/health_checks/test_category_direction.py
index 6f95d2bf..af612fc0 100644
--- a/tests/health_checks/test_category_direction.py
+++ b/tests/health_checks/test_category_direction.py
@@ -5,11 +5,11 @@
import pytest
+from nummus import sql
from nummus.health_checks.category_direction import CategoryDirection
from nummus.models.currency import CURRENCY_FORMATS, DEFAULT_CURRENCY
from nummus.models.health_checks import HealthCheckIssue
from nummus.models.transaction import Transaction, TransactionSplit
-from nummus.models.utils import query_count
if TYPE_CHECKING:
import datetime
@@ -19,9 +19,9 @@
from nummus.models.account import Account
-def test_empty(session: orm.Session) -> None:
+def test_empty() -> None:
c = CategoryDirection()
- c.test(session)
+ c.test()
assert c.issues == {}
@@ -42,27 +42,26 @@ def test_check(
amount: Decimal,
rand_str: str,
) -> None:
- txn = Transaction(
- account_id=account.id_,
- date=today,
- amount=amount,
- statement=rand_str,
- )
- t_split = TransactionSplit(
- parent=txn,
- amount=txn.amount,
- category_id=categories[category],
- )
- session.add_all((txn, t_split))
- session.commit()
+ with session.begin_nested():
+ txn = Transaction.create(
+ account_id=account.id_,
+ date=today,
+ amount=amount,
+ statement=rand_str,
+ )
+ t_split = TransactionSplit.create(
+ parent=txn,
+ amount=txn.amount,
+ category_id=categories[category],
+ )
t_uri = t_split.uri
c = CategoryDirection()
- c.test(session)
+ c.test()
- assert query_count(session.query(HealthCheckIssue)) == 1
+ assert sql.count(HealthCheckIssue.query()) == 1
- i = session.query(HealthCheckIssue).one()
+ i = HealthCheckIssue.one()
assert i.check == c.name()
assert i.value == t_uri
uri = i.uri
diff --git a/tests/health_checks/test_database_integrity.py b/tests/health_checks/test_database_integrity.py
index a96f7501..e950eda8 100644
--- a/tests/health_checks/test_database_integrity.py
+++ b/tests/health_checks/test_database_integrity.py
@@ -1,27 +1,27 @@
from __future__ import annotations
+import textwrap
from typing import TYPE_CHECKING
import sqlalchemy
-from pandas.core.indexes.api import textwrap
+from nummus import sql
from nummus.health_checks.database_integrity import DatabaseIntegrity
from nummus.models.config import Config
from nummus.models.health_checks import HealthCheckIssue
-from nummus.models.utils import query_count
if TYPE_CHECKING:
from sqlalchemy import orm
-def test_empty(session: orm.Session) -> None:
+def test_empty() -> None:
c = DatabaseIntegrity()
- c.test(session)
+ c.test()
assert c.issues == {}
def test_corrupt(session: orm.Session) -> None:
- n = session.query(Config).update({"value": "abc"})
+ n = Config.query().update({"value": "abc"})
# Simulating a corrupt database is difficult
# Instead update the Account schema to have unique constraint
# PRAGMA integrity_check should catch duplicates
@@ -45,18 +45,17 @@ def test_corrupt(session: orm.Session) -> None:
session.commit()
c = DatabaseIntegrity()
- c.test(session)
+ c.test()
- assert query_count(session.query(HealthCheckIssue)) == n - 1
+ assert sql.count(HealthCheckIssue.query()) == n - 1
- i = session.query(HealthCheckIssue).first()
+ i = HealthCheckIssue.first()
assert i is not None
assert i.check == c.name()
assert i.value == "0"
# The balanced $100 transfer also on this day will not show up
target = {
- i.uri: f"non-unique entry in index {index_name}"
- for i in session.query(HealthCheckIssue).all()
+ i.uri: f"non-unique entry in index {index_name}" for i in HealthCheckIssue.all()
}
assert c.issues == target
diff --git a/tests/health_checks/test_duplicate_transactions.py b/tests/health_checks/test_duplicate_transactions.py
index 1077d3e8..9de6b4c5 100644
--- a/tests/health_checks/test_duplicate_transactions.py
+++ b/tests/health_checks/test_duplicate_transactions.py
@@ -2,6 +2,7 @@
from typing import TYPE_CHECKING
+from nummus import sql
from nummus.health_checks.duplicate_transactions import DuplicateTransactions
from nummus.models.currency import (
CURRENCY_FORMATS,
@@ -9,57 +10,51 @@
)
from nummus.models.health_checks import HealthCheckIssue
from nummus.models.transaction import Transaction, TransactionSplit
-from nummus.models.utils import query_count
if TYPE_CHECKING:
from sqlalchemy import orm
-def test_empty(session: orm.Session) -> None:
+def test_empty() -> None:
c = DuplicateTransactions()
- c.test(session)
+ c.test()
assert c.issues == {}
def test_no_issues(
- session: orm.Session,
transactions: list[Transaction],
) -> None:
- _ = transactions
c = DuplicateTransactions()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 0
+ c.test()
+ assert not sql.any_(HealthCheckIssue.query())
def test_duplicate(
session: orm.Session,
transactions: list[Transaction],
) -> None:
- _ = transactions
-
txn_to_copy = transactions[0]
# Fund account on 3 days before today
- txn = Transaction(
- account_id=txn_to_copy.account_id,
- date=txn_to_copy.date,
- amount=txn_to_copy.amount,
- statement=txn_to_copy.statement,
- )
- t_split = TransactionSplit(
- parent=txn,
- amount=txn.amount,
- category_id=txn_to_copy.splits[0].category_id,
- )
- session.add_all((txn, t_split))
- session.commit()
+ with session.begin_nested():
+ txn = Transaction.create(
+ account_id=txn_to_copy.account_id,
+ date=txn_to_copy.date,
+ amount=txn_to_copy.amount,
+ statement=txn_to_copy.statement,
+ )
+ TransactionSplit.create(
+ parent=txn,
+ amount=txn.amount,
+ category_id=txn_to_copy.splits[0].category_id,
+ )
c = DuplicateTransactions()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 1
+ c.test()
+ assert HealthCheckIssue.count() == 1
- i = session.query(HealthCheckIssue).one()
+ i = HealthCheckIssue.one()
assert i.check == c.name()
amount_raw = Transaction.amount.type.process_bind_param(txn.amount, None)
assert i.value == f"{txn.account_id}.{txn.date_ord}.{amount_raw}"
diff --git a/tests/health_checks/test_empty_fields.py b/tests/health_checks/test_empty_fields.py
index a6132252..537832ec 100644
--- a/tests/health_checks/test_empty_fields.py
+++ b/tests/health_checks/test_empty_fields.py
@@ -2,10 +2,10 @@
from typing import TYPE_CHECKING
+from nummus import sql
from nummus.health_checks.empty_fields import EmptyFields
from nummus.models.health_checks import HealthCheckIssue
from nummus.models.transaction_category import TransactionCategory
-from nummus.models.utils import query_count
if TYPE_CHECKING:
from sqlalchemy import orm
@@ -15,20 +15,18 @@
from nummus.models.transaction import Transaction
-def test_empty(session: orm.Session) -> None:
+def test_empty() -> None:
c = EmptyFields()
- c.test(session)
+ c.test()
assert c.issues == {}
def test_no_issues(
- session: orm.Session,
transactions: list[Transaction],
) -> None:
- _ = transactions
c = EmptyFields()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 0
+ c.test()
+ assert not sql.any_(HealthCheckIssue.query())
def test_no_account_number(
@@ -36,14 +34,13 @@ def test_no_account_number(
account: Account,
transactions: list[Transaction],
) -> None:
- account.number = None
- session.commit()
- _ = transactions
+ with session.begin_nested():
+ account.number = None
c = EmptyFields()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 1
+ c.test()
+ assert HealthCheckIssue.count() == 1
- i = session.query(HealthCheckIssue).one()
+ i = HealthCheckIssue.one()
assert i.check == c.name()
assert i.value == f"{account.uri}.number"
uri = i.uri
@@ -57,14 +54,13 @@ def test_no_asset_description(
asset: Asset,
transactions: list[Transaction],
) -> None:
- asset.description = None
- session.commit()
- _ = transactions
+ with session.begin_nested():
+ asset.description = None
c = EmptyFields()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 1
+ c.test()
+ assert HealthCheckIssue.count() == 1
- i = session.query(HealthCheckIssue).one()
+ i = HealthCheckIssue.one()
assert i.check == c.name()
assert i.value == f"{asset.uri}.description"
uri = i.uri
@@ -78,14 +74,14 @@ def test_no_txn_payee(
account: Account,
transactions: list[Transaction],
) -> None:
- txn = transactions[0]
- txn.payee = None
- session.commit()
+ with session.begin_nested():
+ txn = transactions[0]
+ txn.payee = None
c = EmptyFields()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 1
+ c.test()
+ assert HealthCheckIssue.count() == 1
- i = session.query(HealthCheckIssue).one()
+ i = HealthCheckIssue.one()
assert i.check == c.name()
assert i.value == f"{txn.uri}.payee"
uri = i.uri
@@ -99,14 +95,14 @@ def test_uncategorized(
account: Account,
transactions: list[Transaction],
) -> None:
- t_split = transactions[0].splits[0]
- t_split.category_id = TransactionCategory.uncategorized(session)[0]
- session.commit()
+ with session.begin_nested():
+ t_split = transactions[0].splits[0]
+ t_split.category_id = TransactionCategory.uncategorized()[0]
c = EmptyFields()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 1
+ c.test()
+ assert HealthCheckIssue.count() == 1
- i = session.query(HealthCheckIssue).one()
+ i = HealthCheckIssue.one()
assert i.check == c.name()
assert i.value == f"{t_split.uri}.category"
uri = i.uri
diff --git a/tests/health_checks/test_missing_asset_link.py b/tests/health_checks/test_missing_asset_link.py
index e9bb4a53..42a3c332 100644
--- a/tests/health_checks/test_missing_asset_link.py
+++ b/tests/health_checks/test_missing_asset_link.py
@@ -6,7 +6,6 @@
from nummus.health_checks.missing_asset_link import MissingAssetLink
from nummus.models.currency import CURRENCY_FORMATS, DEFAULT_CURRENCY
from nummus.models.health_checks import HealthCheckIssue
-from nummus.models.utils import query_count
if TYPE_CHECKING:
from sqlalchemy import orm
@@ -16,20 +15,18 @@
from nummus.models.transaction import Transaction
-def test_empty(session: orm.Session) -> None:
+def test_empty() -> None:
c = MissingAssetLink()
- c.test(session)
+ c.test()
assert c.issues == {}
def test_no_issues(
- session: orm.Session,
transactions: list[Transaction],
) -> None:
- _ = transactions
c = MissingAssetLink()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 0
+ c.test()
+ assert HealthCheckIssue.count() == 0
def test_missing_link(
@@ -37,16 +34,15 @@ def test_missing_link(
account: Account,
transactions: list[Transaction],
) -> None:
- t_split = transactions[-1].splits[0]
- t_split.asset_id = None
- t_split.asset_quantity_unadjusted = None
- session.commit()
- _ = transactions
+ with session.begin_nested():
+ t_split = transactions[-1].splits[0]
+ t_split.asset_id = None
+ t_split.asset_quantity_unadjusted = None
c = MissingAssetLink()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 1
+ c.test()
+ assert HealthCheckIssue.count() == 1
- i = session.query(HealthCheckIssue).one()
+ i = HealthCheckIssue.one()
assert i.check == c.name()
assert i.value == t_split.uri
uri = i.uri
@@ -65,16 +61,15 @@ def test_extra_link(
asset: Asset,
transactions: list[Transaction],
) -> None:
- t_split = transactions[0].splits[0]
- t_split.asset_id = asset.id_
- t_split.asset_quantity_unadjusted = Decimal()
- session.commit()
- _ = transactions
+ with session.begin_nested():
+ t_split = transactions[0].splits[0]
+ t_split.asset_id = asset.id_
+ t_split.asset_quantity_unadjusted = Decimal()
c = MissingAssetLink()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 1
+ c.test()
+ assert HealthCheckIssue.count() == 1
- i = session.query(HealthCheckIssue).one()
+ i = HealthCheckIssue.one()
assert i.check == c.name()
assert i.value == t_split.uri
uri = i.uri
diff --git a/tests/health_checks/test_missing_valuations.py b/tests/health_checks/test_missing_valuations.py
index 58240ad2..065f1207 100644
--- a/tests/health_checks/test_missing_valuations.py
+++ b/tests/health_checks/test_missing_valuations.py
@@ -2,11 +2,9 @@
from typing import TYPE_CHECKING
+from nummus import sql
from nummus.health_checks.missing_valuations import MissingAssetValuations
from nummus.models.health_checks import HealthCheckIssue
-from nummus.models.utils import (
- query_count,
-)
if TYPE_CHECKING:
from sqlalchemy import orm
@@ -18,9 +16,9 @@
from nummus.models.transaction import Transaction
-def test_empty(session: orm.Session) -> None:
+def test_empty() -> None:
c = MissingAssetValuations()
- c.test(session)
+ c.test()
assert c.issues == {}
@@ -29,25 +27,23 @@ def test_no_issues(
transactions: list[Transaction],
asset_valuation: AssetValuation,
) -> None:
- txn = transactions[1]
- asset_valuation.date_ord = txn.date_ord
- session.commit()
+ with session.begin_nested():
+ txn = transactions[1]
+ asset_valuation.date_ord = txn.date_ord
c = MissingAssetValuations()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 0
+ c.test()
+ assert not sql.any_(HealthCheckIssue.query())
def test_no_valuations(
- session: orm.Session,
asset: Asset,
transactions: list[Transaction],
) -> None:
- _ = transactions
c = MissingAssetValuations()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 1
+ c.test()
+ assert HealthCheckIssue.count() == 1
- i = session.query(HealthCheckIssue).one()
+ i = HealthCheckIssue.one()
assert i.check == c.name()
assert i.value == asset.uri
uri = i.uri
@@ -57,17 +53,16 @@ def test_no_valuations(
def test_no_valuations_before_txn(
- session: orm.Session,
asset: Asset,
transactions: list[Transaction],
asset_valuation: AssetValuation,
) -> None:
txn = transactions[1]
c = MissingAssetValuations()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 1
+ c.test()
+ assert HealthCheckIssue.count() == 1
- i = session.query(HealthCheckIssue).one()
+ i = HealthCheckIssue.one()
assert i.check == c.name()
assert i.value == asset.uri
uri = i.uri
diff --git a/tests/health_checks/test_outlier_asset_price.py b/tests/health_checks/test_outlier_asset_price.py
index a29e0d4b..a6757d7f 100644
--- a/tests/health_checks/test_outlier_asset_price.py
+++ b/tests/health_checks/test_outlier_asset_price.py
@@ -5,10 +5,10 @@
import pytest
+from nummus import sql
from nummus.health_checks.outlier_asset_price import OutlierAssetPrice
from nummus.models.currency import CURRENCY_FORMATS, DEFAULT_CURRENCY
from nummus.models.health_checks import HealthCheckIssue
-from nummus.models.utils import query_count
if TYPE_CHECKING:
from sqlalchemy import orm
@@ -20,9 +20,9 @@
from nummus.models.transaction import Transaction
-def test_empty(session: orm.Session) -> None:
+def test_empty() -> None:
c = OutlierAssetPrice()
- c.test(session)
+ c.test()
assert c.issues == {}
@@ -31,15 +31,15 @@ def test_zero_quantity(
transactions: list[Transaction],
asset_valuation: AssetValuation,
) -> None:
- t_split = transactions[1].splits[0]
- t_split.asset_quantity_unadjusted = Decimal()
- asset_valuation.date_ord = t_split.date_ord
- asset_valuation.value = Decimal(10)
- session.commit()
+ with session.begin_nested():
+ t_split = transactions[1].splits[0]
+ t_split.asset_quantity_unadjusted = Decimal()
+ asset_valuation.date_ord = t_split.date_ord
+ asset_valuation.value = Decimal(10)
c = OutlierAssetPrice()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 0
+ c.test()
+ assert not sql.any_(HealthCheckIssue.query())
@pytest.mark.parametrize(
@@ -58,21 +58,21 @@ def test_check(
amount: Decimal,
target_word: str | None,
) -> None:
- t_split = transactions[1].splits[0]
- asset_valuation.date_ord = t_split.date_ord
- asset_valuation.value = Decimal(10)
- t_split.amount = amount
- session.commit()
+ with session.begin_nested():
+ t_split = transactions[1].splits[0]
+ asset_valuation.date_ord = t_split.date_ord
+ asset_valuation.value = Decimal(10)
+ t_split.amount = amount
c = OutlierAssetPrice()
- c.test(session)
+ c.test()
if target_word is None:
- assert query_count(session.query(HealthCheckIssue)) == 0
+ assert not sql.any_(HealthCheckIssue.query())
return
- assert query_count(session.query(HealthCheckIssue)) == 1
+ assert HealthCheckIssue.count() == 1
- i = session.query(HealthCheckIssue).one()
+ i = HealthCheckIssue.one()
assert i.check == c.name()
assert i.value == t_split.uri
uri = i.uri
diff --git a/tests/health_checks/test_overdrawn_accounts.py b/tests/health_checks/test_overdrawn_accounts.py
index 71ad9c95..a5c4fa4d 100644
--- a/tests/health_checks/test_overdrawn_accounts.py
+++ b/tests/health_checks/test_overdrawn_accounts.py
@@ -3,10 +3,10 @@
from decimal import Decimal
from typing import TYPE_CHECKING
+from nummus import sql
from nummus.health_checks.overdrawn_accounts import OverdrawnAccounts
from nummus.models.currency import CURRENCY_FORMATS, DEFAULT_CURRENCY
from nummus.models.health_checks import HealthCheckIssue
-from nummus.models.utils import query_count
if TYPE_CHECKING:
from sqlalchemy import orm
@@ -15,21 +15,18 @@
from nummus.models.transaction import Transaction
-def test_empty(session: orm.Session) -> None:
+def test_empty() -> None:
c = OverdrawnAccounts()
- c.test(session)
+ c.test()
assert c.issues == {}
def test_no_issues(
- session: orm.Session,
transactions: list[Transaction],
) -> None:
- _ = transactions
- session.commit()
c = OverdrawnAccounts()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 0
+ c.test()
+ assert not sql.any_(HealthCheckIssue.query())
def test_check(
@@ -37,13 +34,14 @@ def test_check(
account: Account,
transactions: list[Transaction],
) -> None:
- t_split = transactions[0].splits[0]
- t_split.amount = Decimal(-1)
+ with session.begin_nested():
+ t_split = transactions[0].splits[0]
+ t_split.amount = Decimal(-1)
c = OverdrawnAccounts()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 1
+ c.test()
+ assert HealthCheckIssue.count() == 1
- i = session.query(HealthCheckIssue).one()
+ i = HealthCheckIssue.one()
assert i.check == c.name()
assert i.value == f"{account.id_}.{t_split.date_ord}"
uri = i.uri
diff --git a/tests/health_checks/test_typos.py b/tests/health_checks/test_typos.py
index 62a7c2e4..8095eee9 100644
--- a/tests/health_checks/test_typos.py
+++ b/tests/health_checks/test_typos.py
@@ -4,9 +4,9 @@
import pytest
+from nummus import sql
from nummus.health_checks.typos import Typos
from nummus.models.health_checks import HealthCheckIssue
-from nummus.models.utils import query_count
if TYPE_CHECKING:
from sqlalchemy import orm
@@ -16,20 +16,18 @@
from nummus.models.transaction import Transaction
-def test_empty(session: orm.Session) -> None:
+def test_empty() -> None:
c = Typos()
- c.test(session)
+ c.test()
assert c.issues == {}
def test_no_issues(
- session: orm.Session,
transactions: list[Transaction],
) -> None:
- _ = transactions
c = Typos()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 0
+ c.test()
+ assert not sql.any_(HealthCheckIssue.query())
def test_mispelled_proper_noun(
@@ -37,14 +35,14 @@ def test_mispelled_proper_noun(
account: Account,
account_savings: Account,
) -> None:
- # institution is proper noun so make a almost the same
- account.institution = account_savings.institution + "a"
- session.commit()
+ with session.begin_nested():
+ # institution is proper noun so make a almost the same
+ account.institution = account_savings.institution + "a"
c = Typos()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 1
+ c.test()
+ assert HealthCheckIssue.count() == 1
- i = session.query(HealthCheckIssue).one()
+ i = HealthCheckIssue.one()
assert i.check == c.name()
assert i.value == account.institution
uri = i.uri
@@ -59,14 +57,14 @@ def test_mispelled(
asset: Asset,
no_description_typos: bool,
) -> None:
- # asset description is checked for dictionary spelling
- asset.description = "Banana mispel & 1234 bananas"
- session.commit()
+ with session.begin_nested():
+ # asset description is checked for dictionary spelling
+ asset.description = "Banana mispel & 1234 bananas"
c = Typos(no_description_typos=no_description_typos)
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 1
+ c.test()
+ assert HealthCheckIssue.count() == 1
- i = session.query(HealthCheckIssue).one()
+ i = HealthCheckIssue.one()
assert i.check == c.name()
assert i.value == "mispel"
uri = i.uri
diff --git a/tests/health_checks/test_unbalanced_transfers.py b/tests/health_checks/test_unbalanced_transfers.py
index 35d996f9..6cbd39cc 100644
--- a/tests/health_checks/test_unbalanced_transfers.py
+++ b/tests/health_checks/test_unbalanced_transfers.py
@@ -4,10 +4,10 @@
from decimal import Decimal
from typing import TYPE_CHECKING
+from nummus import sql
from nummus.health_checks.unbalanced_transfers import UnbalancedTransfers
from nummus.models.currency import CURRENCY_FORMATS, DEFAULT_CURRENCY
from nummus.models.health_checks import HealthCheckIssue
-from nummus.models.utils import query_count
if TYPE_CHECKING:
from sqlalchemy import orm
@@ -16,20 +16,18 @@
from nummus.models.transaction import Transaction
-def test_empty(session: orm.Session) -> None:
+def test_empty() -> None:
c = UnbalancedTransfers()
- c.test(session)
+ c.test()
assert c.issues == {}
def test_no_transfers(
- session: orm.Session,
transactions: list[Transaction],
) -> None:
- _ = transactions
c = UnbalancedTransfers()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 0
+ c.test()
+ assert not sql.any_(HealthCheckIssue.query())
def test_no_issues(
@@ -38,25 +36,25 @@ def test_no_issues(
transactions_spending: list[Transaction],
categories: dict[str, int],
) -> None:
- amount = Decimal(100)
- spec = [
- (0, amount),
- (0, -amount),
- (1, amount),
- (1, -amount),
- ]
- for i, (dt, a) in enumerate(spec):
- txn = transactions_spending[i]
- txn.date = today + datetime.timedelta(days=dt)
- t_split = txn.splits[0]
- t_split.category_id = categories["transfers"]
- t_split.amount = a
- t_split.parent = txn
- session.commit()
+ with session.begin_nested():
+ amount = Decimal(100)
+ spec = [
+ (0, amount),
+ (0, -amount),
+ (1, amount),
+ (1, -amount),
+ ]
+ for i, (dt, a) in enumerate(spec):
+ txn = transactions_spending[i]
+ txn.date = today + datetime.timedelta(days=dt)
+ t_split = txn.splits[0]
+ t_split.category_id = categories["transfers"]
+ t_split.amount = a
+ t_split.parent = txn
c = UnbalancedTransfers()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 0
+ c.test()
+ assert not sql.any_(HealthCheckIssue.query())
def test_wrong_amount(
@@ -66,19 +64,19 @@ def test_wrong_amount(
transactions_spending: list[Transaction],
categories: dict[str, int],
) -> None:
- amount = Decimal(100)
- spec = [amount, -amount * 2]
- for i, a in enumerate(spec):
- t_split = transactions_spending[i].splits[0]
- t_split.category_id = categories["transfers"]
- t_split.amount = a
- session.commit()
+ with session.begin_nested():
+ amount = Decimal(100)
+ spec = [amount, -amount * 2]
+ for i, a in enumerate(spec):
+ t_split = transactions_spending[i].splits[0]
+ t_split.category_id = categories["transfers"]
+ t_split.amount = a
c = UnbalancedTransfers()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 1
+ c.test()
+ assert HealthCheckIssue.count() == 1
- i = session.query(HealthCheckIssue).one()
+ i = HealthCheckIssue.one()
assert i.check == c.name()
assert i.value == today.isoformat()
uri = i.uri
@@ -99,19 +97,19 @@ def test_one_pair(
transactions_spending: list[Transaction],
categories: dict[str, int],
) -> None:
- amount = Decimal(100)
- spec = [amount, -amount, -amount]
- for i, a in enumerate(spec):
- t_split = transactions_spending[i].splits[0]
- t_split.category_id = categories["transfers"]
- t_split.amount = a
- session.commit()
+ with session.begin_nested():
+ amount = Decimal(100)
+ spec = [amount, -amount, -amount]
+ for i, a in enumerate(spec):
+ t_split = transactions_spending[i].splits[0]
+ t_split.category_id = categories["transfers"]
+ t_split.amount = a
c = UnbalancedTransfers()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 1
+ c.test()
+ assert HealthCheckIssue.count() == 1
- i = session.query(HealthCheckIssue).one()
+ i = HealthCheckIssue.one()
assert i.check == c.name()
assert i.value == today.isoformat()
uri = i.uri
@@ -131,33 +129,30 @@ def test_wrong_date(
transactions_spending: list[Transaction],
categories: dict[str, int],
) -> None:
- amount = Decimal(100)
- spec = [
- (0, amount),
- (0, -amount),
- (0, amount),
- (1, -amount),
- ]
- for i, (dt, a) in enumerate(spec):
- txn = transactions_spending[i]
- txn.date = today + datetime.timedelta(days=dt)
- t_split = txn.splits[0]
- t_split.category_id = categories["transfers"]
- t_split.amount = a
- t_split.parent = txn
- amount = Decimal(100)
- session.commit()
+ with session.begin_nested():
+ amount = Decimal(100)
+ spec = [
+ (0, amount),
+ (0, -amount),
+ (0, amount),
+ (1, -amount),
+ ]
+ for i, (dt, a) in enumerate(spec):
+ txn = transactions_spending[i]
+ txn.date = today + datetime.timedelta(days=dt)
+ t_split = txn.splits[0]
+ t_split.category_id = categories["transfers"]
+ t_split.amount = a
+ t_split.parent = txn
+ amount = Decimal(100)
tomorrow = today + datetime.timedelta(days=1)
c = UnbalancedTransfers()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 2
+ c.test()
+ assert HealthCheckIssue.count() == 2
- i = (
- session.query(HealthCheckIssue)
- .where(HealthCheckIssue.value == today.isoformat())
- .one()
- )
+ query = HealthCheckIssue.query().where(HealthCheckIssue.value == today.isoformat())
+ i = sql.one(query)
assert i.check == c.name()
cf = CURRENCY_FORMATS[DEFAULT_CURRENCY]
lines = (
@@ -166,11 +161,10 @@ def test_wrong_date(
)
assert i.msg == "\n".join(lines)
- i = (
- session.query(HealthCheckIssue)
- .where(HealthCheckIssue.value == tomorrow.isoformat())
- .one()
+ query = HealthCheckIssue.query().where(
+ HealthCheckIssue.value == tomorrow.isoformat(),
)
+ i = sql.one(query)
assert i.check == c.name()
assert i.value == tomorrow.isoformat()
lines = (
diff --git a/tests/health_checks/test_uncleared_transactions.py b/tests/health_checks/test_uncleared_transactions.py
index 67f755cc..8dfcfd79 100644
--- a/tests/health_checks/test_uncleared_transactions.py
+++ b/tests/health_checks/test_uncleared_transactions.py
@@ -2,10 +2,10 @@
from typing import TYPE_CHECKING
+from nummus import sql
from nummus.health_checks.uncleared_transactions import UnclearedTransactions
from nummus.models.currency import CURRENCY_FORMATS, DEFAULT_CURRENCY
from nummus.models.health_checks import HealthCheckIssue
-from nummus.models.utils import query_count
if TYPE_CHECKING:
from sqlalchemy import orm
@@ -14,20 +14,18 @@
from nummus.models.transaction import Transaction
-def test_empty(session: orm.Session) -> None:
+def test_empty() -> None:
c = UnclearedTransactions()
- c.test(session)
+ c.test()
assert c.issues == {}
def test_no_issues(
- session: orm.Session,
transactions: list[Transaction],
) -> None:
- _ = transactions
c = UnclearedTransactions()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 0
+ c.test()
+ assert not sql.any_(HealthCheckIssue.query())
def test_check(
@@ -35,17 +33,17 @@ def test_check(
account: Account,
transactions: list[Transaction],
) -> None:
- txn = transactions[0]
- txn.cleared = False
- t_split = txn.splits[0]
- t_split.parent = txn
- session.commit()
+ with session.begin_nested():
+ txn = transactions[0]
+ txn.cleared = False
+ t_split = txn.splits[0]
+ t_split.parent = txn
c = UnclearedTransactions()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 1
+ c.test()
+ assert HealthCheckIssue.count() == 1
- i = session.query(HealthCheckIssue).one()
+ i = HealthCheckIssue.one()
assert i.check == c.name()
assert i.value == t_split.uri
uri = i.uri
diff --git a/tests/health_checks/test_unnecessary_splits.py b/tests/health_checks/test_unnecessary_splits.py
index 90e2d1a8..5b7ef548 100644
--- a/tests/health_checks/test_unnecessary_splits.py
+++ b/tests/health_checks/test_unnecessary_splits.py
@@ -2,10 +2,10 @@
from typing import TYPE_CHECKING
+from nummus import sql
from nummus.health_checks.unnecessary_slits import UnnecessarySplits
from nummus.models.health_checks import HealthCheckIssue
from nummus.models.transaction import TransactionSplit
-from nummus.models.utils import query_count
if TYPE_CHECKING:
from sqlalchemy import orm
@@ -14,20 +14,18 @@
from nummus.models.transaction import Transaction
-def test_empty(session: orm.Session) -> None:
+def test_empty() -> None:
c = UnnecessarySplits()
- c.test(session)
+ c.test()
assert c.issues == {}
def test_no_issues(
- session: orm.Session,
transactions: list[Transaction],
) -> None:
- _ = transactions
c = UnnecessarySplits()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 0
+ c.test()
+ assert not sql.any_(HealthCheckIssue.query())
def test_check(
@@ -35,21 +33,20 @@ def test_check(
account: Account,
transactions: list[Transaction],
) -> None:
- txn = transactions[0]
- t_split = txn.splits[0]
- t_split = TransactionSplit(
- parent=txn,
- amount=t_split.amount,
- category_id=t_split.category_id,
- )
- session.add(t_split)
- session.commit()
+ with session.begin_nested():
+ txn = transactions[0]
+ t_split = txn.splits[0]
+ t_split = TransactionSplit.create(
+ parent=txn,
+ amount=t_split.amount,
+ category_id=t_split.category_id,
+ )
c = UnnecessarySplits()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 1
+ c.test()
+ assert HealthCheckIssue.count() == 1
- i = session.query(HealthCheckIssue).one()
+ i = HealthCheckIssue.one()
assert i.check == c.name()
assert i.value == f"{txn.id_}.{t_split.payee}.{t_split.category_id}"
uri = i.uri
diff --git a/tests/health_checks/test_unused_categories.py b/tests/health_checks/test_unused_categories.py
index 1992a9cf..e2a64525 100644
--- a/tests/health_checks/test_unused_categories.py
+++ b/tests/health_checks/test_unused_categories.py
@@ -5,7 +5,6 @@
from nummus.health_checks.unused_categories import UnusedCategories
from nummus.models.health_checks import HealthCheckIssue
from nummus.models.transaction_category import TransactionCategory
-from nummus.models.utils import query_count
if TYPE_CHECKING:
from sqlalchemy import orm
@@ -13,13 +12,12 @@
from nummus.models.transaction import Transaction
-def test_empty(session: orm.Session) -> None:
+def test_empty() -> None:
# Mark all locked since those are excluded
- session.query(TransactionCategory).update({"locked": True})
- session.commit()
+ TransactionCategory.query().update({"locked": True})
c = UnusedCategories()
- c.test(session)
+ c.test()
assert c.issues == {}
@@ -28,18 +26,16 @@ def test_one(
transactions: list[Transaction],
categories: dict[str, int],
) -> None:
- _ = transactions
# Mark all but groceries and other income locked since those are excluded
- session.query(TransactionCategory).where(
+ TransactionCategory.query().where(
TransactionCategory.name.not_in({"groceries", "other income"}),
).update({"locked": True})
- session.commit()
c = UnusedCategories()
- c.test(session)
- assert query_count(session.query(HealthCheckIssue)) == 1
+ c.test()
+ assert HealthCheckIssue.count() == 1
- i = session.query(HealthCheckIssue).one()
+ i = HealthCheckIssue.one()
assert i.check == c.name()
assert i.value == TransactionCategory.id_to_uri(categories["groceries"])
uri = i.uri
diff --git a/tests/migrations/test_base.py b/tests/migrations/test_base.py
index 195dd76f..58227ed5 100644
--- a/tests/migrations/test_base.py
+++ b/tests/migrations/test_base.py
@@ -27,7 +27,6 @@ class MockMigrator(Migrator):
@override
def migrate(self, p: Portfolio) -> list[str]:
- _ = p
return ["Comments"]
@@ -38,85 +37,85 @@ def test_version() -> None:
def test_drop_column(session: orm.Session) -> None:
m = MockMigrator()
- m.drop_column(session, Asset, "category")
- session.commit()
+ with session.begin_nested():
+ m.drop_column(Asset, "category")
assert m.pending_schema_updates == set()
- result = "\n".join(dump_table_configs(session, Asset))
+ result = "\n".join(dump_table_configs(Asset))
assert "category" not in result
def test_drop_column_with_constraints(session: orm.Session) -> None:
m = MockMigrator()
- m.drop_column(session, AssetValuation, "value")
- session.commit()
+ with session.begin_nested():
+ m.drop_column(AssetValuation, "value")
assert m.pending_schema_updates == {AssetValuation}
- result = "\n".join(dump_table_configs(session, AssetValuation))
+ result = "\n".join(dump_table_configs(AssetValuation))
assert "value" not in result
def test_add_column_no_value_set(session: orm.Session, asset: Asset) -> None:
m = MockMigrator()
- m.drop_column(session, Asset, "category")
- session.commit()
+ with session.begin_nested():
+ m.drop_column(Asset, "category")
m.pending_schema_updates.clear()
- m.add_column(session, Asset, Asset.category)
- session.commit()
+ with session.begin_nested():
+ m.add_column(Asset, Asset.category)
assert m.pending_schema_updates == {Asset}
- result = "\n".join(dump_table_configs(session, Asset))
+ result = "\n".join(dump_table_configs(Asset))
assert "category" in result
+ asset.refresh()
assert asset.category is None
def test_add_column_value_set(session: orm.Session, asset: Asset) -> None:
m = MockMigrator()
- m.drop_column(session, Asset, "category")
- session.commit()
+ with session.begin_nested():
+ m.drop_column(Asset, "category")
m.pending_schema_updates.clear()
- m.add_column(session, Asset, Asset.category, AssetCategory.STOCKS)
- session.commit()
+ with session.begin_nested():
+ m.add_column(Asset, Asset.category, AssetCategory.STOCKS)
assert m.pending_schema_updates == {Asset}
- result = "\n".join(dump_table_configs(session, Asset))
+ result = "\n".join(dump_table_configs(Asset))
assert "category" in result
+ asset.refresh()
assert asset.category == AssetCategory.STOCKS
def test_rename_column(session: orm.Session) -> None:
m = MockMigrator()
- m.rename_column(session, Asset, "category", "class")
- session.commit()
+ with session.begin_nested():
+ m.rename_column(Asset, "category", "class")
assert m.pending_schema_updates == {Asset}
- result = "\n".join(dump_table_configs(session, Asset))
+ result = "\n".join(dump_table_configs(Asset))
assert "category" not in result
assert "class" in result
def test_migrate_schemas_no_value_set(empty_portfolio: Portfolio, asset: Asset) -> None:
- _ = asset
m = SchemaMigrator(set())
- with empty_portfolio.begin_session() as s:
- m.drop_column(s, Asset, "category")
- with empty_portfolio.begin_session() as s:
- m.add_column(s, Asset, Asset.category)
+ with empty_portfolio.begin_session():
+ m.drop_column(Asset, "category")
+ with empty_portfolio.begin_session():
+ m.add_column(Asset, Asset.category)
with pytest.raises(exc.IntegrityError):
m.migrate(empty_portfolio)
def test_migrate_schemas_value_set(empty_portfolio: Portfolio, asset: Asset) -> None:
- _ = asset
m = SchemaMigrator(set())
- with empty_portfolio.begin_session() as s:
- m.drop_column(s, Asset, "category")
- with empty_portfolio.begin_session() as s:
- m.add_column(s, Asset, Asset.category, AssetCategory.STOCKS)
+ with empty_portfolio.begin_session():
+ m.drop_column(Asset, "category")
+ with empty_portfolio.begin_session():
+ m.add_column(Asset, Asset.category, AssetCategory.STOCKS)
assert m.migrate(empty_portfolio) == []
diff --git a/tests/migrations/test_v0_10.py b/tests/migrations/test_v0_10.py
index 2af46874..163cb3ee 100644
--- a/tests/migrations/test_v0_10.py
+++ b/tests/migrations/test_v0_10.py
@@ -23,7 +23,7 @@ def test_migrate(tmp_path: Path, data_path: Path) -> None:
target = []
assert result == target
- with p.begin_session() as s:
- result = "\n".join(dump_table_configs(s, TransactionCategory))
+ with p.begin_session():
+ result = "\n".join(dump_table_configs(TransactionCategory))
assert "essential_spending" in result
assert TransactionCategory in m.pending_schema_updates
diff --git a/tests/migrations/test_v0_13.py b/tests/migrations/test_v0_13.py
index bca8b5c1..0025d6be 100644
--- a/tests/migrations/test_v0_13.py
+++ b/tests/migrations/test_v0_13.py
@@ -3,13 +3,11 @@
import shutil
from typing import TYPE_CHECKING
+from nummus import sql
from nummus.migrations.v0_13 import MigratorV0_13
from nummus.models.label import Label, LabelLink
from nummus.models.transaction import TransactionSplit
-from nummus.models.utils import (
- dump_table_configs,
- query_count,
-)
+from nummus.models.utils import dump_table_configs
from nummus.portfolio import Portfolio
if TYPE_CHECKING:
@@ -27,17 +25,15 @@ def test_migrate(tmp_path: Path, data_path: Path) -> None:
target = []
assert result == target
- with p.begin_session() as s:
- result = "\n".join(dump_table_configs(s, TransactionSplit))
+ with p.begin_session():
+ result = "\n".join(dump_table_configs(TransactionSplit))
assert "label" not in result
- result = "\n".join(dump_table_configs(s, Label))
+ result = "\n".join(dump_table_configs(Label))
assert "name" in result
- n = query_count(s.query(Label))
- assert n == 1
+ assert sql.count(Label.query()) == 1
- result = "\n".join(dump_table_configs(s, LabelLink))
+ result = "\n".join(dump_table_configs(LabelLink))
assert "label_id" in result
- n = query_count(s.query(LabelLink))
- assert n == 100
+ assert sql.count(LabelLink.query()) == 100
diff --git a/tests/migrations/test_v0_15.py b/tests/migrations/test_v0_15.py
index d7c28d69..39c9b6ea 100644
--- a/tests/migrations/test_v0_15.py
+++ b/tests/migrations/test_v0_15.py
@@ -3,12 +3,10 @@
import shutil
from typing import TYPE_CHECKING
+from nummus import sql
from nummus.migrations.v0_15 import MigratorV0_15
from nummus.models.label import Label, LabelLink
-from nummus.models.utils import (
- dump_table_configs,
- query_count,
-)
+from nummus.models.utils import dump_table_configs
from nummus.portfolio import Portfolio
if TYPE_CHECKING:
@@ -26,14 +24,12 @@ def test_migrate(tmp_path: Path, data_path: Path) -> None:
target = []
assert result == target
- with p.begin_session() as s:
- result = "\n".join(dump_table_configs(s, Label))
+ with p.begin_session():
+ result = "\n".join(dump_table_configs(Label))
assert "name" in result
- n = query_count(s.query(Label))
- assert n == 1
+ assert sql.count(Label.query()) == 1
- result = "\n".join(dump_table_configs(s, LabelLink))
+ result = "\n".join(dump_table_configs(LabelLink))
assert "label_id" in result
- n = query_count(s.query(LabelLink))
- assert n == 100
+ assert sql.count(LabelLink.query()) == 100
diff --git a/tests/migrations/test_v0_16.py b/tests/migrations/test_v0_16.py
index 41fe0f42..0798889d 100644
--- a/tests/migrations/test_v0_16.py
+++ b/tests/migrations/test_v0_16.py
@@ -28,11 +28,11 @@ def test_migrate(tmp_path: Path, data_path: Path) -> None:
]
assert result == target
- with p.begin_session() as s:
- result = "\n".join(dump_table_configs(s, Account))
+ with p.begin_session():
+ result = "\n".join(dump_table_configs(Account))
assert "currency" in result
- result = "\n".join(dump_table_configs(s, Asset))
+ result = "\n".join(dump_table_configs(Asset))
assert "currency" in result
- assert Config.base_currency(s) == Currency.USD
+ assert Config.base_currency() == Currency.USD
diff --git a/tests/migrations/test_v0_2.py b/tests/migrations/test_v0_2.py
index 1eeb038e..5d1e6367 100644
--- a/tests/migrations/test_v0_2.py
+++ b/tests/migrations/test_v0_2.py
@@ -28,14 +28,14 @@ def test_migrate(tmp_path: Path, data_path: Path) -> None:
]
assert result == target
- with p.begin_session() as s:
- result = "\n".join(dump_table_configs(s, Transaction))
+ with p.begin_session():
+ result = "\n".join(dump_table_configs(Transaction))
assert "linked" not in result
assert "locked" not in result
assert "cleared" in result
assert "payee" in result
- result = "\n".join(dump_table_configs(s, TransactionSplit))
+ result = "\n".join(dump_table_configs(TransactionSplit))
assert "linked" not in result
assert "locked" not in result
assert "cleared" in result
diff --git a/tests/mock_yfinance.py b/tests/mock_yfinance.py
index 28f10069..046ec150 100644
--- a/tests/mock_yfinance.py
+++ b/tests/mock_yfinance.py
@@ -84,6 +84,6 @@ def history(
dt += datetime.timedelta(days=1)
return pd.DataFrame(
- index=pd.to_datetime(dates),
+ index=pd.to_datetime(dates), # type: ignore[attr-defined]
data={"Close": close, "Stock Splits": split},
)
diff --git a/tests/models/asset/test_asset.py b/tests/models/asset/test_asset.py
index 3b4f9d2a..8adf82d5 100644
--- a/tests/models/asset/test_asset.py
+++ b/tests/models/asset/test_asset.py
@@ -7,6 +7,7 @@
import pytest
from nummus import exceptions as exc
+from nummus import sql
from nummus.models.asset import (
Asset,
AssetCategory,
@@ -16,27 +17,18 @@
)
from nummus.models.currency import Currency, DEFAULT_CURRENCY
from nummus.models.label import LabelLink
-from nummus.models.utils import (
- query_count,
- query_to_dict,
- update_rows,
-)
+from nummus.models.utils import update_rows
from tests import conftest
if TYPE_CHECKING:
- from sqlalchemy import orm
-
from nummus.models.account import Account
- from nummus.models.asset import (
- AssetSplit,
- )
+ from nummus.models.asset import AssetSplit
from nummus.models.transaction import Transaction
from tests.conftest import RandomStringGenerator
@pytest.fixture
def valuations(
- session: orm.Session,
today_ord: int,
asset: Asset,
) -> list[AssetValuation]:
@@ -47,15 +39,12 @@ def valuations(
today_ord + 3: {"value": Decimal(10), "asset_id": a_id},
}
- query = session.query(AssetValuation)
- update_rows(session, AssetValuation, query, "date_ord", updates)
- session.commit()
- return query.all()
+ update_rows(AssetValuation, AssetValuation.query(), "date_ord", updates)
+ return AssetValuation.all()
@pytest.fixture
def valuations_five(
- session: orm.Session,
today_ord: int,
asset: Asset,
) -> list[AssetValuation]:
@@ -68,14 +57,11 @@ def valuations_five(
today_ord + 7: {"value": Decimal(10), "asset_id": a_id},
}
- query = session.query(AssetValuation)
- update_rows(session, AssetValuation, query, "date_ord", updates)
- session.commit()
- return query.all()
+ update_rows(AssetValuation, AssetValuation.query(), "date_ord", updates)
+ return AssetValuation.all()
def test_init_properties(
- session: orm.Session,
rand_str_generator: RandomStringGenerator,
) -> None:
d = {
@@ -86,9 +72,7 @@ def test_init_properties(
"currency": DEFAULT_CURRENCY,
}
- a = Asset(**d)
- session.add(a)
- session.commit()
+ a = Asset.create(**d)
assert a.name == d["name"]
assert a.description == d["description"]
@@ -103,7 +87,6 @@ def test_short() -> None:
def test_get_value_empty(
today_ord: int,
- session: orm.Session,
asset: Asset,
) -> None:
start_ord = today_ord - 3
@@ -111,7 +94,7 @@ def test_get_value_empty(
result = asset.get_value(start_ord, end_ord)
assert result == [Decimal(0)] * 7
- result = Asset.get_value_all(session, start_ord, end_ord)
+ result = Asset.get_value_all(start_ord, end_ord)
assert result == {}
@@ -120,7 +103,6 @@ def test_get_value(
asset: Asset,
valuations: list[AssetValuation],
) -> None:
- _ = valuations
start_ord = today_ord - 3
end_ord = today_ord + 3
result = asset.get_value(start_ord, end_ord)
@@ -142,7 +124,6 @@ def test_get_value_interpolate(
valuations: list[AssetValuation],
) -> None:
asset.interpolate = True
- _ = valuations
start_ord = today_ord - 3
end_ord = today_ord + 3
result = asset.get_value(start_ord, end_ord)
@@ -163,7 +144,6 @@ def test_get_value_today(
asset: Asset,
valuations: list[AssetValuation],
) -> None:
- _ = valuations
result = asset.get_value(today_ord, today_ord)
assert result == [Decimal(100)]
@@ -174,7 +154,6 @@ def test_get_value_tomorrow_interpolate(
valuations: list[AssetValuation],
) -> None:
asset.interpolate = True
- _ = valuations
result = asset.get_value(today_ord + 1, today_ord + 1)
assert result == [Decimal(70)]
@@ -185,7 +164,6 @@ def test_update_splits_empty(
asset: Asset,
transactions: list[Transaction],
) -> None:
- _ = transactions
asset.update_splits()
assets = account.get_asset_qty(today_ord, today_ord)
assert assets == {asset.id_: [Decimal(10)]}
@@ -198,8 +176,6 @@ def test_update_splits(
asset_split: AssetSplit,
transactions: list[Transaction],
) -> None:
- _ = transactions
- _ = asset_split
asset.update_splits()
assets = account.get_asset_qty(today_ord, today_ord)
assert assets == {asset.id_: [Decimal(100)]}
@@ -216,8 +192,6 @@ def test_prune_valuations_none(
valuations: list[AssetValuation],
transactions: list[Transaction],
) -> None:
- _ = valuations
- _ = transactions
assert asset.prune_valuations() == 0
@@ -236,7 +210,6 @@ def test_prune_valuations_none(
ids=conftest.id_func,
)
def test_prune_valuations_first_txn(
- session: orm.Session,
asset: Asset,
valuations_five: list[AssetValuation],
transactions: list[Transaction],
@@ -246,31 +219,28 @@ def test_prune_valuations_first_txn(
for i in to_delete:
txn = transactions[i]
for t_split in txn.splits:
- session.query(LabelLink).where(LabelLink.t_split_id == t_split.id_).delete()
- session.delete(t_split)
- session.delete(txn)
- _ = valuations_five
+ LabelLink.query().where(LabelLink.t_split_id == t_split.id_).delete()
+ t_split.delete()
+ txn.delete()
assert asset.prune_valuations() == target
def test_prune_valuations_index(asset: Asset, valuations: list[AssetValuation]) -> None:
- _ = valuations
asset.category = AssetCategory.INDEX
assert asset.prune_valuations() == 0
-def test_update_valuations_none(session: orm.Session, asset: Asset) -> None:
+def test_update_valuations_none(asset: Asset) -> None:
asset.ticker = None
- session.commit()
with pytest.raises(exc.NoAssetWebSourceError):
asset.update_valuations(through_today=True)
-def test_update_valuations_empty(session: orm.Session, asset: Asset) -> None:
+def test_update_valuations_empty(asset: Asset) -> None:
start, end = asset.update_valuations(through_today=True)
assert start is None
assert end is None
- assert query_count(session.query(AssetValuation)) == 0
+ assert not sql.any_(AssetValuation.query())
@pytest.mark.parametrize(
@@ -283,7 +253,6 @@ def test_update_valuations_empty(session: orm.Session, asset: Asset) -> None:
)
def test_update_valuations(
today: datetime.date,
- session: orm.Session,
transactions: list[Transaction],
category: AssetCategory,
asset: Asset,
@@ -304,22 +273,20 @@ def test_update_valuations(
while start <= end:
n += 0 if start.weekday() in {5, 6} else 1
start += datetime.timedelta(days=1)
- assert query_count(session.query(AssetValuation)) == n
+ assert sql.count(AssetValuation.query()) == n
def test_update_valuations_delisted(
asset: Asset,
transactions: list[Transaction],
) -> None:
- _ = transactions
asset.ticker = "APPLE"
with pytest.raises(exc.AssetWebError):
asset.update_valuations(through_today=True)
-def test_update_sectors_none(session: orm.Session, asset: Asset) -> None:
+def test_update_sectors_none(asset: Asset) -> None:
asset.ticker = None
- session.commit()
with pytest.raises(exc.NoAssetWebSourceError):
asset.update_sectors()
@@ -347,43 +314,39 @@ def test_update_sectors_none(session: orm.Session, asset: Asset) -> None:
],
)
def test_update_sectors(
- session: orm.Session,
asset: Asset,
ticker: str,
target: dict[USSector, Decimal],
) -> None:
asset.ticker = ticker
asset.update_sectors()
- session.commit()
- query = (
- session.query(AssetSector)
- .with_entities(AssetSector.sector, AssetSector.weight)
- .where(AssetSector.asset_id == asset.id_)
+ query = AssetSector.query(AssetSector.sector, AssetSector.weight).where(
+ AssetSector.asset_id == asset.id_,
)
- sectors: dict[USSector, Decimal] = query_to_dict(query)
+ sectors: dict[USSector, Decimal] = sql.to_dict(query)
assert sectors == target
-def test_index_twrr_none(today_ord: int, session: orm.Session) -> None:
+def test_index_twrr_none(today_ord: int) -> None:
with pytest.raises(exc.ProtectedObjectNotFoundError):
- Asset.index_twrr(session, "Fake Index", today_ord, today_ord)
+ Asset.index_twrr("Fake Index", today_ord, today_ord)
-def test_index_twrr(today_ord: int, session: orm.Session, asset: Asset) -> None:
+def test_index_twrr(today_ord: int, asset: Asset) -> None:
asset.category = AssetCategory.INDEX
- result = Asset.index_twrr(session, asset.name, today_ord - 3, today_ord + 3)
+ result = Asset.index_twrr(asset.name, today_ord - 3, today_ord + 3)
# utils.twrr and Asset.get_value already tested, just check they connect well
assert result == [Decimal(0)] * 7
-def test_index_twrr_today(today_ord: int, session: orm.Session, asset: Asset) -> None:
+def test_index_twrr_today(today_ord: int, asset: Asset) -> None:
asset.category = AssetCategory.INDEX
- result = Asset.index_twrr(session, asset.name, today_ord, today_ord)
+ result = Asset.index_twrr(asset.name, today_ord, today_ord)
assert result == [Decimal(0)]
-def test_add_indices(session: orm.Session) -> None:
- for asset in session.query(Asset).all():
+def test_add_indices() -> None:
+ for asset in Asset.all():
assert asset.name is not None
assert asset.description is not None
assert not asset.interpolate
@@ -399,7 +362,6 @@ def test_autodetect_interpolate_sparse(
asset: Asset,
valuations: list[AssetValuation],
) -> None:
- _ = valuations
asset.autodetect_interpolate()
assert asset.interpolate
@@ -410,36 +372,34 @@ def test_autodetect_interpolate_daily(
) -> None:
for i, v in enumerate(valuations):
v.date_ord = valuations[0].date_ord + i
- _ = valuations
asset.autodetect_interpolate()
assert not asset.interpolate
-def test_create_forex(session: orm.Session, asset: Asset) -> None:
+def test_create_forex(asset: Asset) -> None:
asset.ticker = "EURUSD=X"
asset.category = AssetCategory.FOREX
- Asset.create_forex(session, Currency.USD, {*Currency})
+ Asset.create_forex(Currency.USD, {*Currency})
- query = session.query(Asset).where(Asset.category == AssetCategory.FOREX)
+ query = Asset.query().where(Asset.category == AssetCategory.FOREX)
# -1 since don't need USDUSD=x
- assert query_count(query) == len(Currency) - 1
+ assert sql.count(query) == len(Currency) - 1
-def test_create_forex_none(session: orm.Session) -> None:
- Asset.create_forex(session, Currency.USD, set())
+def test_create_forex_none() -> None:
+ Asset.create_forex(Currency.USD, set())
- query = session.query(Asset).where(Asset.category == AssetCategory.FOREX)
- assert query_count(query) == 0
+ query = Asset.query().where(Asset.category == AssetCategory.FOREX)
+ assert not sql.any_(query)
-def test_get_forex_empty(session: orm.Session, today_ord: int) -> None:
- result = Asset.get_forex(session, today_ord, today_ord, DEFAULT_CURRENCY)
+def test_get_forex_empty(today_ord: int) -> None:
+ result = Asset.get_forex(today_ord, today_ord, DEFAULT_CURRENCY)
assert result[DEFAULT_CURRENCY] == [Decimal(1)]
def test_get_forex(
- session: orm.Session,
today_ord: int,
asset: Asset,
asset_valuation: AssetValuation,
@@ -449,7 +409,6 @@ def test_get_forex(
asset.currency = Currency.USD
result = Asset.get_forex(
- session,
today_ord,
today_ord,
Currency.USD,
diff --git a/tests/models/asset/test_asset_sector.py b/tests/models/asset/test_asset_sector.py
index 54eec082..35695e14 100644
--- a/tests/models/asset/test_asset_sector.py
+++ b/tests/models/asset/test_asset_sector.py
@@ -8,14 +8,11 @@
from nummus.models.asset import AssetSector, USSector
if TYPE_CHECKING:
- from sqlalchemy import orm
-
from nummus.models.asset import Asset
from tests.conftest import RandomRealGenerator
def test_init_properties(
- session: orm.Session,
asset: Asset,
rand_real_generator: RandomRealGenerator,
) -> None:
@@ -25,45 +22,35 @@ def test_init_properties(
"weight": rand_real_generator(1, 10),
}
- v = AssetSector(**d)
- session.add(v)
- session.commit()
+ v = AssetSector.create(**d)
assert v.asset_id == d["asset_id"]
assert v.sector == d["sector"]
assert v.weight == d["weight"]
-def test_weight_negative(session: orm.Session, asset: Asset) -> None:
- v = AssetSector(asset_id=asset.id_, sector=USSector.REAL_ESTATE, weight=-1)
- session.add(v)
+def test_weight_negative(asset: Asset) -> None:
with pytest.raises(exc.IntegrityError):
- session.commit()
+ AssetSector.create(asset_id=asset.id_, sector=USSector.REAL_ESTATE, weight=-1)
-def test_weight_zero(session: orm.Session, asset: Asset) -> None:
- v = AssetSector(asset_id=asset.id_, sector=USSector.REAL_ESTATE, weight=0)
- session.add(v)
+def test_weight_zero(asset: Asset) -> None:
with pytest.raises(exc.IntegrityError):
- session.commit()
+ AssetSector.create(asset_id=asset.id_, sector=USSector.REAL_ESTATE, weight=0)
def test_duplicate_sectors(
- session: orm.Session,
asset: Asset,
rand_real_generator: RandomRealGenerator,
) -> None:
- v = AssetSector(
- asset_id=asset.id_,
- sector=USSector.REAL_ESTATE,
- weight=rand_real_generator(1, 10),
- )
- session.add(v)
- v = AssetSector(
+ AssetSector.create(
asset_id=asset.id_,
sector=USSector.REAL_ESTATE,
weight=rand_real_generator(1, 10),
)
- session.add(v)
with pytest.raises(exc.IntegrityError):
- session.commit()
+ AssetSector.create(
+ asset_id=asset.id_,
+ sector=USSector.REAL_ESTATE,
+ weight=rand_real_generator(1, 10),
+ )
diff --git a/tests/models/asset/test_asset_split.py b/tests/models/asset/test_asset_split.py
index 5b177099..85e3df39 100644
--- a/tests/models/asset/test_asset_split.py
+++ b/tests/models/asset/test_asset_split.py
@@ -10,8 +10,6 @@
if TYPE_CHECKING:
import datetime
- from sqlalchemy import orm
-
from nummus.models.asset import Asset
from tests.conftest import RandomRealGenerator
@@ -19,7 +17,6 @@
def test_init_properties(
today: datetime.date,
today_ord: int,
- session: orm.Session,
asset: Asset,
rand_real_generator: RandomRealGenerator,
) -> None:
@@ -29,9 +26,7 @@ def test_init_properties(
"date_ord": today_ord,
}
- v = AssetSplit(**d)
- session.add(v)
- session.commit()
+ v = AssetSplit.create(**d)
assert v.asset_id == d["asset_id"]
assert v.multiplier == d["multiplier"]
@@ -41,39 +36,30 @@ def test_init_properties(
def test_multiplier_negative(
today_ord: int,
- session: orm.Session,
asset: Asset,
) -> None:
- v = AssetSplit(asset_id=asset.id_, date_ord=today_ord, multiplier=-1)
- session.add(v)
with pytest.raises(exc.IntegrityError):
- session.commit()
+ AssetSplit.create(asset_id=asset.id_, date_ord=today_ord, multiplier=-1)
-def test_multiplier_zero(today_ord: int, session: orm.Session, asset: Asset) -> None:
- v = AssetSplit(asset_id=asset.id_, date_ord=today_ord, multiplier=0)
- session.add(v)
+def test_multiplier_zero(today_ord: int, asset: Asset) -> None:
with pytest.raises(exc.IntegrityError):
- session.commit()
+ AssetSplit.create(asset_id=asset.id_, date_ord=today_ord, multiplier=0)
def test_duplicate_dates(
today_ord: int,
- session: orm.Session,
asset: Asset,
rand_real_generator: RandomRealGenerator,
) -> None:
- v = AssetSplit(
- asset_id=asset.id_,
- date_ord=today_ord,
- multiplier=rand_real_generator(1, 10),
- )
- session.add(v)
- v = AssetSplit(
+ AssetSplit.create(
asset_id=asset.id_,
date_ord=today_ord,
multiplier=rand_real_generator(1, 10),
)
- session.add(v)
with pytest.raises(exc.IntegrityError):
- session.commit()
+ AssetSplit.create(
+ asset_id=asset.id_,
+ date_ord=today_ord,
+ multiplier=rand_real_generator(1, 10),
+ )
diff --git a/tests/models/asset/test_asset_valuation.py b/tests/models/asset/test_asset_valuation.py
index 94bfc57a..15fe81cc 100644
--- a/tests/models/asset/test_asset_valuation.py
+++ b/tests/models/asset/test_asset_valuation.py
@@ -11,8 +11,6 @@
import datetime
from decimal import Decimal
- from sqlalchemy import orm
-
from nummus.models.asset import Asset
from tests.conftest import RandomRealGenerator
@@ -20,7 +18,6 @@
def test_init_properties(
today: datetime.date,
today_ord: int,
- session: orm.Session,
asset: Asset,
rand_real: Decimal,
) -> None:
@@ -30,9 +27,7 @@ def test_init_properties(
"value": rand_real,
}
- v = AssetValuation(**d)
- session.add(v)
- session.commit()
+ v = AssetValuation.create(**d)
assert v.asset_id == d["asset_id"]
assert v.value == d["value"]
@@ -42,32 +37,25 @@ def test_init_properties(
def test_multiplier_negative(
today_ord: int,
- session: orm.Session,
asset: Asset,
) -> None:
- v = AssetValuation(asset_id=asset.id_, date_ord=today_ord, value=-1)
- session.add(v)
with pytest.raises(exc.IntegrityError):
- session.commit()
+ AssetValuation.create(asset_id=asset.id_, date_ord=today_ord, value=-1)
def test_duplicate_dates(
today_ord: int,
- session: orm.Session,
asset: Asset,
rand_real_generator: RandomRealGenerator,
) -> None:
- v = AssetValuation(
- asset_id=asset.id_,
- date_ord=today_ord,
- value=rand_real_generator(),
- )
- session.add(v)
- v = AssetValuation(
+ AssetValuation.create(
asset_id=asset.id_,
date_ord=today_ord,
value=rand_real_generator(),
)
- session.add(v)
with pytest.raises(exc.IntegrityError):
- session.commit()
+ AssetValuation.create(
+ asset_id=asset.id_,
+ date_ord=today_ord,
+ value=rand_real_generator(),
+ )
diff --git a/tests/models/base/test_orm_base.py b/tests/models/base/test_orm_base.py
index 177da751..0aea5fea 100644
--- a/tests/models/base/test_orm_base.py
+++ b/tests/models/base/test_orm_base.py
@@ -1,19 +1,36 @@
from __future__ import annotations
+import re
from decimal import Decimal
+from pathlib import Path
from typing import TYPE_CHECKING
import pytest
from sqlalchemy import ForeignKey, orm
+import nummus
+import tests
from nummus import exceptions as exc
from nummus import sql
-from nummus.models import base
+from nummus.models.base import (
+ Base,
+ BaseEnum,
+ Decimal6,
+ ORMInt,
+ ORMIntOpt,
+ ORMRealOpt,
+ ORMStrOpt,
+ SQLEnum,
+ string_column_args,
+)
+from tests import conftest
if TYPE_CHECKING:
- from collections.abc import Mapping
- from pathlib import Path
+ from collections.abc import Callable, Generator, Mapping
+ from nummus.models.base import (
+ NamePair,
+ )
from tests.conftest import RandomStringGenerator
@@ -28,7 +45,7 @@ def __hash__(self) -> int:
return hash(self._data)
-class Derived(base.BaseEnum):
+class Derived(BaseEnum):
RED = 1
BLUE = 2
SEAFOAM_GREEN = 3
@@ -38,15 +55,17 @@ def lut(cls) -> Mapping[str, Derived]:
return {"r": cls.RED, "b": cls.BLUE}
-class Parent(base.Base, skip_register=True):
+class Parent(Base, skip_register=True):
__tablename__ = "parent"
__table_id__ = 0xF0000000
- generic_column: base.ORMIntOpt
- name: base.ORMStrOpt
+ generic_column: ORMIntOpt
+ name: ORMStrOpt
children: orm.Mapped[list[Child]] = orm.relationship(back_populates="parent")
- __table_args__ = (*base.string_column_args("name"),)
+ __table_args__ = (*string_column_args("name"),)
+
+ _SEARCH_PROPERTIES = ("name",)
@orm.validates("name")
def validate_strings(self, key: str, field: str | None) -> str | None:
@@ -63,46 +82,47 @@ def uri_bytes(self) -> Bytes:
return Bytes(self.uri)
-class Child(base.Base, skip_register=True):
+class Child(Base, skip_register=True):
__tablename__ = "child"
__table_id__ = 0xE0000000
- parent_id: base.ORMInt = orm.mapped_column(ForeignKey("parent.id_"))
+ parent_id: ORMInt = orm.mapped_column(ForeignKey("parent.id_"))
parent: orm.Mapped[Parent] = orm.relationship(back_populates="children")
- height: base.ORMRealOpt = orm.mapped_column(base.Decimal6)
+ height: ORMRealOpt = orm.mapped_column(Decimal6)
- color: orm.Mapped[Derived | None] = orm.mapped_column(base.SQLEnum(Derived))
+ color: orm.Mapped[Derived | None] = orm.mapped_column(SQLEnum(Derived))
@orm.validates("height")
def validate_decimals(self, key: str, field: Decimal | None) -> Decimal | None:
return self.clean_decimals(key, field)
-class NoURI(base.Base, skip_register=True):
+class NoURI(Base, skip_register=True):
__tablename__ = "no_uri"
__table_id__ = None
@pytest.fixture
-def session(tmp_path: Path) -> orm.Session:
+def session(tmp_path: Path) -> Generator[orm.Session]:
"""Create SQL session.
Args:
tmp_path: Temp path to create DB in
- Returns:
+ Yields:
Session generator
"""
path = tmp_path / "sql.db"
s = orm.Session(sql.get_engine(path, None))
- base.Base.metadata.create_all(
- s.get_bind(),
- tables=[Parent.sql_table(), Child.sql_table()],
- )
- s.commit()
- return s
+ with s.begin_nested():
+ Base.metadata.create_all(
+ s.get_bind(),
+ tables=[Parent.sql_table(), Child.sql_table()],
+ )
+ with Base.set_session(s):
+ yield s
@pytest.fixture
@@ -113,10 +133,8 @@ def parent(session: orm.Session) -> Parent:
Parent
"""
- p = Parent()
- session.add(p)
- session.commit()
- return p
+ with session.begin_nested():
+ return Parent.create()
@pytest.fixture
@@ -127,17 +145,8 @@ def child(session: orm.Session, parent: Parent) -> Child:
Child
"""
- c = Child(parent=parent)
- session.add(c)
- session.commit()
- return c
-
-
-def test_detached() -> None:
- parent = Parent()
- assert parent.id_ is None
- with pytest.raises(exc.NoIDError):
- _ = parent.uri
+ with session.begin_nested():
+ return Child.create(parent=parent)
def test_init_properties(parent: Parent) -> None:
@@ -147,13 +156,6 @@ def test_init_properties(parent: Parent) -> None:
assert hash(parent) == parent.id_
-def test_detached_child() -> None:
- child = Child()
- assert child.id_ is None
- assert child.parent is None
- assert child.parent_id is None
-
-
def test_link_child(parent: Parent, child: Child) -> None:
assert child.id_ is not None
assert child.parent == parent
@@ -165,23 +167,20 @@ def test_wrong_uri_type(parent: Parent) -> None:
Child.uri_to_id(parent.uri)
-def test_set_decimal_none(session: orm.Session, child: Child) -> None:
+def test_set_decimal_none(child: Child) -> None:
child.height = None
- session.commit()
assert child.height is None
-def test_set_decimal_value(session: orm.Session, child: Child) -> None:
+def test_set_decimal_value(child: Child) -> None:
height = Decimal("1.2")
child.height = height
- session.commit()
assert isinstance(child.height, Decimal)
assert child.height == height
-def test_set_enum(session: orm.Session, child: Child) -> None:
+def test_set_enum(child: Child) -> None:
child.color = Derived.RED
- session.commit()
assert isinstance(child.color, Derived)
assert child.color == Derived.RED
@@ -192,11 +191,9 @@ def test_no_uri() -> None:
_ = no_uri.uri
-def test_comparators_same_session(session: orm.Session) -> None:
- parent_a = Parent()
- parent_b = Parent()
- session.add_all([parent_a, parent_b])
- session.commit()
+def test_comparators_same_session() -> None:
+ parent_a = Parent.create()
+ parent_b = Parent.create()
assert parent_a == parent_a # noqa: PLR0124
assert parent_a != parent_b
@@ -213,25 +210,20 @@ def test_comparators_different_session(session: orm.Session, parent: Parent) ->
assert parent == parent_a_queried
-def test_map_name_none(session: orm.Session) -> None:
+def test_map_name_none() -> None:
with pytest.raises(KeyError, match="Base does not have name column"):
- base.Base.map_name(session)
+ Base.map_name()
-def test_map_name_parent(
- session: orm.Session,
- rand_str_generator: RandomStringGenerator,
-) -> None:
- parent_a = Parent(name=rand_str_generator())
- parent_b = Parent(name=rand_str_generator())
- session.add_all([parent_a, parent_b])
- session.commit()
+def test_map_name_parent(rand_str_generator: RandomStringGenerator) -> None:
+ parent_a = Parent.create(name=rand_str_generator())
+ parent_b = Parent.create(name=rand_str_generator())
target = {
parent_a.id_: parent_a.name,
parent_b.id_: parent_b.name,
}
- assert Parent.map_name(session) == target
+ assert Parent.map_name() == target
def test_clean_strings_none(parent: Parent) -> None:
@@ -258,28 +250,28 @@ def test_clean_strings_short(parent: Parent) -> None:
parent.name = "a"
-def test_string_check_none(session: orm.Session, parent: Parent) -> None:
+def test_string_check_none(parent: Parent) -> None:
with pytest.raises(exc.IntegrityError):
- session.query(Parent).where(Parent.id_ == parent.id_).update({Parent.name: ""})
+ Parent.query().where(Parent.id_ == parent.id_).update({Parent.name: ""})
-def test_string_check_leading(session: orm.Session, parent: Parent) -> None:
+def test_string_check_leading(parent: Parent) -> None:
with pytest.raises(exc.IntegrityError):
- session.query(Parent).where(Parent.id_ == parent.id_).update(
+ Parent.query().where(Parent.id_ == parent.id_).update(
{Parent.name: " leading"},
)
-def test_string_check_trailing(session: orm.Session, parent: Parent) -> None:
+def test_string_check_trailing(parent: Parent) -> None:
with pytest.raises(exc.IntegrityError):
- session.query(Parent).where(Parent.id_ == parent.id_).update(
+ Parent.query().where(Parent.id_ == parent.id_).update(
{Parent.name: "trailing "},
)
-def test_string_check_short(session: orm.Session, parent: Parent) -> None:
+def test_string_check_short(parent: Parent) -> None:
with pytest.raises(exc.IntegrityError):
- session.query(Parent).where(Parent.id_ == parent.id_).update({Parent.name: "a"})
+ Parent.query().where(Parent.id_ == parent.id_).update({Parent.name: "a"})
def test_clean_decimals() -> None:
@@ -293,9 +285,160 @@ def test_clean_decimals() -> None:
def test_clean_emoji_name(rand_str: str) -> None:
text = rand_str.lower()
- assert base.Base.clean_emoji_name(text + " 😀 ") == text
+ assert Base.clean_emoji_name(text + " 😀 ") == text
def test_clean_emoji_name_upper(rand_str: str) -> None:
text = rand_str.lower()
- assert base.Base.clean_emoji_name(text.upper() + " 😀 ") == text
+ assert Base.clean_emoji_name(text.upper() + " 😀 ") == text
+
+
+def test_query_kwargs() -> None:
+ with pytest.raises(exc.NoKeywordArgumentsError):
+ # Intentional bad argument
+ Parent.query(kw=None) # type: ignore[attr-defined]
+
+
+def test_unbound_error() -> None:
+ s = Base._sessions.pop()
+ with pytest.raises(exc.UnboundExecutionError):
+ Base.session()
+ Base._sessions.append(s)
+
+
+def noop[T](x: T) -> T:
+ return x
+
+
+def lower(s: str) -> str:
+ return s.lower()
+
+
+def upper(s: str) -> str:
+ return s.upper()
+
+
+@pytest.mark.parametrize(
+ ("prop", "value_adjuster"),
+ [
+ ("uri", noop),
+ ("name", noop),
+ ("name", lower),
+ ("name", upper),
+ ],
+)
+def test_find(
+ parent: Parent,
+ prop: str,
+ value_adjuster: Callable[[str], str],
+) -> None:
+ parent.name = "Fake"
+ query = value_adjuster(getattr(parent, prop))
+
+ cache: dict[str, NamePair] = {}
+
+ result = Parent.find(query, cache)
+ assert result.id_ == parent.id_
+ assert result.name == parent.name
+
+ assert cache == {query: result}
+
+
+def test_find_missing(parent: Parent) -> None:
+ query = Parent.id_to_uri(parent.id_ + 1)
+
+ cache: dict[str, NamePair] = {}
+ with pytest.raises(exc.NoResultFound):
+ Parent.find(query, cache)
+
+ assert not cache
+
+
+def check_no_session_add(line: str) -> str:
+ if re.match(r"^ *(s|session)\.add\(\w\)", line):
+ return "Use of session.add found, use Model.create()"
+ return ""
+
+
+def check_no_session_query(line: str) -> str:
+ if re.search(r"[( ](s|session)\.query\(", line):
+ return "Use of session.query found, use Model.query()"
+ return ""
+
+
+def check_no_query_with_entities(line: str) -> str:
+ if ".with_entities" in line: # nummus: ignore
+ return "Use of with_entities found, use Model.query(col, ...)"
+ return ""
+
+
+def check_no_query_scalar(line: str) -> str:
+ if (m := re.search(r"(\w*)\.scalar\(", line)) and m.group(1) != "sql":
+ return "Use of query.scalar found, use sql.scalar()"
+ return ""
+
+
+def check_no_query_one(line: str) -> str:
+ if not (m := re.search(r"(\w*)\.one\(", line)):
+ return ""
+ g = m.group(1)
+ if (g and g[0] == g[0].upper()) or g == "sql":
+ # use of Model.one()
+ return ""
+ return "Use of query.one found, use sql.one()"
+
+
+def check_no_query_all(line: str) -> str:
+ if not (m := re.search(r"(\w*)\.all\(", line)):
+ return ""
+ g = m.group(1)
+ if (g and g[0] == g[0].upper()) or g == "sql":
+ # use of Model.all()
+ return ""
+ return "Use of query.all found, use sql.yield_()"
+
+
+def check_no_query_col0(line: str) -> str:
+ if re.search(r"for \w+,? in query", line):
+ return "Use of first column iterator found, use sql.col0()"
+ return ""
+
+
+@pytest.mark.parametrize(
+ "path",
+ sorted(
+ [
+ *Path(nummus.__file__).parent.glob("**/*.py"),
+ *Path(tests.__file__).parent.glob("**/*.py"),
+ ],
+ ),
+ ids=conftest.id_func,
+)
+def test_use_of_mixins(path: Path) -> None:
+ lines = path.read_text("utf-8").splitlines()
+
+ ignore = "# nummus: ignore"
+
+ errors: list[str] = []
+
+ for i, line in enumerate(lines):
+ checks = [
+ check_no_session_add(line),
+ check_no_session_query(line),
+ check_no_query_with_entities(line),
+ check_no_query_scalar(line),
+ check_no_query_one(line),
+ check_no_query_all(line),
+ check_no_query_col0(line),
+ ]
+ checks = [f"{path:}:{i + 1}: {c}" for c in checks if c]
+ if checks:
+ if not line.endswith(ignore):
+ errors.extend(checks)
+ elif line.endswith(ignore):
+ errors.append(
+ f"{path}:{i + 1}: Use of unnecessary 'nummus: ignore'",
+ )
+
+ print("\n".join(errors))
+ assert not errors
diff --git a/tests/models/budget/test_budget_assignment.py b/tests/models/budget/test_budget_assignment.py
index f662e27c..00e02ca2 100644
--- a/tests/models/budget/test_budget_assignment.py
+++ b/tests/models/budget/test_budget_assignment.py
@@ -16,14 +16,11 @@
from nummus.models.transaction_category import TransactionCategory
if TYPE_CHECKING:
- from sqlalchemy import orm
-
from nummus.models.account import Account
def test_init_properties(
month_ord: int,
- session: orm.Session,
categories: dict[str, int],
rand_real: Decimal,
) -> None:
@@ -33,9 +30,7 @@ def test_init_properties(
"category_id": categories["uncategorized"],
}
- b = BudgetAssignment(**d)
- session.add(b)
- session.commit()
+ b = BudgetAssignment.create(**d)
assert b.month_ord == d["month_ord"]
assert b.amount == d["amount"]
@@ -44,33 +39,27 @@ def test_init_properties(
def test_duplicate_months(
month_ord: int,
- session: orm.Session,
categories: dict[str, int],
rand_real: Decimal,
) -> None:
- b = BudgetAssignment(
- month_ord=month_ord,
- amount=rand_real,
- category_id=categories["uncategorized"],
- )
- session.add(b)
- b = BudgetAssignment(
+ BudgetAssignment.create(
month_ord=month_ord,
amount=rand_real,
category_id=categories["uncategorized"],
)
- session.add(b)
with pytest.raises(exc.IntegrityError):
- session.commit()
+ BudgetAssignment.create(
+ month_ord=month_ord,
+ amount=rand_real,
+ category_id=categories["uncategorized"],
+ )
def test_get_monthly_available_empty(
month: datetime.date,
- session: orm.Session,
categories: dict[str, int],
) -> None:
availables, assignable, future_assigned = BudgetAssignment.get_monthly_available(
- session,
month,
)
assert set(availables.keys()) == set(categories.values())
@@ -82,15 +71,11 @@ def test_get_monthly_available_empty(
def test_get_monthly_available(
month: datetime.date,
- session: orm.Session,
categories: dict[str, int],
transactions_spending: list[Transaction],
budget_assignments: list[BudgetAssignment],
) -> None:
- _ = transactions_spending
- _ = budget_assignments
availables, assignable, future_assigned = BudgetAssignment.get_monthly_available(
- session,
month,
)
availables.pop(categories["other income"])
@@ -131,15 +116,11 @@ def test_get_monthly_available(
def test_get_monthly_available_next_month(
month: datetime.date,
- session: orm.Session,
categories: dict[str, int],
transactions_spending: list[Transaction],
budget_assignments: list[BudgetAssignment],
) -> None:
- _ = transactions_spending
- _ = budget_assignments
availables, assignable, future_assigned = BudgetAssignment.get_monthly_available(
- session,
utils.date_add_months(month, 1),
)
availables.pop(categories["other income"])
@@ -173,7 +154,6 @@ def test_get_monthly_available_next_month(
def test_get_emergency_fund_empty(
today_ord: int,
- session: orm.Session,
) -> None:
start_ord = today_ord - 3
end_ord = today_ord + 3
@@ -181,7 +161,6 @@ def test_get_emergency_fund_empty(
n_lower = 20
n_upper = 40
result = BudgetAssignment.get_emergency_fund(
- session,
start_ord,
end_ord,
n_lower,
@@ -197,33 +176,28 @@ def test_get_emergency_fund_empty(
def test_get_emergency_fund(
today: datetime.date,
today_ord: int,
- session: orm.Session,
account: Account,
categories: dict[str, int],
transactions_spending: list[Transaction],
budget_assignments: list[BudgetAssignment],
rand_str: str,
) -> None:
- session.query(TransactionCategory).where(
+ TransactionCategory.query().where(
TransactionCategory.name == "groceries",
).update({"essential_spending": True})
# Add a transaction 30 days ago
- txn = Transaction(
+ txn = Transaction.create(
account_id=account.id_,
date=today - datetime.timedelta(days=30),
amount=-50,
statement=rand_str,
)
- t_split = TransactionSplit(
+ TransactionSplit.create(
parent=txn,
amount=txn.amount,
category_id=categories["groceries"],
)
- session.add_all((txn, t_split))
- session.commit()
- _ = transactions_spending
- _ = budget_assignments
start_ord = today_ord - 3
end_ord = today_ord + 3
@@ -231,7 +205,6 @@ def test_get_emergency_fund(
n_lower = 20
n_upper = 40
result = BudgetAssignment.get_emergency_fund(
- session,
start_ord,
end_ord,
n_lower,
@@ -272,17 +245,14 @@ def test_get_emergency_fund(
def test_get_emergency_fund_balance(
month_ord: int,
- session: orm.Session,
budget_assignments: list[BudgetAssignment],
) -> None:
- _ = budget_assignments
start_ord = month_ord - 3
end_ord = month_ord + 3
n_lower = 20
n_upper = 40
result = BudgetAssignment.get_emergency_fund(
- session,
start_ord,
end_ord,
n_lower,
@@ -297,15 +267,13 @@ def test_get_emergency_fund_balance(
def test_move_from_income(
month_ord: int,
- session: orm.Session,
categories: dict[str, int],
) -> None:
src_cat_id = None
dest_cat_id = categories["groceries"]
- BudgetAssignment.move(session, month_ord, src_cat_id, dest_cat_id, Decimal(100))
- session.commit()
+ BudgetAssignment.move(month_ord, src_cat_id, dest_cat_id, Decimal(100))
- a = session.query(BudgetAssignment).one()
+ a = BudgetAssignment.one()
assert a.category_id == dest_cat_id
assert a.month_ord == month_ord
assert a.amount == 100
@@ -324,7 +292,6 @@ def test_move_from_income(
)
def test_move_to_income_partial(
month_ord: int,
- session: orm.Session,
categories: dict[str, int],
budget_assignments: list[BudgetAssignment],
src: str | None,
@@ -333,14 +300,12 @@ def test_move_to_income_partial(
target_src: Decimal | None,
target_dest: Decimal | None,
) -> None:
- _ = budget_assignments
src_cat_id = None if src is None else categories[src]
dest_cat_id = None if dest is None else categories[dest]
- BudgetAssignment.move(session, month_ord, src_cat_id, dest_cat_id, to_move)
- session.commit()
+ BudgetAssignment.move(month_ord, src_cat_id, dest_cat_id, to_move)
a = (
- session.query(BudgetAssignment)
+ BudgetAssignment.query()
.where(
BudgetAssignment.category_id == src_cat_id,
BudgetAssignment.month_ord == month_ord,
@@ -355,7 +320,7 @@ def test_move_to_income_partial(
assert a.amount == target_src
a = (
- session.query(BudgetAssignment)
+ BudgetAssignment.query()
.where(
BudgetAssignment.category_id == dest_cat_id,
BudgetAssignment.month_ord == month_ord,
diff --git a/tests/models/budget/test_budget_group.py b/tests/models/budget/test_budget_group.py
index 32f81737..36bef2e1 100644
--- a/tests/models/budget/test_budget_group.py
+++ b/tests/models/budget/test_budget_group.py
@@ -8,55 +8,41 @@
from nummus.models.budget import BudgetGroup
if TYPE_CHECKING:
- from sqlalchemy import orm
-
from tests.conftest import RandomStringGenerator
-def test_init_properties(session: orm.Session, rand_str: str) -> None:
+def test_init_properties(rand_str: str) -> None:
d = {
"name": rand_str,
"position": 0,
}
- g = BudgetGroup(**d)
- session.add(g)
- session.commit()
+ g = BudgetGroup.create(**d)
assert g.name == d["name"]
assert g.position == d["position"]
def test_duplicate_names(
- session: orm.Session,
rand_str: str,
) -> None:
- g = BudgetGroup(name=rand_str, position=0)
- session.add(g)
- g = BudgetGroup(name=rand_str, position=1)
- session.add(g)
+ BudgetGroup.create(name=rand_str, position=0)
with pytest.raises(exc.IntegrityError):
- session.commit()
+ BudgetGroup.create(name=rand_str, position=0)
def test_duplicate_position(
- session: orm.Session,
rand_str_generator: RandomStringGenerator,
) -> None:
- g = BudgetGroup(name=rand_str_generator(), position=0)
- session.add(g)
- g = BudgetGroup(name=rand_str_generator(), position=0)
- session.add(g)
+ BudgetGroup.create(name=rand_str_generator(), position=0)
with pytest.raises(exc.IntegrityError):
- session.commit()
+ BudgetGroup.create(name=rand_str_generator(), position=0)
-def test_empty(session: orm.Session) -> None:
- g = BudgetGroup(name="", position=0)
- session.add(g)
+def test_empty() -> None:
with pytest.raises(exc.IntegrityError):
- session.commit()
+ BudgetGroup.create(name="", position=0)
def test_short() -> None:
with pytest.raises(exc.InvalidORMValueError):
- BudgetGroup(name="a", position=0)
+ BudgetGroup.create(name="a", position=0)
diff --git a/tests/models/budget/test_target.py b/tests/models/budget/test_target.py
index 439a1023..897a6d64 100644
--- a/tests/models/budget/test_target.py
+++ b/tests/models/budget/test_target.py
@@ -12,13 +12,10 @@
import datetime
from decimal import Decimal
- from sqlalchemy import orm
-
def test_init_properties(
today: datetime.date,
today_ord: int,
- session: orm.Session,
rand_real: Decimal,
categories: dict[str, int],
) -> None:
@@ -31,9 +28,7 @@ def test_init_properties(
"repeat_every": 0,
}
- t = Target(**d)
- session.add(t)
- session.commit()
+ t = Target.create(**d)
assert t.category_id == d["category_id"]
assert t.amount == d["amount"]
@@ -69,7 +64,6 @@ def test_init_properties(
)
def test_check_constraints(
today_ord: int,
- session: orm.Session,
rand_real: Decimal,
categories: dict[str, int],
period: TargetPeriod,
@@ -77,7 +71,7 @@ def test_check_constraints(
kwargs: dict[str, object],
success: bool,
) -> None:
- d = {
+ d: dict[str, object] = {
"category_id": categories["uncategorized"],
"amount": rand_real,
"type_": type_,
@@ -86,18 +80,15 @@ def test_check_constraints(
"repeat_every": 0,
}
d.update(kwargs)
- t = Target(**d)
- session.add(t)
if success:
- session.commit()
+ Target.create(**d)
else:
with pytest.raises(exc.IntegrityError):
- session.commit()
+ Target.create(**d)
def test_duplicates(
today_ord: int,
- session: orm.Session,
rand_real: Decimal,
categories: dict[str, int],
) -> None:
@@ -110,16 +101,12 @@ def test_duplicates(
"repeat_every": 0,
}
- t = Target(**d)
- session.add(t)
- t = Target(**d)
- session.add(t)
+ Target.create(**d)
with pytest.raises(exc.IntegrityError):
- session.commit()
+ Target.create(**d)
def test_date_none(
- session: orm.Session,
rand_real: Decimal,
categories: dict[str, int],
) -> None:
@@ -132,7 +119,6 @@ def test_date_none(
"repeat_every": 0,
}
- t = Target(**d)
- session.add(t)
+ t = Target.create(**d)
assert t.due_date is None
diff --git a/tests/models/label/test_label.py b/tests/models/label/test_label.py
index 99f7c5f2..ff7ab0e9 100644
--- a/tests/models/label/test_label.py
+++ b/tests/models/label/test_label.py
@@ -1,23 +1,15 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
-
import pytest
from nummus import exceptions as exc
from nummus.models.label import Label
-if TYPE_CHECKING:
-
- from sqlalchemy import orm
-
-def test_init_properties(session: orm.Session, rand_str: str) -> None:
+def test_init_properties(rand_str: str) -> None:
d = {"name": rand_str}
- label = Label(**d)
- session.add(label)
- session.commit()
+ label = Label.create(**d)
assert label.name == d["name"]
diff --git a/tests/models/label/test_label_link.py b/tests/models/label/test_label_link.py
index 16022d89..d87effdc 100644
--- a/tests/models/label/test_label_link.py
+++ b/tests/models/label/test_label_link.py
@@ -2,18 +2,14 @@
from typing import TYPE_CHECKING
+from nummus import sql
from nummus.models.label import Label, LabelLink
-from nummus.models.utils import query_count
if TYPE_CHECKING:
-
- from sqlalchemy import orm
-
from nummus.models.transaction import Transaction
def test_init_properties(
- session: orm.Session,
labels: dict[str, int],
transactions: list[Transaction],
) -> None:
@@ -22,47 +18,38 @@ def test_init_properties(
"t_split_id": transactions[-1].splits[0].id_,
}
- link = LabelLink(**d)
- session.add(link)
- session.commit()
+ link = LabelLink.create(**d)
assert link.label_id == d["label_id"]
assert link.t_split_id == d["t_split_id"]
def test_add_links_delete(
- session: orm.Session,
transactions: list[Transaction],
labels: dict[str, int],
) -> None:
new_labels: dict[int, set[str]] = {txn.splits[0].id_: set() for txn in transactions}
- LabelLink.add_links(session, new_labels)
+ LabelLink.add_links(new_labels)
- n = query_count(session.query(LabelLink))
- assert n == 0
-
- n = query_count(session.query(Label))
- assert n == len(labels)
+ assert not sql.any_(LabelLink.query())
+ assert sql.count(Label.query()) == len(labels)
def test_add_links(
- session: orm.Session,
transactions: list[Transaction],
rand_str: str,
labels: dict[str, int],
) -> None:
new_labels: dict[int, set[str]] = {
- txn.splits[0].id_: {rand_str} for txn in transactions
+ txn.splits[0].id_: {rand_str, "engineer"} for txn in transactions
}
- LabelLink.add_links(session, new_labels)
-
- n = query_count(session.query(LabelLink))
- assert n == len(transactions)
+ LabelLink.add_links(new_labels)
- n = query_count(session.query(Label))
- assert n == len(labels) + 1
+ assert sql.count(LabelLink.query()) == len(transactions) * 2
+ assert sql.count(Label.query()) == len(labels) + 1
- label = session.query(Label).where(Label.id_.not_in(labels.values())).one()
+ query = Label.query().where(Label.id_.not_in(labels.values()))
+ label = sql.one(query)
assert label.name == rand_str
diff --git a/tests/models/test_account.py b/tests/models/test_account.py
index b14f9052..4f1df5d5 100644
--- a/tests/models/test_account.py
+++ b/tests/models/test_account.py
@@ -7,11 +7,9 @@
from nummus import exceptions as exc
from nummus.models.account import Account, AccountCategory
-from nummus.models.currency import DEFAULT_CURRENCY
+from nummus.models.currency import Currency, DEFAULT_CURRENCY
if TYPE_CHECKING:
- from sqlalchemy import orm
-
from nummus.models.asset import Asset, AssetValuation
from nummus.models.transaction import Transaction
from tests.conftest import RandomStringGenerator
@@ -19,7 +17,6 @@
def test_init_properties(
rand_str_generator: RandomStringGenerator,
- session: orm.Session,
) -> None:
d = {
"name": rand_str_generator(),
@@ -29,10 +26,7 @@ def test_init_properties(
"budgeted": False,
"currency": DEFAULT_CURRENCY,
}
- acct = Account(**d)
-
- session.add(acct)
- session.commit()
+ acct = Account.create(**d)
assert acct.name == d["name"]
assert acct.institution == d["institution"]
@@ -47,14 +41,13 @@ def test_short(account: Account) -> None:
account.name = "a"
-def test_ids(session: orm.Session, account: Account) -> None:
- ids = Account.ids(session, AccountCategory.CASH)
+def test_ids(account: Account) -> None:
+ ids = Account.ids(AccountCategory.CASH)
assert ids == {account.id_}
-def test_ids_none(session: orm.Session, account: Account) -> None:
- _ = account
- ids = Account.ids(session, AccountCategory.CREDIT)
+def test_ids_none(account: Account) -> None:
+ ids = Account.ids(AccountCategory.CREDIT)
assert ids == set()
@@ -63,14 +56,12 @@ def test_date_properties(
account: Account,
transactions: list[Transaction],
) -> None:
- _ = transactions
assert account.opened_on_ord == today_ord - 3
assert account.updated_on_ord == today_ord + 7
def test_get_asset_qty_empty(
today_ord: int,
- session: orm.Session,
account: Account,
) -> None:
start_ord = today_ord - 3
@@ -80,22 +71,19 @@ def test_get_asset_qty_empty(
# defaultdict is correct length
assert result[0] == [Decimal()] * 7
- result = Account.get_asset_qty_all(session, start_ord, end_ord)
+ result = Account.get_asset_qty_all(start_ord, end_ord)
assert result == {}
def test_get_asset_qty_none(
today_ord: int,
- session: orm.Session,
account: Account,
transactions: list[Transaction],
) -> None:
- _ = account
- _ = transactions
start_ord = today_ord - 3
end_ord = today_ord + 3
- result = Account.get_asset_qty_all(session, start_ord, end_ord, set())
+ result = Account.get_asset_qty_all(start_ord, end_ord, set())
# defaultdict is correct length
assert result[0][0] == [Decimal()] * 7
@@ -106,7 +94,6 @@ def test_get_asset_qty(
asset: Asset,
transactions: list[Transaction],
) -> None:
- _ = transactions
start_ord = today_ord - 3
end_ord = today_ord + 3
result_qty = account.get_asset_qty(start_ord, end_ord)
@@ -130,14 +117,12 @@ def test_get_asset_qty_today(
asset: Asset,
transactions: list[Transaction],
) -> None:
- _ = transactions
result_qty = account.get_asset_qty(today_ord, today_ord)
assert result_qty == {asset.id_: [Decimal(10)]}
def test_get_value_empty(
today_ord: int,
- session: orm.Session,
account: Account,
) -> None:
start_ord = today_ord - 3
@@ -149,7 +134,7 @@ def test_get_value_empty(
# defaultdict is correct length
assert assets[0] == [Decimal()] * 7
- values, profits, assets = Account.get_value_all(session, start_ord, end_ord)
+ values, profits, assets = Account.get_value_all(start_ord, end_ord)
assert values == {}
assert profits == {}
assert assets == {}
@@ -157,16 +142,13 @@ def test_get_value_empty(
def test_get_value_none(
today_ord: int,
- session: orm.Session,
account: Account,
transactions: list[Transaction],
) -> None:
- _ = account
- _ = transactions
start_ord = today_ord - 3
end_ord = today_ord + 3
- values, profits, assets = Account.get_value_all(session, start_ord, end_ord, set())
+ values, profits, assets = Account.get_value_all(start_ord, end_ord, set())
assert values == {}
assert profits == {}
assert assets == {}
@@ -182,8 +164,6 @@ def test_get_value(
asset_valuation: AssetValuation,
transactions: list[Transaction],
) -> None:
- _ = transactions
- _ = asset_valuation
start_ord = today_ord - 4
end_ord = today_ord + 3
values, profits, assets = account.get_value(start_ord, end_ord)
@@ -224,6 +204,55 @@ def test_get_value(
assert assets == target
+def test_get_value_forex(
+ today_ord: int,
+ account: Account,
+ asset: Asset,
+ asset_valuation: AssetValuation,
+ transactions: list[Transaction],
+) -> None:
+ start_ord = today_ord - 4
+ end_ord = today_ord + 3
+ f = 2
+ forex: dict[Currency, list[Decimal]] = {Currency.USD: [Decimal(f)] * 8}
+ values, profits, assets = account.get_value(start_ord, end_ord, forex=forex)
+ target = [
+ Decimal(),
+ Decimal(100) * f,
+ Decimal(90) * f,
+ Decimal(90) * f,
+ Decimal(110) * f,
+ Decimal(150) * f,
+ Decimal(150) * f,
+ Decimal(150) * f,
+ ]
+ assert values == target
+ target = [
+ Decimal(),
+ Decimal(),
+ Decimal(-10) * f,
+ Decimal(-10) * f,
+ Decimal(10) * f,
+ Decimal(50) * f,
+ Decimal(50) * f,
+ Decimal(50) * f,
+ ]
+ assert profits == target
+ target = {
+ asset.id_: [
+ Decimal(),
+ Decimal(),
+ Decimal(),
+ Decimal(),
+ Decimal(20) * f,
+ Decimal(10) * f,
+ Decimal(10) * f,
+ Decimal(10) * f,
+ ],
+ }
+ assert assets == target
+
+
def test_get_value_today(
today_ord: int,
account: Account,
@@ -231,8 +260,6 @@ def test_get_value_today(
asset_valuation: AssetValuation,
transactions: list[Transaction],
) -> None:
- _ = transactions
- _ = asset_valuation
values, profits, assets = account.get_value(today_ord, today_ord)
assert values == [Decimal(110)]
assert profits == [Decimal()]
@@ -245,7 +272,6 @@ def test_get_value_buy_day(
asset: Asset,
transactions: list[Transaction],
) -> None:
- _ = transactions
values, profits, assets = account.get_value(today_ord - 2, today_ord - 2)
assert values == [Decimal(90)]
assert profits == [Decimal(-10)]
@@ -257,7 +283,6 @@ def test_get_value_fund_day(
account: Account,
transactions: list[Transaction],
) -> None:
- _ = transactions
values, profits, assets = account.get_value(today_ord - 3, today_ord - 3)
assert values == [Decimal(100)]
assert profits == [Decimal()]
@@ -266,7 +291,6 @@ def test_get_value_fund_day(
def test_get_cash_flow_empty(
today_ord: int,
- session: orm.Session,
account: Account,
) -> None:
start_ord = today_ord - 3
@@ -276,7 +300,7 @@ def test_get_cash_flow_empty(
# defaultdict is correct length
assert result[0] == [Decimal()] * 7
- result = Account.get_cash_flow_all(session, start_ord, end_ord)
+ result = Account.get_cash_flow_all(start_ord, end_ord)
assert result == {}
@@ -286,7 +310,6 @@ def test_get_cash_flow(
transactions: list[Transaction],
categories: dict[str, int],
) -> None:
- _ = transactions
start_ord = today_ord - 3
end_ord = today_ord + 3
result = account.get_cash_flow(start_ord, end_ord)
@@ -318,14 +341,12 @@ def test_get_cash_flow_today(
account: Account,
transactions: list[Transaction],
) -> None:
- _ = transactions
result = account.get_cash_flow(today_ord, today_ord)
assert result == {}
def test_get_profit_by_asset_empty(
today_ord: int,
- session: orm.Session,
account: Account,
) -> None:
start_ord = today_ord - 3
@@ -334,7 +355,7 @@ def test_get_profit_by_asset_empty(
assert result == {}
assert result[0] == Decimal()
- result = Account.get_profit_by_asset_all(session, start_ord, end_ord)
+ result = Account.get_profit_by_asset_all(start_ord, end_ord)
assert result == {}
@@ -345,8 +366,6 @@ def test_get_profit_by_asset(
transactions: list[Transaction],
asset_valuation: AssetValuation,
) -> None:
- _ = transactions
- _ = asset_valuation
start_ord = today_ord - 3
end_ord = today_ord + 3
result = account.get_profit_by_asset(start_ord, end_ord)
@@ -362,7 +381,6 @@ def test_get_profit_by_asset_today(
asset: Asset,
transactions: list[Transaction],
) -> None:
- _ = transactions
result = account.get_profit_by_asset(today_ord, today_ord)
assert result == {asset.id_: Decimal()}
@@ -387,6 +405,14 @@ def test_do_include_closed(
account: Account,
transactions: list[Transaction],
) -> None:
- _ = transactions
account.closed = True
assert account.do_include(today_ord)
+
+
+def test_find(
+ account: Account,
+) -> None:
+ account.number = "1234567890"
+ id_, name = Account.find("7890", {})
+ assert id_ == account.id_
+ assert name == account.name
diff --git a/tests/models/test_base_uri.py b/tests/models/test_base_uri.py
index 548476d7..18e6b180 100644
--- a/tests/models/test_base_uri.py
+++ b/tests/models/test_base_uri.py
@@ -86,7 +86,7 @@ def test_empty_uri() -> None:
def test_symmetrical_unique() -> None:
- uris = set()
+ uris: set[str] = set()
n = 10000
for i in range(n):
@@ -147,7 +147,8 @@ def test_table_ids_no_duplicates() -> None:
table_ids: set[int] = set()
for m in MODELS_URI:
- t_id: int = m.__table_id__
+ t_id = m.__table_id__
+ assert t_id is not None
assert t_id not in table_ids
table_ids.add(t_id)
diff --git a/tests/models/test_config.py b/tests/models/test_config.py
index 074d194b..81fa7f57 100644
--- a/tests/models/test_config.py
+++ b/tests/models/test_config.py
@@ -6,84 +6,77 @@
from packaging.version import Version
from nummus import exceptions as exc
+from nummus import sql
from nummus.migrations.top import MIGRATORS
from nummus.models.config import Config, ConfigKey
from nummus.models.currency import DEFAULT_CURRENCY
from nummus.version import __version__
if TYPE_CHECKING:
- from sqlalchemy import orm
-
from tests.conftest import RandomStringGenerator
-def test_init_properties(session: orm.Session, rand_str: str) -> None:
+def test_init_properties(rand_str: str) -> None:
d = {
"key": ConfigKey.WEB_KEY,
"value": rand_str,
}
- c = Config(**d)
- session.add(c)
- session.commit()
+ c = Config.create(**d)
assert c.key == d["key"]
assert c.value == d["value"]
def test_duplicate_keys(
- session: orm.Session,
rand_str_generator: RandomStringGenerator,
) -> None:
- c = Config(key=ConfigKey.WEB_KEY, value=rand_str_generator())
- session.add(c)
- c = Config(key=ConfigKey.WEB_KEY, value=rand_str_generator())
- session.add(c)
+ Config.create(key=ConfigKey.WEB_KEY, value=rand_str_generator())
with pytest.raises(exc.IntegrityError):
- session.commit()
+ Config.create(key=ConfigKey.WEB_KEY, value=rand_str_generator())
-def test_empty(session: orm.Session) -> None:
- c = Config(key=ConfigKey.WEB_KEY, value="")
- session.add(c)
+def test_empty() -> None:
with pytest.raises(exc.IntegrityError):
- session.commit()
+ Config.create(key=ConfigKey.WEB_KEY, value="")
def test_short() -> None:
with pytest.raises(exc.InvalidORMValueError):
- Config(key=ConfigKey.WEB_KEY, value="a")
+ Config.create(key=ConfigKey.WEB_KEY, value="a")
+
+def test_set(rand_str: str) -> None:
+ Config.set_(ConfigKey.VERSION, rand_str)
+ assert Config.fetch(ConfigKey.VERSION) == rand_str
-def test_set(session: orm.Session, rand_str: str) -> None:
- Config.set_(session, ConfigKey.WEB_KEY, rand_str)
- session.commit()
- v = session.query(Config.value).where(Config.key == ConfigKey.WEB_KEY).scalar()
- assert v == rand_str
+def test_set_new(rand_str: str) -> None:
+ Config.set_(ConfigKey.WEB_KEY, rand_str)
+ assert Config.fetch(ConfigKey.WEB_KEY) == rand_str
-def test_fetch(session: orm.Session) -> None:
- target = session.query(Config.value).where(Config.key == ConfigKey.VERSION).scalar()
- assert Config.fetch(session, ConfigKey.VERSION) == target
+def test_fetch() -> None:
+ v = sql.scalar(Config.query(Config.value).where(Config.key == ConfigKey.VERSION))
+ assert Config.fetch(ConfigKey.VERSION) == v
-def test_fetch_missing(session: orm.Session) -> None:
+def test_fetch_missing() -> None:
with pytest.raises(exc.ProtectedObjectNotFoundError):
- Config.fetch(session, ConfigKey.WEB_KEY)
+ Config.fetch(ConfigKey.WEB_KEY)
-def test_fetch_missing_ok(session: orm.Session) -> None:
- assert Config.fetch(session, ConfigKey.WEB_KEY, no_raise=True) is None
+def test_fetch_missing_ok() -> None:
+ assert Config.fetch(ConfigKey.WEB_KEY, no_raise=True) is None
-def test_db_version(session: orm.Session) -> None:
+def test_db_version() -> None:
target = max(
Version(__version__),
*[m.min_version() for m in MIGRATORS],
)
- assert Config.db_version(session) == target
+ assert Config.db_version() == target
-def test_base_currency(session: orm.Session) -> None:
- assert Config.base_currency(session) == DEFAULT_CURRENCY
+def test_base_currency() -> None:
+ assert Config.base_currency() == DEFAULT_CURRENCY
diff --git a/tests/models/test_health_checks.py b/tests/models/test_health_checks.py
index b1356f3c..7b13457a 100644
--- a/tests/models/test_health_checks.py
+++ b/tests/models/test_health_checks.py
@@ -8,13 +8,10 @@
from nummus.models.health_checks import HealthCheckIssue
if TYPE_CHECKING:
- from sqlalchemy import orm
-
from tests.conftest import RandomStringGenerator
def test_init_properties(
- session: orm.Session,
rand_str_generator: RandomStringGenerator,
) -> None:
d = {
@@ -24,9 +21,7 @@ def test_init_properties(
"ignore": False,
}
- i = HealthCheckIssue(**d)
- session.add(i)
- session.commit()
+ i = HealthCheckIssue.create(**d)
assert i.check == d["check"]
assert i.value == d["value"]
@@ -35,7 +30,6 @@ def test_init_properties(
def test_duplicate_keys(
- session: orm.Session,
rand_str_generator: RandomStringGenerator,
) -> None:
d = {
@@ -44,12 +38,9 @@ def test_duplicate_keys(
"msg": rand_str_generator(),
"ignore": False,
}
- i = HealthCheckIssue(**d)
- session.add(i)
- i = HealthCheckIssue(**d)
- session.add(i)
+ HealthCheckIssue.create(**d)
with pytest.raises(exc.IntegrityError):
- session.commit()
+ HealthCheckIssue.create(**d)
def test_short_check(rand_str_generator: RandomStringGenerator) -> None:
@@ -64,7 +55,6 @@ def test_short_check(rand_str_generator: RandomStringGenerator) -> None:
def test_short_value(
- session: orm.Session,
rand_str_generator: RandomStringGenerator,
) -> None:
d = {
@@ -73,6 +63,4 @@ def test_short_value(
"msg": rand_str_generator(),
"ignore": False,
}
- i = HealthCheckIssue(**d)
- session.add(i)
- session.commit()
+ HealthCheckIssue.create(**d)
diff --git a/tests/models/test_imported_file.py b/tests/models/test_imported_file.py
index f240c650..54d74a8a 100644
--- a/tests/models/test_imported_file.py
+++ b/tests/models/test_imported_file.py
@@ -1,34 +1,23 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
-
import pytest
from nummus import exceptions as exc
-from nummus.models import imported_file
-
-if TYPE_CHECKING:
- from sqlalchemy import orm
+from nummus.models.imported_file import ImportedFile
def test_init_properties(
- session: orm.Session,
rand_str: str,
today_ord: int,
) -> None:
- f = imported_file.ImportedFile(hash_=rand_str)
- session.add(f)
- session.commit()
+ f = ImportedFile.create(hash_=rand_str)
# Default date is today
assert f.date_ord == today_ord
assert f.hash_ == rand_str
-def test_duplicates(session: orm.Session, rand_str: str) -> None:
- f = imported_file.ImportedFile(hash_=rand_str)
- session.add(f)
- f = imported_file.ImportedFile(hash_=rand_str)
- session.add(f)
+def test_duplicates(rand_str: str) -> None:
+ ImportedFile.create(hash_=rand_str)
with pytest.raises(exc.IntegrityError):
- session.commit()
+ ImportedFile.create(hash_=rand_str)
diff --git a/tests/models/test_transaction_category.py b/tests/models/test_transaction_category.py
index 677adb0b..ac3a559b 100644
--- a/tests/models/test_transaction_category.py
+++ b/tests/models/test_transaction_category.py
@@ -11,13 +11,10 @@
)
if TYPE_CHECKING:
- from sqlalchemy import orm
-
from nummus.models.budget import BudgetGroup
def test_init_properties(
- session: orm.Session,
rand_str: str,
budget_group: BudgetGroup,
) -> None:
@@ -32,9 +29,7 @@ def test_init_properties(
"budget_position": 0,
}
- t_cat = TransactionCategory(**d)
- session.add(t_cat)
- session.commit()
+ t_cat = TransactionCategory.create(**d)
assert t_cat.name == rand_str.lower()
assert t_cat.emoji_name == d["emoji_name"]
@@ -62,16 +57,16 @@ def test_name_direct() -> None:
TransactionCategory(name="a")
-def test_name_no_position(session: orm.Session, budget_group: BudgetGroup) -> None:
+def test_name_no_position(budget_group: BudgetGroup) -> None:
with pytest.raises(exc.IntegrityError):
- session.query(TransactionCategory).where(
+ TransactionCategory.query().where(
TransactionCategory.name == "transfers",
).update({TransactionCategory.budget_group_id: budget_group.id_})
-def test_name_no_group(session: orm.Session) -> None:
+def test_name_no_group() -> None:
with pytest.raises(exc.IntegrityError):
- session.query(TransactionCategory).where(
+ TransactionCategory.query().where(
TransactionCategory.name == "transfers",
).update({TransactionCategory.budget_position: 0})
@@ -84,18 +79,17 @@ def test_essential_income() -> None:
)
-def test_essential_income_update(session: orm.Session) -> None:
+def test_essential_income_update() -> None:
with pytest.raises(exc.IntegrityError):
- session.query(TransactionCategory).where(
+ TransactionCategory.query().where(
TransactionCategory.name == "other income",
).update({TransactionCategory.essential_spending: True})
-def test_essential_expense(session: orm.Session) -> None:
- session.query(TransactionCategory).where(
+def test_essential_expense() -> None:
+ TransactionCategory.query().where(
TransactionCategory.name == "groceries",
).update({TransactionCategory.essential_spending: True})
- session.commit()
def test_essential_none() -> None:
@@ -103,71 +97,67 @@ def test_essential_none() -> None:
TransactionCategory(essential_spending=None)
-def test_emergency_fund_missing(session: orm.Session) -> None:
- session.query(TransactionCategory).delete()
+def test_emergency_fund_missing() -> None:
+ TransactionCategory.query().delete()
with pytest.raises(exc.ProtectedObjectNotFoundError):
- TransactionCategory.emergency_fund(session)
+ TransactionCategory.emergency_fund()
-def test_emergency_fund(session: orm.Session, categories: dict[str, int]) -> None:
- result = TransactionCategory.emergency_fund(session)
+def test_emergency_fund(categories: dict[str, int]) -> None:
+ result = TransactionCategory.emergency_fund()
t_cat_id = categories["emergency fund"]
assert result == (t_cat_id, TransactionCategory.id_to_uri(t_cat_id))
-def test_uncategorized(session: orm.Session, categories: dict[str, int]) -> None:
- result = TransactionCategory.uncategorized(session)
+def test_uncategorized(categories: dict[str, int]) -> None:
+ result = TransactionCategory.uncategorized()
t_cat_id = categories["uncategorized"]
assert result == (t_cat_id, TransactionCategory.id_to_uri(t_cat_id))
-def test_securities_traded(session: orm.Session, categories: dict[str, int]) -> None:
- result = TransactionCategory.securities_traded(session)
+def test_securities_traded(categories: dict[str, int]) -> None:
+ result = TransactionCategory.securities_traded()
t_cat_id = categories["securities traded"]
assert result == (t_cat_id, TransactionCategory.id_to_uri(t_cat_id))
def test_map_name(
- session: orm.Session,
categories: dict[str, int],
) -> None:
- result = TransactionCategory.map_name(session)
+ result = TransactionCategory.map_name()
assert result[categories["uncategorized"]] == "uncategorized"
assert result[categories["securities traded"]] == "securities traded"
def test_map_name_no_asset_linked(
- session: orm.Session,
categories: dict[str, int],
) -> None:
- result = TransactionCategory.map_name(session, no_asset_linked=True)
+ result = TransactionCategory.map_name(no_asset_linked=True)
assert result[categories["uncategorized"]] == "uncategorized"
assert categories["securities traded"] not in result
def test_map_name_emoji(
- session: orm.Session,
categories: dict[str, int],
) -> None:
- session.query(TransactionCategory).where(
+ TransactionCategory.query().where(
TransactionCategory.name == "uncategorized",
).update(
{TransactionCategory.emoji_name: "🤷 Uncategorized 🤷"},
)
- result = TransactionCategory.map_name_emoji(session)
+ result = TransactionCategory.map_name_emoji()
assert result[categories["uncategorized"]] == "🤷 Uncategorized 🤷"
assert result[categories["securities traded"]] == "Securities Traded"
def test_map_name_emoji_no_asset_linked(
- session: orm.Session,
categories: dict[str, int],
) -> None:
- session.query(TransactionCategory).where(
+ TransactionCategory.query().where(
TransactionCategory.name == "uncategorized",
).update(
{TransactionCategory.emoji_name: "🤷 Uncategorized 🤷"},
)
- result = TransactionCategory.map_name_emoji(session, no_asset_linked=True)
+ result = TransactionCategory.map_name_emoji(no_asset_linked=True)
assert result[categories["uncategorized"]] == "🤷 Uncategorized 🤷"
assert categories["securities traded"] not in result
diff --git a/tests/models/test_utils.py b/tests/models/test_utils.py
index 75850d56..227ed446 100644
--- a/tests/models/test_utils.py
+++ b/tests/models/test_utils.py
@@ -6,12 +6,10 @@
import pytest
from sqlalchemy import CheckConstraint, ForeignKeyConstraint, UniqueConstraint
-from nummus import exceptions as exc
+from nummus import sql
from nummus.models import utils
from nummus.models.account import Account
from nummus.models.asset import (
- Asset,
- AssetCategory,
AssetSplit,
AssetValuation,
)
@@ -20,40 +18,36 @@
if TYPE_CHECKING:
import datetime
- import sqlalchemy
- from sqlalchemy import orm
-
+ from nummus.models.asset import (
+ Asset,
+ )
from tests.conftest import RandomStringGenerator
@pytest.fixture
def transactions(
- session: orm.Session,
today: datetime.date,
account: Account,
categories: dict[str, int],
rand_str_generator: RandomStringGenerator,
) -> list[Transaction]:
for _ in range(10):
- txn = Transaction(
+ txn = Transaction.create(
account_id=account.id_,
date=today,
amount=100,
statement=rand_str_generator(),
)
- t_split = TransactionSplit(
+ TransactionSplit.create(
amount=100,
parent=txn,
category_id=categories["uncategorized"],
)
- session.add_all((txn, t_split))
- session.commit()
- return session.query(Transaction).all()
+ return Transaction.all()
@pytest.fixture
def valuations(
- session: orm.Session,
today_ord: int,
asset: Asset,
) -> list[AssetValuation]:
@@ -63,14 +57,12 @@ def valuations(
today_ord: {"value": Decimal(100), "asset_id": a_id},
}
- query = session.query(AssetValuation)
- utils.update_rows(session, AssetValuation, query, "date_ord", updates)
- session.commit()
- return query.all()
+ utils.update_rows(AssetValuation, AssetValuation.query(), "date_ord", updates)
+ return AssetValuation.all()
-def test_paginate_all(session: orm.Session, transactions: list[Transaction]) -> None:
- page, count, next_offset = utils.paginate(session.query(Transaction), 50, 0)
+def test_paginate_all(transactions: list[Transaction]) -> None:
+ page, count, next_offset = utils.paginate(Transaction.query(), 50, 0)
assert page == transactions
assert count == len(transactions)
assert next_offset is None
@@ -78,11 +70,10 @@ def test_paginate_all(session: orm.Session, transactions: list[Transaction]) ->
@pytest.mark.parametrize("offset", range(10))
def test_paginate_three(
- session: orm.Session,
transactions: list[Transaction],
offset: int,
) -> None:
- page, count, next_offset = utils.paginate(session.query(Transaction), 3, offset)
+ page, count, next_offset = utils.paginate(Transaction.query(), 3, offset)
assert page == transactions[offset : offset + 3]
assert count == len(transactions)
if offset >= (len(transactions) - 3):
@@ -91,102 +82,80 @@ def test_paginate_three(
assert next_offset == offset + 3
-def test_paginate_three_page_1000(
- session: orm.Session,
- transactions: list[Transaction],
-) -> None:
- page, count, next_offset = utils.paginate(session.query(Transaction), 3, 1000)
+def test_paginate_three_page_1000(transactions: list[Transaction]) -> None:
+ page, count, next_offset = utils.paginate(Transaction.query(), 3, 1000)
assert page == []
assert count == len(transactions)
assert next_offset is None
-def test_paginate_three_page_n1000(
- session: orm.Session,
- transactions: list[Transaction],
-) -> None:
- page, count, next_offset = utils.paginate(session.query(Transaction), 3, -1000)
+def test_paginate_three_page_n1000(transactions: list[Transaction]) -> None:
+ page, count, next_offset = utils.paginate(Transaction.query(), 3, -1000)
assert page == transactions[0:3]
assert count == len(transactions)
assert next_offset == 3
-def test_dump_table_configs(session: orm.Session) -> None:
- result = utils.dump_table_configs(session, Account)
+def test_dump_table_configs() -> None:
+ result = utils.dump_table_configs(Account)
assert result[0] == "CREATE TABLE account ("
assert result[-1] == ")"
assert "\t" not in "\n".join(result)
-def test_get_constraints(session: orm.Session) -> None:
+def test_get_constraints() -> None:
target = [
(UniqueConstraint, "asset_id, date_ord"),
(CheckConstraint, "multiplier > 0"),
(ForeignKeyConstraint, "asset_id"),
]
- assert utils.get_constraints(session, AssetSplit) == target
-
-
-def test_obj_session(session: orm.Session, account: Account) -> None:
- result = utils.obj_session(account)
- assert result == session
-
-
-def test_obj_session_detached() -> None:
- acct = Account()
- with pytest.raises(exc.UnboundExecutionError):
- utils.obj_session(acct)
+ assert utils.get_constraints(AssetSplit) == target
def test_update_rows_new(
- session: orm.Session,
today_ord: int,
valuations: list[AssetValuation],
) -> None:
- query = session.query(AssetValuation)
- assert utils.query_count(query) == len(valuations)
+ assert sql.count(AssetValuation.query()) == len(valuations)
- v = query.where(AssetValuation.date_ord == today_ord).one()
+ v = sql.one(AssetValuation.query().where(AssetValuation.date_ord == today_ord))
assert v.value == Decimal(100)
- v = query.where(AssetValuation.date_ord == (today_ord - 1)).one()
+ v = sql.one(
+ AssetValuation.query().where(AssetValuation.date_ord == (today_ord - 1)),
+ )
assert v.value == Decimal(10)
def test_update_rows_edit(
- session: orm.Session,
today_ord: int,
asset: Asset,
valuations: list[AssetValuation],
) -> None:
- query = session.query(AssetValuation)
+ query = AssetValuation.query()
updates: dict[object, dict[str, object]] = {
today_ord - 2: {"value": Decimal(5), "asset_id": asset.id_},
today_ord: {"value": Decimal(50), "asset_id": asset.id_},
}
- utils.update_rows(session, AssetValuation, query, "date_ord", updates)
- session.commit()
- assert utils.query_count(query) == len(valuations)
+ utils.update_rows(AssetValuation, query, "date_ord", updates)
+ assert sql.count(query) == len(valuations)
- v = query.where(AssetValuation.date_ord == today_ord).one()
+ v = sql.one(AssetValuation.query().where(AssetValuation.date_ord == today_ord))
assert v.value == Decimal(50)
- v = query.where(AssetValuation.date_ord == (today_ord - 2)).one()
+ v = sql.one(
+ AssetValuation.query().where(AssetValuation.date_ord == (today_ord - 2)),
+ )
assert v.value == Decimal(5)
-def test_update_rows_delete(
- session: orm.Session,
- valuations: list[AssetValuation],
-) -> None:
- _ = valuations
- query = session.query(AssetValuation)
- utils.update_rows(session, AssetValuation, query, "date_ord", {})
- assert utils.query_count(query) == 0
+def test_update_rows_delete(valuations: list[AssetValuation]) -> None:
+ query = AssetValuation.query()
+ utils.update_rows(AssetValuation, query, "date_ord", {})
+ assert not sql.any_(query)
def test_update_rows_list_edit(
- session: orm.Session,
transactions: list[Transaction],
categories: dict[str, int],
rand_str_generator: RandomStringGenerator,
@@ -211,61 +180,31 @@ def test_update_rows_list_edit(
},
]
utils.update_rows_list(
- session,
TransactionSplit,
- session.query(TransactionSplit).where(TransactionSplit.parent_id == txn.id_),
+ TransactionSplit.query().where(TransactionSplit.parent_id == txn.id_),
updates,
)
- session.commit()
assert t_split_0.parent_id == txn.id_
assert t_split_0.memo == memo_0
assert t_split_0.amount == txn.amount - new_split_amount
- t_split_1 = (
- session.query(TransactionSplit)
- .where(
- TransactionSplit.parent_id == txn.id_,
- TransactionSplit.id_ != t_split_0.id_,
- )
- .one()
+ query = TransactionSplit.query().where(
+ TransactionSplit.parent_id == txn.id_,
+ TransactionSplit.id_ != t_split_0.id_,
)
+ t_split_1 = sql.one(query)
assert t_split_1.parent_id == txn.id_
assert t_split_1.memo == memo_1
assert t_split_1.amount == new_split_amount
def test_update_rows_list_delete(
- session: orm.Session,
transactions: list[Transaction],
) -> None:
txn = transactions[0]
utils.update_rows_list(
- session,
TransactionSplit,
- session.query(TransactionSplit).where(TransactionSplit.parent_id == txn.id_),
+ TransactionSplit.query().where(TransactionSplit.parent_id == txn.id_),
[],
)
- session.commit()
assert len(txn.splits) == 0
-
-
-@pytest.mark.parametrize(
- ("where", "expect_asset"),
- [
- ([], False),
- ([Asset.category == AssetCategory.STOCKS], True),
- ([Asset.category == AssetCategory.BONDS], False),
- ],
-)
-def test_one_or_none(
- session: orm.Session,
- asset: Asset,
- where: list[sqlalchemy.ColumnClause],
- expect_asset: bool,
-) -> None:
- _ = asset
- query = session.query(Asset).where(*where)
- if expect_asset:
- assert utils.one_or_none(query) == asset
- else:
- assert utils.one_or_none(query) is None
diff --git a/tests/models/transaction/test_transaction.py b/tests/models/transaction/test_transaction.py
index 900b1a4a..9d187e2d 100644
--- a/tests/models/transaction/test_transaction.py
+++ b/tests/models/transaction/test_transaction.py
@@ -10,15 +10,12 @@
import datetime
from decimal import Decimal
- from sqlalchemy import orm
-
from nummus.models.account import Account
from tests.conftest import RandomStringGenerator
def test_init_properties(
today: datetime.date,
- session: orm.Session,
account: Account,
rand_real: Decimal,
rand_str_generator: RandomStringGenerator,
@@ -31,9 +28,7 @@ def test_init_properties(
"payee": rand_str_generator(),
}
- txn = Transaction(**d)
- session.add(txn)
- session.commit()
+ txn = Transaction.create(**d)
assert txn.account_id == account.id_
assert txn.date_ord == today.toordinal()
diff --git a/tests/models/transaction/test_transaction_split.py b/tests/models/transaction/test_transaction_split.py
index 6458e152..40fa1829 100644
--- a/tests/models/transaction/test_transaction_split.py
+++ b/tests/models/transaction/test_transaction_split.py
@@ -21,7 +21,6 @@
def test_init_properties(
today: datetime.date,
- session: orm.Session,
account: Account,
asset: Asset,
categories: dict[str, int],
@@ -36,9 +35,7 @@ def test_init_properties(
"payee": rand_str_generator(),
}
- txn = Transaction(**d)
- session.add(txn)
- session.commit()
+ txn = Transaction.create(**d)
d = {
"amount": d["amount"],
@@ -49,10 +46,8 @@ def test_init_properties(
"memo": rand_str_generator(),
}
- t_split_0 = TransactionSplit(**d)
+ t_split_0 = TransactionSplit.create(**d)
- session.add(t_split_0)
- session.commit()
assert t_split_0.parent == txn
assert t_split_0.parent_id == txn.id_
assert t_split_0.category_id == d["category_id"]
@@ -71,9 +66,8 @@ def test_init_properties(
def test_zero_amount(session: orm.Session, transactions: list[Transaction]) -> None:
t_split = transactions[1].splits[0]
- t_split.amount = Decimal()
- with pytest.raises(exc.IntegrityError):
- session.commit()
+ with pytest.raises(exc.IntegrityError), session.begin_nested():
+ t_split.amount = Decimal()
def test_short() -> None:
@@ -104,18 +98,15 @@ def test_unset_asset_quantity(
transactions: list[Transaction],
) -> None:
t_split = transactions[1].splits[0]
- t_split._asset_qty_unadjusted = None
- with pytest.raises(exc.IntegrityError):
- session.commit()
+ with pytest.raises(exc.IntegrityError), session.begin_nested():
+ t_split._asset_qty_unadjusted = None
def test_clear_asset_quantity(
- session: orm.Session,
transactions: list[Transaction],
) -> None:
t_split = transactions[1].splits[0]
t_split.asset_quantity_unadjusted = None
- session.commit()
assert t_split.asset_quantity is None
@@ -159,9 +150,8 @@ def test_parent(transactions: list[Transaction]) -> None:
assert t_split.parent == txn
-def test_search_none(session: orm.Session, transactions: list[Transaction]) -> None:
- _ = transactions
- query = session.query(TransactionSplit)
+def test_search_none(transactions: list[Transaction]) -> None:
+ query = TransactionSplit.query()
with pytest.raises(exc.EmptySearchError):
TransactionSplit.search(query, "")
@@ -190,11 +180,10 @@ def test_search_none(session: orm.Session, transactions: list[Transaction]) -> N
ids=conftest.id_func,
)
def test_search(
- session: orm.Session,
transactions: list[Transaction],
search_str: str,
target: list[int],
) -> None:
- query = session.query(TransactionSplit)
+ query = TransactionSplit.query()
result = TransactionSplit.search(query, search_str)
assert result == [transactions[i].splits[0].id_ for i in target]
diff --git a/tests/portfolio/test_backup_restore.py b/tests/portfolio/test_backup_restore.py
index 048bdb52..6f6bd4fa 100644
--- a/tests/portfolio/test_backup_restore.py
+++ b/tests/portfolio/test_backup_restore.py
@@ -156,8 +156,8 @@ def test_restore_path_traversal(tmp_path: Path) -> None:
def test_restore(empty_portfolio: Portfolio) -> None:
# Delete ENCRYPTION_TEST so reload fails
- with empty_portfolio.begin_session() as s:
- s.query(Config).where(Config.key == ConfigKey.ENCRYPTION_TEST).delete()
+ with empty_portfolio.begin_session():
+ Config.query().where(Config.key == ConfigKey.ENCRYPTION_TEST).delete()
empty_portfolio.backup()
empty_portfolio.path.unlink()
@@ -169,8 +169,8 @@ def test_restore(empty_portfolio: Portfolio) -> None:
def test_restore_path(empty_portfolio: Portfolio) -> None:
# Delete ENCRYPTION_TEST so reload fails
- with empty_portfolio.begin_session() as s:
- s.query(Config).where(Config.key == ConfigKey.ENCRYPTION_TEST).delete()
+ with empty_portfolio.begin_session():
+ Config.query().where(Config.key == ConfigKey.ENCRYPTION_TEST).delete()
empty_portfolio.backup()
empty_portfolio.path.unlink()
diff --git a/tests/portfolio/test_change_key.py b/tests/portfolio/test_change_key.py
index 7cb5fd3c..ff8352ae 100644
--- a/tests/portfolio/test_change_key.py
+++ b/tests/portfolio/test_change_key.py
@@ -11,14 +11,14 @@
@pytest.mark.skipif(not ENCRYPTION_AVAILABLE, reason="No encryption available")
@pytest.mark.encryption
def test_change_db_key(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio_encrypted: tuple[Portfolio, str],
rand_str: str,
) -> None:
new_key = rand_str
p, old_key = empty_portfolio_encrypted
- with p.begin_session() as s:
- web_key_enc = Config.fetch(s, ConfigKey.WEB_KEY)
+ with p.begin_session():
+ web_key_enc = Config.fetch(ConfigKey.WEB_KEY)
web_key = p.decrypt_s(web_key_enc)
p.change_key(new_key)
@@ -28,8 +28,8 @@ def test_change_db_key(
# tqdm in here
assert captured.err
- with p.begin_session() as s:
- web_key_enc = Config.fetch(s, ConfigKey.WEB_KEY)
+ with p.begin_session():
+ web_key_enc = Config.fetch(ConfigKey.WEB_KEY)
new_web_key = p.decrypt_s(web_key_enc)
assert new_web_key == web_key
assert new_web_key != new_key
@@ -57,8 +57,8 @@ def test_change_web_key(
p, db_key = empty_portfolio_encrypted
p.change_web_key(new_key)
- with p.begin_session() as s:
- web_key_enc = Config.fetch(s, ConfigKey.WEB_KEY)
+ with p.begin_session():
+ web_key_enc = Config.fetch(ConfigKey.WEB_KEY)
web_key = p.decrypt_s(web_key_enc)
assert web_key == new_key
assert web_key != db_key
diff --git a/tests/portfolio/test_import.py b/tests/portfolio/test_import.py
index 46e33be9..6ede35eb 100644
--- a/tests/portfolio/test_import.py
+++ b/tests/portfolio/test_import.py
@@ -1,24 +1,21 @@
from __future__ import annotations
import datetime
-import operator
from typing import TYPE_CHECKING
import pytest
from nummus import exceptions as exc
-from nummus.models.account import Account
-from nummus.models.asset import Asset
+from nummus import sql
from nummus.models.transaction import Transaction, TransactionSplit
-from nummus.models.utils import query_count
-from nummus.portfolio import Portfolio
from tests.importers.test_raw_csv import TRANSACTIONS_REQUIRED
if TYPE_CHECKING:
- from collections.abc import Callable
from pathlib import Path
- from nummus.models.base import Base
+ from nummus.models.account import Account
+ from nummus.models.asset import Asset
+ from nummus.portfolio import Portfolio
def test_import_file(
@@ -29,33 +26,30 @@ def test_import_file(
asset: Asset,
categories: dict[str, int],
) -> None:
- _ = account_investments
- _ = asset
path = data_path / "transactions_required.csv"
path_debug = empty_portfolio.path.with_suffix(".importer-debug")
# Create first txn to be cleared
- with empty_portfolio.begin_session() as s:
+ with empty_portfolio.begin_session():
d = TRANSACTIONS_REQUIRED[0]
- txn = Transaction(
+ txn = Transaction.create(
account_id=account.id_,
date=d["date"],
amount=d["amount"],
statement="Manually imported",
)
- t_split = TransactionSplit(
+ TransactionSplit.create(
parent=txn,
amount=txn.amount,
category_id=categories["uncategorized"],
)
- s.add_all((txn, t_split))
empty_portfolio.import_file(path, path_debug)
assert not path_debug.exists()
- with empty_portfolio.begin_session() as s:
- assert query_count(s.query(Transaction)) == len(TRANSACTIONS_REQUIRED)
+ with empty_portfolio.begin_session():
+ assert sql.count(Transaction.query()) == len(TRANSACTIONS_REQUIRED)
def test_import_file_duplicate(
@@ -65,9 +59,6 @@ def test_import_file_duplicate(
account_investments: Account,
asset: Asset,
) -> None:
- _ = account
- _ = account_investments
- _ = asset
path = data_path / "transactions_required.csv"
path_debug = empty_portfolio.path.with_suffix(".importer-debug")
@@ -86,9 +77,6 @@ def test_import_file_force(
account_investments: Account,
asset: Asset,
) -> None:
- _ = account
- _ = account_investments
- _ = asset
path = data_path / "transactions_required.csv"
path_debug = empty_portfolio.path.with_suffix(".importer-debug")
@@ -97,8 +85,8 @@ def test_import_file_force(
assert not path_debug.exists()
- with empty_portfolio.begin_session() as s:
- assert query_count(s.query(Transaction)) == len(TRANSACTIONS_REQUIRED) * 2
+ with empty_portfolio.begin_session():
+ assert sql.count(Transaction.query()) == len(TRANSACTIONS_REQUIRED) * 2
@pytest.mark.parametrize(
@@ -131,9 +119,6 @@ def test_import_file_error(
target: type[Exception],
debug_exists: bool,
) -> None:
- _ = account
- _ = account_investments
- _ = asset
path_debug = empty_portfolio.path.with_suffix(".importer-debug")
for f in files[:-1]:
path = data_path / f
@@ -153,8 +138,6 @@ def test_import_file_investments(
account_investments: Account,
asset: Asset,
) -> None:
- _ = account
- _ = account_investments
path = data_path / "transactions_investments.csv"
path_debug = empty_portfolio.path.with_suffix(".importer-debug")
@@ -162,14 +145,13 @@ def test_import_file_investments(
assert not path_debug.exists()
- with empty_portfolio.begin_session() as s:
- assert query_count(s.query(Transaction)) == 4
+ with empty_portfolio.begin_session():
+ assert sql.count(Transaction.query()) == 4
- txn = (
- s.query(Transaction)
- .where(Transaction.date_ord == datetime.date(2023, 1, 3).toordinal())
- .one()
+ query = Transaction.query().where(
+ Transaction.date_ord == datetime.date(2023, 1, 3).toordinal(),
)
+ txn = sql.one(query)
assert txn.statement == f"Asset Transaction {asset.name}"
@@ -179,7 +161,6 @@ def test_import_file_bad_category(
account: Account,
categories: dict[str, int],
) -> None:
- _ = account
path = data_path / "transactions_bad_category.csv"
path_debug = empty_portfolio.path.with_suffix(".importer-debug")
@@ -187,57 +168,6 @@ def test_import_file_bad_category(
assert not path_debug.exists()
- with empty_portfolio.begin_session() as s:
- t_split = s.query(TransactionSplit).one()
+ with empty_portfolio.begin_session():
+ t_split = TransactionSplit.one()
assert t_split.category_id == categories["uncategorized"]
-
-
-@pytest.mark.parametrize(
- ("type_", "prop", "value_adjuster"),
- [
- (Account, "uri", lambda s: s),
- (Account, "number", lambda s: s),
- (Account, "number", operator.itemgetter(slice(-4, None))),
- (Account, "institution", lambda s: s),
- (Account, "name", lambda s: s),
- (Account, "name", lambda s: s.lower()),
- (Account, "name", lambda s: s.upper()),
- (Asset, "uri", lambda s: s),
- (Asset, "ticker", lambda s: s),
- (Asset, "name", lambda s: s),
- ],
-)
-def test_find(
- empty_portfolio: Portfolio,
- account: Account,
- asset: Asset,
- type_: type[Base],
- prop: str,
- value_adjuster: Callable[[str], str],
-) -> None:
- obj = {
- Account: account,
- Asset: asset,
- }[type_]
- query = value_adjuster(getattr(obj, prop))
-
- cache: dict[str, tuple[int, str | None]] = {}
- with empty_portfolio.begin_session() as s:
- a_id, a_name = Portfolio.find(s, type_, query, cache)
- assert a_id == obj.id_
- assert a_name == obj.name
-
- assert cache == {query: (a_id, a_name)}
-
-
-def test_find_missing(
- empty_portfolio: Portfolio,
- account: Account,
-) -> None:
- query = Account.id_to_uri(account.id_ + 1)
-
- cache: dict[str, tuple[int, str | None]] = {}
- with empty_portfolio.begin_session() as s, pytest.raises(exc.NoResultFound):
- Portfolio.find(s, Account, query, cache)
-
- assert not cache
diff --git a/tests/portfolio/test_open.py b/tests/portfolio/test_open.py
index 562dd890..5e68d412 100644
--- a/tests/portfolio/test_open.py
+++ b/tests/portfolio/test_open.py
@@ -6,11 +6,11 @@
import pytest
from nummus import exceptions as exc
+from nummus import sql
from nummus.encryption.top import ENCRYPTION_AVAILABLE
from nummus.models.asset import Asset
from nummus.models.config import Config, ConfigKey
from nummus.models.transaction_category import TransactionCategory
-from nummus.models.utils import query_count
from nummus.portfolio import Portfolio
if TYPE_CHECKING:
@@ -56,10 +56,10 @@ def test_unencrypted(tmp_path: Path) -> None:
assert not p.is_encrypted
assert not Portfolio.is_encrypted_path(path)
- with p.begin_session() as s:
- assert query_count(s.query(Config)) == 5
- assert query_count(s.query(TransactionCategory)) > 0
- assert query_count(s.query(Asset)) > 0
+ with p.begin_session():
+ assert sql.count(Config.query()) == 5
+ assert sql.any_(TransactionCategory.query())
+ assert sql.any_(Asset.query())
with pytest.raises(exc.NotEncryptedError):
p.encrypt("")
@@ -92,16 +92,16 @@ def test_no_encryption_test(
empty_portfolio: Portfolio,
key: ConfigKey,
) -> None:
- with empty_portfolio.begin_session() as s:
- s.query(Config).where(Config.key == key).delete()
+ with empty_portfolio.begin_session():
+ Config.query().where(Config.key == key).delete()
with pytest.raises(exc.ProtectedObjectNotFoundError):
Portfolio(empty_portfolio.path, None)
def test_bad_encryption_test(empty_portfolio: Portfolio) -> None:
- with empty_portfolio.begin_session() as s:
- Config.set_(s, ConfigKey.ENCRYPTION_TEST, "fake")
+ with empty_portfolio.begin_session():
+ Config.set_(ConfigKey.ENCRYPTION_TEST, "fake")
with pytest.raises(exc.UnlockingError):
Portfolio(empty_portfolio.path, None)
@@ -124,10 +124,10 @@ def test_encrypted(tmp_path: Path, rand_str: str) -> None:
assert p.is_encrypted
assert Portfolio.is_encrypted_path(path)
- with p.begin_session() as s:
- assert query_count(s.query(Config)) == 6
- assert query_count(s.query(TransactionCategory)) > 0
- assert query_count(s.query(Asset)) > 0
+ with p.begin_session():
+ assert sql.count(Config.query()) == 6
+ assert sql.any_(TransactionCategory.query())
+ assert sql.any_(Asset.query())
@pytest.mark.skipif(not ENCRYPTION_AVAILABLE, reason="No encryption available")
@@ -149,8 +149,8 @@ def test_encrypted_bad_enc_test(
empty_portfolio_encrypted: tuple[Portfolio, str],
) -> None:
p, key = empty_portfolio_encrypted
- with p.begin_session() as s:
- Config.set_(s, ConfigKey.ENCRYPTION_TEST, "fake")
+ with p.begin_session():
+ Config.set_(ConfigKey.ENCRYPTION_TEST, "fake")
with pytest.raises(exc.UnlockingError):
Portfolio(p.path, key)
diff --git a/tests/portfolio/test_update_assets.py b/tests/portfolio/test_update_assets.py
index 3b9a9832..29169cfe 100644
--- a/tests/portfolio/test_update_assets.py
+++ b/tests/portfolio/test_update_assets.py
@@ -17,7 +17,7 @@
from nummus.portfolio import Portfolio
-def test_empty(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> None:
+def test_empty(capsys: pytest.CaptureFixture[str], empty_portfolio: Portfolio) -> None:
assert empty_portfolio.update_assets(no_bars=True) == []
captured = capsys.readouterr()
@@ -26,7 +26,6 @@ def test_empty(capsys: pytest.CaptureFixture, empty_portfolio: Portfolio) -> Non
def test_no_txns(empty_portfolio: Portfolio, asset: Asset) -> None:
- _ = asset
assert empty_portfolio.update_assets(no_bars=True) == []
@@ -38,11 +37,9 @@ def test_update_assets(
asset_etf: Asset,
transactions: list[Transaction],
) -> None:
- _ = asset_etf
- asset.interpolate = True
- session.query(Asset).where(Asset.category == AssetCategory.INDEX).delete()
- session.commit()
- _ = transactions
+ with session.begin_nested():
+ asset.interpolate = True
+ Asset.query().where(Asset.category == AssetCategory.INDEX).delete()
target: list[AssetUpdate] = [
AssetUpdate(
asset.name,
@@ -65,10 +62,9 @@ def test_error(
asset: Asset,
transactions: list[Transaction],
) -> None:
- asset.ticker = "FAKE"
- session.query(Asset).where(Asset.category == AssetCategory.INDEX).delete()
- session.commit()
- _ = transactions
+ with session.begin_nested():
+ asset.ticker = "FAKE"
+ Asset.query().where(Asset.category == AssetCategory.INDEX).delete()
target: list[AssetUpdate] = [
AssetUpdate(
asset.name,
diff --git a/tests/test_main.py b/tests/test_main.py
index 0fcf6400..79dea1b4 100644
--- a/tests/test_main.py
+++ b/tests/test_main.py
@@ -42,7 +42,7 @@ def test_unlock_non_existant(empty_portfolio: Portfolio) -> None:
def test_unlock_successful(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
empty_portfolio: Portfolio,
) -> None:
args = ["--portfolio", str(empty_portfolio.path), "unlock"]
diff --git a/tests/test_sql.py b/tests/test_sql.py
index 6e857acc..aed92a35 100644
--- a/tests/test_sql.py
+++ b/tests/test_sql.py
@@ -7,6 +7,7 @@
from nummus import sql
from nummus.encryption.top import Encryption, ENCRYPTION_AVAILABLE
+from nummus.models.config import Config, ConfigKey
if TYPE_CHECKING:
from pathlib import Path
@@ -60,3 +61,116 @@ def test_escape_not_reserved() -> None:
def test_escape_reserved() -> None:
assert sql.escape("where") == "`where`"
+
+
+def test_to_dict() -> None:
+ query = Config.query(Config.key, Config.value)
+ result = sql.to_dict(query)
+ assert isinstance(result, dict)
+ assert all(isinstance(k, ConfigKey) for k in result)
+ assert all(isinstance(v, str) for v in result.values())
+
+
+def test_to_dict_tuple() -> None:
+ query = Config.query(Config.id_, Config.key, Config.value)
+ result = sql.to_dict_tuple(query)
+ assert isinstance(result, dict)
+ assert all(isinstance(k, int) for k in result)
+ assert all(isinstance(v, tuple) for v in result.values())
+ assert all(len(v) == 2 for v in result.values())
+ assert all(isinstance(v[0], ConfigKey) for v in result.values())
+ assert all(isinstance(v[1], str) for v in result.values())
+
+
+def test_count() -> None:
+ query = Config.query()
+ assert sql.count(query) == query.count()
+
+
+def test_any() -> None:
+ assert sql.any_(Config.query())
+
+
+def test_any_none() -> None:
+ Config.query().delete()
+ assert not sql.any_(Config.query())
+
+
+def test_one() -> None:
+ query = Config.query().where(
+ Config.key == ConfigKey.VERSION,
+ )
+ result = sql.one(query)
+ assert isinstance(result, Config)
+
+
+def test_one_value() -> None:
+ query = Config.query(Config.key).where(
+ Config.key == ConfigKey.VERSION,
+ )
+ result = sql.one(query)
+ assert isinstance(result, ConfigKey)
+
+
+def test_one_tuple() -> None:
+ query = Config.query(Config.key, Config.value).where(
+ Config.key == ConfigKey.VERSION,
+ )
+ result = sql.one(query)
+ assert isinstance(result, tuple)
+ assert len(result) == 2
+ assert isinstance(result[0], ConfigKey)
+ assert isinstance(result[1], str)
+
+
+def test_scalar() -> None:
+ query = Config.query().where(
+ Config.key == ConfigKey.VERSION,
+ )
+ result = sql.scalar(query)
+ assert isinstance(result, Config)
+
+
+def test_scalar_value() -> None:
+ query = Config.query(Config.key).where(
+ Config.key == ConfigKey.VERSION,
+ )
+ result = sql.scalar(query)
+ assert isinstance(result, ConfigKey)
+
+
+def test_scalar_tuple() -> None:
+ query = Config.query(Config.key, Config.value).where(
+ Config.key == ConfigKey.VERSION,
+ )
+ result = sql.scalar(query)
+ assert isinstance(result, ConfigKey)
+
+
+def test_yield() -> None:
+ query = Config.query().where()
+ for r in sql.yield_(query):
+ assert isinstance(r, Config)
+
+
+def test_yield_value() -> None:
+ query = Config.query(Config.key)
+ for r in sql.yield_(query):
+ assert isinstance(r, tuple)
+ assert len(r) == 1
+ assert isinstance(r[0], ConfigKey)
+
+
+def test_yield_tuple() -> None:
+ query = Config.query(Config.key, Config.value)
+ for r in sql.yield_(query):
+ assert isinstance(r, tuple)
+ assert len(r) == 2
+ assert isinstance(r[0], ConfigKey)
+ assert isinstance(r[1], str)
+
+
+def test_col0() -> None:
+ query = Config.query(Config.key)
+ for r in sql.col0(query):
+ assert isinstance(r, ConfigKey)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 3a4bc31c..c5940480 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -32,7 +32,7 @@ def test_camel_to_snake(s: str, c: str) -> None:
def test_get_input_insecure(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
rand_str_generator: RandomStringGenerator,
) -> None:
@@ -49,7 +49,7 @@ def mock_input(to_print: str) -> str | None:
def test_get_input_insecure_abort(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
rand_str_generator: RandomStringGenerator,
) -> None:
@@ -66,7 +66,7 @@ def mock_input(to_print: str) -> str | None:
def test_get_input_secure(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
rand_str_generator: RandomStringGenerator,
) -> None:
@@ -83,7 +83,7 @@ def mock_get_pass(to_print: str) -> str | None:
def test_get_input_secure_abort(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
rand_str: str,
) -> None:
@@ -97,7 +97,7 @@ def mock_get_pass(to_print: str) -> str | None:
def test_get_input_secure_with_icon(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
rand_str_generator: RandomStringGenerator,
) -> None:
@@ -149,7 +149,7 @@ def mock_input(to_print: str, *, secure: bool) -> str | None:
],
)
def test_confirm(
- capsys: pytest.CaptureFixture,
+ capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
rand_str: str,
queue: list[str | None],
@@ -651,8 +651,20 @@ def test_pretty_table_no_header() -> None:
utils.pretty_table([None])
-def test_pretty_table_only_header(monkeypatch: pytest.MonkeyPatch) -> None:
- monkeypatch.setattr("shutil.get_terminal_size", lambda **_: (80, 24))
+@pytest.fixture
+def fixed_terminal(
+ monkeypatch: pytest.MonkeyPatch,
+ width: int,
+ height: int,
+) -> None:
+ def mock_terminal_size(**_: object) -> tuple[int, int]:
+ return width, height
+
+ monkeypatch.setattr("shutil.get_terminal_size", mock_terminal_size)
+
+
+@pytest.mark.parametrize(("width", "height"), [(80, 24)])
+def test_pretty_table_only_header(fixed_terminal: None) -> None:
table: list[list[str] | None] = [
["H1", ">H2", " None:
)
assert "\n".join(utils.pretty_table(table)) == target
- # Reset terminal width before verbose info is printed
- monkeypatch.undo()
-
-def test_pretty_table_only_separator(monkeypatch: pytest.MonkeyPatch) -> None:
- monkeypatch.setattr("shutil.get_terminal_size", lambda **_: (80, 24))
+@pytest.mark.parametrize(("width", "height"), [(80, 24)])
+def test_pretty_table_only_separator(fixed_terminal: None) -> None:
table: list[list[str] | None] = [
["H1", ">H2", " None:
)
assert "\n".join(utils.pretty_table(table)) == target
- # Reset terminal width before verbose info is printed
- monkeypatch.undo()
-
@pytest.fixture
def table() -> list[list[str] | None]:
@@ -698,11 +704,11 @@ def table() -> list[list[str] | None]:
]
+@pytest.mark.parametrize(("width", "height"), [(80, 24)])
def test_pretty_table_width_80(
- monkeypatch: pytest.MonkeyPatch,
+ fixed_terminal: None,
table: list[list[str] | None],
) -> None:
- monkeypatch.setattr("shutil.get_terminal_size", lambda **_: (80, 24))
target = textwrap.dedent(
"""\
â•───────────┬───────────┬───────────┬───────────┬───────────┬───────────╮
@@ -715,16 +721,12 @@ def test_pretty_table_width_80(
)
assert "\n".join(utils.pretty_table(table)) == target
- # Reset terminal width before verbose info is printed
- monkeypatch.undo()
-
+@pytest.mark.parametrize(("width", "height"), [(70, 24)])
def test_pretty_table_width_70(
- monkeypatch: pytest.MonkeyPatch,
+ fixed_terminal: None,
table: list[list[str] | None],
) -> None:
- # Make terminal smaller, extra space goes first
- monkeypatch.setattr("shutil.get_terminal_size", lambda **_: (70, 24))
target = textwrap.dedent(
"""\
â•───────────┬───────────┬───────────┬───────────┬─────────┬─────────╮
@@ -737,16 +739,12 @@ def test_pretty_table_width_70(
)
assert "\n".join(utils.pretty_table(table)) == target
- # Reset terminal width before verbose info is printed
- monkeypatch.undo()
-
+@pytest.mark.parametrize(("width", "height"), [(60, 24)])
def test_pretty_table_width_60(
- monkeypatch: pytest.MonkeyPatch,
+ fixed_terminal: None,
table: list[list[str] | None],
) -> None:
- # Make terminal smaller, truncate column goes next
- monkeypatch.setattr("shutil.get_terminal_size", lambda **_: (60, 24))
target = textwrap.dedent(
"""\
â•─────────┬─────────┬─────────┬─────────┬───────┬─────────╮
@@ -759,16 +757,12 @@ def test_pretty_table_width_60(
)
assert "\n".join(utils.pretty_table(table)) == target
- # Reset terminal width before verbose info is printed
- monkeypatch.undo()
-
+@pytest.mark.parametrize(("width", "height"), [(50, 24)])
def test_pretty_table_width_50(
- monkeypatch: pytest.MonkeyPatch,
+ fixed_terminal: None,
table: list[list[str] | None],
) -> None:
- # Make terminal smaller, other columns go next
- monkeypatch.setattr("shutil.get_terminal_size", lambda **_: (50, 24))
target = textwrap.dedent(
"""\
â•───────┬───────┬───────┬────────┬────┬─────────╮
@@ -781,16 +775,12 @@ def test_pretty_table_width_50(
)
assert "\n".join(utils.pretty_table(table)) == target
- # Reset terminal width before verbose info is printed
- monkeypatch.undo()
-
+@pytest.mark.parametrize(("width", "height"), [(10, 24)])
def test_pretty_table_width_10(
- monkeypatch: pytest.MonkeyPatch,
+ fixed_terminal: None,
table: list[list[str] | None],
) -> None:
- # Make terminal tiny, other columns go next, never last
- monkeypatch.setattr("shutil.get_terminal_size", lambda **_: (10, 24))
target = textwrap.dedent(
"""\
â•────┬────┬────┬────┬────┬─────────╮
@@ -803,9 +793,6 @@ def test_pretty_table_width_10(
)
assert "\n".join(utils.pretty_table(table)) == target
- # Reset terminal width before verbose info is printed
- monkeypatch.undo()
-
@pytest.mark.parametrize(
("items", "target"),
diff --git a/tests/test_web.py b/tests/test_web.py
index 23566b36..32a9ac72 100644
--- a/tests/test_web.py
+++ b/tests/test_web.py
@@ -19,8 +19,8 @@
def test_create_app(empty_portfolio: Portfolio, flask_app: flask.Flask) -> None:
- with empty_portfolio.begin_session() as s:
- secret_key = Config.fetch(s, ConfigKey.SECRET_KEY)
+ with empty_portfolio.begin_session():
+ secret_key = Config.fetch(ConfigKey.SECRET_KEY)
assert flask_app.secret_key == secret_key
assert len(flask_app.before_request_funcs[None]) == 1
@@ -32,8 +32,8 @@ def test_no_secret_key(
empty_portfolio: Portfolio,
) -> None:
monkeypatch.setenv("NUMMUS_PORTFOLIO", str(empty_portfolio.path))
- with empty_portfolio.begin_session() as s:
- s.query(Config).where(Config.key == ConfigKey.SECRET_KEY).delete()
+ with empty_portfolio.begin_session():
+ Config.query().where(Config.key == ConfigKey.SECRET_KEY).delete()
with pytest.raises(exc.ProtectedObjectNotFoundError):
web.create_app()
diff --git a/typings/Cryptodome/Cipher/AES.pyi b/typings/Cryptodome/Cipher/AES.pyi
new file mode 100644
index 00000000..14c13447
--- /dev/null
+++ b/typings/Cryptodome/Cipher/AES.pyi
@@ -0,0 +1,111 @@
+# Initially generated by Pyright
+
+from typing import overload
+
+from Cryptodome.Cipher._mode_cbc import CbcMode
+from typing_extensions import Literal
+
+Buffer = bytes | bytearray | memoryview
+MODE_ECB: Literal[1]
+MODE_CBC: Literal[2]
+MODE_CFB: Literal[3]
+MODE_OFB: Literal[5]
+MODE_CTR: Literal[6]
+MODE_OPENPGP: Literal[7]
+MODE_CCM: Literal[8]
+MODE_EAX: Literal[9]
+MODE_SIV: Literal[10]
+MODE_GCM: Literal[11]
+MODE_OCB: Literal[12]
+
+@overload
+def new(key: Buffer, mode: Literal[1], use_aesni: bool = ...) -> None: ...
+@overload
+def new(
+ key: Buffer, mode: Literal[2], iv: Buffer | None = ..., use_aesni: bool = ...
+) -> CbcMode: ...
+@overload
+def new(
+ key: Buffer, mode: Literal[2], IV: Buffer | None = ..., use_aesni: bool = ...
+) -> CbcMode: ...
+@overload
+def new(
+ key: Buffer,
+ mode: Literal[3],
+ iv: Buffer | None = ...,
+ segment_size: int = ...,
+ use_aesni: bool = ...,
+) -> None: ...
+@overload
+def new(
+ key: Buffer,
+ mode: Literal[3],
+ IV: Buffer | None = ...,
+ segment_size: int = ...,
+ use_aesni: bool = ...,
+) -> None: ...
+@overload
+def new(
+ key: Buffer, mode: Literal[5], iv: Buffer | None = ..., use_aesni: bool = ...
+) -> None: ...
+@overload
+def new(
+ key: Buffer, mode: Literal[5], IV: Buffer | None = ..., use_aesni: bool = ...
+) -> None: ...
+@overload
+def new(
+ key: Buffer,
+ mode: Literal[6],
+ nonce: Buffer | None = ...,
+ initial_value: Buffer | int = ...,
+ counter: dict[object, object] = ...,
+ use_aesni: bool = ...,
+) -> None: ...
+@overload
+def new(
+ key: Buffer, mode: Literal[7], iv: Buffer | None = ..., use_aesni: bool = ...
+) -> None: ...
+@overload
+def new(
+ key: Buffer, mode: Literal[7], IV: Buffer | None = ..., use_aesni: bool = ...
+) -> None: ...
+@overload
+def new(
+ key: Buffer,
+ mode: Literal[8],
+ nonce: Buffer | None = ...,
+ mac_len: int = ...,
+ assoc_len: int = ...,
+ use_aesni: bool = ...,
+) -> None: ...
+@overload
+def new(
+ key: Buffer,
+ mode: Literal[9],
+ nonce: Buffer | None = ...,
+ mac_len: int = ...,
+ use_aesni: bool = ...,
+) -> None: ...
+@overload
+def new(
+ key: Buffer, mode: Literal[10], nonce: Buffer | None = ..., use_aesni: bool = ...
+) -> None: ...
+@overload
+def new(
+ key: Buffer,
+ mode: Literal[11],
+ nonce: Buffer | None = ...,
+ mac_len: int = ...,
+ use_aesni: bool = ...,
+) -> None: ...
+@overload
+def new(
+ key: Buffer,
+ mode: Literal[12],
+ nonce: Buffer | None = ...,
+ mac_len: int = ...,
+ use_aesni: bool = ...,
+) -> None: ...
+
+block_size: int
+key_size: tuple[int, int, int]
diff --git a/typings/Cryptodome/Cipher/_mode_cbc.pyi b/typings/Cryptodome/Cipher/_mode_cbc.pyi
new file mode 100644
index 00000000..3c379cb1
--- /dev/null
+++ b/typings/Cryptodome/Cipher/_mode_cbc.pyi
@@ -0,0 +1,30 @@
+# Initially generated by Pyright
+
+from typing import overload
+
+from Cryptodome.Util._raw_api import SmartPointer
+
+Buffer = bytes | bytearray | memoryview
+__all__ = ["CbcMode"]
+
+class CbcMode:
+ block_size: int
+ iv: Buffer
+ IV: Buffer
+ def __init__(self, block_cipher: SmartPointer, iv: Buffer) -> None: ...
+ @overload
+ def encrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def encrypt(
+ self,
+ plaintext: Buffer,
+ output: bytearray | memoryview,
+ ) -> None: ...
+ @overload
+ def decrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def decrypt(
+ self,
+ plaintext: Buffer,
+ output: bytearray | memoryview,
+ ) -> None: ...
diff --git a/typings/flask_assets/__init__.pyi b/typings/flask_assets/__init__.pyi
new file mode 100644
index 00000000..b8c678c5
--- /dev/null
+++ b/typings/flask_assets/__init__.pyi
@@ -0,0 +1,32 @@
+# Initially generated by Pyright
+
+import io
+
+import flask
+from webassets.bundle import Bundle as BaseBundle
+from webassets.env import BaseEnvironment
+from webassets.merge import FileHunk
+
+class Environment(BaseEnvironment):
+ def __init__(self, app: flask.Flask = ...) -> None: ...
+ def set_directory(self, directory: str) -> None: ...
+ def get_directory(self) -> str | None: ...
+ def set_url(self, url: str) -> None: ...
+ def get_url(self) -> str | None: ...
+ def init_app(self, app: flask.Flask) -> None: ...
+ def from_yaml(self, path: str) -> None: ...
+ def from_module(self, path: str) -> None: ...
+ def register(
+ self,
+ name: str,
+ *args: str | Bundle,
+ **kwargs: object,
+ ) -> Bundle | None: ...
+
+class Bundle(BaseBundle):
+ def build(
+ self,
+ force: bool | None = ...,
+ output: io.FileIO | None = ...,
+ disable_cache: bool | None = ...,
+ ) -> list[FileHunk]: ...
diff --git a/typings/flask_login/__init__.pyi b/typings/flask_login/__init__.pyi
new file mode 100644
index 00000000..8f7ab102
--- /dev/null
+++ b/typings/flask_login/__init__.pyi
@@ -0,0 +1,38 @@
+# Initially generated by Pyright
+
+from .login_manager import LoginManager
+from .mixins import AnonymousUserMixin, UserMixin
+from .utils import (
+ confirm_login,
+ current_user,
+ decode_cookie,
+ encode_cookie,
+ fresh_login_required,
+ login_fresh,
+ login_remembered,
+ login_required,
+ login_url,
+ login_user,
+ logout_user,
+ make_next_param,
+ set_login_view,
+)
+
+__all__ = [
+ "LoginManager",
+ "AnonymousUserMixin",
+ "UserMixin",
+ "confirm_login",
+ "current_user",
+ "decode_cookie",
+ "encode_cookie",
+ "fresh_login_required",
+ "login_fresh",
+ "login_remembered",
+ "login_required",
+ "login_url",
+ "login_user",
+ "logout_user",
+ "make_next_param",
+ "set_login_view",
+]
diff --git a/typings/flask_login/login_manager.pyi b/typings/flask_login/login_manager.pyi
new file mode 100644
index 00000000..d80fb8ee
--- /dev/null
+++ b/typings/flask_login/login_manager.pyi
@@ -0,0 +1,47 @@
+# Initially generated by Pyright
+
+from collections.abc import Callable
+
+import flask
+
+from .mixins import AnonymousUserMixin, UserMixin
+
+class LoginManager:
+ login_view: str | None
+ def __init__(
+ self,
+ app: flask.Flask = ...,
+ add_context_processor: bool = ...,
+ ) -> None: ...
+ def setup_app(
+ self,
+ app: flask.Flask,
+ add_context_processor: bool = ...,
+ ) -> None: ...
+ def init_app(self, app: flask.Flask, add_context_processor: bool = ...) -> None: ...
+ def unauthorized(self) -> flask.Response: ...
+ def user_loader(
+ self,
+ callback: Callable[[str], UserMixin | AnonymousUserMixin | None],
+ ) -> None: ...
+ @property
+ def user_callback(self) -> None: ...
+ def request_loader(
+ self,
+ callback: Callable[[flask.Request], UserMixin | AnonymousUserMixin | None],
+ ) -> None: ...
+ @property
+ def request_callback(self) -> None: ...
+ def unauthorized_handler[T: Callable[[], flask.Response]](
+ self,
+ callback: T,
+ ) -> T: ...
+ def needs_refresh_handler[T: Callable[[], flask.Response]](
+ self,
+ callback: T,
+ ) -> T: ...
+ def needs_refresh(self) -> flask.Response: ...
+ def header_loader[T: Callable[[str], UserMixin | AnonymousUserMixin | None]](
+ self,
+ callback: T,
+ ) -> T: ...
diff --git a/typings/flask_login/mixins.pyi b/typings/flask_login/mixins.pyi
new file mode 100644
index 00000000..165962ba
--- /dev/null
+++ b/typings/flask_login/mixins.pyi
@@ -0,0 +1,23 @@
+# Initially generated by Pyright
+
+from typing import Literal
+
+class UserMixin:
+ @property
+ def is_active(self) -> Literal[True]: ...
+ @property
+ def is_authenticated(self) -> Literal[True]: ...
+ @property
+ def is_anonymous(self) -> Literal[False]: ...
+ def get_id(self) -> str: ...
+ def __eq__(self, other: UserMixin | object) -> bool: ...
+ def __ne__(self, other: UserMixin | object) -> bool: ...
+
+class AnonymousUserMixin:
+ @property
+ def is_authenticated(self) -> Literal[False]: ...
+ @property
+ def is_active(self) -> Literal[False]: ...
+ @property
+ def is_anonymous(self) -> Literal[True]: ...
+ def get_id(self) -> None: ...
diff --git a/typings/flask_login/utils.pyi b/typings/flask_login/utils.pyi
new file mode 100644
index 00000000..bc05e870
--- /dev/null
+++ b/typings/flask_login/utils.pyi
@@ -0,0 +1,31 @@
+# Initially generated by Pyright
+
+import datetime
+from collections.abc import Callable
+from typing import Literal
+
+import flask
+
+from .mixins import AnonymousUserMixin, UserMixin
+
+current_user: UserMixin | AnonymousUserMixin = ...
+
+def encode_cookie(payload: str, key: str = ...) -> str: ...
+def decode_cookie(cookie: str, key: str = ...) -> str | None: ...
+def make_next_param(login_url: str, current_url: str) -> str: ...
+def expand_login_view(login_view: str) -> str: ...
+def login_url(login_view: str, next_url: str = ..., next_field: str = ...) -> str: ...
+def login_fresh() -> bool: ...
+def login_remembered() -> bool: ...
+def login_user(
+ user: UserMixin,
+ remember: bool = ...,
+ duration: datetime.timedelta = ...,
+ force: bool = ...,
+ fresh: bool = ...,
+) -> bool: ...
+def logout_user() -> Literal[True]: ...
+def confirm_login() -> None: ...
+def login_required[T: Callable[..., object]](func: T) -> T: ...
+def fresh_login_required[T: Callable[..., object]](func: T) -> T: ...
+def set_login_view(login_view: str, blueprint: flask.Blueprint = ...) -> None: ...
diff --git a/typings/jsmin/__init__.pyi b/typings/jsmin/__init__.pyi
new file mode 100644
index 00000000..658e62ad
--- /dev/null
+++ b/typings/jsmin/__init__.pyi
@@ -0,0 +1,16 @@
+# Initially generated by Pyright
+
+import io
+
+def jsmin(js: str, **kwargs: object) -> str: ...
+
+class JavascriptMinify:
+ def __init__(
+ self,
+ instream: io.StringIO = ...,
+ outstream: io.StringIO = ...,
+ quote_chars: str = ...,
+ ) -> None: ...
+ def minify(
+ self, instream: io.StringIO = ..., outstream: io.StringIO = ...
+ ) -> None: ...
diff --git a/typings/numpy_financial/__init__.pyi b/typings/numpy_financial/__init__.pyi
new file mode 100644
index 00000000..fd47cd48
--- /dev/null
+++ b/typings/numpy_financial/__init__.pyi
@@ -0,0 +1,5 @@
+# Initially generated by Pyright
+
+from ._financial import *
+
+__version__ = ...
diff --git a/typings/numpy_financial/_financial.pyi b/typings/numpy_financial/_financial.pyi
new file mode 100644
index 00000000..be1a6f59
--- /dev/null
+++ b/typings/numpy_financial/_financial.pyi
@@ -0,0 +1,5 @@
+# Initially generated by Pyright
+
+from decimal import Decimal
+
+def irr(values: list[Decimal] | list[float]) -> float: ...
diff --git a/typings/prometheus_flask_exporter/__init__.pyi b/typings/prometheus_flask_exporter/__init__.pyi
new file mode 100644
index 00000000..7d13172e
--- /dev/null
+++ b/typings/prometheus_flask_exporter/__init__.pyi
@@ -0,0 +1,107 @@
+# Initially generated by Pyright
+
+from collections.abc import Callable
+from typing import Literal, Self
+
+import flask
+import prometheus_client
+from flask.typing import ResponseReturnValue
+from prometheus_client import Gauge
+from werkzeug.exceptions import HTTPException
+
+GroupBy = Literal["path", "endpoint", "url_rule"]
+
+Responder = Callable[..., ResponseReturnValue | object | HTTPException | flask.Response]
+
+class PrometheusMetrics:
+ def __init__[T: Callable[..., object]](
+ self,
+ app: flask.Flask,
+ path: str = ...,
+ export_defaults: bool = ...,
+ defaults_prefix: str = ...,
+ group_by: GroupBy = ...,
+ buckets: tuple[float, ...] | None = ...,
+ default_latency_as_histogram: bool = ...,
+ default_labels: dict[str, object] | None = ...,
+ response_converter: Responder = ...,
+ excluded_paths: list[str] | str | None = ...,
+ exclude_user_defaults: bool = ...,
+ metrics_decorator: Callable[[T], T] = ...,
+ registry: prometheus_client.CollectorRegistry | None = ...,
+ **kwargs: object,
+ ) -> None: ...
+ @classmethod
+ def for_app_factory(cls, **kwargs: object) -> Self: ...
+ def init_app(self, app: flask.Flask) -> None: ...
+ def register_endpoint(self, path: str, app: flask.Flask = ...) -> None: ...
+ def generate_metrics(
+ self,
+ accept_header: str | None = ...,
+ names: list[str] | None = ...,
+ ) -> tuple[str, str]: ...
+ def start_http_server(
+ self,
+ port: int,
+ host: str = ...,
+ endpoint: str = ...,
+ ssl: dict[str, str] | None = ...,
+ ) -> None: ...
+ def export_defaults(
+ self,
+ buckets: tuple[float, ...] | None = ...,
+ group_by: GroupBy = ...,
+ latency_as_histogram: bool = ...,
+ prefix: str = ...,
+ app: flask.Flask = ...,
+ **kwargs: object,
+ ) -> None: ...
+ def register_default[T: Callable[..., object]](
+ self,
+ *metric_wrappers: Callable[[T], T],
+ **kwargs: object,
+ ) -> None: ...
+ def histogram(
+ self,
+ name: str,
+ description: str,
+ labels: dict[str, object] | None = ...,
+ initial_value_when_only_static_labels: bool = ...,
+ **kwargs: str,
+ ) -> Callable[..., Responder]: ...
+ def summary(
+ self,
+ name: str,
+ description: str,
+ labels: dict[str, object] | None = ...,
+ initial_value_when_only_static_labels: bool = ...,
+ **kwargs: str,
+ ) -> Callable[..., Responder]: ...
+ def gauge(
+ self,
+ name: str,
+ description: str,
+ labels: dict[str, object] | None = ...,
+ initial_value_when_only_static_labels: bool = ...,
+ **kwargs: str,
+ ) -> Callable[..., Responder]: ...
+ def counter(
+ self,
+ name: str,
+ description: str,
+ labels: dict[str, object] | None = ...,
+ initial_value_when_only_static_labels: bool = ...,
+ **kwargs: str,
+ ) -> Callable[..., Responder]: ...
+ @staticmethod
+ def do_not_track[T: Callable[..., object]]() -> Callable[[T], T]: ...
+ @staticmethod
+ def exclude_all_metrics[T: Callable[..., object]]() -> Callable[[T], T]: ...
+ def info(
+ self,
+ name: str,
+ description: str,
+ labelnames: tuple[str, ...] | None = ...,
+ labelvalues: tuple[object, ...] | None = ...,
+ **labels: object,
+ ) -> Gauge: ...
diff --git a/typings/prometheus_flask_exporter/multiprocess.pyi b/typings/prometheus_flask_exporter/multiprocess.pyi
new file mode 100644
index 00000000..3957653d
--- /dev/null
+++ b/typings/prometheus_flask_exporter/multiprocess.pyi
@@ -0,0 +1,19 @@
+# Initially generated by Pyright
+
+from abc import ABCMeta, abstractmethod
+from typing import Literal
+
+from . import PrometheusMetrics
+
+class MultiprocessPrometheusMetrics(PrometheusMetrics):
+ __metaclass__ = ABCMeta
+
+ @abstractmethod
+ def should_start_http_server(self) -> bool: ...
+
+class GunicornPrometheusMetrics(MultiprocessPrometheusMetrics):
+ def should_start_http_server(self) -> Literal[True]: ...
+ @classmethod
+ def start_http_server_when_ready(cls, port: int, host: str = ...) -> None: ...
+ @classmethod
+ def mark_process_dead_on_child_exit(cls, pid: int) -> None: ...
diff --git a/typings/pytailwindcss/__init__.pyi b/typings/pytailwindcss/__init__.pyi
new file mode 100644
index 00000000..180d3bdd
--- /dev/null
+++ b/typings/pytailwindcss/__init__.pyi
@@ -0,0 +1,13 @@
+# Initially generated by Pyright
+#
+import pathlib
+
+def run(
+ tailwindcss_cli_args: list[str] | str = ...,
+ cwd: pathlib.Path | str = ...,
+ bin_path: pathlib.Path | str = ...,
+ env: dict[str, str] = ...,
+ live_output: bool = ...,
+ auto_install: bool = ...,
+ version: str = ...,
+) -> str: ...
diff --git a/typings/pytest/__init__.pyi b/typings/pytest/__init__.pyi
new file mode 100644
index 00000000..07eb6452
--- /dev/null
+++ b/typings/pytest/__init__.pyi
@@ -0,0 +1,181 @@
+# Initially generated by Pyright
+
+from decimal import Decimal
+from typing import overload
+
+from _pytest import __version__, version_tuple
+from _pytest._code import ExceptionInfo
+from _pytest.assertion import register_assert_rewrite
+from _pytest.cacheprovider import Cache
+from _pytest.capture import CaptureFixture
+from _pytest.config import (
+ cmdline,
+ Config,
+ console_main,
+ ExitCode,
+ hookimpl,
+ hookspec,
+ main,
+ PytestPluginManager,
+ UsageError,
+)
+from _pytest.config.argparsing import OptionGroup, Parser
+from _pytest.doctest import DoctestItem
+from _pytest.fixtures import (
+ fixture,
+ FixtureDef,
+ FixtureLookupError,
+ FixtureRequest,
+)
+from _pytest.freeze_support import freeze_includes
+from _pytest.legacypath import TempdirFactory, Testdir
+from _pytest.logging import LogCaptureFixture
+from _pytest.main import Dir, Session
+from _pytest.mark import HIDDEN_PARAM, Mark
+from _pytest.mark import MARK_GEN as mark
+from _pytest.mark import MarkDecorator, MarkGenerator, param
+from _pytest.monkeypatch import MonkeyPatch
+from _pytest.nodes import Collector, Directory, File, Item
+from _pytest.outcomes import exit, fail, importorskip, skip, xfail
+from _pytest.pytester import (
+ HookRecorder,
+ LineMatcher,
+ Pytester,
+ RecordedHookCall,
+ RunResult,
+)
+from _pytest.python import Class, Function, Metafunc, Module, Package
+from _pytest.python_api import ApproxBase, ApproxDecimal, ApproxScalar
+from _pytest.raises import raises, RaisesExc, RaisesGroup
+from _pytest.recwarn import deprecated_call, WarningsRecorder, warns
+from _pytest.reports import CollectReport, TestReport
+from _pytest.runner import CallInfo
+from _pytest.stash import Stash, StashKey
+from _pytest.subtests import SubtestReport, Subtests
+from _pytest.terminal import TerminalReporter, TestShortLogReport
+from _pytest.tmpdir import TempPathFactory
+from _pytest.warning_types import (
+ PytestAssertRewriteWarning,
+ PytestCacheWarning,
+ PytestCollectionWarning,
+ PytestConfigWarning,
+ PytestDeprecationWarning,
+ PytestExperimentalApiWarning,
+ PytestFDWarning,
+ PytestRemovedIn9Warning,
+ PytestRemovedIn10Warning,
+ PytestReturnNotNoneWarning,
+ PytestUnhandledThreadExceptionWarning,
+ PytestUnknownMarkWarning,
+ PytestUnraisableExceptionWarning,
+ PytestWarning,
+)
+
+@overload
+def approx[T: Decimal](
+ expected: T,
+ rel: T | None = None,
+ abs: T | None = None,
+ nan_ok: bool = False,
+) -> ApproxDecimal: ...
+@overload
+def approx[T](
+ expected: T,
+ rel: T | None = None,
+ abs: T | None = None,
+ nan_ok: bool = False,
+) -> ApproxScalar: ...
+def approx[T](
+ expected: T,
+ rel: T | None = None,
+ abs: T | None = None,
+ nan_ok: bool = False,
+) -> ApproxBase: ...
+
+__all__ = [
+ "HIDDEN_PARAM",
+ "Cache",
+ "CallInfo",
+ "CaptureFixture",
+ "Class",
+ "CollectReport",
+ "Collector",
+ "Config",
+ "Dir",
+ "Directory",
+ "DoctestItem",
+ "ExceptionInfo",
+ "ExitCode",
+ "File",
+ "FixtureDef",
+ "FixtureLookupError",
+ "FixtureRequest",
+ "Function",
+ "HookRecorder",
+ "Item",
+ "LineMatcher",
+ "LogCaptureFixture",
+ "Mark",
+ "MarkDecorator",
+ "MarkGenerator",
+ "Metafunc",
+ "Module",
+ "MonkeyPatch",
+ "OptionGroup",
+ "Package",
+ "Parser",
+ "PytestAssertRewriteWarning",
+ "PytestCacheWarning",
+ "PytestCollectionWarning",
+ "PytestConfigWarning",
+ "PytestDeprecationWarning",
+ "PytestExperimentalApiWarning",
+ "PytestFDWarning",
+ "PytestPluginManager",
+ "PytestRemovedIn9Warning",
+ "PytestRemovedIn10Warning",
+ "PytestReturnNotNoneWarning",
+ "PytestUnhandledThreadExceptionWarning",
+ "PytestUnknownMarkWarning",
+ "PytestUnraisableExceptionWarning",
+ "PytestWarning",
+ "Pytester",
+ "RaisesExc",
+ "RaisesGroup",
+ "RecordedHookCall",
+ "RunResult",
+ "Session",
+ "Stash",
+ "StashKey",
+ "SubtestReport",
+ "Subtests",
+ "TempPathFactory",
+ "TempdirFactory",
+ "TerminalReporter",
+ "TestReport",
+ "TestShortLogReport",
+ "Testdir",
+ "UsageError",
+ "WarningsRecorder",
+ "__version__",
+ "approx",
+ "cmdline",
+ "console_main",
+ "deprecated_call",
+ "exit",
+ "fail",
+ "fixture",
+ "freeze_includes",
+ "hookimpl",
+ "hookspec",
+ "importorskip",
+ "main",
+ "mark",
+ "param",
+ "raises",
+ "register_assert_rewrite",
+ "skip",
+ "version_tuple",
+ "warns",
+ "xfail",
+]
diff --git a/typings/sqlcipher3/__init__.pyi b/typings/sqlcipher3/__init__.pyi
new file mode 100644
index 00000000..51667194
--- /dev/null
+++ b/typings/sqlcipher3/__init__.pyi
@@ -0,0 +1,2 @@
+# Generated by Pyright
+# Don't actually need any typing
diff --git a/typings/yfinance/__init__.pyi b/typings/yfinance/__init__.pyi
new file mode 100644
index 00000000..3d095e1c
--- /dev/null
+++ b/typings/yfinance/__init__.pyi
@@ -0,0 +1,5 @@
+# Initially generated by Pyright
+
+from .ticker import Ticker
+
+__all__ = ["Ticker"]
diff --git a/typings/yfinance/base.pyi b/typings/yfinance/base.pyi
new file mode 100644
index 00000000..748d987c
--- /dev/null
+++ b/typings/yfinance/base.pyi
@@ -0,0 +1,29 @@
+# Initially generated by Pyright
+
+import datetime
+
+import pandas as pd
+import requests
+
+class TickerBase:
+ def __init__(
+ self,
+ ticker: str | tuple[str, str],
+ session: requests.Session | None = ...,
+ ) -> None: ...
+ def history(
+ self,
+ period: str | None = ...,
+ interval: str = ...,
+ start: datetime.date | str | None = ...,
+ end: datetime.date | str | None = ...,
+ prepost: bool = ...,
+ actions: bool = ...,
+ auto_adjust: bool = ...,
+ back_adjust: bool = ...,
+ repair: bool = ...,
+ keepna: bool = ...,
+ rounding: bool = ...,
+ timeout: float | None = ...,
+ raise_errors: bool = ...,
+ ) -> pd.DataFrame: ...
diff --git a/typings/yfinance/exceptions.pyi b/typings/yfinance/exceptions.pyi
new file mode 100644
index 00000000..2eec6a0f
--- /dev/null
+++ b/typings/yfinance/exceptions.pyi
@@ -0,0 +1,4 @@
+# Initially generated by Pyright
+
+class YFException(Exception): ...
+class YFDataException(YFException): ...
diff --git a/typings/yfinance/scrapers/__init__.pyi b/typings/yfinance/scrapers/__init__.pyi
new file mode 100644
index 00000000..e69de29b
diff --git a/typings/yfinance/scrapers/funds.pyi b/typings/yfinance/scrapers/funds.pyi
new file mode 100644
index 00000000..14c1a8a7
--- /dev/null
+++ b/typings/yfinance/scrapers/funds.pyi
@@ -0,0 +1,5 @@
+# Initially generated by Pyright
+
+class FundsData:
+ @property
+ def sector_weightings(self) -> dict[str, float]: ...
diff --git a/typings/yfinance/ticker.pyi b/typings/yfinance/ticker.pyi
new file mode 100644
index 00000000..bcea0081
--- /dev/null
+++ b/typings/yfinance/ticker.pyi
@@ -0,0 +1,24 @@
+# Initially generated by Pyright
+
+from typing import NotRequired, TypedDict
+
+import requests
+
+from .base import TickerBase
+from .scrapers.funds import FundsData
+
+class Info(TypedDict):
+
+ currency: str
+ sector: NotRequired[str | None]
+
+class Ticker(TickerBase):
+ def __init__(
+ self,
+ ticker: str | tuple[str, str],
+ session: requests.Session | None = ...,
+ ) -> None: ...
+ @property
+ def info(self) -> Info: ...
+ @property
+ def funds_data(self) -> FundsData: ...