Skip to content

Commit a394c58

Browse files
Merge pull request #2923 from plotly/facet_geo
enable faceting for geo, geojson everywhere possible, text/symbols fo…
2 parents 326932b + a5520f9 commit a394c58

File tree

5 files changed

+104
-73
lines changed

5 files changed

+104
-73
lines changed

Diff for: CHANGELOG.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ This project adheres to [Semantic Versioning](http://semver.org/).
1111
## [4.13.0] - UNRELEASED
1212

1313
### Added
14-
14+
- `px.choropleth`, `px.scatter_geo` and `px.line_geo` now support faceting as well as `fitbounds` and `basemap_visible` [2923](https://github.com/plotly/plotly.py/pull/2923)
15+
- `px.scatter_geo` and `px.line_geo` now support `geojson`/`featureidkey` input [2923](https://github.com/plotly/plotly.py/pull/2923)
16+
- `px.scatter_geo` now supports `symbol` [2923](https://github.com/plotly/plotly.py/pull/2923)
1517
- `go.Figure` now has a `set_subplots` method to set subplots on an already
1618
existing figure. [2866](https://github.com/plotly/plotly.py/pull/2866)
1719
- Added `Turbo` colorscale and fancier swatch display functions
@@ -37,6 +39,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
3739

3840
### Fixed
3941

42+
- `px.scatter_geo` support for `text` is fixed [2923](https://github.com/plotly/plotly.py/pull/2923)
4043
- the `x` and `y` parameters of `px.imshow` are now used also in the case where
4144
an Image trace is used (for RGB data or with `binary_string=True`). However,
4245
only numerical values are accepted (while the Heatmap trace allows date or

Diff for: doc/python/facet-plots.md

+22-1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,26 @@ fig = px.histogram(df, x="total_bill", y="tip", color="sex", facet_row="time", f
8383
fig.show()
8484
```
8585

86+
### Choropleth Column Facets
87+
88+
*new in version 4.13*
89+
90+
```python
91+
import plotly.express as px
92+
93+
df = px.data.election()
94+
df = df.melt(id_vars="district", value_vars=["Coderre", "Bergeron", "Joly"],
95+
var_name="candidate", value_name="votes")
96+
geojson = px.data.election_geojson()
97+
98+
fig = px.choropleth(df, geojson=geojson, color="votes", facet_col="candidate",
99+
locations="district", featureidkey="properties.district",
100+
projection="mercator"
101+
)
102+
fig.update_geos(fitbounds="locations", visible=False)
103+
fig.show()
104+
```
105+
86106
### Adding Lines and Rectangles to Facet Plots
87107

88108
*introduced in plotly 4.12*
@@ -133,7 +153,8 @@ trace.update(legendgroup="trendline", showlegend=False)
133153
fig.add_trace(trace, row="all", col="all", exclude_empty_subplots=True)
134154

135155
# set only the last trace added to appear in the legend
136-
fig.data[-1].update(showlegend=True)
156+
# `selector=-1` introduced in plotly v4.13
157+
fig.update_traces(selector=-1, showlegend=True)
137158
fig.show()
138159
```
139160

Diff for: packages/python/plotly/plotly/express/_chart_types.py

+30-17
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,11 @@ def choropleth(
940940
geojson=None,
941941
featureidkey=None,
942942
color=None,
943+
facet_row=None,
944+
facet_col=None,
945+
facet_col_wrap=0,
946+
facet_row_spacing=None,
947+
facet_col_spacing=None,
943948
hover_name=None,
944949
hover_data=None,
945950
custom_data=None,
@@ -955,6 +960,8 @@ def choropleth(
955960
projection=None,
956961
scope=None,
957962
center=None,
963+
fitbounds=None,
964+
basemap_visible=None,
958965
title=None,
959966
template=None,
960967
width=None,
@@ -967,13 +974,7 @@ def choropleth(
967974
return make_figure(
968975
args=locals(),
969976
constructor=go.Choropleth,
970-
trace_patch=dict(
971-
locationmode=locationmode,
972-
featureidkey=featureidkey,
973-
geojson=geojson
974-
if not hasattr(geojson, "__geo_interface__") # for geopandas
975-
else geojson.__geo_interface__,
976-
),
977+
trace_patch=dict(locationmode=locationmode),
977978
)
978979

979980

@@ -986,8 +987,16 @@ def scatter_geo(
986987
lon=None,
987988
locations=None,
988989
locationmode=None,
990+
geojson=None,
991+
featureidkey=None,
989992
color=None,
990993
text=None,
994+
symbol=None,
995+
facet_row=None,
996+
facet_col=None,
997+
facet_col_wrap=0,
998+
facet_row_spacing=None,
999+
facet_col_spacing=None,
9911000
hover_name=None,
9921001
hover_data=None,
9931002
custom_data=None,
@@ -1001,11 +1010,15 @@ def scatter_geo(
10011010
color_continuous_scale=None,
10021011
range_color=None,
10031012
color_continuous_midpoint=None,
1013+
symbol_sequence=None,
1014+
symbol_map={},
10041015
opacity=None,
10051016
size_max=None,
10061017
projection=None,
10071018
scope=None,
10081019
center=None,
1020+
fitbounds=None,
1021+
basemap_visible=None,
10091022
title=None,
10101023
template=None,
10111024
width=None,
@@ -1031,9 +1044,16 @@ def line_geo(
10311044
lon=None,
10321045
locations=None,
10331046
locationmode=None,
1047+
geojson=None,
1048+
featureidkey=None,
10341049
color=None,
10351050
line_dash=None,
10361051
text=None,
1052+
facet_row=None,
1053+
facet_col=None,
1054+
facet_col_wrap=0,
1055+
facet_row_spacing=None,
1056+
facet_col_spacing=None,
10371057
hover_name=None,
10381058
hover_data=None,
10391059
custom_data=None,
@@ -1049,6 +1069,8 @@ def line_geo(
10491069
projection=None,
10501070
scope=None,
10511071
center=None,
1072+
fitbounds=None,
1073+
basemap_visible=None,
10521074
title=None,
10531075
template=None,
10541076
width=None,
@@ -1138,16 +1160,7 @@ def choropleth_mapbox(
11381160
In a Mapbox choropleth map, each row of `data_frame` is represented by a
11391161
colored region on a Mapbox map.
11401162
"""
1141-
return make_figure(
1142-
args=locals(),
1143-
constructor=go.Choroplethmapbox,
1144-
trace_patch=dict(
1145-
featureidkey=featureidkey,
1146-
geojson=geojson
1147-
if not hasattr(geojson, "__geo_interface__") # for geopandas
1148-
else geojson.__geo_interface__,
1149-
),
1150-
)
1163+
return make_figure(args=locals(), constructor=go.Choroplethmapbox)
11511164

11521165

11531166
choropleth_mapbox.__doc__ = make_docstring(choropleth_mapbox)

Diff for: packages/python/plotly/plotly/express/_core.py

+45-53
Original file line numberDiff line numberDiff line change
@@ -616,33 +616,27 @@ def configure_cartesian_axes(args, fig, orders):
616616
if "is_timeline" in args:
617617
fig.update_xaxes(type="date")
618618

619-
return fig.layout
620-
621619

622620
def configure_ternary_axes(args, fig, orders):
623-
fig.update_layout(
624-
ternary=dict(
625-
aaxis=dict(title_text=get_label(args, args["a"])),
626-
baxis=dict(title_text=get_label(args, args["b"])),
627-
caxis=dict(title_text=get_label(args, args["c"])),
628-
)
621+
fig.update_ternaries(
622+
aaxis=dict(title_text=get_label(args, args["a"])),
623+
baxis=dict(title_text=get_label(args, args["b"])),
624+
caxis=dict(title_text=get_label(args, args["c"])),
629625
)
630626

631627

632628
def configure_polar_axes(args, fig, orders):
633-
layout = dict(
634-
polar=dict(
635-
angularaxis=dict(direction=args["direction"], rotation=args["start_angle"]),
636-
radialaxis=dict(),
637-
)
629+
patch = dict(
630+
angularaxis=dict(direction=args["direction"], rotation=args["start_angle"]),
631+
radialaxis=dict(),
638632
)
639633

640634
for var, axis in [("r", "radialaxis"), ("theta", "angularaxis")]:
641635
if args[var] in orders:
642-
layout["polar"][axis]["categoryorder"] = "array"
643-
layout["polar"][axis]["categoryarray"] = orders[args[var]]
636+
patch[axis]["categoryorder"] = "array"
637+
patch[axis]["categoryarray"] = orders[args[var]]
644638

645-
radialaxis = layout["polar"]["radialaxis"]
639+
radialaxis = patch["radialaxis"]
646640
if args["log_r"]:
647641
radialaxis["type"] = "log"
648642
if args["range_r"]:
@@ -652,21 +646,19 @@ def configure_polar_axes(args, fig, orders):
652646
radialaxis["range"] = args["range_r"]
653647

654648
if args["range_theta"]:
655-
layout["polar"]["sector"] = args["range_theta"]
656-
fig.update(layout=layout)
649+
patch["sector"] = args["range_theta"]
650+
fig.update_polars(patch)
657651

658652

659653
def configure_3d_axes(args, fig, orders):
660-
layout = dict(
661-
scene=dict(
662-
xaxis=dict(title_text=get_label(args, args["x"])),
663-
yaxis=dict(title_text=get_label(args, args["y"])),
664-
zaxis=dict(title_text=get_label(args, args["z"])),
665-
)
654+
patch = dict(
655+
xaxis=dict(title_text=get_label(args, args["x"])),
656+
yaxis=dict(title_text=get_label(args, args["y"])),
657+
zaxis=dict(title_text=get_label(args, args["z"])),
666658
)
667659

668660
for letter in ["x", "y", "z"]:
669-
axis = layout["scene"][letter + "axis"]
661+
axis = patch[letter + "axis"]
670662
if args["log_" + letter]:
671663
axis["type"] = "log"
672664
if args["range_" + letter]:
@@ -677,7 +669,7 @@ def configure_3d_axes(args, fig, orders):
677669
if args[letter] in orders:
678670
axis["categoryorder"] = "array"
679671
axis["categoryarray"] = orders[args[letter]]
680-
fig.update(layout=layout)
672+
fig.update_scenes(patch)
681673

682674

683675
def configure_mapbox(args, fig, orders):
@@ -687,23 +679,21 @@ def configure_mapbox(args, fig, orders):
687679
lat=args["data_frame"][args["lat"]].mean(),
688680
lon=args["data_frame"][args["lon"]].mean(),
689681
)
690-
fig.update_layout(
691-
mapbox=dict(
692-
accesstoken=MAPBOX_TOKEN,
693-
center=center,
694-
zoom=args["zoom"],
695-
style=args["mapbox_style"],
696-
)
682+
fig.update_mapboxes(
683+
accesstoken=MAPBOX_TOKEN,
684+
center=center,
685+
zoom=args["zoom"],
686+
style=args["mapbox_style"],
697687
)
698688

699689

700690
def configure_geo(args, fig, orders):
701-
fig.update_layout(
702-
geo=dict(
703-
center=args["center"],
704-
scope=args["scope"],
705-
projection=dict(type=args["projection"]),
706-
)
691+
fig.update_geos(
692+
center=args["center"],
693+
scope=args["scope"],
694+
fitbounds=args["fitbounds"],
695+
visible=args["basemap_visible"],
696+
projection=dict(type=args["projection"]),
707697
)
708698

709699

@@ -1750,6 +1740,14 @@ def infer_config(args, constructor, trace_patch, layout_patch):
17501740
if "line_shape" in args:
17511741
trace_patch["line"] = dict(shape=args["line_shape"])
17521742

1743+
if "geojson" in args:
1744+
trace_patch["featureidkey"] = args["featureidkey"]
1745+
trace_patch["geojson"] = (
1746+
args["geojson"]
1747+
if not hasattr(args["geojson"], "__geo_interface__") # for geopandas
1748+
else args["geojson"].__geo_interface__
1749+
)
1750+
17531751
# Compute marginal attribute
17541752
if "marginal" in args:
17551753
position = "marginal_x" if args["orientation"] == "v" else "marginal_y"
@@ -2062,20 +2060,12 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
20622060

20632061
def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels):
20642062
# Build subplot specs
2065-
specs = [[{}] * ncols for _ in range(nrows)]
2066-
for frame in frame_list:
2067-
for trace in frame["data"]:
2068-
row0 = trace._subplot_row - 1
2069-
col0 = trace._subplot_col - 1
2070-
if isinstance(trace, go.Splom):
2071-
# Splom not compatible with make_subplots, treat as domain
2072-
specs[row0][col0] = {"type": "domain"}
2073-
else:
2074-
specs[row0][col0] = {"type": trace.type}
2063+
specs = [[dict(type=subplot_type or "domain")] * ncols for _ in range(nrows)]
20752064

20762065
# Default row/column widths uniform
20772066
column_widths = [1.0] * ncols
20782067
row_heights = [1.0] * nrows
2068+
facet_col_wrap = args.get("facet_col_wrap", 0)
20792069

20802070
# Build column_widths/row_heights
20812071
if subplot_type == "xy":
@@ -2087,7 +2077,7 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la
20872077

20882078
row_heights = [main_size] * (nrows - 1) + [1 - main_size]
20892079
vertical_spacing = 0.01
2090-
elif args.get("facet_col_wrap", 0):
2080+
elif facet_col_wrap:
20912081
vertical_spacing = args.get("facet_row_spacing", None) or 0.07
20922082
else:
20932083
vertical_spacing = args.get("facet_row_spacing", None) or 0.03
@@ -2108,10 +2098,12 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la
21082098
#
21092099
# We can customize subplot spacing per type once we enable faceting
21102100
# for all plot types
2111-
vertical_spacing = 0.1
2112-
horizontal_spacing = 0.1
2101+
if facet_col_wrap:
2102+
vertical_spacing = args.get("facet_row_spacing", None) or 0.07
2103+
else:
2104+
vertical_spacing = args.get("facet_row_spacing", None) or 0.03
2105+
horizontal_spacing = args.get("facet_col_spacing", None) or 0.02
21132106

2114-
facet_col_wrap = args.get("facet_col_wrap", 0)
21152107
if facet_col_wrap:
21162108
subplot_labels = [None] * nrows * ncols
21172109
while len(col_labels) < nrows * ncols:

Diff for: packages/python/plotly/plotly/express/_doc.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -475,9 +475,11 @@
475475
"If `True`, an extra line segment is drawn between the first and last point.",
476476
],
477477
line_shape=["str (default `'linear'`)", "One of `'linear'` or `'spline'`."],
478+
fitbounds=["str (default `False`).", "One of `False`, `locations` or `geojson`."],
479+
basemap_visible=["bool", "Force the basemap visibility."],
478480
scope=[
479481
"str (default `'world'`).",
480-
"One of `'world'`, `'usa'`, `'europe'`, `'asia'`, `'africa'`, `'north america'`, or `'south america'`)"
482+
"One of `'world'`, `'usa'`, `'europe'`, `'asia'`, `'africa'`, `'north america'`, or `'south america'`"
481483
"Default is `'world'` unless `projection` is set to `'albers usa'`, which forces `'usa'`.",
482484
],
483485
projection=[

0 commit comments

Comments
 (0)