Skip to content

Commit 37c8c81

Browse files
Merge pull request #2105 from plotly/px_orthogonal_ordering
preload val_map from orders
2 parents 8e5bbad + 249440b commit 37c8c81

File tree

2 files changed

+79
-13
lines changed

2 files changed

+79
-13
lines changed

Diff for: packages/python/plotly/plotly/express/_core.py

+29-13
Original file line numberDiff line numberDiff line change
@@ -1136,11 +1136,19 @@ def infer_config(args, constructor, trace_patch):
11361136
def get_orderings(args, grouper, grouped):
11371137
"""
11381138
`orders` is the user-supplied ordering (with the remaining data-frame-supplied
1139-
ordering appended if the column is used for grouping)
1139+
ordering appended if the column is used for grouping). It includes anything the user
1140+
gave, for any variable, including values not present in the dataset. It is used
1141+
downstream to set e.g. `categoryarray` for cartesian axes
1142+
11401143
`group_names` is the set of groups, ordered by the order above
1144+
1145+
`group_values` is a subset of `orders` in both keys and values. It contains a key
1146+
for every grouped mapping and its values are the sorted *data* values for these
1147+
mappings.
11411148
"""
11421149
orders = {} if "category_orders" not in args else args["category_orders"].copy()
11431150
group_names = []
1151+
group_values = {}
11441152
for group_name in grouped.groups:
11451153
if len(grouper) == 1:
11461154
group_name = (group_name,)
@@ -1154,6 +1162,7 @@ def get_orderings(args, grouper, grouped):
11541162
for val in uniques:
11551163
if val not in orders[col]:
11561164
orders[col].append(val)
1165+
group_values[col] = sorted(uniques, key=orders[col].index)
11571166

11581167
for i, col in reversed(list(enumerate(grouper))):
11591168
if col != one_group:
@@ -1162,7 +1171,7 @@ def get_orderings(args, grouper, grouped):
11621171
key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
11631172
)
11641173

1165-
return orders, group_names
1174+
return orders, group_names, group_values
11661175

11671176

11681177
def make_figure(args, constructor, trace_patch={}, layout_patch={}):
@@ -1174,16 +1183,31 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
11741183
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
11751184
grouped = args["data_frame"].groupby(grouper, sort=False)
11761185

1177-
orders, sorted_group_names = get_orderings(args, grouper, grouped)
1186+
orders, sorted_group_names, sorted_group_values = get_orderings(
1187+
args, grouper, grouped
1188+
)
1189+
1190+
col_labels = []
1191+
row_labels = []
1192+
1193+
for m in grouped_mappings:
1194+
if m.grouper:
1195+
if m.facet == "col":
1196+
prefix = get_label(args, args["facet_col"]) + "="
1197+
col_labels = [prefix + str(s) for s in sorted_group_values[m.grouper]]
1198+
if m.facet == "row":
1199+
prefix = get_label(args, args["facet_row"]) + "="
1200+
row_labels = [prefix + str(s) for s in sorted_group_values[m.grouper]]
1201+
for val in sorted_group_values[m.grouper]:
1202+
if val not in m.val_map:
1203+
m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)]
11781204

11791205
subplot_type = _subplot_type_for_trace_type(constructor().type)
11801206

11811207
trace_names_by_frame = {}
11821208
frames = OrderedDict()
11831209
trendline_rows = []
11841210
nrows = ncols = 1
1185-
col_labels = []
1186-
row_labels = []
11871211
trace_name_labels = None
11881212
for group_name in sorted_group_names:
11891213
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
@@ -1281,10 +1305,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
12811305
# Find row for trace, handling facet_row and marginal_x
12821306
if m.facet == "row":
12831307
row = m.val_map[val]
1284-
if args["facet_row"] and len(row_labels) < row:
1285-
row_labels.append(
1286-
get_label(args, args["facet_row"]) + "=" + str(val)
1287-
)
12881308
else:
12891309
if (
12901310
bool(args.get("marginal_x", False))
@@ -1298,10 +1318,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
12981318
# Find col for trace, handling facet_col and marginal_y
12991319
if m.facet == "col":
13001320
col = m.val_map[val]
1301-
if args["facet_col"] and len(col_labels) < col:
1302-
col_labels.append(
1303-
get_label(args, args["facet_col"]) + "=" + str(val)
1304-
)
13051321
if facet_col_wrap: # assumes no facet_row, no marginals
13061322
row = 1 + ((col - 1) // facet_col_wrap)
13071323
col = 1 + ((col - 1) % facet_col_wrap)

Diff for: packages/python/plotly/plotly/tests/test_core/test_px/test_px.py

+50
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,53 @@ def test_px_templates():
182182
assert fig.layout.xaxis3.showgrid is None
183183
assert fig.layout.yaxis2.showgrid
184184
assert fig.layout.yaxis3.showgrid
185+
186+
187+
def test_orthogonal_orderings():
188+
from itertools import permutations
189+
190+
df = px.data.tips()
191+
192+
symbol_sequence = ["circle", "diamond", "square", "cross"]
193+
color_sequence = ["red", "blue"]
194+
195+
def assert_orderings(days_order, days_check, times_order, times_check):
196+
fig = px.scatter(
197+
df,
198+
x="total_bill",
199+
y="tip",
200+
facet_row="time",
201+
facet_col="day",
202+
color="time",
203+
symbol="day",
204+
symbol_sequence=symbol_sequence,
205+
color_discrete_sequence=color_sequence,
206+
category_orders=dict(day=days_order, time=times_order),
207+
)
208+
209+
for col in range(len(days_check)):
210+
for trace in fig.select_traces(col=col + 1):
211+
assert days_check[col] in trace.hovertemplate
212+
213+
for row in range(len(times_check)):
214+
for trace in fig.select_traces(row=2 - row):
215+
assert times_check[row] in trace.hovertemplate
216+
217+
for trace in fig.data:
218+
for i, day in enumerate(days_check):
219+
if day in trace.name:
220+
assert trace.marker.symbol == symbol_sequence[i]
221+
for i, time in enumerate(times_check):
222+
if time in trace.name:
223+
assert trace.marker.color == color_sequence[i]
224+
225+
assert_orderings(
226+
["x", "Sun", "Sat", "y", "Fri", "z"], # add extra noise, missing Thur
227+
["Sun", "Sat", "Fri", "Thur"], # Thur is at the back
228+
["a", "Lunch", "b"], # add extra noise, missing Dinner
229+
["Lunch", "Dinner"], # Dinner is at the back
230+
)
231+
232+
for days in permutations(df["day"].unique()):
233+
for times in permutations(df["time"].unique()):
234+
assert_orderings(days, days, times, times)

0 commit comments

Comments
 (0)