@@ -1136,11 +1136,19 @@ def infer_config(args, constructor, trace_patch):
1136
1136
def get_orderings (args , grouper , grouped ):
1137
1137
"""
1138
1138
`orders` is the user-supplied ordering (with the remaining data-frame-supplied
1139
- ordering appended if the column is used for grouping)
1139
+ ordering appended if the column is used for grouping). It includes anything the user
1140
+ gave, for any variable, including values not present in the dataset. It is used
1141
+ downstream to set e.g. `categoryarray` for cartesian axes
1142
+
1140
1143
`group_names` is the set of groups, ordered by the order above
1144
+
1145
+ `group_values` is a subset of `orders` in both keys and values. It contains a key
1146
+ for every grouped mapping and its values are the sorted *data* values for these
1147
+ mappings.
1141
1148
"""
1142
1149
orders = {} if "category_orders" not in args else args ["category_orders" ].copy ()
1143
1150
group_names = []
1151
+ group_values = {}
1144
1152
for group_name in grouped .groups :
1145
1153
if len (grouper ) == 1 :
1146
1154
group_name = (group_name ,)
@@ -1154,6 +1162,7 @@ def get_orderings(args, grouper, grouped):
1154
1162
for val in uniques :
1155
1163
if val not in orders [col ]:
1156
1164
orders [col ].append (val )
1165
+ group_values [col ] = sorted (uniques , key = orders [col ].index )
1157
1166
1158
1167
for i , col in reversed (list (enumerate (grouper ))):
1159
1168
if col != one_group :
@@ -1162,7 +1171,7 @@ def get_orderings(args, grouper, grouped):
1162
1171
key = lambda g : orders [col ].index (g [i ]) if g [i ] in orders [col ] else - 1 ,
1163
1172
)
1164
1173
1165
- return orders , group_names
1174
+ return orders , group_names , group_values
1166
1175
1167
1176
1168
1177
def make_figure (args , constructor , trace_patch = {}, layout_patch = {}):
@@ -1174,16 +1183,31 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
1174
1183
grouper = [x .grouper or one_group for x in grouped_mappings ] or [one_group ]
1175
1184
grouped = args ["data_frame" ].groupby (grouper , sort = False )
1176
1185
1177
- orders , sorted_group_names = get_orderings (args , grouper , grouped )
1186
+ orders , sorted_group_names , sorted_group_values = get_orderings (
1187
+ args , grouper , grouped
1188
+ )
1189
+
1190
+ col_labels = []
1191
+ row_labels = []
1192
+
1193
+ for m in grouped_mappings :
1194
+ if m .grouper :
1195
+ if m .facet == "col" :
1196
+ prefix = get_label (args , args ["facet_col" ]) + "="
1197
+ col_labels = [prefix + str (s ) for s in sorted_group_values [m .grouper ]]
1198
+ if m .facet == "row" :
1199
+ prefix = get_label (args , args ["facet_row" ]) + "="
1200
+ row_labels = [prefix + str (s ) for s in sorted_group_values [m .grouper ]]
1201
+ for val in sorted_group_values [m .grouper ]:
1202
+ if val not in m .val_map :
1203
+ m .val_map [val ] = m .sequence [len (m .val_map ) % len (m .sequence )]
1178
1204
1179
1205
subplot_type = _subplot_type_for_trace_type (constructor ().type )
1180
1206
1181
1207
trace_names_by_frame = {}
1182
1208
frames = OrderedDict ()
1183
1209
trendline_rows = []
1184
1210
nrows = ncols = 1
1185
- col_labels = []
1186
- row_labels = []
1187
1211
trace_name_labels = None
1188
1212
for group_name in sorted_group_names :
1189
1213
group = grouped .get_group (group_name if len (group_name ) > 1 else group_name [0 ])
@@ -1281,10 +1305,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
1281
1305
# Find row for trace, handling facet_row and marginal_x
1282
1306
if m .facet == "row" :
1283
1307
row = m .val_map [val ]
1284
- if args ["facet_row" ] and len (row_labels ) < row :
1285
- row_labels .append (
1286
- get_label (args , args ["facet_row" ]) + "=" + str (val )
1287
- )
1288
1308
else :
1289
1309
if (
1290
1310
bool (args .get ("marginal_x" , False ))
@@ -1298,10 +1318,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
1298
1318
# Find col for trace, handling facet_col and marginal_y
1299
1319
if m .facet == "col" :
1300
1320
col = m .val_map [val ]
1301
- if args ["facet_col" ] and len (col_labels ) < col :
1302
- col_labels .append (
1303
- get_label (args , args ["facet_col" ]) + "=" + str (val )
1304
- )
1305
1321
if facet_col_wrap : # assumes no facet_row, no marginals
1306
1322
row = 1 + ((col - 1 ) // facet_col_wrap )
1307
1323
col = 1 + ((col - 1 ) % facet_col_wrap )
0 commit comments