-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Overhaul the code structure around database logics
This commit introduces a new SQLModel dependency to merge Pydantic model and SQLAlchemy table declarations into one.
- Loading branch information
Showing
11 changed files
with
526 additions
and
424 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
from enum import Enum | ||
from typing import assert_never | ||
|
||
import sqlalchemy | ||
from databases import Database | ||
from pydantic import BaseModel | ||
from sqlmodel import SQLModel | ||
|
||
from . import placed_item, placement, product | ||
from .placed_item import PlacedItem | ||
from .placement import Placement | ||
from .product import Product | ||
|
||
DATABASE_URL = "sqlite:///app.db" | ||
database = Database(DATABASE_URL) | ||
|
||
|
||
ProductTable = product.Table(database) | ||
PlacedItemTable = placed_item.Table(database) | ||
PlacementTable = placement.Table(database) | ||
|
||
|
||
class PlacementReceipt(BaseModel): | ||
class ProductEntry(BaseModel): | ||
product_id: int | ||
count: int | ||
name: str | ||
filename: str | ||
price: int | ||
|
||
placement_id: int | ||
products: list[ProductEntry] | ||
total_price: int | ||
|
||
class Result(BaseModel): | ||
placement_id: int | ||
product_id: int | ||
count: int | ||
name: str | ||
filename: str | ||
price: int | ||
|
||
|
||
async def select_placements( | ||
canceled: bool, | ||
completed: bool, | ||
) -> list[dict[str, int | list[dict[str, int | str]] | str]]: | ||
# list of products with each row alongside the number of ordered items | ||
query = f""" | ||
SELECT | ||
{PlacedItem.placement_id}, | ||
{PlacedItem.product_id}, | ||
COUNT({PlacedItem.product_id}) AS count, | ||
{Product.name}, | ||
{Product.filename}, | ||
{Product.price} | ||
FROM {PlacedItem.__tablename__} as {PlacedItem.__name__} | ||
JOIN {Product.__tablename__} as {Product.__name__} ON {PlacedItem.product_id} = {Product.product_id} | ||
JOIN {Placement.__tablename__} as {Placement.__name__} ON {PlacedItem.placement_id} = {Placement.placement_id} | ||
WHERE {Placement.canceled} = {int(canceled)} AND | ||
{Placement.completed} = {int(completed)} | ||
GROUP BY {PlacedItem.placement_id}, {PlacedItem.product_id} | ||
ORDER BY {PlacedItem.placement_id} ASC, {PlacedItem.product_id} ASC | ||
""" | ||
placements: list[dict[str, int | list[dict[str, int | str]] | str]] = [] | ||
prev_placement_id = -1 | ||
prev_products: list[dict[str, int | str]] = [] | ||
total_price = 0 | ||
|
||
async for map in database.iterate(query): | ||
placement_id = map["placement_id"] | ||
if placement_id != prev_placement_id: | ||
if prev_placement_id != -1: | ||
placements.append( | ||
{ | ||
"placement_id": prev_placement_id, | ||
"products": prev_products, | ||
"total_price": Product.to_price_str(total_price), | ||
} | ||
) | ||
prev_placement_id = placement_id | ||
prev_products = [] | ||
total_price = 0 | ||
count, price = map["count"], map["price"] | ||
prev_products.append( | ||
{ | ||
"product_id": map["product_id"], | ||
"count": count, | ||
"name": map["name"], | ||
"filename": map["filename"], | ||
"price": Product.to_price_str(price), | ||
} | ||
) | ||
total_price += count * price | ||
if prev_placement_id != -1: | ||
placements.append( | ||
{ | ||
"placement_id": prev_placement_id, | ||
"products": prev_products, | ||
"total_price": Product.to_price_str(total_price), | ||
} | ||
) | ||
return placements | ||
|
||
|
||
# NOTE:get placements by incoming order in datetime | ||
# TODO: add a datetime field to PlacedItem | ||
# | ||
# async def select_placements_by_incoming_order() -> dict[int, list[dict]]: | ||
# query = f""" | ||
# SELECT | ||
# {PlacedItem.placement_id}, | ||
# {PlacedItem.product_id}, | ||
# {Product.name}, | ||
# {Product.filename} | ||
# FROM {PlacedItem.__tablename__} | ||
# JOIN {Product.__tablename__} as {Product.__name__} ON {PlacedItem.product_id} = {Product.product_id} | ||
# ORDER BY {PlacedItem.placement_id} ASC, {PlacedItem.item_no} ASC; | ||
# """ | ||
# placements: dict[int, list[dict]] = {} | ||
# async for row in db.iterate(query): | ||
# print(dict(row)) | ||
# return placements | ||
|
||
|
||
class SortOrderedProductsBy(Enum): | ||
PRODUCT_ID = "product_id" | ||
TIME = "time" | ||
NO_ITEMS = "no_items" | ||
|
||
|
||
async def select_ordered_products( | ||
sort_by: SortOrderedProductsBy, | ||
canceled: bool, | ||
completed: bool, | ||
) -> list[dict[str, int | str | list[dict[str, int]]]]: | ||
match sort_by: | ||
case SortOrderedProductsBy.PRODUCT_ID: | ||
order_by = f"{PlacedItem.product_id} ASC" | ||
case SortOrderedProductsBy.TIME: | ||
# TODO: add datatime field to the placements table | ||
# order_by = Placement.datetime | ||
order_by = NotImplemented | ||
case SortOrderedProductsBy.NO_ITEMS: | ||
order_by = "count DESC" | ||
case _: | ||
assert_never(SortOrderedProductsBy) | ||
query = f""" | ||
SELECT | ||
{PlacedItem.placement_id}, | ||
COUNT({PlacedItem.product_id}) AS count, | ||
{PlacedItem.product_id}, | ||
{Product.name}, | ||
{Product.filename} | ||
FROM {PlacedItem.__tablename__} as {PlacedItem.__name__} | ||
JOIN {Product.__tablename__} as {Product.__name__} ON {PlacedItem.product_id} = {Product.product_id} | ||
JOIN {Placement.__tablename__} as {Placement.__name__} ON {PlacedItem.placement_id} = {Placement.placement_id} | ||
WHERE {Placement.canceled} = {int(canceled)} AND | ||
{Placement.completed} = {int(completed)} | ||
GROUP BY {PlacedItem.placement_id}, {PlacedItem.product_id} | ||
ORDER BY {order_by} | ||
""" | ||
ret: dict[int, dict[str, int | str | list[dict[str, int]]]] = {} | ||
async for map in database.iterate(query): | ||
product_id = map["product_id"] | ||
if (product_dict := ret.get(product_id)) is None: | ||
ret[product_id] = { | ||
"name": map["name"], | ||
"filename": map["filename"], | ||
"placements": [], | ||
} | ||
lst: ... = ret[product_id]["placements"] | ||
else: | ||
lst: ... = product_dict["placements"] | ||
lst.append( | ||
{ | ||
"placement_id": map["placement_id"], | ||
"count": map["count"], | ||
} | ||
) | ||
return list(ret.values()) | ||
|
||
|
||
async def _startup_db() -> None: | ||
await database.connect() | ||
|
||
# TODO: we should use a database schema migration tool like Alembic as explained in: | ||
# https://www.encode.io/databases/database_queries/#creating-tables | ||
for table in SQLModel.metadata.tables.values(): | ||
schema = sqlalchemy.schema.CreateTable(table, if_not_exists=True) | ||
query = str(schema.compile()) | ||
await database.execute(query) | ||
|
||
await ProductTable.ainit() | ||
await PlacedItemTable.ainit() | ||
|
||
|
||
async def _shutdown_db() -> None: | ||
await database.disconnect() | ||
|
||
|
||
startup_and_shutdown_db = (_startup_db, _shutdown_db) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import sqlmodel | ||
from databases import Database | ||
|
||
|
||
class PlacedItem(sqlmodel.SQLModel, table=True): | ||
# NOTE: there are no Pydantic ways to set the generated table's name, as per https://github.com/fastapi/sqlmodel/issues/159 | ||
__tablename__ = "placed_items" # type: ignore[reportAssignmentType] | ||
|
||
id: int | None = sqlmodel.Field(default=None, primary_key=True) | ||
placement_id: int | ||
item_no: int | ||
product_id: int | ||
|
||
|
||
class Table: | ||
_last_placement_id: int | None | ||
_db: Database | ||
|
||
def __init__(self, database: Database): | ||
self._db = database | ||
|
||
async def ainit(self) -> None: | ||
query = sqlmodel.func.max(PlacedItem.placement_id).select() | ||
self._last_placement_id = await self._db.fetch_val(query) | ||
|
||
async def select_all(self) -> list[PlacedItem]: | ||
query = sqlmodel.select(PlacedItem) | ||
return [PlacedItem.model_validate(m) async for m in self._db.iterate(query)] | ||
|
||
async def by_placement_id(self, placement_id: int) -> list[PlacedItem]: | ||
clause = PlacedItem.placement_id == placement_id | ||
query = sqlmodel.select(PlacedItem).where(clause) | ||
return [PlacedItem.model_validate(m) async for m in self._db.iterate(query)] | ||
|
||
async def issue(self, product_ids: list[int]) -> int: | ||
placement_id = (self._last_placement_id or 0) + 1 | ||
await self._db.execute_many( | ||
sqlmodel.insert(PlacedItem).values(placement_id=placement_id), | ||
[{"item_no": i, "product_id": pid} for i, pid in enumerate(product_ids)], | ||
) | ||
self._last_placement_id = placement_id | ||
return placement_id | ||
|
||
# NOTE: this function needs authorization since it destroys all receipts | ||
async def clear(self) -> None: | ||
await self._db.execute(sqlmodel.delete(PlacedItem)) | ||
self._last_placement_id = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from typing import Annotated | ||
|
||
import sqlmodel | ||
from databases import Database | ||
|
||
|
||
class Placement(sqlmodel.SQLModel, table=True): | ||
# NOTE: there are no Pydantic ways to set the generated table's name, as per https://github.com/fastapi/sqlmodel/issues/159 | ||
__tablename__ = "placements" # type: ignore[reportAssignmentType] | ||
|
||
id: int | None = sqlmodel.Field(default=None, primary_key=True) | ||
placement_id: int | ||
# Column(..., server_default=sqlalchemy.text("0")) | ||
canceled: Annotated[ | ||
bool, sqlmodel.Field(sa_column_kwargs={"server_default": sqlmodel.text("0")}) | ||
] | ||
# Column(..., server_default=sqlalchemy.text("0")) | ||
completed: Annotated[ | ||
bool, sqlmodel.Field(sa_column_kwargs={"server_default": sqlmodel.text("0")}) | ||
] | ||
|
||
|
||
class Table: | ||
def __init__(self, database: Database): | ||
self._db = database | ||
|
||
async def insert(self, placement_id: int) -> None: | ||
query = sqlmodel.insert(Placement) | ||
await self._db.execute(query, {"placement_id": placement_id}) | ||
|
||
async def update(self, placement_id: int, canceled: bool, completed: bool) -> None: | ||
clause = Placement.placement_id == placement_id | ||
# NOTE: I don't why, but this where clause argument does not typecheck | ||
query = sqlmodel.update(Placement).where(clause) # type: ignore[reportArgumentType] | ||
await self._db.execute(query, {"canceled": canceled, "completed": completed}) | ||
|
||
async def cancel(self, placement_id: int) -> None: | ||
await self.update(placement_id, canceled=True, completed=False) | ||
|
||
async def complete(self, placement_id: int) -> None: | ||
await self.update(placement_id, canceled=False, completed=True) | ||
|
||
async def reset(self, placement_id: int) -> None: | ||
await self.update(placement_id, canceled=False, completed=False) | ||
|
||
async def by_placement_id(self, placement_id: int) -> Placement | None: | ||
query = sqlmodel.select(Placement).where(Placement.placement_id == placement_id) | ||
row = await self._db.fetch_one(query) | ||
return Placement.model_validate(row) if row else None | ||
|
||
async def select_all(self) -> list[Placement]: | ||
query = sqlmodel.select(Placement) | ||
return [Placement.model_validate(m) async for m in self._db.iterate(query)] |
Oops, something went wrong.