Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dataframe support in data_editor #2091

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion reflex/.templates/web/utils/helpers/dataeditor.js
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ export function formatCell(value, column) {
switch (column.type) {
case "int":
case "float":
case "int64":
case "float64":
return {
kind: GridCellKind.Number,
data: value,
Expand Down Expand Up @@ -64,4 +66,4 @@ export function formatDataEditorCells(col, row, columns, data) {
return formatCell(cellData, column);
}
return { kind: GridCellKind.Loading };
}
}
173 changes: 116 additions & 57 deletions reflex/components/datadisplay/dataeditor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Union

from pandas import DataFrame

from reflex.base import Base
from reflex.components.component import Component, NoSSRComponent
from reflex.components.literals import LiteralRowMarker
from reflex.components.tags.tag import Tag
from reflex.utils import console, format, imports, types
from reflex.utils.serializers import serializer
from reflex.vars import ImportVar, Var, get_unique_variable_name
from reflex.utils.serializers import serialize, serializer
from reflex.vars import BaseVar, ComputedVar, ImportVar, Var, get_unique_variable_name


# TODO: Fix the serialization issue for custom types.
Expand Down Expand Up @@ -114,10 +117,10 @@ class DataEditor(NoSSRComponent):
rows: Var[int]

# Headers of the columns for the data grid.
columns: Var[List[Dict[str, Any]]]
columns: Var[Union[List[Dict[str, Any]], DataFrame]]

# The data.
data: Var[List[List[Any]]]
data: Var[Union[List[List[Any]], List[Dict[str, Any]], DataFrame]]

# The name of the callback used to find the data to display.
get_cell_content: Var[str]
Expand Down Expand Up @@ -200,6 +203,9 @@ class DataEditor(NoSSRComponent):
# global theme
theme: Var[Union[DataEditorTheme, Dict]]

# internal_value
_editor_id: str = ""

def _get_imports(self):
return imports.merge_imports(
super()._get_imports(),
Expand Down Expand Up @@ -248,17 +254,18 @@ def edit_sig(pos, data: dict[str, Any]):
}

def _get_hooks(self) -> str | None:
# Define the id of the component in case multiple are used in the same page.
editor_id = get_unique_variable_name()

# Define the name of the getData callback associated with this component and assign to get_cell_content.
data_callback = f"getData_{editor_id}"
data_callback = f"getData_{self._editor_id}"
self.get_cell_content = Var.create(data_callback, _var_is_local=False) # type: ignore

code = [f"function {data_callback}([col, row])" "{"]

columns_path = f"{self.columns._var_full_name}"
data_path = f"{self.data._var_full_name}"
if issubclass(self.data._var_type, DataFrame):
columns_path = f"{self.data._var_full_name}.columns"
data_path = f"{self.data._var_full_name}.data"
else:
columns_path = f"{self.columns._var_full_name}"
data_path = f"{self.data._var_full_name}"

code.extend(
[
Expand All @@ -285,31 +292,48 @@ def create(cls, *children, **props) -> Component:
"""
from reflex.el.elements import Div

columns = props.get("columns", [])
columns = props.get("columns", None)
data = props.get("data", [])
rows = props.get("rows", None)

# If rows is not provided, determine from data.
if rows is None:
props["rows"] = (
data.length() # BaseVar.create(value=f"{data}.length()", is_local=False)
if isinstance(data, Var)
else len(data)
# check ComputerVar return type annotation for data and columns props.
if isinstance(data, ComputedVar) and data._var_type == Any:
raise ValueError(
f"Return type annotation for the computed var {data._var_full_name} should be provided."
)

if (
columns is not None
and isinstance(columns, ComputedVar)
and columns._var_type == Any
):
raise ValueError(
f"Return type annotation for the computed var {columns._var_full_name} should be provided."
)

if not isinstance(columns, Var) and len(columns):
if (
types.is_dataframe(type(data))
or isinstance(data, Var)
and types.is_dataframe(data._var_type)
):
if types.is_dataframe(type(columns)) or (
isinstance(columns, Var) and types.is_dataframe(columns._var_type)
):
raise ValueError("DataFrame should be passed to the `data` props instead.")

if types.is_dataframe(type(data)) or (
isinstance(data, Var) and types.is_dataframe(data._var_type)
):
if columns is not None:
raise ValueError(
"Cannot pass in both a pandas dataframe and columns to the data_editor component."
"Cannot pass in both a pandas dataframe and columns to the data_table component."
)
else:
props["columns"] = [
format.format_data_editor_column(col) for col in columns
]
props["columns"] = data

# If rows is not provided, determine from data.
if rows is None:
props["rows"] = data.length() if isinstance(data, Var) else len(data)

if isinstance(columns, List):
props["columns"] = [
format.format_data_editor_column(col) for col in columns
]

if "theme" in props:
theme = props.get("theme")
Expand All @@ -326,13 +350,38 @@ def create(cls, *children, **props) -> Component:
console.warn(
"get_cell_content is not user configurable, the provided value will be discarded"
)
grid = super().create(*children, **props)
editor = super().create(*children, **props)

# Define the id of the component in case multiple are used in the same page.
editor._editor_id = get_unique_variable_name()
return Div.create(
grid,
editor,
width=props.pop("width", "100%"),
height=props.pop("height", "100%"),
)

def _render(self) -> Tag:
if isinstance(self.data, Var) and types.is_dataframe(self.data._var_type):
self.columns = BaseVar(
_var_name=f"{self.data._var_name}.columns",
_var_type=List[Any],
_var_state=self.data._var_state,
)
self.data = BaseVar(
_var_name=f"{self.data._var_name}.data",
_var_type=List[List[Any]],
_var_state=self.data._var_state,
)
if types.is_dataframe(type(self.data)):
# If given a pandas df break up the data and columns
data = serialize(self.data)
assert isinstance(data, dict), "Serialized dataframe should be a dict."
self.columns = Var.create_safe(data["columns"])
self.data = Var.create_safe(data["data"])

# Render the table.
return super()._render()

def _get_app_wrap_components(self) -> dict[tuple[int, str], Component]:
"""Get the app wrap components for the component.

Expand All @@ -348,41 +397,51 @@ def get_ref(self):
return {(-1, "DataEditorPortal"): Portal.create(id="portal")}


# try:
# pass
def format_dataframe_columns(df: DataFrame) -> List[Dict[str, str]]:
"""Format dataframe columns to a list of dicts.

Args:
df: The dataframe to format.

Returns:
The dataframe columns as a list of dicts.
"""
formatted_columns = [
format.format_data_editor_column({"title": column_, "type": str(type_)})
for column_, type_ in df.dtypes.items()
]
return formatted_columns


# # def format_dataframe_values(df: DataFrame) -> list[list[Any]]:
# # """Format dataframe values to a list of lists.
def format_dataframe_values(df: DataFrame) -> list[list[Any]]:
"""Format dataframe values to a list of lists.

# # Args:
# # df: The dataframe to format.
Args:
df: The dataframe to format.

# # Returns:
# # The dataframe as a list of lists.
# # """
# # return [
# # [str(d) if isinstance(d, (list, tuple)) else d for d in data]
# # for data in list(df.values.tolist())
# # ]
# # ...
Returns:
The dataframe as a list of lists.
"""
return [
[str(d) if isinstance(d, (list, tuple)) else d for d in data]
for data in list(df.values.tolist())
]

# # @serializer
# # def serialize_dataframe(df: DataFrame) -> dict:
# # """Serialize a pandas dataframe.

# # Args:
# # df: The dataframe to serialize.
@serializer(override=True)
def serialize_dataframe(df: DataFrame) -> dict:
"""Serialize a pandas dataframe.

# # Returns:
# # The serialized dataframe.
# # """
# # return {
# # "columns": df.columns.tolist(),
# # "data": format_dataframe_values(df),
# # }
Args:
df: The dataframe to serialize.

# except ImportError:
# pass
Returns:
The serialized dataframe.
"""
return {
"columns": format_dataframe_columns(df),
"data": format_dataframe_values(df),
}


@serializer
Expand Down
2 changes: 1 addition & 1 deletion reflex/utils/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def format_data_editor_column(col: str | dict):
if isinstance(col, (dict,)):
if "id" not in col:
col["id"] = col["title"].lower()
if "type" not in col:
if "type" not in col or col["type"] == "object":
col["type"] = "str"
if "overlayIcon" not in col:
col["overlayIcon"] = None
Expand Down
76 changes: 44 additions & 32 deletions reflex/utils/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union, get_type_hints

from reflex.base import Base
from reflex.utils import exceptions, format, types
from reflex.utils import console, exceptions, format, types

# Mapping from type to a serializer.
# The serializer should convert the type to a JSON object.
Expand All @@ -16,44 +16,56 @@
SERIALIZERS: dict[Type, Serializer] = {}


def serializer(fn: Serializer) -> Serializer:
def serializer(override: Any = False) -> Any: # type: ignore
"""Decorator to add a serializer for a given type.

Args:
fn: The function to decorate.
override: If the serializer can override an already defined one of the same type.

Returns:
The decorated function.

Raises:
ValueError: If the function does not take a single argument.
The serializer decorator.
"""
# Get the global serializers.
global SERIALIZERS

# Check the type hints to get the type of the argument.
type_hints = get_type_hints(fn)
args = [arg for arg in type_hints if arg != "return"]

# Make sure the function takes a single argument.
if len(args) != 1:
raise ValueError("Serializer must take a single argument.")

# Get the type of the argument.
type_ = type_hints[args[0]]

# Make sure the type is not already registered.
registered_fn = SERIALIZERS.get(type_)
if registered_fn is not None and registered_fn != fn:
raise ValueError(
f"Serializer for type {type_} is already registered as {registered_fn.__qualname__}."
)

# Register the serializer.
SERIALIZERS[type_] = fn

# Return the function.
return fn
def inner_serializer(fn: Serializer):
# Get the global serializers.
global SERIALIZERS

# Check the type hints to get the type of the argument.
type_hints = get_type_hints(fn)
args = [arg for arg in type_hints if arg != "return"]

# Make sure the function takes a single argument.
if len(args) != 1:
raise ValueError("Serializer must take a single argument.")

# Get the type of the argument.
type_ = type_hints[args[0]]

# Make sure the type is not already registered.
registered_fn = SERIALIZERS.get(type_)
if registered_fn is not None and registered_fn != fn:
if override:
console.warn(
f"Overriding serializer for type {type_}: Replacing {registered_fn.__module__}.{registered_fn.__qualname__} with {fn.__module__}.{fn.__qualname__}"
)
else:
console.warn(
f"Serializer for type {type_} is already registered as {registered_fn.__module__}.{registered_fn.__qualname__}."
)
return fn

# Register the serializer.
SERIALIZERS[type_] = fn

# Return the function.
return fn

if callable(override):
_fn = override
override = False
return inner_serializer(_fn)
else:
return inner_serializer


def serialize(value: Any) -> SerializedType | None:
Expand Down
6 changes: 5 additions & 1 deletion reflex/vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
get_type_hints,
)

from pandas import DataFrame

from reflex import constants
from reflex.base import Base
from reflex.utils import console, format, serializers, types
Expand Down Expand Up @@ -599,7 +601,9 @@ def length(self) -> Var:
Raises:
TypeError: If the var is not a list.
"""
if not types._issubclass(self._var_type, List):
if not types._issubclass(self._var_type, List) and not types._issubclass(
self._var_type, DataFrame
):
raise TypeError(f"Cannot get length of non-list var {self}.")
return BaseVar(
_var_name=f"{self._var_full_name}.length",
Expand Down
Loading