@@ -400,10 +400,10 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
400400 if hover_is_dict and not attr_value [col ]:
401401 continue
402402 if col in [
403- args .get ("x" , None ),
404- args .get ("y" , None ),
405- args .get ("z" , None ),
406- args .get ("base" , None ),
403+ args .get ("x" ),
404+ args .get ("y" ),
405+ args .get ("z" ),
406+ args .get ("base" ),
407407 ]:
408408 continue
409409 try :
@@ -552,8 +552,10 @@ def set_cartesian_axis_opts(args, axis, letter, orders):
552552 axis ["categoryarray" ] = (
553553 orders [args [letter ]]
554554 if isinstance (axis , go .layout .XAxis )
555- else list (reversed (orders [args [letter ]]))
555+ else list (reversed (orders [args [letter ]])) # top down for Y axis
556556 )
557+ if "range" not in axis :
558+ axis ["range" ] = [- 0.5 , len (orders [args [letter ]]) - 0.5 ]
557559
558560
559561def configure_cartesian_marginal_axes (args , fig , orders ):
@@ -1284,8 +1286,8 @@ def build_dataframe(args, constructor):
12841286
12851287 # now we handle special cases like wide-mode or x-xor-y specification
12861288 # by rearranging args to tee things up for process_args_into_dataframe to work
1287- no_x = args .get ("x" , None ) is None
1288- no_y = args .get ("y" , None ) is None
1289+ no_x = args .get ("x" ) is None
1290+ no_y = args .get ("y" ) is None
12891291 wide_x = False if no_x else _is_col_list (df_input , args ["x" ])
12901292 wide_y = False if no_y else _is_col_list (df_input , args ["y" ])
12911293
@@ -1312,9 +1314,9 @@ def build_dataframe(args, constructor):
13121314 if var_name in [None , "value" , "index" ] or var_name in df_input :
13131315 var_name = "variable"
13141316 if constructor == go .Funnel :
1315- wide_orientation = args .get ("orientation" , None ) or "h"
1317+ wide_orientation = args .get ("orientation" ) or "h"
13161318 else :
1317- wide_orientation = args .get ("orientation" , None ) or "v"
1319+ wide_orientation = args .get ("orientation" ) or "v"
13181320 args ["orientation" ] = wide_orientation
13191321 args ["wide_cross" ] = None
13201322 elif wide_x != wide_y :
@@ -1345,7 +1347,7 @@ def build_dataframe(args, constructor):
13451347 if constructor in [go .Scatter , go .Bar , go .Funnel ] + hist2d_types :
13461348 if not wide_mode and (no_x != no_y ):
13471349 for ax in ["x" , "y" ]:
1348- if args .get (ax , None ) is None :
1350+ if args .get (ax ) is None :
13491351 args [ax ] = df_input .index if df_provided else Range ()
13501352 if constructor == go .Bar :
13511353 missing_bar_dim = ax
@@ -1369,7 +1371,7 @@ def build_dataframe(args, constructor):
13691371 )
13701372
13711373 no_color = False
1372- if type (args .get ("color" , None )) == str and args ["color" ] == NO_COLOR :
1374+ if type (args .get ("color" )) == str and args ["color" ] == NO_COLOR :
13731375 no_color = True
13741376 args ["color" ] = None
13751377 # now that things have been prepped, we do the systematic rewriting of `args`
@@ -1777,25 +1779,25 @@ def infer_config(args, constructor, trace_patch, layout_patch):
17771779 else args ["geojson" ].__geo_interface__
17781780 )
17791781
1780- # Compute marginal attribute
1782+ # Compute marginal attribute: copy to appropriate marginal_*
17811783 if "marginal" in args :
17821784 position = "marginal_x" if args ["orientation" ] == "v" else "marginal_y"
17831785 other_position = "marginal_x" if args ["orientation" ] == "h" else "marginal_y"
17841786 args [position ] = args ["marginal" ]
17851787 args [other_position ] = None
17861788
17871789 # If both marginals and faceting are specified, faceting wins
1788- if args .get ("facet_col" , None ) is not None and args .get ("marginal_y" , None ) :
1790+ if args .get ("facet_col" ) is not None and args .get ("marginal_y" ) is not None :
17891791 args ["marginal_y" ] = None
17901792
1791- if args .get ("facet_row" , None ) is not None and args .get ("marginal_x" , None ) :
1793+ if args .get ("facet_row" ) is not None and args .get ("marginal_x" ) is not None :
17921794 args ["marginal_x" ] = None
17931795
17941796 # facet_col_wrap only works if no marginals or row faceting is used
17951797 if (
1796- args .get ("marginal_x" , None ) is not None
1797- or args .get ("marginal_y" , None ) is not None
1798- or args .get ("facet_row" , None ) is not None
1798+ args .get ("marginal_x" ) is not None
1799+ or args .get ("marginal_y" ) is not None
1800+ or args .get ("facet_row" ) is not None
17991801 ):
18001802 args ["facet_col_wrap" ] = 0
18011803
@@ -1814,43 +1816,41 @@ def infer_config(args, constructor, trace_patch, layout_patch):
18141816
18151817def get_orderings (args , grouper , grouped ):
18161818 """
1817- `orders` is the user-supplied ordering (with the remaining data-frame-supplied
1818- ordering appended if the column is used for grouping). It includes anything the user
1819- gave, for any variable, including values not present in the dataset. It is used
1820- downstream to set e.g. `categoryarray` for cartesian axes
1821-
1822- `group_names` is the set of groups, ordered by the order above
1823-
1824- `group_values` is a subset of `orders` in both keys and values. It contains a key
1825- for every grouped mapping and its values are the sorted *data* values for these
1826- mappings.
1819+ `orders` is the user-supplied ordering with the remaining data-frame-supplied
1820+ ordering appended if the column is used for grouping. It includes anything the user
1821+ gave, for any variable, including values not present in the dataset. It's a dict
1822+ where the keys are e.g. "x" or "color"
1823+
1824+ `sorted_group_names` is the set of groups, ordered by the order above. It's a list
1825+ of tuples like [("value1", ""), ("value2", "")] where each tuple contains the name
1826+ of a single dimension-group
18271827 """
1828+
18281829 orders = {} if "category_orders" not in args else args ["category_orders" ].copy ()
1829- group_names = []
1830- group_values = {}
1830+ for col in grouper :
1831+ if col != one_group :
1832+ uniques = args ["data_frame" ][col ].unique ()
1833+ if col not in orders :
1834+ orders [col ] = list (uniques )
1835+ else :
1836+ orders [col ] = list (orders [col ])
1837+ for val in uniques :
1838+ if val not in orders [col ]:
1839+ orders [col ].append (val )
1840+
1841+ sorted_group_names = []
18311842 for group_name in grouped .groups :
18321843 if len (grouper ) == 1 :
18331844 group_name = (group_name ,)
1834- group_names .append (group_name )
1835- for col in grouper :
1836- if col != one_group :
1837- uniques = args ["data_frame" ][col ].unique ()
1838- if col not in orders :
1839- orders [col ] = list (uniques )
1840- else :
1841- for val in uniques :
1842- if val not in orders [col ]:
1843- orders [col ].append (val )
1844- group_values [col ] = sorted (uniques , key = orders [col ].index )
1845+ sorted_group_names .append (group_name )
18451846
18461847 for i , col in reversed (list (enumerate (grouper ))):
18471848 if col != one_group :
1848- group_names = sorted (
1849- group_names ,
1849+ sorted_group_names = sorted (
1850+ sorted_group_names ,
18501851 key = lambda g : orders [col ].index (g [i ]) if g [i ] in orders [col ] else - 1 ,
18511852 )
1852-
1853- return orders , group_names , group_values
1853+ return orders , sorted_group_names
18541854
18551855
18561856def make_figure (args , constructor , trace_patch = None , layout_patch = None ):
@@ -1871,32 +1871,35 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
18711871 grouper = [x .grouper or one_group for x in grouped_mappings ] or [one_group ]
18721872 grouped = args ["data_frame" ].groupby (grouper , sort = False )
18731873
1874- orders , sorted_group_names , sorted_group_values = get_orderings (
1875- args , grouper , grouped
1876- )
1874+ orders , sorted_group_names = get_orderings (args , grouper , grouped )
18771875
18781876 col_labels = []
18791877 row_labels = []
1880-
1878+ nrows = ncols = 1
18811879 for m in grouped_mappings :
1882- if m .grouper :
1880+ if m .grouper not in orders :
1881+ m .val_map ["" ] = m .sequence [0 ]
1882+ else :
1883+ sorted_values = orders [m .grouper ]
18831884 if m .facet == "col" :
18841885 prefix = get_label (args , args ["facet_col" ]) + "="
1885- col_labels = [prefix + str (s ) for s in sorted_group_values [m .grouper ]]
1886+ col_labels = [prefix + str (s ) for s in sorted_values ]
1887+ ncols = len (col_labels )
18861888 if m .facet == "row" :
18871889 prefix = get_label (args , args ["facet_row" ]) + "="
1888- row_labels = [prefix + str (s ) for s in sorted_group_values [m .grouper ]]
1889- for val in sorted_group_values [m .grouper ]:
1890- if val not in m .val_map :
1890+ row_labels = [prefix + str (s ) for s in sorted_values ]
1891+ nrows = len (row_labels )
1892+ for val in sorted_values :
1893+ if val not in m .val_map : # always False if it's an IdentityMap
18911894 m .val_map [val ] = m .sequence [len (m .val_map ) % len (m .sequence )]
18921895
18931896 subplot_type = _subplot_type_for_trace_type (constructor ().type )
18941897
18951898 trace_names_by_frame = {}
18961899 frames = OrderedDict ()
18971900 trendline_rows = []
1898- nrows = ncols = 1
18991901 trace_name_labels = None
1902+ facet_col_wrap = args .get ("facet_col_wrap" , 0 )
19001903 for group_name in sorted_group_names :
19011904 group = grouped .get_group (group_name if len (group_name ) > 1 else group_name [0 ])
19021905 mapping_labels = OrderedDict ()
@@ -1943,8 +1946,6 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
19431946
19441947 for i , m in enumerate (grouped_mappings ):
19451948 val = group_name [i ]
1946- if val not in m .val_map :
1947- m .val_map [val ] = m .sequence [len (m .val_map ) % len (m .sequence )]
19481949 try :
19491950 m .updater (trace , m .val_map [val ]) # covers most cases
19501951 except ValueError :
@@ -1979,14 +1980,13 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
19791980 row = m .val_map [val ]
19801981 else :
19811982 if (
1982- bool ( args .get ("marginal_x" , False ))
1983- and trace_spec .marginal != "x"
1983+ args .get ("marginal_x" ) is not None # there is a marginal
1984+ and trace_spec .marginal != "x" # and we're not it
19841985 ):
19851986 row = 2
19861987 else :
19871988 row = 1
19881989
1989- facet_col_wrap = args .get ("facet_col_wrap" , 0 )
19901990 # Find col for trace, handling facet_col and marginal_y
19911991 if m .facet == "col" :
19921992 col = m .val_map [val ]
@@ -1999,11 +1999,9 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
19991999 else :
20002000 col = 1
20012001
2002- nrows = max (nrows , row )
20032002 if row > 1 :
20042003 trace ._subplot_row = row
20052004
2006- ncols = max (ncols , col )
20072005 if col > 1 :
20082006 trace ._subplot_col = col
20092007 if (
@@ -2062,6 +2060,16 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
20622060 ):
20632061 layout_patch ["legend" ]["itemsizing" ] = "constant"
20642062
2063+ if facet_col_wrap :
2064+ nrows = math .ceil (ncols / facet_col_wrap )
2065+ ncols = min (ncols , facet_col_wrap )
2066+
2067+ if args .get ("marginal_x" ) is not None :
2068+ nrows += 1
2069+
2070+ if args .get ("marginal_y" ) is not None :
2071+ ncols += 1
2072+
20652073 fig = init_figure (
20662074 args , subplot_type , frame_list , nrows , ncols , col_labels , row_labels
20672075 )
@@ -2106,7 +2114,7 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la
21062114
21072115 # Build column_widths/row_heights
21082116 if subplot_type == "xy" :
2109- if bool ( args .get ("marginal_x" , False )) :
2117+ if args .get ("marginal_x" ) is not None :
21102118 if args ["marginal_x" ] == "histogram" or ("color" in args and args ["color" ]):
21112119 main_size = 0.74
21122120 else :
@@ -2115,11 +2123,11 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la
21152123 row_heights = [main_size ] * (nrows - 1 ) + [1 - main_size ]
21162124 vertical_spacing = 0.01
21172125 elif facet_col_wrap :
2118- vertical_spacing = args .get ("facet_row_spacing" , None ) or 0.07
2126+ vertical_spacing = args .get ("facet_row_spacing" ) or 0.07
21192127 else :
2120- vertical_spacing = args .get ("facet_row_spacing" , None ) or 0.03
2128+ vertical_spacing = args .get ("facet_row_spacing" ) or 0.03
21212129
2122- if bool ( args .get ("marginal_y" , False )) :
2130+ if args .get ("marginal_y" ) is not None :
21232131 if args ["marginal_y" ] == "histogram" or ("color" in args and args ["color" ]):
21242132 main_size = 0.74
21252133 else :
@@ -2128,18 +2136,18 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la
21282136 column_widths = [main_size ] * (ncols - 1 ) + [1 - main_size ]
21292137 horizontal_spacing = 0.005
21302138 else :
2131- horizontal_spacing = args .get ("facet_col_spacing" , None ) or 0.02
2139+ horizontal_spacing = args .get ("facet_col_spacing" ) or 0.02
21322140 else :
21332141 # Other subplot types:
21342142 # 'scene', 'geo', 'polar', 'ternary', 'mapbox', 'domain', None
21352143 #
21362144 # We can customize subplot spacing per type once we enable faceting
21372145 # for all plot types
21382146 if facet_col_wrap :
2139- vertical_spacing = args .get ("facet_row_spacing" , None ) or 0.07
2147+ vertical_spacing = args .get ("facet_row_spacing" ) or 0.07
21402148 else :
2141- vertical_spacing = args .get ("facet_row_spacing" , None ) or 0.03
2142- horizontal_spacing = args .get ("facet_col_spacing" , None ) or 0.02
2149+ vertical_spacing = args .get ("facet_row_spacing" ) or 0.03
2150+ horizontal_spacing = args .get ("facet_col_spacing" ) or 0.02
21432151
21442152 if facet_col_wrap :
21452153 subplot_labels = [None ] * nrows * ncols
0 commit comments