diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py index 2ccd496a6c3..24264bf3b41 100644 --- a/packages/python/plotly/plotly/express/_chart_types.py +++ b/packages/python/plotly/plotly/express/_chart_types.py @@ -16,6 +16,7 @@ def scatter( text=None, facet_row=None, facet_col=None, + facet_col_wrap=0, error_x=None, error_x_minus=None, error_y=None, @@ -65,6 +66,7 @@ def density_contour( color=None, facet_row=None, facet_col=None, + facet_col_wrap=0, hover_name=None, hover_data=None, animation_frame=None, @@ -120,6 +122,7 @@ def density_heatmap( z=None, facet_row=None, facet_col=None, + facet_col_wrap=0, hover_name=None, hover_data=None, animation_frame=None, @@ -180,6 +183,7 @@ def line( text=None, facet_row=None, facet_col=None, + facet_col_wrap=0, error_x=None, error_x_minus=None, error_y=None, @@ -225,6 +229,7 @@ def area( text=None, facet_row=None, facet_col=None, + facet_col_wrap=0, animation_frame=None, animation_group=None, category_orders={}, @@ -267,6 +272,7 @@ def bar( color=None, facet_row=None, facet_col=None, + facet_col_wrap=0, hover_name=None, hover_data=None, custom_data=None, @@ -318,6 +324,7 @@ def histogram( color=None, facet_row=None, facet_col=None, + facet_col_wrap=0, hover_name=None, hover_data=None, animation_frame=None, @@ -376,6 +383,7 @@ def violin( color=None, facet_row=None, facet_col=None, + facet_col_wrap=0, hover_name=None, hover_data=None, custom_data=None, @@ -427,6 +435,7 @@ def box( color=None, facet_row=None, facet_col=None, + facet_col_wrap=0, hover_name=None, hover_data=None, custom_data=None, @@ -473,6 +482,7 @@ def strip( color=None, facet_row=None, facet_col=None, + facet_col_wrap=0, hover_name=None, hover_data=None, custom_data=None, diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index e5846be0668..3e01f445c26 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -233,7 +233,7 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref): result["y"] = trendline[:, 1] hover_header = "LOWESS trendline

" elif v == "ols": - fit_results = sm.OLS(y, sm.add_constant(x)).fit() + fit_results = sm.OLS(y.values, sm.add_constant(x.values)).fit() result["y"] = fit_results.predict() hover_header = "OLS trendline
" hover_header += "%s = %f * %s + %f
" % ( @@ -747,10 +747,10 @@ def apply_default_cascade(args): ] # If both marginals and faceting are specified, faceting wins - if args.get("facet_col", None) and args.get("marginal_y", None): + if args.get("facet_col", None) is not None and args.get("marginal_y", None): args["marginal_y"] = None - if args.get("facet_row", None) and args.get("marginal_x", None): + if args.get("facet_row", None) is not None and args.get("marginal_x", None): args["marginal_x"] = None @@ -874,7 +874,7 @@ def build_dataframe(args, attrables, array_attrables): "pandas MultiIndex is not supported by plotly express " "at the moment." % field ) - ## ----------------- argument is a col name ---------------------- + # ----------------- argument is a col name ---------------------- if isinstance(argument, str) or isinstance( argument, int ): # just a column name given as str or int @@ -1042,6 +1042,13 @@ def infer_config(args, constructor, trace_patch): args[position] = args["marginal"] args[other_position] = None + if ( + args.get("marginal_x", None) is not None + or args.get("marginal_y", None) is not None + or args.get("facet_row", None) is not None + ): + args["facet_col_wrap"] = 0 + # Compute applicable grouping attributes for k in group_attrables: if k in args: @@ -1098,15 +1105,14 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): orders, sorted_group_names = get_orderings(args, grouper, grouped) - has_marginal_x = bool(args.get("marginal_x", False)) - has_marginal_y = bool(args.get("marginal_y", False)) - subplot_type = _subplot_type_for_trace_type(constructor().type) trace_names_by_frame = {} frames = OrderedDict() trendline_rows = [] nrows = ncols = 1 + col_labels = [] + row_labels = [] for group_name in sorted_group_names: group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0]) mapping_labels = OrderedDict() @@ -1188,27 +1194,36 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): # Find row for trace, handling facet_row and marginal_x if m.facet == "row": row = m.val_map[val] - trace._subplot_row_val = val + if args["facet_row"] and len(row_labels) < row: + row_labels.append(args["facet_row"] + "=" + str(val)) else: - if has_marginal_x and trace_spec.marginal != "x": + if ( + bool(args.get("marginal_x", False)) + and trace_spec.marginal != "x" + ): row = 2 else: row = 1 - nrows = max(nrows, row) - if row > 1: - trace._subplot_row = row - + facet_col_wrap = args.get("facet_col_wrap", 0) # Find col for trace, handling facet_col and marginal_y if m.facet == "col": col = m.val_map[val] - trace._subplot_col_val = val + if args["facet_col"] and len(col_labels) < col: + col_labels.append(args["facet_col"] + "=" + str(val)) + if facet_col_wrap: # assumes no facet_row, no marginals + row = 1 + ((col - 1) // facet_col_wrap) + col = 1 + ((col - 1) % facet_col_wrap) else: if trace_spec.marginal == "y": col = 2 else: col = 1 + nrows = max(nrows, row) + if row > 1: + trace._subplot_row = row + ncols = max(ncols, col) if col > 1: trace._subplot_col = col @@ -1238,7 +1253,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): if show_colorbar: colorvar = "z" if constructor == go.Histogram2d else "color" range_color = args["range_color"] or [None, None] - d = len(args["color_continuous_scale"]) - 1 colorscale_validator = ColorscaleValidator("colorscale", "make_figure") layout_patch["coloraxis1"] = dict( @@ -1260,7 +1274,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): layout_patch["legend"]["itemsizing"] = "constant" fig = init_figure( - args, subplot_type, frame_list, ncols, nrows, has_marginal_x, has_marginal_y + args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels ) # Position traces in subplots @@ -1290,30 +1304,18 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): return fig -def init_figure( - args, subplot_type, frame_list, ncols, nrows, has_marginal_x, has_marginal_y -): +def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels): # Build subplot specs specs = [[{}] * ncols for _ in range(nrows)] - column_titles = [None] * ncols - row_titles = [None] * nrows for frame in frame_list: for trace in frame["data"]: row0 = trace._subplot_row - 1 col0 = trace._subplot_col - 1 - if isinstance(trace, go.Splom): # Splom not compatible with make_subplots, treat as domain specs[row0][col0] = {"type": "domain"} else: specs[row0][col0] = {"type": trace.type} - if args.get("facet_row", None) and hasattr(trace, "_subplot_row_val"): - row_titles[row0] = args["facet_row"] + "=" + str(trace._subplot_row_val) - - if args.get("facet_col", None) and hasattr(trace, "_subplot_col_val"): - column_titles[col0] = ( - args["facet_col"] + "=" + str(trace._subplot_col_val) - ) # Default row/column widths uniform column_widths = [1.0] * ncols @@ -1321,7 +1323,7 @@ def init_figure( # Build column_widths/row_heights if subplot_type == "xy": - if has_marginal_x: + if bool(args.get("marginal_x", False)): if args["marginal_x"] == "histogram" or ("color" in args and args["color"]): main_size = 0.74 else: @@ -1329,10 +1331,12 @@ def init_figure( row_heights = [main_size] * (nrows - 1) + [1 - main_size] vertical_spacing = 0.01 + elif args.get("facet_col_wrap", 0): + vertical_spacing = 0.07 else: vertical_spacing = 0.03 - if has_marginal_y: + if bool(args.get("marginal_y", False)): if args["marginal_y"] == "histogram" or ("color" in args and args["color"]): main_size = 0.74 else: @@ -1351,6 +1355,15 @@ def init_figure( vertical_spacing = 0.1 horizontal_spacing = 0.1 + facet_col_wrap = args.get("facet_col_wrap", 0) + if facet_col_wrap: + subplot_labels = [None] * nrows * ncols + while len(col_labels) < nrows * ncols: + col_labels.append(None) + for i in range(nrows): + for j in range(ncols): + subplot_labels[i * ncols + j] = col_labels[(nrows - 1 - i) * ncols + j] + # Create figure with subplots fig = make_subplots( rows=nrows, @@ -1358,8 +1371,9 @@ def init_figure( specs=specs, shared_xaxes="all", shared_yaxes="all", - row_titles=list(reversed(row_titles)), - column_titles=column_titles, + row_titles=[] if facet_col_wrap else list(reversed(row_labels)), + column_titles=[] if facet_col_wrap else col_labels, + subplot_titles=subplot_labels if facet_col_wrap else [], horizontal_spacing=horizontal_spacing, vertical_spacing=vertical_spacing, row_heights=row_heights, diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index 1bf75f17e49..653c71b886b 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -183,6 +183,12 @@ colref_desc, "Values from this column or array_like are used to assign marks to facetted subplots in the horizontal direction.", ], + facet_col_wrap=[ + "int", + "Maximum number of facet columns.", + "Wraps the column variable at this width, so that the column facets span multiple rows.", + "Ignored if 0, and forced to 0 if `facet_row` or a `marginal` is set.", + ], animation_frame=[ colref_type, colref_desc, diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index 08bb1a9cc95..89e3a921c9e 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -61,6 +61,9 @@ def test_pandas_series(): assert fig.data[0].hovertemplate == "day=%{x}
y=%{y}" fig = px.bar(tips, x="day", y=before_tip, labels={"y": "bill"}) assert fig.data[0].hovertemplate == "day=%{x}
bill=%{y}" + # lock down that we can pass df.col to facet_* + fig = px.bar(tips, x="day", y="tip", facet_row=tips.day, facet_col=tips.day) + assert fig.data[0].hovertemplate == "day=%{x}
tip=%{y}" def test_several_dataframes(): diff --git a/test/percy/plotly-express.py b/test/percy/plotly-express.py index 7ca527efbb8..01b9bb63386 100644 --- a/test/percy/plotly-express.py +++ b/test/percy/plotly-express.py @@ -184,6 +184,33 @@ import plotly.express as px +tips = px.data.tips() +fig = px.scatter( + tips, + x="day", + y="tip", + facet_col="day", + facet_col_wrap=2, + category_orders={"day": ["Thur", "Fri", "Sat", "Sun"]}, +) +fig.write_html(os.path.join(dir_name, "facet_wrap_neat.html")) + +import plotly.express as px + +tips = px.data.tips() +fig = px.scatter( + tips, + x="day", + y="tip", + color="sex", + facet_col="day", + facet_col_wrap=3, + category_orders={"day": ["Thur", "Fri", "Sat", "Sun"]}, +) +fig.write_html(os.path.join(dir_name, "facet_wrap_ragged.html")) + +import plotly.express as px + gapminder = px.data.gapminder() fig = px.area(gapminder, x="year", y="pop", color="continent", line_group="country") fig.write_html(os.path.join(dir_name, "area.html"))