-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
Feat/scatter by size #17582
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/scatter by size #17582
Changes from all commits
55733a3
d2d42e5
895afd8
9a86ce1
bc5adb4
84de8ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,9 @@ | |
is_integer, | ||
is_number, | ||
is_hashable, | ||
is_iterator) | ||
is_iterator, | ||
is_numeric_dtype, | ||
is_categorical_dtype) | ||
from pandas.core.dtypes.generic import ABCSeries | ||
|
||
from pandas.core.common import AbstractMethodError, _try_sort | ||
|
@@ -815,11 +817,27 @@ def _post_plot_logic(self, ax, data): | |
class ScatterPlot(PlanePlot): | ||
_kind = 'scatter' | ||
|
||
def __init__(self, data, x, y, s=None, c=None, **kwargs): | ||
def __init__(self, data, x, y, s=None, c=None, size_factor=1, **kwargs): | ||
if s is None: | ||
# hide the matplotlib default for size, in case we want to change | ||
# the handling of this argument later | ||
# Set default size if no argument is given. | ||
s = 20 | ||
elif is_hashable(s) and s in data.columns: | ||
# Handle the case where s is a label of a column of the df. | ||
# The data is normalized to 200 * size_factor. | ||
size_data = data[s] | ||
if is_categorical_dtype(size_data): | ||
if size_data.cat.ordered: | ||
size_data = size_data.cat.codes + 1 | ||
else: | ||
raise TypeError("'s' must be numeric or ordered categorical dtype") | ||
if is_numeric_dtype(size_data): | ||
self.size_title = s | ||
self.s_data_max = size_data.max() | ||
self.size_factor = size_factor | ||
self.bubble_points = 200 | ||
s = self.bubble_points * size_factor * size_data / self.s_data_max | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd like to support Categorical data here as well. How much work would it be to do that? Perhaps if is_categorical_dtype(size_data):
if size_data.ordered:
size_data = size_data.codes
# else raise with a nice error message and then the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @TomAugspurger I'm not sure I understand what it means to plot categorical data as sizes. Could you give me a use case example? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @VincentAntoine this basically, but with all the nice stuff from this PR df = pd.DataFrame(np.random.randn(100, 2), columns=['a', 'b'])
categories = list('abcd')
df['c'] = pd.Categorical(np.random.choice(categories, size=(100,)),
categories=categories, ordered=True)
df.head()
fig, ax = plt.subplots()
s = (10 + df.c.cat.codes * 10)
ax.scatter(x='a', y='b', data=df, s=s); So the idea is to automatically know to use
before There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it works just as you said with no additionnal modification :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We just need to do size_data = df[s].cat.codes + 1, as codes start at 0 and the resulting bubbles will have an area of 0. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @TomAugspurger I added this in the last commit, and added a test for scatter plot with categorical data as well. I'll write the release note now. Let me know if the code needs any more changes. |
||
raise TypeError("'s' must be numeric or ordered categorical dtype") | ||
super(ScatterPlot, self).__init__(data, x, y, s=s, **kwargs) | ||
if is_integer(c) and not self.data.columns.holds_integer(): | ||
c = self.data.columns[c] | ||
|
@@ -828,7 +846,6 @@ def __init__(self, data, x, y, s=None, c=None, **kwargs): | |
def _make_plot(self): | ||
x, y, c, data = self.x, self.y, self.c, self.data | ||
ax = self.axes[0] | ||
|
||
c_is_column = is_hashable(c) and c in self.data.columns | ||
|
||
# plot a colorbar only if a colormap is provided or necessary | ||
|
@@ -875,6 +892,68 @@ def _make_plot(self): | |
ax.errorbar(data[x].values, data[y].values, | ||
linestyle='none', **err_kwds) | ||
|
||
def _sci_notation(self, num): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Brief doc-string here. |
||
''' | ||
Returns mantissa and exponent of the number passed in agument. | ||
Example: | ||
_sci_notation(89278.8924) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this can just be a module level function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does matplotlib have anything that does this? (cc @tacaswell) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, typically the docstring formatting is like >>> _sci_notation(89278.8924)
(8.9, 5.0) identical to what you get from the regular python REPL. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but buried in the formatter code. I would not try to re-use it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we place this function as a module level function within plotting/_core as suggested by @jreback or is there another module that would be more appropriate ? Maybe plotting/_converter.py ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. module level in |
||
>>> (8.9, 5.0) | ||
''' | ||
scientific_notation = '{:e}'.format(num) | ||
expnt = float(re.search(r'e([+-]\d*)$', | ||
scientific_notation).groups()[0]) | ||
coef = float(re.search(r'^([+-]?\d\.\d)', | ||
scientific_notation).groups()[0]) | ||
return coef, expnt | ||
|
||
def _legend_bubbles(self, s_data_max, size_factor, bubble_points): | ||
''' | ||
Computes and returns appropriate bubble sizes and labels for the | ||
legend of a bubble plot. Creates 4 bubbles with round values for the | ||
labels, the largest of which is close to the maximum of the data. | ||
''' | ||
coef, expnt = self._sci_notation(s_data_max) | ||
labels_catalog = { | ||
(9, 10): [10, 5, 2.5, 1], | ||
(7, 9): [8, 4, 2, 0.5], | ||
(5.5, 7): [6, 3, 1.5, 0.5], | ||
(4.5, 5.5): [5, 2, 1, 0.2], | ||
(3.5, 4.5): [4, 2, 1, 0.2], | ||
(2.5, 3.5): [3, 1, 0.5, 0.2], | ||
(1.5, 2.5): [2, 1, 0.5, 0.2], | ||
(0, 1.5): [1, 0.5, 0.25, 0.1] | ||
} | ||
for lower_bound, upper_bound in labels_catalog: | ||
if (coef >= lower_bound) & (coef < upper_bound): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did not know that worked :) |
||
labels = 10**expnt * np.array(labels_catalog[lower_bound, | ||
upper_bound]) | ||
sizes = list(bubble_points * size_factor * labels / s_data_max) | ||
labels = ['{:g}'.format(l) for l in labels] | ||
return (sizes, labels) | ||
|
||
def _make_legend(self): | ||
if hasattr(self, "size_title"): | ||
ax = self.axes[0] | ||
import matplotlib.legend as legend | ||
from matplotlib.collections import CircleCollection | ||
sizes, labels = self._legend_bubbles(self.s_data_max, | ||
self.size_factor, | ||
self.bubble_points) | ||
color = self.plt.rcParams['axes.facecolor'], | ||
edgecolor = self.plt.rcParams['axes.edgecolor'])) | ||
bubbles = [] | ||
for size in sizes: | ||
bubbles.append(CircleCollection(sizes=[size], | ||
color=color, | ||
edgecolor=edgecolor)) | ||
bubble_legend = legend.Legend(ax, | ||
handles=bubbles, | ||
labels=labels, | ||
loc='lower right') | ||
bubble_legend.set_title(self.size_title) | ||
ax.add_artist(bubble_legend) | ||
super()._make_legend() | ||
|
||
|
||
class HexBinPlot(PlanePlot): | ||
_kind = 'hexbin' | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need to check hasability, other ops will fail if this is the case