Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

only interchange necessary columns #4286

Merged
merged 5 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
this feature was anonymously sponsored: thank you to our sponsor!
- Add `legend.xref` and `legend.yref` to enable container-referenced positioning of legends [[#6589](https://github.com/plotly/plotly.js/pull/6589)], with thanks to [Gamma Technologies](https://www.gtisoft.com/) for sponsoring the related development.
- Add `colorbar.xref` and `colorbar.yref` to enable container-referenced positioning of colorbars [[#6593](https://github.com/plotly/plotly.js/pull/6593)], with thanks to [Gamma Technologies](https://www.gtisoft.com/) for sponsoring the related development.
- `px` methods now accept data-frame-like objects that support a `to_pandas()` method, such as polars, cudf, vaex etc
- `px` methods now accept data-frame-like objects that support a `to_pandas()` method, such as polars, cudf, vaex etc [[#4244](https://github.com/plotly/plotly.py/pull/4244)], [[#4286](https://github.com/plotly/plotly.py/pull/4286)]

### Fixed
- Fixed another compatibility issue with Pandas 2.0, just affecting `px.*(line_close=True)` [[#4190](https://github.com/plotly/plotly.py/pull/4190)]
Expand Down
83 changes: 56 additions & 27 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,7 @@ def _get_reserved_col_names(args):
return reserved_names


def _is_col_list(df_input, arg):
def _is_col_list(columns, arg):
"""Returns True if arg looks like it's a list of columns or references to columns
in df_input, and False otherwise (in which case it's assumed to be a single column
or reference to a column).
Expand All @@ -1033,7 +1033,7 @@ def _is_col_list(df_input, arg):
return False # not iterable
for c in arg:
if isinstance(c, str) or isinstance(c, int):
if df_input is None or c not in df_input.columns:
if columns is None or c not in columns:
return False
else:
try:
Expand All @@ -1059,8 +1059,8 @@ def _isinstance_listlike(x):
return True


def _escape_col_name(df_input, col_name, extra):
while df_input is not None and (col_name in df_input.columns or col_name in extra):
def _escape_col_name(columns, col_name, extra):
while columns is not None and (col_name in columns or col_name in extra):
col_name = "_" + col_name
return col_name

Expand Down Expand Up @@ -1307,37 +1307,36 @@ def build_dataframe(args, constructor):

# Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.)
df_provided = args["data_frame"] is not None
needs_interchanging = False
if df_provided and not isinstance(args["data_frame"], pd.DataFrame):
if hasattr(args["data_frame"], "__dataframe__") and version.parse(
pd.__version__
) >= version.parse("2.0.2"):
import pandas.api.interchange

df_not_pandas = args["data_frame"]
try:
df_pandas = pandas.api.interchange.from_dataframe(df_not_pandas)
except (ImportError, NotImplementedError) as exc:
# temporary workaround; developers of third-party libraries themselves
# should try a different implementation, if available. For example:
# def __dataframe__(self, ...):
# if not some_condition:
# self.to_pandas(...)
if not hasattr(df_not_pandas, "to_pandas"):
raise exc
df_pandas = df_not_pandas.to_pandas()
args["data_frame"] = df_pandas
args["data_frame"] = df_not_pandas.__dataframe__()
columns = args["data_frame"].column_names()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we be 100% sure anything returned by __dataframe__() will have column_names and select_columns_by_name methods? If there's any chance an object will come in with either of these missing we should fall back on interchanging the whole thing up front.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I know they're in the spec, but I also know not everyone follows a spec to the letter 😉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can help highlight shortcomings in their implementation then 😉 I tried it out with polars and it works fine there

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we'll know how to respond when we see:
AttributeError: 'MyDataFrame' object has no attribute 'select_columns_by_name'
And in principle you're right that it's not our problem, but we'll be the ones responding to the issue and having to tell our users "don't use this dataframe directly until they fix it." Whereas if we caught this case explicitly we could emit a warning like "This dataframe only partially implements the dataframe interchange protocol. Falling back on a slower full-copy algorithm" so it wouldn't affect usage in px, only performance, and it would be clear where the issue needs to be raised.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for explaining - OK I've added a condition so it'll only use select_columns_by_name if that attribute is present

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. I took a look at also adding a fallback for missing column_names and that would be pretty awkward... but if someone has a partial implementation of the protocol presumably column_names is an easy piece so would get included early, whereas select_columns_by_name could be trickier. So let's leave it as you have it now. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah if they don't have column_names then from_dataframe wouldn't work either, as it uses that internally

https://github.com/pandas-dev/pandas/blob/92792ec063031ae41443dabeb9d12f8daaac3ef1/pandas/core/interchange/from_dataframe.py#L112

needs_interchanging = True
elif hasattr(args["data_frame"], "to_pandas"):
args["data_frame"] = args["data_frame"].to_pandas()
columns = args["data_frame"].columns
else:
args["data_frame"] = pd.DataFrame(args["data_frame"])
columns = args["data_frame"].columns
elif df_provided:
columns = args["data_frame"].columns
else:
columns = None

df_input = args["data_frame"]

# now we handle special cases like wide-mode or x-xor-y specification
# by rearranging args to tee things up for process_args_into_dataframe to work
no_x = args.get("x") is None
no_y = args.get("y") is None
wide_x = False if no_x else _is_col_list(df_input, args["x"])
wide_y = False if no_y else _is_col_list(df_input, args["y"])
wide_x = False if no_x else _is_col_list(columns, args["x"])
wide_y = False if no_y else _is_col_list(columns, args["y"])

wide_mode = False
var_name = None # will likely be "variable" in wide_mode
Expand All @@ -1352,15 +1351,18 @@ def build_dataframe(args, constructor):
)
if df_provided and no_x and no_y:
wide_mode = True
if isinstance(df_input.columns, pd.MultiIndex):
if isinstance(columns, pd.MultiIndex):
raise TypeError(
"Data frame columns is a pandas MultiIndex. "
"pandas MultiIndex is not supported by plotly express "
"at the moment."
)
args["wide_variable"] = list(df_input.columns)
var_name = df_input.columns.name
if var_name in [None, "value", "index"] or var_name in df_input:
args["wide_variable"] = list(columns)
if isinstance(columns, pd.Index):
var_name = columns.name
else:
var_name = None
if var_name in [None, "value", "index"] or var_name in columns:
var_name = "variable"
if constructor == go.Funnel:
wide_orientation = args.get("orientation") or "h"
Expand All @@ -1371,12 +1373,12 @@ def build_dataframe(args, constructor):
elif wide_x != wide_y:
wide_mode = True
args["wide_variable"] = args["y"] if wide_y else args["x"]
if df_provided and args["wide_variable"] is df_input.columns:
var_name = df_input.columns.name
if df_provided and args["wide_variable"] is columns:
var_name = columns.name
if isinstance(args["wide_variable"], pd.Index):
args["wide_variable"] = list(args["wide_variable"])
if var_name in [None, "value", "index"] or (
df_provided and var_name in df_input
df_provided and var_name in columns
):
var_name = "variable"
if hist1d_orientation:
Expand All @@ -1389,8 +1391,35 @@ def build_dataframe(args, constructor):
wide_cross_name = "__x__" if wide_y else "__y__"

if wide_mode:
value_name = _escape_col_name(df_input, "value", [])
var_name = _escape_col_name(df_input, var_name, [])
value_name = _escape_col_name(columns, "value", [])
var_name = _escape_col_name(columns, var_name, [])

if needs_interchanging:
try:
if wide_mode or not hasattr(args["data_frame"], "select_columns_by_name"):
args["data_frame"] = pd.api.interchange.from_dataframe(
args["data_frame"]
)
else:
# Save precious resources by only interchanging columns that are
# actually going to be plotted.
columns = [
i for i in args.values() if isinstance(i, str) and i in columns
]
args["data_frame"] = pd.api.interchange.from_dataframe(
args["data_frame"].select_columns_by_name(columns)
)
except (ImportError, NotImplementedError) as exc:
# temporary workaround; developers of third-party libraries themselves
# should try a different implementation, if available. For example:
# def __dataframe__(self, ...):
# if not some_condition:
# self.to_pandas(...)
if not hasattr(df_not_pandas, "to_pandas"):
raise exc
args["data_frame"] = df_not_pandas.to_pandas()

df_input = args["data_frame"]

missing_bar_dim = None
if (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,21 +252,58 @@ def test_build_df_with_index():
def test_build_df_using_interchange_protocol_mock(
add_interchange_module_for_old_pandas,
):
class InterchangeDataFrame:
def __init__(self, columns):
self._columns = columns

def column_names(self):
return self._columns

interchange_dataframe = InterchangeDataFrame(
["petal_width", "sepal_length", "sepal_width"]
)
interchange_dataframe_reduced = InterchangeDataFrame(
["petal_width", "sepal_length"]
)
interchange_dataframe.select_columns_by_name = mock.MagicMock(
return_value=interchange_dataframe_reduced
)
interchange_dataframe_reduced.select_columns_by_name = mock.MagicMock(
return_value=interchange_dataframe_reduced
)

class CustomDataFrame:
def __dataframe__(self):
pass
return interchange_dataframe

class CustomDataFrameReduced:
def __dataframe__(self):
return interchange_dataframe_reduced

input_dataframe = CustomDataFrame()
args = dict(data_frame=input_dataframe, x="petal_width", y="sepal_length")
input_dataframe_reduced = CustomDataFrameReduced()

iris_pandas = px.data.iris()

with mock.patch("pandas.__version__", "2.0.2"):
args = dict(data_frame=input_dataframe, x="petal_width", y="sepal_length")
with mock.patch(
"pandas.api.interchange.from_dataframe", return_value=iris_pandas
) as mock_from_dataframe:
build_dataframe(args, go.Scatter)
mock_from_dataframe.assert_called_once_with(input_dataframe)
mock_from_dataframe.assert_called_once_with(interchange_dataframe_reduced)
interchange_dataframe.select_columns_by_name.assert_called_with(
["petal_width", "sepal_length"]
)

args = dict(data_frame=input_dataframe_reduced, color=None)
with mock.patch(
"pandas.api.interchange.from_dataframe",
return_value=iris_pandas[["petal_width", "sepal_length"]],
) as mock_from_dataframe:
build_dataframe(args, go.Scatter)
mock_from_dataframe.assert_called_once_with(interchange_dataframe_reduced)
interchange_dataframe_reduced.select_columns_by_name.assert_not_called()


@pytest.mark.skipif(
Expand Down