Skip to content

Commit

Permalink
Overhaul the code structure around database logics
Browse files Browse the repository at this point in the history
This commit introduces a new SQLModel dependency to merge Pydantic model and SQLAlchemy table declarations into one.
  • Loading branch information
exflikt committed Sep 12, 2024
1 parent 3a28189 commit 745f473
Show file tree
Hide file tree
Showing 11 changed files with 526 additions and 424 deletions.
410 changes: 0 additions & 410 deletions app/db.py

This file was deleted.

202 changes: 202 additions & 0 deletions app/db/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
from enum import Enum
from typing import assert_never

import sqlalchemy
from databases import Database
from pydantic import BaseModel
from sqlmodel import SQLModel

from . import placed_item, placement, product
from .placed_item import PlacedItem
from .placement import Placement
from .product import Product

DATABASE_URL = "sqlite:///app.db"
database = Database(DATABASE_URL)


ProductTable = product.Table(database)
PlacedItemTable = placed_item.Table(database)
PlacementTable = placement.Table(database)


class PlacementReceipt(BaseModel):
class ProductEntry(BaseModel):
product_id: int
count: int
name: str
filename: str
price: int

placement_id: int
products: list[ProductEntry]
total_price: int

class Result(BaseModel):
placement_id: int
product_id: int
count: int
name: str
filename: str
price: int


async def select_placements(
canceled: bool,
completed: bool,
) -> list[dict[str, int | list[dict[str, int | str]] | str]]:
# list of products with each row alongside the number of ordered items
query = f"""
SELECT
{PlacedItem.placement_id},
{PlacedItem.product_id},
COUNT({PlacedItem.product_id}) AS count,
{Product.name},
{Product.filename},
{Product.price}
FROM {PlacedItem.__tablename__} as {PlacedItem.__name__}
JOIN {Product.__tablename__} as {Product.__name__} ON {PlacedItem.product_id} = {Product.product_id}
JOIN {Placement.__tablename__} as {Placement.__name__} ON {PlacedItem.placement_id} = {Placement.placement_id}
WHERE {Placement.canceled} = {int(canceled)} AND
{Placement.completed} = {int(completed)}
GROUP BY {PlacedItem.placement_id}, {PlacedItem.product_id}
ORDER BY {PlacedItem.placement_id} ASC, {PlacedItem.product_id} ASC
"""
placements: list[dict[str, int | list[dict[str, int | str]] | str]] = []
prev_placement_id = -1
prev_products: list[dict[str, int | str]] = []
total_price = 0

async for map in database.iterate(query):
placement_id = map["placement_id"]
if placement_id != prev_placement_id:
if prev_placement_id != -1:
placements.append(
{
"placement_id": prev_placement_id,
"products": prev_products,
"total_price": Product.to_price_str(total_price),
}
)
prev_placement_id = placement_id
prev_products = []
total_price = 0
count, price = map["count"], map["price"]
prev_products.append(
{
"product_id": map["product_id"],
"count": count,
"name": map["name"],
"filename": map["filename"],
"price": Product.to_price_str(price),
}
)
total_price += count * price
if prev_placement_id != -1:
placements.append(
{
"placement_id": prev_placement_id,
"products": prev_products,
"total_price": Product.to_price_str(total_price),
}
)
return placements


# NOTE:get placements by incoming order in datetime
# TODO: add a datetime field to PlacedItem
#
# async def select_placements_by_incoming_order() -> dict[int, list[dict]]:
# query = f"""
# SELECT
# {PlacedItem.placement_id},
# {PlacedItem.product_id},
# {Product.name},
# {Product.filename}
# FROM {PlacedItem.__tablename__}
# JOIN {Product.__tablename__} as {Product.__name__} ON {PlacedItem.product_id} = {Product.product_id}
# ORDER BY {PlacedItem.placement_id} ASC, {PlacedItem.item_no} ASC;
# """
# placements: dict[int, list[dict]] = {}
# async for row in db.iterate(query):
# print(dict(row))
# return placements


class SortOrderedProductsBy(Enum):
PRODUCT_ID = "product_id"
TIME = "time"
NO_ITEMS = "no_items"


async def select_ordered_products(
sort_by: SortOrderedProductsBy,
canceled: bool,
completed: bool,
) -> list[dict[str, int | str | list[dict[str, int]]]]:
match sort_by:
case SortOrderedProductsBy.PRODUCT_ID:
order_by = f"{PlacedItem.product_id} ASC"
case SortOrderedProductsBy.TIME:
# TODO: add datatime field to the placements table
# order_by = Placement.datetime
order_by = NotImplemented
case SortOrderedProductsBy.NO_ITEMS:
order_by = "count DESC"
case _:
assert_never(SortOrderedProductsBy)
query = f"""
SELECT
{PlacedItem.placement_id},
COUNT({PlacedItem.product_id}) AS count,
{PlacedItem.product_id},
{Product.name},
{Product.filename}
FROM {PlacedItem.__tablename__} as {PlacedItem.__name__}
JOIN {Product.__tablename__} as {Product.__name__} ON {PlacedItem.product_id} = {Product.product_id}
JOIN {Placement.__tablename__} as {Placement.__name__} ON {PlacedItem.placement_id} = {Placement.placement_id}
WHERE {Placement.canceled} = {int(canceled)} AND
{Placement.completed} = {int(completed)}
GROUP BY {PlacedItem.placement_id}, {PlacedItem.product_id}
ORDER BY {order_by}
"""
ret: dict[int, dict[str, int | str | list[dict[str, int]]]] = {}
async for map in database.iterate(query):
product_id = map["product_id"]
if (product_dict := ret.get(product_id)) is None:
ret[product_id] = {
"name": map["name"],
"filename": map["filename"],
"placements": [],
}
lst: ... = ret[product_id]["placements"]
else:
lst: ... = product_dict["placements"]
lst.append(
{
"placement_id": map["placement_id"],
"count": map["count"],
}
)
return list(ret.values())


async def _startup_db() -> None:
await database.connect()

# TODO: we should use a database schema migration tool like Alembic as explained in:
# https://www.encode.io/databases/database_queries/#creating-tables
for table in SQLModel.metadata.tables.values():
schema = sqlalchemy.schema.CreateTable(table, if_not_exists=True)
query = str(schema.compile())
await database.execute(query)

await ProductTable.ainit()
await PlacedItemTable.ainit()


async def _shutdown_db() -> None:
await database.disconnect()


startup_and_shutdown_db = (_startup_db, _shutdown_db)
47 changes: 47 additions & 0 deletions app/db/placed_item.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import sqlmodel
from databases import Database


class PlacedItem(sqlmodel.SQLModel, table=True):
# NOTE: there are no Pydantic ways to set the generated table's name, as per https://github.com/fastapi/sqlmodel/issues/159
__tablename__ = "placed_items" # type: ignore[reportAssignmentType]

id: int | None = sqlmodel.Field(default=None, primary_key=True)
placement_id: int
item_no: int
product_id: int


class Table:
_last_placement_id: int | None
_db: Database

def __init__(self, database: Database):
self._db = database

async def ainit(self) -> None:
query = sqlmodel.func.max(PlacedItem.placement_id).select()
self._last_placement_id = await self._db.fetch_val(query)

async def select_all(self) -> list[PlacedItem]:
query = sqlmodel.select(PlacedItem)
return [PlacedItem.model_validate(m) async for m in self._db.iterate(query)]

async def by_placement_id(self, placement_id: int) -> list[PlacedItem]:
clause = PlacedItem.placement_id == placement_id
query = sqlmodel.select(PlacedItem).where(clause)
return [PlacedItem.model_validate(m) async for m in self._db.iterate(query)]

async def issue(self, product_ids: list[int]) -> int:
placement_id = (self._last_placement_id or 0) + 1
await self._db.execute_many(
sqlmodel.insert(PlacedItem).values(placement_id=placement_id),
[{"item_no": i, "product_id": pid} for i, pid in enumerate(product_ids)],
)
self._last_placement_id = placement_id
return placement_id

# NOTE: this function needs authorization since it destroys all receipts
async def clear(self) -> None:
await self._db.execute(sqlmodel.delete(PlacedItem))
self._last_placement_id = None
53 changes: 53 additions & 0 deletions app/db/placement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Annotated

import sqlmodel
from databases import Database


class Placement(sqlmodel.SQLModel, table=True):
# NOTE: there are no Pydantic ways to set the generated table's name, as per https://github.com/fastapi/sqlmodel/issues/159
__tablename__ = "placements" # type: ignore[reportAssignmentType]

id: int | None = sqlmodel.Field(default=None, primary_key=True)
placement_id: int
# Column(..., server_default=sqlalchemy.text("0"))
canceled: Annotated[
bool, sqlmodel.Field(sa_column_kwargs={"server_default": sqlmodel.text("0")})
]
# Column(..., server_default=sqlalchemy.text("0"))
completed: Annotated[
bool, sqlmodel.Field(sa_column_kwargs={"server_default": sqlmodel.text("0")})
]


class Table:
def __init__(self, database: Database):
self._db = database

async def insert(self, placement_id: int) -> None:
query = sqlmodel.insert(Placement)
await self._db.execute(query, {"placement_id": placement_id})

async def update(self, placement_id: int, canceled: bool, completed: bool) -> None:
clause = Placement.placement_id == placement_id
# NOTE: I don't why, but this where clause argument does not typecheck
query = sqlmodel.update(Placement).where(clause) # type: ignore[reportArgumentType]
await self._db.execute(query, {"canceled": canceled, "completed": completed})

async def cancel(self, placement_id: int) -> None:
await self.update(placement_id, canceled=True, completed=False)

async def complete(self, placement_id: int) -> None:
await self.update(placement_id, canceled=False, completed=True)

async def reset(self, placement_id: int) -> None:
await self.update(placement_id, canceled=False, completed=False)

async def by_placement_id(self, placement_id: int) -> Placement | None:
query = sqlmodel.select(Placement).where(Placement.placement_id == placement_id)
row = await self._db.fetch_one(query)
return Placement.model_validate(row) if row else None

async def select_all(self) -> list[Placement]:
query = sqlmodel.select(Placement)
return [Placement.model_validate(m) async for m in self._db.iterate(query)]
Loading

0 comments on commit 745f473

Please sign in to comment.