Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/pandas Performance Warning Issue due to multiple frame.insert #4246

Merged
merged 8 commits into from
Jul 25, 2023
2 changes: 2 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[flake8]
max-line-length = 88
50 changes: 30 additions & 20 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,6 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
and args["y"]
and len(trace_data[[args["x"], args["y"]]].dropna()) > 1
):

# sorting is bad but trace_specs with "trendline" have no other attrs
sorted_trace_data = trace_data.sort_values(by=args["x"])
y = sorted_trace_data[args["y"]].values
Expand Down Expand Up @@ -562,7 +561,6 @@ def set_cartesian_axis_opts(args, axis, letter, orders):


def configure_cartesian_marginal_axes(args, fig, orders):

if "histogram" in [args["marginal_x"], args["marginal_y"]]:
fig.layout["barmode"] = "overlay"

Expand Down Expand Up @@ -1064,14 +1062,14 @@ def _escape_col_name(df_input, col_name, extra):
return col_name


def to_unindexed_series(x):
def to_unindexed_series(x, name=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious why this change is required?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally, it was creating series without a name. I set it as None in case this function was used externally to avoid breaking compatibility.
When we create a dataframe from a dict, it's safer to have named series. Also, it would be easier to debug in line to see which series was created in order to know the column that caused the issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change might not be necessary, but i like the explicitness. Can be useful during debugging.

"""
assuming x is list-like or even an existing pd.Series, return a new pd.Series with
no index, without extracting the data from an existing Series via numpy, which
seems to mangle datetime columns. Stripping the index from existing pd.Series is
required to get things to match up right in the new DataFrame we're building
"""
return pd.Series(x).reset_index(drop=True)
return pd.Series(x, name=name).reset_index(drop=True)


def process_args_into_dataframe(args, wide_mode, var_name, value_name):
Expand All @@ -1086,9 +1084,12 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name):
df_input = args["data_frame"]
df_provided = df_input is not None

df_output = pd.DataFrame()
constants = dict()
ranges = list()
# we use a dict instead of a dataframe directly so that it doesn't cause
# PerformanceWarning by pandas by repeatedly setting the columns.
# a dict is used instead of a list as the columns needs to be overwritten.
df_output = {}
constants = {}
ranges = []
wide_id_vars = set()
reserved_names = _get_reserved_col_names(args) if df_provided else set()

Expand All @@ -1099,7 +1100,7 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name):
"No data were provided. Please provide data either with the `data_frame` or with the `dimensions` argument."
)
else:
df_output[df_input.columns] = df_input[df_input.columns]
df_output = {col: series for col, series in df_input.items()}

# hover_data is a dict
hover_data_is_dict = (
Expand Down Expand Up @@ -1140,7 +1141,7 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name):
# argument_list and field_list ready, iterate over them
# Core of the loop starts here
for i, (argument, field) in enumerate(zip(argument_list, field_list)):
length = len(df_output)
length = len(df_output[next(iter(df_output))]) if len(df_output) else 0
if argument is None:
continue
col_name = None
Expand Down Expand Up @@ -1181,11 +1182,11 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name):
% (
argument,
len(real_argument),
str(list(df_output.columns)),
str(list(df_output.keys())),
length,
)
)
df_output[col_name] = to_unindexed_series(real_argument)
df_output[col_name] = to_unindexed_series(real_argument, col_name)
elif not df_provided:
raise ValueError(
"String or int arguments are only possible when a "
Expand Down Expand Up @@ -1214,13 +1215,15 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name):
% (
field,
len(df_input[argument]),
str(list(df_output.columns)),
str(list(df_output.keys())),
length,
)
)
else:
col_name = str(argument)
df_output[col_name] = to_unindexed_series(df_input[argument])
df_output[col_name] = to_unindexed_series(
df_input[argument], col_name
)
# ----------------- argument is likely a column / array / list.... -------
else:
if df_provided and hasattr(argument, "name"):
Expand All @@ -1247,9 +1250,9 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name):
"All arguments should have the same length. "
"The length of argument `%s` is %d, whereas the "
"length of previously-processed arguments %s is %d"
% (field, len(argument), str(list(df_output.columns)), length)
% (field, len(argument), str(list(df_output.keys())), length)
)
df_output[str(col_name)] = to_unindexed_series(argument)
df_output[str(col_name)] = to_unindexed_series(argument, str(col_name))

# Finally, update argument with column name now that column exists
assert col_name is not None, (
Expand All @@ -1267,12 +1270,19 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name):
if field_name != "wide_variable":
wide_id_vars.add(str(col_name))

for col_name in ranges:
df_output[col_name] = range(len(df_output))

for col_name in constants:
df_output[col_name] = constants[col_name]
length = len(df_output[next(iter(df_output))]) if len(df_output) else 0
df_output.update(
{col_name: to_unindexed_series(range(length), col_name) for col_name in ranges}
)
df_output.update(
{
# constant is single value. repeat by len to avoid creating NaN on concating
col_name: to_unindexed_series([constants[col_name]] * length, col_name)
for col_name in constants
}
)

df_output = pd.DataFrame(df_output)
return df_output, wide_id_vars


Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
from plotly.express._core import build_dataframe, _is_col_list
from pandas.testing import assert_frame_equal
import pytest
import warnings


def test_is_col_list():
Expand Down Expand Up @@ -847,3 +849,29 @@ def test_line_group():
assert len(fig.data) == 4
fig = px.scatter(df, x="x", y=["miss", "score"], color="who")
assert len(fig.data) == 2


def test_no_pd_perf_warning():
n_cols = 1000
n_rows = 1000

columns = list(f"col_{c}" for c in range(n_cols))
index = list(f"i_{r}" for r in range(n_rows))

df = pd.DataFrame(
np.random.uniform(size=(n_rows, n_cols)), index=index, columns=columns
)

with warnings.catch_warnings(record=True) as warn_list:
_ = px.bar(
df,
x=df.index,
y=df.columns[:-2],
labels=df.columns[:-2],
)
performance_warnings = [
warn
for warn in warn_list
if issubclass(warn.category, pd.errors.PerformanceWarning)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does that mean there are other warnings emitted during this px.bar call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there might be. but this test is for checking this warning only. we can change it look out for any pandas warning

]
assert len(performance_warnings) == 0, "PerformanceWarning(s) raised!"