Skip to content

Commit

Permalink
MAINT: theme_seaborn
Browse files Browse the repository at this point in the history
  • Loading branch information
has2k1 committed Mar 25, 2022
1 parent dddab55 commit c87a2df
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 44 deletions.
Binary file modified plotnine/tests/baseline_images/test_theme/theme_seaborn.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
84 changes: 43 additions & 41 deletions plotnine/themes/seaborn_rcmod.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
import matplotlib as _mpl
import functools

# https://github.com/mwaskom/seaborn/seaborn/rcmod.py
# commit: d19fff8
# https://github.com/mwaskom/seaborn/blob/master/seaborn/rcmod.py
# License: BSD-3-Clause License
#
# Modifications
# ---------------
# modified set()
# modified set_theme()
# removed set_palette(), reset_defaults(), reset_orig()
# set mpl_ge_150, mpl_ge_2 for MPL > 3
#
# We (plotnine) do not want to modify the rcParams
# on the matplotlib instance, so we create a dummy object
Expand All @@ -24,9 +23,7 @@ class dummy:

mpl = dummy()
mpl.__version__ = _mpl.__version__

mpl_ge_150 = False
mpl_ge_2 = True
mpl.rcParams = {}


_style_keys = [
Expand All @@ -51,30 +48,23 @@ class dummy:
"lines.solid_capstyle",

"patch.edgecolor",
"patch.force_edgecolor",

"image.cmap",
"font.family",
"font.sans-serif",

]

if mpl_ge_2:

_style_keys.extend([

"patch.force_edgecolor",
"xtick.bottom",
"xtick.top",
"ytick.left",
"ytick.right",

"xtick.bottom",
"xtick.top",
"ytick.left",
"ytick.right",
"axes.spines.left",
"axes.spines.bottom",
"axes.spines.right",
"axes.spines.top",

"axes.spines.left",
"axes.spines.bottom",
"axes.spines.right",
"axes.spines.top",

])
]

_context_keys = [

Expand All @@ -84,6 +74,7 @@ class dummy:
"xtick.labelsize",
"ytick.labelsize",
"legend.fontsize",
"legend.title_fontsize",

"axes.linewidth",
"grid.linewidth",
Expand All @@ -101,11 +92,11 @@ class dummy:
"xtick.minor.size",
"ytick.minor.size",

]
]


def set(context="notebook", style="darkgrid", palette="deep",
font="sans-serif", font_scale=1, color_codes=False, rc=None):
def set_theme(context="notebook", style="darkgrid", palette="deep",
font="sans-serif", font_scale=1, color_codes=False, rc=None):
"""Set aesthetic parameters in one step.
Each set of parameters can be set directly or temporarily, see the
referenced functions below for more information.
Expand All @@ -130,14 +121,21 @@ def set(context="notebook", style="darkgrid", palette="deep",
Dictionary of rc parameter mappings to override the above.
"""
mpl.rcParams = {}
set_context(context, font_scale)
set_style(style, rc={"font.family": font})
if rc is not None:
mpl.rcParams.update(rc)
return mpl.rcParams


def set(*args, **kwargs):
"""
Alias for :func:`set_theme`, which is the preferred interface.
This function may be removed in the future.
"""
set_theme(*args, **kwargs)


def axes_style(style=None, rc=None):
"""Return a parameter dict for the aesthetic style of the plots.
Expand Down Expand Up @@ -219,17 +217,17 @@ def axes_style(style=None, rc=None):
"xtick.top": False,
"ytick.right": False,

}
}

# Set grid on or off
if "grid" in style:
style_dict.update({
"axes.grid": True,
})
})
else:
style_dict.update({
"axes.grid": False,
})
})

# Set the color of the background, spines, and grids
if style.startswith("dark"):
Expand All @@ -244,7 +242,7 @@ def axes_style(style=None, rc=None):
"axes.spines.right": True,
"axes.spines.top": True,

})
})

elif style == "whitegrid":
style_dict.update({
Expand All @@ -258,7 +256,7 @@ def axes_style(style=None, rc=None):
"axes.spines.right": True,
"axes.spines.top": True,

})
})

elif style in ["white", "ticks"]:
style_dict.update({
Expand All @@ -272,19 +270,19 @@ def axes_style(style=None, rc=None):
"axes.spines.right": True,
"axes.spines.top": True,

})
})

# Show or hide the axes ticks
if style == "ticks":
style_dict.update({
"xtick.bottom": True,
"ytick.left": True,
})
})
else:
style_dict.update({
"xtick.bottom": False,
"ytick.left": False,
})
})

# Remove entries that are not defined in the base list of valid keys
# This lets us handle matplotlib <=/> 2.0
Expand Down Expand Up @@ -376,7 +374,6 @@ def plotting_context(context=None, font_scale=1, rc=None):
set_context : set the matplotlib parameters to scale plot elements
axes_style : return a dict of parameters defining a figure style
color_palette : define the color palette for a plot
"""
if context is None:
context_dict = {k: mpl.rcParams[k] for k in _context_keys}
Expand All @@ -391,14 +388,19 @@ def plotting_context(context=None, font_scale=1, rc=None):
raise ValueError("context must be in %s" % ", ".join(contexts))

# Set up dictionary of default parameters
base_context = {
texts_base_context = {

"font.size": 12,
"axes.labelsize": 12,
"axes.titlesize": 12,
"xtick.labelsize": 11,
"ytick.labelsize": 11,
"legend.fontsize": 11,
"legend.title_fontsize": 12,

}

base_context = {

"axes.linewidth": 1.25,
"grid.linewidth": 1,
Expand All @@ -416,15 +418,15 @@ def plotting_context(context=None, font_scale=1, rc=None):
"xtick.minor.size": 4,
"ytick.minor.size": 4,

}
}
base_context.update(texts_base_context)

# Scale all the parameters by the same factor depending on the context
scaling = dict(paper=.8, notebook=1, talk=1.5, poster=2)[context]
context_dict = {k: v * scaling for k, v in base_context.items()}

# Now independently scale the fonts
font_keys = ["axes.labelsize", "axes.titlesize", "legend.fontsize",
"xtick.labelsize", "ytick.labelsize", "font.size"]
font_keys = texts_base_context.keys()
font_dict = {k: context_dict[k] * font_scale for k in font_keys}
context_dict.update(font_dict)

Expand Down
6 changes: 3 additions & 3 deletions plotnine/themes/theme_seaborn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ..options import get_option
from .theme import theme
from .seaborn_rcmod import set as seaborn_set
from .seaborn_rcmod import set_theme


class theme_seaborn(theme):
Expand Down Expand Up @@ -33,6 +33,6 @@ def __init__(self, style='darkgrid', context='notebook',
figure_size=get_option('figure_size'),
panel_spacing=0.1,
complete=True)
d = seaborn_set(context=context, style=style,
font=font, font_scale=font_scale)
d = set_theme(context=context, style=style,
font=font, font_scale=font_scale)
self._rcParams.update(d)

0 comments on commit c87a2df

Please sign in to comment.