Skip to content

Fix for DataFrame.hist() with by- and weights-keyword #11441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions pandas/tests/test_graphics_others.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,30 @@ def test_boxplot_empty_column(self):
df.loc[:, 0] = np.nan
_check_plot_works(df.boxplot, return_type='axes')

@slow
def test_hist_df_nan_and_weights(self):
d = {'category' : ['A', 'A', 'B', 'B', 'C'],
'items' : [4., 3., 2., np.nan, 1],
'val' : [10., 8., np.nan, 5, 7.]}
df = DataFrame(d)
orig_columns = df.columns
orig_rows = len(df)
_check_plot_works(df.hist, column='items', by='category',
weights='val', bins=range(0, 10))
_check_plot_works(df.hist, column='items', by='category',
weights=df.val.values, bins=range(0, 10))
# check without weights functionality
_check_plot_works(df.hist, column='items', by='category',
bins=range(0, 10))
_check_plot_works(df.hist, column='items', weights='val',
bins=range(0, 10))
_check_plot_works(df.hist, column='items', weights=df.val.values,
bins=range(0, 10))
# also check that we have not changed the original df that had
# nan values in it.
self.assertEqual(len(orig_columns), len(df.columns))
self.assertEqual(orig_rows, len(df))

@slow
def test_hist_df_legacy(self):
from matplotlib.patches import Rectangle
Expand Down
75 changes: 59 additions & 16 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2772,7 +2772,8 @@ def plot_group(group, ax):

def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None,
xrot=None, ylabelsize=None, yrot=None, ax=None, sharex=False,
sharey=False, figsize=None, layout=None, bins=10, **kwds):
sharey=False, figsize=None, layout=None, bins=10, weights=None,
**kwds):
"""
Draw histogram of the DataFrame's series using matplotlib / pylab.

Expand Down Expand Up @@ -2807,17 +2808,37 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None,
layout: (optional) a tuple (rows, columns) for the layout of the histograms
bins: integer, default 10
Number of histogram bins to be used
weights : string or sequence
If passed, will be used to weight the data
kwds : other plotting keyword arguments
To be passed to hist function
"""
subset_cols_drop_nan = []
if weights is not None:
if isinstance(weights, np.ndarray):
# weights supplied as an array instead of a part of the dataframe
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this will not work if weights is a >1 dimensional ndarray... Need to think on this

if 'weights' in data.columns:
raise NameError('weights already in data.columns. Could not ' +
'add dummy column')
data = data.copy()
data['weights'] = weights
weights = 'weights'
subset_cols_drop_nan.append(weights)
if column is not None:
subset_cols_drop_nan.append(column)
data = data.dropna(subset=subset_cols_drop_nan)

if by is not None:
axes = grouped_hist(data, column=column, by=by, ax=ax, grid=grid, figsize=figsize,
sharex=sharex, sharey=sharey, layout=layout, bins=bins,
xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot,
**kwds)
weights=weights, **kwds)
return axes

if weights is not None:
weights = data[weights]
weights = weights._get_numeric_data()

if column is not None:
if not isinstance(column, (list, np.ndarray, Index)):
column = [column]
Expand All @@ -2832,7 +2853,7 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None,

for i, col in enumerate(com._try_sort(data.columns)):
ax = _axes[i]
ax.hist(data[col].dropna().values, bins=bins, **kwds)
ax.hist(data[col].values, bins=bins, weights=weights, **kwds)
ax.set_title(col)
ax.grid(grid)

Expand Down Expand Up @@ -2916,10 +2937,10 @@ def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None,
return axes


def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None,
layout=None, sharex=False, sharey=False, rot=90, grid=True,
xlabelsize=None, xrot=None, ylabelsize=None, yrot=None,
**kwargs):
def grouped_hist(data, column=None, by=None, ax=None, bins=50,
figsize=None, layout=None, sharex=False, sharey=False, rot=90,
grid=True, xlabelsize=None, xrot=None, ylabelsize=None,
yrot=None, weights=None, **kwargs):
"""
Grouped histogram

Expand All @@ -2936,20 +2957,30 @@ def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None,
sharey: boolean, default False
rot: int, default 90
grid: bool, default True
weights: object, optional
kwargs: dict, keyword arguments passed to matplotlib.Axes.hist

Returns
-------
axes: collection of Matplotlib Axes
"""
def plot_group(group, ax):
ax.hist(group.dropna().values, bins=bins, **kwargs)
def plot_group(group, ax, weights=None):
if isinstance(group, np.ndarray) == False:
group = group.values
if weights is not None:
if isinstance(weights, np.ndarray) == False:
weights = weights.values
if len(group) > 0:
# if length is less than 0, we had only NaN's for this group
# nothing to print!
ax.hist(group, weights=weights, bins=bins, **kwargs)

xrot = xrot or rot

fig, axes = _grouped_plot(plot_group, data, column=column,
by=by, sharex=sharex, sharey=sharey, ax=ax,
figsize=figsize, layout=layout, rot=rot)
fig, axes = _grouped_plot(plot_group, data, column=column, by=by,
sharex=sharex, sharey=sharey, ax=ax,
figsize=figsize, layout=layout, rot=rot,
weights=weights)

_set_ticks_props(axes, xlabelsize=xlabelsize, xrot=xrot,
ylabelsize=ylabelsize, yrot=yrot)
Expand Down Expand Up @@ -3034,9 +3065,9 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,
return ret


def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
figsize=None, sharex=True, sharey=True, layout=None,
rot=0, ax=None, **kwargs):
def _grouped_plot(plotf, data, column=None, by=None,
numeric_only=True, figsize=None, sharex=True, sharey=True,
layout=None, rot=0, ax=None, weights=None, **kwargs):
from pandas import DataFrame

if figsize == 'default':
Expand All @@ -3046,6 +3077,9 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
figsize = None

grouped = data.groupby(by)

if weights is not None:
weights = grouped[weights]
if column is not None:
grouped = grouped[column]

Expand All @@ -3056,11 +3090,20 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,

_axes = _flatten(axes)

weight = None
for i, (key, group) in enumerate(grouped):
ax = _axes[i]
if weights is not None:
weight = weights.get_group(key)
if numeric_only and isinstance(group, DataFrame):
group = group._get_numeric_data()
plotf(group, ax, **kwargs)
if weight is not None:
weight = weight._get_numeric_data()
if weight is not None:
plotf(group, ax, weight, **kwargs)
else:
# scatterplot etc has not the weight implemented in plotf
plotf(group, ax, **kwargs)
ax.set_title(com.pprint_thing(key))

return fig, axes
Expand Down