Skip to content

initial build-out of facet wrapping #1838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions packages/python/plotly/plotly/express/_chart_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def scatter(
text=None,
facet_row=None,
facet_col=None,
facet_col_wrap=0,
error_x=None,
error_x_minus=None,
error_y=None,
Expand Down Expand Up @@ -65,6 +66,7 @@ def density_contour(
color=None,
facet_row=None,
facet_col=None,
facet_col_wrap=0,
hover_name=None,
hover_data=None,
animation_frame=None,
Expand Down Expand Up @@ -120,6 +122,7 @@ def density_heatmap(
z=None,
facet_row=None,
facet_col=None,
facet_col_wrap=0,
hover_name=None,
hover_data=None,
animation_frame=None,
Expand Down Expand Up @@ -180,6 +183,7 @@ def line(
text=None,
facet_row=None,
facet_col=None,
facet_col_wrap=0,
error_x=None,
error_x_minus=None,
error_y=None,
Expand Down Expand Up @@ -225,6 +229,7 @@ def area(
text=None,
facet_row=None,
facet_col=None,
facet_col_wrap=0,
animation_frame=None,
animation_group=None,
category_orders={},
Expand Down Expand Up @@ -267,6 +272,7 @@ def bar(
color=None,
facet_row=None,
facet_col=None,
facet_col_wrap=0,
hover_name=None,
hover_data=None,
custom_data=None,
Expand Down Expand Up @@ -318,6 +324,7 @@ def histogram(
color=None,
facet_row=None,
facet_col=None,
facet_col_wrap=0,
hover_name=None,
hover_data=None,
animation_frame=None,
Expand Down Expand Up @@ -376,6 +383,7 @@ def violin(
color=None,
facet_row=None,
facet_col=None,
facet_col_wrap=0,
hover_name=None,
hover_data=None,
custom_data=None,
Expand Down Expand Up @@ -427,6 +435,7 @@ def box(
color=None,
facet_row=None,
facet_col=None,
facet_col_wrap=0,
hover_name=None,
hover_data=None,
custom_data=None,
Expand Down Expand Up @@ -473,6 +482,7 @@ def strip(
color=None,
facet_row=None,
facet_col=None,
facet_col_wrap=0,
hover_name=None,
hover_data=None,
custom_data=None,
Expand Down
80 changes: 47 additions & 33 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
result["y"] = trendline[:, 1]
hover_header = "<b>LOWESS trendline</b><br><br>"
elif v == "ols":
fit_results = sm.OLS(y, sm.add_constant(x)).fit()
fit_results = sm.OLS(y.values, sm.add_constant(x.values)).fit()
result["y"] = fit_results.predict()
hover_header = "<b>OLS trendline</b><br>"
hover_header += "%s = %f * %s + %f<br>" % (
Expand Down Expand Up @@ -747,10 +747,10 @@ def apply_default_cascade(args):
]

# If both marginals and faceting are specified, faceting wins
if args.get("facet_col", None) and args.get("marginal_y", None):
if args.get("facet_col", None) is not None and args.get("marginal_y", None):
args["marginal_y"] = None

if args.get("facet_row", None) and args.get("marginal_x", None):
if args.get("facet_row", None) is not None and args.get("marginal_x", None):
args["marginal_x"] = None


Expand Down Expand Up @@ -874,7 +874,7 @@ def build_dataframe(args, attrables, array_attrables):
"pandas MultiIndex is not supported by plotly express "
"at the moment." % field
)
## ----------------- argument is a col name ----------------------
# ----------------- argument is a col name ----------------------
if isinstance(argument, str) or isinstance(
argument, int
): # just a column name given as str or int
Expand Down Expand Up @@ -1042,6 +1042,13 @@ def infer_config(args, constructor, trace_patch):
args[position] = args["marginal"]
args[other_position] = None

if (
args.get("marginal_x", None) is not None
or args.get("marginal_y", None) is not None
or args.get("facet_row", None) is not None
):
args["facet_col_wrap"] = 0

# Compute applicable grouping attributes
for k in group_attrables:
if k in args:
Expand Down Expand Up @@ -1098,15 +1105,14 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):

orders, sorted_group_names = get_orderings(args, grouper, grouped)

has_marginal_x = bool(args.get("marginal_x", False))
has_marginal_y = bool(args.get("marginal_y", False))

subplot_type = _subplot_type_for_trace_type(constructor().type)

trace_names_by_frame = {}
frames = OrderedDict()
trendline_rows = []
nrows = ncols = 1
col_labels = []
row_labels = []
for group_name in sorted_group_names:
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
mapping_labels = OrderedDict()
Expand Down Expand Up @@ -1188,27 +1194,36 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
# Find row for trace, handling facet_row and marginal_x
if m.facet == "row":
row = m.val_map[val]
trace._subplot_row_val = val
if args["facet_row"] and len(row_labels) < row:
row_labels.append(args["facet_row"] + "=" + str(val))
else:
if has_marginal_x and trace_spec.marginal != "x":
if (
bool(args.get("marginal_x", False))
and trace_spec.marginal != "x"
):
row = 2
else:
row = 1

nrows = max(nrows, row)
if row > 1:
trace._subplot_row = row

facet_col_wrap = args.get("facet_col_wrap", 0)
# Find col for trace, handling facet_col and marginal_y
if m.facet == "col":
col = m.val_map[val]
trace._subplot_col_val = val
if args["facet_col"] and len(col_labels) < col:
col_labels.append(args["facet_col"] + "=" + str(val))
if facet_col_wrap: # assumes no facet_row, no marginals
row = 1 + ((col - 1) // facet_col_wrap)
col = 1 + ((col - 1) % facet_col_wrap)
else:
if trace_spec.marginal == "y":
col = 2
else:
col = 1

nrows = max(nrows, row)
if row > 1:
trace._subplot_row = row

ncols = max(ncols, col)
if col > 1:
trace._subplot_col = col
Expand Down Expand Up @@ -1238,7 +1253,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
if show_colorbar:
colorvar = "z" if constructor == go.Histogram2d else "color"
range_color = args["range_color"] or [None, None]
d = len(args["color_continuous_scale"]) - 1

colorscale_validator = ColorscaleValidator("colorscale", "make_figure")
layout_patch["coloraxis1"] = dict(
Expand All @@ -1260,7 +1274,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
layout_patch["legend"]["itemsizing"] = "constant"

fig = init_figure(
args, subplot_type, frame_list, ncols, nrows, has_marginal_x, has_marginal_y
args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels
)

# Position traces in subplots
Expand Down Expand Up @@ -1290,49 +1304,39 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
return fig


def init_figure(
args, subplot_type, frame_list, ncols, nrows, has_marginal_x, has_marginal_y
):
def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels):
# Build subplot specs
specs = [[{}] * ncols for _ in range(nrows)]
column_titles = [None] * ncols
row_titles = [None] * nrows
for frame in frame_list:
for trace in frame["data"]:
row0 = trace._subplot_row - 1
col0 = trace._subplot_col - 1

if isinstance(trace, go.Splom):
# Splom not compatible with make_subplots, treat as domain
specs[row0][col0] = {"type": "domain"}
else:
specs[row0][col0] = {"type": trace.type}
if args.get("facet_row", None) and hasattr(trace, "_subplot_row_val"):
row_titles[row0] = args["facet_row"] + "=" + str(trace._subplot_row_val)

if args.get("facet_col", None) and hasattr(trace, "_subplot_col_val"):
column_titles[col0] = (
args["facet_col"] + "=" + str(trace._subplot_col_val)
)

# Default row/column widths uniform
column_widths = [1.0] * ncols
row_heights = [1.0] * nrows

# Build column_widths/row_heights
if subplot_type == "xy":
if has_marginal_x:
if bool(args.get("marginal_x", False)):
if args["marginal_x"] == "histogram" or ("color" in args and args["color"]):
main_size = 0.74
else:
main_size = 0.84

row_heights = [main_size] * (nrows - 1) + [1 - main_size]
vertical_spacing = 0.01
elif args.get("facet_col_wrap", 0):
vertical_spacing = 0.07
else:
vertical_spacing = 0.03

if has_marginal_y:
if bool(args.get("marginal_y", False)):
if args["marginal_y"] == "histogram" or ("color" in args and args["color"]):
main_size = 0.74
else:
Expand All @@ -1351,15 +1355,25 @@ def init_figure(
vertical_spacing = 0.1
horizontal_spacing = 0.1

facet_col_wrap = args.get("facet_col_wrap", 0)
if facet_col_wrap:
subplot_labels = [None] * nrows * ncols
while len(col_labels) < nrows * ncols:
col_labels.append(None)
for i in range(nrows):
for j in range(ncols):
subplot_labels[i * ncols + j] = col_labels[(nrows - 1 - i) * ncols + j]

# Create figure with subplots
fig = make_subplots(
rows=nrows,
cols=ncols,
specs=specs,
shared_xaxes="all",
shared_yaxes="all",
row_titles=list(reversed(row_titles)),
column_titles=column_titles,
row_titles=[] if facet_col_wrap else list(reversed(row_labels)),
column_titles=[] if facet_col_wrap else col_labels,
subplot_titles=subplot_labels if facet_col_wrap else [],
horizontal_spacing=horizontal_spacing,
vertical_spacing=vertical_spacing,
row_heights=row_heights,
Expand Down
6 changes: 6 additions & 0 deletions packages/python/plotly/plotly/express/_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@
colref_desc,
"Values from this column or array_like are used to assign marks to facetted subplots in the horizontal direction.",
],
facet_col_wrap=[
"int",
"Maximum number of facet columns.",
"Wraps the column variable at this width, so that the column facets span multiple rows.",
"Ignored if 0, and forced to 0 if `facet_row` or a `marginal` is set.",
],
animation_frame=[
colref_type,
colref_desc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def test_pandas_series():
assert fig.data[0].hovertemplate == "day=%{x}<br>y=%{y}"
fig = px.bar(tips, x="day", y=before_tip, labels={"y": "bill"})
assert fig.data[0].hovertemplate == "day=%{x}<br>bill=%{y}"
# lock down that we can pass df.col to facet_*
fig = px.bar(tips, x="day", y="tip", facet_row=tips.day, facet_col=tips.day)
assert fig.data[0].hovertemplate == "day=%{x}<br>tip=%{y}"


def test_several_dataframes():
Expand Down
27 changes: 27 additions & 0 deletions test/percy/plotly-express.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,33 @@

import plotly.express as px

tips = px.data.tips()
fig = px.scatter(
tips,
x="day",
y="tip",
facet_col="day",
facet_col_wrap=2,
category_orders={"day": ["Thur", "Fri", "Sat", "Sun"]},
)
fig.write_html(os.path.join(dir_name, "facet_wrap_neat.html"))

import plotly.express as px

tips = px.data.tips()
fig = px.scatter(
tips,
x="day",
y="tip",
color="sex",
facet_col="day",
facet_col_wrap=3,
category_orders={"day": ["Thur", "Fri", "Sat", "Sun"]},
)
fig.write_html(os.path.join(dir_name, "facet_wrap_ragged.html"))

import plotly.express as px

gapminder = px.data.gapminder()
fig = px.area(gapminder, x="year", y="pop", color="continent", line_group="country")
fig.write_html(os.path.join(dir_name, "area.html"))
Expand Down