Skip to content

Commit 9bfd172

Browse files
initial pass at PX auto-orientation
1 parent f622085 commit 9bfd172

File tree

2 files changed

+48
-35
lines changed

2 files changed

+48
-35
lines changed

packages/python/plotly/plotly/express/_chart_types.py

+11-29
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def area(
236236
labels={},
237237
color_discrete_sequence=None,
238238
color_discrete_map={},
239-
orientation="v",
239+
orientation=None,
240240
groupnorm=None,
241241
log_x=False,
242242
log_y=False,
@@ -256,9 +256,7 @@ def area(
256256
return make_figure(
257257
args=locals(),
258258
constructor=go.Scatter,
259-
trace_patch=dict(
260-
stackgroup=1, mode="lines", orientation=orientation, groupnorm=groupnorm
261-
),
259+
trace_patch=dict(stackgroup=1, mode="lines", groupnorm=groupnorm),
262260
)
263261

264262

@@ -291,7 +289,7 @@ def bar(
291289
range_color=None,
292290
color_continuous_midpoint=None,
293291
opacity=None,
294-
orientation="v",
292+
orientation=None,
295293
barmode="relative",
296294
log_x=False,
297295
log_y=False,
@@ -335,7 +333,7 @@ def histogram(
335333
color_discrete_map={},
336334
marginal=None,
337335
opacity=None,
338-
orientation="v",
336+
orientation=None,
339337
barmode="relative",
340338
barnorm=None,
341339
histnorm=None,
@@ -361,13 +359,7 @@ def histogram(
361359
args=locals(),
362360
constructor=go.Histogram,
363361
trace_patch=dict(
364-
orientation=orientation,
365-
histnorm=histnorm,
366-
histfunc=histfunc,
367-
nbinsx=nbins if orientation == "v" else None,
368-
nbinsy=None if orientation == "v" else nbins,
369-
cumulative=dict(enabled=cumulative),
370-
bingroup="x" if orientation == "v" else "y",
362+
histnorm=histnorm, histfunc=histfunc, cumulative=dict(enabled=cumulative),
371363
),
372364
layout_patch=dict(barmode=barmode, barnorm=barnorm),
373365
)
@@ -393,7 +385,7 @@ def violin(
393385
labels={},
394386
color_discrete_sequence=None,
395387
color_discrete_map={},
396-
orientation="v",
388+
orientation=None,
397389
violinmode="group",
398390
log_x=False,
399391
log_y=False,
@@ -414,12 +406,7 @@ def violin(
414406
args=locals(),
415407
constructor=go.Violin,
416408
trace_patch=dict(
417-
orientation=orientation,
418-
points=points,
419-
box=dict(visible=box),
420-
scalegroup=True,
421-
x0=" ",
422-
y0=" ",
409+
points=points, box=dict(visible=box), scalegroup=True, x0=" ", y0=" ",
423410
),
424411
layout_patch=dict(violinmode=violinmode),
425412
)
@@ -445,7 +432,7 @@ def box(
445432
labels={},
446433
color_discrete_sequence=None,
447434
color_discrete_map={},
448-
orientation="v",
435+
orientation=None,
449436
boxmode="group",
450437
log_x=False,
451438
log_y=False,
@@ -470,9 +457,7 @@ def box(
470457
return make_figure(
471458
args=locals(),
472459
constructor=go.Box,
473-
trace_patch=dict(
474-
orientation=orientation, boxpoints=points, notched=notched, x0=" ", y0=" "
475-
),
460+
trace_patch=dict(boxpoints=points, notched=notched, x0=" ", y0=" "),
476461
layout_patch=dict(boxmode=boxmode),
477462
)
478463

@@ -497,7 +482,7 @@ def strip(
497482
labels={},
498483
color_discrete_sequence=None,
499484
color_discrete_map={},
500-
orientation="v",
485+
orientation=None,
501486
stripmode="group",
502487
log_x=False,
503488
log_y=False,
@@ -516,7 +501,6 @@ def strip(
516501
args=locals(),
517502
constructor=go.Box,
518503
trace_patch=dict(
519-
orientation=orientation,
520504
boxpoints="all",
521505
pointpos=0,
522506
hoveron="points",
@@ -1398,9 +1382,7 @@ def funnel(
13981382
In a funnel plot, each row of `data_frame` is represented as a
13991383
rectangular sector of a funnel.
14001384
"""
1401-
return make_figure(
1402-
args=locals(), constructor=go.Funnel, trace_patch=dict(orientation=orientation),
1403-
)
1385+
return make_figure(args=locals(), constructor=go.Funnel,)
14041386

14051387

14061388
funnel.__doc__ = make_docstring(funnel)

packages/python/plotly/plotly/express/_core.py

+37-6
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ def get_label(args, column):
9292
return column
9393

9494

95+
def _is_continuous(df, col_name):
96+
return df[col_name].dtype.kind in "ifc"
97+
98+
9599
def get_decorated_label(args, column, role):
96100
label = get_label(args, column)
97101
if "histfunc" in args and (
@@ -188,7 +192,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
188192
if ((not attr_value) or (name in attr_value))
189193
and (
190194
trace_spec.constructor != go.Parcoords
191-
or args["data_frame"][name].dtype.kind in "ifc"
195+
or _is_continuous(args["data_frame"], name)
192196
)
193197
and (
194198
trace_spec.constructor != go.Parcats
@@ -1124,7 +1128,7 @@ def aggfunc_discrete(x):
11241128
agg_f[count_colname] = "sum"
11251129

11261130
if args["color"]:
1127-
if df[args["color"]].dtype.kind not in "ifc":
1131+
if not _is_continuous(df, args["color"]):
11281132
aggfunc_color = aggfunc_discrete
11291133
discrete_color = True
11301134
elif not aggfunc_color:
@@ -1212,6 +1216,36 @@ def infer_config(args, constructor, trace_patch):
12121216
if constructor in [go.Treemap, go.Sunburst] and args["path"] is not None:
12131217
args = process_dataframe_hierarchy(args)
12141218

1219+
if "orientation" in args:
1220+
has_x = args["x"] is not None
1221+
has_y = args["y"] is not None
1222+
if args["orientation"] is None:
1223+
if constructor in [go.Histogram, go.Scatter]:
1224+
if has_y and not has_x:
1225+
args["orientation"] = "h"
1226+
elif constructor in [go.Violin, go.Box, go.Bar, go.Funnel]:
1227+
if has_x and not has_y:
1228+
args["orientation"] = "h"
1229+
1230+
if args["orientation"] is None and has_x and has_y:
1231+
x_is_continuous = _is_continuous(args["data_frame"], args["x"])
1232+
y_is_continuous = _is_continuous(args["data_frame"], args["y"])
1233+
if x_is_continuous and not y_is_continuous:
1234+
args["orientation"] = "h"
1235+
if y_is_continuous and not x_is_continuous:
1236+
args["orientation"] = "v"
1237+
1238+
if args["orientation"] is None:
1239+
args["orientation"] = "v"
1240+
1241+
if constructor == go.Histogram:
1242+
orientation = args["orientation"]
1243+
nbins = args["nbins"]
1244+
trace_patch["nbinsx"] = nbins if orientation == "v" else None
1245+
trace_patch["nbinsy"] = None if orientation == "v" else nbins
1246+
trace_patch["bingroup"] = "x" if orientation == "v" else "y"
1247+
trace_patch["orientation"] = args["orientation"]
1248+
12151249
attrs = [k for k in attrables if k in args]
12161250
grouped_attrs = []
12171251

@@ -1226,10 +1260,7 @@ def infer_config(args, constructor, trace_patch):
12261260
if "color_discrete_sequence" not in args:
12271261
attrs.append("color")
12281262
else:
1229-
if (
1230-
args["color"]
1231-
and args["data_frame"][args["color"]].dtype.kind in "ifc"
1232-
):
1263+
if args["color"] and _is_continuous(args["data_frame"], args["color"]):
12331264
attrs.append("color")
12341265
args["color_is_continuous"] = True
12351266
elif constructor in [go.Sunburst, go.Treemap]:

0 commit comments

Comments
 (0)