Skip to content

Commit

Permalink
Use model_validate built into pydantic to parse DB models
Browse files Browse the repository at this point in the history
`pydantic` models have a deprecated `from_orm` method, which is designed to parse DB models into the schema.
This was superseded by `model_validate`, which does the same thing when the schema config `from_attributes` value is true.
Migrating to using this simplifies the method of parsing DB models, and will eventually allow the schemas to be used internally instead of the DB models.
  • Loading branch information
UpstreamData committed Sep 17, 2024
1 parent 18abbb3 commit a740559
Show file tree
Hide file tree
Showing 17 changed files with 139 additions and 144 deletions.
3 changes: 2 additions & 1 deletion goosebit/api/v1/devices/device/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ async def device_get(_: Request, updater: UpdateManager = Depends(get_update_man
device = await updater.get_device()
if device is None:
raise HTTPException(404)
return await DeviceResponse.convert(device)
await device.fetch_related("assigned_software", "hardware")
return DeviceResponse.model_validate(device)


@router.get(
Expand Down
7 changes: 0 additions & 7 deletions goosebit/api/v1/devices/responses.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
from __future__ import annotations

import asyncio

from pydantic import BaseModel

from goosebit.db.models import Device
from goosebit.schema.devices import DeviceSchema


class DevicesResponse(BaseModel):
devices: list[DeviceSchema]

@classmethod
async def convert(cls, devices: list[Device]):
return cls(devices=await asyncio.gather(*[DeviceSchema.convert(d) for d in devices]))
3 changes: 2 additions & 1 deletion goosebit/api/v1/devices/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
dependencies=[Security(validate_user_permissions, scopes=["home.read"])],
)
async def devices_get(_: Request) -> DevicesResponse:
return await DevicesResponse.convert(await Device.all().prefetch_related("assigned_software", "hardware"))
devices = await Device.all().prefetch_related("assigned_software", "hardware")
return DevicesResponse(devices=devices)


@router.delete(
Expand Down
6 changes: 4 additions & 2 deletions goosebit/api/v1/download/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
from fastapi.responses import FileResponse, RedirectResponse

from goosebit.db.models import Software
from goosebit.schema.software import SoftwareSchema

router = APIRouter(prefix="/download", tags=["download"])


@router.get("/{file_id}")
async def download_file(_: Request, file_id: int):
software = await Software.get_or_none(id=file_id)
if software is None:
software_model = await Software.get_or_none(id=file_id).prefetch_related("compatibility")
if software_model is None:
raise HTTPException(404)
software = SoftwareSchema.model_validate(software_model)
if software.local:
return FileResponse(
software.path,
Expand Down
9 changes: 2 additions & 7 deletions goosebit/api/v1/rollouts/responses.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
import asyncio
from __future__ import annotations

from pydantic import BaseModel

from goosebit.api.responses import StatusResponse
from goosebit.db.models import Rollout
from goosebit.schema.rollouts import RolloutSchema


class RolloutsPutResponse(StatusResponse):
id: int
id: int | None = None


class RolloutsResponse(BaseModel):
rollouts: list[RolloutSchema]

@classmethod
async def convert(cls, devices: list[Rollout]):
return cls(rollouts=await asyncio.gather(*[RolloutSchema.convert(d) for d in devices]))
10 changes: 7 additions & 3 deletions goosebit/api/v1/rollouts/routes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from fastapi import APIRouter, Security
from fastapi import APIRouter, HTTPException, Security
from fastapi.requests import Request

from goosebit.api.responses import StatusResponse
from goosebit.auth import validate_user_permissions
from goosebit.db.models import Rollout
from goosebit.db.models import Rollout, Software

from .requests import RolloutsDeleteRequest, RolloutsPatchRequest, RolloutsPutRequest
from .responses import RolloutsPutResponse, RolloutsResponse
Expand All @@ -16,14 +16,18 @@
dependencies=[Security(validate_user_permissions, scopes=["rollout.read"])],
)
async def rollouts_get(_: Request) -> RolloutsResponse:
return await RolloutsResponse.convert(await Rollout.all().prefetch_related("software"))
rollouts = await Rollout.all().prefetch_related("software", "software__compatibility")
return RolloutsResponse(rollouts=rollouts)


@router.post(
"",
dependencies=[Security(validate_user_permissions, scopes=["rollout.write"])],
)
async def rollouts_put(_: Request, rollout: RolloutsPutRequest) -> RolloutsPutResponse:
software = await Software.filter(id=rollout.software_id)
if len(software) == 0:
raise HTTPException(404, f"No software with ID {rollout.software_id} found")
rollout = await Rollout.create(
name=rollout.name,
feed=rollout.feed,
Expand Down
7 changes: 0 additions & 7 deletions goosebit/api/v1/software/responses.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
from __future__ import annotations

import asyncio

from pydantic import BaseModel

from goosebit.db.models import Software
from goosebit.schema.software import SoftwareSchema


class SoftwareResponse(BaseModel):
software: list[SoftwareSchema]

@classmethod
async def convert(cls, software: list[Software]):
return cls(software=await asyncio.gather(*[SoftwareSchema.convert(f) for f in software]))
13 changes: 9 additions & 4 deletions goosebit/api/v1/software/routes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import random
import string

Expand All @@ -8,6 +10,7 @@
from goosebit.api.responses import StatusResponse
from goosebit.auth import validate_user_permissions
from goosebit.db.models import Rollout, Software
from goosebit.schema.software import SoftwareSchema
from goosebit.settings import config
from goosebit.updates import create_software_update

Expand All @@ -22,7 +25,7 @@
dependencies=[Security(validate_user_permissions, scopes=["software.read"])],
)
async def software_get(_: Request) -> SoftwareResponse:
return await SoftwareResponse.convert(await Software.all().prefetch_related("compatibility"))
return SoftwareResponse(software=await Software.all().prefetch_related("compatibility"))


@router.delete(
Expand All @@ -32,11 +35,13 @@ async def software_get(_: Request) -> SoftwareResponse:
async def software_delete(_: Request, delete_req: SoftwareDeleteRequest) -> StatusResponse:
success = False
for software_id in delete_req.software_ids:
software = await Software.get_or_none(id=software_id)
software_model = await Software.get_or_none(id=software_id)

if software is None:
if software_model is None:
continue

software = SoftwareSchema.model_validate(software_model)

rollout_count = await Rollout.filter(software=software).count()
if rollout_count > 0:
raise HTTPException(409, "Software is referenced by rollout")
Expand All @@ -46,7 +51,7 @@ async def software_delete(_: Request, delete_req: SoftwareDeleteRequest) -> Stat
if await path.exists():
await path.unlink()

await software.delete()
await software_model.delete()
success = True
return StatusResponse(success=success)

Expand Down
20 changes: 2 additions & 18 deletions goosebit/db/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations

from enum import IntEnum
from typing import Self
from urllib.parse import unquote, urlparse
from urllib.request import url2pathname

import semver
from anyio import Path
from tortoise import Model, fields

from goosebit.api.telemetry.metrics import devices_count
Expand Down Expand Up @@ -130,18 +129,3 @@ async def latest(cls, device: Device) -> Self | None:
key=lambda x: semver.Version.parse(x.version),
reverse=True,
)[0]

@property
def path(self) -> Path:
return Path(url2pathname(unquote(urlparse(self.uri).path)))

@property
def local(self) -> bool:
return urlparse(self.uri).scheme == "file"

@property
def path_user(self) -> str:
if self.local:
return self.path.name
else:
return self.uri
74 changes: 39 additions & 35 deletions goosebit/schema/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from enum import Enum, IntEnum, StrEnum
from typing import Annotated

from pydantic import BaseModel, BeforeValidator, computed_field
from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, computed_field

from goosebit.db.models import Device, UpdateModeEnum, UpdateStateEnum
from goosebit.updater.manager import get_update_manager
from goosebit.db.models import UpdateModeEnum, UpdateStateEnum
from goosebit.schema.software import HardwareSchema, SoftwareSchema
from goosebit.updater.manager import DeviceUpdateManager


class ConvertableEnum(StrEnum):
Expand All @@ -26,48 +27,51 @@ def enum_factory(name: str, base: type[Enum]) -> type[ConvertableEnum]:


class DeviceSchema(BaseModel):
model_config = ConfigDict(from_attributes=True)

uuid: str
name: str | None
sw_version: str | None
sw_target_version: str | None
sw_assigned: int | None
hw_model: str
hw_revision: str

assigned_software: SoftwareSchema | None = Field(exclude=True)
hardware: HardwareSchema | None = Field(exclude=True)

feed: str
progress: int | None
last_state: Annotated[UpdateStateSchema, BeforeValidator(UpdateStateSchema.convert)] # type: ignore[valid-type]
update_mode: Annotated[UpdateModeSchema, BeforeValidator(UpdateModeSchema.convert)] # type: ignore[valid-type]
force_update: bool
last_ip: str | None
last_seen: int | None
poll_seconds: int
last_seen: Annotated[
int | None, BeforeValidator(lambda last_seen: round(time.time() - last_seen) if last_seen is not None else None)
]

@computed_field
@computed_field # type: ignore[misc]
@property
def online(self) -> bool | None:
return self.last_seen < self.poll_seconds if self.last_seen is not None else None

@classmethod
async def convert(cls, device: Device):
manager = await get_update_manager(device.uuid)
_, target_software = await manager.get_update()
last_seen = device.last_seen
if last_seen is not None:
last_seen = round(time.time() - device.last_seen)

return cls(
uuid=device.uuid,
name=device.name,
sw_version=device.sw_version,
sw_target_version=(target_software.version if target_software is not None else None),
sw_assigned=(device.assigned_software.id if device.assigned_software is not None else None),
hw_model=device.hardware.model,
hw_revision=device.hardware.revision,
feed=device.feed,
progress=device.progress,
last_state=device.last_state,
update_mode=device.update_mode,
force_update=device.force_update,
last_ip=device.last_ip,
last_seen=last_seen,
poll_seconds=manager.poll_seconds,
)
@computed_field # type: ignore[misc]
@property
def sw_target_version(self) -> str | None:
return self.assigned_software.version if self.assigned_software is not None else None

@computed_field # type: ignore[misc]
@property
def sw_assigned(self) -> int | None:
return self.assigned_software.id if self.assigned_software is not None else None

@computed_field # type: ignore[misc]
@property
def hw_model(self) -> str | None:
return self.hardware.model if self.hardware is not None else None

@computed_field # type: ignore[misc]
@property
def hw_revision(self) -> str | None:
return self.hardware.revision if self.hardware is not None else None

@computed_field # type: ignore[misc]
@property
def poll_seconds(self) -> int:
return DeviceUpdateManager(self.uuid).poll_seconds
39 changes: 21 additions & 18 deletions goosebit/schema/rollouts.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,34 @@
from __future__ import annotations

from pydantic import BaseModel
from datetime import datetime

from goosebit.db.models import Rollout
from pydantic import BaseModel, ConfigDict, Field, computed_field, field_serializer

from goosebit.schema.software import SoftwareSchema


class RolloutSchema(BaseModel):
model_config = ConfigDict(from_attributes=True)

id: int
created_at: int
created_at: datetime
name: str | None
feed: str
sw_file: str
sw_version: str
software: SoftwareSchema = Field(exclude=True)
paused: bool
success_count: int
failure_count: int

@classmethod
async def convert(cls, rollout: Rollout):
return cls(
id=rollout.id,
created_at=int(rollout.created_at.timestamp() * 1000),
name=rollout.name,
feed=rollout.feed,
sw_file=rollout.software.path.name,
sw_version=rollout.software.version,
paused=rollout.paused,
success_count=rollout.success_count,
failure_count=rollout.failure_count,
)
@computed_field # type: ignore[misc]
@property
def sw_version(self) -> str:
return self.software.version

@computed_field # type: ignore[misc]
@property
def sw_file(self) -> str:
return self.software.path.name

@field_serializer("created_at")
def serialize_created_at(self, created_at: datetime, _info):
return int(created_at.timestamp() * 1000)
Loading

0 comments on commit a740559

Please sign in to comment.