Skip to content

Commit 5305f74

Browse files
Kullyjonmmease
authored andcommitted
subplot titles respecting row_width and column_width (#1245)
* first pass at subplot titles respecting row_width and column_width * add test for row_width and column_width in make_subplots
1 parent 9875f66 commit 5305f74

File tree

2 files changed

+154
-16
lines changed

2 files changed

+154
-16
lines changed

Diff for: plotly/tests/test_core/test_tools/test_make_subplots.py

+114-5
Original file line numberDiff line numberDiff line change
@@ -1946,7 +1946,7 @@ def test_subplot_titles_shared_axes(self):
19461946
layout=Layout(
19471947
annotations=Annotations([
19481948
Annotation(
1949-
x=0.22499999999999998,
1949+
x=0.225,
19501950
y=1.0,
19511951
xref='paper',
19521952
yref='paper',
@@ -1957,7 +1957,7 @@ def test_subplot_titles_shared_axes(self):
19571957
yanchor='bottom'
19581958
),
19591959
Annotation(
1960-
x=0.7749999999999999,
1960+
x=0.775,
19611961
y=1.0,
19621962
xref='paper',
19631963
yref='paper',
@@ -1968,7 +1968,7 @@ def test_subplot_titles_shared_axes(self):
19681968
yanchor='bottom'
19691969
),
19701970
Annotation(
1971-
x=0.22499999999999998,
1971+
x=0.225,
19721972
y=0.375,
19731973
xref='paper',
19741974
yref='paper',
@@ -1979,7 +1979,7 @@ def test_subplot_titles_shared_axes(self):
19791979
yanchor='bottom'
19801980
),
19811981
Annotation(
1982-
x=0.7749999999999999,
1982+
x=0.775,
19831983
y=0.375,
19841984
xref='paper',
19851985
yref='paper',
@@ -2010,13 +2010,13 @@ def test_subplot_titles_shared_axes(self):
20102010
)
20112011
)
20122012
)
2013+
20132014
fig = tls.make_subplots(rows=2, cols=2,
20142015
subplot_titles=('Title 1', 'Title 2',
20152016
'Title 3', 'Title 4'),
20162017
shared_xaxes=True, shared_yaxes=True)
20172018
self.assertEqual(fig.to_plotly_json(), expected.to_plotly_json())
20182019

2019-
20202020
def test_subplot_titles_irregular_layout(self):
20212021
# make a title for each subplot when the layout is irregular:
20222022
expected = Figure(
@@ -2155,3 +2155,112 @@ def test_large_columns_no_errors(self):
21552155
fig = tls.make_subplots(100, 1,
21562156
vertical_spacing=v_space,
21572157
specs=[[{'is_3d': True}] for _ in range(100)])
2158+
2159+
def test_row_width_and_column_width(self):
2160+
2161+
expected = Figure({
2162+
'data': [],
2163+
'layout': {'annotations': [{'font': {'size': 16},
2164+
'showarrow': False,
2165+
'text': 'Title 1',
2166+
'x': 0.405,
2167+
'xanchor': 'center',
2168+
'xref': 'paper',
2169+
'y': 1.0,
2170+
'yanchor': 'bottom',
2171+
'yref': 'paper'},
2172+
{'font': {'size': 16},
2173+
'showarrow': False,
2174+
'text': 'Title 2',
2175+
'x': 0.9550000000000001,
2176+
'xanchor': 'center',
2177+
'xref': 'paper',
2178+
'y': 1.0,
2179+
'yanchor': 'bottom',
2180+
'yref': 'paper'},
2181+
{'font': {'size': 16},
2182+
'showarrow': False,
2183+
'text': 'Title 3',
2184+
'x': 0.405,
2185+
'xanchor': 'center',
2186+
'xref': 'paper',
2187+
'y': 0.1875,
2188+
'yanchor': 'bottom',
2189+
'yref': 'paper'},
2190+
{'font': {'size': 16},
2191+
'showarrow': False,
2192+
'text': 'Title 4',
2193+
'x': 0.9550000000000001,
2194+
'xanchor': 'center',
2195+
'xref': 'paper',
2196+
'y': 0.1875,
2197+
'yanchor': 'bottom',
2198+
'yref': 'paper'}],
2199+
'xaxis': {'anchor': 'y', 'domain': [0.0, 0.81]},
2200+
'xaxis2': {'anchor': 'y2', 'domain': [0.91, 1.0]},
2201+
'xaxis3': {'anchor': 'y3', 'domain': [0.0, 0.81]},
2202+
'xaxis4': {'anchor': 'y4', 'domain': [0.91, 1.0]},
2203+
'yaxis': {'anchor': 'x', 'domain': [0.4375, 1.0]},
2204+
'yaxis2': {'anchor': 'x2', 'domain': [0.4375, 1.0]},
2205+
'yaxis3': {'anchor': 'x3', 'domain': [0.0, 0.1875]},
2206+
'yaxis4': {'anchor': 'x4', 'domain': [0.0, 0.1875]}}
2207+
})
2208+
fig = tls.make_subplots(rows=2, cols=2,
2209+
subplot_titles=('Title 1', 'Title 2', 'Title 3', 'Title 4'),
2210+
row_width=[1, 3], column_width=[9, 1])
2211+
self.assertEqual(fig.to_plotly_json(), expected.to_plotly_json())
2212+
2213+
def test_row_width_and_shared_yaxes(self):
2214+
2215+
expected = Figure({
2216+
'data': [],
2217+
'layout': {'annotations': [{'font': {'size': 16},
2218+
'showarrow': False,
2219+
'text': 'Title 1',
2220+
'x': 0.225,
2221+
'xanchor': 'center',
2222+
'xref': 'paper',
2223+
'y': 1.0,
2224+
'yanchor': 'bottom',
2225+
'yref': 'paper'},
2226+
{'font': {'size': 16},
2227+
'showarrow': False,
2228+
'text': 'Title 2',
2229+
'x': 0.775,
2230+
'xanchor': 'center',
2231+
'xref': 'paper',
2232+
'y': 1.0,
2233+
'yanchor': 'bottom',
2234+
'yref': 'paper'},
2235+
{'font': {'size': 16},
2236+
'showarrow': False,
2237+
'text': 'Title 3',
2238+
'x': 0.225,
2239+
'xanchor': 'center',
2240+
'xref': 'paper',
2241+
'y': 0.1875,
2242+
'yanchor': 'bottom',
2243+
'yref': 'paper'},
2244+
{'font': {'size': 16},
2245+
'showarrow': False,
2246+
'text': 'Title 4',
2247+
'x': 0.775,
2248+
'xanchor': 'center',
2249+
'xref': 'paper',
2250+
'y': 0.1875,
2251+
'yanchor': 'bottom',
2252+
'yref': 'paper'}],
2253+
'xaxis': {'anchor': 'y', 'domain': [0.0, 0.45]},
2254+
'xaxis2': {'anchor': 'free', 'domain': [0.55, 1.0], 'position': 0.4375},
2255+
'xaxis3': {'anchor': 'y2', 'domain': [0.0, 0.45]},
2256+
'xaxis4': {'anchor': 'free', 'domain': [0.55, 1.0], 'position': 0.0},
2257+
'yaxis': {'anchor': 'x', 'domain': [0.4375, 1.0]},
2258+
'yaxis2': {'anchor': 'x3', 'domain': [0.0, 0.1875]}}
2259+
})
2260+
2261+
fig = tls.make_subplots(rows=2, cols=2, row_width=[1, 3], shared_yaxes=True,
2262+
subplot_titles=('Title 1', 'Title 2', 'Title 3', 'Title 4'))
2263+
2264+
self.assertEqual(fig.to_plotly_json(), expected.to_plotly_json())
2265+
2266+
# def test_row_width_and_shared_yaxes(self):

Diff for: plotly/tools.py

+40-11
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import six
1515
import copy
16+
import re
1617

1718
from plotly import exceptions, optional_imports, session, utils
1819
from plotly.files import (CONFIG_FILE, CREDENTIALS_FILE, FILE_CONTENT,
@@ -1001,7 +1002,6 @@ def _checks(item, defaults):
10011002
) for c in col_seq
10021003
] for r in row_seq
10031004
]
1004-
10051005
# [grid_ref] Initialize the grid and insets' axis-reference lists
10061006
grid_ref = [[None for c in range(cols)] for r in range(rows)]
10071007
insets_ref = [None for inset in range(len(insets))] if insets else None
@@ -1323,20 +1323,49 @@ def _pad(s, cell_len=cell_len):
13231323
subtitle_pos_x.append(sum(x_domains) / 2)
13241324
for y_domains in y_dom:
13251325
subtitle_pos_y.append(y_domains[1])
1326+
13261327
# If shared_axes is True the domin of each subplot is not returned so the
13271328
# title position must be calculated for each subplot
13281329
else:
1329-
subtitle_pos_x = [None] * cols
1330-
subtitle_pos_y = [None] * rows
1331-
delt_x = (x_e - x_s)
1330+
x_dom_vals = [k for k in layout.to_plotly_json().keys() if 'xaxis' in k]
1331+
y_dom_vals = [k for k in layout.to_plotly_json().keys() if 'yaxis' in k]
1332+
1333+
# sort xaxis and yaxis layout keys
1334+
r = re.compile('\d+')
1335+
1336+
def key_func(m):
1337+
try:
1338+
return int(r.search(m).group(0))
1339+
except AttributeError:
1340+
return 0
1341+
1342+
xaxies_labels_sorted = sorted(x_dom_vals, key=key_func)
1343+
yaxies_labels_sorted = sorted(y_dom_vals, key=key_func)
1344+
1345+
x_dom = [layout[k]['domain'] for k in xaxies_labels_sorted]
1346+
y_dom = [layout[k]['domain'] for k in yaxies_labels_sorted]
1347+
13321348
for index in range(cols):
1333-
subtitle_pos_x[index] = ((delt_x / 2) +
1334-
((delt_x + horizontal_spacing) * index))
1335-
subtitle_pos_x *= rows
1336-
for index in range(rows):
1337-
subtitle_pos_y[index] = (1 - ((y_e + vertical_spacing) * index))
1338-
subtitle_pos_y *= cols
1339-
subtitle_pos_y = sorted(subtitle_pos_y, reverse=True)
1349+
subtitle_pos_x = []
1350+
for x_domains in x_dom:
1351+
subtitle_pos_x.append(sum(x_domains) / 2)
1352+
subtitle_pos_x *= rows
1353+
1354+
if shared_yaxes:
1355+
for index in range(rows):
1356+
subtitle_pos_y = []
1357+
for y_domain in y_dom:
1358+
subtitle_pos_y.append(y_domain[1])
1359+
subtitle_pos_y *= cols
1360+
subtitle_pos_y = sorted(subtitle_pos_y, reverse=True)
1361+
1362+
else:
1363+
for index in range(rows):
1364+
subtitle_pos_y = []
1365+
for y_domain in y_dom:
1366+
subtitle_pos_y.append(y_domain[1])
1367+
subtitle_pos_y = sorted(subtitle_pos_y, reverse=True)
1368+
subtitle_pos_y *= cols
13401369

13411370
plot_titles = []
13421371
for index in range(len(subplot_titles)):

0 commit comments

Comments
 (0)