Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add more type annotations through the code #4401

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,14 @@ build-backend = "poetry.core.masonry.api"
target-version = "py39"
output-format = "concise"
lint.isort.split-on-trailing-comma = false
lint.select = ["B", "C4", "D", "E", "ERA", "F", "FURB", "I", "PERF", "PTH", "RUF", "SIM", "W"]
lint.select = ["ANN001", "B", "C4", "D", "E", "ERA", "F", "FURB", "I", "PERF", "PTH", "RUF", "SIM", "W"]
lint.ignore = ["B008", "D205", "E501", "F403", "SIM115", "RUF006", "RUF012"]
lint.pydocstyle.convention = "google"

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]
"tests/*.py" = ["D100", "D103", "D104", "B018", "PERF"]
"tests/*.py" = ["ANN001", "D100", "D103", "D104", "B018", "PERF"]
"benchmarks/*.py" = ["ANN001"]
"reflex/.templates/*.py" = ["D100", "D103", "D104"]
"*.pyi" = ["D301", "D415", "D417", "D418", "E742"]
"*/blank.py" = ["I001"]
Expand Down
5 changes: 4 additions & 1 deletion reflex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@

from __future__ import annotations

from types import ModuleType
from typing import Any

from reflex.utils import (
compat, # for side-effects
lazy_loader,
Expand Down Expand Up @@ -366,7 +369,7 @@
)


def __getattr__(name):
def __getattr__(name: ModuleType | Any):
if name == "chakra":
from reflex.utils import console

Expand Down
29 changes: 14 additions & 15 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,8 +650,8 @@ def add_custom_404_page(
Args:
component: The component to display at the page.
title: The title of the page.
description: The description of the page.
image: The image to display on the page.
description: The description of the page.
on_load: The event handler(s) that will be called each time the page load.
meta: The metadata of the page.
"""
Expand Down Expand Up @@ -989,7 +989,7 @@ def get_compilation_time() -> str:
with executor:
result_futures = []

def _submit_work(fn, *args, **kwargs):
def _submit_work(fn: Callable, *args, **kwargs):
f = executor.submit(fn, *args, **kwargs)
result_futures.append(f)

Expand Down Expand Up @@ -1319,15 +1319,14 @@ async def process(
if app._process_background(state, event) is not None:
# `final=True` allows the frontend send more events immediately.
yield StateUpdate(final=True)
return

# Process the event synchronously.
async for update in state._process(event):
# Postprocess the event.
update = await app._postprocess(state, event, update)

# Yield the update.
yield update
else:
# Process the event synchronously.
async for update in state._process(event):
# Postprocess the event.
update = await app._postprocess(state, event, update)

# Yield the update.
yield update
except Exception as ex:
telemetry.send_error(ex, context="backend")

Expand Down Expand Up @@ -1520,7 +1519,7 @@ def __init__(self, namespace: str, app: App):
self.sid_to_token = {}
self.app = app

def on_connect(self, sid, environ):
def on_connect(self, sid: str, environ: dict):
"""Event for when the websocket is connected.

Args:
Expand All @@ -1529,7 +1528,7 @@ def on_connect(self, sid, environ):
"""
pass

def on_disconnect(self, sid):
def on_disconnect(self, sid: str):
"""Event for when the websocket disconnects.

Args:
Expand All @@ -1551,7 +1550,7 @@ async def emit_update(self, update: StateUpdate, sid: str) -> None:
self.emit(str(constants.SocketEvent.EVENT), update, to=sid)
)

async def on_event(self, sid, data):
async def on_event(self, sid: str, data: Any):
"""Event for receiving front-end websocket events.

Raises:
Expand Down Expand Up @@ -1594,7 +1593,7 @@ async def on_event(self, sid, data):
# Emit the update from processing the event.
await self.emit_update(update=update, sid=sid)

async def on_ping(self, sid):
async def on_ping(self, sid: str):
"""Event for testing the API endpoint.

Args:
Expand Down
2 changes: 1 addition & 1 deletion reflex/app_mixins/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):

Args:
task: The task to register.
task_kwargs: The kwargs of the task.
**task_kwargs: The kwargs of the task.

Raises:
InvalidLifespanTaskType: If the task is a generator function.
Expand Down
2 changes: 1 addition & 1 deletion reflex/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def json(self) -> str:
default=serialize,
)

def set(self, **kwargs):
def set(self, **kwargs: Any):
"""Set multiple fields and return the object.

Args:
Expand Down
2 changes: 1 addition & 1 deletion reflex/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def empty_dir(path: str | Path, keep_files: list[str] | None = None):
path_ops.rm(element)


def is_valid_url(url) -> bool:
def is_valid_url(url: str) -> bool:
"""Check if a url is valid.

Args:
Expand Down
8 changes: 4 additions & 4 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ def create(cls, *children, **props) -> Component:
# Filter out None props
props = {key: value for key, value in props.items() if value is not None}

def validate_children(children):
def validate_children(children: tuple):
for child in children:
if isinstance(child, tuple):
validate_children(child)
Expand Down Expand Up @@ -956,7 +956,7 @@ def _get_style(self) -> dict:
else {}
)

def render(self) -> Dict:
def render(self) -> dict:
"""Render the component.

Returns:
Expand All @@ -974,7 +974,7 @@ def render(self) -> Dict:
self._replace_prop_names(rendered_dict)
return rendered_dict

def _replace_prop_names(self, rendered_dict) -> None:
def _replace_prop_names(self, rendered_dict: dict) -> None:
"""Replace the prop names in the render dictionary.

Args:
Expand Down Expand Up @@ -1014,7 +1014,7 @@ def _validate_component_children(self, children: List[Component]):
comp.__name__ for comp in (Fragment, Foreach, Cond, Match)
]

def validate_child(child):
def validate_child(child: Any):
child_name = type(child).__name__

# Iterate through the immediate children of fragment
Expand Down
2 changes: 1 addition & 1 deletion reflex/components/core/client_side_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def render(self) -> str:
return ""


def wait_for_client_redirect(component) -> Component:
def wait_for_client_redirect(component: Component) -> Component:
"""Wait for a redirect to occur before rendering a component.

This prevents the 404 page from flashing while the redirect is happening.
Expand Down
2 changes: 1 addition & 1 deletion reflex/components/core/client_side_routing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class ClientSideRouting(Component):
"""
...

def wait_for_client_redirect(component) -> Component: ...
def wait_for_client_redirect(component: Component) -> Component: ...

class Default404Page(Component):
@overload
Expand Down
2 changes: 1 addition & 1 deletion reflex/components/core/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def cond(condition: Any, c1: Any, c2: Any = None) -> Component | Var:
if c2 is None:
raise ValueError("For conditional vars, the second argument must be set.")

def create_var(cond_part):
def create_var(cond_part: Any) -> Var[Any]:
return LiteralVar.create(cond_part)

# convert the truth and false cond parts into vars so the _var_data can be obtained.
Expand Down
2 changes: 1 addition & 1 deletion reflex/components/core/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _process_cases(
return cases, default

@classmethod
def _create_case_var_with_var_data(cls, case_element):
def _create_case_var_with_var_data(cls, case_element: Any) -> Var:
"""Convert a case element into a Var.If the case
is a Style type, we extract the var data and merge it with the
newly created Var.
Expand Down
6 changes: 2 additions & 4 deletions reflex/components/datadisplay/logo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@ def svg_logo(color: Union[str, rx.Var[str]] = rx.color_mode_cond("#110F1F", "whi
The Reflex logo SVG.
"""

def logo_path(d):
return rx.el.svg.path(
d=d,
)
def logo_path(d: str):
return rx.el.svg.path(d=d)

paths = [
"M0 11.5999V0.399902H8.96V4.8799H6.72V2.6399H2.24V4.8799H6.72V7.1199H2.24V11.5999H0ZM6.72 11.5999V7.1199H8.96V11.5999H6.72Z",
Expand Down
2 changes: 1 addition & 1 deletion reflex/components/el/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class Element(Component):
"""The base class for all raw HTML elements."""

def __eq__(self, other):
def __eq__(self, other: object):
"""Two elements are equal if they have the same tag.

Args:
Expand Down
8 changes: 5 additions & 3 deletions reflex/components/markdown/markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from hashlib import md5
from typing import Any, Callable, Dict, Sequence, Union

from reflex.components.component import Component, CustomComponent
from reflex.components.component import BaseComponent, Component, CustomComponent
from reflex.components.tags.tag import Tag
from reflex.utils import types
from reflex.utils.imports import ImportDict, ImportVar
Expand Down Expand Up @@ -379,7 +379,9 @@ def _get_map_fn_var_from_children(self, component: Component, tag: str) -> Var:
# fallback to the default fn Var creation if the component is not a MarkdownComponentMap.
return MarkdownComponentMap.create_map_fn_var(fn_body=formatted_component)

def _get_map_fn_custom_code_from_children(self, component) -> list[str]:
def _get_map_fn_custom_code_from_children(
self, component: BaseComponent
) -> list[str]:
"""Recursively get markdown custom code from children components.

Args:
Expand Down Expand Up @@ -409,7 +411,7 @@ def _get_map_fn_custom_code_from_children(self, component) -> list[str]:
return custom_code_list

@staticmethod
def _component_map_hash(component_map) -> str:
def _component_map_hash(component_map: dict) -> str:
inp = str(
{tag: component(_MOCK_ARG) for tag, component in component_map.items()}
).encode()
Expand Down
4 changes: 3 additions & 1 deletion reflex/components/next/image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Image component from next/image."""

from __future__ import annotations

from typing import Any, Literal, Optional, Union

from reflex.event import EventHandler, no_args_event_spec
Expand Down Expand Up @@ -83,7 +85,7 @@ def create(
style = props.get("style", {})
DEFAULT_W_H = "100%"

def check_prop_type(prop_name, prop_value):
def check_prop_type(prop_name: str, prop_value: int | str | None):
if types.check_prop_in_allowed_types(prop_value, allowed_types=[int]):
props[prop_name] = prop_value

Expand Down
2 changes: 1 addition & 1 deletion reflex/components/props.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def dict(self, *args, **kwargs):
class NoExtrasAllowedProps(Base):
"""A class that holds props to be passed or applied to a component with no extra props allowed."""

def __init__(self, component_name=None, **kwargs):
def __init__(self, component_name: str | None = None, **kwargs):
"""Initialize the props.

Args:
Expand Down
12 changes: 7 additions & 5 deletions reflex/components/radix/themes/color_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from __future__ import annotations

from typing import Dict, List, Literal, Optional, Union, get_args
from typing import Any, Dict, List, Literal, Optional, Union, get_args

from reflex.components.component import BaseComponent
from reflex.components.core.cond import Cond, color_mode_cond, cond
Expand Down Expand Up @@ -78,17 +78,19 @@ def create(


# needed to inverse contains for find
def _find(const: List[str], var):
def _find(const: List[str], var: Any):
return LiteralArrayVar.create(const).contains(var)


def _set_var_default(props, position, prop, default1, default2=""):
def _set_var_default(
props: dict, position: Any, prop: str, default1: str, default2: str = ""
):
props.setdefault(
prop, cond(_find(position_map[prop], position), default1, default2)
)


def _set_static_default(props, position, prop, default):
def _set_static_default(props: dict, position: Any, prop: str, default: str):
if prop in position:
props.setdefault(prop, default)

Expand Down Expand Up @@ -142,7 +144,7 @@ def create(

if allow_system:

def color_mode_item(_color_mode):
def color_mode_item(_color_mode: str):
return dropdown_menu.item(
_color_mode.title(), on_click=set_color_mode(_color_mode)
)
Expand Down
2 changes: 1 addition & 1 deletion reflex/components/radix/themes/layout/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class List(ComponentNamespace):
unordered_list = list_ns.unordered


def __getattr__(name):
def __getattr__(name: Any):
# special case for when accessing list to avoid shadowing
# python's built in list object.
if name == "list":
Expand Down
2 changes: 1 addition & 1 deletion reflex/components/recharts/charts.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _ensure_valid_dimension(name: str, value: Any) -> None:
)

@classmethod
def create(cls, *children, **props) -> Component:
def create(cls, *children: Any, **props: Any) -> Component:
"""Create a chart component.

Args:
Expand Down
Loading
Loading