diff --git a/pandas/tests/test_graphics_others.py b/pandas/tests/test_graphics_others.py index b18cbae600b43..54f6cf50ea5ec 100644 --- a/pandas/tests/test_graphics_others.py +++ b/pandas/tests/test_graphics_others.py @@ -302,6 +302,30 @@ def test_boxplot_empty_column(self): df.loc[:, 0] = np.nan _check_plot_works(df.boxplot, return_type='axes') + @slow + def test_hist_df_nan_and_weights(self): + d = {'category' : ['A', 'A', 'B', 'B', 'C'], + 'items' : [4., 3., 2., np.nan, 1], + 'val' : [10., 8., np.nan, 5, 7.]} + df = DataFrame(d) + orig_columns = df.columns + orig_rows = len(df) + _check_plot_works(df.hist, column='items', by='category', + weights='val', bins=range(0, 10)) + _check_plot_works(df.hist, column='items', by='category', + weights=df.val.values, bins=range(0, 10)) + # check without weights functionality + _check_plot_works(df.hist, column='items', by='category', + bins=range(0, 10)) + _check_plot_works(df.hist, column='items', weights='val', + bins=range(0, 10)) + _check_plot_works(df.hist, column='items', weights=df.val.values, + bins=range(0, 10)) + # also check that we have not changed the original df that had + # nan values in it. + self.assertEqual(len(orig_columns), len(df.columns)) + self.assertEqual(orig_rows, len(df)) + @slow def test_hist_df_legacy(self): from matplotlib.patches import Rectangle diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 98d6f5e8eb797..3acf2f39b8864 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -2772,7 +2772,8 @@ def plot_group(group, ax): def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None, xrot=None, ylabelsize=None, yrot=None, ax=None, sharex=False, - sharey=False, figsize=None, layout=None, bins=10, **kwds): + sharey=False, figsize=None, layout=None, bins=10, weights=None, + **kwds): """ Draw histogram of the DataFrame's series using matplotlib / pylab. @@ -2807,17 +2808,37 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None, layout: (optional) a tuple (rows, columns) for the layout of the histograms bins: integer, default 10 Number of histogram bins to be used + weights : string or sequence + If passed, will be used to weight the data kwds : other plotting keyword arguments To be passed to hist function """ + subset_cols_drop_nan = [] + if weights is not None: + if isinstance(weights, np.ndarray): + # weights supplied as an array instead of a part of the dataframe + if 'weights' in data.columns: + raise NameError('weights already in data.columns. Could not ' + + 'add dummy column') + data = data.copy() + data['weights'] = weights + weights = 'weights' + subset_cols_drop_nan.append(weights) + if column is not None: + subset_cols_drop_nan.append(column) + data = data.dropna(subset=subset_cols_drop_nan) if by is not None: axes = grouped_hist(data, column=column, by=by, ax=ax, grid=grid, figsize=figsize, sharex=sharex, sharey=sharey, layout=layout, bins=bins, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot, - **kwds) + weights=weights, **kwds) return axes + if weights is not None: + weights = data[weights] + weights = weights._get_numeric_data() + if column is not None: if not isinstance(column, (list, np.ndarray, Index)): column = [column] @@ -2832,7 +2853,7 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None, for i, col in enumerate(com._try_sort(data.columns)): ax = _axes[i] - ax.hist(data[col].dropna().values, bins=bins, **kwds) + ax.hist(data[col].values, bins=bins, weights=weights, **kwds) ax.set_title(col) ax.grid(grid) @@ -2916,10 +2937,10 @@ def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None, return axes -def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None, - layout=None, sharex=False, sharey=False, rot=90, grid=True, - xlabelsize=None, xrot=None, ylabelsize=None, yrot=None, - **kwargs): +def grouped_hist(data, column=None, by=None, ax=None, bins=50, + figsize=None, layout=None, sharex=False, sharey=False, rot=90, + grid=True, xlabelsize=None, xrot=None, ylabelsize=None, + yrot=None, weights=None, **kwargs): """ Grouped histogram @@ -2936,20 +2957,30 @@ def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None, sharey: boolean, default False rot: int, default 90 grid: bool, default True + weights: object, optional kwargs: dict, keyword arguments passed to matplotlib.Axes.hist Returns ------- axes: collection of Matplotlib Axes """ - def plot_group(group, ax): - ax.hist(group.dropna().values, bins=bins, **kwargs) + def plot_group(group, ax, weights=None): + if isinstance(group, np.ndarray) == False: + group = group.values + if weights is not None: + if isinstance(weights, np.ndarray) == False: + weights = weights.values + if len(group) > 0: + # if length is less than 0, we had only NaN's for this group + # nothing to print! + ax.hist(group, weights=weights, bins=bins, **kwargs) xrot = xrot or rot - fig, axes = _grouped_plot(plot_group, data, column=column, - by=by, sharex=sharex, sharey=sharey, ax=ax, - figsize=figsize, layout=layout, rot=rot) + fig, axes = _grouped_plot(plot_group, data, column=column, by=by, + sharex=sharex, sharey=sharey, ax=ax, + figsize=figsize, layout=layout, rot=rot, + weights=weights) _set_ticks_props(axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot) @@ -3034,9 +3065,9 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None, return ret -def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, - figsize=None, sharex=True, sharey=True, layout=None, - rot=0, ax=None, **kwargs): +def _grouped_plot(plotf, data, column=None, by=None, + numeric_only=True, figsize=None, sharex=True, sharey=True, + layout=None, rot=0, ax=None, weights=None, **kwargs): from pandas import DataFrame if figsize == 'default': @@ -3046,6 +3077,9 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, figsize = None grouped = data.groupby(by) + + if weights is not None: + weights = grouped[weights] if column is not None: grouped = grouped[column] @@ -3056,11 +3090,20 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, _axes = _flatten(axes) + weight = None for i, (key, group) in enumerate(grouped): ax = _axes[i] + if weights is not None: + weight = weights.get_group(key) if numeric_only and isinstance(group, DataFrame): group = group._get_numeric_data() - plotf(group, ax, **kwargs) + if weight is not None: + weight = weight._get_numeric_data() + if weight is not None: + plotf(group, ax, weight, **kwargs) + else: + # scatterplot etc has not the weight implemented in plotf + plotf(group, ax, **kwargs) ax.set_title(com.pprint_thing(key)) return fig, axes