diff --git a/.gitignore b/.gitignore index 986dd221d..d1a5f0c77 100644 --- a/.gitignore +++ b/.gitignore @@ -62,8 +62,9 @@ nosetests.xml .venv -# Do not include test output -/result_images/ + +# Do not include test output of matplotlib +result_images # Ignore Github codespace build artifact -oryx-build-commands.txt \ No newline at end of file +oryx-build-commands.txt diff --git a/CHANGELOG.md b/CHANGELOG.md index 21a910d88..c6abd903b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,8 +16,8 @@ * [Feature] Adds `--persist-replace` argument to replace existing tables when persisting data frames ([#440](https://github.com/ploomber/jupysql/issues/440)) * [Fix] Fix error when checking if custom connection was PEP 249 Compliant ([#517](https://github.com/ploomber/jupysql/issues/517)) * [Doc] documenting how to manage connections with `Connection` object ([#282](https://github.com/ploomber/jupysql/issues/282)) - * [Feature] Github Codespace (Devcontainer) support for development (by [@jorisroovers](https://github.com/jorisroovers)) ([#484](https://github.com/ploomber/jupysql/issues/484)) +* [Feature] Added bar plot and pie charts to %sqlplot ([#417](https://github.com/ploomber/jupysql/issues/417)) ## 0.7.5 (2023-05-24) diff --git a/doc/api/magic-plot.md b/doc/api/magic-plot.md index 4e03c425a..0ba67cb2c 100644 --- a/doc/api/magic-plot.md +++ b/doc/api/magic-plot.md @@ -160,3 +160,92 @@ ax = %sqlplot histogram --table no_nulls --column body_mass_g --with no_nulls ax.set_title("Body mass (grams)") _ = ax.grid() ``` +## `%sqlplot bar` + +```{versionadded} 0.7.6 +``` + +Shortcut: `%sqlplot bar` + +`-t`/`--table` Table to use (if using DuckDB: path to the file to query) + +`-c`/`--column` Column to plot. + +`-o`/`--orient` Barplot orientation (`h` for horizontal, `v` for vertical) + +`-w`/`--with` Use a previously saved query as input data + +`-S`/`--show-numbers` Show numbers on top of the bar + +Bar plot does not support NULL values, so we automatically remove them, when plotting. + +```{code-cell} ipython3 +%sqlplot bar --table penguins.csv --column species +``` + +You can additionally pass two columns to bar plot i.e. `x` and `height` columns. + +```{code-cell} ipython3 +%%sql --save add_col --no-execute +SELECT species, count(species) as cnt +FROM penguins.csv +group by species +``` + +```{code-cell} ipython3 +%sqlplot bar --table add_col --column species cnt --with add_col +``` + +You can also pass the orientation using the `orient` argument. + +```{code-cell} ipython3 +%sqlplot bar --table add_col --column species cnt --with add_col --orient h +``` + +You can also show the number on top of the bar using the `S`/`show-numbers` argument. + +```{code-cell} ipython3 +%sqlplot bar --table penguins.csv --column species -S +``` + +## `%sqlplot pie` + +```{versionadded} 0.7.6 +``` + +Shortcut: `%sqlplot pie` + +`-t`/`--table` Table to use (if using DuckDB: path to the file to query) + +`-c`/`--column` Column to plot + +`-w`/`--with` Use a previously saved query as input data + +`-S`/`--show-numbers` Show the percentage on top of the pie + +Pie chart does not support NULL values, so we automatically remove them, when plotting the pie chart. + +```{code-cell} ipython3 +%sqlplot pie --table penguins.csv --column species +``` + +You can additionally pass two columns to bar plot i.e. `labels` and `x` columns. + +```{code-cell} ipython3 +%%sql --save add_col --no-execute +SELECT species, count(species) as cnt +FROM penguins.csv +group by species +``` + +```{code-cell} ipython3 +%sqlplot pie --table add_col --column species cnt --with add_col +``` +Here, `species` is the `labels` column and `cnt` is the `x` column. + + +You can also show the percentage on top of the pie using the `S`/`show-numbers` argument. + +```{code-cell} ipython3 +%sqlplot pie --table penguins.csv --column species -S +``` \ No newline at end of file diff --git a/src/sql/magic_plot.py b/src/sql/magic_plot.py index 03fcae30e..c2c43e774 100644 --- a/src/sql/magic_plot.py +++ b/src/sql/magic_plot.py @@ -17,6 +17,8 @@ from sql import exceptions from sql import util +SUPPORTED_PLOTS = ["histogram", "boxplot", "bar", "pie"] + @magics_class class SqlPlotMagic(Magics, Configurable): @@ -51,6 +53,12 @@ class SqlPlotMagic(Magics, Configurable): action="append", dest="with_", ) + @argument( + "-S", + "--show-numbers", + action="store_true", + help="Show number of observations", + ) @modify_exceptions def execute(self, line="", cell="", local_ns=None): """ @@ -65,8 +73,10 @@ def execute(self, line="", cell="", local_ns=None): column = cmd.args.column if not cmd.args.line: + plot_str = util.pretty_print(SUPPORTED_PLOTS, last_delimiter="or") raise exceptions.UsageError( - "Missing the first argument, must be: 'histogram' or 'boxplot'. " + "Missing the first argument, must be any of: " + f"{plot_str}\n" "Example: %sqlplot histogram" ) @@ -92,7 +102,29 @@ def execute(self, line="", cell="", local_ns=None): with_=cmd.args.with_, conn=None, ) + elif cmd.args.line[0] in {"bar"}: + util.is_table_exists(table, with_=cmd.args.with_) + + return plot.bar( + table=table, + column=column, + with_=cmd.args.with_, + orient=cmd.args.orient, + show_num=cmd.args.show_numbers, + conn=None, + ) + elif cmd.args.line[0] in {"pie"}: + util.is_table_exists(table, with_=cmd.args.with_) + + return plot.pie( + table=table, + column=column, + with_=cmd.args.with_, + show_num=cmd.args.show_numbers, + conn=None, + ) else: + plot_str = util.pretty_print(SUPPORTED_PLOTS, last_delimiter="or") raise exceptions.UsageError( - f"Unknown plot {cmd.args.line[0]!r}. Must be: 'histogram' or 'boxplot'" + f"Unknown plot {cmd.args.line[0]!r}. Must be any of: " f"{plot_str}" ) diff --git a/src/sql/plot.py b/src/sql/plot.py index e6c55e842..ac572ae6b 100644 --- a/src/sql/plot.py +++ b/src/sql/plot.py @@ -605,3 +605,326 @@ def _histogram_stacked( data = conn.execute(query, with_).fetchall() return data + + +@modify_exceptions +def _bar(table, column, with_=None, conn=None): + """get x and height for bar plot""" + if not conn: + conn = sql.connection.Connection.current + use_backticks = conn.is_use_backtick_template() + + if isinstance(column, list): + if len(column) > 2: + raise exceptions.UsageError( + f"Passed columns: {column}\n" + "Bar chart currently supports, either a single column" + " on which group by and count is applied or" + " two columns: labels and size" + ) + + x_ = column[0] + height_ = column[1] + + print(f"Removing NULLs, if there exists any from {x_} and {height_}") + template_ = """ + select "{{x_}}" as x, + "{{height_}}" as height + from "{{table}}" + where "{{x_}}" is not null + and "{{height_}}" is not null; + """ + + xlabel = x_ + ylabel = height_ + + if use_backticks: + template_ = template_.replace('"', "`") + + template = Template(template_) + query = template.render(table=table, x_=x_, height_=height_) + + else: + print(f"Removing NULLs, if there exists any from {column}") + template_ = """ + select "{{column}}" as x, + count("{{column}}") as height + from "{{table}}" + where "{{column}}" is not null + group by "{{column}}"; + """ + + xlabel = column + ylabel = "Count" + + if use_backticks: + template_ = template_.replace('"', "`") + + template = Template(template_) + query = template.render(table=table, column=column) + + data = conn.execute(query, with_).fetchall() + + x, height = zip(*data) + + if x[0] is None: + raise ValueError("Data contains NULLs") + + return x, height, xlabel, ylabel + + +@requires(["matplotlib"]) +@telemetry.log_call("bar", payload=True) +def bar( + payload, + table, + column, + show_num=False, + orient="v", + with_=None, + conn=None, + cmap=None, + color=None, + edgecolor=None, + ax=None, +): + """Plot Bar Chart + + Parameters + ---------- + table : str + Table name where the data is located + + column : str + Column(s) to plot + + show_num: bool + Show numbers on top of plot + + orient : str, default='v' + Orientation of the plot. 'v' for vertical and 'h' for horizontal + + conn : connection, default=None + Database connection. If None, it uses the current connection + + Notes + ----- + + .. versionadded:: 0.7.6 + + Returns + ------- + ax : matplotlib.Axes + Generated plot + + """ + + if not conn: + conn = sql.connection.Connection.current + + ax = ax or plt.gca() + payload["connection_info"] = conn._get_curr_sqlalchemy_connection_info() + + if column is None: + raise exceptions.UsageError("Column name has not been specified") + + x, height_, xlabel, ylabel = _bar(table, column, with_=with_, conn=conn) + + if color and cmap: + # raise a userwarning + warnings.warn( + "Both color and cmap are given. cmap will be ignored", UserWarning + ) + + if (not color) and cmap: + cmap = plt.get_cmap(cmap) + norm = Normalize(vmin=0, vmax=len(x)) + color = [cmap(norm(i)) for i in range(len(x))] + + if orient == "h": + ax.barh( + x, + height_, + align="center", + edgecolor=edgecolor, + color=color, + ) + ax.set_xlabel(ylabel) + ax.set_ylabel(xlabel) + else: + ax.bar( + x, + height_, + align="center", + edgecolor=edgecolor, + color=color, + ) + ax.set_ylabel(ylabel) + ax.set_xlabel(xlabel) + + if show_num: + if orient == "v": + for i, v in enumerate(height_): + ax.text( + i, + v, + str(v), + color="black", + fontweight="bold", + ha="center", + va="bottom", + ) + else: + for i, v in enumerate(height_): + ax.text( + v, + i, + str(v), + color="black", + fontweight="bold", + ha="left", + va="center", + ) + + ax.set_title(table) + + return ax + + +@modify_exceptions +def _pie(table, column, with_=None, conn=None): + """get x and height for pie chart""" + if not conn: + conn = sql.connection.Connection.current + use_backticks = conn.is_use_backtick_template() + + if isinstance(column, list): + if len(column) > 2: + raise exceptions.UsageError( + f"Passed columns: {column}\n" + "Pie chart currently supports, either a single column" + " on which group by and count is applied or" + " two columns: labels and size" + ) + + labels_ = column[0] + size_ = column[1] + + print(f"Removing NULLs, if there exists any from {labels_} and {size_}") + template_ = """ + select "{{labels_}}" as labels, + "{{size_}}" as size + from "{{table}}" + where "{{labels_}}" is not null + and "{{size_}}" is not null; + """ + if use_backticks: + template_ = template_.replace('"', "`") + + template = Template(template_) + query = template.render(table=table, labels_=labels_, size_=size_) + + else: + print(f"Removing NULLs, if there exists any from {column}") + template_ = """ + select "{{column}}" as x, + count("{{column}}") as height + from "{{table}}" + where "{{column}}" is not null + group by "{{column}}"; + """ + if use_backticks: + template_ = template_.replace('"', "`") + + template = Template(template_) + query = template.render(table=table, column=column) + + data = conn.execute(query, with_).fetchall() + + labels, size = zip(*data) + + if labels[0] is None: + raise ValueError("Data contains NULLs") + + return labels, size + + +@requires(["matplotlib"]) +@telemetry.log_call("bar", payload=True) +def pie( + payload, + table, + column, + show_num=False, + with_=None, + conn=None, + cmap=None, + color=None, + ax=None, +): + """Plot Pie Chart + + Parameters + ---------- + table : str + Table name where the data is located + + column : str + Column(s) to plot + + show_num: bool + Show numbers on top of plot + + conn : connection, default=None + Database connection. If None, it uses the current connection + + Notes + ----- + + .. versionadded:: 0.7.6 + + Returns + ------- + ax : matplotlib.Axes + Generated plot + """ + + if not conn: + conn = sql.connection.Connection.current + + ax = ax or plt.gca() + payload["connection_info"] = conn._get_curr_sqlalchemy_connection_info() + + if column is None: + raise exceptions.UsageError("Column name has not been specified") + + labels, size_ = _pie(table, column, with_=with_, conn=conn) + + if color and cmap: + # raise a userwarning + warnings.warn( + "Both color and cmap are given. cmap will be ignored", UserWarning + ) + + if (not color) and cmap: + cmap = plt.get_cmap(cmap) + norm = Normalize(vmin=0, vmax=len(labels)) + color = [cmap(norm(i)) for i in range(len(labels))] + + if show_num: + ax.pie( + size_, + labels=labels, + colors=color, + autopct="%1.2f%%", + ) + else: + ax.pie( + size_, + labels=labels, + colors=color, + ) + + ax.set_title(table) + + return ax diff --git a/src/tests/baseline_images/test_magic_plot/bar_one_col.png b/src/tests/baseline_images/test_magic_plot/bar_one_col.png new file mode 100644 index 000000000..9096669d6 Binary files /dev/null and b/src/tests/baseline_images/test_magic_plot/bar_one_col.png differ diff --git a/src/tests/baseline_images/test_magic_plot/bar_one_col_h.png b/src/tests/baseline_images/test_magic_plot/bar_one_col_h.png new file mode 100644 index 000000000..1fb31d680 Binary files /dev/null and b/src/tests/baseline_images/test_magic_plot/bar_one_col_h.png differ diff --git a/src/tests/baseline_images/test_magic_plot/bar_one_col_null.png b/src/tests/baseline_images/test_magic_plot/bar_one_col_null.png new file mode 100644 index 000000000..9096669d6 Binary files /dev/null and b/src/tests/baseline_images/test_magic_plot/bar_one_col_null.png differ diff --git a/src/tests/baseline_images/test_magic_plot/bar_one_col_num_h.png b/src/tests/baseline_images/test_magic_plot/bar_one_col_num_h.png new file mode 100644 index 000000000..c09a93f02 Binary files /dev/null and b/src/tests/baseline_images/test_magic_plot/bar_one_col_num_h.png differ diff --git a/src/tests/baseline_images/test_magic_plot/bar_one_col_num_v.png b/src/tests/baseline_images/test_magic_plot/bar_one_col_num_v.png new file mode 100644 index 000000000..4b482d7c5 Binary files /dev/null and b/src/tests/baseline_images/test_magic_plot/bar_one_col_num_v.png differ diff --git a/src/tests/baseline_images/test_magic_plot/bar_two_col.png b/src/tests/baseline_images/test_magic_plot/bar_two_col.png new file mode 100644 index 000000000..2798537e2 Binary files /dev/null and b/src/tests/baseline_images/test_magic_plot/bar_two_col.png differ diff --git a/src/tests/baseline_images/test_magic_plot/pie_one_col.png b/src/tests/baseline_images/test_magic_plot/pie_one_col.png new file mode 100644 index 000000000..75d1a2669 Binary files /dev/null and b/src/tests/baseline_images/test_magic_plot/pie_one_col.png differ diff --git a/src/tests/baseline_images/test_magic_plot/pie_one_col_null.png b/src/tests/baseline_images/test_magic_plot/pie_one_col_null.png new file mode 100644 index 000000000..75d1a2669 Binary files /dev/null and b/src/tests/baseline_images/test_magic_plot/pie_one_col_null.png differ diff --git a/src/tests/baseline_images/test_magic_plot/pie_one_col_num.png b/src/tests/baseline_images/test_magic_plot/pie_one_col_num.png new file mode 100644 index 000000000..c3ce3f31a Binary files /dev/null and b/src/tests/baseline_images/test_magic_plot/pie_one_col_num.png differ diff --git a/src/tests/baseline_images/test_magic_plot/pie_two_col.png b/src/tests/baseline_images/test_magic_plot/pie_two_col.png new file mode 100644 index 000000000..fd65c2dc1 Binary files /dev/null and b/src/tests/baseline_images/test_magic_plot/pie_two_col.png differ diff --git a/src/tests/test_magic_plot.py b/src/tests/test_magic_plot.py index 780512092..5f78373c4 100644 --- a/src/tests/test_magic_plot.py +++ b/src/tests/test_magic_plot.py @@ -3,6 +3,12 @@ import pytest from IPython.core.error import UsageError import matplotlib.pyplot as plt +from sql import util + +from matplotlib.testing.decorators import image_comparison, _cleanup_cm + +SUPPORTED_PLOTS = ["bar", "boxplot", "histogram", "pie"] +plot_str = util.pretty_print(SUPPORTED_PLOTS, last_delimiter="or") @pytest.mark.parametrize( @@ -11,12 +17,12 @@ [ "%sqlplot someplot -t a -c b", UsageError, - "Unknown plot 'someplot'. Must be: 'histogram' or 'boxplot'", + f"Unknown plot 'someplot'. Must be any of: {plot_str}", ], [ "%sqlplot -t a -c b", UsageError, - "Missing the first argument, must be: 'histogram' or 'boxplot'", + f"Missing the first argument, must be any of: {plot_str}", ], ], ) @@ -49,7 +55,7 @@ def test_validate_arguments(tmp_empty, ip, cell, error_type, error_message): assert str(out.error_in_exec) == (error_message) -@pytest.mark.xfail() +@_cleanup_cm() @pytest.mark.parametrize( "cell", [ @@ -66,29 +72,40 @@ def test_validate_arguments(tmp_empty, ip, cell, error_type, error_message): "%sqlplot boxplot --table subset --column x --with subset", "%sqlplot boxplot -t subset -c x -w subset -o h", "%sqlplot boxplot --table nas.csv --column x", + "%sqlplot bar -t data.csv -c x", + "%sqlplot bar -t data.csv -c x -S", + "%sqlplot bar -t data.csv -c x -o h", + "%sqlplot bar -t data.csv -c x y", + "%sqlplot pie -t data.csv -c x", + "%sqlplot pie -t data.csv -c x -S", + "%sqlplot pie -t data.csv -c x y", + '%sqlplot boxplot --table spaces.csv --column "some column"', + '%sqlplot histogram --table spaces.csv --column "some column"', + '%sqlplot bar --table spaces.csv --column "some column"', + '%sqlplot pie --table spaces.csv --column "some column"', pytest.param( - "%sqlplot boxplot --table spaces.csv --column 'some column'", + "%sqlplot boxplot --table 'file with spaces.csv' --column x", marks=pytest.mark.xfail( sys.platform == "win32", reason="problem in IPython.core.magic_arguments.parse_argstring", ), ), pytest.param( - "%sqlplot histogram --table spaces.csv --column 'some column'", + "%sqlplot histogram --table 'file with spaces.csv' --column x", marks=pytest.mark.xfail( sys.platform == "win32", reason="problem in IPython.core.magic_arguments.parse_argstring", ), ), pytest.param( - "%sqlplot boxplot --table 'file with spaces.csv' --column x", + "%sqlplot bar --table 'file with spaces.csv' --column x", marks=pytest.mark.xfail( sys.platform == "win32", reason="problem in IPython.core.magic_arguments.parse_argstring", ), ), pytest.param( - "%sqlplot histogram --table 'file with spaces.csv' --column x", + "%sqlplot pie --table 'file with spaces.csv' --column x", marks=pytest.mark.xfail( sys.platform == "win32", reason="problem in IPython.core.magic_arguments.parse_argstring", @@ -106,10 +123,21 @@ def test_validate_arguments(tmp_empty, ip, cell, error_type, error_message): "boxplot-with", "boxplot-shortcuts", "boxplot-nas", + "bar-1-col", + "bar-1-col-show_num", + "bar-1-col-horizontal", + "bar-2-col", + "pie-1-col", + "pie-1-col-show_num", + "pie-2-col", "boxplot-column-name-with-spaces", "histogram-column-name-with-spaces", + "bar-column-name-with-spaces", + "pie-column-name-with-spaces", "boxplot-table-name-with-spaces", "histogram-table-name-with-spaces", + "bar-table-name-with-spaces", + "pie-table-name-with-spaces", ], ) def test_sqlplot(tmp_empty, ip, cell): @@ -166,3 +194,127 @@ def test_sqlplot(tmp_empty, ip, cell): # maptlotlib >= 3.7 has Axes but earlier Python # versions are not compatible assert type(out.result).__name__ in {"Axes", "AxesSubplot"} + + +@pytest.fixture +def load_data_two_col(ip): + if not Path("data_two.csv").is_file(): + Path("data_two.csv").write_text( + """\ +x, y +0, 0 +1, 1 +2, 2 +5, 7""" + ) + + ip.run_cell("%sql duckdb://") + + +@pytest.fixture +def load_data_one_col(ip): + if not Path("data_one.csv").is_file(): + Path("data_one.csv").write_text( + """\ +x +0 +0 +1 +1 +1 +2 +""" + ) + ip.run_cell("%sql duckdb://") + + +@pytest.fixture +def load_data_one_col_null(ip): + if not Path("data_one_null.csv").is_file(): + Path("data_one_null.csv").write_text( + """\ +x + +0 + +0 +1 + +1 +1 +2 +""" + ) + ip.run_cell("%sql duckdb://") + + +@_cleanup_cm() +@image_comparison(baseline_images=["bar_one_col"], extensions=["png"], remove_text=True) +def test_bar_one_col(load_data_one_col, ip): + ip.run_cell("%sqlplot bar -t data_one.csv -c x") + + +@_cleanup_cm() +@image_comparison( + baseline_images=["bar_one_col_null"], extensions=["png"], remove_text=True +) +def test_bar_one_col_null(load_data_one_col_null, ip): + ip.run_cell("%sqlplot bar -t data_one_null.csv -c x") + + +@_cleanup_cm() +@image_comparison( + baseline_images=["bar_one_col_h"], extensions=["png"], remove_text=True +) +def test_bar_one_col_h(load_data_one_col, ip): + ip.run_cell("%sqlplot bar -t data_one.csv -c x -o h") + + +@_cleanup_cm() +@image_comparison( + baseline_images=["bar_one_col_num_h"], extensions=["png"], remove_text=True +) +def test_bar_one_col_num_h(load_data_one_col, ip): + ip.run_cell("%sqlplot bar -t data_one.csv -c x -o h -S") + + +@_cleanup_cm() +@image_comparison( + baseline_images=["bar_one_col_num_v"], extensions=["png"], remove_text=True +) +def test_bar_one_col_num_v(load_data_one_col, ip): + ip.run_cell("%sqlplot bar -t data_one.csv -c x -S") + + +@_cleanup_cm() +@image_comparison(baseline_images=["bar_two_col"], extensions=["png"], remove_text=True) +def test_bar_two_col(load_data_two_col, ip): + ip.run_cell("%sqlplot bar -t data_two.csv -c x y") + + +@_cleanup_cm() +@image_comparison(baseline_images=["pie_one_col"], extensions=["png"], remove_text=True) +def test_pie_one_col(load_data_one_col, ip): + ip.run_cell("%sqlplot pie -t data_one.csv -c x") + + +@_cleanup_cm() +@image_comparison( + baseline_images=["pie_one_col_null"], extensions=["png"], remove_text=True +) +def test_pie_one_col_null(load_data_one_col_null, ip): + ip.run_cell("%sqlplot pie -t data_one_null.csv -c x") + + +@_cleanup_cm() +@image_comparison( + baseline_images=["pie_one_col_num"], extensions=["png"], remove_text=True +) +def test_pie_one_col_num(load_data_one_col, ip): + ip.run_cell("%sqlplot pie -t data_one.csv -c x -S") + + +@_cleanup_cm() +@image_comparison(baseline_images=["pie_two_col"], extensions=["png"], remove_text=True) +def test_pie_two_col(load_data_two_col, ip): + ip.run_cell("%sqlplot pie -t data_two.csv -c x y")