From 88bbac1a76d41f35099d530b6898d35225dfcc29 Mon Sep 17 00:00:00 2001 From: sinhrks Date: Mon, 5 May 2014 07:18:54 +0900 Subject: [PATCH] ENH/BUG: boxplot now supports layout --- doc/source/release.rst | 2 + doc/source/v0.14.0.txt | 2 + pandas/tests/test_graphics.py | 148 ++++++++++++++++++----- pandas/tools/plotting.py | 219 +++++++++++++++++----------------- 4 files changed, 229 insertions(+), 142 deletions(-) diff --git a/doc/source/release.rst b/doc/source/release.rst index 3e6f7bb232156..009f7df3f9c03 100644 --- a/doc/source/release.rst +++ b/doc/source/release.rst @@ -332,6 +332,7 @@ Improvements to existing features - Arrays of strings can be wrapped to a specified width (``str.wrap``) (:issue:`6999`) - ``GroupBy.count()`` is now implemented in Cython and is much faster for large numbers of groups (:issue:`7016`). +- ``boxplot`` now supports ``layout`` keyword (:issue:`6769`) .. _release.bug_fixes-0.14.0: @@ -485,6 +486,7 @@ Bug Fixes - Bug in cache coherence with chained indexing and slicing; add ``_is_view`` property to ``NDFrame`` to correctly predict views; mark ``is_copy`` on ``xs` only if its an actual copy (and not a view) (:issue:`7084`) - Bug in DatetimeIndex creation from string ndarray with ``dayfirst=True`` (:issue:`5917`) +- Bug in ``boxplot`` and ``hist`` draws unnecessary axes (:issue:`6769`) pandas 0.13.1 ------------- diff --git a/doc/source/v0.14.0.txt b/doc/source/v0.14.0.txt index cde6bf3bfd670..3b3570eb38675 100644 --- a/doc/source/v0.14.0.txt +++ b/doc/source/v0.14.0.txt @@ -390,6 +390,8 @@ Plotting positional argument ``frame`` instead of ``data``. A ``FutureWarning`` is raised if the old ``data`` argument is used by name. (:issue:`6956`) +- ``boxplot`` now supports ``layout`` keyword (:issue:`6769`) + .. _whatsnew_0140.prior_deprecations: Prior Version Deprecations/Changes diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index ca17e74d5eb07..cb3f9183beb81 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -13,6 +13,7 @@ from pandas.compat import range, lrange, StringIO, lmap, lzip, u, zip import pandas.util.testing as tm from pandas.util.testing import ensure_clean +import pandas.core.common as com from pandas.core.config import set_option @@ -1837,6 +1838,19 @@ def test_errorbar_scatter(self): @tm.mplskip class TestDataFrameGroupByPlots(tm.TestCase): + + def setUp(self): + n = 100 + with tm.RNGContext(42): + gender = tm.choice(['Male', 'Female'], size=n) + classroom = tm.choice(['A', 'B', 'C'], size=n) + + self.hist_df = DataFrame({'gender': gender, + 'classroom': classroom, + 'height': random.normal(66, 4, size=n), + 'weight': random.normal(161, 32, size=n), + 'category': random.randint(4, size=n)}) + def tearDown(self): tm.close() @@ -1924,39 +1938,117 @@ def test_grouped_hist(self): with tm.assertRaises(AttributeError): plotting.grouped_hist(df.A, by=df.C, foo='bar') + def _check_axes_shape(self, axes, axes_num=None, layout=None, figsize=(8.0, 6.0)): + """ + Check expected number of axes is drawn in expected layout + + Parameters + ---------- + axes : matplotlib Axes object, or its list-like + axes_num : number + expected number of axes. Unnecessary axes should be set to invisible. + layout : tuple + expected layout + figsize : tuple + expected figsize. default is matplotlib default + """ + visible_axes = self._flatten_visible(axes) + + if axes_num is not None: + self.assertEqual(len(visible_axes), axes_num) + for ax in visible_axes: + # check something drawn on visible axes + self.assert_(len(ax.get_children()) > 0) + + if layout is not None: + if isinstance(axes, list): + self.assertEqual((len(axes), ), layout) + elif isinstance(axes, np.ndarray): + self.assertEqual(axes.shape, layout) + else: + # in case of AxesSubplot + self.assertEqual((1, ), layout) + + self.assert_numpy_array_equal(np.round(visible_axes[0].figure.get_size_inches()), + np.array(figsize)) + + def _flatten_visible(self, axes): + axes = plotting._flatten(axes) + axes = [ax for ax in axes if ax.get_visible()] + return axes + @slow - def test_grouped_hist_layout(self): + def test_grouped_box_layout(self): import matplotlib.pyplot as plt - n = 100 - gender = tm.choice(['Male', 'Female'], size=n) - df = DataFrame({'gender': gender, - 'height': random.normal(66, 4, size=n), - 'weight': random.normal(161, 32, size=n), - 'category': random.randint(4, size=n)}) - self.assertRaises(ValueError, df.hist, column='weight', by=df.gender, + df = self.hist_df + + self.assertRaises(ValueError, df.boxplot, column=['weight', 'height'], by=df.gender, layout=(1, 1)) + self.assertRaises(ValueError, df.boxplot, column=['height', 'weight', 'category'], + layout=(2, 1)) + + box = _check_plot_works(df.groupby('gender').boxplot, column='height') + self._check_axes_shape(plt.gcf().axes, axes_num=2) + + box = _check_plot_works(df.groupby('category').boxplot, column='height') + self._check_axes_shape(plt.gcf().axes, axes_num=4) + + # GH 6769 + box = _check_plot_works(df.groupby('classroom').boxplot, column='height') + self._check_axes_shape(plt.gcf().axes, axes_num=3) + + box = df.boxplot(column=['height', 'weight', 'category'], by='gender') + self._check_axes_shape(plt.gcf().axes, axes_num=3) + + box = df.groupby('classroom').boxplot(column=['height', 'weight', 'category']) + self._check_axes_shape(plt.gcf().axes, axes_num=3) + + box = _check_plot_works(df.groupby('category').boxplot, column='height', layout=(3, 2)) + self._check_axes_shape(plt.gcf().axes, axes_num=4) + + box = df.boxplot(column=['height', 'weight', 'category'], by='gender', layout=(4, 1)) + self._check_axes_shape(plt.gcf().axes, axes_num=3) + + box = df.groupby('classroom').boxplot(column=['height', 'weight', 'category'], layout=(1, 4)) + self._check_axes_shape(plt.gcf().axes, axes_num=3) + + @slow + def test_grouped_hist_layout(self): + + df = self.hist_df self.assertRaises(ValueError, df.hist, column='weight', by=df.gender, - layout=(1,)) + layout=(1, 1)) self.assertRaises(ValueError, df.hist, column='height', by=df.category, layout=(1, 3)) - self.assertRaises(ValueError, df.hist, column='height', by=df.category, - layout=(2, 1)) - self.assertEqual(df.hist(column='height', by=df.gender, - layout=(2, 1)).shape, (2,)) - tm.close() - self.assertEqual(df.hist(column='height', by=df.category, - layout=(4, 1)).shape, (4,)) - tm.close() - self.assertEqual(df.hist(column='height', by=df.category, - layout=(4, 2)).shape, (4, 2)) + + axes = _check_plot_works(df.hist, column='height', by=df.gender, layout=(2, 1)) + self._check_axes_shape(axes, axes_num=2, layout=(2, ), figsize=(10, 5)) + + axes = _check_plot_works(df.hist, column='height', by=df.category, layout=(4, 1)) + self._check_axes_shape(axes, axes_num=4, layout=(4, ), figsize=(10, 5)) + + axes = _check_plot_works(df.hist, column='height', by=df.category, + layout=(4, 2), figsize=(12, 8)) + self._check_axes_shape(axes, axes_num=4, layout=(4, 2), figsize=(12, 8)) + + # GH 6769 + axes = _check_plot_works(df.hist, column='height', by='classroom', layout=(2, 2)) + self._check_axes_shape(axes, axes_num=3, layout=(2, 2), figsize=(10, 5)) + + # without column + axes = _check_plot_works(df.hist, by='classroom') + self._check_axes_shape(axes, axes_num=3, layout=(2, 2), figsize=(10, 5)) + + axes = _check_plot_works(df.hist, by='gender', layout=(3, 5)) + self._check_axes_shape(axes, axes_num=2, layout=(3, 5), figsize=(10, 5)) + + axes = _check_plot_works(df.hist, column=['height', 'weight', 'category']) + self._check_axes_shape(axes, axes_num=3, layout=(2, 2), figsize=(10, 5)) @slow def test_axis_share_x(self): + df = self.hist_df # GH4089 - n = 100 - df = DataFrame({'gender': tm.choice(['Male', 'Female'], size=n), - 'height': random.normal(66, 4, size=n), - 'weight': random.normal(161, 32, size=n)}) ax1, ax2 = df.hist(column='height', by=df.gender, sharex=True) # share x @@ -1969,10 +2061,7 @@ def test_axis_share_x(self): @slow def test_axis_share_y(self): - n = 100 - df = DataFrame({'gender': tm.choice(['Male', 'Female'], size=n), - 'height': random.normal(66, 4, size=n), - 'weight': random.normal(161, 32, size=n)}) + df = self.hist_df ax1, ax2 = df.hist(column='height', by=df.gender, sharey=True) # share y @@ -1985,10 +2074,7 @@ def test_axis_share_y(self): @slow def test_axis_share_xy(self): - n = 100 - df = DataFrame({'gender': tm.choice(['Male', 'Female'], size=n), - 'height': random.normal(66, 4, size=n), - 'weight': random.normal(161, 32, size=n)}) + df = self.hist_df ax1, ax2 = df.hist(column='height', by=df.gender, sharex=True, sharey=True) diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 33f6d4464bc43..e9dca5d91c8fc 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -740,42 +740,6 @@ def r(h): return ax -def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None, - layout=None, sharex=False, sharey=False, rot=90, grid=True, - **kwargs): - """ - Grouped histogram - - Parameters - ---------- - data: Series/DataFrame - column: object, optional - by: object, optional - ax: axes, optional - bins: int, default 50 - figsize: tuple, optional - layout: optional - sharex: boolean, default False - sharey: boolean, default False - rot: int, default 90 - grid: bool, default True - kwargs: dict, keyword arguments passed to matplotlib.Axes.hist - - Returns - ------- - axes: collection of Matplotlib Axes - """ - def plot_group(group, ax): - ax.hist(group.dropna().values, bins=bins, **kwargs) - - fig, axes = _grouped_plot(plot_group, data, column=column, - by=by, sharex=sharex, sharey=sharey, - figsize=figsize, layout=layout, rot=rot) - fig.subplots_adjust(bottom=0.15, top=0.9, left=0.1, right=0.9, - hspace=0.5, wspace=0.3) - return axes - - class MPLPlot(object): """ Base class for assembling a pandas plot using matplotlib @@ -2294,7 +2258,7 @@ def plot_series(series, label=None, kind='line', use_index=True, rot=None, def boxplot(data, column=None, by=None, ax=None, fontsize=None, - rot=0, grid=True, figsize=None, **kwds): + rot=0, grid=True, figsize=None, layout=None, **kwds): """ Make a box plot from DataFrame column optionally grouped by some columns or other inputs @@ -2311,6 +2275,8 @@ def boxplot(data, column=None, by=None, ax=None, fontsize=None, rot : label rotation angle figsize : A tuple (width, height) in inches grid : Setting this to True will show the grid + layout : tuple (optional) + (rows, columns) for the layout of the plot kwds : other plotting keyword arguments to be passed to matplotlib boxplot function @@ -2355,16 +2321,16 @@ def plot_group(grouped, ax): columns = [column] if by is not None: - if not isinstance(by, (list, tuple)): - by = [by] - fig, axes = _grouped_plot_by_column(plot_group, data, columns=columns, by=by, grid=grid, figsize=figsize, - ax=ax) + ax=ax, layout=layout) # Return axes in multiplot case, maybe revisit later # 985 ret = axes else: + if layout is not None: + raise ValueError("The 'layout' keyword is not supported when " + "'by' is None") if ax is None: ax = _gca() fig = ax.get_figure() @@ -2489,13 +2455,8 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None, """ import matplotlib.pyplot as plt - if column is not None: - if not isinstance(column, (list, np.ndarray)): - column = [column] - data = data[column] - if by is not None: - axes = grouped_hist(data, by=by, ax=ax, grid=grid, figsize=figsize, + axes = grouped_hist(data, column=column, by=by, ax=ax, grid=grid, figsize=figsize, sharex=sharex, sharey=sharey, layout=layout, bins=bins, **kwds) @@ -2511,27 +2472,18 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None, return axes - n = len(data.columns) - - if layout is not None: - if not isinstance(layout, (tuple, list)) or len(layout) != 2: - raise ValueError('Layout must be a tuple of (rows, columns)') + if column is not None: + if not isinstance(column, (list, np.ndarray)): + column = [column] + data = data[column] + naxes = len(data.columns) - rows, cols = layout - if rows * cols < n: - raise ValueError('Layout of %sx%s is incompatible with %s columns' % (rows, cols, n)) - else: - rows, cols = 1, 1 - while rows * cols < n: - if cols > rows: - rows += 1 - else: - cols += 1 - fig, axes = _subplots(nrows=rows, ncols=cols, ax=ax, squeeze=False, + nrows, ncols = _get_layout(naxes, layout=layout) + fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes, ax=ax, squeeze=False, sharex=sharex, sharey=sharey, figsize=figsize) for i, col in enumerate(com._try_sort(data.columns)): - ax = axes[i / cols, i % cols] + ax = axes[i / ncols, i % ncols] ax.xaxis.set_visible(True) ax.yaxis.set_visible(True) ax.hist(data[col].dropna().values, bins=bins, **kwds) @@ -2547,10 +2499,6 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None, if yrot is not None: plt.setp(ax.get_yticklabels(), rotation=yrot) - for j in range(i + 1, rows * cols): - ax = axes[j / cols, j % cols] - ax.set_visible(False) - fig.subplots_adjust(wspace=0.3, hspace=0.3) return axes @@ -2633,8 +2581,44 @@ def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None, return axes +def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None, + layout=None, sharex=False, sharey=False, rot=90, grid=True, + **kwargs): + """ + Grouped histogram + + Parameters + ---------- + data: Series/DataFrame + column: object, optional + by: object, optional + ax: axes, optional + bins: int, default 50 + figsize: tuple, optional + layout: optional + sharex: boolean, default False + sharey: boolean, default False + rot: int, default 90 + grid: bool, default True + kwargs: dict, keyword arguments passed to matplotlib.Axes.hist + + Returns + ------- + axes: collection of Matplotlib Axes + """ + def plot_group(group, ax): + ax.hist(group.dropna().values, bins=bins, **kwargs) + + fig, axes = _grouped_plot(plot_group, data, column=column, + by=by, sharex=sharex, sharey=sharey, + figsize=figsize, layout=layout, rot=rot) + fig.subplots_adjust(bottom=0.15, top=0.9, left=0.1, right=0.9, + hspace=0.5, wspace=0.3) + return axes + + def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None, - rot=0, grid=True, figsize=None, **kwds): + rot=0, grid=True, figsize=None, layout=None, **kwds): """ Make box plots from DataFrameGroupBy data. @@ -2650,6 +2634,8 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None, rot : label rotation angle grid : Setting this to True will show the grid figsize : A tuple (width, height) in inches + layout : tuple (optional) + (rows, columns) for the layout of the plot kwds : other plotting keyword arguments to be passed to matplotlib boxplot function @@ -2676,15 +2662,16 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None, >>> boxplot_frame_groupby(grouped, subplots=False) """ if subplots is True: - nrows, ncols = _get_layout(len(grouped)) - _, axes = _subplots(nrows=nrows, ncols=ncols, squeeze=False, + naxes = len(grouped) + nrows, ncols = _get_layout(naxes, layout=layout) + _, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes, squeeze=False, sharex=False, sharey=True) - axes = axes.reshape(-1) if len(grouped) > 1 else axes + axes = _flatten(axes) ret = {} for (key, group), ax in zip(grouped, axes): d = group.boxplot(ax=ax, column=column, fontsize=fontsize, - rot=rot, grid=grid, figsize=figsize, **kwds) + rot=rot, grid=grid, **kwds) ax.set_title(com.pprint_thing(key)) ret[key] = d else: @@ -2698,7 +2685,7 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None, else: df = frames[0] ret = df.boxplot(column=column, fontsize=fontsize, rot=rot, - grid=grid, figsize=figsize, **kwds) + grid=grid, figsize=figsize, layout=layout, **kwds) return ret @@ -2706,7 +2693,6 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, figsize=None, sharex=True, sharey=True, layout=None, rot=0, ax=None, **kwargs): from pandas.core.frame import DataFrame - import matplotlib.pyplot as plt # allow to specify mpl default with 'default' if figsize is None or figsize == 'default': @@ -2716,29 +2702,16 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, if column is not None: grouped = grouped[column] - ngroups = len(grouped) - nrows, ncols = layout or _get_layout(ngroups) - - if nrows * ncols < ngroups: - raise ValueError("Number of plots in 'layout' must greater than or " - "equal to the number " "of groups in 'by'") - + naxes = len(grouped) + nrows, ncols = _get_layout(naxes, layout=layout) if figsize is None: # our favorite default beating matplotlib's idea of the # default size figsize = (10, 5) - fig, axes = _subplots(nrows=nrows, ncols=ncols, figsize=figsize, - sharex=sharex, sharey=sharey, ax=ax) + fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes, + figsize=figsize, sharex=sharex, sharey=sharey, ax=ax) - if isinstance(axes, plt.Axes): - ravel_axes = [axes] - else: - ravel_axes = [] - for row in axes: - if isinstance(row, plt.Axes): - ravel_axes.append(row) - else: - ravel_axes.extend(row) + ravel_axes = _flatten(axes) for i, (key, group) in enumerate(grouped): ax = ravel_axes[i] @@ -2752,34 +2725,26 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, def _grouped_plot_by_column(plotf, data, columns=None, by=None, numeric_only=True, grid=False, - figsize=None, ax=None, **kwargs): - import matplotlib.pyplot as plt - + figsize=None, ax=None, layout=None, **kwargs): grouped = data.groupby(by) if columns is None: + if not isinstance(by, (list, tuple)): + by = [by] columns = data._get_numeric_data().columns - by - ngroups = len(columns) + naxes = len(columns) if ax is None: - nrows, ncols = _get_layout(ngroups) - fig, axes = _subplots(nrows=nrows, ncols=ncols, + nrows, ncols = _get_layout(naxes, layout=layout) + fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes, sharex=True, sharey=True, figsize=figsize, ax=ax) else: - if ngroups > 1: + if naxes > 1: raise ValueError("Using an existing axis is not supported when plotting multiple columns.") fig = ax.get_figure() axes = ax.get_axes() - if isinstance(axes, plt.Axes): - ravel_axes = [axes] - else: - ravel_axes = [] - for row in axes: - if isinstance(row, plt.Axes): - ravel_axes.append(row) - else: - ravel_axes.extend(row) + ravel_axes = _flatten(axes) for i, col in enumerate(columns): ax = ravel_axes[i] @@ -2836,7 +2801,18 @@ def table(ax, data, rowLabels=None, colLabels=None, return table -def _get_layout(nplots): +def _get_layout(nplots, layout=None): + if layout is not None: + if not isinstance(layout, (tuple, list)) or len(layout) != 2: + raise ValueError('Layout must be a tuple of (rows, columns)') + + nrows, ncols = layout + if nrows * ncols < nplots: + raise ValueError('Layout of %sx%s must be larger than required size %s' % + (nrows, ncols, nplots)) + + return layout + if nplots == 1: return (1, 1) elif nplots == 2: @@ -2856,7 +2832,7 @@ def _get_layout(nplots): # copied from matplotlib/pyplot.py for compatibility with matplotlib < 1.0 -def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, +def _subplots(nrows=1, ncols=1, naxes=None, sharex=False, sharey=False, squeeze=True, subplot_kw=None, ax=None, secondary_y=False, data=None, **fig_kw): """Create a figure with a set of subplots already made. @@ -2872,6 +2848,9 @@ def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, ncols : int Number of columns of the subplot grid. Defaults to 1. + naxes : int + Number of required axes. Exceeded axes are set invisible. Default is nrows * ncols. + sharex : bool If True, the X axis will be shared amongst all subplots. @@ -2949,6 +2928,12 @@ def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, # Create empty object array to hold all axes. It's easiest to make it 1-d # so we can just append subplots upon creation, and then nplots = nrows * ncols + + if naxes is None: + naxes = nrows * ncols + elif nplots < naxes: + raise ValueError("naxes {0} is larger than layour size defined by nrows * ncols".format(naxes)) + axarr = np.empty(nplots, dtype=object) def on_right(i): @@ -2998,6 +2983,10 @@ def on_right(i): [label.set_visible( False) for label in ax.get_yticklabels()] + if naxes != nplots: + for ax in axarr[naxes:]: + ax.set_visible(False) + if squeeze: # Reshape the array to have the final desired dimension (nrow,ncol), # though discarding unneeded dimensions that equal 1. If we only have @@ -3013,6 +3002,14 @@ def on_right(i): return fig, axes +def _flatten(axes): + if not com.is_list_like(axes): + axes = [axes] + elif isinstance(axes, np.ndarray): + axes = axes.ravel() + return axes + + def _get_xlim(lines): left, right = np.inf, -np.inf for l in lines: