Skip to content

Commit f7dc2be

Browse files
authored
Sunburst improvements (#2133)
* color column now appears in hover * corrected bug: path column can be numeric
1 parent 51fa1ee commit f7dc2be

File tree

2 files changed

+20
-5
lines changed

2 files changed

+20
-5
lines changed

packages/python/plotly/plotly/express/_core.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,7 @@ def build_dataframe(args, attrables, array_attrables):
10241024
def _check_dataframe_all_leaves(df):
10251025
df_sorted = df.sort_values(by=list(df.columns))
10261026
null_mask = df_sorted.isnull()
1027+
df_sorted = df_sorted.astype(str)
10271028
null_indices = np.nonzero(null_mask.any(axis=1).values)[0]
10281029
for null_row_index in null_indices:
10291030
row = null_mask.iloc[null_row_index]
@@ -1055,8 +1056,9 @@ def process_dataframe_hierarchy(args):
10551056

10561057
if args["color"] and args["color"] in path:
10571058
series_to_copy = df[args["color"]]
1058-
args["color"] = str(args["color"]) + "additional_col_for_px"
1059-
df[args["color"]] = series_to_copy
1059+
new_col_name = args["color"] + "additional_col_for_color"
1060+
path = [new_col_name if x == args["color"] else x for x in path]
1061+
df[new_col_name] = series_to_copy
10601062
if args["hover_data"]:
10611063
for col_name in args["hover_data"]:
10621064
if col_name == args["color"]:
@@ -1160,6 +1162,11 @@ def aggfunc_continuous(x):
11601162
args["ids"] = "id"
11611163
args["names"] = "labels"
11621164
args["parents"] = "parent"
1165+
if args["color"]:
1166+
if not args["hover_data"]:
1167+
args["hover_data"] = [args["color"]]
1168+
else:
1169+
args["hover_data"].append(args["color"])
11631170
return args
11641171

11651172

packages/python/plotly/plotly/tests/test_core/test_px/test_px_functions.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -209,14 +209,22 @@ def test_sunburst_treemap_with_path_color():
209209
# Hover info
210210
df["hover"] = [el.lower() for el in vendors]
211211
fig = px.sunburst(df, path=path, color="calls", hover_data=["hover"])
212-
custom = fig.data[0].customdata.ravel()
213-
assert np.all(custom[:8] == df["hover"])
214-
assert np.all(custom[8:] == "(?)")
212+
custom = fig.data[0].customdata
213+
assert np.all(custom[:8, 0] == df["hover"])
214+
assert np.all(custom[8:, 0] == "(?)")
215+
assert np.all(custom[:8, 1] == df["calls"])
215216

216217
# Discrete color
217218
fig = px.sunburst(df, path=path, color="vendors")
218219
assert len(np.unique(fig.data[0].marker.colors)) == 9
219220

221+
# Numerical column in path
222+
df["regions"] = df["regions"].map({"North": 1, "South": 2})
223+
path = ["total", "regions", "sectors", "vendors"]
224+
fig = px.sunburst(df, path=path, values="values", color="calls")
225+
colors = fig.data[0].marker.colors
226+
assert np.all(np.array(colors[:8]) == np.array(calls))
227+
220228

221229
def test_sunburst_treemap_with_path_non_rectangular():
222230
vendors = ["A", "B", "C", "D", None, "E", "F", "G", "H", None]

0 commit comments

Comments
 (0)