From 9583d52166872aa94fc76646bbd3124e5b19bd25 Mon Sep 17 00:00:00 2001 From: Nikhil Rao Date: Thu, 12 Oct 2023 15:05:35 -0700 Subject: [PATCH] Fix custom components special props --- reflex/components/component.py | 19 ++++++++++------- reflex/components/typography/markdown.py | 26 ++++++++++++++++++++---- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index c79ba20595..b0a907b898 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -540,7 +540,7 @@ def validate_valid_child(child_name): if self.valid_children: validate_valid_child(name) - def _get_custom_code(self) -> Optional[str]: + def _get_custom_code(self) -> str | None: """Get custom code for the component. Returns: @@ -569,7 +569,7 @@ def get_custom_code(self) -> Set[str]: # Return the code. return code - def _get_dynamic_imports(self) -> Optional[str]: + def _get_dynamic_imports(self) -> str | None: """Get dynamic import for the component. Returns: @@ -667,7 +667,7 @@ def _get_hooks_internal(self) -> Set[str]: if hook ) - def _get_hooks(self) -> Optional[str]: + def _get_hooks(self) -> str | None: """Get the React hooks for this component. Downstream components should override this method to add their own hooks. @@ -697,7 +697,7 @@ def get_hooks(self) -> Set[str]: return code - def get_ref(self) -> Optional[str]: + def get_ref(self) -> str | None: """Get the name of the ref for the component. Returns: @@ -723,7 +723,7 @@ def get_refs(self) -> Set[str]: return refs def get_custom_components( - self, seen: Optional[Set[str]] = None + self, seen: set[str] | None = None ) -> Set[CustomComponent]: """Get all the custom components used by the component. @@ -846,7 +846,7 @@ def get_props(cls) -> Set[str]: return set() def get_custom_components( - self, seen: Optional[Set[str]] = None + self, seen: set[str] | None = None ) -> Set[CustomComponent]: """Get all the custom components used by the component. @@ -875,7 +875,10 @@ def _render(self) -> Tag: Returns: The tag to render. """ - return Tag(name=self.tag).add_props(**self.props) + return Tag( + name=self.tag if not self.alias else self.alias, + special_props=self.special_props, + ).add_props(**self.props) def get_prop_vars(self) -> List[BaseVar]: """Get the prop vars. @@ -914,6 +917,8 @@ def custom_component( @wraps(component_fn) def wrapper(*children, **props) -> CustomComponent: + # Remove the children from the props. + props.pop("children", None) return CustomComponent(component_fn=component_fn, children=children, **props) return wrapper diff --git a/reflex/components/typography/markdown.py b/reflex/components/typography/markdown.py index 9390112a8c..dc1c3712aa 100644 --- a/reflex/components/typography/markdown.py +++ b/reflex/components/typography/markdown.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, Union from reflex.compiler import utils -from reflex.components.component import Component +from reflex.components.component import Component, CustomComponent from reflex.components.datadisplay.list import ListItem, OrderedList, UnorderedList from reflex.components.navigation import Link from reflex.components.tags.tag import Tag @@ -19,6 +19,7 @@ # Special vars used in the component map. _CHILDREN = Var.create_safe("children", is_local=False) _PROPS = Var.create_safe("...props", is_local=False) +_MOCK_ARG = Var.create_safe("") # Special remark plugins. _REMARK_MATH = Var.create_safe("remarkMath", is_local=False) @@ -122,6 +123,25 @@ def create(cls, *children, **props) -> Component: # Create the component. return super().create(src, component_map=component_map, **props) + def get_custom_components( + self, seen: set[str] | None = None + ) -> set[CustomComponent]: + """Get all the custom components used by the component. + + Args: + seen: The tags of the components that have already been seen. + + Returns: + The set of custom components. + """ + custom_components = super().get_custom_components(seen=seen) + + # Get the custom components for each tag. + for component in self.component_map.values(): + custom_components |= component(_MOCK_ARG).get_custom_components(seen=seen) + + return custom_components + def _get_imports(self) -> imports.ImportDict: # Import here to avoid circular imports. from reflex.components.datadisplay.code import Code, CodeBlock @@ -145,9 +165,7 @@ def _get_imports(self) -> imports.ImportDict: # Get the imports for each component. for component in self.component_map.values(): - imports = utils.merge_imports( - imports, component(Var.create("")).get_imports() - ) + imports = utils.merge_imports(imports, component(_MOCK_ARG).get_imports()) # Get the imports for the code components. imports = utils.merge_imports(