diff --git a/packages/python/plotly/plotly/__init__.py b/packages/python/plotly/plotly/__init__.py index a8fdee9158d..c75cac14006 100644 --- a/packages/python/plotly/plotly/__init__.py +++ b/packages/python/plotly/plotly/__init__.py @@ -27,10 +27,6 @@ """ from __future__ import absolute_import -# https://packaging.python.org/guides/packaging-namespace-packages/ -# pkgutil-style-namespace-packages -__path__ = __import__('pkgutil').extend_path(__path__, __name__) - from plotly import ( graph_objs, tools, @@ -38,6 +34,8 @@ offline, colors, io, + data, + colors, _docstring_gen ) diff --git a/packages/python/plotly/plotly/colors.py b/packages/python/plotly/plotly/colors/__init__.py similarity index 98% rename from packages/python/plotly/plotly/colors.py rename to packages/python/plotly/plotly/colors/__init__.py index 8ec66f78cb8..205771fd5ee 100644 --- a/packages/python/plotly/plotly/colors.py +++ b/packages/python/plotly/plotly/colors/__init__.py @@ -81,6 +81,21 @@ from plotly import exceptions + +# Built-in qualitative color sequences and sequential, +# diverging and cyclical color scales. +# +# Initially ported over from plotly_express +from . import ( # noqa: F401 + qualitative, + sequential, + diverging, + cyclical, + cmocean, + colorbrewer, + carto, +) + DEFAULT_PLOTLY_COLORS = ['rgb(31, 119, 180)', 'rgb(255, 127, 14)', 'rgb(44, 160, 44)', 'rgb(214, 39, 40)', 'rgb(148, 103, 189)', 'rgb(140, 86, 75)', diff --git a/packages/python/plotly/plotly/colors/_swatches.py b/packages/python/plotly/plotly/colors/_swatches.py new file mode 100644 index 00000000000..ff07ff42907 --- /dev/null +++ b/packages/python/plotly/plotly/colors/_swatches.py @@ -0,0 +1,37 @@ +def _swatches(module_names, module_contents): + """ + Returns: + A `Figure` object. This figure demonstrates the color scales and + sequences in this module, as stacked bar charts. + """ + import plotly.graph_objs as go + + sequences = [ + (k, v) + for k, v in module_contents.items() + if not (k.startswith("_") or k == "swatches") + ] + + return go.Figure( + data=[ + go.Bar( + orientation="h", + y=[name] * len(colors), + x=[1] * len(colors), + customdata=list(range(len(colors))), + marker=dict(color=colors), + hovertemplate="%{y}[%{customdata}] = %{marker.color}", + ) + for name, colors in reversed(sequences) + ], + layout=dict( + title=module_names, + barmode="stack", + barnorm="fraction", + template="plotly", + bargap=0.5, + showlegend=False, + xaxis=dict(range=[-0.02, 1.02], showticklabels=False, showgrid=False), + height=max(600, 40 * len(sequences)), + ), + ) diff --git a/packages/python/plotly/plotly/colors/carto.py b/packages/python/plotly/plotly/colors/carto.py new file mode 100644 index 00000000000..c20d6b5d259 --- /dev/null +++ b/packages/python/plotly/plotly/colors/carto.py @@ -0,0 +1,384 @@ +""" +Color sequences and scales from CARTO's CartoColors + +Learn more at https://github.com/CartoDB/CartoColor + +CARTOColors are made available under a Creative Commons Attribution license: https://creativecommons.org/licenses/by/3.0/us/ +""" + +from ._swatches import _swatches + + +def swatches(): + return _swatches(__name__, globals()) + + +swatches.__doc__ = _swatches.__doc__ + +Burg = [ + "rgb(255, 198, 196)", + "rgb(244, 163, 168)", + "rgb(227, 129, 145)", + "rgb(204, 96, 125)", + "rgb(173, 70, 108)", + "rgb(139, 48, 88)", + "rgb(103, 32, 68)", +] + +Burgyl = [ + "rgb(251, 230, 197)", + "rgb(245, 186, 152)", + "rgb(238, 138, 130)", + "rgb(220, 113, 118)", + "rgb(200, 88, 108)", + "rgb(156, 63, 93)", + "rgb(112, 40, 74)", +] + +Redor = [ + "rgb(246, 210, 169)", + "rgb(245, 183, 142)", + "rgb(241, 156, 124)", + "rgb(234, 129, 113)", + "rgb(221, 104, 108)", + "rgb(202, 82, 104)", + "rgb(177, 63, 100)", +] + +Oryel = [ + "rgb(236, 218, 154)", + "rgb(239, 196, 126)", + "rgb(243, 173, 106)", + "rgb(247, 148, 93)", + "rgb(249, 123, 87)", + "rgb(246, 99, 86)", + "rgb(238, 77, 90)", +] + +Peach = [ + "rgb(253, 224, 197)", + "rgb(250, 203, 166)", + "rgb(248, 181, 139)", + "rgb(245, 158, 114)", + "rgb(242, 133, 93)", + "rgb(239, 106, 76)", + "rgb(235, 74, 64)", +] + +Pinkyl = [ + "rgb(254, 246, 181)", + "rgb(255, 221, 154)", + "rgb(255, 194, 133)", + "rgb(255, 166, 121)", + "rgb(250, 138, 118)", + "rgb(241, 109, 122)", + "rgb(225, 83, 131)", +] + +Mint = [ + "rgb(228, 241, 225)", + "rgb(180, 217, 204)", + "rgb(137, 192, 182)", + "rgb(99, 166, 160)", + "rgb(68, 140, 138)", + "rgb(40, 114, 116)", + "rgb(13, 88, 95)", +] + +Blugrn = [ + "rgb(196, 230, 195)", + "rgb(150, 210, 164)", + "rgb(109, 188, 144)", + "rgb(77, 162, 132)", + "rgb(54, 135, 122)", + "rgb(38, 107, 110)", + "rgb(29, 79, 96)", +] + +Darkmint = [ + "rgb(210, 251, 212)", + "rgb(165, 219, 194)", + "rgb(123, 188, 176)", + "rgb(85, 156, 158)", + "rgb(58, 124, 137)", + "rgb(35, 93, 114)", + "rgb(18, 63, 90)", +] + +Emrld = [ + "rgb(211, 242, 163)", + "rgb(151, 225, 150)", + "rgb(108, 192, 139)", + "rgb(76, 155, 130)", + "rgb(33, 122, 121)", + "rgb(16, 89, 101)", + "rgb(7, 64, 80)", +] + +Aggrnyl = [ + "rgb(36, 86, 104)", + "rgb(15, 114, 121)", + "rgb(13, 143, 129)", + "rgb(57, 171, 126)", + "rgb(110, 197, 116)", + "rgb(169, 220, 103)", + "rgb(237, 239, 93)", +] + +Bluyl = [ + "rgb(247, 254, 174)", + "rgb(183, 230, 165)", + "rgb(124, 203, 162)", + "rgb(70, 174, 160)", + "rgb(8, 144, 153)", + "rgb(0, 113, 139)", + "rgb(4, 82, 117)", +] + +Teal = [ + "rgb(209, 238, 234)", + "rgb(168, 219, 217)", + "rgb(133, 196, 201)", + "rgb(104, 171, 184)", + "rgb(79, 144, 166)", + "rgb(59, 115, 143)", + "rgb(42, 86, 116)", +] + +Tealgrn = [ + "rgb(176, 242, 188)", + "rgb(137, 232, 172)", + "rgb(103, 219, 165)", + "rgb(76, 200, 163)", + "rgb(56, 178, 163)", + "rgb(44, 152, 160)", + "rgb(37, 125, 152)", +] + +Purp = [ + "rgb(243, 224, 247)", + "rgb(228, 199, 241)", + "rgb(209, 175, 232)", + "rgb(185, 152, 221)", + "rgb(159, 130, 206)", + "rgb(130, 109, 186)", + "rgb(99, 88, 159)", +] + +Purpor = [ + "rgb(249, 221, 218)", + "rgb(242, 185, 196)", + "rgb(229, 151, 185)", + "rgb(206, 120, 179)", + "rgb(173, 95, 173)", + "rgb(131, 75, 160)", + "rgb(87, 59, 136)", +] + +Sunset = [ + "rgb(243, 231, 155)", + "rgb(250, 196, 132)", + "rgb(248, 160, 126)", + "rgb(235, 127, 134)", + "rgb(206, 102, 147)", + "rgb(160, 89, 160)", + "rgb(92, 83, 165)", +] + +Magenta = [ + "rgb(243, 203, 211)", + "rgb(234, 169, 189)", + "rgb(221, 136, 172)", + "rgb(202, 105, 157)", + "rgb(177, 77, 142)", + "rgb(145, 53, 125)", + "rgb(108, 33, 103)", +] + +Sunsetdark = [ + "rgb(252, 222, 156)", + "rgb(250, 164, 118)", + "rgb(240, 116, 110)", + "rgb(227, 79, 111)", + "rgb(220, 57, 119)", + "rgb(185, 37, 122)", + "rgb(124, 29, 111)", +] + +Agsunset = [ + "rgb(75, 41, 145)", + "rgb(135, 44, 162)", + "rgb(192, 54, 157)", + "rgb(234, 79, 136)", + "rgb(250, 120, 118)", + "rgb(246, 169, 122)", + "rgb(237, 217, 163)", +] + +Brwnyl = [ + "rgb(237, 229, 207)", + "rgb(224, 194, 162)", + "rgb(211, 156, 131)", + "rgb(193, 118, 111)", + "rgb(166, 84, 97)", + "rgb(129, 55, 83)", + "rgb(84, 31, 63)", +] + +# Diverging schemes + +Armyrose = [ + "rgb(121, 130, 52)", + "rgb(163, 173, 98)", + "rgb(208, 211, 162)", + "rgb(253, 251, 228)", + "rgb(240, 198, 195)", + "rgb(223, 145, 163)", + "rgb(212, 103, 128)", +] + +Fall = [ + "rgb(61, 89, 65)", + "rgb(119, 136, 104)", + "rgb(181, 185, 145)", + "rgb(246, 237, 189)", + "rgb(237, 187, 138)", + "rgb(222, 138, 90)", + "rgb(202, 86, 44)", +] + +Geyser = [ + "rgb(0, 128, 128)", + "rgb(112, 164, 148)", + "rgb(180, 200, 168)", + "rgb(246, 237, 189)", + "rgb(237, 187, 138)", + "rgb(222, 138, 90)", + "rgb(202, 86, 44)", +] + +Temps = [ + "rgb(0, 147, 146)", + "rgb(57, 177, 133)", + "rgb(156, 203, 134)", + "rgb(233, 226, 156)", + "rgb(238, 180, 121)", + "rgb(232, 132, 113)", + "rgb(207, 89, 126)", +] + +Tealrose = [ + "rgb(0, 147, 146)", + "rgb(114, 170, 161)", + "rgb(177, 199, 179)", + "rgb(241, 234, 200)", + "rgb(229, 185, 173)", + "rgb(217, 137, 148)", + "rgb(208, 88, 126)", +] + +Tropic = [ + "rgb(0, 155, 158)", + "rgb(66, 183, 185)", + "rgb(167, 211, 212)", + "rgb(241, 241, 241)", + "rgb(228, 193, 217)", + "rgb(214, 145, 193)", + "rgb(199, 93, 171)", +] + +Earth = [ + "rgb(161, 105, 40)", + "rgb(189, 146, 90)", + "rgb(214, 189, 141)", + "rgb(237, 234, 194)", + "rgb(181, 200, 184)", + "rgb(121, 167, 172)", + "rgb(40, 135, 161)", +] + +# Qualitative palettes + +Antique = [ + "rgb(133, 92, 117)", + "rgb(217, 175, 107)", + "rgb(175, 100, 88)", + "rgb(115, 111, 76)", + "rgb(82, 106, 131)", + "rgb(98, 83, 119)", + "rgb(104, 133, 92)", + "rgb(156, 156, 94)", + "rgb(160, 97, 119)", + "rgb(140, 120, 93)", + "rgb(124, 124, 124)", +] + +Bold = [ + "rgb(127, 60, 141)", + "rgb(17, 165, 121)", + "rgb(57, 105, 172)", + "rgb(242, 183, 1)", + "rgb(231, 63, 116)", + "rgb(128, 186, 90)", + "rgb(230, 131, 16)", + "rgb(0, 134, 149)", + "rgb(207, 28, 144)", + "rgb(249, 123, 114)", + "rgb(165, 170, 153)", +] + +Pastel = [ + "rgb(102, 197, 204)", + "rgb(246, 207, 113)", + "rgb(248, 156, 116)", + "rgb(220, 176, 242)", + "rgb(135, 197, 95)", + "rgb(158, 185, 243)", + "rgb(254, 136, 177)", + "rgb(201, 219, 116)", + "rgb(139, 224, 164)", + "rgb(180, 151, 231)", + "rgb(179, 179, 179)", +] + +Prism = [ + "rgb(95, 70, 144)", + "rgb(29, 105, 150)", + "rgb(56, 166, 165)", + "rgb(15, 133, 84)", + "rgb(115, 175, 72)", + "rgb(237, 173, 8)", + "rgb(225, 124, 5)", + "rgb(204, 80, 62)", + "rgb(148, 52, 110)", + "rgb(111, 64, 112)", + "rgb(102, 102, 102)", +] + +Safe = [ + "rgb(136, 204, 238)", + "rgb(204, 102, 119)", + "rgb(221, 204, 119)", + "rgb(17, 119, 51)", + "rgb(51, 34, 136)", + "rgb(170, 68, 153)", + "rgb(68, 170, 153)", + "rgb(153, 153, 51)", + "rgb(136, 34, 85)", + "rgb(102, 17, 0)", + "rgb(136, 136, 136)", +] + +Vivid = [ + "rgb(229, 134, 6)", + "rgb(93, 105, 177)", + "rgb(82, 188, 163)", + "rgb(153, 201, 69)", + "rgb(204, 97, 176)", + "rgb(36, 121, 108)", + "rgb(218, 165, 27)", + "rgb(47, 138, 196)", + "rgb(118, 78, 159)", + "rgb(237, 100, 90)", + "rgb(165, 170, 153)", +] diff --git a/packages/python/plotly/plotly/colors/cmocean.py b/packages/python/plotly/plotly/colors/cmocean.py new file mode 100644 index 00000000000..c7174fd60c0 --- /dev/null +++ b/packages/python/plotly/plotly/colors/cmocean.py @@ -0,0 +1,269 @@ +""" +Color scales from the cmocean project + +Learn more at https://matplotlib.org/cmocean/ + +cmocean is made available under an MIT license: https://github.com/matplotlib/cmocean/blob/master/LICENSE.txt +""" + +from ._swatches import _swatches + + +def swatches(): + return _swatches(__name__, globals()) + + +swatches.__doc__ = _swatches.__doc__ + +turbid = [ + "rgb(232, 245, 171)", + "rgb(220, 219, 137)", + "rgb(209, 193, 107)", + "rgb(199, 168, 83)", + "rgb(186, 143, 66)", + "rgb(170, 121, 60)", + "rgb(151, 103, 58)", + "rgb(129, 87, 56)", + "rgb(104, 72, 53)", + "rgb(80, 59, 46)", + "rgb(57, 45, 37)", + "rgb(34, 30, 27)", +] +thermal = [ + "rgb(3, 35, 51)", + "rgb(13, 48, 100)", + "rgb(53, 50, 155)", + "rgb(93, 62, 153)", + "rgb(126, 77, 143)", + "rgb(158, 89, 135)", + "rgb(193, 100, 121)", + "rgb(225, 113, 97)", + "rgb(246, 139, 69)", + "rgb(251, 173, 60)", + "rgb(246, 211, 70)", + "rgb(231, 250, 90)", +] +haline = [ + "rgb(41, 24, 107)", + "rgb(42, 35, 160)", + "rgb(15, 71, 153)", + "rgb(18, 95, 142)", + "rgb(38, 116, 137)", + "rgb(53, 136, 136)", + "rgb(65, 157, 133)", + "rgb(81, 178, 124)", + "rgb(111, 198, 107)", + "rgb(160, 214, 91)", + "rgb(212, 225, 112)", + "rgb(253, 238, 153)", +] +solar = [ + "rgb(51, 19, 23)", + "rgb(79, 28, 33)", + "rgb(108, 36, 36)", + "rgb(135, 47, 32)", + "rgb(157, 66, 25)", + "rgb(174, 88, 20)", + "rgb(188, 111, 19)", + "rgb(199, 137, 22)", + "rgb(209, 164, 32)", + "rgb(217, 192, 44)", + "rgb(222, 222, 59)", + "rgb(224, 253, 74)", +] +ice = [ + "rgb(3, 5, 18)", + "rgb(25, 25, 51)", + "rgb(44, 42, 87)", + "rgb(58, 60, 125)", + "rgb(62, 83, 160)", + "rgb(62, 109, 178)", + "rgb(72, 134, 187)", + "rgb(89, 159, 196)", + "rgb(114, 184, 205)", + "rgb(149, 207, 216)", + "rgb(192, 229, 232)", + "rgb(234, 252, 253)", +] +gray = [ + "rgb(0, 0, 0)", + "rgb(16, 16, 16)", + "rgb(38, 38, 38)", + "rgb(59, 59, 59)", + "rgb(81, 80, 80)", + "rgb(102, 101, 101)", + "rgb(124, 123, 122)", + "rgb(146, 146, 145)", + "rgb(171, 171, 170)", + "rgb(197, 197, 195)", + "rgb(224, 224, 223)", + "rgb(254, 254, 253)", +] +oxy = [ + "rgb(63, 5, 5)", + "rgb(101, 6, 13)", + "rgb(138, 17, 9)", + "rgb(96, 95, 95)", + "rgb(119, 118, 118)", + "rgb(142, 141, 141)", + "rgb(166, 166, 165)", + "rgb(193, 192, 191)", + "rgb(222, 222, 220)", + "rgb(239, 248, 90)", + "rgb(230, 210, 41)", + "rgb(220, 174, 25)", +] +deep = [ + "rgb(253, 253, 204)", + "rgb(206, 236, 179)", + "rgb(156, 219, 165)", + "rgb(111, 201, 163)", + "rgb(86, 177, 163)", + "rgb(76, 153, 160)", + "rgb(68, 130, 155)", + "rgb(62, 108, 150)", + "rgb(62, 82, 143)", + "rgb(64, 60, 115)", + "rgb(54, 43, 77)", + "rgb(39, 26, 44)", +] +dense = [ + "rgb(230, 240, 240)", + "rgb(191, 221, 229)", + "rgb(156, 201, 226)", + "rgb(129, 180, 227)", + "rgb(115, 154, 228)", + "rgb(117, 127, 221)", + "rgb(120, 100, 202)", + "rgb(119, 74, 175)", + "rgb(113, 50, 141)", + "rgb(100, 31, 104)", + "rgb(80, 20, 66)", + "rgb(54, 14, 36)", +] +algae = [ + "rgb(214, 249, 207)", + "rgb(186, 228, 174)", + "rgb(156, 209, 143)", + "rgb(124, 191, 115)", + "rgb(85, 174, 91)", + "rgb(37, 157, 81)", + "rgb(7, 138, 78)", + "rgb(13, 117, 71)", + "rgb(23, 95, 61)", + "rgb(25, 75, 49)", + "rgb(23, 55, 35)", + "rgb(17, 36, 20)", +] +matter = [ + "rgb(253, 237, 176)", + "rgb(250, 205, 145)", + "rgb(246, 173, 119)", + "rgb(240, 142, 98)", + "rgb(231, 109, 84)", + "rgb(216, 80, 83)", + "rgb(195, 56, 90)", + "rgb(168, 40, 96)", + "rgb(138, 29, 99)", + "rgb(107, 24, 93)", + "rgb(76, 21, 80)", + "rgb(47, 15, 61)", +] +speed = [ + "rgb(254, 252, 205)", + "rgb(239, 225, 156)", + "rgb(221, 201, 106)", + "rgb(194, 182, 59)", + "rgb(157, 167, 21)", + "rgb(116, 153, 5)", + "rgb(75, 138, 20)", + "rgb(35, 121, 36)", + "rgb(11, 100, 44)", + "rgb(18, 78, 43)", + "rgb(25, 56, 34)", + "rgb(23, 35, 18)", +] +amp = [ + "rgb(241, 236, 236)", + "rgb(230, 209, 203)", + "rgb(221, 182, 170)", + "rgb(213, 156, 137)", + "rgb(205, 129, 103)", + "rgb(196, 102, 73)", + "rgb(186, 74, 47)", + "rgb(172, 44, 36)", + "rgb(149, 19, 39)", + "rgb(120, 14, 40)", + "rgb(89, 13, 31)", + "rgb(60, 9, 17)", +] +tempo = [ + "rgb(254, 245, 244)", + "rgb(222, 224, 210)", + "rgb(189, 206, 181)", + "rgb(153, 189, 156)", + "rgb(110, 173, 138)", + "rgb(65, 157, 129)", + "rgb(25, 137, 125)", + "rgb(18, 116, 117)", + "rgb(25, 94, 106)", + "rgb(28, 72, 93)", + "rgb(25, 51, 80)", + "rgb(20, 29, 67)", +] +phase = [ + "rgb(167, 119, 12)", + "rgb(197, 96, 51)", + "rgb(217, 67, 96)", + "rgb(221, 38, 163)", + "rgb(196, 59, 224)", + "rgb(153, 97, 244)", + "rgb(95, 127, 228)", + "rgb(40, 144, 183)", + "rgb(15, 151, 136)", + "rgb(39, 153, 79)", + "rgb(119, 141, 17)", + "rgb(167, 119, 12)", +] +balance = [ + "rgb(23, 28, 66)", + "rgb(41, 58, 143)", + "rgb(11, 102, 189)", + "rgb(69, 144, 185)", + "rgb(142, 181, 194)", + "rgb(210, 216, 219)", + "rgb(230, 210, 204)", + "rgb(213, 157, 137)", + "rgb(196, 101, 72)", + "rgb(172, 43, 36)", + "rgb(120, 14, 40)", + "rgb(60, 9, 17)", +] +delta = [ + "rgb(16, 31, 63)", + "rgb(38, 62, 144)", + "rgb(30, 110, 161)", + "rgb(60, 154, 171)", + "rgb(140, 193, 186)", + "rgb(217, 229, 218)", + "rgb(239, 226, 156)", + "rgb(195, 182, 59)", + "rgb(115, 152, 5)", + "rgb(34, 120, 36)", + "rgb(18, 78, 43)", + "rgb(23, 35, 18)", +] +curl = [ + "rgb(20, 29, 67)", + "rgb(28, 72, 93)", + "rgb(18, 115, 117)", + "rgb(63, 156, 129)", + "rgb(153, 189, 156)", + "rgb(223, 225, 211)", + "rgb(241, 218, 206)", + "rgb(224, 160, 137)", + "rgb(203, 101, 99)", + "rgb(164, 54, 96)", + "rgb(111, 23, 91)", + "rgb(51, 13, 53)", +] diff --git a/packages/python/plotly/plotly/colors/colorbrewer.py b/packages/python/plotly/plotly/colors/colorbrewer.py new file mode 100644 index 00000000000..fa2e2b9a981 --- /dev/null +++ b/packages/python/plotly/plotly/colors/colorbrewer.py @@ -0,0 +1,458 @@ +""" +Color scales and sequences from the colorbrewer 2 project + +Learn more at http://colorbrewer2.org + +colorbrewer is made available under an Apache license: http://colorbrewer2.org/export/LICENSE.txt +""" + +from ._swatches import _swatches + + +def swatches(): + return _swatches(__name__, globals()) + + +swatches.__doc__ = _swatches.__doc__ + +BrBG = [ + "rgb(84,48,5)", + "rgb(140,81,10)", + "rgb(191,129,45)", + "rgb(223,194,125)", + "rgb(246,232,195)", + "rgb(245,245,245)", + "rgb(199,234,229)", + "rgb(128,205,193)", + "rgb(53,151,143)", + "rgb(1,102,94)", + "rgb(0,60,48)", +] + +PRGn = [ + "rgb(64,0,75)", + "rgb(118,42,131)", + "rgb(153,112,171)", + "rgb(194,165,207)", + "rgb(231,212,232)", + "rgb(247,247,247)", + "rgb(217,240,211)", + "rgb(166,219,160)", + "rgb(90,174,97)", + "rgb(27,120,55)", + "rgb(0,68,27)", +] + +PiYG = [ + "rgb(142,1,82)", + "rgb(197,27,125)", + "rgb(222,119,174)", + "rgb(241,182,218)", + "rgb(253,224,239)", + "rgb(247,247,247)", + "rgb(230,245,208)", + "rgb(184,225,134)", + "rgb(127,188,65)", + "rgb(77,146,33)", + "rgb(39,100,25)", +] + +PuOr = [ + "rgb(127,59,8)", + "rgb(179,88,6)", + "rgb(224,130,20)", + "rgb(253,184,99)", + "rgb(254,224,182)", + "rgb(247,247,247)", + "rgb(216,218,235)", + "rgb(178,171,210)", + "rgb(128,115,172)", + "rgb(84,39,136)", + "rgb(45,0,75)", +] + +RdBu = [ + "rgb(103,0,31)", + "rgb(178,24,43)", + "rgb(214,96,77)", + "rgb(244,165,130)", + "rgb(253,219,199)", + "rgb(247,247,247)", + "rgb(209,229,240)", + "rgb(146,197,222)", + "rgb(67,147,195)", + "rgb(33,102,172)", + "rgb(5,48,97)", +] + +RdGy = [ + "rgb(103,0,31)", + "rgb(178,24,43)", + "rgb(214,96,77)", + "rgb(244,165,130)", + "rgb(253,219,199)", + "rgb(255,255,255)", + "rgb(224,224,224)", + "rgb(186,186,186)", + "rgb(135,135,135)", + "rgb(77,77,77)", + "rgb(26,26,26)", +] + +RdYlBu = [ + "rgb(165,0,38)", + "rgb(215,48,39)", + "rgb(244,109,67)", + "rgb(253,174,97)", + "rgb(254,224,144)", + "rgb(255,255,191)", + "rgb(224,243,248)", + "rgb(171,217,233)", + "rgb(116,173,209)", + "rgb(69,117,180)", + "rgb(49,54,149)", +] + +RdYlGn = [ + "rgb(165,0,38)", + "rgb(215,48,39)", + "rgb(244,109,67)", + "rgb(253,174,97)", + "rgb(254,224,139)", + "rgb(255,255,191)", + "rgb(217,239,139)", + "rgb(166,217,106)", + "rgb(102,189,99)", + "rgb(26,152,80)", + "rgb(0,104,55)", +] + +Spectral = [ + "rgb(158,1,66)", + "rgb(213,62,79)", + "rgb(244,109,67)", + "rgb(253,174,97)", + "rgb(254,224,139)", + "rgb(255,255,191)", + "rgb(230,245,152)", + "rgb(171,221,164)", + "rgb(102,194,165)", + "rgb(50,136,189)", + "rgb(94,79,162)", +] + +Set1 = [ + "rgb(228,26,28)", + "rgb(55,126,184)", + "rgb(77,175,74)", + "rgb(152,78,163)", + "rgb(255,127,0)", + "rgb(255,255,51)", + "rgb(166,86,40)", + "rgb(247,129,191)", + "rgb(153,153,153)", +] + + +Pastel1 = [ + "rgb(251,180,174)", + "rgb(179,205,227)", + "rgb(204,235,197)", + "rgb(222,203,228)", + "rgb(254,217,166)", + "rgb(255,255,204)", + "rgb(229,216,189)", + "rgb(253,218,236)", + "rgb(242,242,242)", +] +Dark2 = [ + "rgb(27,158,119)", + "rgb(217,95,2)", + "rgb(117,112,179)", + "rgb(231,41,138)", + "rgb(102,166,30)", + "rgb(230,171,2)", + "rgb(166,118,29)", + "rgb(102,102,102)", +] +Set2 = [ + "rgb(102,194,165)", + "rgb(252,141,98)", + "rgb(141,160,203)", + "rgb(231,138,195)", + "rgb(166,216,84)", + "rgb(255,217,47)", + "rgb(229,196,148)", + "rgb(179,179,179)", +] + + +Pastel2 = [ + "rgb(179,226,205)", + "rgb(253,205,172)", + "rgb(203,213,232)", + "rgb(244,202,228)", + "rgb(230,245,201)", + "rgb(255,242,174)", + "rgb(241,226,204)", + "rgb(204,204,204)", +] + +Set3 = [ + "rgb(141,211,199)", + "rgb(255,255,179)", + "rgb(190,186,218)", + "rgb(251,128,114)", + "rgb(128,177,211)", + "rgb(253,180,98)", + "rgb(179,222,105)", + "rgb(252,205,229)", + "rgb(217,217,217)", + "rgb(188,128,189)", + "rgb(204,235,197)", + "rgb(255,237,111)", +] + +Accent = [ + "rgb(127,201,127)", + "rgb(190,174,212)", + "rgb(253,192,134)", + "rgb(255,255,153)", + "rgb(56,108,176)", + "rgb(240,2,127)", + "rgb(191,91,23)", + "rgb(102,102,102)", +] + + +Paired = [ + "rgb(166,206,227)", + "rgb(31,120,180)", + "rgb(178,223,138)", + "rgb(51,160,44)", + "rgb(251,154,153)", + "rgb(227,26,28)", + "rgb(253,191,111)", + "rgb(255,127,0)", + "rgb(202,178,214)", + "rgb(106,61,154)", + "rgb(255,255,153)", + "rgb(177,89,40)", +] + + +Blues = [ + "rgb(247,251,255)", + "rgb(222,235,247)", + "rgb(198,219,239)", + "rgb(158,202,225)", + "rgb(107,174,214)", + "rgb(66,146,198)", + "rgb(33,113,181)", + "rgb(8,81,156)", + "rgb(8,48,107)", +] + +BuGn = [ + "rgb(247,252,253)", + "rgb(229,245,249)", + "rgb(204,236,230)", + "rgb(153,216,201)", + "rgb(102,194,164)", + "rgb(65,174,118)", + "rgb(35,139,69)", + "rgb(0,109,44)", + "rgb(0,68,27)", +] + +BuPu = [ + "rgb(247,252,253)", + "rgb(224,236,244)", + "rgb(191,211,230)", + "rgb(158,188,218)", + "rgb(140,150,198)", + "rgb(140,107,177)", + "rgb(136,65,157)", + "rgb(129,15,124)", + "rgb(77,0,75)", +] + +GnBu = [ + "rgb(247,252,240)", + "rgb(224,243,219)", + "rgb(204,235,197)", + "rgb(168,221,181)", + "rgb(123,204,196)", + "rgb(78,179,211)", + "rgb(43,140,190)", + "rgb(8,104,172)", + "rgb(8,64,129)", +] + +Greens = [ + "rgb(247,252,245)", + "rgb(229,245,224)", + "rgb(199,233,192)", + "rgb(161,217,155)", + "rgb(116,196,118)", + "rgb(65,171,93)", + "rgb(35,139,69)", + "rgb(0,109,44)", + "rgb(0,68,27)", +] + +Greys = [ + "rgb(255,255,255)", + "rgb(240,240,240)", + "rgb(217,217,217)", + "rgb(189,189,189)", + "rgb(150,150,150)", + "rgb(115,115,115)", + "rgb(82,82,82)", + "rgb(37,37,37)", + "rgb(0,0,0)", +] + +OrRd = [ + "rgb(255,247,236)", + "rgb(254,232,200)", + "rgb(253,212,158)", + "rgb(253,187,132)", + "rgb(252,141,89)", + "rgb(239,101,72)", + "rgb(215,48,31)", + "rgb(179,0,0)", + "rgb(127,0,0)", +] + +Oranges = [ + "rgb(255,245,235)", + "rgb(254,230,206)", + "rgb(253,208,162)", + "rgb(253,174,107)", + "rgb(253,141,60)", + "rgb(241,105,19)", + "rgb(217,72,1)", + "rgb(166,54,3)", + "rgb(127,39,4)", +] + +PuBu = [ + "rgb(255,247,251)", + "rgb(236,231,242)", + "rgb(208,209,230)", + "rgb(166,189,219)", + "rgb(116,169,207)", + "rgb(54,144,192)", + "rgb(5,112,176)", + "rgb(4,90,141)", + "rgb(2,56,88)", +] + +PuBuGn = [ + "rgb(255,247,251)", + "rgb(236,226,240)", + "rgb(208,209,230)", + "rgb(166,189,219)", + "rgb(103,169,207)", + "rgb(54,144,192)", + "rgb(2,129,138)", + "rgb(1,108,89)", + "rgb(1,70,54)", +] + +PuRd = [ + "rgb(247,244,249)", + "rgb(231,225,239)", + "rgb(212,185,218)", + "rgb(201,148,199)", + "rgb(223,101,176)", + "rgb(231,41,138)", + "rgb(206,18,86)", + "rgb(152,0,67)", + "rgb(103,0,31)", +] + +Purples = [ + "rgb(252,251,253)", + "rgb(239,237,245)", + "rgb(218,218,235)", + "rgb(188,189,220)", + "rgb(158,154,200)", + "rgb(128,125,186)", + "rgb(106,81,163)", + "rgb(84,39,143)", + "rgb(63,0,125)", +] + +RdPu = [ + "rgb(255,247,243)", + "rgb(253,224,221)", + "rgb(252,197,192)", + "rgb(250,159,181)", + "rgb(247,104,161)", + "rgb(221,52,151)", + "rgb(174,1,126)", + "rgb(122,1,119)", + "rgb(73,0,106)", +] + +Reds = [ + "rgb(255,245,240)", + "rgb(254,224,210)", + "rgb(252,187,161)", + "rgb(252,146,114)", + "rgb(251,106,74)", + "rgb(239,59,44)", + "rgb(203,24,29)", + "rgb(165,15,21)", + "rgb(103,0,13)", +] + +YlGn = [ + "rgb(255,255,229)", + "rgb(247,252,185)", + "rgb(217,240,163)", + "rgb(173,221,142)", + "rgb(120,198,121)", + "rgb(65,171,93)", + "rgb(35,132,67)", + "rgb(0,104,55)", + "rgb(0,69,41)", +] + +YlGnBu = [ + "rgb(255,255,217)", + "rgb(237,248,177)", + "rgb(199,233,180)", + "rgb(127,205,187)", + "rgb(65,182,196)", + "rgb(29,145,192)", + "rgb(34,94,168)", + "rgb(37,52,148)", + "rgb(8,29,88)", +] + +YlOrBr = [ + "rgb(255,255,229)", + "rgb(255,247,188)", + "rgb(254,227,145)", + "rgb(254,196,79)", + "rgb(254,153,41)", + "rgb(236,112,20)", + "rgb(204,76,2)", + "rgb(153,52,4)", + "rgb(102,37,6)", +] + +YlOrRd = [ + "rgb(255,255,204)", + "rgb(255,237,160)", + "rgb(254,217,118)", + "rgb(254,178,76)", + "rgb(253,141,60)", + "rgb(252,78,42)", + "rgb(227,26,28)", + "rgb(189,0,38)", + "rgb(128,0,38)", +] diff --git a/packages/python/plotly/plotly/colors/cyclical.py b/packages/python/plotly/plotly/colors/cyclical.py new file mode 100644 index 00000000000..ec024865df6 --- /dev/null +++ b/packages/python/plotly/plotly/colors/cyclical.py @@ -0,0 +1,127 @@ +""" +Cyclical color scales are appropriate for continuous data that has a natural cyclical \ +structure, such as temporal data (hour of day, day of week, day of year, seasons) or +complex numbers or other phase data. +""" + +from ._swatches import _swatches + + +def swatches(): + return _swatches(__name__, globals()) + + +swatches.__doc__ = _swatches.__doc__ + +Twilight = [ + "#e2d9e2", + "#9ebbc9", + "#6785be", + "#5e43a5", + "#421257", + "#471340", + "#8e2c50", + "#ba6657", + "#ceac94", + "#e2d9e2", +] +IceFire = [ + "#000000", + "#001f4d", + "#003786", + "#0e58a8", + "#217eb8", + "#30a4ca", + "#54c8df", + "#9be4ef", + "#e1e9d1", + "#f3d573", + "#e7b000", + "#da8200", + "#c65400", + "#ac2301", + "#820000", + "#4c0000", + "#040100", +] +Edge = [ + "#313131", + "#3d019d", + "#3810dc", + "#2d47f9", + "#2593ff", + "#2adef6", + "#60fdfa", + "#aefdff", + "#f3f3f1", + "#fffda9", + "#fafd5b", + "#f7da29", + "#ff8e25", + "#f8432d", + "#d90d39", + "#97023d", + "#313131", +] +Phase = [ + "rgb(167, 119, 12)", + "rgb(197, 96, 51)", + "rgb(217, 67, 96)", + "rgb(221, 38, 163)", + "rgb(196, 59, 224)", + "rgb(153, 97, 244)", + "rgb(95, 127, 228)", + "rgb(40, 144, 183)", + "rgb(15, 151, 136)", + "rgb(39, 153, 79)", + "rgb(119, 141, 17)", + "rgb(167, 119, 12)", +] +HSV = [ + "#ff0000", + "#ffa700", + "#afff00", + "#08ff00", + "#00ff9f", + "#00b7ff", + "#0010ff", + "#9700ff", + "#ff00bf", + "#ff0018", +] +mrybm = [ + "#f884f7", + "#f968c4", + "#ea4388", + "#cf244b", + "#b51a15", + "#bd4304", + "#cc6904", + "#d58f04", + "#cfaa27", + "#a19f62", + "#588a93", + "#2269c4", + "#3e3ef0", + "#6b4ef9", + "#956bfa", + "#cd7dfe", +] +mygbm = [ + "#ef55f1", + "#fb84ce", + "#fbafa1", + "#fcd471", + "#f0ed35", + "#c6e516", + "#96d310", + "#61c10b", + "#31ac28", + "#439064", + "#3d719a", + "#284ec8", + "#2e21ea", + "#6324f5", + "#9139fa", + "#c543fa", +] diff --git a/packages/python/plotly/plotly/colors/diverging.py b/packages/python/plotly/plotly/colors/diverging.py new file mode 100644 index 00000000000..66caef7e921 --- /dev/null +++ b/packages/python/plotly/plotly/colors/diverging.py @@ -0,0 +1,31 @@ +""" +Diverging color scales are appropriate for continuous data that has a natural midpoint \ +other otherwise informative special value, such as 0 altitude, or the boiling point +of a liquid. The color scales in this module are \ +mostly meant to be passed in as the `color_continuous_scale` argument to various \ +functions, and to be used with the `color_continuous_midpoint` argument. +""" + +from .colorbrewer import ( # noqa: F401 + BrBG, + PRGn, + PiYG, + PuOr, + RdBu, + RdGy, + RdYlBu, + RdYlGn, + Spectral, +) +from .cmocean import balance, delta, curl # noqa: F401 +from .carto import Armyrose, Fall, Geyser, Temps, Tealrose, Tropic, Earth # noqa: F401 + + +from ._swatches import _swatches + + +def swatches(): + return _swatches(__name__, globals()) + + +swatches.__doc__ = _swatches.__doc__ diff --git a/packages/python/plotly/plotly/colors/qualitative.py b/packages/python/plotly/plotly/colors/qualitative.py new file mode 100644 index 00000000000..ca934f0d2e2 --- /dev/null +++ b/packages/python/plotly/plotly/colors/qualitative.py @@ -0,0 +1,147 @@ +""" +Qualitative color sequences are appropriate for data that has no natural ordering, such \ +as categories, colors, names, countries etc. The color sequences in this module are \ +mostly meant to be passed in as the `color_discrete_sequence` argument to various functions. +""" + +from ._swatches import _swatches + + +def swatches(): + return _swatches(__name__, globals()) + + +swatches.__doc__ = _swatches.__doc__ + +Plotly = [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#19d3f3", + "#e763fa", + "#fecb52", + "#ffa15a", + "#ff6692", + "#b6e880", +] + +D3 = [ + "#1f77b4", + "#ff7f0e", + "#2ca02c", + "#d62728", + "#9467bd", + "#8c564b", + "#e377c2", + "#7f7f7f", + "#bcbd22", + "#17becf", +] +G10 = [ + "#3366cc", + "#dc3912", + "#ff9900", + "#109618", + "#990099", + "#0099c6", + "#dd4477", + "#66aa00", + "#b82e2e", + "#316395", +] +T10 = [ + "#4c78a8", + "#f58518", + "#e45756", + "#72b7b2", + "#54a24b", + "#eeca3b", + "#b279a2", + "#ff9da6", + "#9d755d", + "#bab0ac", +] +Alphabet = [ + "#AA0DFE", + "#3283FE", + "#85660D", + "#782AB6", + "#565656", + "#1C8356", + "#16FF32", + "#F7E1A0", + "#E2E2E2", + "#1CBE4F", + "#C4451C", + "#DEA0FD", + "#FE00FA", + "#325A9B", + "#FEAF16", + "#F8A19F", + "#90AD1C", + "#F6222E", + "#1CFFCE", + "#2ED9FF", + "#B10DA1", + "#C075A6", + "#FC1CBF", + "#B00068", + "#FBE426", + "#FA0087", +] +Dark24 = [ + "#2E91E5", + "#E15F99", + "#1CA71C", + "#FB0D0D", + "#DA16FF", + "#222A2A", + "#B68100", + "#750D86", + "#EB663B", + "#511CFB", + "#00A08B", + "#FB00D1", + "#FC0080", + "#B2828D", + "#6C7C32", + "#778AAE", + "#862A16", + "#A777F1", + "#620042", + "#1616A7", + "#DA60CA", + "#6C4516", + "#0D2A63", + "#AF0038", +] +Light24 = [ + "#FD3216", + "#00FE35", + "#6A76FC", + "#FED4C4", + "#FE00CE", + "#0DF9FF", + "#F6F926", + "#FF9616", + "#479B55", + "#EEA6FB", + "#DC587D", + "#D626FF", + "#6E899C", + "#00B5F7", + "#B68E00", + "#C9FBE5", + "#FF0092", + "#22FFA7", + "#E3EE9E", + "#86CE00", + "#BC7196", + "#7E7DCD", + "#FC6955", + "#E48F72", +] + +from .colorbrewer import Set1, Pastel1, Dark2, Set2, Pastel2, Set3 # noqa: F401 +from .carto import Antique, Bold, Pastel, Prism, Safe, Vivid # noqa: F401 diff --git a/packages/python/plotly/plotly/colors/sequential.py b/packages/python/plotly/plotly/colors/sequential.py new file mode 100644 index 00000000000..a860f3f37b4 --- /dev/null +++ b/packages/python/plotly/plotly/colors/sequential.py @@ -0,0 +1,154 @@ +""" +Sequential color scales are appropriate for most continuous data, but in some cases it \ +can be helpful to use a `plotly_express.colors.diverging` or \ +`plotly_express.colors.cyclical` scale instead. The color scales in this module are \ +mostly meant to be passed in as the `color_continuous_scale` argument to various functions. +""" + +from ._swatches import _swatches + + +def swatches(): + return _swatches(__name__, globals()) + + +swatches.__doc__ = _swatches.__doc__ + +Plotly = [ + "#0508b8", + "#1910d8", + "#3c19f0", + "#6b1cfb", + "#981cfd", + "#bf1cfd", + "#dd2bfd", + "#f246fe", + "#fc67fd", + "#fe88fc", + "#fea5fd", + "#febefe", + "#fec3fe", +] + +Viridis = [ + "#440154", + "#482878", + "#3e4989", + "#31688e", + "#26828e", + "#1f9e89", + "#35b779", + "#6ece58", + "#b5de2b", + "#fde725", +] +Cividis = [ + "#00224e", + "#123570", + "#3b496c", + "#575d6d", + "#707173", + "#8a8678", + "#a59c74", + "#c3b369", + "#e1cc55", + "#fee838", +] + +Inferno = [ + "#000004", + "#1b0c41", + "#4a0c6b", + "#781c6d", + "#a52c60", + "#cf4446", + "#ed6925", + "#fb9b06", + "#f7d13d", + "#fcffa4", +] +Magma = [ + "#000004", + "#180f3d", + "#440f76", + "#721f81", + "#9e2f7f", + "#cd4071", + "#f1605d", + "#fd9668", + "#feca8d", + "#fcfdbf", +] +Plasma = [ + "#0d0887", + "#46039f", + "#7201a8", + "#9c179e", + "#bd3786", + "#d8576b", + "#ed7953", + "#fb9f3a", + "#fdca26", + "#f0f921", +] + +from .colorbrewer import ( # noqa: F401 + Blues, + BuGn, + BuPu, + GnBu, + Greens, + Greys, + OrRd, + Oranges, + PuBu, + PuBuGn, + PuRd, + Purples, + RdPu, + Reds, + YlGn, + YlGnBu, + YlOrBr, + YlOrRd, +) + +from .cmocean import ( # noqa: F401 + turbid, + thermal, + haline, + solar, + ice, + gray, + deep, + dense, + algae, + matter, + speed, + amp, + tempo, +) + +from .carto import ( # noqa: F401 + Burg, + Burgyl, + Redor, + Oryel, + Peach, + Pinkyl, + Mint, + Blugrn, + Darkmint, + Emrld, + Aggrnyl, + Bluyl, + Teal, + Tealgrn, + Purp, + Purpor, + Sunset, + Magenta, + Sunsetdark, + Agsunset, + Brwnyl, +) diff --git a/packages/python/plotly/plotly/data/__init__.py b/packages/python/plotly/plotly/data/__init__.py new file mode 100644 index 00000000000..2a83f8719ff --- /dev/null +++ b/packages/python/plotly/plotly/data/__init__.py @@ -0,0 +1,78 @@ +""" +Built-in datasets for demonstration, educational and test purposes. +""" + + +def gapminder(): + """ + Each row represents a country on a given year. + + https://www.gapminder.org/data/ + + Returns: + A `pandas.DataFrame` with 1704 rows and the following columns: `['country', 'continent', 'year', 'lifeExp', 'pop', 'gdpPercap', + 'iso_alpha', 'iso_num']`. + """ + return _get_dataset("gapminder") + + +def tips(): + """ + Each row represents a restaurant bill. + + https://vincentarelbundock.github.io/Rdatasets/doc/reshape2/tips.html + + Returns: + A `pandas.DataFrame` with 244 rows and the following columns: `['total_bill', 'tip', 'sex', 'smoker', 'day', 'time', 'size']`. + """ + return _get_dataset("tips") + + +def iris(): + """ + Each row represents a flower. + + https://en.wikipedia.org/wiki/Iris_flower_data_set + + Returns: + A `pandas.DataFrame` with 150 rows and the following columns: `['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species', + 'species_id']`. + """ + return _get_dataset("iris") + + +def wind(): + """ + Each row represents a level of wind intensity in a cardinal direction. + + Returns: + A `pandas.DataFrame` with 128 rows and the following columns: `['direction', 'strength', 'value']`. + """ + return _get_dataset("wind") + + +def election(): + """ + Each row represents voting results for an electoral district in the 2013 Montreal mayoral election. + + Returns: + A `pandas.DataFrame` with 58 rows and the following columns: `['district', 'Coderre', 'Bergeron', 'Joly', 'total', 'winner', 'result']`. + """ + return _get_dataset("election") + + +def carshare(): + """ + Each row represents the availability of car-sharing services near the centroid of a zone in Montreal. + + Returns: + A `pandas.DataFrame` with 249 rows and the following columns: `['centroid_lat', 'centroid_lon', 'car_hours', 'peak_hour']`. + """ + return _get_dataset("carshare") + + +def _get_dataset(d): + import pandas + import os + + return pandas.read_csv(os.path.join(os.path.dirname(__file__), d + ".csv.gz")) diff --git a/packages/python/plotly/plotly/express/__init__.py b/packages/python/plotly/plotly/express/__init__.py new file mode 100644 index 00000000000..6a956e64a54 --- /dev/null +++ b/packages/python/plotly/plotly/express/__init__.py @@ -0,0 +1,73 @@ +""" +`plotly_express` is a terse, consistent, high-level wrapper around `plotly` for rapid \ +data exploration and figure generation. See the gallery at https://plotly.github.io/plotly_express +""" + +__version__ = "0.3.0" + +from ._chart_types import ( # noqa: F401 + scatter, + scatter_3d, + scatter_polar, + scatter_ternary, + scatter_mapbox, + scatter_geo, + line, + line_3d, + line_polar, + line_ternary, + line_mapbox, + line_geo, + area, + bar, + bar_polar, + violin, + box, + strip, + histogram, + scatter_matrix, + parallel_coordinates, + parallel_categories, + choropleth, + density_contour, + density_heatmap, +) + +from ._core import ( # noqa: F401 + set_mapbox_access_token, + defaults, + get_trendline_results, +) + +from . import data, colors # noqa: F401 + +__all__ = [ + "scatter", + "scatter_3d", + "scatter_polar", + "scatter_ternary", + "scatter_mapbox", + "scatter_geo", + "scatter_matrix", + "density_contour", + "density_heatmap", + "line", + "line_polar", + "line_ternary", + "line_mapbox", + "line_geo", + "parallel_coordinates", + "parallel_categories", + "area", + "bar", + "bar_polar", + "violin", + "box", + "strip", + "histogram", + "choropleth", + "data", + "colors", + "set_mapbox_access_token", + "get_trendline_results", +] diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py new file mode 100644 index 00000000000..c6f2fc099a0 --- /dev/null +++ b/packages/python/plotly/plotly/express/_chart_types.py @@ -0,0 +1,1062 @@ +from ._core import make_figure +from ._doc import make_docstring +import plotly.graph_objs as go + + +def scatter( + data_frame, + x=None, + y=None, + color=None, + symbol=None, + size=None, + hover_name=None, + hover_data=None, + text=None, + facet_row=None, + facet_col=None, + error_x=None, + error_x_minus=None, + error_y=None, + error_y_minus=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_discrete_sequence=None, + color_discrete_map={}, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + symbol_sequence=None, + symbol_map={}, + opacity=None, + size_max=None, + marginal_x=None, + marginal_y=None, + trendline=None, + trendline_color_override=None, + log_x=False, + log_y=False, + range_x=None, + range_y=None, + render_mode="auto", + title=None, + template=None, + width=None, + height=None, +): + """ + In a scatter plot, each row of `data_frame` is represented by a symbol mark in 2D space. + """ + return make_figure(args=locals(), constructor=go.Scatter) + + +scatter.__doc__ = make_docstring(scatter) + + +def density_contour( + data_frame, + x=None, + y=None, + z=None, + color=None, + facet_row=None, + facet_col=None, + hover_name=None, + hover_data=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_discrete_sequence=None, + color_discrete_map={}, + marginal_x=None, + marginal_y=None, + trendline=None, + trendline_color_override=None, + log_x=False, + log_y=False, + range_x=None, + range_y=None, + histfunc=None, + histnorm=None, + nbinsx=None, + nbinsy=None, + title=None, + template=None, + width=None, + height=None, +): + """ + In a density contour plot, rows of `data_frame` are grouped together into contour marks to \ + visualize the 2D distribution of an aggregate function `histfunc` (e.g. the count or sum) \ + of the value `z`. + """ + return make_figure( + args=locals(), + constructor=go.Histogram2dContour, + trace_patch=dict( + contours=dict(coloring="none"), + histfunc=histfunc, + histnorm=histnorm, + nbinsx=nbinsx, + nbinsy=nbinsy, + xbingroup="x", + ybingroup="y", + ), + ) + + +density_contour.__doc__ = make_docstring(density_contour) + + +def density_heatmap( + data_frame, + x=None, + y=None, + z=None, + facet_row=None, + facet_col=None, + hover_name=None, + hover_data=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + marginal_x=None, + marginal_y=None, + opacity=None, + log_x=False, + log_y=False, + range_x=None, + range_y=None, + histfunc=None, + histnorm=None, + nbinsx=None, + nbinsy=None, + title=None, + template=None, + width=None, + height=None, +): + """ + In a density heatmap, rows of `data_frame` are grouped together into colored \ + rectangular tiles to visualize the 2D distribution of an aggregate function \ + `histfunc` (e.g. the count or sum) of the value `z`. + """ + return make_figure( + args=locals(), + constructor=go.Histogram2d, + trace_patch=dict( + histfunc=histfunc, + histnorm=histnorm, + nbinsx=nbinsx, + nbinsy=nbinsy, + xbingroup="x", + ybingroup="y", + ), + ) + + +density_heatmap.__doc__ = make_docstring(density_heatmap) + + +def line( + data_frame, + x=None, + y=None, + line_group=None, + color=None, + line_dash=None, + hover_name=None, + hover_data=None, + text=None, + facet_row=None, + facet_col=None, + error_x=None, + error_x_minus=None, + error_y=None, + error_y_minus=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_discrete_sequence=None, + color_discrete_map={}, + line_dash_sequence=None, + line_dash_map={}, + log_x=False, + log_y=False, + range_x=None, + range_y=None, + line_shape=None, + render_mode="auto", + title=None, + template=None, + width=None, + height=None, +): + """ + In a 2D line plot, each row of `data_frame` is represented as vertex of a polyline mark in 2D space. + """ + return make_figure(args=locals(), constructor=go.Scatter) + + +line.__doc__ = make_docstring(line) + + +def area( + data_frame, + x=None, + y=None, + line_group=None, + color=None, + hover_name=None, + hover_data=None, + text=None, + facet_row=None, + facet_col=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_discrete_sequence=None, + color_discrete_map={}, + orientation="v", + groupnorm=None, + log_x=False, + log_y=False, + range_x=None, + range_y=None, + line_shape=None, + title=None, + template=None, + width=None, + height=None, +): + """ + In a stacked area plot, each row of `data_frame` is represented as vertex of a polyline mark in 2D space. The area between successive polylines is filled. + """ + return make_figure( + args=locals(), + constructor=go.Scatter, + trace_patch=dict( + stackgroup=1, mode="lines", orientation=orientation, groupnorm=groupnorm + ), + ) + + +area.__doc__ = make_docstring(area) + + +def bar( + data_frame, + x=None, + y=None, + color=None, + facet_row=None, + facet_col=None, + hover_name=None, + hover_data=None, + text=None, + error_x=None, + error_x_minus=None, + error_y=None, + error_y_minus=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_discrete_sequence=None, + color_discrete_map={}, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + opacity=None, + orientation="v", + barmode="relative", + log_x=False, + log_y=False, + range_x=None, + range_y=None, + title=None, + template=None, + width=None, + height=None, +): + """ + In a bar plot, each row of `data_frame` is represented as a rectangular mark. + """ + return make_figure( + args=locals(), + constructor=go.Bar, + trace_patch=dict(orientation=orientation, textposition="auto"), + layout_patch=dict(barmode=barmode), + ) + + +bar.__doc__ = make_docstring(bar) + + +def histogram( + data_frame, + x=None, + y=None, + color=None, + facet_row=None, + facet_col=None, + hover_name=None, + hover_data=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_discrete_sequence=None, + color_discrete_map={}, + marginal=None, + opacity=None, + orientation="v", + barmode="relative", + barnorm=None, + histnorm=None, + log_x=False, + log_y=False, + range_x=None, + range_y=None, + histfunc=None, + cumulative=None, + nbins=None, + title=None, + template=None, + width=None, + height=None, +): + """ + In a histogram, rows of `data_frame` are grouped together into a rectangular mark to \ + visualize the 1D distribution of an aggregate function `histfunc` (e.g. the count or sum) \ + of the value `y` (or `x` if `orientation` is `'h'`). + """ + return make_figure( + args=locals(), + constructor=go.Histogram, + trace_patch=dict( + orientation=orientation, + histnorm=histnorm, + histfunc=histfunc, + nbinsx=nbins if orientation == "v" else None, + nbinsy=None if orientation == "v" else nbins, + cumulative=dict(enabled=cumulative), + bingroup="x" if orientation == "v" else "y", + ), + layout_patch=dict(barmode=barmode, barnorm=barnorm), + ) + + +histogram.__doc__ = make_docstring(histogram) + + +def violin( + data_frame, + x=None, + y=None, + color=None, + facet_row=None, + facet_col=None, + hover_name=None, + hover_data=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_discrete_sequence=None, + color_discrete_map={}, + orientation="v", + violinmode="group", + log_x=False, + log_y=False, + range_x=None, + range_y=None, + points=None, + box=False, + title=None, + template=None, + width=None, + height=None, +): + """ + In a violin plot, rows of `data_frame` are grouped together into a curved mark to \ + visualize their distribution. + """ + return make_figure( + args=locals(), + constructor=go.Violin, + trace_patch=dict( + orientation=orientation, + points=points, + box=dict(visible=box), + scalegroup=True, + x0=" ", + y0=" ", + ), + layout_patch=dict(violinmode=violinmode), + ) + + +violin.__doc__ = make_docstring(violin) + + +def box( + data_frame, + x=None, + y=None, + color=None, + facet_row=None, + facet_col=None, + hover_name=None, + hover_data=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_discrete_sequence=None, + color_discrete_map={}, + orientation="v", + boxmode="group", + log_x=False, + log_y=False, + range_x=None, + range_y=None, + points=None, + notched=False, + title=None, + template=None, + width=None, + height=None, +): + """ + In a box plot, rows of `data_frame` are grouped together into a box-and-whisker mark to \ + visualize their distribution. + """ + return make_figure( + args=locals(), + constructor=go.Box, + trace_patch=dict( + orientation=orientation, boxpoints=points, notched=notched, x0=" ", y0=" " + ), + layout_patch=dict(boxmode=boxmode), + ) + + +box.__doc__ = make_docstring(box) + + +def strip( + data_frame, + x=None, + y=None, + color=None, + facet_row=None, + facet_col=None, + hover_name=None, + hover_data=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_discrete_sequence=None, + color_discrete_map={}, + orientation="v", + stripmode="group", + log_x=False, + log_y=False, + range_x=None, + range_y=None, + title=None, + template=None, + width=None, + height=None, +): + """ + In a strip plot each row of `data_frame` is represented as a jittered mark within categories. + """ + return make_figure( + args=locals(), + constructor=go.Box, + trace_patch=dict( + orientation=orientation, + boxpoints="all", + pointpos=0, + hoveron="points", + fillcolor="rgba(255,255,255,0)", + line={"color": "rgba(255,255,255,0)"}, + x0=" ", + y0=" ", + ), + layout_patch=dict(boxmode=stripmode), + ) + + +strip.__doc__ = make_docstring(strip) + + +def scatter_3d( + data_frame, + x=None, + y=None, + z=None, + color=None, + symbol=None, + size=None, + text=None, + hover_name=None, + hover_data=None, + error_x=None, + error_x_minus=None, + error_y=None, + error_y_minus=None, + error_z=None, + error_z_minus=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + size_max=None, + color_discrete_sequence=None, + color_discrete_map={}, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + symbol_sequence=None, + symbol_map={}, + opacity=None, + log_x=False, + log_y=False, + log_z=False, + range_x=None, + range_y=None, + range_z=None, + title=None, + template=None, + width=None, + height=None, +): + """ + In a 3D scatter plot, each row of `data_frame` is represented by a symbol mark in 3D space. + """ + return make_figure(args=locals(), constructor=go.Scatter3d) + + +scatter_3d.__doc__ = make_docstring(scatter_3d) + + +def line_3d( + data_frame, + x=None, + y=None, + z=None, + color=None, + line_dash=None, + text=None, + line_group=None, + hover_name=None, + hover_data=None, + error_x=None, + error_x_minus=None, + error_y=None, + error_y_minus=None, + error_z=None, + error_z_minus=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_discrete_sequence=None, + color_discrete_map={}, + line_dash_sequence=None, + line_dash_map={}, + log_x=False, + log_y=False, + log_z=False, + range_x=None, + range_y=None, + range_z=None, + title=None, + template=None, + width=None, + height=None, +): + """ + In a 3D line plot, each row of `data_frame` is represented as vertex of a polyline mark in 3D space. + """ + return make_figure(args=locals(), constructor=go.Scatter3d) + + +line_3d.__doc__ = make_docstring(line_3d) + + +def scatter_ternary( + data_frame, + a=None, + b=None, + c=None, + color=None, + symbol=None, + size=None, + text=None, + hover_name=None, + hover_data=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_discrete_sequence=None, + color_discrete_map={}, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + symbol_sequence=None, + symbol_map={}, + opacity=None, + size_max=None, + title=None, + template=None, + width=None, + height=None, +): + """ + In a ternary scatter plot, each row of `data_frame` is represented by a symbol mark in ternary coordinates. + """ + return make_figure(args=locals(), constructor=go.Scatterternary) + + +scatter_ternary.__doc__ = make_docstring(scatter_ternary) + + +def line_ternary( + data_frame, + a=None, + b=None, + c=None, + color=None, + line_dash=None, + line_group=None, + hover_name=None, + hover_data=None, + text=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_discrete_sequence=None, + color_discrete_map={}, + line_dash_sequence=None, + line_dash_map={}, + line_shape=None, + title=None, + template=None, + width=None, + height=None, +): + """ + In a ternary line plot, each row of `data_frame` is represented as vertex of a polyline mark in ternary coordinates. + """ + return make_figure(args=locals(), constructor=go.Scatterternary) + + +line_ternary.__doc__ = make_docstring(line_ternary) + + +def scatter_polar( + data_frame, + r=None, + theta=None, + color=None, + symbol=None, + size=None, + hover_name=None, + hover_data=None, + text=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_discrete_sequence=None, + color_discrete_map={}, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + symbol_sequence=None, + symbol_map={}, + opacity=None, + direction="clockwise", + start_angle=90, + size_max=None, + range_r=None, + log_r=False, + render_mode="auto", + title=None, + template=None, + width=None, + height=None, +): + """ + In a polar scatter plot, each row of `data_frame` is represented by a symbol mark in + polar coordinates. + """ + return make_figure(args=locals(), constructor=go.Scatterpolar) + + +scatter_polar.__doc__ = make_docstring(scatter_polar) + + +def line_polar( + data_frame, + r=None, + theta=None, + color=None, + line_dash=None, + hover_name=None, + hover_data=None, + line_group=None, + text=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_discrete_sequence=None, + color_discrete_map={}, + line_dash_sequence=None, + line_dash_map={}, + direction="clockwise", + start_angle=90, + line_close=False, + line_shape=None, + render_mode="auto", + range_r=None, + log_r=False, + title=None, + template=None, + width=None, + height=None, +): + """ + In a polar line plot, each row of `data_frame` is represented as vertex of a polyline mark in polar coordinates. + """ + return make_figure(args=locals(), constructor=go.Scatterpolar) + + +line_polar.__doc__ = make_docstring(line_polar) + + +def bar_polar( + data_frame, + r=None, + theta=None, + color=None, + hover_name=None, + hover_data=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_discrete_sequence=None, + color_discrete_map={}, + barnorm="", + barmode="relative", + direction="clockwise", + start_angle=90, + range_r=None, + log_r=False, + title=None, + template=None, + width=None, + height=None, +): + """ + In a polar bar plot, each row of `data_frame` is represented as a wedge mark in polar coordinates. + """ + return make_figure( + args=locals(), + constructor=go.Barpolar, + layout_patch=dict(barnorm=barnorm, barmode=barmode), + ) + + +bar_polar.__doc__ = make_docstring(bar_polar) + + +def choropleth( + data_frame, + lat=None, + lon=None, + locations=None, + locationmode=None, + color=None, + hover_name=None, + hover_data=None, + size=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + size_max=None, + projection=None, + scope=None, + center=None, + title=None, + template=None, + width=None, + height=None, +): + """ + In a choropleth map, each row of `data_frame` is represented by a colored region mark on a map. + """ + return make_figure( + args=locals(), + constructor=go.Choropleth, + trace_patch=dict(locationmode=locationmode), + ) + + +choropleth.__doc__ = make_docstring(choropleth) + + +def scatter_geo( + data_frame, + lat=None, + lon=None, + locations=None, + locationmode=None, + color=None, + text=None, + hover_name=None, + hover_data=None, + size=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_discrete_sequence=None, + color_discrete_map={}, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + opacity=None, + size_max=None, + projection=None, + scope=None, + center=None, + title=None, + template=None, + width=None, + height=None, +): + """ + In a geographic scatter plot, each row of `data_frame` is represented by a symbol mark on a map. + """ + return make_figure( + args=locals(), + constructor=go.Scattergeo, + trace_patch=dict(locationmode=locationmode), + ) + + +scatter_geo.__doc__ = make_docstring(scatter_geo) + + +def line_geo( + data_frame, + lat=None, + lon=None, + locations=None, + locationmode=None, + color=None, + line_dash=None, + text=None, + hover_name=None, + hover_data=None, + line_group=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_discrete_sequence=None, + color_discrete_map={}, + line_dash_sequence=None, + line_dash_map={}, + projection=None, + scope=None, + center=None, + title=None, + template=None, + width=None, + height=None, +): + """ + In a geographic line plot, each row of `data_frame` is represented as vertex of a polyline mark on a map. + """ + return make_figure( + args=locals(), + constructor=go.Scattergeo, + trace_patch=dict(locationmode=locationmode), + ) + + +line_geo.__doc__ = make_docstring(line_geo) + + +def scatter_mapbox( + data_frame, + lat=None, + lon=None, + color=None, + text=None, + hover_name=None, + hover_data=None, + size=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_discrete_sequence=None, + color_discrete_map={}, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + opacity=None, + size_max=None, + zoom=8, + title=None, + template=None, + width=None, + height=None, +): + """ + In a Mapbox scatter plot, each row of `data_frame` is represented by a symbol mark on a Mapbox map. + """ + return make_figure(args=locals(), constructor=go.Scattermapbox) + + +scatter_mapbox.__doc__ = make_docstring(scatter_mapbox) + + +def line_mapbox( + data_frame, + lat=None, + lon=None, + color=None, + text=None, + hover_name=None, + hover_data=None, + line_group=None, + animation_frame=None, + animation_group=None, + category_orders={}, + labels={}, + color_discrete_sequence=None, + color_discrete_map={}, + zoom=8, + title=None, + template=None, + width=None, + height=None, +): + """ + In a Mapbox line plot, each row of `data_frame` is represented as vertex of a polyline mark on a Mapbox map. + """ + return make_figure(args=locals(), constructor=go.Scattermapbox) + + +line_mapbox.__doc__ = make_docstring(line_mapbox) + + +def scatter_matrix( + data_frame, + dimensions=None, + color=None, + symbol=None, + size=None, + hover_name=None, + hover_data=None, + category_orders={}, + labels={}, + color_discrete_sequence=None, + color_discrete_map={}, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + symbol_sequence=None, + symbol_map={}, + opacity=None, + size_max=None, + title=None, + template=None, + width=None, + height=None, +): + """ + In a scatter plot matrix (or SPLOM), each row of `data_frame` is represented \ + by a multiple symbol marks, one in each cell of a grid of 2D scatter plots, which \ + plot each pair of `dimensions` against each other. + """ + return make_figure( + args=locals(), constructor=go.Splom, layout_patch=dict(dragmode="select") + ) + + +scatter_matrix.__doc__ = make_docstring(scatter_matrix) + + +def parallel_coordinates( + data_frame, + dimensions=None, + color=None, + labels={}, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + title=None, + template=None, + width=None, + height=None, +): + """ + In a parallel coordinates plot, each row of `data_frame` is represented \ + by a polyline mark which traverses a set of parallel axes, one for each of the \ + `dimensions`. + """ + return make_figure(args=locals(), constructor=go.Parcoords) + + +parallel_coordinates.__doc__ = make_docstring(parallel_coordinates) + + +def parallel_categories( + data_frame, + dimensions=None, + color=None, + labels={}, + color_continuous_scale=None, + range_color=None, + color_continuous_midpoint=None, + title=None, + template=None, + width=None, + height=None, +): + """ + In a parallel categories (or parallel sets) plot, each row of `data_frame` is \ + grouped with other rows that share the same values of `dimensions` and then plotted \ + as a polyline mark through a set of parallel axes, one for each of the `dimensions`. + """ + return make_figure(args=locals(), constructor=go.Parcats) + + +parallel_categories.__doc__ = make_docstring(parallel_categories) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py new file mode 100644 index 00000000000..0659fa0d696 --- /dev/null +++ b/packages/python/plotly/plotly/express/_core.py @@ -0,0 +1,1116 @@ +import plotly.graph_objs as go +import plotly.io as pio +from collections import namedtuple, OrderedDict +from .colors import qualitative, sequential +import math +import pandas + +from plotly.subplots import ( + make_subplots, + _set_trace_grid_reference, + _subplot_type_for_trace_type, +) + + +class PxDefaults(object): + def __init__(self): + self.template = None + self.width = None + self.height = 600 + self.color_discrete_sequence = None + self.color_continuous_scale = None + self.symbol_sequence = ["circle", "diamond", "square", "x", "cross"] + self.line_dash_sequence = ["solid", "dot", "dash", "longdash", "dashdot"] + [ + "longdashdot" + ] + self.size_max = 20 + + +defaults = PxDefaults() +del PxDefaults + +MAPBOX_TOKEN = "" + + +def set_mapbox_access_token(token): + """ + Arguments: + token: A Mapbox token to be used in `plotly_express.scatter_mapbox` and \ + `plotly_express.line_mapbox` figures. See \ + https://docs.mapbox.com/help/how-mapbox-works/access-tokens/ for more details + """ + global MAPBOX_TOKEN + MAPBOX_TOKEN = token + + +def get_trendline_results(fig): + """ + Extracts fit statistics for trendlines (when applied to figures generated with + the `trendline` argument set to `"ols"`). + + Arguments: + fig: the output of a `plotly_express` charting call + Returns: + A `pandas.DataFrame` with a column "px_fit_results" containing the `statsmodels` + results objects, along with columns identifying the subset of the data the + trendline was fit on. + """ + return fig._px_trendlines + + +Mapping = namedtuple( + "Mapping", + [ + "show_in_trace_name", + "grouper", + "val_map", + "sequence", + "updater", + "variable", + "facet", + ], +) +TraceSpec = namedtuple("TraceSpec", ["constructor", "attrs", "trace_patch", "marginal"]) + + +def get_label(args, column): + try: + return args["labels"][column] + except Exception: + return column + + +def get_decorated_label(args, column, role): + label = get_label(args, column) + if "histfunc" in args and ( + (role == "x" and "orientation" in args and args["orientation"] == "h") + or (role == "y" and "orientation" in args and args["orientation"] == "v") + or (role == "z") + ): + if label: + return "%s of %s" % (args["histfunc"] or "count", label) + else: + return "count" + else: + return label + + +def make_mapping(args, variable): + if variable == "line_group" or variable == "animation_frame": + return Mapping( + show_in_trace_name=False, + grouper=args[variable], + val_map={}, + sequence=[""], + variable=variable, + updater=(lambda trace, v: v), + facet=None, + ) + if variable == "facet_row" or variable == "facet_col": + letter = "x" if variable == "facet_col" else "y" + return Mapping( + show_in_trace_name=False, + variable=letter, + grouper=args[variable], + val_map={}, + sequence=[i for i in range(1, 1000)], + updater=(lambda trace, v: v), + facet="row" if variable == "facet_row" else "col", + ) + (parent, variable) = variable.split(".") + vprefix = variable + arg_name = variable + if variable == "color": + vprefix = "color_discrete" + if variable == "dash": + arg_name = "line_dash" + vprefix = "line_dash" + return Mapping( + show_in_trace_name=True, + variable=variable, + grouper=args[arg_name], + val_map=args[vprefix + "_map"].copy(), + sequence=args[vprefix + "_sequence"], + updater=lambda trace, v: trace.update({parent: {variable: v}}), + facet=None, + ) + + +def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref): + + if "line_close" in args and args["line_close"]: + g = g.append(g.iloc[0]) + result = trace_spec.trace_patch.copy() or {} + fit_results = None + hover_header = "" + for k in trace_spec.attrs: + v = args[k] + v_label = get_decorated_label(args, v, k) + if k == "dimensions": + dims = [ + (name, column) + for (name, column) in g.iteritems() + if ((not v) or (name in v)) + and ( + trace_spec.constructor != go.Parcoords + or args["data_frame"][name].dtype.kind in "bifc" + ) + and ( + trace_spec.constructor != go.Parcats + or len(args["data_frame"][name].unique()) <= 20 + ) + ] + result["dimensions"] = [ + dict(label=get_label(args, name), values=column.values) + for (name, column) in dims + ] + if trace_spec.constructor == go.Splom: + for d in result["dimensions"]: + d["axis"] = dict(matches=True) + mapping_labels["%{xaxis.title.text}"] = "%{x}" + mapping_labels["%{yaxis.title.text}"] = "%{y}" + + elif ( + v is not None + or (trace_spec.constructor == go.Histogram and k in ["x", "y"]) + or ( + trace_spec.constructor in [go.Histogram2d, go.Histogram2dContour] + and k == "z" + ) + ): + if k == "size": + if "marker" not in result: + result["marker"] = dict() + result["marker"]["size"] = g[v] + result["marker"]["sizemode"] = "area" + result["marker"]["sizeref"] = sizeref + mapping_labels[v_label] = "%{marker.size}" + elif k == "marginal_x": + if trace_spec.constructor == go.Histogram: + mapping_labels["count"] = "%{y}" + elif k == "marginal_y": + if trace_spec.constructor == go.Histogram: + mapping_labels["count"] = "%{x}" + elif k == "trendline": + if v in ["ols", "lowess"] and args["x"] and args["y"] and len(g) > 1: + import statsmodels.api as sm + import numpy as np + + # sorting is bad but trace_specs with "trendline" have no other attrs + g2 = g.sort_values(by=args["x"]) + y = g2[args["y"]] + x = g2[args["x"]] + result["x"] = x + + if x.dtype.type == np.datetime64: + x = x.astype(int) / 10 ** 9 # convert to unix epoch seconds + + if v == "lowess": + trendline = sm.nonparametric.lowess(y, x) + result["y"] = trendline[:, 1] + hover_header = "LOWESS trendline

" + elif v == "ols": + fit_results = sm.OLS(y, sm.add_constant(x)).fit() + result["y"] = fit_results.predict() + hover_header = "OLS trendline
" + hover_header += "%s = %f * %s + %f
" % ( + args["y"], + fit_results.params[1], + args["x"], + fit_results.params[0], + ) + hover_header += ( + "R2=%f

" % fit_results.rsquared + ) + mapping_labels[get_label(args, args["x"])] = "%{x}" + mapping_labels[get_label(args, args["y"])] = "%{y} (trend)" + + elif k.startswith("error"): + error_xy = k[:7] + arr = "arrayminus" if k.endswith("minus") else "array" + if error_xy not in result: + result[error_xy] = {} + result[error_xy][arr] = g[v] + elif k == "hover_name": + if trace_spec.constructor not in [ + go.Histogram, + go.Histogram2d, + go.Histogram2dContour, + ]: + result["hovertext"] = g[v] + if hover_header == "": + hover_header = "%{hovertext}

" + elif k == "hover_data": + if trace_spec.constructor not in [ + go.Histogram, + go.Histogram2d, + go.Histogram2dContour, + ]: + result["customdata"] = g[v].values + for i, col in enumerate(v): + v_label_col = get_decorated_label(args, col, None) + mapping_labels[v_label_col] = "%%{customdata[%d]}" % i + elif k == "color": + if trace_spec.constructor == go.Choropleth: + result["z"] = g[v] + result["coloraxis"] = "coloraxis1" + mapping_labels[v_label] = "%{z}" + else: + colorable = "marker" + if trace_spec.constructor in [go.Parcats, go.Parcoords]: + colorable = "line" + if colorable not in result: + result[colorable] = dict() + result[colorable]["color"] = g[v] + result[colorable]["coloraxis"] = "coloraxis1" + mapping_labels[v_label] = "%%{%s.color}" % colorable + elif k == "animation_group": + result["ids"] = g[v] + elif k == "locations": + result[k] = g[v] + mapping_labels[v_label] = "%{location}" + else: + if v: + result[k] = g[v] + mapping_labels[v_label] = "%%{%s}" % k + if trace_spec.constructor not in [go.Parcoords, go.Parcats]: + hover_lines = [k + "=" + v for k, v in mapping_labels.items()] + result["hovertemplate"] = hover_header + "
".join(hover_lines) + return result, fit_results + + +def configure_axes(args, constructor, fig, orders): + configurators = { + go.Scatter: configure_cartesian_axes, + go.Scattergl: configure_cartesian_axes, + go.Bar: configure_cartesian_axes, + go.Box: configure_cartesian_axes, + go.Violin: configure_cartesian_axes, + go.Histogram: configure_cartesian_axes, + go.Histogram2dContour: configure_cartesian_axes, + go.Histogram2d: configure_cartesian_axes, + go.Scatter3d: configure_3d_axes, + go.Scatterternary: configure_ternary_axes, + go.Scatterpolar: configure_polar_axes, + go.Scatterpolargl: configure_polar_axes, + go.Barpolar: configure_polar_axes, + go.Scattermapbox: configure_mapbox, + go.Scattergeo: configure_geo, + go.Choropleth: configure_geo, + } + if constructor in configurators: + configurators[constructor](args, fig, orders) + + +def set_cartesian_axis_opts(args, axis, letter, orders): + log_key = "log_" + letter + range_key = "range_" + letter + if log_key in args and args[log_key]: + axis["type"] = "log" + if range_key in args and args[range_key]: + axis["range"] = [math.log(r, 10) for r in args[range_key]] + elif range_key in args and args[range_key]: + axis["range"] = args[range_key] + + if args[letter] in orders: + axis["categoryorder"] = "array" + axis["categoryarray"] = ( + orders[args[letter]] + if isinstance(axis, go.layout.XAxis) + else list(reversed(orders[args[letter]])) + ) + + +def configure_cartesian_marginal_axes(args, fig, orders): + + if "histogram" in [args["marginal_x"], args["marginal_y"]]: + fig.layout["barmode"] = "overlay" + + nrows = len(fig._grid_ref) + ncols = len(fig._grid_ref[0]) + + # Set y-axis titles and axis options in the left-most column + for yaxis in fig.select_yaxes(col=1): + set_cartesian_axis_opts(args, yaxis, "y", orders) + + # Set x-axis titles and axis options in the bottom-most row + for xaxis in fig.select_xaxes(row=1): + set_cartesian_axis_opts(args, xaxis, "x", orders) + + # Configure axis ticks on marginal subplots + if args["marginal_x"]: + fig.update_yaxes( + showticklabels=False, + showgrid=args["marginal_x"] == "histogram", + row=nrows, + ) + fig.update_xaxes(showgrid=True, row=nrows) + + if args["marginal_y"]: + fig.update_xaxes( + showticklabels=False, + showgrid=args["marginal_y"] == "histogram", + col=ncols, + ) + fig.update_yaxes(showgrid=True, col=ncols) + + # Add axis titles to non-marginal subplots + y_title = get_decorated_label(args, args["y"], "y") + for row in range(1, nrows): + fig.update_yaxes(title_text=y_title, row=row, col=1) + + x_title = get_decorated_label(args, args["x"], "x") + for col in range(1, ncols): + fig.update_xaxes(title_text=x_title, row=1, col=col) + + # Configure axis type across all x-axes + if "log_x" in args and args["log_x"]: + fig.update_xaxes(type="log") + + # Configure axis type across all y-axes + if "log_y" in args and args["log_y"]: + fig.update_yaxes(type="log") + + # Configure matching and axis type for marginal y-axes + matches_y = "y" + str(ncols + 1) + if args["marginal_x"]: + for row in range(2, nrows + 1, 2): + fig.update_yaxes(matches=matches_y, type=None, row=row) + + if args["marginal_y"]: + for col in range(2, ncols + 1, 2): + fig.update_xaxes(matches="x2", type=None, col=col) + + +def configure_cartesian_axes(args, fig, orders): + if ("marginal_x" in args and args["marginal_x"]) or ( + "marginal_y" in args and args["marginal_y"] + ): + configure_cartesian_marginal_axes(args, fig, orders) + return + + # Set y-axis titles and axis options in the left-most column + y_title = get_decorated_label(args, args["y"], "y") + for yaxis in fig.select_yaxes(col=1): + yaxis.update(title_text=y_title) + set_cartesian_axis_opts(args, yaxis, "y", orders) + + # Set x-axis titles and axis options in the bottom-most row + x_title = get_decorated_label(args, args["x"], "x") + for xaxis in fig.select_xaxes(row=1): + xaxis.update(title_text=x_title) + set_cartesian_axis_opts(args, xaxis, "x", orders) + + # Configure axis type across all x-axes + if "log_x" in args and args["log_x"]: + fig.update_xaxes(type="log") + + # Configure axis type across all y-axes + if "log_y" in args and args["log_y"]: + fig.update_yaxes(type="log") + + return fig.layout + + +def configure_ternary_axes(args, fig, orders): + fig.update( + layout=dict( + ternary=dict( + aaxis=dict(title=get_label(args, args["a"])), + baxis=dict(title=get_label(args, args["b"])), + caxis=dict(title=get_label(args, args["c"])), + ) + ) + ) + + +def configure_polar_axes(args, fig, orders): + layout = dict( + polar=dict( + angularaxis=dict(direction=args["direction"], rotation=args["start_angle"]), + radialaxis=dict(), + ) + ) + + for var, axis in [("r", "radialaxis"), ("theta", "angularaxis")]: + if args[var] in orders: + layout["polar"][axis]["categoryorder"] = "array" + layout["polar"][axis]["categoryarray"] = orders[args[var]] + + radialaxis = layout["polar"]["radialaxis"] + if args["log_r"]: + radialaxis["type"] = "log" + if args["range_r"]: + radialaxis["range"] = [math.log(x, 10) for x in args["range_r"]] + else: + if args["range_r"]: + radialaxis["range"] = args["range_r"] + fig.update(layout=layout) + + +def configure_3d_axes(args, fig, orders): + layout = dict( + scene=dict( + xaxis=dict(title=get_label(args, args["x"])), + yaxis=dict(title=get_label(args, args["y"])), + zaxis=dict(title=get_label(args, args["z"])), + ) + ) + + for letter in ["x", "y", "z"]: + axis = layout["scene"][letter + "axis"] + if args["log_" + letter]: + axis["type"] = "log" + if args["range_" + letter]: + axis["range"] = [math.log(x, 10) for x in args["range_" + letter]] + else: + if args["range_" + letter]: + axis["range"] = args["range_" + letter] + if args[letter] in orders: + axis["categoryorder"] = "array" + axis["categoryarray"] = orders[args[letter]] + fig.update(layout=layout) + + +def configure_mapbox(args, fig, orders): + fig.update( + layout=dict( + mapbox=dict( + accesstoken=MAPBOX_TOKEN, + center=dict( + lat=args["data_frame"][args["lat"]].mean(), + lon=args["data_frame"][args["lon"]].mean(), + ), + zoom=args["zoom"], + ) + ) + ) + + +def configure_geo(args, fig, orders): + fig.update( + layout=dict( + geo=dict( + center=args["center"], + scope=args["scope"], + projection=dict(type=args["projection"]), + ) + ) + ) + + +def configure_animation_controls(args, constructor, fig): + def frame_args(duration): + return { + "frame": {"duration": duration, "redraw": constructor != go.Scatter}, + "mode": "immediate", + "fromcurrent": True, + "transition": {"duration": duration, "easing": "linear"}, + } + + if "animation_frame" in args and args["animation_frame"] and len(fig.frames) > 1: + fig.layout.updatemenus = [ + { + "buttons": [ + { + "args": [None, frame_args(500)], + "label": "▶", + "method": "animate", + }, + { + "args": [[None], frame_args(0)], + "label": "◼", + "method": "animate", + }, + ], + "direction": "left", + "pad": {"r": 10, "t": 70}, + "showactive": False, + "type": "buttons", + "x": 0.1, + "xanchor": "right", + "y": 0, + "yanchor": "top", + } + ] + fig.layout.sliders = [ + { + "active": 0, + "yanchor": "top", + "xanchor": "left", + "currentvalue": { + "prefix": get_label(args, args["animation_frame"]) + "=" + }, + "pad": {"b": 10, "t": 60}, + "len": 0.9, + "x": 0.1, + "y": 0, + "steps": [ + { + "args": [[f.name], frame_args(0)], + "label": f.name, + "method": "animate", + } + for f in fig.frames + ], + } + ] + + +def make_trace_spec(args, constructor, attrs, trace_patch): + # Create base trace specification + result = [TraceSpec(constructor, attrs, trace_patch, None)] + + # Add marginal trace specifications + for letter in ["x", "y"]: + if "marginal_" + letter in args and args["marginal_" + letter]: + trace_spec = None + axis_map = dict( + xaxis="x1" if letter == "x" else "x2", + yaxis="y1" if letter == "y" else "y2", + ) + if args["marginal_" + letter] == "histogram": + trace_spec = TraceSpec( + constructor=go.Histogram, + attrs=[letter, "marginal_" + letter], + trace_patch=dict(opacity=0.5, bingroup=letter, **axis_map), + marginal=letter, + ) + elif args["marginal_" + letter] == "violin": + trace_spec = TraceSpec( + constructor=go.Violin, + attrs=[letter, "hover_name", "hover_data"], + trace_patch=dict(scalegroup=letter), + marginal=letter, + ) + elif args["marginal_" + letter] == "box": + trace_spec = TraceSpec( + constructor=go.Box, + attrs=[letter, "hover_name", "hover_data"], + trace_patch=dict(notched=True), + marginal=letter, + ) + elif args["marginal_" + letter] == "rug": + symbols = {"x": "line-ns-open", "y": "line-ew-open"} + trace_spec = TraceSpec( + constructor=go.Box, + attrs=[letter, "hover_name", "hover_data"], + trace_patch=dict( + fillcolor="rgba(255,255,255,0)", + line={"color": "rgba(255,255,255,0)"}, + boxpoints="all", + jitter=0, + hoveron="points", + marker={"symbol": symbols[letter]}, + ), + marginal=letter, + ) + if "color" in attrs or "color" not in args: + if "marker" not in trace_spec.trace_patch: + trace_spec.trace_patch["marker"] = dict() + first_default_color = args["color_continuous_scale"][0] + trace_spec.trace_patch["marker"]["color"] = first_default_color + result.append(trace_spec) + + # Add trendline trace specifications + if "trendline" in args and args["trendline"]: + trace_spec = TraceSpec( + constructor=go.Scatter, + attrs=["trendline"], + trace_patch=dict(mode="lines"), + marginal=None, + ) + if args["trendline_color_override"]: + trace_spec.trace_patch["line"] = dict( + color=args["trendline_color_override"] + ) + result.append(trace_spec) + return result + + +def one_group(x): + return "" + + +def apply_default_cascade(args): + # first we apply px.defaults to unspecified args + for param in ( + ["color_discrete_sequence", "color_continuous_scale"] + + ["symbol_sequence", "line_dash_sequence", "template"] + + ["width", "height", "size_max"] + ): + if param in args and args[param] is None: + args[param] = getattr(defaults, param) + + # load the default template if set, otherwise "plotly" + if args["template"] is None: + if pio.templates.default is not None: + args["template"] = pio.templates.default + else: + args["template"] = "plotly" + + # retrieve the actual template if we were given a name + try: + template = pio.templates[args["template"]] + except Exception: + template = args["template"] + + # if colors not set explicitly or in px.defaults, defer to a template + # if the template doesn't have one, we set some final fallback defaults + if "color_continuous_scale" in args: + if args["color_continuous_scale"] is None: + try: + args["color_continuous_scale"] = [ + x[1] for x in template.layout.colorscale.sequential + ] + except AttributeError: + pass + if args["color_continuous_scale"] is None: + args["color_continuous_scale"] = sequential.Plasma + + if "color_discrete_sequence" in args: + if args["color_discrete_sequence"] is None: + try: + args["color_discrete_sequence"] = template.layout.colorway + except AttributeError: + pass + if args["color_discrete_sequence"] is None: + args["color_discrete_sequence"] = qualitative.Plotly + + # If both marginals and faceting are specified, faceting wins + if args.get('facet_col', None) and args.get('marginal_y', None): + args['marginal_y'] = None + + if args.get('facet_row', None) and args.get('marginal_x', None): + args['marginal_x'] = None + + +def infer_config(args, constructor, trace_patch): + # Declare all supported attributes, across all plot types + attrables = ( + ["x", "y", "z", "a", "b", "c", "r", "theta", "size"] + + ["dimensions", "hover_name", "hover_data", "text", "error_x", "error_x_minus"] + + ["error_y", "error_y_minus", "error_z", "error_z_minus"] + + ["lat", "lon", "locations", "animation_group"] + ) + array_attrables = ["dimensions", "hover_data"] + group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"] + + # Validate that the strings provided as attribute values reference columns + # in the provided data_frame + df_columns = args["data_frame"].columns + + for attr in attrables + group_attrables + ["color"]: + if attr in args and args[attr] is not None: + maybe_col_list = [args[attr]] if attr not in array_attrables else args[attr] + for maybe_col in maybe_col_list: + try: + in_cols = maybe_col in df_columns + except TypeError: + in_cols = False + if not in_cols: + value_str = ( + "Element of value" if attr in array_attrables else "Value" + ) + raise ValueError( + "%s of '%s' is not the name of a column in 'data_frame'. " + "Expected one of %s but received: %s" + % (value_str, attr, str(list(df_columns)), str(maybe_col)) + ) + + attrs = [k for k in attrables if k in args] + grouped_attrs = [] + + # Compute sizeref + sizeref = 0 + if "size" in args and args["size"]: + sizeref = args["data_frame"][args["size"]].max() / args["size_max"] ** 2 + + # Compute color attributes and grouping attributes + if "color" in args: + if "color_continuous_scale" in args: + if "color_discrete_sequence" not in args: + attrs.append("color") + else: + if ( + args["color"] + and args["data_frame"][args["color"]].dtype.kind in "bifc" + ): + attrs.append("color") + else: + grouped_attrs.append("marker.color") + elif "line_group" in args or constructor == go.Histogram2dContour: + grouped_attrs.append("line.color") + else: + grouped_attrs.append("marker.color") + + show_colorbar = bool("color" in attrs and args["color"]) + + # Compute line_dash grouping attribute + if "line_dash" in args: + grouped_attrs.append("line.dash") + + # Compute symbol grouping attribute + if "symbol" in args: + grouped_attrs.append("marker.symbol") + + # Compute final trace patch + trace_patch = trace_patch.copy() + + if constructor == go.Histogram2d: + show_colorbar = True + trace_patch["coloraxis"] = "coloraxis1" + + if "opacity" in args: + if args["opacity"] is None: + if "barmode" in args and args["barmode"] == "overlay": + trace_patch["marker"] = dict(opacity=0.5) + else: + trace_patch["marker"] = dict(opacity=args["opacity"]) + if "line_group" in args: + trace_patch["mode"] = "lines" + ("+markers+text" if args["text"] else "") + elif constructor != go.Splom and ( + "symbol" in args or constructor == go.Scattermapbox + ): + trace_patch["mode"] = "markers" + ("+text" if args["text"] else "") + + if "line_shape" in args: + trace_patch["line"] = dict(shape=args["line_shape"]) + + # Compute marginal attribute + if "marginal" in args: + position = "marginal_x" if args["orientation"] == "v" else "marginal_y" + other_position = "marginal_x" if args["orientation"] == "h" else "marginal_y" + args[position] = args["marginal"] + args[other_position] = None + + # Compute applicable grouping attributes + for k in group_attrables: + if k in args: + grouped_attrs.append(k) + + # Create grouped mappings + grouped_mappings = [make_mapping(args, a) for a in grouped_attrs] + + # Create trace specs + trace_specs = make_trace_spec(args, constructor, attrs, trace_patch) + return trace_specs, grouped_mappings, sizeref, show_colorbar + + +def get_orderings(args, grouper, grouped): + """ + `orders` is the user-supplied ordering (with the remaining data-frame-supplied + ordering appended if the column is used for grouping) + `group_names` is the set of groups, ordered by the order above + """ + orders = {} if "category_orders" not in args else args["category_orders"].copy() + group_names = [] + for group_name in grouped.groups: + if len(grouper) == 1: + group_name = (group_name,) + group_names.append(group_name) + for col in grouper: + if col != one_group: + uniques = args["data_frame"][col].unique() + if col not in orders: + orders[col] = list(uniques) + else: + for val in uniques: + if val not in orders[col]: + orders[col].append(val) + + for i, col in reversed(list(enumerate(grouper))): + if col != one_group: + group_names = sorted( + group_names, + key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1, + ) + + return orders, group_names + + +def make_figure(args, constructor, trace_patch={}, layout_patch={}): + apply_default_cascade(args) + + trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config( + args, constructor, trace_patch + ) + grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group] + grouped = args["data_frame"].groupby(grouper, sort=False) + + orders, sorted_group_names = get_orderings(args, grouper, grouped) + + has_marginal_x = bool(args.get("marginal_x", False)) + has_marginal_y = bool(args.get("marginal_y", False)) + + subplot_type = _subplot_type_for_trace_type(constructor().type) + + trace_names_by_frame = {} + frames = OrderedDict() + trendline_rows = [] + nrows = ncols = 1 + for group_name in sorted_group_names: + group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0]) + mapping_labels = OrderedDict() + trace_name_labels = OrderedDict() + frame_name = "" + for col, val, m in zip(grouper, group_name, grouped_mappings): + if col != one_group: + key = get_label(args, col) + mapping_labels[key] = str(val) + if m.show_in_trace_name: + trace_name_labels[key] = str(val) + if m.variable == "animation_frame": + frame_name = val + trace_name = ", ".join(k + "=" + v for k, v in trace_name_labels.items()) + if frame_name not in trace_names_by_frame: + trace_names_by_frame[frame_name] = set() + trace_names = trace_names_by_frame[frame_name] + + for trace_spec in trace_specs: + constructor_to_use = trace_spec.constructor + if constructor_to_use in [go.Scatter, go.Scatterpolar]: + if "render_mode" in args and ( + args["render_mode"] == "webgl" + or ( + args["render_mode"] == "auto" + and len(args["data_frame"]) > 1000 + and args["animation_frame"] is None + ) + ): + constructor_to_use = ( + go.Scattergl + if constructor_to_use == go.Scatter + else go.Scatterpolargl + ) + trace = constructor_to_use(name=trace_name) + if trace_spec.constructor not in [ + go.Parcats, + go.Parcoords, + go.Choropleth, + go.Histogram2d, + ]: + trace.update( + legendgroup=trace_name, + showlegend=(trace_name != "" and trace_name not in trace_names), + ) + if trace_spec.constructor in [go.Bar, go.Violin, go.Box, go.Histogram]: + trace.update(alignmentgroup=True, offsetgroup=trace_name) + if trace_spec.constructor not in [go.Parcats, go.Parcoords]: + trace.update(hoverlabel=dict(namelength=0)) + trace_names.add(trace_name) + + # Init subplot row/col + trace._subplot_row = 1 + trace._subplot_col = 1 + + for i, m in enumerate(grouped_mappings): + val = group_name[i] + if val not in m.val_map: + m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)] + try: + m.updater(trace, m.val_map[val]) + except ValueError: + if ( + trace_spec != trace_specs[0] + and trace_spec.constructor in [go.Violin, go.Box, go.Histogram] + and m.variable == "symbol" + ): + pass + elif ( + trace_spec != trace_specs[0] + and trace_spec.constructor in [go.Histogram] + and m.variable == "color" + ): + trace.update(marker=dict(color=m.val_map[val])) + else: + raise + + # Find row for trace, handling facet_row and marginal_x + if m.facet == "row": + row = m.val_map[val] + trace._subplot_row_val = val + else: + if trace_spec.marginal == "x": + row = 2 + else: + row = 1 + + nrows = max(nrows, row) + if row > 1: + trace._subplot_row = row + + # Find col for trace, handling facet_col and marginal_y + if m.facet == "col": + col = m.val_map[val] + trace._subplot_col_val = val + else: + if trace_spec.marginal == "y": + col = 2 + else: + col = 1 + + ncols = max(ncols, col) + if col > 1: + trace._subplot_col = col + if ( + trace_specs[0].constructor == go.Histogram2dContour + and trace_spec.constructor == go.Box + and trace.line.color + ): + trace.update(marker=dict(color=trace.line.color)) + + patch, fit_results = make_trace_kwargs( + args, trace_spec, group, mapping_labels.copy(), sizeref + ) + trace.update(patch) + if fit_results is not None: + trendline_rows.append(mapping_labels.copy()) + trendline_rows[-1]["px_fit_results"] = fit_results + if frame_name not in frames: + frames[frame_name] = dict(data=[], name=frame_name) + frames[frame_name]["data"].append(trace) + frame_list = [f for f in frames.values()] + if len(frame_list) > 1: + frame_list = sorted( + frame_list, key=lambda f: orders[args["animation_frame"]].index(f["name"]) + ) + layout_patch = layout_patch.copy() + if show_colorbar: + colorvar = "z" if constructor == go.Histogram2d else "color" + range_color = args["range_color"] or [None, None] + d = len(args["color_continuous_scale"]) - 1 + layout_patch["coloraxis1"] = dict( + colorscale=[ + [(1.0 * i) / (1.0 * d), x] + for i, x in enumerate(args["color_continuous_scale"]) + ], + cmid=args["color_continuous_midpoint"], + cmin=range_color[0], + cmax=range_color[1], + colorbar=dict(title=get_decorated_label(args, args[colorvar], colorvar)), + ) + for v in ["title", "height", "width", "template"]: + if args[v]: + layout_patch[v] = args[v] + layout_patch["legend"] = {"tracegroupgap": 0} + if "title" not in layout_patch: + layout_patch["margin"] = {"t": 60} + if "size" in args and args["size"]: + layout_patch["legend"]["itemsizing"] = "constant" + + fig = init_figure( + args, subplot_type, frame_list, ncols, nrows, has_marginal_x, has_marginal_y + ) + + # Position traces in subplots + for frame in frame_list: + for trace in frame["data"]: + if isinstance(trace, go.Splom): + # Special case that is not compatible with make_subplots + continue + + _set_trace_grid_reference( + trace, + fig.layout, + fig._grid_ref, + trace._subplot_row, + trace._subplot_col, + ) + + # Add traces, layout and frames to figure + fig.add_traces(frame_list[0]["data"] if len(frame_list) > 0 else []) + fig.layout.update(layout_patch) + fig.frames = frame_list if len(frames) > 1 else [] + + fig._px_trendlines = pandas.DataFrame(trendline_rows) + + configure_axes(args, constructor, fig, orders) + configure_animation_controls(args, constructor, fig) + return fig + + +def init_figure( + args, subplot_type, frame_list, ncols, nrows, has_marginal_x, has_marginal_y +): + # Build subplot specs + specs = [[{}] * ncols for _ in range(nrows)] + column_titles = [None] * ncols + row_titles = [None] * nrows + for frame in frame_list: + for trace in frame["data"]: + row0 = nrows - trace._subplot_row + col0 = trace._subplot_col - 1 + + if isinstance(trace, go.Splom): + # Splom not compatible with make_subplots, treat as domain + specs[row0][col0] = {"type": "domain"} + else: + specs[row0][col0] = {"type": trace.type} + if args.get("facet_row", None) and hasattr(trace, "_subplot_row_val"): + row_titles[row0] = ( + args["facet_row"] + "=" + str(trace._subplot_row_val) + ) + + if args.get("facet_col", None) and hasattr(trace, "_subplot_col_val"): + column_titles[col0] = ( + args["facet_col"] + "=" + str(trace._subplot_col_val) + ) + + # Default row/column widths uniform + column_widths = [1.0] * ncols + row_heights = [1.0] * nrows + + # Build column_widths/row_heights + if subplot_type == "xy": + if has_marginal_x: + if args["marginal_x"] == "histogram" or ("color" in args and args["color"]): + main_size = 0.74 + else: + main_size = 0.84 + + row_heights = [main_size] * (nrows - 1) + [1 - main_size] + vertical_spacing = 0.01 + else: + vertical_spacing = 0.03 + + if has_marginal_y: + if args["marginal_y"] == "histogram" or ("color" in args and args["color"]): + main_size = 0.74 + else: + main_size = 0.84 + + column_widths = [main_size] * (ncols - 1) + [1 - main_size] + horizontal_spacing = 0.005 + else: + horizontal_spacing = 0.02 + else: + # Other subplot types: + # 'scene', 'geo', 'polar', 'ternary', 'mapbox', 'domain', None + # + # We can customize subplot spacing per type once we enable faceting + # for all plot types + vertical_spacing = 0.1 + horizontal_spacing = 0.1 + + # Create figure with subplots + fig = make_subplots( + rows=nrows, + cols=ncols, + specs=specs, + shared_xaxes="all", + shared_yaxes="all", + row_titles=row_titles, + column_titles=column_titles, + horizontal_spacing=horizontal_spacing, + vertical_spacing=vertical_spacing, + row_heights=row_heights, + column_widths=column_widths, + start_cell="bottom-left", + ) + + # Remove explicit font size of row/col titles so template can take over + for annot in fig.layout.annotations: + annot.update(font=None) + + return fig diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py new file mode 100644 index 00000000000..4b8a5e949d9 --- /dev/null +++ b/packages/python/plotly/plotly/express/_doc.py @@ -0,0 +1,362 @@ +import inspect + +colref = "(string: name of column in `data_frame`)" +colref_list = "(list of string: names of columns in `data_frame`)" + +# TODO contents of columns +# TODO explain categorical +# TODO handle color +# TODO handle details of box/violin/histogram +# TODO handle details of column selection with `dimensions` +# TODO document "or `None`, default `None`" in various places +# TODO standardize positioning and casing of 'default' + +docs = dict( + data_frame=["A 'tidy' `pandas.DataFrame`"], + x=[ + colref, + "Values from this column are used to position marks along the x axis in cartesian coordinates.", + "For horizontal `histogram`s, these values are used as inputs to `histfunc`.", + ], + y=[ + colref, + "Values from this column are used to position marks along the y axis in cartesian coordinates.", + "For vertical `histogram`s, these values are used as inputs to `histfunc`.", + ], + z=[ + colref, + "Values from this column are used to position marks along the z axis in cartesian coordinates.", + "For `density_heatmap` and `density_contour` these values are used as the inputs to `histfunc`.", + ], + a=[ + colref, + "Values from this column are used to position marks along the a axis in ternary coordinates.", + ], + b=[ + colref, + "Values from this column are used to position marks along the b axis in ternary coordinates.", + ], + c=[ + colref, + "Values from this column are used to position marks along the c axis in ternary coordinates.", + ], + r=[ + colref, + "Values from this column are used to position marks along the radial axis in polar coordinates.", + ], + theta=[ + colref, + "Values from this column are used to position marks along the angular axis in polar coordinates.", + ], + lat=[ + colref, + "Values from this column are used to position marks according to latitude on a map.", + ], + lon=[ + colref, + "Values from this column are used to position marks according to longitude on a map.", + ], + locations=[ + colref, + "Values from this column are be interpreted according to `locationmode` and mapped to longitude/latitude.", + ], + dimensions=[ + "(list of strings, names of columns in `data_frame`)", + "Columns to be used in multidimensional visualization.", + ], + error_x=[ + colref, + "Values from this column are used to size x-axis error bars.", + "If `error_x_minus` is `None`, error bars will be symmetrical, otherwise `error_x` is used for the positive direction only.", + ], + error_x_minus=[ + colref, + "Values from this column are used to size x-axis error bars in the negative direction.", + "Ignored if `error_x` is `None`.", + ], + error_y=[ + colref, + "Values from this column are used to size y-axis error bars.", + "If `error_y_minus` is `None`, error bars will be symmetrical, otherwise `error_y` is used for the positive direction only.", + ], + error_y_minus=[ + colref, + "Values from this column are used to size y-axis error bars in the negative direction.", + "Ignored if `error_y` is `None`.", + ], + error_z=[ + colref, + "Values from this column are used to size z-axis error bars.", + "If `error_z_minus` is `None`, error bars will be symmetrical, otherwise `error_z` is used for the positive direction only.", + ], + error_z_minus=[ + colref, + "Values from this column are used to size z-axis error bars in the negative direction.", + "Ignored if `error_z` is `None`.", + ], + color=[colref, "Values from this column are used to assign color to marks."], + opacity=["(number, between 0 and 1) Sets the opacity for markers."], + line_dash=[ + colref, + "Values from this column are used to assign dash-patterns to lines.", + ], + line_group=[ + colref, + "Values from this column are used to group rows of `data_frame` into lines.", + ], + symbol=[colref, "Values from this column are used to assign symbols to marks."], + size=[colref, "Values from this column are used to assign mark sizes."], + hover_name=[colref, "Values from this column appear in bold in the hover tooltip."], + hover_data=[ + colref_list, + "Values from these columns appear as extra data in the hover tooltip.", + ], + text=[colref, "Values from this column appear in the figure as text labels."], + locationmode=[ + "(string, one of 'ISO-3', 'USA-states', 'country names')", + "Determines the set of locations used to match entries in `locations` to regions on the map.", + ], + facet_row=[ + colref, + "Values from this column are used to assign marks to facetted subplots in the vertical direction.", + ], + facet_col=[ + colref, + "Values from this column are used to assign marks to facetted subplots in the horizontal direction.", + ], + animation_frame=[ + colref, + "Values from this column are used to assign marks to animation frames.", + ], + animation_group=[ + colref, + "Values from this column are used to provide object-constancy across animation frames: rows with matching `animation_group`s will be treated as if they describe the same object in each frame.", + ], + symbol_sequence=[ + "(list of strings defining plotly.js symbols)", + "When `symbol` is set, values in that column are assigned symbols by cycling through `symbol_sequence` in the order described in `category_orders`, unless the value of `symbol` is a key in `symbol_map`.", + ], + symbol_map=[ + "(dict with string keys and values that are strings defining plotly.js symbols, default `{}`)", + "Used to override `symbol_sequence` to assign a specific symbols to marks corresponding with specific values.", + "Keys in `symbol_map` should be values in the column denoted by `symbol`.", + ], + line_dash_map=[ + "(dict with string keys and values that are strings defining plotly.js dash-patterns, default `{}`)" + "Used to override `line_dash_sequences` to assign a specific dash-patterns to lines corresponding with specific values.", + "Keys in `line_dash_map` should be values in the column denoted by `line_dash`.", + ], + line_dash_sequence=[ + "(list of strings defining plotly.js dash-patterns)", + "When `line_dash` is set, values in that column are assigned dash-patterns by cycling through `line_dash_sequence` in the order described in `category_orders`, unless the value of `line_dash` is a key in `line_dash_map`.", + ], + color_discrete_sequence=[ + "(list of valid CSS-color strings)", + "When `color` is set and the values in the corresponding column are not numeric, values in that column are assigned colors by cycling through `color_discrete_sequence` in the order described in `category_orders`, unless the value of `color` is a key in `color_discrete_map`.", + "Various useful color sequences are available in the `plotly_express.colors` submodules, specifically `plotly_express.colors.qualitative`.", + ], + color_discrete_map=[ + "(dict with string keys and values that are valid CSS-color strings, default `{}`)", + "Used to override `color_discrete_sequence` to assign a specific colors to marks corresponding with specific values.", + "Keys in `color_discrete_map` should be values in the column denoted by `color`.", + ], + color_continuous_scale=[ + "(list of valid CSS-color strings)", + "This list is used to build a continuous color scale when the column denoted by `color` contains numeric data.", + "Various useful color scales are available in the `plotly_express.colors` submodules, specifically `plotly_express.colors.sequential`, `plotly_express.colors.diverging` and `plotly_express.colors.cyclical`.", + ], + color_continuous_midpoint=[ + "(number, defaults to `None`)", + "If set, computes the bounds of the continuous color scale to have the desired midpoint.", + "Setting this value is recommended when using `plotly_express.colors.diverging` color scales as the inputs to `color_continuous_scale`.", + ], + size_max=["(integer, default 20)", "Set the maximum mark size when using `size`."], + log_x=[ + "(boolean, default `False`)", + "If `True`, the x-axis is log-scaled in cartesian coordinates.", + ], + log_y=[ + "(boolean, default `False`)", + "If `True`, the y-axis is log-scaled in cartesian coordinates.", + ], + log_z=[ + "(boolean, default `False`)", + "If `True`, the z-axis is log-scaled in cartesian coordinates.", + ], + log_r=[ + "(boolean, default `False`)", + "If `True`, the radial axis is log-scaled in polar coordinates.", + ], + range_x=[ + "(2-element list of numbers)", + "If provided, overrides auto-scaling on the x-axis in cartesian coordinates.", + ], + range_y=[ + "(2-element list of numbers)", + "If provided, overrides auto-scaling on the y-axis in cartesian coordinates.", + ], + range_z=[ + "(2-element list of numbers)", + "If provided, overrides auto-scaling on the z-axis in cartesian coordinates.", + ], + range_color=[ + "(2-element list of numbers)", + "If provided, overrides auto-scaling on the continuous color scale.", + ], + range_r=[ + "(2-element list of numbers)", + "If provided, overrides auto-scaling on the radial axis in polar coordinates.", + ], + title=["(string)", "The figure title."], + template=[ + "(string or Plotly.py template object)", + "The figure template name or definition.", + ], + width=["(integer, default `None`)", "The figure width in pixels."], + height=["(integer, default `600`)", "The figure height in pixels."], + labels=[ + "(dict with string keys and string values, default `{}`)", + "By default, column names are used in the figure for axis titles, legend entries and hovers.", + "This parameter allows this to be overridden.", + "The keys of this dict should correspond to column names, and the values should correspond to the desired label to be displayed.", + ], + category_orders=[ + "(dict with string keys and list-of-string values, default `{}`)", + "By default, in Python 3.6+, the order of categorical values in axes, legends and facets depends on the order in which these values are first encountered in `data_frame` (and no order is guaranteed by default in Python below 3.6).", + "This parameter is used to force a specific ordering of values per column.", + "The keys of this dict should correspond to column names, and the values should be lists of strings corresponding to the specific display order desired.", + ], + marginal=[ + "(string, one of `'rug'`, `'box'`, `'violin'`, `'histogram'`)", + "If set, a subplot is drawn alongside the main plot, visulizing the distribution.", + ], + marginal_x=[ + "(string, one of `'rug'`, `'box'`, `'violin'`, `'histogram'`)", + "If set, a horizontal subplot is drawn above the main plot, visulizing the x-distribution.", + ], + marginal_y=[ + "(string, one of `'rug'`, `'box'`, `'violin'`, `'histogram'`)", + "If set, a vertical subplot is drawn to the right of the main plot, visulizing the y-distribution.", + ], + trendline=[ + "(string, one of `'ols'` or `'lowess'`, default `None`)", + "If `'ols'`, an Ordinary Least Squares regression line will be drawn for each discrete-color/symbol group.", + "If `'lowess`', a Locally Weighted Scatterplot Smoothing line will be drawn for each discrete-color/symbol group.", + ], + trendline_color_override=[ + "(string, valid CSS color)", + "If provided, and if `trendline` is set, all trendlines will be drawn in this color.", + ], + render_mode=[ + "(string, one of `'auto'`, `'svg'` or `'webgl'`, default `'auto'`)", + "Controls the browser API used to draw marks.", + "`'svg`' is appropriate for figures of less than 1000 data points, and will allow for fully-vectorized output.", + "`'webgl'` is likely necessary for acceptable performance above 1000 points but rasterizes part of the output. ", + "`'auto'` uses heuristics to choose the mode.", + ], + direction=[ + "(string, one of '`counterclockwise'`, `'clockwise'`. Default is `'clockwise'`)", + "Sets the direction in which increasing values of the angular axis are drawn.", + ], + start_angle=[ + "(integer, default is 90)", + "Sets start angle for the angular axis, with 0 being due east and 90 being due north.", + ], + histfunc=[ + "(string, one of `'count'`, `'sum'`, `'avg'`, `'min'`, `'max'`. Default is `'count'`)" + "Function used to aggregate values for summarization (note: can be normalized with `histnorm`).", + "The arguments to this function for `histogram` are the values of `y` if `orientation` is `'v'`,", + "otherwise the arguements are the values of `x`.", + "The arguments to this function for `density_heatmap` and `density_contour` are the values of `z`.", + ], + histnorm=[ + "(string, one of `'percent'`, `'probability'`, `'density'`, `'probability density'`, default `None`)", + "If `None`, the output of `histfunc` is used as is.", + "If `'probability'`, the output of `histfunc` for a given bin is divided by the sum of the output of `histfunc` for all bins.", + "If `'percent'`, the output of `histfunc` for a given bin is divided by the sum of the output of `histfunc` for all bins and multiplied by 100.", + "If `'density'`, the output of `histfunc` for a given bin is divided by the size of the bin.", + "If `'probability density'`, the output of `histfunc` for a given bin is normalized such that it corresponds to the probability that a random event whose distribution is described by the output of `histfunc` will fall into that bin.", + ], + barnorm=[ + "(string, one of `'fraction'` or `'percent'`, default is `None`)", + "If set to `'fraction'`, the value of each bar is divided by the sum of all values at that location coordinate.", + "`'percent'` is the same but multiplied by 100 to show percentages.", + ], + groupnorm=[ + "(string, one of `'fraction'` or `'percent'`, default is `None`)", + "If set to `'fraction'`, the value of each point is divided by the sum of all values at that location coordinate.", + "`'percent'` is the same but multiplied by 100 to show percentages.", + ], + barmode=[ + "(string, one of `'group'`, `'overlay'` or `'relative'`. Default is `'relative'`)", + "In `'relative'` mode, bars are stacked above zero for positive values and below zero for negative values.", + "In `'overlay'` mode, bars are on drawn top of one another.", + "In `'group'` mode, bars are placed beside each other.", + ], + boxmode=[ + "(string, one of `'group'` or `'overlay'`. Default is `'group'`)", + "In `'overlay'` mode, boxes are on drawn top of one another.", + "In `'group'` mode, baxes are placed beside each other.", + ], + violinmode=[ + "(string, one of `'group'` or `'overlay'`. Default is `'group'`)", + "In `'overlay'` mode, violins are on drawn top of one another.", + "In `'group'` mode, violins are placed beside each other.", + ], + stripmode=[ + "(string, one of `'group'` or `'overlay'`. Default is `'group'`)", + "In `'overlay'` mode, strips are on drawn top of one another.", + "In `'group'` mode, strips are placed beside each other.", + ], + zoom=["(integer between 0 and 20, default is 8)", "Sets map zoom level."], + orientation=[ + "(string, one of `'h'` for horizontal or `'v'` for vertical)", + "Default is `'v'`.", + ], + line_close=[ + "(boolean, default `False`)", + "If `True`, an extra line segment is drawn between the first and last point.", + ], + line_shape=["(string, one of `'linear'` or `'spline'`)", "Default is `'linear'`."], + scope=[ + "(string, one of `'world'`, `'usa'`, `'europe'`, `'asia'`, `'africa'`, `'north america'`, `'south america'`)" + "Default is `'world'` unless `projection` is set to `'albers usa'`, which forces `'usa'`." + ], + projection=[ + "(string, one of `'equirectangular'`, `'mercator'`, `'orthographic'`, `'natural earth'`, `'kavrayskiy7'`, `'miller'`, `'robinson'`, `'eckert4'`, `'azimuthal equal area'`, `'azimuthal equidistant'`, `'conic equal area'`, `'conic conformal'`, `'conic equidistant'`, `'gnomonic'`, `'stereographic'`, `'mollweide'`, `'hammer'`, `'transverse mercator'`, `'albers usa'`, `'winkel tripel'`, `'aitoff'`, `'sinusoidal'`)" + "Default depends on `scope`." + ], + center=["(dict with `lat` and `lon` keys)", "Sets the center point of the map."], + points=[ + "(string or boolean, one of `'all'`, `'outliers'`, or `False`. Default is `'outliers'`)", + "If `'outliers'`, only the sample points lying outside the whiskers are shown.", + "If `'all'`, all sample points are shown.", + "If `False`, no sample points are shown", + ], + box=[ + "(boolean, default `False`)", + "If `True`, boxes are drawn inside the violins.", + ], + notched=["(boolean, default `False`)", "If `True`, boxes are drawn with notches."], + cumulative=[ + "(boolean, default `False`)", + "If `True`, histogram values are cumulative.", + ], + nbins=["(positive integer)", "Sets the number of bins."], + nbinsx=["(positive integer)", "Sets the number of bins along the x axis."], + nbinsy=["(positive integer)", "Sets the number of bins along the y axis."], +) + + +def make_docstring(fn): + result = (fn.__doc__ or "") + "\nArguments:\n" + for arg in inspect.getargspec(fn)[0]: + d = ( + " ".join(docs[arg] or "") + if arg in docs + else "(documentation missing from map)" + ) + result += " %s: %s\n" % (arg, d) + result += "Returns:\n" + result += " A `Figure` object." + return result diff --git a/packages/python/plotly/plotly/express/colors.py b/packages/python/plotly/plotly/express/colors.py new file mode 100644 index 00000000000..134f1707e9a --- /dev/null +++ b/packages/python/plotly/plotly/express/colors.py @@ -0,0 +1,2 @@ +from __future__ import absolute_import +from plotly.colors import * diff --git a/packages/python/plotly/plotly/express/data.py b/packages/python/plotly/plotly/express/data.py new file mode 100644 index 00000000000..375de95c4fd --- /dev/null +++ b/packages/python/plotly/plotly/express/data.py @@ -0,0 +1,2 @@ +from __future__ import absolute_import +from plotly.data import * diff --git a/packages/python/plotly/setup.py b/packages/python/plotly/setup.py index 69048c63071..271e7d6035c 100644 --- a/packages/python/plotly/setup.py +++ b/packages/python/plotly/setup.py @@ -431,6 +431,9 @@ def run(self): 'plotly.matplotlylib.mplexporter', 'plotly.matplotlylib.mplexporter.renderers', 'plotly.figure_factory', + 'plotly.data', + 'plotly.colors', + 'plotly.express', '_plotly_utils', '_plotly_future_', ] + graph_objs_packages + validator_packages,