Skip to content

Commit

Permalink
feature: use pydantic validators for message checks (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
gounux authored Jul 29, 2024
1 parent bbd26cb commit 8c59d07
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 80 deletions.
19 changes: 5 additions & 14 deletions gischat/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from fastapi.encoders import jsonable_encoder
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from pydantic import ValidationError
from starlette.websockets import WebSocketDisconnect

from gischat import INTERNAL_MESSAGE_AUTHOR
from gischat.models import (
InternalMessageModel,
MessageErrorModel,
MessageModel,
RulesModel,
StatusModel,
Expand Down Expand Up @@ -151,17 +151,10 @@ async def get_rules() -> RulesModel:
@app.put(
"/room/{room}/message",
response_model=MessageModel,
responses={420: {"model": MessageErrorModel}},
)
async def put_message(room: str, message: MessageModel) -> MessageModel:
if room not in notifier.connections.keys():
raise HTTPException(status_code=404, detail=f"Room '{room}' not registered")
ok, errors = message.check_validity()
if not ok:
logger.warning(f"Uncompliant message in room '{room}': {','.join(errors)}")
raise HTTPException(
status_code=420, detail={"message": "Uncompliant message", "errors": errors}
)
logger.info(f"Message in room '{room}': {message}")
await notifier.notify(room, json.dumps(jsonable_encoder(message)))
return message
Expand All @@ -181,12 +174,10 @@ async def websocket_endpoint(websocket: WebSocket, room: str) -> None:
try:
while True:
data = await websocket.receive_text()
message = MessageModel(**json.loads(data))
ok, errors = message.check_validity()
if not ok:
logger.warning(
f"Uncompliant message in room '{room}': {','.join(errors)}"
)
try:
message = MessageModel(**json.loads(data))
except ValidationError:
logger.error("Invalid message in websocket")
continue
logger.info(f"Message in room '{room}': {message}")
await notifier.notify(room, json.dumps(jsonable_encoder(message)))
Expand Down
56 changes: 10 additions & 46 deletions gischat/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from typing import Optional

from pydantic import BaseModel
from pydantic import BaseModel, Field


class VersionModel(BaseModel):
Expand All @@ -28,56 +28,20 @@ class RulesModel(BaseModel):


class MessageModel(BaseModel):
message: str
author: str
message: str = Field(
None, max_length=int(os.environ.get("MAX_MESSAGE_LENGTH", 255))
)
author: str = Field(
None,
min_length=int(os.environ.get("MIN_AUTHOR_LENGTH", 3)),
max_length=int(os.environ.get("MAX_AUTHOR_LENGTH", 32)),
pattern=r"^[a-z-A-Z-0-9-_]+$",
)
avatar: Optional[str] = None

def __str__(self) -> str:
return f"[{self.author}]: '{self.message}'"

def check_validity(self) -> tuple[bool, list[str]]:
"""
Checks if a message is compliant with the rules.
Rules:
- author must be alphanumeric
- author must have min length of 3
- message length must be max 255
:return: tuple with first element if message is ok,
second element with errors if any
"""
ok, errors = True, []

# check alphanum author
for c in self.author:
if not c.isalnum() and c not in ["-", "_"]:
ok = False
errors.append(f"Character not alphanumeric found in author: {c}")

# check author min length
min_author_length = int(os.environ.get("MIN_AUTHOR_LENGTH", 3))
if len(self.author) < min_author_length:
ok = False
errors.append(f"Author must have at least {min_author_length} characters")

# check author max length
max_author_length = int(os.environ.get("MAX_AUTHOR_LENGTH", 32))
if len(self.author) > max_author_length:
ok = False
errors.append(f"Author too long: max {max_author_length} characters")

# check message max length
max_message_length = int(os.environ.get("MAX_MESSAGE_LENGTH", 255))
if len(self.message) > max_message_length:
ok = False
errors.append(f"Message too long: max {max_message_length} characters")

return ok, errors


class MessageErrorModel(BaseModel):
message: str
errors: list[str]


class InternalMessageModel(BaseModel):
author: str
Expand Down
24 changes: 4 additions & 20 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,7 @@ def test_put_message_author_not_alphanum(client: TestClient, room: str):
f"/room/{room}/message",
json={"message": "fromage", "author": "<darth_chri$tian>"},
)
assert response.status_code == 420
payload = response.json()["detail"]
assert payload["message"] == "Uncompliant message"
assert "Character not alphanumeric found in author: <" in payload["errors"]
assert "Character not alphanumeric found in author: >" in payload["errors"]
assert "Character not alphanumeric found in author: $" in payload["errors"]
assert response.status_code == 422


@pytest.mark.parametrize("room", test_rooms())
Expand All @@ -90,12 +85,7 @@ def test_put_message_author_too_short(client: TestClient, room: str):
f"/room/{room}/message",
json={"message": "fromage", "author": "ch", "avatar": "postgis"},
)
assert response.status_code == 420
payload = response.json()["detail"]
assert payload["message"] == "Uncompliant message"
assert (
f"Author must have at least {MIN_AUTHOR_LENGTH} characters" in payload["errors"]
)
assert response.status_code == 422


@pytest.mark.parametrize("room", test_rooms())
Expand All @@ -105,10 +95,7 @@ def test_put_message_author_too_long(client: TestClient, room: str):
f"/room/{room}/message",
json={"message": "fromage", "author": author},
)
assert response.status_code == 420
payload = response.json()["detail"]
assert payload["message"] == "Uncompliant message"
assert f"Author too long: max {MAX_AUTHOR_LENGTH} characters" in payload["errors"]
assert response.status_code == 422


@pytest.mark.parametrize("room", test_rooms())
Expand All @@ -118,7 +105,4 @@ def test_put_message_too_long(client: TestClient, room: str):
f"/room/{room}/message",
json={"message": message, "author": "stephanie"},
)
assert response.status_code == 420
payload = response.json()["detail"]
assert payload["message"] == "Uncompliant message"
assert f"Message too long: max {MAX_MESSAGE_LENGTH} characters" in payload["errors"]
assert response.status_code == 422

0 comments on commit 8c59d07

Please sign in to comment.