diff --git a/gseapy/plot.py b/gseapy/plot.py index c9cb264..9280935 100644 --- a/gseapy/plot.py +++ b/gseapy/plot.py @@ -938,6 +938,22 @@ def add_colorbar(self, sc): for key, spine in cbar.ax.spines.items(): spine.set_visible(False) + def _parse_colors(self, color=None): + """ + parse colors for groups + """ + # map color to group + if isinstance(color, dict): + return list(color.values()) + # get default color cycle + if (not isinstance(color, str)) and hasattr(color, "__len__"): + _colors = list(color) + else: + # get current matplotlib color cycle + prop_cycle = plt.rcParams["axes.prop_cycle"] + _colors = prop_cycle.by_key()["color"] + return _colors + def barh(self, color=None, group=None, ax=None): """ Barplot @@ -956,31 +972,31 @@ def barh(self, color=None, group=None, ax=None): bar.set_ylabel("") bar.set_title(self.title, fontsize=24, fontweight="bold") bar.xaxis.set_major_locator(MaxNLocator(nbins=5, integer=True)) - - # get default color cycle - if (not isinstance(color, str)) and hasattr(color, "__len__"): - _colors = list(color) - else: - prop_cycle = plt.rcParams["axes.prop_cycle"] - _colors = prop_cycle.by_key()["color"] - colors = _colors + # + _colors = self._parse_colors(color=color) # remove old legend first bar.legend_.remove() if (group is not None) and (group in self.data.columns): num_grp = self.data[group].value_counts(sort=False) - # set colors for each bar (groupby hue) + # set colors for each bar (groupby hue) using full length colors = [] legend_elements = [] for i, n in enumerate(num_grp): # cycle _colors if num_grp > len(_colors) c = _colors[i % len(_colors)] + # group_label + label = num_grp.index[i] + # if input color is a dict with keys in group + if isinstance(color, dict) and label in color: + c = color[label] + # expand the length to match bars colors += [c] * n ele = Line2D( xdata=[0], ydata=[0], marker="o", color="w", - label=num_grp.index[i], + label=label, markerfacecolor=c, markersize=8, ) @@ -993,6 +1009,7 @@ def barh(self, color=None, group=None, ax=None): bbox_to_anchor=(1.02, 0.5), frameon=False, ) + # update color of bars for j, b in enumerate(ax.patches): c = colors[j % len(colors)] @@ -1210,7 +1227,7 @@ def barplot( cutoff: float = 0.05, top_term: int = 10, figsize: Tuple[float, float] = (4, 6), - color: Union[str, List[str]] = "salmon", + color: Union[str, List[str], Dict[str, str]] = "salmon", ofname: Optional[str] = None, **kwargs, ): @@ -1225,7 +1242,8 @@ def barplot( ("Adjusted P-value", "P-value", "NOM p-val", "FDR q-val") :param top_term: number of top enriched terms grouped by `hue` are shown. :param figsize: tuple, matplotlib figsize. - :param color: color or list of matplotlib.colors. Must be reconigzed by matplotlib. + :param color: color or list or dict of matplotlib.colors. Must be reconigzed by matplotlib. + if dict input, dict keys must be found in the `group` :param ofname: output file name. If None, don't save figure :return: matplotlib.Axes. return None if given ofname.