diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index d3ea809b79b76..6124da58995d8 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -755,9 +755,9 @@ class MPLPlot(object): _default_rot = 0 _pop_attributes = ['label', 'style', 'logy', 'logx', 'loglog', - 'mark_right'] + 'mark_right', 'stacked'] _attr_defaults = {'logy': False, 'logx': False, 'loglog': False, - 'mark_right': True} + 'mark_right': True, 'stacked': False} def __init__(self, data, kind=None, by=None, subplots=False, sharex=True, sharey=False, use_index=True, @@ -1080,7 +1080,6 @@ def _make_legend(self): for ax in self.axes: ax.legend(loc='best') - def _get_ax_legend(self, ax): leg = ax.get_legend() other_ax = (getattr(ax, 'right_ax', None) or @@ -1139,12 +1138,22 @@ def _get_plot_function(self): Returns the matplotlib plotting function (plot or errorbar) based on the presence of errorbar keywords. ''' - - if all(e is None for e in self.errors.values()): - plotf = self.plt.Axes.plot - else: - plotf = self.plt.Axes.errorbar - + errorbar = any(e is not None for e in self.errors.values()) + def plotf(ax, x, y, style=None, **kwds): + mask = com.isnull(y) + if mask.any(): + y = np.ma.array(y) + y = np.ma.masked_where(mask, y) + + if errorbar: + return self.plt.Axes.errorbar(ax, x, y, **kwds) + else: + # prevent style kwarg from going to errorbar, where it is unsupported + if style is not None: + args = (ax, x, y, style) + else: + args = (ax, x, y) + return self.plt.Axes.plot(*args, **kwds) return plotf def _get_index_name(self): @@ -1472,11 +1481,9 @@ def _post_plot_logic(self): class LinePlot(MPLPlot): def __init__(self, data, **kwargs): - self.stacked = kwargs.pop('stacked', False) - if self.stacked: - data = data.fillna(value=0) - MPLPlot.__init__(self, data, **kwargs) + if self.stacked: + self.data = self.data.fillna(value=0) self.x_compat = plot_params['x_compat'] if 'x_compat' in self.kwds: self.x_compat = bool(self.kwds.pop('x_compat')) @@ -1533,56 +1540,39 @@ def _is_ts_plot(self): return not self.x_compat and self.use_index and self._use_dynamic_x() def _make_plot(self): - self._pos_prior = np.zeros(len(self.data)) - self._neg_prior = np.zeros(len(self.data)) + self._initialize_prior(len(self.data)) if self._is_ts_plot(): data = self._maybe_convert_index(self.data) - self._make_ts_plot(data) + x = data.index # dummy, not used + plotf = self._get_ts_plot_function() + it = self._iter_data(data=data, keep_index=True) else: x = self._get_xticks(convert_period=True) - plotf = self._get_plot_function() - colors = self._get_colors() - - for i, (label, y) in enumerate(self._iter_data()): - ax = self._get_ax(i) - style = self._get_style(i, label) - kwds = self.kwds.copy() - self._maybe_add_color(colors, kwds, style, i) + it = self._iter_data() - errors = self._get_errorbars(label=label, index=i) - kwds = dict(kwds, **errors) - - label = com.pprint_thing(label) # .encode('utf-8') - kwds['label'] = label - - y_values = self._get_stacked_values(y, label) - - if not self.stacked: - mask = com.isnull(y_values) - if mask.any(): - y_values = np.ma.array(y_values) - y_values = np.ma.masked_where(mask, y_values) + colors = self._get_colors() + for i, (label, y) in enumerate(it): + ax = self._get_ax(i) + style = self._get_style(i, label) + kwds = self.kwds.copy() + self._maybe_add_color(colors, kwds, style, i) - # prevent style kwarg from going to errorbar, where it is unsupported - if style is not None and plotf.__name__ != 'errorbar': - args = (ax, x, y_values, style) - else: - args = (ax, x, y_values) + errors = self._get_errorbars(label=label, index=i) + kwds = dict(kwds, **errors) - newlines = plotf(*args, **kwds) - self._add_legend_handle(newlines[0], label, index=i) + label = com.pprint_thing(label) # .encode('utf-8') + kwds['label'] = label + y_values = self._get_stacked_values(y, label) - if self.stacked and not self.subplots: - if (y >= 0).all(): - self._pos_prior += y - elif (y <= 0).all(): - self._neg_prior += y + newlines = plotf(ax, x, y_values, style=style, **kwds) + self._update_prior(y) + self._add_legend_handle(newlines[0], label, index=i) - lines = _get_all_lines(ax) - left, right = _get_xlim(lines) - ax.set_xlim(left, right) + lines = _get_all_lines(ax) + left, right = _get_xlim(lines) + ax.set_xlim(left, right) def _get_stacked_values(self, y, label): if self.stacked: @@ -1599,46 +1589,26 @@ def _get_stacked_values(self, y, label): def _get_ts_plot_function(self): from pandas.tseries.plotting import tsplot plotf = self._get_plot_function() - - def _plot(data, ax, label, style, **kwds): - # errorbar function does not support style argument - if plotf.__name__ == 'errorbar': - lines = tsplot(data, plotf, ax=ax, label=label, - **kwds) - return lines - else: - lines = tsplot(data, plotf, ax=ax, label=label, - style=style, **kwds) - return lines + def _plot(ax, x, data, style=None, **kwds): + # accept x to be consistent with normal plot func, + # x is not passed to tsplot as it uses data.index as x coordinate + lines = tsplot(data, plotf, ax=ax, style=style, **kwds) + return lines return _plot - def _make_ts_plot(self, data, **kwargs): - colors = self._get_colors() - plotf = self._get_ts_plot_function() - - it = self._iter_data(data=data, keep_index=True) - for i, (label, y) in enumerate(it): - ax = self._get_ax(i) - style = self._get_style(i, label) - kwds = self.kwds.copy() - - self._maybe_add_color(colors, kwds, style, i) - - errors = self._get_errorbars(label=label, index=i, xerr=False) - kwds = dict(kwds, **errors) - - label = com.pprint_thing(label) - - y_values = self._get_stacked_values(y, label) - - newlines = plotf(y_values, ax, label, style, **kwds) - self._add_legend_handle(newlines[0], label, index=i) + def _initialize_prior(self, n): + self._pos_prior = np.zeros(n) + self._neg_prior = np.zeros(n) - if self.stacked and not self.subplots: - if (y >= 0).all(): - self._pos_prior += y - elif (y <= 0).all(): - self._neg_prior += y + def _update_prior(self, y): + if self.stacked and not self.subplots: + # tsplot resample may changedata length + if len(self._pos_prior) != len(y): + self._initialize_prior(len(y)) + if (y >= 0).all(): + self._pos_prior += y + elif (y <= 0).all(): + self._neg_prior += y def _maybe_convert_index(self, data): # tsplot converts automatically, but don't want to convert index @@ -1707,13 +1677,14 @@ def _get_plot_function(self): if self.logy or self.loglog: raise ValueError("Log-y scales are not supported in area plot") else: - f = LinePlot._get_plot_function(self) - - def plotf(*args, **kwds): - lines = f(*args, **kwds) + f = MPLPlot._get_plot_function(self) + def plotf(ax, x, y, style=None, **kwds): + lines = f(ax, x, y, style=style, **kwds) + # get data from the line # insert fill_between starting point - y = args[2] + xdata, y_values = lines[0].get_data(orig=False) + if (y >= 0).all(): start = self._pos_prior elif (y <= 0).all(): @@ -1721,16 +1692,10 @@ def plotf(*args, **kwds): else: start = np.zeros(len(y)) - # get x data from the line - # to retrieve x coodinates of tsplot - xdata = lines[0].get_data()[0] - # remove style - args = (args[0], xdata, start, y) - if not 'color' in kwds: kwds['color'] = lines[0].get_color() - self.plt.Axes.fill_between(*args, **kwds) + self.plt.Axes.fill_between(ax, xdata, start, y_values, **kwds) return lines return plotf @@ -1746,15 +1711,6 @@ def _add_legend_handle(self, handle, label, index=None): def _post_plot_logic(self): LinePlot._post_plot_logic(self) - if self._is_ts_plot(): - pass - else: - if self.xlim is None: - for ax in self.axes: - lines = _get_all_lines(ax) - left, right = _get_xlim(lines) - ax.set_xlim(left, right) - if self.ylim is None: if (self.data >= 0).all().all(): for ax in self.axes: @@ -1769,12 +1725,8 @@ class BarPlot(MPLPlot): _default_rot = {'bar': 90, 'barh': 0} def __init__(self, data, **kwargs): - self.stacked = kwargs.pop('stacked', False) - self.bar_width = kwargs.pop('width', 0.5) - pos = kwargs.pop('position', 0.5) - kwargs.setdefault('align', 'center') self.tick_pos = np.arange(len(data)) diff --git a/pandas/tseries/plotting.py b/pandas/tseries/plotting.py index 6031482fd9927..33a14403b0f08 100644 --- a/pandas/tseries/plotting.py +++ b/pandas/tseries/plotting.py @@ -18,8 +18,6 @@ from pandas.tseries.converter import (PeriodConverter, TimeSeries_DateLocator, TimeSeries_DateFormatter) -from pandas.tools.plotting import _get_all_lines, _get_xlim - #---------------------------------------------------------------------- # Plotting functions and monkey patches @@ -59,25 +57,15 @@ def tsplot(series, plotf, **kwargs): # Set ax with freq info _decorate_axes(ax, freq, kwargs) - # mask missing values - args = _maybe_mask(series) - # how to make sure ax.clear() flows through? if not hasattr(ax, '_plot_data'): ax._plot_data = [] ax._plot_data.append((series, kwargs)) - # styles - style = kwargs.pop('style', None) - if style is not None: - args.append(style) - - lines = plotf(ax, *args, **kwargs) + lines = plotf(ax, series.index, series.values, **kwargs) # set date formatter, locators and rescale limits format_dateaxis(ax, ax.freq) - left, right = _get_xlim(_get_all_lines(ax)) - ax.set_xlim(left, right) # x and y coord info ax.format_coord = lambda t, y: ("t = {0} " @@ -165,8 +153,7 @@ def _replot_ax(ax, freq, plotf, kwargs): idx = series.index.asfreq(freq, how='S') series.index = idx ax._plot_data.append(series) - args = _maybe_mask(series) - lines.append(plotf(ax, *args, **kwds)[0]) + lines.append(plotf(ax, series.index, series.values, **kwds)[0]) labels.append(com.pprint_thing(series.name)) return lines, labels @@ -184,17 +171,6 @@ def _decorate_axes(ax, freq, kwargs): ax.date_axis_info = None -def _maybe_mask(series): - mask = isnull(series) - if mask.any(): - masked_array = np.ma.array(series.values) - masked_array = np.ma.masked_where(mask, masked_array) - args = [series.index, masked_array] - else: - args = [series.index, series.values] - return args - - def _get_freq(ax, series): # get frequency from data freq = getattr(series.index, 'freq', None)