diff --git a/pandas/plotting/_core.py b/pandas/plotting/_core.py index e5b9497993172..338155633f6b3 100644 --- a/pandas/plotting/_core.py +++ b/pandas/plotting/_core.py @@ -17,7 +17,9 @@ is_integer, is_number, is_hashable, - is_iterator) + is_iterator, + is_numeric_dtype, + is_categorical_dtype) from pandas.core.dtypes.generic import ABCSeries from pandas.core.common import AbstractMethodError, _try_sort @@ -815,11 +817,27 @@ def _post_plot_logic(self, ax, data): class ScatterPlot(PlanePlot): _kind = 'scatter' - def __init__(self, data, x, y, s=None, c=None, **kwargs): + def __init__(self, data, x, y, s=None, c=None, size_factor=1, **kwargs): if s is None: - # hide the matplotlib default for size, in case we want to change - # the handling of this argument later + # Set default size if no argument is given. s = 20 + elif is_hashable(s) and s in data.columns: + # Handle the case where s is a label of a column of the df. + # The data is normalized to 200 * size_factor. + size_data = data[s] + if is_categorical_dtype(size_data): + if size_data.cat.ordered: + size_data = size_data.cat.codes + 1 + else: + raise TypeError("'s' must be numeric or ordered categorical dtype") + if is_numeric_dtype(size_data): + self.size_title = s + self.s_data_max = size_data.max() + self.size_factor = size_factor + self.bubble_points = 200 + s = self.bubble_points * size_factor * size_data / self.s_data_max + else: + raise TypeError("'s' must be numeric or ordered categorical dtype") super(ScatterPlot, self).__init__(data, x, y, s=s, **kwargs) if is_integer(c) and not self.data.columns.holds_integer(): c = self.data.columns[c] @@ -828,7 +846,6 @@ def __init__(self, data, x, y, s=None, c=None, **kwargs): def _make_plot(self): x, y, c, data = self.x, self.y, self.c, self.data ax = self.axes[0] - c_is_column = is_hashable(c) and c in self.data.columns # plot a colorbar only if a colormap is provided or necessary @@ -875,6 +892,68 @@ def _make_plot(self): ax.errorbar(data[x].values, data[y].values, linestyle='none', **err_kwds) + def _sci_notation(self, num): + ''' + Returns mantissa and exponent of the number passed in agument. + Example: + _sci_notation(89278.8924) + >>> (8.9, 5.0) + ''' + scientific_notation = '{:e}'.format(num) + expnt = float(re.search(r'e([+-]\d*)$', + scientific_notation).groups()[0]) + coef = float(re.search(r'^([+-]?\d\.\d)', + scientific_notation).groups()[0]) + return coef, expnt + + def _legend_bubbles(self, s_data_max, size_factor, bubble_points): + ''' + Computes and returns appropriate bubble sizes and labels for the + legend of a bubble plot. Creates 4 bubbles with round values for the + labels, the largest of which is close to the maximum of the data. + ''' + coef, expnt = self._sci_notation(s_data_max) + labels_catalog = { + (9, 10): [10, 5, 2.5, 1], + (7, 9): [8, 4, 2, 0.5], + (5.5, 7): [6, 3, 1.5, 0.5], + (4.5, 5.5): [5, 2, 1, 0.2], + (3.5, 4.5): [4, 2, 1, 0.2], + (2.5, 3.5): [3, 1, 0.5, 0.2], + (1.5, 2.5): [2, 1, 0.5, 0.2], + (0, 1.5): [1, 0.5, 0.25, 0.1] + } + for lower_bound, upper_bound in labels_catalog: + if (coef >= lower_bound) & (coef < upper_bound): + labels = 10**expnt * np.array(labels_catalog[lower_bound, + upper_bound]) + sizes = list(bubble_points * size_factor * labels / s_data_max) + labels = ['{:g}'.format(l) for l in labels] + return (sizes, labels) + + def _make_legend(self): + if hasattr(self, "size_title"): + ax = self.axes[0] + import matplotlib.legend as legend + from matplotlib.collections import CircleCollection + sizes, labels = self._legend_bubbles(self.s_data_max, + self.size_factor, + self.bubble_points) + color = self.plt.rcParams['axes.facecolor'], + edgecolor = self.plt.rcParams['axes.edgecolor'])) + bubbles = [] + for size in sizes: + bubbles.append(CircleCollection(sizes=[size], + color=color, + edgecolor=edgecolor)) + bubble_legend = legend.Legend(ax, + handles=bubbles, + labels=labels, + loc='lower right') + bubble_legend.set_title(self.size_title) + ax.add_artist(bubble_legend) + super()._make_legend() + class HexBinPlot(PlanePlot): _kind = 'hexbin' diff --git a/pandas/tests/plotting/test_frame.py b/pandas/tests/plotting/test_frame.py index 67098529a0111..86c93d193d593 100644 --- a/pandas/tests/plotting/test_frame.py +++ b/pandas/tests/plotting/test_frame.py @@ -1005,6 +1005,40 @@ def test_scatter_colors(self): tm.assert_numpy_array_equal(ax.collections[0].get_facecolor()[0], np.array([1, 1, 1, 1], dtype=np.float64)) + @pytest.mark.slow + def test_plot_scatter_with_s(self): + data = np.array([[3.1, 4.2, 1.9], + [1.9, 2.8, 3.1], + [5.4, 4.32, 2.0], + [0.4, 3.4, 0.46], + [4.4, 4.9, 0.8], + [2.7, 6.2, 1.49]]) + df = DataFrame(data, + columns = ['x', 'y', 'z']) + ax = df.plot.scatter(x='x', y='y', s='z', size_factor=4) + bubbles = ax.collections[0] + bubble_sizes = bubbles.get_sizes() + max_data = df['z'].max() + expected_sizes = 200 * 4 * df['z'].values / max_data + tm.assert_numpy_array_equal(bubble_sizes, expected_sizes) + + @pytest.mark.slow + def test_plot_scatter_with_categorical_s(self): + data = np.array([[3.1, 4.2], + [1.9, 2.8], + [5.4, 4.32], + [0.4, 3.4], + [4.4, 4.9], + [2.7, 6.2]]) + df = DataFrame(data, columns = ['x', 'y']) + df['z'] = pd.Categorical(['a', 'b', 'c', 'a', 'b', 'c'], ordered=True) + ax = df.plot.scatter(x='x', y='y', s='z', size_factor=4) + bubbles = ax.collections[0] + bubble_sizes = bubbles.get_sizes() + max_data = df['z'].cat.codes.max() + 1 + expected_sizes = 200 * 4 * (df['z'].cat.codes.values + 1) / max_data + tm.assert_numpy_array_equal(bubble_sizes, expected_sizes) + @pytest.mark.slow def test_plot_bar(self): df = DataFrame(randn(6, 4),