diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 9f389d8118..cc2d805fa0 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -28,13 +28,14 @@ from reflex.state import BaseState, Cookie, LocalStorage from reflex.style import Style from reflex.utils import console, format, imports, path_ops +from reflex.utils.imports import ImportVar, ParsedImportDict from reflex.vars import Var # To re-export this function. merge_imports = imports.merge_imports -def compile_import_statement(fields: list[imports.ImportVar]) -> tuple[str, list[str]]: +def compile_import_statement(fields: list[ImportVar]) -> tuple[str, list[str]]: """Compile an import statement. Args: @@ -59,7 +60,7 @@ def compile_import_statement(fields: list[imports.ImportVar]) -> tuple[str, list return default, list(rest) -def validate_imports(import_dict: imports.ImportDict): +def validate_imports(import_dict: ParsedImportDict): """Verify that the same Tag is not used in multiple import. Args: @@ -82,7 +83,7 @@ def validate_imports(import_dict: imports.ImportDict): used_tags[import_name] = lib -def compile_imports(import_dict: imports.ImportDict) -> list[dict]: +def compile_imports(import_dict: ParsedImportDict) -> list[dict]: """Compile an import dict. Args: @@ -91,7 +92,7 @@ def compile_imports(import_dict: imports.ImportDict) -> list[dict]: Returns: The list of import dict. """ - collapsed_import_dict = imports.collapse_imports(import_dict) + collapsed_import_dict: ParsedImportDict = imports.collapse_imports(import_dict) validate_imports(collapsed_import_dict) import_dicts = [] for lib, fields in collapsed_import_dict.items(): @@ -231,7 +232,7 @@ def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]: def compile_custom_component( component: CustomComponent, -) -> tuple[dict, imports.ImportDict]: +) -> tuple[dict, ParsedImportDict]: """Compile a custom component. Args: @@ -244,7 +245,7 @@ def compile_custom_component( render = component.get_component(component) # Get the imports. - imports = { + imports: ParsedImportDict = { lib: fields for lib, fields in render._get_all_imports().items() if lib != component.library diff --git a/reflex/components/chakra/base.py b/reflex/components/chakra/base.py index b0f43b3135..a80d88cd6c 100644 --- a/reflex/components/chakra/base.py +++ b/reflex/components/chakra/base.py @@ -5,14 +5,14 @@ from typing import List, Literal from reflex.components.component import Component -from reflex.utils import imports +from reflex.utils.imports import ImportDict, ImportVar from reflex.vars import Var class ChakraComponent(Component): """A component that wraps a Chakra component.""" - library = "@chakra-ui/react@2.6.1" + library: str = "@chakra-ui/react@2.6.1" # type: ignore lib_dependencies: List[str] = [ "@chakra-ui/system@2.5.7", "framer-motion@10.16.4", @@ -35,14 +35,14 @@ def _get_style(self) -> dict: @classmethod @lru_cache(maxsize=None) - def _get_dependencies_imports(cls) -> imports.ImportDict: + def _get_dependencies_imports(cls) -> ImportDict: """Get the imports from lib_dependencies for installing. Returns: The dependencies imports of the component. """ return { - dep: [imports.ImportVar(tag=None, render=False)] + dep: [ImportVar(tag=None, render=False)] for dep in [ "@chakra-ui/system@2.5.7", "framer-motion@10.16.4", @@ -70,15 +70,16 @@ def create(cls) -> Component: ), ) - def _get_imports(self) -> imports.ImportDict: - _imports = super()._get_imports() - _imports.setdefault(self.__fields__["library"].default, []).append( - imports.ImportVar(tag="extendTheme", is_default=False), - ) - _imports.setdefault("/utils/theme.js", []).append( - imports.ImportVar(tag="theme", is_default=True), - ) - return _imports + def add_imports(self) -> ImportDict: + """Add imports for the ChakraProvider component. + + Returns: + The import dict for the component. + """ + return { + self.library: ImportVar(tag="extendTheme", is_default=False), + "/utils/theme.js": ImportVar(tag="theme", is_default=True), + } @staticmethod @lru_cache(maxsize=None) diff --git a/reflex/components/chakra/base.pyi b/reflex/components/chakra/base.pyi index be99b2f90c..d209e48fb3 100644 --- a/reflex/components/chakra/base.pyi +++ b/reflex/components/chakra/base.pyi @@ -10,7 +10,7 @@ from reflex.style import Style from functools import lru_cache from typing import List, Literal from reflex.components.component import Component -from reflex.utils import imports +from reflex.utils.imports import ImportDict, ImportVar from reflex.vars import Var class ChakraComponent(Component): @@ -155,6 +155,7 @@ class ChakraProvider(ChakraComponent): A new ChakraProvider component. """ ... + def add_imports(self) -> ImportDict: ... chakra_provider = ChakraProvider.create() diff --git a/reflex/components/chakra/forms/input.py b/reflex/components/chakra/forms/input.py index 4512a4f48c..6152b41970 100644 --- a/reflex/components/chakra/forms/input.py +++ b/reflex/components/chakra/forms/input.py @@ -11,7 +11,7 @@ from reflex.components.core.debounce import DebounceInput from reflex.components.literals import LiteralInputType from reflex.constants import EventTriggers, MemoizationMode -from reflex.utils import imports +from reflex.utils.imports import ImportDict from reflex.vars import Var @@ -59,11 +59,13 @@ class Input(ChakraComponent): # The name of the form field name: Var[str] - def _get_imports(self) -> imports.ImportDict: - return imports.merge_imports( - super()._get_imports(), - {"/utils/state": {imports.ImportVar(tag="set_val")}}, - ) + def add_imports(self) -> ImportDict: + """Add imports for the Input component. + + Returns: + The import dict. + """ + return {"/utils/state": "set_val"} def get_event_triggers(self) -> Dict[str, Any]: """Get the event triggers that pass the component's value to the handler. diff --git a/reflex/components/chakra/forms/input.pyi b/reflex/components/chakra/forms/input.pyi index 475ab8ae44..18eab5c266 100644 --- a/reflex/components/chakra/forms/input.pyi +++ b/reflex/components/chakra/forms/input.pyi @@ -17,10 +17,11 @@ from reflex.components.component import Component from reflex.components.core.debounce import DebounceInput from reflex.components.literals import LiteralInputType from reflex.constants import EventTriggers, MemoizationMode -from reflex.utils import imports +from reflex.utils.imports import ImportDict from reflex.vars import Var class Input(ChakraComponent): + def add_imports(self) -> ImportDict: ... def get_event_triggers(self) -> Dict[str, Any]: ... @overload @classmethod diff --git a/reflex/components/chakra/navigation/link.py b/reflex/components/chakra/navigation/link.py index 238473b399..c3e6e71c28 100644 --- a/reflex/components/chakra/navigation/link.py +++ b/reflex/components/chakra/navigation/link.py @@ -4,7 +4,7 @@ from reflex.components.chakra import ChakraComponent from reflex.components.component import Component from reflex.components.next.link import NextLink -from reflex.utils import imports +from reflex.utils.imports import ImportDict from reflex.vars import BaseVar, Var next_link = NextLink.create() @@ -32,8 +32,13 @@ class Link(ChakraComponent): # If true, the link will open in new tab. is_external: Var[bool] - def _get_imports(self) -> imports.ImportDict: - return {**super()._get_imports(), **next_link._get_imports()} + def add_imports(self) -> ImportDict: + """Add imports for the link component. + + Returns: + The import dict. + """ + return next_link._get_imports() # type: ignore @classmethod def create(cls, *children, **props) -> Component: diff --git a/reflex/components/chakra/navigation/link.pyi b/reflex/components/chakra/navigation/link.pyi index 583ced1e16..ea037c1555 100644 --- a/reflex/components/chakra/navigation/link.pyi +++ b/reflex/components/chakra/navigation/link.pyi @@ -10,12 +10,13 @@ from reflex.style import Style from reflex.components.chakra import ChakraComponent from reflex.components.component import Component from reflex.components.next.link import NextLink -from reflex.utils import imports +from reflex.utils.imports import ImportDict from reflex.vars import BaseVar, Var next_link = NextLink.create() class Link(ChakraComponent): + def add_imports(self) -> ImportDict: ... @overload @classmethod def create( # type: ignore diff --git a/reflex/components/component.py b/reflex/components/component.py index 31396433d8..370ed5693c 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -44,7 +44,7 @@ ) from reflex.style import Style, format_as_emotion from reflex.utils import console, format, imports, types -from reflex.utils.imports import ImportVar +from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports from reflex.utils.serializers import serializer from reflex.vars import BaseVar, Var, VarData @@ -95,7 +95,7 @@ def _get_all_hooks(self) -> dict[str, None]: """ @abstractmethod - def _get_all_imports(self) -> imports.ImportDict: + def _get_all_imports(self) -> ParsedImportDict: """Get all the libraries and fields that are used by the component. Returns: @@ -213,7 +213,7 @@ class Component(BaseComponent, ABC): # State class associated with this component instance State: Optional[Type[reflex.state.State]] = None - def add_imports(self) -> dict[str, str | ImportVar | list[str | ImportVar]]: + def add_imports(self) -> ImportDict | list[ImportDict]: """Add imports for the component. This method should be implemented by subclasses to add new imports for the component. @@ -1224,7 +1224,7 @@ def _get_all_dynamic_imports(self) -> Set[str]: # Return the dynamic imports return dynamic_imports - def _get_props_imports(self) -> List[str]: + def _get_props_imports(self) -> List[ParsedImportDict]: """Get the imports needed for components props. Returns: @@ -1250,7 +1250,7 @@ def _should_transpile(self, dep: str | None) -> bool: or format.format_library_name(dep or "") in self.transpile_packages ) - def _get_dependencies_imports(self) -> imports.ImportDict: + def _get_dependencies_imports(self) -> ParsedImportDict: """Get the imports from lib_dependencies for installing. Returns: @@ -1267,7 +1267,7 @@ def _get_dependencies_imports(self) -> imports.ImportDict: for dep in self.lib_dependencies } - def _get_hooks_imports(self) -> imports.ImportDict: + def _get_hooks_imports(self) -> ParsedImportDict: """Get the imports required by certain hooks. Returns: @@ -1308,7 +1308,7 @@ def _get_hooks_imports(self) -> imports.ImportDict: return imports.merge_imports(_imports, *other_imports) - def _get_imports(self) -> imports.ImportDict: + def _get_imports(self) -> ParsedImportDict: """Get all the libraries and fields that are used by the component. Returns: @@ -1328,25 +1328,15 @@ def _get_imports(self) -> imports.ImportDict: var._var_data.imports for var in self._get_vars() if var._var_data ] - # If any subclass implements add_imports, merge the imports. - def _make_list( - value: str | ImportVar | list[str | ImportVar], - ) -> list[str | ImportVar]: - if isinstance(value, (str, ImportVar)): - return [value] - return value - - _added_import_dicts = [] + added_import_dicts: list[ParsedImportDict] = [] for clz in self._iter_parent_classes_with_method("add_imports"): - _added_import_dicts.append( - { - package: [ - ImportVar(tag=tag) if not isinstance(tag, ImportVar) else tag - for tag in _make_list(maybe_tags) - ] - for package, maybe_tags in clz.add_imports(self).items() - } - ) + list_of_import_dict = clz.add_imports(self) + + if not isinstance(list_of_import_dict, list): + list_of_import_dict = [list_of_import_dict] + + for import_dict in list_of_import_dict: + added_import_dicts.append(parse_imports(import_dict)) return imports.merge_imports( *self._get_props_imports(), @@ -1355,10 +1345,10 @@ def _make_list( _imports, event_imports, *var_imports, - *_added_import_dicts, + *added_import_dicts, ) - def _get_all_imports(self, collapse: bool = False) -> imports.ImportDict: + def _get_all_imports(self, collapse: bool = False) -> ParsedImportDict: """Get all the libraries and fields that are used by the component and its children. Args: @@ -1453,7 +1443,7 @@ def _get_hooks_internal(self) -> dict[str, None]: **self._get_special_hooks(), } - def _get_added_hooks(self) -> dict[str, imports.ImportDict]: + def _get_added_hooks(self) -> dict[str, ImportDict]: """Get the hooks added via `add_hooks` method. Returns: @@ -1842,7 +1832,7 @@ def wrapper(*children, **props) -> CustomComponent: class NoSSRComponent(Component): """A dynamic component that is not rendered on the server.""" - def _get_imports(self) -> imports.ImportDict: + def _get_imports(self) -> ParsedImportDict: """Get the imports for the component. Returns: @@ -2185,7 +2175,7 @@ def _get_all_hooks(self) -> dict[str, None]: """ return {} - def _get_all_imports(self) -> imports.ImportDict: + def _get_all_imports(self) -> ParsedImportDict: """Get all the libraries and fields that are used by the component. Returns: diff --git a/reflex/components/core/banner.py b/reflex/components/core/banner.py index 33a6f0dee1..b634ab75a8 100644 --- a/reflex/components/core/banner.py +++ b/reflex/components/core/banner.py @@ -19,7 +19,7 @@ from reflex.components.sonner.toast import Toaster, ToastProps from reflex.constants import Dirs, Hooks, Imports from reflex.constants.compiler import CompileVars -from reflex.utils import imports +from reflex.utils.imports import ImportDict, ImportVar from reflex.utils.serializers import serialize from reflex.vars import Var, VarData @@ -65,10 +65,15 @@ class WebsocketTargetURL(Bare): """A component that renders the websocket target URL.""" - def _get_imports(self) -> imports.ImportDict: + def add_imports(self) -> ImportDict: + """Add imports for the websocket target URL component. + + Returns: + The import dict. + """ return { - f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="getBackendURL")], - "/env.json": [imports.ImportVar(tag="env", is_default=True)], + f"/{Dirs.STATE_PATH}": [ImportVar(tag="getBackendURL")], + "/env.json": [ImportVar(tag="env", is_default=True)], } @classmethod @@ -98,7 +103,7 @@ def default_connection_error() -> list[str | Var | Component]: class ConnectionToaster(Toaster): """A connection toaster component.""" - def add_hooks(self) -> list[str]: + def add_hooks(self) -> list[str | Var]: """Add the hooks for the connection toaster. Returns: @@ -116,7 +121,7 @@ def add_hooks(self) -> list[str]: duration=120000, id=toast_id, ) - hook = Var.create( + hook = Var.create_safe( f""" const toast_props = {serialize(props)}; const [userDismissed, setUserDismissed] = useState(false); @@ -135,22 +140,17 @@ def add_hooks(self) -> list[str]: }}, [{connect_errors}]);""", _var_is_string=False, ) - - hook._var_data = VarData.merge( # type: ignore + imports: ImportDict = { + "react": ["useEffect", "useState"], + **target_url._get_imports(), # type: ignore + } + hook._var_data = VarData.merge( connect_errors._var_data, - VarData( - imports={ - "react": [ - imports.ImportVar(tag="useEffect"), - imports.ImportVar(tag="useState"), - ], - **target_url._get_imports(), - } - ), + VarData(imports=imports), ) return [ Hooks.EVENTS, - hook, # type: ignore + hook, ] @@ -216,10 +216,11 @@ class WifiOffPulse(Icon): """A wifi_off icon with an animated opacity pulse.""" @classmethod - def create(cls, **props) -> Component: + def create(cls, *children, **props) -> Icon: """Create a wifi_off icon with an animated opacity pulse. Args: + *children: The children of the component. **props: The properties of the component. Returns: @@ -237,11 +238,13 @@ def create(cls, **props) -> Component: **props, ) - def _get_imports(self) -> imports.ImportDict: - return imports.merge_imports( - super()._get_imports(), - {"@emotion/react": [imports.ImportVar(tag="keyframes")]}, - ) + def add_imports(self) -> dict[str, str | ImportVar | list[str | ImportVar]]: + """Add imports for the WifiOffPulse component. + + Returns: + The import dict. + """ + return {"@emotion/react": [ImportVar(tag="keyframes")]} def _get_custom_code(self) -> str | None: return """ diff --git a/reflex/components/core/banner.pyi b/reflex/components/core/banner.pyi index 1cf4ec87ce..c957bab932 100644 --- a/reflex/components/core/banner.pyi +++ b/reflex/components/core/banner.pyi @@ -23,7 +23,7 @@ from reflex.components.radix.themes.typography.text import Text from reflex.components.sonner.toast import Toaster, ToastProps from reflex.constants import Dirs, Hooks, Imports from reflex.constants.compiler import CompileVars -from reflex.utils import imports +from reflex.utils.imports import ImportDict, ImportVar from reflex.utils.serializers import serialize from reflex.vars import Var, VarData @@ -35,6 +35,7 @@ has_connection_errors: Var has_too_many_connection_errors: Var class WebsocketTargetURL(Bare): + def add_imports(self) -> ImportDict: ... @overload @classmethod def create( # type: ignore @@ -104,7 +105,7 @@ class WebsocketTargetURL(Bare): def default_connection_error() -> list[str | Var | Component]: ... class ConnectionToaster(Toaster): - def add_hooks(self) -> list[str]: ... + def add_hooks(self) -> list[str | Var]: ... @overload @classmethod def create( # type: ignore @@ -430,6 +431,7 @@ class WifiOffPulse(Icon): """Create a wifi_off icon with an animated opacity pulse. Args: + *children: The children of the component. size: The size of the icon in pixels. style: The style of the component. key: A unique key for the component. @@ -443,6 +445,7 @@ class WifiOffPulse(Icon): The icon component with default props applied. """ ... + def add_imports(self) -> dict[str, str | ImportVar | list[str | ImportVar]]: ... class ConnectionPulser(Div): @overload diff --git a/reflex/components/core/cond.py b/reflex/components/core/cond.py index fcc12bc51c..15b56d7514 100644 --- a/reflex/components/core/cond.py +++ b/reflex/components/core/cond.py @@ -10,11 +10,12 @@ from reflex.constants import Dirs from reflex.constants.colors import Color from reflex.style import LIGHT_COLOR_MODE, color_mode -from reflex.utils import format, imports +from reflex.utils import format +from reflex.utils.imports import ImportDict, ImportVar from reflex.vars import Var, VarData -_IS_TRUE_IMPORT = { - f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="isTrue")], +_IS_TRUE_IMPORT: ImportDict = { + f"/{Dirs.STATE_PATH}": [ImportVar(tag="isTrue")], } @@ -96,12 +97,16 @@ def render(self) -> Dict: cond_state=f"isTrue({self.cond._var_full_name})", ) - def _get_imports(self) -> imports.ImportDict: - return imports.merge_imports( - super()._get_imports(), - getattr(self.cond._var_data, "imports", {}), - _IS_TRUE_IMPORT, + def add_imports(self) -> ImportDict: + """Add imports for the Cond component. + + Returns: + The import dict for the component. + """ + cond_imports: dict[str, str | ImportVar | list[str | ImportVar]] = getattr( + self.cond._var_data, "imports", {} ) + return {**cond_imports, **_IS_TRUE_IMPORT} @overload diff --git a/reflex/components/core/match.py b/reflex/components/core/match.py index 8b684678ef..e85739605b 100644 --- a/reflex/components/core/match.py +++ b/reflex/components/core/match.py @@ -8,8 +8,9 @@ from reflex.components.core.colors import Color from reflex.components.tags import MatchTag, Tag from reflex.style import Style -from reflex.utils import format, imports, types +from reflex.utils import format, types from reflex.utils.exceptions import MatchTypeError +from reflex.utils.imports import ImportDict from reflex.vars import BaseVar, Var, VarData @@ -268,11 +269,13 @@ def render(self) -> Dict: tag.name = "match" return dict(tag) - def _get_imports(self) -> imports.ImportDict: - return imports.merge_imports( - super()._get_imports(), - getattr(self.cond._var_data, "imports", {}), - ) + def add_imports(self) -> ImportDict: + """Add imports for the Match component. + + Returns: + The import dict. + """ + return getattr(self.cond._var_data, "imports", {}) match = Match.create diff --git a/reflex/components/core/upload.py b/reflex/components/core/upload.py index d499772983..f053ef7c7f 100644 --- a/reflex/components/core/upload.py +++ b/reflex/components/core/upload.py @@ -19,17 +19,15 @@ call_script, parse_args_spec, ) -from reflex.utils import imports +from reflex.utils.imports import ImportVar from reflex.vars import BaseVar, CallableVar, Var, VarData DEFAULT_UPLOAD_ID: str = "default" upload_files_context_var_data: VarData = VarData( imports={ - "react": [imports.ImportVar(tag="useContext")], - f"/{Dirs.CONTEXTS_PATH}": [ - imports.ImportVar(tag="UploadFilesContext"), - ], + "react": "useContext", + f"/{Dirs.CONTEXTS_PATH}": "UploadFilesContext", }, hooks={ "const [filesById, setFilesById] = useContext(UploadFilesContext);": None, @@ -133,8 +131,8 @@ def get_upload_dir() -> Path: _var_is_string=False, _var_data=VarData( imports={ - f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="getBackendURL")], - "/env.json": [imports.ImportVar(tag="env", is_default=True)], + f"/{Dirs.STATE_PATH}": "getBackendURL", + "/env.json": ImportVar(tag="env", is_default=True), } ), ) diff --git a/reflex/components/core/upload.pyi b/reflex/components/core/upload.pyi index b9b4ad8d0d..728f5d074a 100644 --- a/reflex/components/core/upload.pyi +++ b/reflex/components/core/upload.pyi @@ -23,7 +23,7 @@ from reflex.event import ( call_script, parse_args_spec, ) -from reflex.utils import imports +from reflex.utils.imports import ImportVar from reflex.vars import BaseVar, CallableVar, Var, VarData DEFAULT_UPLOAD_ID: str diff --git a/reflex/components/datadisplay/code.py b/reflex/components/datadisplay/code.py index d706ec6f3d..4f24c0ab19 100644 --- a/reflex/components/datadisplay/code.py +++ b/reflex/components/datadisplay/code.py @@ -12,8 +12,8 @@ from reflex.constants.colors import Color from reflex.event import set_clipboard from reflex.style import Style -from reflex.utils import format, imports -from reflex.utils.imports import ImportVar +from reflex.utils import format +from reflex.utils.imports import ImportDict, ImportVar from reflex.vars import Var LiteralCodeBlockTheme = Literal[ @@ -381,42 +381,45 @@ class CodeBlock(Component): # Props passed down to the code tag. code_tag_props: Var[Dict[str, str]] - def _get_imports(self) -> imports.ImportDict: - merged_imports = super()._get_imports() - # Get all themes from a cond literal + def add_imports(self) -> ImportDict: + """Add imports for the CodeBlock component. + + Returns: + The import dict. + """ + imports_: ImportDict = {} themes = re.findall(r"`(.*?)`", self.theme._var_name) if not themes: themes = [self.theme._var_name] - merged_imports = imports.merge_imports( - merged_imports, + + imports_.update( { - f"react-syntax-highlighter/dist/cjs/styles/prism/{self.convert_theme_name(theme)}": { + f"react-syntax-highlighter/dist/cjs/styles/prism/{self.convert_theme_name(theme)}": [ ImportVar( tag=format.to_camel_case(self.convert_theme_name(theme)), is_default=True, install=False, ) - } + ] for theme in themes - }, + } ) + if ( self.language is not None and self.language._var_name in LiteralCodeLanguage.__args__ # type: ignore ): - merged_imports = imports.merge_imports( - merged_imports, - { - f"react-syntax-highlighter/dist/cjs/languages/prism/{self.language._var_name}": { - ImportVar( - tag=format.to_camel_case(self.language._var_name), - is_default=True, - install=False, - ) - } - }, - ) - return merged_imports + imports_[ + f"react-syntax-highlighter/dist/cjs/languages/prism/{self.language._var_name}" + ] = [ + ImportVar( + tag=format.to_camel_case(self.language._var_name), + is_default=True, + install=False, + ) + ] + + return imports_ def _get_custom_code(self) -> Optional[str]: if ( diff --git a/reflex/components/datadisplay/code.pyi b/reflex/components/datadisplay/code.pyi index 238cb2fb65..c1d700bda4 100644 --- a/reflex/components/datadisplay/code.pyi +++ b/reflex/components/datadisplay/code.pyi @@ -17,8 +17,8 @@ from reflex.components.core.cond import color_mode_cond from reflex.constants.colors import Color from reflex.event import set_clipboard from reflex.style import Style -from reflex.utils import format, imports -from reflex.utils.imports import ImportVar +from reflex.utils import format +from reflex.utils.imports import ImportDict, ImportVar from reflex.vars import Var LiteralCodeBlockTheme = Literal[ @@ -351,6 +351,7 @@ LiteralCodeLanguage = Literal[ ] class CodeBlock(Component): + def add_imports(self) -> ImportDict: ... @overload @classmethod def create( # type: ignore diff --git a/reflex/components/datadisplay/dataeditor.py b/reflex/components/datadisplay/dataeditor.py index 16c289f883..ebada60473 100644 --- a/reflex/components/datadisplay/dataeditor.py +++ b/reflex/components/datadisplay/dataeditor.py @@ -2,13 +2,14 @@ from __future__ import annotations from enum import Enum -from typing import Any, Callable, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from reflex.base import Base from reflex.components.component import Component, NoSSRComponent from reflex.components.literals import LiteralRowMarker -from reflex.utils import console, format, imports, types -from reflex.utils.imports import ImportVar +from reflex.event import EventHandler +from reflex.utils import console, format, types +from reflex.utils.imports import ImportDict, ImportVar from reflex.utils.serializers import serializer from reflex.vars import Var, get_unique_variable_name @@ -205,51 +206,66 @@ class DataEditor(NoSSRComponent): # global theme theme: Var[Union[DataEditorTheme, Dict]] - def _get_imports(self): - return imports.merge_imports( - super()._get_imports(), - { - "": { - ImportVar( - tag=f"{format.format_library_name(self.library)}/dist/index.css" - ) - }, - self.library: {ImportVar(tag="GridCellKind")}, - "/utils/helpers/dataeditor.js": { - ImportVar( - tag=f"formatDataEditorCells", is_default=False, install=False - ), - }, - }, - ) + # Triggered when a cell is activated. + on_cell_activated: EventHandler[lambda pos: [pos]] - def get_event_triggers(self) -> Dict[str, Callable]: - """The event triggers of the component. + # Triggered when a cell is clicked. + on_cell_clicked: EventHandler[lambda pos: [pos]] - Returns: - The dict describing the event triggers. - """ + # Triggered when a cell is right-clicked. + on_cell_context_menu: EventHandler[lambda pos: [pos]] + + # Triggered when a cell is edited. + on_cell_edited: EventHandler[lambda pos, data: [pos, data]] + + # Triggered when a group header is clicked. + on_group_header_clicked: EventHandler[lambda pos, data: [pos, data]] + + # Triggered when a group header is right-clicked. + on_group_header_context_menu: EventHandler[lambda grp_idx, data: [grp_idx, data]] + + # Triggered when a group header is renamed. + on_group_header_renamed: EventHandler[lambda idx, val: [idx, val]] + + # Triggered when a header is clicked. + on_header_clicked: EventHandler[lambda pos: [pos]] + + # Triggered when a header is right-clicked. + on_header_context_menu: EventHandler[lambda pos: [pos]] - def edit_sig(pos, data: dict[str, Any]): - return [pos, data] + # Triggered when a header menu is clicked. + on_header_menu_click: EventHandler[lambda col, pos: [col, pos]] + # Triggered when an item is hovered. + on_item_hovered: EventHandler[lambda pos: [pos]] + + # Triggered when a selection is deleted. + on_delete: EventHandler[lambda selection: [selection]] + + # Triggered when editing is finished. + on_finished_editing: EventHandler[lambda new_value, movement: [new_value, movement]] + + # Triggered when a row is appended. + on_row_appended: EventHandler[lambda: []] + + # Triggered when the selection is cleared. + on_selection_cleared: EventHandler[lambda: []] + + # Triggered when a column is resized. + on_column_resize: EventHandler[lambda col, width: [col, width]] + + def add_imports(self) -> ImportDict: + """Add imports for the component. + + Returns: + The import dict. + """ return { - "on_cell_activated": lambda pos: [pos], - "on_cell_clicked": lambda pos: [pos], - "on_cell_context_menu": lambda pos: [pos], - "on_cell_edited": edit_sig, - "on_group_header_clicked": edit_sig, - "on_group_header_context_menu": lambda grp_idx, data: [grp_idx, data], - "on_group_header_renamed": lambda idx, val: [idx, val], - "on_header_clicked": lambda pos: [pos], - "on_header_context_menu": lambda pos: [pos], - "on_header_menu_click": lambda col, pos: [col, pos], - "on_item_hovered": lambda pos: [pos], - "on_delete": lambda selection: [selection], - "on_finished_editing": lambda new_value, movement: [new_value, movement], - "on_row_appended": lambda: [], - "on_selection_cleared": lambda: [], - "on_column_resize": lambda col, width: [col, width], + "": f"{format.format_library_name(self.library)}/dist/index.css", + self.library: "GridCellKind", + "/utils/helpers/dataeditor.js": ImportVar( + tag="formatDataEditorCells", is_default=False, install=False + ), } def add_hooks(self) -> list[str]: diff --git a/reflex/components/datadisplay/dataeditor.pyi b/reflex/components/datadisplay/dataeditor.pyi index cfdd41fdfe..dab31a5e6a 100644 --- a/reflex/components/datadisplay/dataeditor.pyi +++ b/reflex/components/datadisplay/dataeditor.pyi @@ -8,12 +8,13 @@ from reflex.vars import Var, BaseVar, ComputedVar from reflex.event import EventChain, EventHandler, EventSpec from reflex.style import Style from enum import Enum -from typing import Any, Callable, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from reflex.base import Base from reflex.components.component import Component, NoSSRComponent from reflex.components.literals import LiteralRowMarker -from reflex.utils import console, format, imports, types -from reflex.utils.imports import ImportVar +from reflex.event import EventHandler +from reflex.utils import console, format, types +from reflex.utils.imports import ImportDict, ImportVar from reflex.utils.serializers import serializer from reflex.vars import Var, get_unique_variable_name @@ -80,7 +81,7 @@ class DataEditorTheme(Base): text_medium: Optional[str] class DataEditor(NoSSRComponent): - def get_event_triggers(self) -> Dict[str, Callable]: ... + def add_imports(self) -> ImportDict: ... def add_hooks(self) -> list[str]: ... @overload @classmethod @@ -136,6 +137,9 @@ class DataEditor(NoSSRComponent): class_name: Optional[Any] = None, autofocus: Optional[bool] = None, custom_attrs: Optional[Dict[str, Union[Var, str]]] = None, + on_blur: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, on_cell_activated: Optional[ Union[EventHandler, EventSpec, list, function, BaseVar] ] = None, @@ -148,15 +152,27 @@ class DataEditor(NoSSRComponent): on_cell_edited: Optional[ Union[EventHandler, EventSpec, list, function, BaseVar] ] = None, + on_click: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, on_column_resize: Optional[ Union[EventHandler, EventSpec, list, function, BaseVar] ] = None, + on_context_menu: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, on_delete: Optional[ Union[EventHandler, EventSpec, list, function, BaseVar] ] = None, + on_double_click: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, on_finished_editing: Optional[ Union[EventHandler, EventSpec, list, function, BaseVar] ] = None, + on_focus: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, on_group_header_clicked: Optional[ Union[EventHandler, EventSpec, list, function, BaseVar] ] = None, @@ -178,12 +194,42 @@ class DataEditor(NoSSRComponent): on_item_hovered: Optional[ Union[EventHandler, EventSpec, list, function, BaseVar] ] = None, + on_mount: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + on_mouse_down: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + on_mouse_enter: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + on_mouse_leave: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + on_mouse_move: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + on_mouse_out: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + on_mouse_over: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, + on_mouse_up: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, on_row_appended: Optional[ Union[EventHandler, EventSpec, list, function, BaseVar] ] = None, + on_scroll: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, on_selection_cleared: Optional[ Union[EventHandler, EventSpec, list, function, BaseVar] ] = None, + on_unmount: Optional[ + Union[EventHandler, EventSpec, list, function, BaseVar] + ] = None, **props ) -> "DataEditor": """Create the DataEditor component. diff --git a/reflex/components/el/elements/forms.py b/reflex/components/el/elements/forms.py index 750c4cdb4b..4f2cdcff6f 100644 --- a/reflex/components/el/elements/forms.py +++ b/reflex/components/el/elements/forms.py @@ -11,8 +11,8 @@ from reflex.components.tags.tag import Tag from reflex.constants import Dirs, EventTriggers from reflex.event import EventChain -from reflex.utils import imports from reflex.utils.format import format_event_chain +from reflex.utils.imports import ImportDict from reflex.vars import BaseVar, Var from .base import BaseHTML @@ -169,17 +169,16 @@ def create(cls, *children, **props): ).hexdigest() return form - def _get_imports(self) -> imports.ImportDict: - return imports.merge_imports( - super()._get_imports(), - { - "react": {imports.ImportVar(tag="useCallback")}, - f"/{Dirs.STATE_PATH}": { - imports.ImportVar(tag="getRefValue"), - imports.ImportVar(tag="getRefValues"), - }, - }, - ) + def add_imports(self) -> ImportDict: + """Add imports needed by the form component. + + Returns: + The imports for the form component. + """ + return { + "react": "useCallback", + f"/{Dirs.STATE_PATH}": ["getRefValue", "getRefValues"], + } def add_hooks(self) -> list[str]: """Add hooks for the form. diff --git a/reflex/components/el/elements/forms.pyi b/reflex/components/el/elements/forms.pyi index 4655f3d0a3..e4ee74f32d 100644 --- a/reflex/components/el/elements/forms.pyi +++ b/reflex/components/el/elements/forms.pyi @@ -14,8 +14,8 @@ from reflex.components.el.element import Element from reflex.components.tags.tag import Tag from reflex.constants import Dirs, EventTriggers from reflex.event import EventChain -from reflex.utils import imports from reflex.utils.format import format_event_chain +from reflex.utils.imports import ImportDict from reflex.vars import BaseVar, Var from .base import BaseHTML @@ -581,6 +581,7 @@ class Form(BaseHTML): The form component. """ ... + def add_imports(self) -> ImportDict: ... def add_hooks(self) -> list[str]: ... class Input(BaseHTML): diff --git a/reflex/components/gridjs/datatable.py b/reflex/components/gridjs/datatable.py index fd0a220212..6d856cf454 100644 --- a/reflex/components/gridjs/datatable.py +++ b/reflex/components/gridjs/datatable.py @@ -6,7 +6,8 @@ from reflex.components.component import Component from reflex.components.tags import Tag -from reflex.utils import imports, types +from reflex.utils import types +from reflex.utils.imports import ImportDict from reflex.utils.serializers import serialize from reflex.vars import BaseVar, ComputedVar, Var @@ -102,11 +103,13 @@ def create(cls, *children, **props): **props, ) - def _get_imports(self) -> imports.ImportDict: - return imports.merge_imports( - super()._get_imports(), - {"": {imports.ImportVar(tag="gridjs/dist/theme/mermaid.css")}}, - ) + def add_imports(self) -> ImportDict: + """Add the imports for the datatable component. + + Returns: + The import dict for the component. + """ + return {"": "gridjs/dist/theme/mermaid.css"} def _render(self) -> Tag: if isinstance(self.data, Var) and types.is_dataframe(self.data._var_type): diff --git a/reflex/components/gridjs/datatable.pyi b/reflex/components/gridjs/datatable.pyi index 9a401b63f8..13a7131e71 100644 --- a/reflex/components/gridjs/datatable.pyi +++ b/reflex/components/gridjs/datatable.pyi @@ -10,7 +10,8 @@ from reflex.style import Style from typing import Any, Dict, List, Union from reflex.components.component import Component from reflex.components.tags import Tag -from reflex.utils import imports, types +from reflex.utils import types +from reflex.utils.imports import ImportDict from reflex.utils.serializers import serialize from reflex.vars import BaseVar, ComputedVar, Var @@ -180,3 +181,4 @@ class DataTable(Gridjs): ValueError: If a pandas dataframe is passed in and columns are also provided. """ ... + def add_imports(self) -> ImportDict: ... diff --git a/reflex/components/markdown/markdown.py b/reflex/components/markdown/markdown.py index 68ed6b42f5..d9a62b98da 100644 --- a/reflex/components/markdown/markdown.py +++ b/reflex/components/markdown/markdown.py @@ -7,7 +7,6 @@ from hashlib import md5 from typing import Any, Callable, Dict, Union -from reflex.compiler import utils from reflex.components.component import Component, CustomComponent from reflex.components.radix.themes.layout.list import ( ListItem, @@ -18,8 +17,8 @@ from reflex.components.radix.themes.typography.link import Link from reflex.components.radix.themes.typography.text import Text from reflex.components.tags.tag import Tag -from reflex.utils import imports, types -from reflex.utils.imports import ImportVar +from reflex.utils import types +from reflex.utils.imports import ImportDict, ImportVar from reflex.vars import Var # Special vars used in the component map. @@ -145,47 +144,41 @@ def _get_all_custom_components( return custom_components - def _get_imports(self) -> imports.ImportDict: - # Import here to avoid circular imports. + def add_imports(self) -> ImportDict | list[ImportDict]: + """Add imports for the markdown component. + + Returns: + The imports for the markdown component. + """ from reflex.components.datadisplay.code import CodeBlock from reflex.components.radix.themes.typography.code import Code - imports = super()._get_imports() - - # Special markdown imports. - imports.update( + return [ { - "": [ImportVar(tag="katex/dist/katex.min.css")], - "remark-math@5.1.1": [ - ImportVar(tag=_REMARK_MATH._var_name, is_default=True) - ], - "remark-gfm@3.0.1": [ - ImportVar(tag=_REMARK_GFM._var_name, is_default=True) - ], - "remark-unwrap-images@4.0.0": [ - ImportVar(tag=_REMARK_UNWRAP_IMAGES._var_name, is_default=True) - ], - "rehype-katex@6.0.3": [ - ImportVar(tag=_REHYPE_KATEX._var_name, is_default=True) - ], - "rehype-raw@6.1.1": [ - ImportVar(tag=_REHYPE_RAW._var_name, is_default=True) - ], - } - ) - - # Get the imports for each component. - for component in self.component_map.values(): - imports = utils.merge_imports( - imports, component(_MOCK_ARG)._get_all_imports() - ) - - # Get the imports for the code components. - imports = utils.merge_imports( - imports, CodeBlock.create(theme="light")._get_imports() - ) - imports = utils.merge_imports(imports, Code.create()._get_imports()) - return imports + "": "katex/dist/katex.min.css", + "remark-math@5.1.1": ImportVar( + tag=_REMARK_MATH._var_name, is_default=True + ), + "remark-gfm@3.0.1": ImportVar( + tag=_REMARK_GFM._var_name, is_default=True + ), + "remark-unwrap-images@4.0.0": ImportVar( + tag=_REMARK_UNWRAP_IMAGES._var_name, is_default=True + ), + "rehype-katex@6.0.3": ImportVar( + tag=_REHYPE_KATEX._var_name, is_default=True + ), + "rehype-raw@6.1.1": ImportVar( + tag=_REHYPE_RAW._var_name, is_default=True + ), + }, + *[ + component(_MOCK_ARG)._get_imports() # type: ignore + for component in self.component_map.values() + ], + CodeBlock.create(theme="light")._get_imports(), # type: ignore, + Code.create()._get_imports(), # type: ignore, + ] def get_component(self, tag: str, **props) -> Component: """Get the component for a tag and props. diff --git a/reflex/components/markdown/markdown.pyi b/reflex/components/markdown/markdown.pyi index 8dec6400ac..5b2968bf5e 100644 --- a/reflex/components/markdown/markdown.pyi +++ b/reflex/components/markdown/markdown.pyi @@ -11,7 +11,6 @@ import textwrap from functools import lru_cache from hashlib import md5 from typing import Any, Callable, Dict, Union -from reflex.compiler import utils from reflex.components.component import Component, CustomComponent from reflex.components.radix.themes.layout.list import ( ListItem, @@ -22,8 +21,8 @@ from reflex.components.radix.themes.typography.heading import Heading from reflex.components.radix.themes.typography.link import Link from reflex.components.radix.themes.typography.text import Text from reflex.components.tags.tag import Tag -from reflex.utils import imports, types -from reflex.utils.imports import ImportVar +from reflex.utils import types +from reflex.utils.imports import ImportDict, ImportVar from reflex.vars import Var _CHILDREN = Var.create_safe("children", _var_is_local=False, _var_is_string=False) @@ -124,6 +123,7 @@ class Markdown(Component): The markdown component. """ ... + def add_imports(self) -> ImportDict | list[ImportDict]: ... def get_component(self, tag: str, **props) -> Component: ... def format_component(self, tag: str, **props) -> str: ... def format_component_map(self) -> dict[str, str]: ... diff --git a/reflex/components/moment/moment.py b/reflex/components/moment/moment.py index 53e199c4e8..4672444845 100644 --- a/reflex/components/moment/moment.py +++ b/reflex/components/moment/moment.py @@ -4,7 +4,7 @@ from reflex.base import Base from reflex.components.component import Component, NoSSRComponent -from reflex.utils import imports +from reflex.utils.imports import ImportDict from reflex.vars import Var @@ -90,14 +90,15 @@ class Moment(NoSSRComponent): # Display the date in the given timezone. tz: Var[str] - def _get_imports(self) -> imports.ImportDict: - merged_imports = super()._get_imports() + def add_imports(self) -> ImportDict: + """Add the imports for the Moment component. + + Returns: + The import dict for the component. + """ if self.tz is not None: - merged_imports = imports.merge_imports( - merged_imports, - {"moment-timezone": {imports.ImportVar(tag="")}}, - ) - return merged_imports + return {"moment-timezone": ""} + return {} def get_event_triggers(self) -> Dict[str, Any]: """Get the events triggers signatures for the component. diff --git a/reflex/components/moment/moment.pyi b/reflex/components/moment/moment.pyi index 73ad8ca5d7..b456baeb0a 100644 --- a/reflex/components/moment/moment.pyi +++ b/reflex/components/moment/moment.pyi @@ -10,7 +10,7 @@ from reflex.style import Style from typing import Any, Dict, List, Optional from reflex.base import Base from reflex.components.component import Component, NoSSRComponent -from reflex.utils import imports +from reflex.utils.imports import ImportDict from reflex.vars import Var class MomentDelta(Base): @@ -25,6 +25,7 @@ class MomentDelta(Base): milliseconds: Optional[int] class Moment(NoSSRComponent): + def add_imports(self) -> ImportDict: ... def get_event_triggers(self) -> Dict[str, Any]: ... @overload @classmethod diff --git a/reflex/components/radix/primitives/accordion.py b/reflex/components/radix/primitives/accordion.py index b2ea91a260..bea3810cfd 100644 --- a/reflex/components/radix/primitives/accordion.py +++ b/reflex/components/radix/primitives/accordion.py @@ -11,7 +11,6 @@ from reflex.components.radix.primitives.base import RadixPrimitiveComponent from reflex.components.radix.themes.base import LiteralAccentColor, LiteralRadius from reflex.style import Style -from reflex.utils import imports from reflex.vars import Var, get_uuid_string_var LiteralAccordionType = Literal["single", "multiple"] @@ -413,13 +412,13 @@ class AccordionContent(AccordionComponent): alias = "RadixAccordionContent" - def add_imports(self) -> imports.ImportDict: + def add_imports(self) -> dict: """Add imports to the component. Returns: The imports of the component. """ - return {"@emotion/react": [imports.ImportVar(tag="keyframes")]} + return {"@emotion/react": "keyframes"} @classmethod def create(cls, *children, **props) -> Component: diff --git a/reflex/components/radix/primitives/accordion.pyi b/reflex/components/radix/primitives/accordion.pyi index a04073d2d0..2b9276593c 100644 --- a/reflex/components/radix/primitives/accordion.pyi +++ b/reflex/components/radix/primitives/accordion.pyi @@ -15,7 +15,6 @@ from reflex.components.lucide.icon import Icon from reflex.components.radix.primitives.base import RadixPrimitiveComponent from reflex.components.radix.themes.base import LiteralAccentColor, LiteralRadius from reflex.style import Style -from reflex.utils import imports from reflex.vars import Var, get_uuid_string_var LiteralAccordionType = Literal["single", "multiple"] @@ -899,7 +898,7 @@ class AccordionIcon(Icon): ... class AccordionContent(AccordionComponent): - def add_imports(self) -> imports.ImportDict: ... + def add_imports(self) -> dict: ... @overload @classmethod def create( # type: ignore diff --git a/reflex/components/radix/themes/base.py b/reflex/components/radix/themes/base.py index cf61b1704a..e0e05cc816 100644 --- a/reflex/components/radix/themes/base.py +++ b/reflex/components/radix/themes/base.py @@ -7,7 +7,7 @@ from reflex.components import Component from reflex.components.tags import Tag from reflex.config import get_config -from reflex.utils.imports import ImportVar +from reflex.utils.imports import ImportDict, ImportVar from reflex.vars import Var LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"] @@ -209,13 +209,13 @@ def create( children = [ThemePanel.create(), *children] return super().create(*children, **props) - def add_imports(self) -> dict[str, list[ImportVar] | ImportVar]: + def add_imports(self) -> ImportDict | list[ImportDict]: """Add imports for the Theme component. Returns: The import dict. """ - _imports: dict[str, list[ImportVar] | ImportVar] = { + _imports: ImportDict = { "/utils/theme.js": [ImportVar(tag="theme", is_default=True)], } if get_config().tailwind is None: diff --git a/reflex/components/radix/themes/base.pyi b/reflex/components/radix/themes/base.pyi index ba47ffee53..14ed8f7c30 100644 --- a/reflex/components/radix/themes/base.pyi +++ b/reflex/components/radix/themes/base.pyi @@ -11,7 +11,7 @@ from typing import Any, Dict, Literal from reflex.components import Component from reflex.components.tags import Tag from reflex.config import get_config -from reflex.utils.imports import ImportVar +from reflex.utils.imports import ImportDict, ImportVar from reflex.vars import Var LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"] @@ -580,7 +580,7 @@ class Theme(RadixThemesComponent): A new component instance. """ ... - def add_imports(self) -> dict[str, list[ImportVar] | ImportVar]: ... + def add_imports(self) -> ImportDict | list[ImportDict]: ... class ThemePanel(RadixThemesComponent): def add_imports(self) -> dict[str, str]: ... diff --git a/reflex/components/radix/themes/typography/link.py b/reflex/components/radix/themes/typography/link.py index 4df8a71f5a..f1b82ae913 100644 --- a/reflex/components/radix/themes/typography/link.py +++ b/reflex/components/radix/themes/typography/link.py @@ -12,7 +12,7 @@ from reflex.components.core.cond import cond from reflex.components.el.elements.inline import A from reflex.components.next.link import NextLink -from reflex.utils import imports +from reflex.utils.imports import ImportDict from reflex.vars import Var from ..base import ( @@ -59,8 +59,13 @@ class Link(RadixThemesComponent, A, MemoizationLeaf): # If True, the link will open in a new tab is_external: Var[bool] - def _get_imports(self) -> imports.ImportDict: - return {**super()._get_imports(), **next_link._get_imports()} + def add_imports(self) -> ImportDict: + """Add imports for the Link component. + + Returns: + The import dict. + """ + return next_link._get_imports() # type: ignore @classmethod def create(cls, *children, **props) -> Component: diff --git a/reflex/components/radix/themes/typography/link.pyi b/reflex/components/radix/themes/typography/link.pyi index c203d5d379..e6ae3a6325 100644 --- a/reflex/components/radix/themes/typography/link.pyi +++ b/reflex/components/radix/themes/typography/link.pyi @@ -13,7 +13,7 @@ from reflex.components.core.colors import color from reflex.components.core.cond import cond from reflex.components.el.elements.inline import A from reflex.components.next.link import NextLink -from reflex.utils import imports +from reflex.utils.imports import ImportDict from reflex.vars import Var from ..base import LiteralAccentColor, RadixThemesComponent from .base import LiteralTextSize, LiteralTextTrim, LiteralTextWeight @@ -22,6 +22,7 @@ LiteralLinkUnderline = Literal["auto", "hover", "always", "none"] next_link = NextLink.create() class Link(RadixThemesComponent, A, MemoizationLeaf): + def add_imports(self) -> ImportDict: ... @overload @classmethod def create( # type: ignore diff --git a/reflex/components/suneditor/editor.py b/reflex/components/suneditor/editor.py index 92a1e80c37..432a557e5e 100644 --- a/reflex/components/suneditor/editor.py +++ b/reflex/components/suneditor/editor.py @@ -8,7 +8,7 @@ from reflex.components.component import Component, NoSSRComponent from reflex.constants import EventTriggers from reflex.utils.format import to_camel_case -from reflex.utils.imports import ImportVar +from reflex.utils.imports import ImportDict, ImportVar from reflex.vars import Var @@ -176,12 +176,15 @@ class Editor(NoSSRComponent): # default: False disable_toolbar: Var[bool] - def _get_imports(self): - imports = super()._get_imports() - imports[""] = [ - ImportVar(tag="suneditor/dist/css/suneditor.min.css", install=False) - ] - return imports + def add_imports(self) -> ImportDict: + """Add imports for the Editor component. + + Returns: + The import dict. + """ + return { + "": ImportVar(tag="suneditor/dist/css/suneditor.min.css", install=False) + } def get_event_triggers(self) -> Dict[str, Any]: """Get the event triggers that pass the component's value to the handler. diff --git a/reflex/components/suneditor/editor.pyi b/reflex/components/suneditor/editor.pyi index f878ef538b..0f29d27e6b 100644 --- a/reflex/components/suneditor/editor.pyi +++ b/reflex/components/suneditor/editor.pyi @@ -13,7 +13,7 @@ from reflex.base import Base from reflex.components.component import Component, NoSSRComponent from reflex.constants import EventTriggers from reflex.utils.format import to_camel_case -from reflex.utils.imports import ImportVar +from reflex.utils.imports import ImportDict, ImportVar from reflex.vars import Var class EditorButtonList(list, enum.Enum): @@ -48,6 +48,7 @@ class EditorOptions(Base): button_list: Optional[List[Union[List[str], str]]] class Editor(NoSSRComponent): + def add_imports(self) -> ImportDict: ... def get_event_triggers(self) -> Dict[str, Any]: ... @overload @classmethod diff --git a/reflex/utils/imports.py b/reflex/utils/imports.py index 263de1e3da..397c305ff9 100644 --- a/reflex/utils/imports.py +++ b/reflex/utils/imports.py @@ -3,12 +3,12 @@ from __future__ import annotations from collections import defaultdict -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from reflex.base import Base -def merge_imports(*imports) -> ImportDict: +def merge_imports(*imports: ImportDict | ParsedImportDict) -> ParsedImportDict: """Merge multiple import dicts together. Args: @@ -24,7 +24,31 @@ def merge_imports(*imports) -> ImportDict: return all_imports -def collapse_imports(imports: ImportDict) -> ImportDict: +def parse_imports(imports: ImportDict | ParsedImportDict) -> ParsedImportDict: + """Parse the import dict into a standard format. + + Args: + imports: The import dict to parse. + + Returns: + The parsed import dict. + """ + + def _make_list(value: ImportTypes) -> list[str | ImportVar] | list[ImportVar]: + if isinstance(value, (str, ImportVar)): + return [value] + return value + + return { + package: [ + ImportVar(tag=tag) if isinstance(tag, str) else tag + for tag in _make_list(maybe_tags) + ] + for package, maybe_tags in imports.items() + } + + +def collapse_imports(imports: ParsedImportDict) -> ParsedImportDict: """Remove all duplicate ImportVar within an ImportDict. Args: @@ -33,7 +57,10 @@ def collapse_imports(imports: ImportDict) -> ImportDict: Returns: The collapsed import dict. """ - return {lib: list(set(import_vars)) for lib, import_vars in imports.items()} + return { + lib: list(set(import_vars)) if isinstance(import_vars, list) else import_vars + for lib, import_vars in imports.items() + } class ImportVar(Base): @@ -90,4 +117,6 @@ def __hash__(self) -> int: ) -ImportDict = Dict[str, List[ImportVar]] +ImportTypes = Union[str, ImportVar, List[Union[str, ImportVar]], List[ImportVar]] +ImportDict = Dict[str, ImportTypes] +ParsedImportDict = Dict[str, List[ImportVar]] diff --git a/reflex/vars.py b/reflex/vars.py index 438fa6c1cc..3b1c0019e5 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -39,7 +39,12 @@ from reflex.utils.exceptions import VarAttributeError, VarTypeError, VarValueError # This module used to export ImportVar itself, so we still import it for export here -from reflex.utils.imports import ImportDict, ImportVar +from reflex.utils.imports import ( + ImportDict, + ImportVar, + ParsedImportDict, + parse_imports, +) from reflex.utils.types import override if TYPE_CHECKING: @@ -120,7 +125,7 @@ class VarData(Base): state: str = "" # Imports needed to render this var - imports: ImportDict = {} + imports: ParsedImportDict = {} # Hooks that need to be present in the component to render this var hooks: Dict[str, None] = {} @@ -130,6 +135,19 @@ class VarData(Base): # segments. interpolations: List[Tuple[int, int]] = [] + def __init__( + self, imports: Union[ImportDict, ParsedImportDict] | None = None, **kwargs: Any + ): + """Initialize the var data. + + Args: + imports: The imports needed to render this var. + **kwargs: The var data fields. + """ + if imports: + kwargs["imports"] = parse_imports(imports) + super().__init__(**kwargs) + @classmethod def merge(cls, *others: VarData | None) -> VarData | None: """Merge multiple var data objects. diff --git a/reflex/vars.pyi b/reflex/vars.pyi index e95561318c..5cd322ec9f 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -10,7 +10,7 @@ from reflex.base import Base as Base from reflex.state import State as State from reflex.state import BaseState as BaseState from reflex.utils import console as console, format as format, types as types -from reflex.utils.imports import ImportVar +from reflex.utils.imports import ImportVar, ImportDict, ParsedImportDict from types import FunctionType from typing import ( Any, @@ -36,7 +36,7 @@ def _extract_var_data(value: Iterable) -> list[VarData | None]: ... class VarData(Base): state: str = "" - imports: dict[str, List[ImportVar]] = {} + imports: Union[ImportDict, ParsedImportDict] = {} hooks: Dict[str, None] = {} interpolations: List[Tuple[int, int]] = [] @classmethod diff --git a/tests/compiler/test_compiler.py b/tests/compiler/test_compiler.py index 20dc7dd312..63014cf33b 100644 --- a/tests/compiler/test_compiler.py +++ b/tests/compiler/test_compiler.py @@ -4,8 +4,7 @@ import pytest from reflex.compiler import compiler, utils -from reflex.utils import imports -from reflex.utils.imports import ImportVar +from reflex.utils.imports import ImportVar, ParsedImportDict @pytest.mark.parametrize( @@ -93,7 +92,7 @@ def test_compile_import_statement( ), ], ) -def test_compile_imports(import_dict: imports.ImportDict, test_dicts: List[dict]): +def test_compile_imports(import_dict: ParsedImportDict, test_dicts: List[dict]): """Test the compile_imports function. Args: diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 2e395ce37f..76a75a67a7 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -20,7 +20,7 @@ from reflex.state import BaseState from reflex.style import Style from reflex.utils import imports -from reflex.utils.imports import ImportVar +from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports from reflex.vars import BaseVar, Var, VarData @@ -56,7 +56,7 @@ class TestComponent1(Component): # A test string/number prop. text_or_number: Var[Union[int, str]] - def _get_imports(self) -> imports.ImportDict: + def _get_imports(self) -> ParsedImportDict: return {"react": [ImportVar(tag="Component")]} def _get_custom_code(self) -> str: @@ -89,7 +89,7 @@ def get_event_triggers(self) -> Dict[str, Any]: "on_close": lambda e0: [e0], } - def _get_imports(self) -> imports.ImportDict: + def _get_imports(self) -> ParsedImportDict: return {"react-redux": [ImportVar(tag="connect")]} def _get_custom_code(self) -> str: @@ -1773,21 +1773,15 @@ def get_event_triggers(self) -> Dict[str, Any]: ), ) def test_component_add_imports(tags): - def _list_to_import_vars(tags: List[str]) -> List[ImportVar]: - return [ - ImportVar(tag=tag) if not isinstance(tag, ImportVar) else tag - for tag in tags - ] - class BaseComponent(Component): - def _get_imports(self) -> imports.ImportDict: + def _get_imports(self) -> ImportDict: return {} class Reference(Component): - def _get_imports(self) -> imports.ImportDict: + def _get_imports(self) -> ParsedImportDict: return imports.merge_imports( super()._get_imports(), - {"react": _list_to_import_vars(tags)}, + parse_imports({"react": tags}), {"foo": [ImportVar(tag="bar")]}, ) @@ -1806,10 +1800,12 @@ def add_imports( baseline = Reference.create() test = Test.create() - assert baseline._get_all_imports() == { - "react": _list_to_import_vars(tags), - "foo": [ImportVar(tag="bar")], - } + assert baseline._get_all_imports() == parse_imports( + { + "react": tags, + "foo": [ImportVar(tag="bar")], + } + ) assert test._get_all_imports() == baseline._get_all_imports() diff --git a/tests/utils/test_imports.py b/tests/utils/test_imports.py index 9a55371362..e9be5c1be0 100644 --- a/tests/utils/test_imports.py +++ b/tests/utils/test_imports.py @@ -1,6 +1,12 @@ import pytest -from reflex.utils.imports import ImportVar, merge_imports +from reflex.utils.imports import ( + ImportDict, + ImportVar, + ParsedImportDict, + merge_imports, + parse_imports, +) @pytest.mark.parametrize( @@ -76,3 +82,32 @@ def test_merge_imports(input_1, input_2, output): for key in output: assert set(res[key]) == set(output[key]) + + +@pytest.mark.parametrize( + "input, output", + [ + ({}, {}), + ( + {"react": "Component"}, + {"react": [ImportVar(tag="Component")]}, + ), + ( + {"react": ["Component"]}, + {"react": [ImportVar(tag="Component")]}, + ), + ( + {"react": ["Component", ImportVar(tag="useState")]}, + {"react": [ImportVar(tag="Component"), ImportVar(tag="useState")]}, + ), + ( + {"react": ["Component"], "foo": "anotherFunction"}, + { + "react": [ImportVar(tag="Component")], + "foo": [ImportVar(tag="anotherFunction")], + }, + ), + ], +) +def test_parse_imports(input: ImportDict, output: ParsedImportDict): + assert parse_imports(input) == output