Skip to content

Commit

Permalink
Merge pull request #168 from Ed-XCF/feature/hybrid-property-to-dict
Browse files Browse the repository at this point in the history
Feature/hybrid property to dict
  • Loading branch information
JoshYuJump authored Feb 27, 2023
2 parents 59fe451 + c500dc1 commit cf61b9b
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 11 deletions.
65 changes: 65 additions & 0 deletions bali/db/session_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import functools
from typing import Tuple, TypeVar

from bali.db import db
from sqlalchemy import inspect
from sqlalchemy.exc import InvalidRequestError

__all__ = ["merge_session", "remove_session", "automatic_session_property"]

T = TypeVar("T")


def merge_session(obj: T) -> Tuple[T, bool]:
if inspect(obj).transient:
return obj, False

if inspect(obj).session is not None:
return obj, False

try:
return db.session.merge(obj, load=False), True
except InvalidRequestError:
return db.session.merge(obj), True


def remove_session(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
finally:
db.session.remove()

return wrapper


class automatic_session_property: # noqa
def __init__(self, func):
self._func = func
self._copy_func_info()

def _copy_func_info(self):
if self._func is None:
return

for member_name in [
"__doc__",
"__name__",
"__module__",
]:
value = getattr(self._func, member_name)
setattr(self, member_name, value)

def __get__(self, instance, _):
if instance is None:
return self

this, is_merged = merge_session(instance)
try:
return self._func(this)
finally:
is_merged and db.session.remove()

def __call__(self, func):
return self.__get__(func, None)
50 changes: 50 additions & 0 deletions bali/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import asyncio
import logging
from functools import wraps
from importlib import import_module
from typing import Callable

from fastapi import FastAPI

__all__ = ["add_background_tasks", "run_every"]

logger = logging.getLogger('bali')
background_tasks = set()


def add_background_tasks(app: FastAPI, module="tasks"):
import_module(module)
on_startup = app.on_event("startup")
for i in background_tasks:
logger.info("Find Task %s", i.__name__)
on_startup(i)


def run_every(seconds: float):
def decorator(func: Callable):
@wraps(func)
def wrapped():
async def task():
task_name = func.__name__
is_coroutine = asyncio.iscoroutinefunction(func)

while True:
logger.info("Start task %s", task_name)
try:
if is_coroutine:
await func()
else:
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, func)
except Exception as e:
logger.error("Task %s raise an error: %s", task_name, repr(e))
else:
logger.info("Task %s done", task_name)
finally:
await asyncio.sleep(seconds)

asyncio.ensure_future(task())

background_tasks.add(wrapped)

return decorator
52 changes: 41 additions & 11 deletions bali/utils/timezone.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import calendar
import os
from datetime import datetime, date, timedelta
from datetime import datetime, date, timedelta, time
from typing import Union
from typing_extensions import Literal

import pytz
from dateutil.relativedelta import relativedelta

TzInfoType = Union[type(pytz.utc), pytz.tzinfo.DstTzInfo]
StrTzInfoType = Union[TzInfoType, str]
Expand Down Expand Up @@ -80,22 +82,50 @@ def localdate(value: datetime = None, timezone: StrTzInfoType = None) -> date:
return localtime(value, timezone).date()


GRANULARITY = Literal["year", "month", "week", "day"]


def start_of(
granularity: str,
value: datetime = None,
*,
timezone: StrTzInfoType = None,
granularity: GRANULARITY,
value: datetime = None,
*,
timezone: StrTzInfoType = None,
) -> datetime:
value = localtime(value, timezone)
value = localtime(value, timezone=timezone)
if granularity == "year":
value = value.replace(month=1, day=1)
result = value.replace(month=1, day=1)
elif granularity == "month":
value = value.replace(day=1)
result = value.replace(day=1)
elif granularity == "week":
value = value - timedelta(days=calendar.weekday(value.year, value.month, value.day))
result = value - timedelta(
days=calendar.weekday(value.year, value.month, value.day)
)
elif granularity == "day":
pass
result = value
else:
raise ValueError("Granularity must be year, month, week or day")

return make_aware(datetime.combine(result, time.min))


def end_of(
granularity: GRANULARITY,
value: datetime = None,
*,
timezone: StrTzInfoType = None,
) -> datetime:
value = localtime(value, timezone=timezone)
if granularity == "year":
result = value.replace(month=12, day=31)
elif granularity == "month":
result = value + relativedelta(day=1, months=1, days=-1)
elif granularity == "week":
result = value + timedelta(
days=6 - calendar.weekday(value.year, value.month, value.day)
)
elif granularity == "day":
result = value
else:
raise ValueError("Granularity must be year, month, week or day")

return value.replace(hour=0, minute=0, second=0, microsecond=0)
return make_aware(datetime.combine(result, time.max))

0 comments on commit cf61b9b

Please sign in to comment.