Skip to content

Commit

Permalink
chore(typing): add some more typing to frappe.__init__ (frappe#28215)
Browse files Browse the repository at this point in the history
  • Loading branch information
blaggacao authored Oct 21, 2024
1 parent 91a737d commit 8d6f8bc
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 62 deletions.
121 changes: 61 additions & 60 deletions frappe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,24 @@
import traceback
import warnings
from collections import defaultdict
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeAlias, overload
from collections.abc import Callable, Iterable
from typing import (
TYPE_CHECKING,
Any,
Generic,
Literal,
Optional,
TypeAlias,
TypeVar,
Union,
overload,
)

import click
from werkzeug.local import Local, release_local
from werkzeug.local import Local, LocalProxy, release_local

import frappe
from frappe.query_builder import (
from frappe.query_builder.utils import (
get_query,
get_query_builder,
patch_query_aggregation,
Expand All @@ -55,41 +65,26 @@
__version__ = "16.0.0-dev"
__title__ = "Frappe Framework"

# This if block is never executed when running the code. It is only used for
# telling static code analyzer where to find dynamically defined attributes.
if TYPE_CHECKING: # pragma: no cover
from logging import Logger
from types import ModuleType

from werkzeug.wrappers import Request

from frappe.database.mariadb.database import MariaDBDatabase
from frappe.database.postgres.database import PostgresDatabase
from frappe.email.doctype.email_queue.email_queue import EmailQueue
from frappe.model.document import Document
from frappe.query_builder.builder import MariaDB, Postgres
from frappe.types.lazytranslatedstring import _LazyTranslate
from frappe.utils.redis_wrapper import RedisWrapper

db: MariaDBDatabase | PostgresDatabase
qb: MariaDB | Postgres
cache: RedisWrapper
response: _dict
conf: _dict
form_dict: _dict
flags: _dict
request: Request
session: _dict
user: str
flags: _dict
lang: str


# end: static analysis hack


controllers = {}
controllers: dict[str, "Document"] = {}
local = Local()
cache = None
cache: Optional["RedisWrapper"] = None
STANDARD_USERS = ("Guest", "Administrator")

_qb_patched = {}
_qb_patched: dict[str, bool] = {}
_dev_server = int(sbool(os.environ.get("DEV_SERVER", False)))
_tune_gc = bool(sbool(os.environ.get("FRAPPE_TUNE_GC", True)))

Expand Down Expand Up @@ -134,7 +129,7 @@ def _(msg: str, lang: str | None = None, context: str | None = None) -> str:
return translated_string or non_translated_string


def _lt(msg: str, lang: str | None = None, context: str | None = None):
def _lt(msg: str, lang: str | None = None, context: str | None = None) -> "_LazyTranslate":
"""Lazily translate a string.
Expand Down Expand Up @@ -171,26 +166,30 @@ def set_user_lang(user: str, user_language: str | None = None) -> None:


# local-globals

db = local("db")
qb = local("qb")
conf = local("conf")
form = form_dict = local("form_dict")
request = local("request")
db: LocalProxy[Union["MariaDBDatabase", "PostgresDatabase"]] = local("db")
qb: LocalProxy[Union["MariaDB", "Postgres"]] = local("qb")
conf: LocalProxy[_dict[str, Any]] = local("conf") # type: ignore[no-any-explicit]
form_dict: LocalProxy[_dict[str, str]] = local("form_dict")
form = form_dict
request: LocalProxy["Request"] = local("request")
job = local("job")
response = local("response")
session = local("session")
user = local("user")
flags = local("flags")
response: LocalProxy[_dict[str, Any]] = local("response") # type: ignore[no-any-explicit]
# TODO: make session a dataclass instead of undtyped _dict
SettionType = _dict[str, Any]
session: LocalProxy[SettionType] = local("session") # type: ignore[no-any-explicit]
user: LocalProxy[str] = local("user")
flags: LocalProxy[_dict[str, Any]] = local("flags") # type: ignore[no-any-explicit]

error_log = local("error_log")
debug_log = local("debug_log")
message_log = local("message_log")
error_log: LocalProxy[list[dict[str, str]]] = local("error_log")
debug_log: LocalProxy[list[str]] = local("debug_log")
# TODO: implement dataclass
LogMessageType = _dict[str, Any]
message_log: LocalProxy[list[LogMessageType]] = local("message_log")

lang = local("lang")
lang: LocalProxy[str] = local("lang")


def init(site: str, sites_path: str = ".", new_site: bool = False, force=False) -> None:
def init(site: str, sites_path: str = ".", new_site: bool = False, force: bool = False) -> None:
"""Initialize frappe for the current site. Reset thread locals `frappe.local`"""
if getattr(local, "initialised", None) and not force:
return
Expand All @@ -214,7 +213,7 @@ def init(site: str, sites_path: str = ".", new_site: bool = False, force=False)
"read_only": False,
}
)
local.locked_documents = []
local.locked_documents: list["Document"] = []
local.test_objects = defaultdict(list)

local.site = site
Expand Down Expand Up @@ -308,7 +307,7 @@ def connect(site: str | None = None, db_name: str | None = None, set_admin_as_us
def connect_replica() -> bool:
from frappe.database import get_db

if local and hasattr(local, "replica_db") and hasattr(local, "primary_db"):
if hasattr(local, "replica_db") and hasattr(local, "primary_db"):
return False

user = local.conf.db_user
Expand All @@ -335,10 +334,10 @@ def connect_replica() -> bool:
return True


def get_site_config(sites_path: str | None = None, site_path: str | None = None) -> dict[str, Any]:
def get_site_config(sites_path: str | None = None, site_path: str | None = None) -> _dict[str, Any]:
"""Return `site_config.json` combined with `sites/common_site_config.json`.
`site_config` is a set of site wide settings like database name, password, email etc."""
config = _dict()
config: _dict[str, Any] = _dict()

sites_path = sites_path or getattr(local, "sites_path", None)
site_path = site_path or getattr(local, "site_path", None)
Expand Down Expand Up @@ -417,7 +416,7 @@ def db_default_ports(db_type):
return config


def get_common_site_config(sites_path: str | None = None) -> dict[str, Any]:
def get_common_site_config(sites_path: str | None = None) -> _dict[str, Any]:
"""Return common site config as dictionary.
This is useful for:
Expand All @@ -436,7 +435,7 @@ def get_common_site_config(sites_path: str | None = None) -> dict[str, Any]:
return _dict()


def get_conf(site: str | None = None) -> dict[str, Any]:
def get_conf(site: str | None = None) -> _dict[str, Any]:
if hasattr(local, "conf"):
return local.conf

Expand Down Expand Up @@ -838,10 +837,10 @@ def sendmail(
return builder.process(send_now=now)


whitelisted = set()
guest_methods = set()
xss_safe_methods = set()
allowed_http_methods_for_whitelisted_func = {}
whitelisted: set[Callable] = set()
guest_methods: set[Callable] = set()
xss_safe_methods: set[Callable] = set()
allowed_http_methods_for_whitelisted_func: dict[Callable, list[str]] = {}


def whitelist(allow_guest=False, xss_safe=False, methods=None):
Expand Down Expand Up @@ -924,7 +923,7 @@ def wrapper_fn(*args, **kwargs):
try:
retval = fn(*args, **get_newargs(fn, kwargs))
finally:
if switched_connection and local and hasattr(local, "primary_db"):
if switched_connection and hasattr(local, "primary_db"):
local.db.close()
local.db = local.primary_db

Expand Down Expand Up @@ -1208,7 +1207,7 @@ def set_value(doctype, docname, fieldname, value=None):
return frappe.client.set_value(doctype, docname, fieldname, value)


def get_cached_doc(*args, **kwargs) -> "Document":
def get_cached_doc(*args: Any, **kwargs: Any) -> "Document":
"""Identical to `frappe.get_doc`, but return from cache if available."""
if (key := can_cache_doc(args)) and (doc := cache.get_value(key)):
return doc
Expand Down Expand Up @@ -1269,7 +1268,9 @@ def clear_in_redis():
delattr(local, "website_settings")


def get_cached_value(doctype: str, name: str, fieldname: str = "name", as_dict: bool = False) -> Any:
def get_cached_value(
doctype: str, name: str, fieldname: str | Iterable[str] = "name", as_dict: bool = False
) -> Any:
try:
doc = get_cached_doc(doctype, name)
except DoesNotExistError:
Expand Down Expand Up @@ -1322,7 +1323,7 @@ def get_doc(documentdict: dict) -> "_NewDocument":
pass


def get_doc(*args, **kwargs):
def get_doc(*args: Any, **kwargs: Any) -> "Document":
"""Return a `frappe.model.document.Document` object of the given type and name.
:param arg1: DocType name as string **or** document JSON.
Expand Down Expand Up @@ -1481,7 +1482,7 @@ def rename_doc(
)


def get_module(modulename):
def get_module(modulename: str) -> "ModuleType":
"""Return a module object for given Python module name using `importlib.import_module`."""
return importlib.import_module(modulename)

Expand Down Expand Up @@ -1570,7 +1571,7 @@ def get_all_apps(with_internal_apps=True, sites_path=None):


@request_cache
def get_installed_apps(*, _ensure_on_bench=False) -> list[str]:
def get_installed_apps(*, _ensure_on_bench: bool = False) -> list[str]:
"""
Get list of installed apps in current site.
Expand Down Expand Up @@ -2342,8 +2343,8 @@ def _get_doctype_app():
return local_cache("doctype_app", doctype, generator=_get_doctype_app)


loggers = {}
log_level = None
loggers: dict[str, "Logger"] = {}
log_level: int | None = None


def logger(module=None, with_more_info=False, allow_site=True, filter=None, max_size=100_000, file_count=20):
Expand Down
2 changes: 1 addition & 1 deletion frappe/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def remove_blanks(d: dict) -> dict:
return d


def strip_html_tags(text):
def strip_html_tags(text: str) -> str:
"""Remove html tags from the given `text`."""
return HTML_TAGS_PATTERN.sub("", text)

Expand Down
2 changes: 1 addition & 1 deletion frappe/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,7 +1160,7 @@ def cstr(s, encoding="utf-8") -> str:
return frappe.as_unicode(s, encoding)


def sbool(x: str) -> bool | Any:
def sbool(x: str | Any) -> bool | str | Any:
"""Convert str object to Boolean if possible.
Example:
Expand Down
Loading

0 comments on commit 8d6f8bc

Please sign in to comment.