Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use add_imports everywhere #3448

Merged
merged 10 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions reflex/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@
from reflex.state import BaseState, Cookie, LocalStorage
from reflex.style import Style
from reflex.utils import console, format, imports, path_ops
from reflex.utils.imports import ImportVar, ParsedImportDict
from reflex.vars import Var

# To re-export this function.
merge_imports = imports.merge_imports


def compile_import_statement(fields: list[imports.ImportVar]) -> tuple[str, list[str]]:
def compile_import_statement(fields: list[ImportVar]) -> tuple[str, list[str]]:
"""Compile an import statement.

Args:
Expand All @@ -59,7 +60,7 @@ def compile_import_statement(fields: list[imports.ImportVar]) -> tuple[str, list
return default, list(rest)


def validate_imports(import_dict: imports.ImportDict):
def validate_imports(import_dict: ParsedImportDict):
"""Verify that the same Tag is not used in multiple import.

Args:
Expand All @@ -82,7 +83,7 @@ def validate_imports(import_dict: imports.ImportDict):
used_tags[import_name] = lib


def compile_imports(import_dict: imports.ImportDict) -> list[dict]:
def compile_imports(import_dict: ParsedImportDict) -> list[dict]:
"""Compile an import dict.

Args:
Expand All @@ -91,7 +92,7 @@ def compile_imports(import_dict: imports.ImportDict) -> list[dict]:
Returns:
The list of import dict.
"""
collapsed_import_dict = imports.collapse_imports(import_dict)
collapsed_import_dict: ParsedImportDict = imports.collapse_imports(import_dict)
validate_imports(collapsed_import_dict)
import_dicts = []
for lib, fields in collapsed_import_dict.items():
Expand Down Expand Up @@ -231,7 +232,7 @@ def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]:

def compile_custom_component(
component: CustomComponent,
) -> tuple[dict, imports.ImportDict]:
) -> tuple[dict, ParsedImportDict]:
"""Compile a custom component.

Args:
Expand All @@ -244,7 +245,7 @@ def compile_custom_component(
render = component.get_component(component)

# Get the imports.
imports = {
imports: ParsedImportDict = {
lib: fields
for lib, fields in render._get_all_imports().items()
if lib != component.library
Expand Down
27 changes: 14 additions & 13 deletions reflex/components/chakra/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from typing import List, Literal

from reflex.components.component import Component
from reflex.utils import imports
from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var


class ChakraComponent(Component):
"""A component that wraps a Chakra component."""

library = "@chakra-ui/react@2.6.1"
library: str = "@chakra-ui/react@2.6.1" # type: ignore
lib_dependencies: List[str] = [
"@chakra-ui/system@2.5.7",
"framer-motion@10.16.4",
Expand All @@ -35,14 +35,14 @@ def _get_style(self) -> dict:

@classmethod
@lru_cache(maxsize=None)
def _get_dependencies_imports(cls) -> imports.ImportDict:
def _get_dependencies_imports(cls) -> ImportDict:
"""Get the imports from lib_dependencies for installing.

Returns:
The dependencies imports of the component.
"""
return {
dep: [imports.ImportVar(tag=None, render=False)]
dep: [ImportVar(tag=None, render=False)]
for dep in [
"@chakra-ui/system@2.5.7",
"framer-motion@10.16.4",
Expand All @@ -68,15 +68,16 @@ def create(cls) -> Component:
theme=Var.create("extendTheme(theme)", _var_is_local=False),
)

def _get_imports(self) -> imports.ImportDict:
_imports = super()._get_imports()
_imports.setdefault(self.__fields__["library"].default, []).append(
imports.ImportVar(tag="extendTheme", is_default=False),
)
_imports.setdefault("/utils/theme.js", []).append(
imports.ImportVar(tag="theme", is_default=True),
)
return _imports
def add_imports(self) -> ImportDict:
"""Add imports for the ChakraProvider component.

Returns:
The import dict for the component.
"""
return {
self.library: ImportVar(tag="extendTheme", is_default=False),
"/utils/theme.js": ImportVar(tag="theme", is_default=True),
}

@staticmethod
@lru_cache(maxsize=None)
Expand Down
3 changes: 2 additions & 1 deletion reflex/components/chakra/base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ from reflex.style import Style
from functools import lru_cache
from typing import List, Literal
from reflex.components.component import Component
from reflex.utils import imports
from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import Var

class ChakraComponent(Component):
Expand Down Expand Up @@ -155,6 +155,7 @@ class ChakraProvider(ChakraComponent):
A new ChakraProvider component.
"""
...
def add_imports(self) -> ImportDict: ...

chakra_provider = ChakraProvider.create()

Expand Down
14 changes: 8 additions & 6 deletions reflex/components/chakra/forms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from reflex.components.core.debounce import DebounceInput
from reflex.components.literals import LiteralInputType
from reflex.constants import EventTriggers, MemoizationMode
from reflex.utils import imports
from reflex.utils.imports import ImportDict
from reflex.vars import Var


Expand Down Expand Up @@ -59,11 +59,13 @@ class Input(ChakraComponent):
# The name of the form field
name: Var[str]

def _get_imports(self) -> imports.ImportDict:
return imports.merge_imports(
super()._get_imports(),
{"/utils/state": {imports.ImportVar(tag="set_val")}},
)
def add_imports(self) -> ImportDict:
"""Add imports for the Input component.

Returns:
The import dict.
"""
return {"/utils/state": "set_val"}

def get_event_triggers(self) -> Dict[str, Any]:
"""Get the event triggers that pass the component's value to the handler.
Expand Down
3 changes: 2 additions & 1 deletion reflex/components/chakra/forms/input.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ from reflex.components.component import Component
from reflex.components.core.debounce import DebounceInput
from reflex.components.literals import LiteralInputType
from reflex.constants import EventTriggers, MemoizationMode
from reflex.utils import imports
from reflex.utils.imports import ImportDict
from reflex.vars import Var

class Input(ChakraComponent):
def add_imports(self) -> ImportDict: ...
def get_event_triggers(self) -> Dict[str, Any]: ...
@overload
@classmethod
Expand Down
11 changes: 8 additions & 3 deletions reflex/components/chakra/navigation/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from reflex.components.chakra import ChakraComponent
from reflex.components.component import Component
from reflex.components.next.link import NextLink
from reflex.utils import imports
from reflex.utils.imports import ImportDict
from reflex.vars import BaseVar, Var

next_link = NextLink.create()
Expand All @@ -30,8 +30,13 @@ class Link(ChakraComponent):
# If true, the link will open in new tab.
is_external: Var[bool]

def _get_imports(self) -> imports.ImportDict:
return {**super()._get_imports(), **next_link._get_imports()}
def add_imports(self) -> ImportDict:
"""Add imports for the link component.

Returns:
The import dict.
"""
return next_link._get_imports() # type: ignore

@classmethod
def create(cls, *children, **props) -> Component:
Expand Down
3 changes: 2 additions & 1 deletion reflex/components/chakra/navigation/link.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ from reflex.style import Style
from reflex.components.chakra import ChakraComponent
from reflex.components.component import Component
from reflex.components.next.link import NextLink
from reflex.utils import imports
from reflex.utils.imports import ImportDict
from reflex.vars import BaseVar, Var

next_link = NextLink.create()

class Link(ChakraComponent):
def add_imports(self) -> ImportDict: ...
@overload
@classmethod
def create( # type: ignore
Expand Down
50 changes: 20 additions & 30 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
)
from reflex.style import Style, format_as_emotion
from reflex.utils import console, format, imports, types
from reflex.utils.imports import ImportVar
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
from reflex.utils.serializers import serializer
from reflex.vars import BaseVar, Var, VarData

Expand Down Expand Up @@ -95,7 +95,7 @@ def _get_all_hooks(self) -> dict[str, None]:
"""

@abstractmethod
def _get_all_imports(self) -> imports.ImportDict:
def _get_all_imports(self) -> ParsedImportDict:
"""Get all the libraries and fields that are used by the component.

Returns:
Expand Down Expand Up @@ -213,7 +213,7 @@ class Component(BaseComponent, ABC):
# State class associated with this component instance
State: Optional[Type[reflex.state.State]] = None

def add_imports(self) -> dict[str, str | ImportVar | list[str | ImportVar]]:
def add_imports(self) -> ImportDict | list[ImportDict]:
"""Add imports for the component.

This method should be implemented by subclasses to add new imports for the component.
Expand Down Expand Up @@ -1196,7 +1196,7 @@ def _get_all_dynamic_imports(self) -> Set[str]:
# Return the dynamic imports
return dynamic_imports

def _get_props_imports(self) -> List[str]:
def _get_props_imports(self) -> List[ParsedImportDict]:
"""Get the imports needed for components props.

Returns:
Expand All @@ -1222,7 +1222,7 @@ def _should_transpile(self, dep: str | None) -> bool:
or format.format_library_name(dep or "") in self.transpile_packages
)

def _get_dependencies_imports(self) -> imports.ImportDict:
def _get_dependencies_imports(self) -> ParsedImportDict:
"""Get the imports from lib_dependencies for installing.

Returns:
Expand All @@ -1239,7 +1239,7 @@ def _get_dependencies_imports(self) -> imports.ImportDict:
for dep in self.lib_dependencies
}

def _get_hooks_imports(self) -> imports.ImportDict:
def _get_hooks_imports(self) -> ParsedImportDict:
"""Get the imports required by certain hooks.

Returns:
Expand Down Expand Up @@ -1280,7 +1280,7 @@ def _get_hooks_imports(self) -> imports.ImportDict:

return imports.merge_imports(_imports, *other_imports)

def _get_imports(self) -> imports.ImportDict:
def _get_imports(self) -> ParsedImportDict:
"""Get all the libraries and fields that are used by the component.

Returns:
Expand All @@ -1300,25 +1300,15 @@ def _get_imports(self) -> imports.ImportDict:
var._var_data.imports for var in self._get_vars() if var._var_data
]

# If any subclass implements add_imports, merge the imports.
def _make_list(
value: str | ImportVar | list[str | ImportVar],
) -> list[str | ImportVar]:
if isinstance(value, (str, ImportVar)):
return [value]
return value

_added_import_dicts = []
added_import_dicts: list[ParsedImportDict] = []
for clz in self._iter_parent_classes_with_method("add_imports"):
_added_import_dicts.append(
{
package: [
ImportVar(tag=tag) if not isinstance(tag, ImportVar) else tag
for tag in _make_list(maybe_tags)
]
for package, maybe_tags in clz.add_imports(self).items()
}
)
list_of_import_dict = clz.add_imports(self)

if not isinstance(list_of_import_dict, list):
list_of_import_dict = [list_of_import_dict]

for import_dict in list_of_import_dict:
added_import_dicts.append(parse_imports(import_dict))

return imports.merge_imports(
*self._get_props_imports(),
Expand All @@ -1327,10 +1317,10 @@ def _make_list(
_imports,
event_imports,
*var_imports,
*_added_import_dicts,
*added_import_dicts,
)

def _get_all_imports(self, collapse: bool = False) -> imports.ImportDict:
def _get_all_imports(self, collapse: bool = False) -> ParsedImportDict:
"""Get all the libraries and fields that are used by the component and its children.

Args:
Expand Down Expand Up @@ -1425,7 +1415,7 @@ def _get_hooks_internal(self) -> dict[str, None]:
**self._get_special_hooks(),
}

def _get_added_hooks(self) -> dict[str, imports.ImportDict]:
def _get_added_hooks(self) -> dict[str, ImportDict]:
"""Get the hooks added via `add_hooks` method.

Returns:
Expand Down Expand Up @@ -1814,7 +1804,7 @@ def wrapper(*children, **props) -> CustomComponent:
class NoSSRComponent(Component):
"""A dynamic component that is not rendered on the server."""

def _get_imports(self) -> imports.ImportDict:
def _get_imports(self) -> ParsedImportDict:
"""Get the imports for the component.

Returns:
Expand Down Expand Up @@ -2157,7 +2147,7 @@ def _get_all_hooks(self) -> dict[str, None]:
"""
return {}

def _get_all_imports(self) -> imports.ImportDict:
def _get_all_imports(self) -> ParsedImportDict:
"""Get all the libraries and fields that are used by the component.

Returns:
Expand Down
Loading
Loading