Skip to content

Commit

Permalink
use add_imports everywhere (#3448)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lendemor authored Jun 12, 2024
1 parent 991f6e0 commit 462b023
Show file tree
Hide file tree
Showing 40 changed files with 469 additions and 304 deletions.
13 changes: 7 additions & 6 deletions reflex/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@
from reflex.state import BaseState, Cookie, LocalStorage
from reflex.style import Style
from reflex.utils import console, format, imports, path_ops
from reflex.utils.imports import ImportVar, ParsedImportDict
from reflex.vars import Var

# To re-export this function.
merge_imports = imports.merge_imports


def compile_import_statement(fields: list[imports.ImportVar]) -> tuple[str, list[str]]:
def compile_import_statement(fields: list[ImportVar]) -> tuple[str, list[str]]:
"""Compile an import statement.
Args:
Expand All @@ -59,7 +60,7 @@ def compile_import_statement(fields: list[imports.ImportVar]) -> tuple[str, list
return default, list(rest)


def validate_imports(import_dict: imports.ImportDict):
def validate_imports(import_dict: ParsedImportDict):
"""Verify that the same Tag is not used in multiple import.
Args:
Expand All @@ -82,7 +83,7 @@ def validate_imports(import_dict: imports.ImportDict):
used_tags[import_name] = lib


def compile_imports(import_dict: imports.ImportDict) -> list[dict]:
def compile_imports(import_dict: ParsedImportDict) -> list[dict]:
"""Compile an import dict.
Args:
Expand All @@ -91,7 +92,7 @@ def compile_imports(import_dict: imports.ImportDict) -> list[dict]:
Returns:
The list of import dict.
"""
collapsed_import_dict = imports.collapse_imports(import_dict)
collapsed_import_dict: ParsedImportDict = imports.collapse_imports(import_dict)
validate_imports(collapsed_import_dict)
import_dicts = []
for lib, fields in collapsed_import_dict.items():
Expand Down Expand Up @@ -231,7 +232,7 @@ def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:

def compile_custom_component(
component: CustomComponent,
) -> tuple[dict, imports.ImportDict]:
) -> tuple[dict, ParsedImportDict]:
"""Compile a custom component.
Args:
Expand All @@ -244,7 +245,7 @@ def compile_custom_component(
render = component.get_component(component)

# Get the imports.
imports = {
imports: ParsedImportDict = {
lib: fields
for lib, fields in render._get_all_imports().items()
if lib != component.library
Expand Down
27 changes: 14 additions & 13 deletions reflex/components/chakra/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from typing import List, Literal

from reflex.components.component import Component
from reflex.utils import imports
from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var


class ChakraComponent(Component):
"""A component that wraps a Chakra component."""

library = "@chakra-ui/react@2.6.1"
library: str = "@chakra-ui/react@2.6.1" # type: ignore
lib_dependencies: List[str] = [
"@chakra-ui/system@2.5.7",
"framer-motion@10.16.4",
Expand All @@ -35,14 +35,14 @@ def _get_style(self) -> dict:

@classmethod
@lru_cache(maxsize=None)
def _get_dependencies_imports(cls) -> imports.ImportDict:
def _get_dependencies_imports(cls) -> ImportDict:
"""Get the imports from lib_dependencies for installing.
Returns:
The dependencies imports of the component.
"""
return {
dep: [imports.ImportVar(tag=None, render=False)]
dep: [ImportVar(tag=None, render=False)]
for dep in [
"@chakra-ui/system@2.5.7",
"framer-motion@10.16.4",
Expand Down Expand Up @@ -70,15 +70,16 @@ def create(cls) -> Component:
),
)

def _get_imports(self) -> imports.ImportDict:
_imports = super()._get_imports()
_imports.setdefault(self.__fields__["library"].default, []).append(
imports.ImportVar(tag="extendTheme", is_default=False),
)
_imports.setdefault("/utils/theme.js", []).append(
imports.ImportVar(tag="theme", is_default=True),
)
return _imports
def add_imports(self) -> ImportDict:
"""Add imports for the ChakraProvider component.
Returns:
The import dict for the component.
"""
return {
self.library: ImportVar(tag="extendTheme", is_default=False),
"/utils/theme.js": ImportVar(tag="theme", is_default=True),
}

@staticmethod
@lru_cache(maxsize=None)
Expand Down
3 changes: 2 additions & 1 deletion reflex/components/chakra/base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ from reflex.style import Style
from functools import lru_cache
from typing import List, Literal
from reflex.components.component import Component
from reflex.utils import imports
from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var

class ChakraComponent(Component):
Expand Down Expand Up @@ -155,6 +155,7 @@ class ChakraProvider(ChakraComponent):
A new ChakraProvider component.
"""
...
def add_imports(self) -> ImportDict: ...

chakra_provider = ChakraProvider.create()

Expand Down
14 changes: 8 additions & 6 deletions reflex/components/chakra/forms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from reflex.components.core.debounce import DebounceInput
from reflex.components.literals import LiteralInputType
from reflex.constants import EventTriggers, MemoizationMode
from reflex.utils import imports
from reflex.utils.imports import ImportDict
from reflex.vars import Var


Expand Down Expand Up @@ -59,11 +59,13 @@ class Input(ChakraComponent):
# The name of the form field
name: Var[str]

def _get_imports(self) -> imports.ImportDict:
return imports.merge_imports(
super()._get_imports(),
{"/utils/state": {imports.ImportVar(tag="set_val")}},
)
def add_imports(self) -> ImportDict:
"""Add imports for the Input component.
Returns:
The import dict.
"""
return {"/utils/state": "set_val"}

def get_event_triggers(self) -> Dict[str, Any]:
"""Get the event triggers that pass the component's value to the handler.
Expand Down
3 changes: 2 additions & 1 deletion reflex/components/chakra/forms/input.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ from reflex.components.component import Component
from reflex.components.core.debounce import DebounceInput
from reflex.components.literals import LiteralInputType
from reflex.constants import EventTriggers, MemoizationMode
from reflex.utils import imports
from reflex.utils.imports import ImportDict
from reflex.vars import Var

class Input(ChakraComponent):
def add_imports(self) -> ImportDict: ...
def get_event_triggers(self) -> Dict[str, Any]: ...
@overload
@classmethod
Expand Down
11 changes: 8 additions & 3 deletions reflex/components/chakra/navigation/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from reflex.components.chakra import ChakraComponent
from reflex.components.component import Component
from reflex.components.next.link import NextLink
from reflex.utils import imports
from reflex.utils.imports import ImportDict
from reflex.vars import BaseVar, Var

next_link = NextLink.create()
Expand Down Expand Up @@ -32,8 +32,13 @@ class Link(ChakraComponent):
# If true, the link will open in new tab.
is_external: Var[bool]

def _get_imports(self) -> imports.ImportDict:
return {**super()._get_imports(), **next_link._get_imports()}
def add_imports(self) -> ImportDict:
"""Add imports for the link component.
Returns:
The import dict.
"""
return next_link._get_imports() # type: ignore

@classmethod
def create(cls, *children, **props) -> Component:
Expand Down
3 changes: 2 additions & 1 deletion reflex/components/chakra/navigation/link.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ from reflex.style import Style
from reflex.components.chakra import ChakraComponent
from reflex.components.component import Component
from reflex.components.next.link import NextLink
from reflex.utils import imports
from reflex.utils.imports import ImportDict
from reflex.vars import BaseVar, Var

next_link = NextLink.create()

class Link(ChakraComponent):
def add_imports(self) -> ImportDict: ...
@overload
@classmethod
def create( # type: ignore
Expand Down
50 changes: 20 additions & 30 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
)
from reflex.style import Style, format_as_emotion
from reflex.utils import console, format, imports, types
from reflex.utils.imports import ImportVar
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
from reflex.utils.serializers import serializer
from reflex.vars import BaseVar, Var, VarData

Expand Down Expand Up @@ -95,7 +95,7 @@ def _get_all_hooks(self) -> dict[str, None]:
"""

@abstractmethod
def _get_all_imports(self) -> imports.ImportDict:
def _get_all_imports(self) -> ParsedImportDict:
"""Get all the libraries and fields that are used by the component.
Returns:
Expand Down Expand Up @@ -213,7 +213,7 @@ class Component(BaseComponent, ABC):
# State class associated with this component instance
State: Optional[Type[reflex.state.State]] = None

def add_imports(self) -> dict[str, str | ImportVar | list[str | ImportVar]]:
def add_imports(self) -> ImportDict | list[ImportDict]:
"""Add imports for the component.
This method should be implemented by subclasses to add new imports for the component.
Expand Down Expand Up @@ -1224,7 +1224,7 @@ def _get_all_dynamic_imports(self) -> Set[str]:
# Return the dynamic imports
return dynamic_imports

def _get_props_imports(self) -> List[str]:
def _get_props_imports(self) -> List[ParsedImportDict]:
"""Get the imports needed for components props.
Returns:
Expand All @@ -1250,7 +1250,7 @@ def _should_transpile(self, dep: str | None) -> bool:
or format.format_library_name(dep or "") in self.transpile_packages
)

def _get_dependencies_imports(self) -> imports.ImportDict:
def _get_dependencies_imports(self) -> ParsedImportDict:
"""Get the imports from lib_dependencies for installing.
Returns:
Expand All @@ -1267,7 +1267,7 @@ def _get_dependencies_imports(self) -> imports.ImportDict:
for dep in self.lib_dependencies
}

def _get_hooks_imports(self) -> imports.ImportDict:
def _get_hooks_imports(self) -> ParsedImportDict:
"""Get the imports required by certain hooks.
Returns:
Expand Down Expand Up @@ -1308,7 +1308,7 @@ def _get_hooks_imports(self) -> imports.ImportDict:

return imports.merge_imports(_imports, *other_imports)

def _get_imports(self) -> imports.ImportDict:
def _get_imports(self) -> ParsedImportDict:
"""Get all the libraries and fields that are used by the component.
Returns:
Expand All @@ -1328,25 +1328,15 @@ def _get_imports(self) -> imports.ImportDict:
var._var_data.imports for var in self._get_vars() if var._var_data
]

# If any subclass implements add_imports, merge the imports.
def _make_list(
value: str | ImportVar | list[str | ImportVar],
) -> list[str | ImportVar]:
if isinstance(value, (str, ImportVar)):
return [value]
return value

_added_import_dicts = []
added_import_dicts: list[ParsedImportDict] = []
for clz in self._iter_parent_classes_with_method("add_imports"):
_added_import_dicts.append(
{
package: [
ImportVar(tag=tag) if not isinstance(tag, ImportVar) else tag
for tag in _make_list(maybe_tags)
]
for package, maybe_tags in clz.add_imports(self).items()
}
)
list_of_import_dict = clz.add_imports(self)

if not isinstance(list_of_import_dict, list):
list_of_import_dict = [list_of_import_dict]

for import_dict in list_of_import_dict:
added_import_dicts.append(parse_imports(import_dict))

return imports.merge_imports(
*self._get_props_imports(),
Expand All @@ -1355,10 +1345,10 @@ def _make_list(
_imports,
event_imports,
*var_imports,
*_added_import_dicts,
*added_import_dicts,
)

def _get_all_imports(self, collapse: bool = False) -> imports.ImportDict:
def _get_all_imports(self, collapse: bool = False) -> ParsedImportDict:
"""Get all the libraries and fields that are used by the component and its children.
Args:
Expand Down Expand Up @@ -1453,7 +1443,7 @@ def _get_hooks_internal(self) -> dict[str, None]:
**self._get_special_hooks(),
}

def _get_added_hooks(self) -> dict[str, imports.ImportDict]:
def _get_added_hooks(self) -> dict[str, ImportDict]:
"""Get the hooks added via `add_hooks` method.
Returns:
Expand Down Expand Up @@ -1842,7 +1832,7 @@ def wrapper(*children, **props) -> CustomComponent:
class NoSSRComponent(Component):
"""A dynamic component that is not rendered on the server."""

def _get_imports(self) -> imports.ImportDict:
def _get_imports(self) -> ParsedImportDict:
"""Get the imports for the component.
Returns:
Expand Down Expand Up @@ -2185,7 +2175,7 @@ def _get_all_hooks(self) -> dict[str, None]:
"""
return {}

def _get_all_imports(self) -> imports.ImportDict:
def _get_all_imports(self) -> ParsedImportDict:
"""Get all the libraries and fields that are used by the component.
Returns:
Expand Down
Loading

0 comments on commit 462b023

Please sign in to comment.