diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py
index 2ccd496a6c3..24264bf3b41 100644
--- a/packages/python/plotly/plotly/express/_chart_types.py
+++ b/packages/python/plotly/plotly/express/_chart_types.py
@@ -16,6 +16,7 @@ def scatter(
text=None,
facet_row=None,
facet_col=None,
+ facet_col_wrap=0,
error_x=None,
error_x_minus=None,
error_y=None,
@@ -65,6 +66,7 @@ def density_contour(
color=None,
facet_row=None,
facet_col=None,
+ facet_col_wrap=0,
hover_name=None,
hover_data=None,
animation_frame=None,
@@ -120,6 +122,7 @@ def density_heatmap(
z=None,
facet_row=None,
facet_col=None,
+ facet_col_wrap=0,
hover_name=None,
hover_data=None,
animation_frame=None,
@@ -180,6 +183,7 @@ def line(
text=None,
facet_row=None,
facet_col=None,
+ facet_col_wrap=0,
error_x=None,
error_x_minus=None,
error_y=None,
@@ -225,6 +229,7 @@ def area(
text=None,
facet_row=None,
facet_col=None,
+ facet_col_wrap=0,
animation_frame=None,
animation_group=None,
category_orders={},
@@ -267,6 +272,7 @@ def bar(
color=None,
facet_row=None,
facet_col=None,
+ facet_col_wrap=0,
hover_name=None,
hover_data=None,
custom_data=None,
@@ -318,6 +324,7 @@ def histogram(
color=None,
facet_row=None,
facet_col=None,
+ facet_col_wrap=0,
hover_name=None,
hover_data=None,
animation_frame=None,
@@ -376,6 +383,7 @@ def violin(
color=None,
facet_row=None,
facet_col=None,
+ facet_col_wrap=0,
hover_name=None,
hover_data=None,
custom_data=None,
@@ -427,6 +435,7 @@ def box(
color=None,
facet_row=None,
facet_col=None,
+ facet_col_wrap=0,
hover_name=None,
hover_data=None,
custom_data=None,
@@ -473,6 +482,7 @@ def strip(
color=None,
facet_row=None,
facet_col=None,
+ facet_col_wrap=0,
hover_name=None,
hover_data=None,
custom_data=None,
diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py
index e5846be0668..3e01f445c26 100644
--- a/packages/python/plotly/plotly/express/_core.py
+++ b/packages/python/plotly/plotly/express/_core.py
@@ -233,7 +233,7 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
result["y"] = trendline[:, 1]
hover_header = "LOWESS trendline
"
elif v == "ols":
- fit_results = sm.OLS(y, sm.add_constant(x)).fit()
+ fit_results = sm.OLS(y.values, sm.add_constant(x.values)).fit()
result["y"] = fit_results.predict()
hover_header = "OLS trendline
"
hover_header += "%s = %f * %s + %f
" % (
@@ -747,10 +747,10 @@ def apply_default_cascade(args):
]
# If both marginals and faceting are specified, faceting wins
- if args.get("facet_col", None) and args.get("marginal_y", None):
+ if args.get("facet_col", None) is not None and args.get("marginal_y", None):
args["marginal_y"] = None
- if args.get("facet_row", None) and args.get("marginal_x", None):
+ if args.get("facet_row", None) is not None and args.get("marginal_x", None):
args["marginal_x"] = None
@@ -874,7 +874,7 @@ def build_dataframe(args, attrables, array_attrables):
"pandas MultiIndex is not supported by plotly express "
"at the moment." % field
)
- ## ----------------- argument is a col name ----------------------
+ # ----------------- argument is a col name ----------------------
if isinstance(argument, str) or isinstance(
argument, int
): # just a column name given as str or int
@@ -1042,6 +1042,13 @@ def infer_config(args, constructor, trace_patch):
args[position] = args["marginal"]
args[other_position] = None
+ if (
+ args.get("marginal_x", None) is not None
+ or args.get("marginal_y", None) is not None
+ or args.get("facet_row", None) is not None
+ ):
+ args["facet_col_wrap"] = 0
+
# Compute applicable grouping attributes
for k in group_attrables:
if k in args:
@@ -1098,15 +1105,14 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
orders, sorted_group_names = get_orderings(args, grouper, grouped)
- has_marginal_x = bool(args.get("marginal_x", False))
- has_marginal_y = bool(args.get("marginal_y", False))
-
subplot_type = _subplot_type_for_trace_type(constructor().type)
trace_names_by_frame = {}
frames = OrderedDict()
trendline_rows = []
nrows = ncols = 1
+ col_labels = []
+ row_labels = []
for group_name in sorted_group_names:
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
mapping_labels = OrderedDict()
@@ -1188,27 +1194,36 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
# Find row for trace, handling facet_row and marginal_x
if m.facet == "row":
row = m.val_map[val]
- trace._subplot_row_val = val
+ if args["facet_row"] and len(row_labels) < row:
+ row_labels.append(args["facet_row"] + "=" + str(val))
else:
- if has_marginal_x and trace_spec.marginal != "x":
+ if (
+ bool(args.get("marginal_x", False))
+ and trace_spec.marginal != "x"
+ ):
row = 2
else:
row = 1
- nrows = max(nrows, row)
- if row > 1:
- trace._subplot_row = row
-
+ facet_col_wrap = args.get("facet_col_wrap", 0)
# Find col for trace, handling facet_col and marginal_y
if m.facet == "col":
col = m.val_map[val]
- trace._subplot_col_val = val
+ if args["facet_col"] and len(col_labels) < col:
+ col_labels.append(args["facet_col"] + "=" + str(val))
+ if facet_col_wrap: # assumes no facet_row, no marginals
+ row = 1 + ((col - 1) // facet_col_wrap)
+ col = 1 + ((col - 1) % facet_col_wrap)
else:
if trace_spec.marginal == "y":
col = 2
else:
col = 1
+ nrows = max(nrows, row)
+ if row > 1:
+ trace._subplot_row = row
+
ncols = max(ncols, col)
if col > 1:
trace._subplot_col = col
@@ -1238,7 +1253,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
if show_colorbar:
colorvar = "z" if constructor == go.Histogram2d else "color"
range_color = args["range_color"] or [None, None]
- d = len(args["color_continuous_scale"]) - 1
colorscale_validator = ColorscaleValidator("colorscale", "make_figure")
layout_patch["coloraxis1"] = dict(
@@ -1260,7 +1274,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
layout_patch["legend"]["itemsizing"] = "constant"
fig = init_figure(
- args, subplot_type, frame_list, ncols, nrows, has_marginal_x, has_marginal_y
+ args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels
)
# Position traces in subplots
@@ -1290,30 +1304,18 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
return fig
-def init_figure(
- args, subplot_type, frame_list, ncols, nrows, has_marginal_x, has_marginal_y
-):
+def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels):
# Build subplot specs
specs = [[{}] * ncols for _ in range(nrows)]
- column_titles = [None] * ncols
- row_titles = [None] * nrows
for frame in frame_list:
for trace in frame["data"]:
row0 = trace._subplot_row - 1
col0 = trace._subplot_col - 1
-
if isinstance(trace, go.Splom):
# Splom not compatible with make_subplots, treat as domain
specs[row0][col0] = {"type": "domain"}
else:
specs[row0][col0] = {"type": trace.type}
- if args.get("facet_row", None) and hasattr(trace, "_subplot_row_val"):
- row_titles[row0] = args["facet_row"] + "=" + str(trace._subplot_row_val)
-
- if args.get("facet_col", None) and hasattr(trace, "_subplot_col_val"):
- column_titles[col0] = (
- args["facet_col"] + "=" + str(trace._subplot_col_val)
- )
# Default row/column widths uniform
column_widths = [1.0] * ncols
@@ -1321,7 +1323,7 @@ def init_figure(
# Build column_widths/row_heights
if subplot_type == "xy":
- if has_marginal_x:
+ if bool(args.get("marginal_x", False)):
if args["marginal_x"] == "histogram" or ("color" in args and args["color"]):
main_size = 0.74
else:
@@ -1329,10 +1331,12 @@ def init_figure(
row_heights = [main_size] * (nrows - 1) + [1 - main_size]
vertical_spacing = 0.01
+ elif args.get("facet_col_wrap", 0):
+ vertical_spacing = 0.07
else:
vertical_spacing = 0.03
- if has_marginal_y:
+ if bool(args.get("marginal_y", False)):
if args["marginal_y"] == "histogram" or ("color" in args and args["color"]):
main_size = 0.74
else:
@@ -1351,6 +1355,15 @@ def init_figure(
vertical_spacing = 0.1
horizontal_spacing = 0.1
+ facet_col_wrap = args.get("facet_col_wrap", 0)
+ if facet_col_wrap:
+ subplot_labels = [None] * nrows * ncols
+ while len(col_labels) < nrows * ncols:
+ col_labels.append(None)
+ for i in range(nrows):
+ for j in range(ncols):
+ subplot_labels[i * ncols + j] = col_labels[(nrows - 1 - i) * ncols + j]
+
# Create figure with subplots
fig = make_subplots(
rows=nrows,
@@ -1358,8 +1371,9 @@ def init_figure(
specs=specs,
shared_xaxes="all",
shared_yaxes="all",
- row_titles=list(reversed(row_titles)),
- column_titles=column_titles,
+ row_titles=[] if facet_col_wrap else list(reversed(row_labels)),
+ column_titles=[] if facet_col_wrap else col_labels,
+ subplot_titles=subplot_labels if facet_col_wrap else [],
horizontal_spacing=horizontal_spacing,
vertical_spacing=vertical_spacing,
row_heights=row_heights,
diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py
index 1bf75f17e49..653c71b886b 100644
--- a/packages/python/plotly/plotly/express/_doc.py
+++ b/packages/python/plotly/plotly/express/_doc.py
@@ -183,6 +183,12 @@
colref_desc,
"Values from this column or array_like are used to assign marks to facetted subplots in the horizontal direction.",
],
+ facet_col_wrap=[
+ "int",
+ "Maximum number of facet columns.",
+ "Wraps the column variable at this width, so that the column facets span multiple rows.",
+ "Ignored if 0, and forced to 0 if `facet_row` or a `marginal` is set.",
+ ],
animation_frame=[
colref_type,
colref_desc,
diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py
index 08bb1a9cc95..89e3a921c9e 100644
--- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py
+++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py
@@ -61,6 +61,9 @@ def test_pandas_series():
assert fig.data[0].hovertemplate == "day=%{x}
y=%{y}"
fig = px.bar(tips, x="day", y=before_tip, labels={"y": "bill"})
assert fig.data[0].hovertemplate == "day=%{x}
bill=%{y}"
+ # lock down that we can pass df.col to facet_*
+ fig = px.bar(tips, x="day", y="tip", facet_row=tips.day, facet_col=tips.day)
+ assert fig.data[0].hovertemplate == "day=%{x}
tip=%{y}"
def test_several_dataframes():
diff --git a/test/percy/plotly-express.py b/test/percy/plotly-express.py
index 7ca527efbb8..01b9bb63386 100644
--- a/test/percy/plotly-express.py
+++ b/test/percy/plotly-express.py
@@ -184,6 +184,33 @@
import plotly.express as px
+tips = px.data.tips()
+fig = px.scatter(
+ tips,
+ x="day",
+ y="tip",
+ facet_col="day",
+ facet_col_wrap=2,
+ category_orders={"day": ["Thur", "Fri", "Sat", "Sun"]},
+)
+fig.write_html(os.path.join(dir_name, "facet_wrap_neat.html"))
+
+import plotly.express as px
+
+tips = px.data.tips()
+fig = px.scatter(
+ tips,
+ x="day",
+ y="tip",
+ color="sex",
+ facet_col="day",
+ facet_col_wrap=3,
+ category_orders={"day": ["Thur", "Fri", "Sat", "Sun"]},
+)
+fig.write_html(os.path.join(dir_name, "facet_wrap_ragged.html"))
+
+import plotly.express as px
+
gapminder = px.data.gapminder()
fig = px.area(gapminder, x="year", y="pop", color="continent", line_group="country")
fig.write_html(os.path.join(dir_name, "area.html"))