Skip to content

Commit

Permalink
allow classname to be state vars (#3991)
Browse files Browse the repository at this point in the history
* allow classname to be state vars

* simplify join with all literal string vars

* add test case and avoid concat var operation if it's not necessary

* remove silly print statement

* simplify case where there's no var

* don't automatically do class name str to literal var
  • Loading branch information
adhami3310 authored Sep 25, 2024
1 parent 982c43d commit 74d1c47
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 2 deletions.
8 changes: 7 additions & 1 deletion reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
)
from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var
from reflex.vars.sequence import LiteralArrayVar


class BaseComponent(Base, ABC):
Expand Down Expand Up @@ -496,7 +497,12 @@ def __init__(self, *args, **kwargs):
# Convert class_name to str if it's list
class_name = kwargs.get("class_name", "")
if isinstance(class_name, (List, tuple)):
kwargs["class_name"] = " ".join(class_name)
if any(isinstance(c, Var) for c in class_name):
kwargs["class_name"] = LiteralArrayVar.create(
class_name, _var_type=List[str]
).join(" ")
else:
kwargs["class_name"] = " ".join(class_name)

# Construct the component.
super().__init__(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion reflex/components/radix/themes/layout/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def create(
"""
# Apply the default classname
given_class_name = props.pop("class_name", [])
if isinstance(given_class_name, str):
if not isinstance(given_class_name, list):
given_class_name = [given_class_name]
props["class_name"] = ["rx-Stack", *given_class_name]

Expand Down
43 changes: 43 additions & 0 deletions reflex/vars/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,29 @@ def create(
else:
return only_string.to(StringVar, only_string._var_type)

if len(
literal_strings := [
s
for s in filtered_strings_and_vals
if isinstance(s, (str, LiteralStringVar))
]
) == len(filtered_strings_and_vals):
return LiteralStringVar.create(
"".join(
s._var_value if isinstance(s, LiteralStringVar) else s
for s in literal_strings
),
_var_type=_var_type,
_var_data=VarData.merge(
_var_data,
*(
s._get_all_var_data()
for s in filtered_strings_and_vals
if isinstance(s, Var)
),
),
)

concat_result = ConcatVarOperation.create(
*filtered_strings_and_vals,
_var_data=_var_data,
Expand Down Expand Up @@ -736,6 +759,26 @@ def join(self, sep: Any = "") -> StringVar:
"""
if not isinstance(sep, (StringVar, str)):
raise_unsupported_operand_types("join", (type(self), type(sep)))
if (
isinstance(self, LiteralArrayVar)
and (
len(
args := [
x
for x in self._var_value
if isinstance(x, (LiteralStringVar, str))
]
)
== len(self._var_value)
)
and isinstance(sep, (LiteralStringVar, str))
):
sep_str = sep._var_value if isinstance(sep, LiteralStringVar) else sep
return LiteralStringVar.create(
sep_str.join(
i._var_value if isinstance(i, LiteralStringVar) else i for i in args
)
)
return array_join_operation(self, sep)

def reverse(self) -> ArrayVar[ARRAY_VAR_TYPE]:
Expand Down
10 changes: 10 additions & 0 deletions tests/components/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,16 @@ def handler2(self, arg):
[FORMATTED_TEST_VAR],
id="fstring-class_name",
),
pytest.param(
rx.fragment(class_name=f"foo{TEST_VAR}bar other-class"),
[LiteralVar.create(f"{FORMATTED_TEST_VAR} other-class")],
id="fstring-dual-class_name",
),
pytest.param(
rx.fragment(class_name=[TEST_VAR, "other-class"]),
[LiteralVar.create([TEST_VAR, "other-class"]).join(" ")],
id="fstring-dual-class_name",
),
pytest.param(
rx.fragment(special_props=[TEST_VAR]),
[TEST_VAR],
Expand Down

0 comments on commit 74d1c47

Please sign in to comment.