Skip to content

Feat/scatter by size #17582

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 84 additions & 5 deletions pandas/plotting/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
is_integer,
is_number,
is_hashable,
is_iterator)
is_iterator,
is_numeric_dtype,
is_categorical_dtype)
from pandas.core.dtypes.generic import ABCSeries

from pandas.core.common import AbstractMethodError, _try_sort
Expand Down Expand Up @@ -815,11 +817,27 @@ def _post_plot_logic(self, ax, data):
class ScatterPlot(PlanePlot):
_kind = 'scatter'

def __init__(self, data, x, y, s=None, c=None, **kwargs):
def __init__(self, data, x, y, s=None, c=None, size_factor=1, **kwargs):
if s is None:
# hide the matplotlib default for size, in case we want to change
# the handling of this argument later
# Set default size if no argument is given.
s = 20
elif is_hashable(s) and s in data.columns:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need to check hasability, other ops will fail if this is the case

# Handle the case where s is a label of a column of the df.
# The data is normalized to 200 * size_factor.
size_data = data[s]
if is_categorical_dtype(size_data):
if size_data.cat.ordered:
size_data = size_data.cat.codes + 1
else:
raise TypeError("'s' must be numeric or ordered categorical dtype")
if is_numeric_dtype(size_data):
self.size_title = s
self.s_data_max = size_data.max()
self.size_factor = size_factor
self.bubble_points = 200
s = self.bubble_points * size_factor * size_data / self.s_data_max
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to support Categorical data here as well. How much work would it be to do that? Perhaps

if is_categorical_dtype(size_data):
    if size_data.ordered:
        size_data = size_data.codes
    # else raise with a nice error message

and then the if is_numeric_dtype(size_data)? Does that work?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TomAugspurger I'm not sure I understand what it means to plot categorical data as sizes. Could you give me a use case example?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VincentAntoine this basically, but with all the nice stuff from this PR

df = pd.DataFrame(np.random.randn(100, 2), columns=['a', 'b'])

categories = list('abcd')
df['c'] = pd.Categorical(np.random.choice(categories, size=(100,)),
                         categories=categories, ordered=True)
df.head()

fig, ax = plt.subplots()
s = (10 + df.c.cat.codes * 10)

ax.scatter(x='a', y='b', data=df, s=s);

gh

So the idea is to automatically know to use .categories.codes for categorical dtype data, instead of just the categories (which may not be numeric). I think if you do the

  • check for categorical dtype
  • size_data = df[s].cat.codes

before if is_numeric_dtype(size_data), then everything else should be the same

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it works just as you said with no additionnal modification :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We just need to do size_data = df[s].cat.codes + 1, as codes start at 0 and the resulting bubbles will have an area of 0.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TomAugspurger I added this in the last commit, and added a test for scatter plot with categorical data as well. I'll write the release note now. Let me know if the code needs any more changes.

raise TypeError("'s' must be numeric or ordered categorical dtype")
super(ScatterPlot, self).__init__(data, x, y, s=s, **kwargs)
if is_integer(c) and not self.data.columns.holds_integer():
c = self.data.columns[c]
Expand All @@ -828,7 +846,6 @@ def __init__(self, data, x, y, s=None, c=None, **kwargs):
def _make_plot(self):
x, y, c, data = self.x, self.y, self.c, self.data
ax = self.axes[0]

c_is_column = is_hashable(c) and c in self.data.columns

# plot a colorbar only if a colormap is provided or necessary
Expand Down Expand Up @@ -875,6 +892,68 @@ def _make_plot(self):
ax.errorbar(data[x].values, data[y].values,
linestyle='none', **err_kwds)

def _sci_notation(self, num):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Brief doc-string here.

'''
Returns mantissa and exponent of the number passed in agument.
Example:
_sci_notation(89278.8924)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can just be a module level function

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does matplotlib have anything that does this? (cc @tacaswell)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, typically the docstring formatting is like

>>> _sci_notation(89278.8924)
(8.9, 5.0)

identical to what you get from the regular python REPL.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but buried in the formatter code. I would not try to re-use it.

Copy link
Author

@VincentAntoine VincentAntoine Oct 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we place this function as a module level function within plotting/_core as suggested by @jreback or is there another module that would be more appropriate ? Maybe plotting/_converter.py ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

module level in plotting/_core.py is fine I think.

>>> (8.9, 5.0)
'''
scientific_notation = '{:e}'.format(num)
expnt = float(re.search(r'e([+-]\d*)$',
scientific_notation).groups()[0])
coef = float(re.search(r'^([+-]?\d\.\d)',
scientific_notation).groups()[0])
return coef, expnt

def _legend_bubbles(self, s_data_max, size_factor, bubble_points):
'''
Computes and returns appropriate bubble sizes and labels for the
legend of a bubble plot. Creates 4 bubbles with round values for the
labels, the largest of which is close to the maximum of the data.
'''
coef, expnt = self._sci_notation(s_data_max)
labels_catalog = {
(9, 10): [10, 5, 2.5, 1],
(7, 9): [8, 4, 2, 0.5],
(5.5, 7): [6, 3, 1.5, 0.5],
(4.5, 5.5): [5, 2, 1, 0.2],
(3.5, 4.5): [4, 2, 1, 0.2],
(2.5, 3.5): [3, 1, 0.5, 0.2],
(1.5, 2.5): [2, 1, 0.5, 0.2],
(0, 1.5): [1, 0.5, 0.25, 0.1]
}
for lower_bound, upper_bound in labels_catalog:
if (coef >= lower_bound) & (coef < upper_bound):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lower_bound <= coef < upper_bound is a bit more readable here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not know that worked :)

labels = 10**expnt * np.array(labels_catalog[lower_bound,
upper_bound])
sizes = list(bubble_points * size_factor * labels / s_data_max)
labels = ['{:g}'.format(l) for l in labels]
return (sizes, labels)

def _make_legend(self):
if hasattr(self, "size_title"):
ax = self.axes[0]
import matplotlib.legend as legend
from matplotlib.collections import CircleCollection
sizes, labels = self._legend_bubbles(self.s_data_max,
self.size_factor,
self.bubble_points)
color = self.plt.rcParams['axes.facecolor'],
edgecolor = self.plt.rcParams['axes.edgecolor']))
bubbles = []
for size in sizes:
bubbles.append(CircleCollection(sizes=[size],
color=color,
edgecolor=edgecolor))
bubble_legend = legend.Legend(ax,
handles=bubbles,
labels=labels,
loc='lower right')
bubble_legend.set_title(self.size_title)
ax.add_artist(bubble_legend)
super()._make_legend()


class HexBinPlot(PlanePlot):
_kind = 'hexbin'
Expand Down
34 changes: 34 additions & 0 deletions pandas/tests/plotting/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,40 @@ def test_scatter_colors(self):
tm.assert_numpy_array_equal(ax.collections[0].get_facecolor()[0],
np.array([1, 1, 1, 1], dtype=np.float64))

@pytest.mark.slow
def test_plot_scatter_with_s(self):
data = np.array([[3.1, 4.2, 1.9],
[1.9, 2.8, 3.1],
[5.4, 4.32, 2.0],
[0.4, 3.4, 0.46],
[4.4, 4.9, 0.8],
[2.7, 6.2, 1.49]])
df = DataFrame(data,
columns = ['x', 'y', 'z'])
ax = df.plot.scatter(x='x', y='y', s='z', size_factor=4)
bubbles = ax.collections[0]
bubble_sizes = bubbles.get_sizes()
max_data = df['z'].max()
expected_sizes = 200 * 4 * df['z'].values / max_data
tm.assert_numpy_array_equal(bubble_sizes, expected_sizes)

@pytest.mark.slow
def test_plot_scatter_with_categorical_s(self):
data = np.array([[3.1, 4.2],
[1.9, 2.8],
[5.4, 4.32],
[0.4, 3.4],
[4.4, 4.9],
[2.7, 6.2]])
df = DataFrame(data, columns = ['x', 'y'])
df['z'] = pd.Categorical(['a', 'b', 'c', 'a', 'b', 'c'], ordered=True)
ax = df.plot.scatter(x='x', y='y', s='z', size_factor=4)
bubbles = ax.collections[0]
bubble_sizes = bubbles.get_sizes()
max_data = df['z'].cat.codes.max() + 1
expected_sizes = 200 * 4 * (df['z'].cat.codes.values + 1) / max_data
tm.assert_numpy_array_equal(bubble_sizes, expected_sizes)

@pytest.mark.slow
def test_plot_bar(self):
df = DataFrame(randn(6, 4),
Expand Down