Skip to content

Commit 68f1674

Browse files
Merge pull request #3247 from plotly/val_map5.0
PX val_map now respects category_orders
2 parents bca2925 + b076a38 commit 68f1674

File tree

2 files changed

+84
-79
lines changed

2 files changed

+84
-79
lines changed

Diff for: packages/python/plotly/plotly/express/_core.py

+77-69
Original file line numberDiff line numberDiff line change
@@ -400,10 +400,10 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
400400
if hover_is_dict and not attr_value[col]:
401401
continue
402402
if col in [
403-
args.get("x", None),
404-
args.get("y", None),
405-
args.get("z", None),
406-
args.get("base", None),
403+
args.get("x"),
404+
args.get("y"),
405+
args.get("z"),
406+
args.get("base"),
407407
]:
408408
continue
409409
try:
@@ -552,8 +552,10 @@ def set_cartesian_axis_opts(args, axis, letter, orders):
552552
axis["categoryarray"] = (
553553
orders[args[letter]]
554554
if isinstance(axis, go.layout.XAxis)
555-
else list(reversed(orders[args[letter]]))
555+
else list(reversed(orders[args[letter]])) # top down for Y axis
556556
)
557+
if "range" not in axis:
558+
axis["range"] = [-0.5, len(orders[args[letter]]) - 0.5]
557559

558560

559561
def configure_cartesian_marginal_axes(args, fig, orders):
@@ -1284,8 +1286,8 @@ def build_dataframe(args, constructor):
12841286

12851287
# now we handle special cases like wide-mode or x-xor-y specification
12861288
# by rearranging args to tee things up for process_args_into_dataframe to work
1287-
no_x = args.get("x", None) is None
1288-
no_y = args.get("y", None) is None
1289+
no_x = args.get("x") is None
1290+
no_y = args.get("y") is None
12891291
wide_x = False if no_x else _is_col_list(df_input, args["x"])
12901292
wide_y = False if no_y else _is_col_list(df_input, args["y"])
12911293

@@ -1312,9 +1314,9 @@ def build_dataframe(args, constructor):
13121314
if var_name in [None, "value", "index"] or var_name in df_input:
13131315
var_name = "variable"
13141316
if constructor == go.Funnel:
1315-
wide_orientation = args.get("orientation", None) or "h"
1317+
wide_orientation = args.get("orientation") or "h"
13161318
else:
1317-
wide_orientation = args.get("orientation", None) or "v"
1319+
wide_orientation = args.get("orientation") or "v"
13181320
args["orientation"] = wide_orientation
13191321
args["wide_cross"] = None
13201322
elif wide_x != wide_y:
@@ -1345,7 +1347,7 @@ def build_dataframe(args, constructor):
13451347
if constructor in [go.Scatter, go.Bar, go.Funnel] + hist2d_types:
13461348
if not wide_mode and (no_x != no_y):
13471349
for ax in ["x", "y"]:
1348-
if args.get(ax, None) is None:
1350+
if args.get(ax) is None:
13491351
args[ax] = df_input.index if df_provided else Range()
13501352
if constructor == go.Bar:
13511353
missing_bar_dim = ax
@@ -1369,7 +1371,7 @@ def build_dataframe(args, constructor):
13691371
)
13701372

13711373
no_color = False
1372-
if type(args.get("color", None)) == str and args["color"] == NO_COLOR:
1374+
if type(args.get("color")) == str and args["color"] == NO_COLOR:
13731375
no_color = True
13741376
args["color"] = None
13751377
# now that things have been prepped, we do the systematic rewriting of `args`
@@ -1777,25 +1779,25 @@ def infer_config(args, constructor, trace_patch, layout_patch):
17771779
else args["geojson"].__geo_interface__
17781780
)
17791781

1780-
# Compute marginal attribute
1782+
# Compute marginal attribute: copy to appropriate marginal_*
17811783
if "marginal" in args:
17821784
position = "marginal_x" if args["orientation"] == "v" else "marginal_y"
17831785
other_position = "marginal_x" if args["orientation"] == "h" else "marginal_y"
17841786
args[position] = args["marginal"]
17851787
args[other_position] = None
17861788

17871789
# If both marginals and faceting are specified, faceting wins
1788-
if args.get("facet_col", None) is not None and args.get("marginal_y", None):
1790+
if args.get("facet_col") is not None and args.get("marginal_y") is not None:
17891791
args["marginal_y"] = None
17901792

1791-
if args.get("facet_row", None) is not None and args.get("marginal_x", None):
1793+
if args.get("facet_row") is not None and args.get("marginal_x") is not None:
17921794
args["marginal_x"] = None
17931795

17941796
# facet_col_wrap only works if no marginals or row faceting is used
17951797
if (
1796-
args.get("marginal_x", None) is not None
1797-
or args.get("marginal_y", None) is not None
1798-
or args.get("facet_row", None) is not None
1798+
args.get("marginal_x") is not None
1799+
or args.get("marginal_y") is not None
1800+
or args.get("facet_row") is not None
17991801
):
18001802
args["facet_col_wrap"] = 0
18011803

@@ -1814,43 +1816,41 @@ def infer_config(args, constructor, trace_patch, layout_patch):
18141816

18151817
def get_orderings(args, grouper, grouped):
18161818
"""
1817-
`orders` is the user-supplied ordering (with the remaining data-frame-supplied
1818-
ordering appended if the column is used for grouping). It includes anything the user
1819-
gave, for any variable, including values not present in the dataset. It is used
1820-
downstream to set e.g. `categoryarray` for cartesian axes
1821-
1822-
`group_names` is the set of groups, ordered by the order above
1823-
1824-
`group_values` is a subset of `orders` in both keys and values. It contains a key
1825-
for every grouped mapping and its values are the sorted *data* values for these
1826-
mappings.
1819+
`orders` is the user-supplied ordering with the remaining data-frame-supplied
1820+
ordering appended if the column is used for grouping. It includes anything the user
1821+
gave, for any variable, including values not present in the dataset. It's a dict
1822+
where the keys are e.g. "x" or "color"
1823+
1824+
`sorted_group_names` is the set of groups, ordered by the order above. It's a list
1825+
of tuples like [("value1", ""), ("value2", "")] where each tuple contains the name
1826+
of a single dimension-group
18271827
"""
1828+
18281829
orders = {} if "category_orders" not in args else args["category_orders"].copy()
1829-
group_names = []
1830-
group_values = {}
1830+
for col in grouper:
1831+
if col != one_group:
1832+
uniques = args["data_frame"][col].unique()
1833+
if col not in orders:
1834+
orders[col] = list(uniques)
1835+
else:
1836+
orders[col] = list(orders[col])
1837+
for val in uniques:
1838+
if val not in orders[col]:
1839+
orders[col].append(val)
1840+
1841+
sorted_group_names = []
18311842
for group_name in grouped.groups:
18321843
if len(grouper) == 1:
18331844
group_name = (group_name,)
1834-
group_names.append(group_name)
1835-
for col in grouper:
1836-
if col != one_group:
1837-
uniques = args["data_frame"][col].unique()
1838-
if col not in orders:
1839-
orders[col] = list(uniques)
1840-
else:
1841-
for val in uniques:
1842-
if val not in orders[col]:
1843-
orders[col].append(val)
1844-
group_values[col] = sorted(uniques, key=orders[col].index)
1845+
sorted_group_names.append(group_name)
18451846

18461847
for i, col in reversed(list(enumerate(grouper))):
18471848
if col != one_group:
1848-
group_names = sorted(
1849-
group_names,
1849+
sorted_group_names = sorted(
1850+
sorted_group_names,
18501851
key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
18511852
)
1852-
1853-
return orders, group_names, group_values
1853+
return orders, sorted_group_names
18541854

18551855

18561856
def make_figure(args, constructor, trace_patch=None, layout_patch=None):
@@ -1871,32 +1871,35 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
18711871
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
18721872
grouped = args["data_frame"].groupby(grouper, sort=False)
18731873

1874-
orders, sorted_group_names, sorted_group_values = get_orderings(
1875-
args, grouper, grouped
1876-
)
1874+
orders, sorted_group_names = get_orderings(args, grouper, grouped)
18771875

18781876
col_labels = []
18791877
row_labels = []
1880-
1878+
nrows = ncols = 1
18811879
for m in grouped_mappings:
1882-
if m.grouper:
1880+
if m.grouper not in orders:
1881+
m.val_map[""] = m.sequence[0]
1882+
else:
1883+
sorted_values = orders[m.grouper]
18831884
if m.facet == "col":
18841885
prefix = get_label(args, args["facet_col"]) + "="
1885-
col_labels = [prefix + str(s) for s in sorted_group_values[m.grouper]]
1886+
col_labels = [prefix + str(s) for s in sorted_values]
1887+
ncols = len(col_labels)
18861888
if m.facet == "row":
18871889
prefix = get_label(args, args["facet_row"]) + "="
1888-
row_labels = [prefix + str(s) for s in sorted_group_values[m.grouper]]
1889-
for val in sorted_group_values[m.grouper]:
1890-
if val not in m.val_map:
1890+
row_labels = [prefix + str(s) for s in sorted_values]
1891+
nrows = len(row_labels)
1892+
for val in sorted_values:
1893+
if val not in m.val_map: # always False if it's an IdentityMap
18911894
m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)]
18921895

18931896
subplot_type = _subplot_type_for_trace_type(constructor().type)
18941897

18951898
trace_names_by_frame = {}
18961899
frames = OrderedDict()
18971900
trendline_rows = []
1898-
nrows = ncols = 1
18991901
trace_name_labels = None
1902+
facet_col_wrap = args.get("facet_col_wrap", 0)
19001903
for group_name in sorted_group_names:
19011904
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
19021905
mapping_labels = OrderedDict()
@@ -1943,8 +1946,6 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
19431946

19441947
for i, m in enumerate(grouped_mappings):
19451948
val = group_name[i]
1946-
if val not in m.val_map:
1947-
m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)]
19481949
try:
19491950
m.updater(trace, m.val_map[val]) # covers most cases
19501951
except ValueError:
@@ -1979,14 +1980,13 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
19791980
row = m.val_map[val]
19801981
else:
19811982
if (
1982-
bool(args.get("marginal_x", False))
1983-
and trace_spec.marginal != "x"
1983+
args.get("marginal_x") is not None # there is a marginal
1984+
and trace_spec.marginal != "x" # and we're not it
19841985
):
19851986
row = 2
19861987
else:
19871988
row = 1
19881989

1989-
facet_col_wrap = args.get("facet_col_wrap", 0)
19901990
# Find col for trace, handling facet_col and marginal_y
19911991
if m.facet == "col":
19921992
col = m.val_map[val]
@@ -1999,11 +1999,9 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
19991999
else:
20002000
col = 1
20012001

2002-
nrows = max(nrows, row)
20032002
if row > 1:
20042003
trace._subplot_row = row
20052004

2006-
ncols = max(ncols, col)
20072005
if col > 1:
20082006
trace._subplot_col = col
20092007
if (
@@ -2062,6 +2060,16 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
20622060
):
20632061
layout_patch["legend"]["itemsizing"] = "constant"
20642062

2063+
if facet_col_wrap:
2064+
nrows = math.ceil(ncols / facet_col_wrap)
2065+
ncols = min(ncols, facet_col_wrap)
2066+
2067+
if args.get("marginal_x") is not None:
2068+
nrows += 1
2069+
2070+
if args.get("marginal_y") is not None:
2071+
ncols += 1
2072+
20652073
fig = init_figure(
20662074
args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels
20672075
)
@@ -2106,7 +2114,7 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la
21062114

21072115
# Build column_widths/row_heights
21082116
if subplot_type == "xy":
2109-
if bool(args.get("marginal_x", False)):
2117+
if args.get("marginal_x") is not None:
21102118
if args["marginal_x"] == "histogram" or ("color" in args and args["color"]):
21112119
main_size = 0.74
21122120
else:
@@ -2115,11 +2123,11 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la
21152123
row_heights = [main_size] * (nrows - 1) + [1 - main_size]
21162124
vertical_spacing = 0.01
21172125
elif facet_col_wrap:
2118-
vertical_spacing = args.get("facet_row_spacing", None) or 0.07
2126+
vertical_spacing = args.get("facet_row_spacing") or 0.07
21192127
else:
2120-
vertical_spacing = args.get("facet_row_spacing", None) or 0.03
2128+
vertical_spacing = args.get("facet_row_spacing") or 0.03
21212129

2122-
if bool(args.get("marginal_y", False)):
2130+
if args.get("marginal_y") is not None:
21232131
if args["marginal_y"] == "histogram" or ("color" in args and args["color"]):
21242132
main_size = 0.74
21252133
else:
@@ -2128,18 +2136,18 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la
21282136
column_widths = [main_size] * (ncols - 1) + [1 - main_size]
21292137
horizontal_spacing = 0.005
21302138
else:
2131-
horizontal_spacing = args.get("facet_col_spacing", None) or 0.02
2139+
horizontal_spacing = args.get("facet_col_spacing") or 0.02
21322140
else:
21332141
# Other subplot types:
21342142
# 'scene', 'geo', 'polar', 'ternary', 'mapbox', 'domain', None
21352143
#
21362144
# We can customize subplot spacing per type once we enable faceting
21372145
# for all plot types
21382146
if facet_col_wrap:
2139-
vertical_spacing = args.get("facet_row_spacing", None) or 0.07
2147+
vertical_spacing = args.get("facet_row_spacing") or 0.07
21402148
else:
2141-
vertical_spacing = args.get("facet_row_spacing", None) or 0.03
2142-
horizontal_spacing = args.get("facet_col_spacing", None) or 0.02
2149+
vertical_spacing = args.get("facet_row_spacing") or 0.03
2150+
horizontal_spacing = args.get("facet_col_spacing") or 0.02
21432151

21442152
if facet_col_wrap:
21452153
subplot_labels = [None] * nrows * ncols

Diff for: packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,8 @@ def test_px_defaults():
209209

210210

211211
def assert_orderings(days_order, days_check, times_order, times_check):
212-
symbol_sequence = ["circle", "diamond", "square", "cross"]
213-
color_sequence = ["red", "blue"]
212+
symbol_sequence = ["circle", "diamond", "square", "cross", "circle", "diamond"]
213+
color_sequence = ["red", "blue", "red", "blue", "red", "blue", "red", "blue"]
214214
fig = px.scatter(
215215
px.data.tips(),
216216
x="total_bill",
@@ -229,7 +229,7 @@ def assert_orderings(days_order, days_check, times_order, times_check):
229229
assert days_check[col] in trace.hovertemplate
230230

231231
for row in range(len(times_check)):
232-
for trace in fig.select_traces(row=2 - row):
232+
for trace in fig.select_traces(row=len(times_check) - row):
233233
assert times_check[row] in trace.hovertemplate
234234

235235
for trace in fig.data:
@@ -241,13 +241,10 @@ def assert_orderings(days_order, days_check, times_order, times_check):
241241
assert trace.marker.color == color_sequence[i]
242242

243243

244-
def test_noisy_orthogonal_orderings():
245-
assert_orderings(
246-
["x", "Sun", "Sat", "y", "Fri", "z"], # add extra noise, missing Thur
247-
["Sun", "Sat", "Fri", "Thur"], # Thur is at the back
248-
["a", "Lunch", "b"], # add extra noise, missing Dinner
249-
["Lunch", "Dinner"], # Dinner is at the back
250-
)
244+
@pytest.mark.parametrize("days", permutations(["Sun", "Sat", "Fri", "x"]))
245+
@pytest.mark.parametrize("times", permutations(["Lunch", "x"]))
246+
def test_orthogonal_and_missing_orderings(days, times):
247+
assert_orderings(days, list(days) + ["Thur"], times, list(times) + ["Dinner"])
251248

252249

253250
@pytest.mark.parametrize("days", permutations(["Sun", "Sat", "Fri", "Thur"]))

0 commit comments

Comments
 (0)