Skip to content

Commit

Permalink
repeat
Browse files Browse the repository at this point in the history
  • Loading branch information
WSL0809 committed Apr 24, 2024
1 parent dfb42d3 commit 54a21de
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 5 deletions.
2 changes: 2 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ class RoomStatus(Enum):


class ClientStatus(Enum):
# 已入住
in_there = 0
# 出院
out = 1
# 手动创建
manual_create = 2
Expand Down
13 changes: 10 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fastapi.security import OAuth2PasswordRequestForm
from jose import jwt, JWTError
from sqlalchemy.orm import Session
from schedules import repeat_every

import crud
import utils
Expand Down Expand Up @@ -74,7 +75,7 @@ def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None


def get_current_user(
token: str = Depends(utils.oauth2_scheme), db: Session = Depends(get_db)
token: str = Depends(utils.oauth2_scheme), db: Session = Depends(get_db)
):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
Expand Down Expand Up @@ -105,7 +106,7 @@ async def get_current_active_user(current_user: User = Depends(get_current_user)

@app.post("/token")
async def login_for_access_token(
form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)
form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)
):
user = get_user(db, form_data.username)
access_token_expires = timedelta(minutes=utils.ACCESS_TOKEN_EXPIRE_MINUTES)
Expand Down Expand Up @@ -135,7 +136,7 @@ async def register_user(user: UserCreate, db: Session = Depends(get_db)):

@app.get("/get_users")
async def get_users(
current_user: User = Depends(get_current_active_user), db: Session = Depends(get_db)
current_user: User = Depends(get_current_active_user), db: Session = Depends(get_db)
):
if current_user.role == "admin" and current_user.is_active:
return crud.get_users(db)
Expand All @@ -150,3 +151,9 @@ async def get_users(
@app.get("/hello")
async def hello():
return {"hello"}


@app.on_event("startup")
@repeat_every(seconds=10)
async def test_hello():
print('hello')
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ cryptography==42.0.5
decorator==5.1.1
defusedxml==0.7.1
docopt==0.6.2
ecdsa==0.18.0
ecdsa==0.19.0
exceptiongroup==1.2.0
executing==2.0.1
fastapi==0.109.2
Expand Down Expand Up @@ -84,4 +84,4 @@ watchfiles==0.21.0
wcwidth==0.2.13
webencodings==0.5.1
websockets==12.0
yarg==0.1.9
yarg==0.1.9
1 change: 1 addition & 0 deletions schedules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .decorator_timer import repeat_every
57 changes: 57 additions & 0 deletions schedules/decorator_timer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import asyncio
import logging
from asyncio import ensure_future
from functools import wraps
from traceback import format_exception
from typing import Any, Callable, Coroutine, Optional, Union

from starlette.concurrency import run_in_threadpool

NoArgsNoReturnFuncT = Callable[[], None]
NoArgsNoReturnAsyncFuncT = Callable[[], Coroutine[Any, Any, None]]
NoArgsNoReturnDecorator = Callable[[Union[NoArgsNoReturnFuncT, NoArgsNoReturnAsyncFuncT]], NoArgsNoReturnAsyncFuncT]


def repeat_every(
*,
seconds: float,
wait_first: bool = False,
logger: Optional[logging.Logger] = None,
raise_exceptions: bool = False,
max_repetitions: Optional[int] = None,
) -> NoArgsNoReturnDecorator:

def decorator(func: Union[NoArgsNoReturnAsyncFuncT, NoArgsNoReturnFuncT]) -> NoArgsNoReturnAsyncFuncT:
"""
Converts the decorated function into a repeated, periodically-called version of itself.
"""
is_coroutine = asyncio.iscoroutinefunction(func)

@wraps(func)
async def wrapped() -> None:
repetitions = 0

async def loop() -> None:
nonlocal repetitions
if wait_first:
await asyncio.sleep(seconds)
while max_repetitions is None or repetitions < max_repetitions:
try:
if is_coroutine:
await func() # type: ignore
else:
await run_in_threadpool(func)
repetitions += 1
except Exception as exc:
if logger is not None:
formatted_exception = "".join(format_exception(type(exc), exc, exc.__traceback__))
logger.error(formatted_exception)
if raise_exceptions:
raise exc
await asyncio.sleep(seconds)

ensure_future(loop())

return wrapped

return decorator

0 comments on commit 54a21de

Please sign in to comment.