Skip to content

Commit 705ec3b

Browse files
authored
Sunburst/treemap path (#2006)
1 parent 9e1f1c2 commit 705ec3b

File tree

6 files changed

+413
-2
lines changed

6 files changed

+413
-2
lines changed

doc/python/sunburst-charts.md

+47
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,53 @@ fig =px.sunburst(
6262
fig.show()
6363
```
6464

65+
### Sunburst of a rectangular DataFrame with plotly.express
66+
67+
Hierarchical data are often stored as a rectangular dataframe, with different columns corresponding to different levels of the hierarchy. `px.sunburst` can take a `path` parameter corresponding to a list of columns. Note that `id` and `parent` should not be provided if `path` is given.
68+
69+
```python
70+
import plotly.express as px
71+
df = px.data.tips()
72+
fig = px.sunburst(df, path=['day', 'time', 'sex'], values='total_bill')
73+
fig.show()
74+
```
75+
76+
### Sunburst of a rectangular DataFrame with continuous color argument in px.sunburst
77+
78+
If a `color` argument is passed, the color of a node is computed as the average of the color values of its children, weighted by their values.
79+
80+
```python
81+
import plotly.express as px
82+
import numpy as np
83+
df = px.data.gapminder().query("year == 2007")
84+
fig = px.sunburst(df, path=['continent', 'country'], values='pop',
85+
color='lifeExp', hover_data=['iso_alpha'],
86+
color_continuous_scale='RdBu',
87+
color_continuous_midpoint=np.average(df['lifeExp'], weights=df['pop']))
88+
fig.show()
89+
```
90+
91+
### Rectangular data with missing values
92+
93+
If the dataset is not fully rectangular, missing values should be supplied as `None`. Note that the parents of `None` entries must be a leaf, i.e. it cannot have other children than `None` (otherwise a `ValueError` is raised).
94+
95+
```python
96+
import plotly.express as px
97+
import pandas as pd
98+
vendors = ["A", "B", "C", "D", None, "E", "F", "G", "H", None]
99+
sectors = ["Tech", "Tech", "Finance", "Finance", "Other",
100+
"Tech", "Tech", "Finance", "Finance", "Other"]
101+
regions = ["North", "North", "North", "North", "North",
102+
"South", "South", "South", "South", "South"]
103+
sales = [1, 3, 2, 4, 1, 2, 2, 1, 4, 1]
104+
df = pd.DataFrame(
105+
dict(vendors=vendors, sectors=sectors, regions=regions, sales=sales)
106+
)
107+
print(df)
108+
fig = px.sunburst(df, path=['regions', 'sectors', 'vendors'], values='sales')
109+
fig.show()
110+
```
111+
65112
### Basic Sunburst Plot with go.Sunburst
66113

67114
If Plotly Express does not provide a good starting point, it is also possible to use the more generic `go.Sunburst` function from `plotly.graph_objects`.

doc/python/treemaps.md

+46
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,52 @@ fig = px.treemap(
5151
fig.show()
5252
```
5353

54+
### Treemap of a rectangular DataFrame with plotly.express
55+
56+
Hierarchical data are often stored as a rectangular dataframe, with different columns corresponding to different levels of the hierarchy. `px.treemap` can take a `path` parameter corresponding to a list of columns. Note that `id` and `parent` should not be provided if `path` is given.
57+
58+
```python
59+
import plotly.express as px
60+
df = px.data.tips()
61+
fig = px.treemap(df, path=['day', 'time', 'sex'], values='total_bill')
62+
fig.show()
63+
```
64+
65+
### Treemap of a rectangular DataFrame with continuous color argument in px.treemap
66+
67+
If a `color` argument is passed, the color of a node is computed as the average of the color values of its children, weighted by their values.
68+
69+
```python
70+
import plotly.express as px
71+
import numpy as np
72+
df = px.data.gapminder().query("year == 2007")
73+
fig = px.treemap(df, path=['continent', 'country'], values='pop',
74+
color='lifeExp', hover_data=['iso_alpha'],
75+
color_continuous_scale='RdBu',
76+
color_continuous_midpoint=np.average(df['lifeExp'], weights=df['pop']))
77+
fig.show()
78+
```
79+
80+
### Rectangular data with missing values
81+
82+
If the dataset is not fully rectangular, missing values should be supplied as `None`.
83+
84+
```python
85+
import plotly.express as px
86+
import pandas as pd
87+
vendors = ["A", "B", "C", "D", None, "E", "F", "G", "H", None]
88+
sectors = ["Tech", "Tech", "Finance", "Finance", "Other",
89+
"Tech", "Tech", "Finance", "Finance", "Other"]
90+
regions = ["North", "North", "North", "North", "North",
91+
"South", "South", "South", "South", "South"]
92+
sales = [1, 3, 2, 4, 1, 2, 2, 1, 4, 1]
93+
df = pd.DataFrame(
94+
dict(vendors=vendors, sectors=sectors, regions=regions, sales=sales)
95+
)
96+
print(df)
97+
fig = px.treemap(df, path=['regions', 'sectors', 'vendors'], values='sales')
98+
fig.show()
99+
```
54100
### Basic Treemap with go.Treemap
55101

56102
If Plotly Express does not provide a good starting point, it is also possible to use the more generic `go.Treemap` function from `plotly.graph_objects`.

packages/python/plotly/plotly/express/_chart_types.py

+16
Original file line numberDiff line numberDiff line change
@@ -1270,6 +1270,7 @@ def sunburst(
12701270
names=None,
12711271
values=None,
12721272
parents=None,
1273+
path=None,
12731274
ids=None,
12741275
color=None,
12751276
color_continuous_scale=None,
@@ -1296,6 +1297,13 @@ def sunburst(
12961297
layout_patch = {"sunburstcolorway": color_discrete_sequence}
12971298
else:
12981299
layout_patch = {}
1300+
if path is not None and (ids is not None or parents is not None):
1301+
raise ValueError(
1302+
"Either `path` should be provided, or `ids` and `parents`."
1303+
"These parameters are mutually exclusive and cannot be passed together."
1304+
)
1305+
if path is not None and branchvalues is None:
1306+
branchvalues = "total"
12991307
return make_figure(
13001308
args=locals(),
13011309
constructor=go.Sunburst,
@@ -1313,6 +1321,7 @@ def treemap(
13131321
values=None,
13141322
parents=None,
13151323
ids=None,
1324+
path=None,
13161325
color=None,
13171326
color_continuous_scale=None,
13181327
range_color=None,
@@ -1338,6 +1347,13 @@ def treemap(
13381347
layout_patch = {"treemapcolorway": color_discrete_sequence}
13391348
else:
13401349
layout_patch = {}
1350+
if path is not None and (ids is not None or parents is not None):
1351+
raise ValueError(
1352+
"Either `path` should be provided, or `ids` and `parents`."
1353+
"These parameters are mutually exclusive and cannot be passed together."
1354+
)
1355+
if path is not None and branchvalues is None:
1356+
branchvalues = "total"
13411357
return make_figure(
13421358
args=locals(),
13431359
constructor=go.Treemap,

packages/python/plotly/plotly/express/_core.py

+145-2
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,147 @@ def build_dataframe(args, attrables, array_attrables):
10091009
return args
10101010

10111011

1012+
def _check_dataframe_all_leaves(df):
1013+
df_sorted = df.sort_values(by=list(df.columns))
1014+
null_mask = df_sorted.isnull()
1015+
null_indices = np.nonzero(null_mask.any(axis=1).values)[0]
1016+
for null_row_index in null_indices:
1017+
row = null_mask.iloc[null_row_index]
1018+
indices = np.nonzero(row.values)[0]
1019+
if not row[indices[0] :].all():
1020+
raise ValueError(
1021+
"None entries cannot have not-None children",
1022+
df_sorted.iloc[null_row_index],
1023+
)
1024+
df_sorted[null_mask] = ""
1025+
row_strings = list(df_sorted.apply(lambda x: "".join(x), axis=1))
1026+
for i, row in enumerate(row_strings[:-1]):
1027+
if row_strings[i + 1] in row and (i + 1) in null_indices:
1028+
raise ValueError(
1029+
"Non-leaves rows are not permitted in the dataframe \n",
1030+
df_sorted.iloc[i + 1],
1031+
"is not a leaf.",
1032+
)
1033+
1034+
1035+
def process_dataframe_hierarchy(args):
1036+
"""
1037+
Build dataframe for sunburst or treemap when the path argument is provided.
1038+
"""
1039+
df = args["data_frame"]
1040+
path = args["path"][::-1]
1041+
_check_dataframe_all_leaves(df[path[::-1]])
1042+
discrete_color = False
1043+
1044+
if args["color"] and args["color"] in path:
1045+
series_to_copy = df[args["color"]]
1046+
args["color"] = str(args["color"]) + "additional_col_for_px"
1047+
df[args["color"]] = series_to_copy
1048+
if args["hover_data"]:
1049+
for col_name in args["hover_data"]:
1050+
if col_name == args["color"]:
1051+
series_to_copy = df[col_name]
1052+
new_col_name = str(args["color"]) + "additional_col_for_hover"
1053+
df[new_col_name] = series_to_copy
1054+
args["color"] = new_col_name
1055+
elif col_name in path:
1056+
series_to_copy = df[col_name]
1057+
new_col_name = col_name + "additional_col_for_hover"
1058+
path = [new_col_name if x == col_name else x for x in path]
1059+
df[new_col_name] = series_to_copy
1060+
# ------------ Define aggregation functions --------------------------------
1061+
def aggfunc_discrete(x):
1062+
uniques = x.unique()
1063+
if len(uniques) == 1:
1064+
return uniques[0]
1065+
else:
1066+
return "(?)"
1067+
1068+
agg_f = {}
1069+
aggfunc_color = None
1070+
if args["values"]:
1071+
try:
1072+
df[args["values"]] = pd.to_numeric(df[args["values"]])
1073+
except ValueError:
1074+
raise ValueError(
1075+
"Column `%s` of `df` could not be converted to a numerical data type."
1076+
% args["values"]
1077+
)
1078+
1079+
if args["color"]:
1080+
if args["color"] == args["values"]:
1081+
aggfunc_color = "sum"
1082+
count_colname = args["values"]
1083+
else:
1084+
# we need a count column for the first groupby and the weighted mean of color
1085+
# trick to be sure the col name is unused: take the sum of existing names
1086+
count_colname = (
1087+
"count"
1088+
if "count" not in df.columns
1089+
else "".join([str(el) for el in list(df.columns)])
1090+
)
1091+
# we can modify df because it's a copy of the px argument
1092+
df[count_colname] = 1
1093+
args["values"] = count_colname
1094+
agg_f[count_colname] = "sum"
1095+
1096+
if args["color"]:
1097+
if df[args["color"]].dtype.kind not in "bifc":
1098+
aggfunc_color = aggfunc_discrete
1099+
discrete_color = True
1100+
elif not aggfunc_color:
1101+
1102+
def aggfunc_continuous(x):
1103+
return np.average(x, weights=df.loc[x.index, count_colname])
1104+
1105+
aggfunc_color = aggfunc_continuous
1106+
agg_f[args["color"]] = aggfunc_color
1107+
1108+
# Other columns (for color, hover_data, custom_data etc.)
1109+
cols = list(set(df.columns).difference(path))
1110+
for col in cols: # for hover_data, custom_data etc.
1111+
if col not in agg_f:
1112+
agg_f[col] = aggfunc_discrete
1113+
# ----------------------------------------------------------------------------
1114+
1115+
df_all_trees = pd.DataFrame(columns=["labels", "parent", "id"] + cols)
1116+
# Set column type here (useful for continuous vs discrete colorscale)
1117+
for col in cols:
1118+
df_all_trees[col] = df_all_trees[col].astype(df[col].dtype)
1119+
for i, level in enumerate(path):
1120+
df_tree = pd.DataFrame(columns=df_all_trees.columns)
1121+
dfg = df.groupby(path[i:]).agg(agg_f)
1122+
dfg = dfg.reset_index()
1123+
# Path label massaging
1124+
df_tree["labels"] = dfg[level].copy().astype(str)
1125+
df_tree["parent"] = ""
1126+
df_tree["id"] = dfg[level].copy().astype(str)
1127+
if i < len(path) - 1:
1128+
j = i + 1
1129+
while j < len(path):
1130+
df_tree["parent"] = (
1131+
dfg[path[j]].copy().astype(str) + "/" + df_tree["parent"]
1132+
)
1133+
df_tree["id"] = dfg[path[j]].copy().astype(str) + "/" + df_tree["id"]
1134+
j += 1
1135+
1136+
df_tree["parent"] = df_tree["parent"].str.rstrip("/")
1137+
if cols:
1138+
df_tree[cols] = dfg[cols]
1139+
df_all_trees = df_all_trees.append(df_tree, ignore_index=True)
1140+
1141+
if args["color"] and discrete_color:
1142+
df_all_trees = df_all_trees.sort_values(by=args["color"])
1143+
1144+
# Now modify arguments
1145+
args["data_frame"] = df_all_trees
1146+
args["path"] = None
1147+
args["ids"] = "id"
1148+
args["names"] = "labels"
1149+
args["parents"] = "parent"
1150+
return args
1151+
1152+
10121153
def infer_config(args, constructor, trace_patch):
10131154
# Declare all supported attributes, across all plot types
10141155
attrables = (
@@ -1017,9 +1158,9 @@ def infer_config(args, constructor, trace_patch):
10171158
+ ["names", "values", "parents", "ids"]
10181159
+ ["error_x", "error_x_minus"]
10191160
+ ["error_y", "error_y_minus", "error_z", "error_z_minus"]
1020-
+ ["lat", "lon", "locations", "animation_group"]
1161+
+ ["lat", "lon", "locations", "animation_group", "path"]
10211162
)
1022-
array_attrables = ["dimensions", "custom_data", "hover_data"]
1163+
array_attrables = ["dimensions", "custom_data", "hover_data", "path"]
10231164
group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"]
10241165
all_attrables = attrables + group_attrables + ["color"]
10251166
group_attrs = ["symbol", "line_dash"]
@@ -1028,6 +1169,8 @@ def infer_config(args, constructor, trace_patch):
10281169
all_attrables += [group_attr]
10291170

10301171
args = build_dataframe(args, all_attrables, array_attrables)
1172+
if constructor in [go.Treemap, go.Sunburst] and args["path"] is not None:
1173+
args = process_dataframe_hierarchy(args)
10311174

10321175
attrs = [k for k in attrables if k in args]
10331176
grouped_attrs = []

packages/python/plotly/plotly/express/_doc.py

+6
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@
8686
colref_desc,
8787
"Values from this column or array_like are used to set ids of sectors",
8888
],
89+
path=[
90+
colref_list_type,
91+
colref_list_desc,
92+
"List of columns names or columns of a rectangular dataframe defining the hierarchy of sectors, from root to leaves.",
93+
"An error is raised if path AND ids or parents is passed",
94+
],
8995
lat=[
9096
colref_type,
9197
colref_desc,

0 commit comments

Comments
 (0)