diff --git a/doc/source/release.rst b/doc/source/release.rst index b5a11091779ec..ab89cbfe2929b 100644 --- a/doc/source/release.rst +++ b/doc/source/release.rst @@ -475,6 +475,7 @@ Bug Fixes caused possible color/class mismatch (:issue:`6956`) - Bug in ``radviz`` and ``andrews_curves`` where multiple values of 'color' were being passed to plotting method (:issue:`6956`) +- Bug in ``DataFrame.boxplot`` where it failed to use the axis passed as the ``ax`` argument (:issue:`3578`) pandas 0.13.1 ------------- diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index e3f49e14400d1..ca17e74d5eb07 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -1083,6 +1083,25 @@ def test_boxplot(self): df['X'] = Series(['A', 'A', 'A', 'A', 'A', 'B', 'B', 'B', 'B', 'B']) _check_plot_works(df.boxplot, by='X') + # When ax is supplied, existing axes should be used: + import matplotlib.pyplot as plt + fig, ax = plt.subplots() + axes = df.boxplot('Col1', by='X', ax=ax) + self.assertIs(ax.get_axes(), axes) + + # Multiple columns with an ax argument is not supported + fig, ax = plt.subplots() + self.assertRaisesRegexp( + ValueError, 'existing axis', df.boxplot, + column=['Col1', 'Col2'], by='X', ax=ax + ) + + # When by is None, check that all relevant lines are present in the dict + fig, ax = plt.subplots() + d = df.boxplot(ax=ax) + lines = list(itertools.chain.from_iterable(d.values())) + self.assertEqual(len(ax.get_lines()), len(lines)) + @slow def test_kde(self): _skip_if_no_scipy() diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index b11d71f48baf2..33f6d4464bc43 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -2760,10 +2760,16 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None, columns = data._get_numeric_data().columns - by ngroups = len(columns) - nrows, ncols = _get_layout(ngroups) - fig, axes = _subplots(nrows=nrows, ncols=ncols, - sharex=True, sharey=True, - figsize=figsize, ax=ax) + if ax is None: + nrows, ncols = _get_layout(ngroups) + fig, axes = _subplots(nrows=nrows, ncols=ncols, + sharex=True, sharey=True, + figsize=figsize, ax=ax) + else: + if ngroups > 1: + raise ValueError("Using an existing axis is not supported when plotting multiple columns.") + fig = ax.get_figure() + axes = ax.get_axes() if isinstance(axes, plt.Axes): ravel_axes = [axes]