@@ -400,10 +400,10 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
400
400
if hover_is_dict and not attr_value [col ]:
401
401
continue
402
402
if col in [
403
- args .get ("x" , None ),
404
- args .get ("y" , None ),
405
- args .get ("z" , None ),
406
- args .get ("base" , None ),
403
+ args .get ("x" ),
404
+ args .get ("y" ),
405
+ args .get ("z" ),
406
+ args .get ("base" ),
407
407
]:
408
408
continue
409
409
try :
@@ -552,8 +552,10 @@ def set_cartesian_axis_opts(args, axis, letter, orders):
552
552
axis ["categoryarray" ] = (
553
553
orders [args [letter ]]
554
554
if isinstance (axis , go .layout .XAxis )
555
- else list (reversed (orders [args [letter ]]))
555
+ else list (reversed (orders [args [letter ]])) # top down for Y axis
556
556
)
557
+ if "range" not in axis :
558
+ axis ["range" ] = [- 0.5 , len (orders [args [letter ]]) - 0.5 ]
557
559
558
560
559
561
def configure_cartesian_marginal_axes (args , fig , orders ):
@@ -1284,8 +1286,8 @@ def build_dataframe(args, constructor):
1284
1286
1285
1287
# now we handle special cases like wide-mode or x-xor-y specification
1286
1288
# by rearranging args to tee things up for process_args_into_dataframe to work
1287
- no_x = args .get ("x" , None ) is None
1288
- no_y = args .get ("y" , None ) is None
1289
+ no_x = args .get ("x" ) is None
1290
+ no_y = args .get ("y" ) is None
1289
1291
wide_x = False if no_x else _is_col_list (df_input , args ["x" ])
1290
1292
wide_y = False if no_y else _is_col_list (df_input , args ["y" ])
1291
1293
@@ -1312,9 +1314,9 @@ def build_dataframe(args, constructor):
1312
1314
if var_name in [None , "value" , "index" ] or var_name in df_input :
1313
1315
var_name = "variable"
1314
1316
if constructor == go .Funnel :
1315
- wide_orientation = args .get ("orientation" , None ) or "h"
1317
+ wide_orientation = args .get ("orientation" ) or "h"
1316
1318
else :
1317
- wide_orientation = args .get ("orientation" , None ) or "v"
1319
+ wide_orientation = args .get ("orientation" ) or "v"
1318
1320
args ["orientation" ] = wide_orientation
1319
1321
args ["wide_cross" ] = None
1320
1322
elif wide_x != wide_y :
@@ -1345,7 +1347,7 @@ def build_dataframe(args, constructor):
1345
1347
if constructor in [go .Scatter , go .Bar , go .Funnel ] + hist2d_types :
1346
1348
if not wide_mode and (no_x != no_y ):
1347
1349
for ax in ["x" , "y" ]:
1348
- if args .get (ax , None ) is None :
1350
+ if args .get (ax ) is None :
1349
1351
args [ax ] = df_input .index if df_provided else Range ()
1350
1352
if constructor == go .Bar :
1351
1353
missing_bar_dim = ax
@@ -1369,7 +1371,7 @@ def build_dataframe(args, constructor):
1369
1371
)
1370
1372
1371
1373
no_color = False
1372
- if type (args .get ("color" , None )) == str and args ["color" ] == NO_COLOR :
1374
+ if type (args .get ("color" )) == str and args ["color" ] == NO_COLOR :
1373
1375
no_color = True
1374
1376
args ["color" ] = None
1375
1377
# now that things have been prepped, we do the systematic rewriting of `args`
@@ -1777,25 +1779,25 @@ def infer_config(args, constructor, trace_patch, layout_patch):
1777
1779
else args ["geojson" ].__geo_interface__
1778
1780
)
1779
1781
1780
- # Compute marginal attribute
1782
+ # Compute marginal attribute: copy to appropriate marginal_*
1781
1783
if "marginal" in args :
1782
1784
position = "marginal_x" if args ["orientation" ] == "v" else "marginal_y"
1783
1785
other_position = "marginal_x" if args ["orientation" ] == "h" else "marginal_y"
1784
1786
args [position ] = args ["marginal" ]
1785
1787
args [other_position ] = None
1786
1788
1787
1789
# If both marginals and faceting are specified, faceting wins
1788
- if args .get ("facet_col" , None ) is not None and args .get ("marginal_y" , None ) :
1790
+ if args .get ("facet_col" ) is not None and args .get ("marginal_y" ) is not None :
1789
1791
args ["marginal_y" ] = None
1790
1792
1791
- if args .get ("facet_row" , None ) is not None and args .get ("marginal_x" , None ) :
1793
+ if args .get ("facet_row" ) is not None and args .get ("marginal_x" ) is not None :
1792
1794
args ["marginal_x" ] = None
1793
1795
1794
1796
# facet_col_wrap only works if no marginals or row faceting is used
1795
1797
if (
1796
- args .get ("marginal_x" , None ) is not None
1797
- or args .get ("marginal_y" , None ) is not None
1798
- or args .get ("facet_row" , None ) is not None
1798
+ args .get ("marginal_x" ) is not None
1799
+ or args .get ("marginal_y" ) is not None
1800
+ or args .get ("facet_row" ) is not None
1799
1801
):
1800
1802
args ["facet_col_wrap" ] = 0
1801
1803
@@ -1814,43 +1816,41 @@ def infer_config(args, constructor, trace_patch, layout_patch):
1814
1816
1815
1817
def get_orderings (args , grouper , grouped ):
1816
1818
"""
1817
- `orders` is the user-supplied ordering (with the remaining data-frame-supplied
1818
- ordering appended if the column is used for grouping). It includes anything the user
1819
- gave, for any variable, including values not present in the dataset. It is used
1820
- downstream to set e.g. `categoryarray` for cartesian axes
1821
-
1822
- `group_names` is the set of groups, ordered by the order above
1823
-
1824
- `group_values` is a subset of `orders` in both keys and values. It contains a key
1825
- for every grouped mapping and its values are the sorted *data* values for these
1826
- mappings.
1819
+ `orders` is the user-supplied ordering with the remaining data-frame-supplied
1820
+ ordering appended if the column is used for grouping. It includes anything the user
1821
+ gave, for any variable, including values not present in the dataset. It's a dict
1822
+ where the keys are e.g. "x" or "color"
1823
+
1824
+ `sorted_group_names` is the set of groups, ordered by the order above. It's a list
1825
+ of tuples like [("value1", ""), ("value2", "")] where each tuple contains the name
1826
+ of a single dimension-group
1827
1827
"""
1828
+
1828
1829
orders = {} if "category_orders" not in args else args ["category_orders" ].copy ()
1829
- group_names = []
1830
- group_values = {}
1830
+ for col in grouper :
1831
+ if col != one_group :
1832
+ uniques = args ["data_frame" ][col ].unique ()
1833
+ if col not in orders :
1834
+ orders [col ] = list (uniques )
1835
+ else :
1836
+ orders [col ] = list (orders [col ])
1837
+ for val in uniques :
1838
+ if val not in orders [col ]:
1839
+ orders [col ].append (val )
1840
+
1841
+ sorted_group_names = []
1831
1842
for group_name in grouped .groups :
1832
1843
if len (grouper ) == 1 :
1833
1844
group_name = (group_name ,)
1834
- group_names .append (group_name )
1835
- for col in grouper :
1836
- if col != one_group :
1837
- uniques = args ["data_frame" ][col ].unique ()
1838
- if col not in orders :
1839
- orders [col ] = list (uniques )
1840
- else :
1841
- for val in uniques :
1842
- if val not in orders [col ]:
1843
- orders [col ].append (val )
1844
- group_values [col ] = sorted (uniques , key = orders [col ].index )
1845
+ sorted_group_names .append (group_name )
1845
1846
1846
1847
for i , col in reversed (list (enumerate (grouper ))):
1847
1848
if col != one_group :
1848
- group_names = sorted (
1849
- group_names ,
1849
+ sorted_group_names = sorted (
1850
+ sorted_group_names ,
1850
1851
key = lambda g : orders [col ].index (g [i ]) if g [i ] in orders [col ] else - 1 ,
1851
1852
)
1852
-
1853
- return orders , group_names , group_values
1853
+ return orders , sorted_group_names
1854
1854
1855
1855
1856
1856
def make_figure (args , constructor , trace_patch = None , layout_patch = None ):
@@ -1871,32 +1871,35 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
1871
1871
grouper = [x .grouper or one_group for x in grouped_mappings ] or [one_group ]
1872
1872
grouped = args ["data_frame" ].groupby (grouper , sort = False )
1873
1873
1874
- orders , sorted_group_names , sorted_group_values = get_orderings (
1875
- args , grouper , grouped
1876
- )
1874
+ orders , sorted_group_names = get_orderings (args , grouper , grouped )
1877
1875
1878
1876
col_labels = []
1879
1877
row_labels = []
1880
-
1878
+ nrows = ncols = 1
1881
1879
for m in grouped_mappings :
1882
- if m .grouper :
1880
+ if m .grouper not in orders :
1881
+ m .val_map ["" ] = m .sequence [0 ]
1882
+ else :
1883
+ sorted_values = orders [m .grouper ]
1883
1884
if m .facet == "col" :
1884
1885
prefix = get_label (args , args ["facet_col" ]) + "="
1885
- col_labels = [prefix + str (s ) for s in sorted_group_values [m .grouper ]]
1886
+ col_labels = [prefix + str (s ) for s in sorted_values ]
1887
+ ncols = len (col_labels )
1886
1888
if m .facet == "row" :
1887
1889
prefix = get_label (args , args ["facet_row" ]) + "="
1888
- row_labels = [prefix + str (s ) for s in sorted_group_values [m .grouper ]]
1889
- for val in sorted_group_values [m .grouper ]:
1890
- if val not in m .val_map :
1890
+ row_labels = [prefix + str (s ) for s in sorted_values ]
1891
+ nrows = len (row_labels )
1892
+ for val in sorted_values :
1893
+ if val not in m .val_map : # always False if it's an IdentityMap
1891
1894
m .val_map [val ] = m .sequence [len (m .val_map ) % len (m .sequence )]
1892
1895
1893
1896
subplot_type = _subplot_type_for_trace_type (constructor ().type )
1894
1897
1895
1898
trace_names_by_frame = {}
1896
1899
frames = OrderedDict ()
1897
1900
trendline_rows = []
1898
- nrows = ncols = 1
1899
1901
trace_name_labels = None
1902
+ facet_col_wrap = args .get ("facet_col_wrap" , 0 )
1900
1903
for group_name in sorted_group_names :
1901
1904
group = grouped .get_group (group_name if len (group_name ) > 1 else group_name [0 ])
1902
1905
mapping_labels = OrderedDict ()
@@ -1943,8 +1946,6 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
1943
1946
1944
1947
for i , m in enumerate (grouped_mappings ):
1945
1948
val = group_name [i ]
1946
- if val not in m .val_map :
1947
- m .val_map [val ] = m .sequence [len (m .val_map ) % len (m .sequence )]
1948
1949
try :
1949
1950
m .updater (trace , m .val_map [val ]) # covers most cases
1950
1951
except ValueError :
@@ -1979,14 +1980,13 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
1979
1980
row = m .val_map [val ]
1980
1981
else :
1981
1982
if (
1982
- bool ( args .get ("marginal_x" , False ))
1983
- and trace_spec .marginal != "x"
1983
+ args .get ("marginal_x" ) is not None # there is a marginal
1984
+ and trace_spec .marginal != "x" # and we're not it
1984
1985
):
1985
1986
row = 2
1986
1987
else :
1987
1988
row = 1
1988
1989
1989
- facet_col_wrap = args .get ("facet_col_wrap" , 0 )
1990
1990
# Find col for trace, handling facet_col and marginal_y
1991
1991
if m .facet == "col" :
1992
1992
col = m .val_map [val ]
@@ -1999,11 +1999,9 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
1999
1999
else :
2000
2000
col = 1
2001
2001
2002
- nrows = max (nrows , row )
2003
2002
if row > 1 :
2004
2003
trace ._subplot_row = row
2005
2004
2006
- ncols = max (ncols , col )
2007
2005
if col > 1 :
2008
2006
trace ._subplot_col = col
2009
2007
if (
@@ -2062,6 +2060,16 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
2062
2060
):
2063
2061
layout_patch ["legend" ]["itemsizing" ] = "constant"
2064
2062
2063
+ if facet_col_wrap :
2064
+ nrows = math .ceil (ncols / facet_col_wrap )
2065
+ ncols = min (ncols , facet_col_wrap )
2066
+
2067
+ if args .get ("marginal_x" ) is not None :
2068
+ nrows += 1
2069
+
2070
+ if args .get ("marginal_y" ) is not None :
2071
+ ncols += 1
2072
+
2065
2073
fig = init_figure (
2066
2074
args , subplot_type , frame_list , nrows , ncols , col_labels , row_labels
2067
2075
)
@@ -2106,7 +2114,7 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la
2106
2114
2107
2115
# Build column_widths/row_heights
2108
2116
if subplot_type == "xy" :
2109
- if bool ( args .get ("marginal_x" , False )) :
2117
+ if args .get ("marginal_x" ) is not None :
2110
2118
if args ["marginal_x" ] == "histogram" or ("color" in args and args ["color" ]):
2111
2119
main_size = 0.74
2112
2120
else :
@@ -2115,11 +2123,11 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la
2115
2123
row_heights = [main_size ] * (nrows - 1 ) + [1 - main_size ]
2116
2124
vertical_spacing = 0.01
2117
2125
elif facet_col_wrap :
2118
- vertical_spacing = args .get ("facet_row_spacing" , None ) or 0.07
2126
+ vertical_spacing = args .get ("facet_row_spacing" ) or 0.07
2119
2127
else :
2120
- vertical_spacing = args .get ("facet_row_spacing" , None ) or 0.03
2128
+ vertical_spacing = args .get ("facet_row_spacing" ) or 0.03
2121
2129
2122
- if bool ( args .get ("marginal_y" , False )) :
2130
+ if args .get ("marginal_y" ) is not None :
2123
2131
if args ["marginal_y" ] == "histogram" or ("color" in args and args ["color" ]):
2124
2132
main_size = 0.74
2125
2133
else :
@@ -2128,18 +2136,18 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la
2128
2136
column_widths = [main_size ] * (ncols - 1 ) + [1 - main_size ]
2129
2137
horizontal_spacing = 0.005
2130
2138
else :
2131
- horizontal_spacing = args .get ("facet_col_spacing" , None ) or 0.02
2139
+ horizontal_spacing = args .get ("facet_col_spacing" ) or 0.02
2132
2140
else :
2133
2141
# Other subplot types:
2134
2142
# 'scene', 'geo', 'polar', 'ternary', 'mapbox', 'domain', None
2135
2143
#
2136
2144
# We can customize subplot spacing per type once we enable faceting
2137
2145
# for all plot types
2138
2146
if facet_col_wrap :
2139
- vertical_spacing = args .get ("facet_row_spacing" , None ) or 0.07
2147
+ vertical_spacing = args .get ("facet_row_spacing" ) or 0.07
2140
2148
else :
2141
- vertical_spacing = args .get ("facet_row_spacing" , None ) or 0.03
2142
- horizontal_spacing = args .get ("facet_col_spacing" , None ) or 0.02
2149
+ vertical_spacing = args .get ("facet_row_spacing" ) or 0.03
2150
+ horizontal_spacing = args .get ("facet_col_spacing" ) or 0.02
2143
2151
2144
2152
if facet_col_wrap :
2145
2153
subplot_labels = [None ] * nrows * ncols
0 commit comments