Skip to content

Commit

Permalink
Preliminary implementation of the "preserve comments" feature
Browse files Browse the repository at this point in the history
This implements the basic behaviour requested in issue #23: it isn't neither perfect nor
beautiful, but... what is?
  • Loading branch information
lelit committed May 3, 2021
1 parent bdf573a commit 649ecb5
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 19 deletions.
38 changes: 27 additions & 11 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,12 @@ In this case, you can use a variant that uses the lexical *scanner* instead:
select 1 from
select 2

Reformat a ``SQL`` statement from the command line
==================================================
------------
Command line
------------

Reformat a ``SQL`` statement
============================

.. code-block:: shell
Expand All @@ -357,8 +361,7 @@ Reformat a ``SQL`` statement from the command line
, c
FROM sometable
$ echo "select a, case when a=1 then 'singular' else 'plural' end from test" > /tmp/q.sql
$ pgpp /tmp/q.sql
$ pgpp -S "select a, case when a=1 then 'singular' else 'plural' end from test"
SELECT a
, CASE
WHEN (a = 1)
Expand Down Expand Up @@ -386,27 +389,25 @@ Get a more compact representation

.. code-block:: shell
$ echo "select a,b,c from st where a='longvalue1' and b='longvalue2'" \
| pgpp --compact 30
$ pgpp --compact 30 -S "select a,b,c from st where a='longvalue1' and b='longvalue2'"
SELECT a, b, c
FROM st
WHERE (a = 'longvalue1')
AND (b = 'longvalue2')
.. code-block:: shell
$ echo "select a,b,c from st where a='longvalue1' and b='longvalue2'" \
| pgpp --compact 60
$ pgpp --compact 60 -S "select a,b,c from st where a='longvalue1' and b='longvalue2'"
SELECT a, b, c
FROM st
WHERE (a = 'longvalue1') AND (b = 'longvalue2')
Obtain the *parse tree* of a ``SQL`` statement from the command line
====================================================================
Obtain the *parse tree* of a ``SQL`` statement
==============================================

.. code-block:: shell
$ echo "select 1" | pgpp --parse-tree
$ pgpp --parse-tree --statement "select 1"
[{'@': 'RawStmt',
'stmt': {'@': 'SelectStmt',
'all': False,
Expand All @@ -420,6 +421,21 @@ Obtain the *parse tree* of a ``SQL`` statement from the command line
'stmt_len': 0,
'stmt_location': 0}]
Preserve comments
=================

.. code-block:: shell
$ pgpp --preserve-comments -S "/* Header */ select 1"
/* Header */
SELECT 1
.. code-block:: shell
$ echo -e "--what?\nselect foo\n--where?\nfrom bar" | pgpp -C
/* what? */
SELECT foo
FROM /* where? */ bar
---

Expand Down
11 changes: 9 additions & 2 deletions pglast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .error import Error
from .node import Missing, Node
try:
from .parser import fingerprint, get_postgresql_version, parse_sql, split
from .parser import fingerprint, get_postgresql_version, parse_sql, scan, split
except ModuleNotFoundError:
# bootstrap
pass
Expand All @@ -31,12 +31,13 @@ def parse_plpgsql(statement):
return loads(parse_plpgsql_json(statement))


def prettify(statement, safety_belt=True, **options):
def prettify(statement, safety_belt=True, preserve_comments=False, **options):
r"""Render given `statement` into a prettified format.
:param str statement: the SQL statement(s)
:param bool safety_belt: whether to perform a safe check against bugs in pglast's
serialization
:param bool preserve_comments: whether comments shall be preserved, defaults to not
:param \*\*options: any keyword option accepted by :class:`~.printer.IndentedStream`
constructor
:returns: a string with the equivalent prettified statement(s)
Expand All @@ -53,6 +54,12 @@ def prettify(statement, safety_belt=True, **options):
from .printer import IndentedStream
from . import printers # noqa

if preserve_comments:
comments = options['comments'] = []
for token in scan(statement):
if token.name in ('C_COMMENT', 'SQL_COMMENT'):
comments.append((token.start, statement[token.start:token.end+1]))

orig_pt = parse_sql(statement)
prettified = IndentedStream(**options)(Node(orig_pt))
if safety_belt:
Expand Down
3 changes: 3 additions & 0 deletions pglast/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def workhorse(args):
try:
prettified = prettify(
statement,
preserve_comments=args.preserve_comments,
compact_lists_margin=args.compact_lists_margin,
split_string_literals_threshold=args.split_string_literals,
special_functions=args.special_functions,
Expand Down Expand Up @@ -77,6 +78,8 @@ def main(options=None):
' after each item')
parser.add_argument('-e', '--semicolon-after-last-statement', action='store_true',
default=False, help='end the last statement with a semicolon')
parser.add_argument('-C', '--preserve-comments', action='store_true',
default=False, help="preserve comments in the statement")
parser.add_argument('-S', '--statement',
help='the SQL statement')
parser.add_argument('infile', nargs='?', type=argparse.FileType(),
Expand Down
60 changes: 56 additions & 4 deletions pglast/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from . import parse_plpgsql, parse_sql
from .error import Error
from .node import List, Node, Scalar
from .node import List, Missing, Node, Scalar
from .keywords import RESERVED_KEYWORDS, TYPE_FUNC_NAME_KEYWORDS


Expand Down Expand Up @@ -197,21 +197,24 @@ class RawStream(OutputStream):
:param bool semicolon_after_last_statement:
``False`` by default, when ``True`` add a semicolon after the last statement,
otherwise it is emitted only as a separator between multiple statements
:param comments: optional sequence of tuples with the comments extracted from the statement
This augments :class:`OutputStream` and implements the basic machinery needed to serialize
the *parse tree* produced by :func:`~.parser.parse_sql()` back to a textual representation,
without any adornment.
"""

def __init__(self, expression_level=0, separate_statements=1, special_functions=False,
comma_at_eoln=False, semicolon_after_last_statement=False):
comma_at_eoln=False, semicolon_after_last_statement=False,
comments=None):
super().__init__()
self.current_column = 0
self.expression_level = expression_level
self.separate_statements = separate_statements
self.special_functions = special_functions
self.comma_at_eoln = comma_at_eoln
self.semicolon_after_last_statement = semicolon_after_last_statement
self.current_column = 0
self.comments = comments

def show(self, where=stderr): # pragma: no cover
"""Emit also current expression_level and a "pointer" showing current_column."""
Expand Down Expand Up @@ -248,8 +251,14 @@ def __call__(self, sql, plpgsql=False):
for _ in range(self.separate_statements):
self.newline()
self.print_node(statement)

if self.semicolon_after_last_statement:
self.write(';')

if self.comments:
while self.comments:
self.print_comment(self.comments.pop(0)[1])

return self.getvalue()

def dedent(self):
Expand Down Expand Up @@ -342,6 +351,33 @@ def _print_scalar(self, node, is_name, is_symbol):
else:
self.write(str(value))

def print_comment(self, comment):
if comment.startswith('--'):
comment = comment[2:].strip()
else:
comment = comment[2:-2].strip()
if comment:
cc = self.current_column
is_before_anything = self.tell() == 0
# if not is_before_anything:
# self.newline()
self.write('/* ')
lines = comment.splitlines()
if len(lines) > 1:
with self.push_indent():
for line in lines:
self.write(line)
self.newline()
else:
self.write(lines[0])
self.write(' ')
self.write('*/')
if is_before_anything:
self.newline()
else:
self.space()
self.current_column = cc

def print_name(self, nodes, sep='.'):
"Helper method, execute :meth:`print_node` or :meth:`print_list` as needed."

Expand All @@ -368,6 +404,17 @@ def print_node(self, node, is_name=False, is_symbol=False):
whether this is the name of an *operator*, should not be double quoted
"""

if self.comments:
if hasattr(node, 'location'):
node_location = getattr(node, 'location')
elif hasattr(node, 'stmt_location'):
node_location = getattr(node, 'stmt_location')
else:
node_location = Missing
if node_location is not Missing:
if self.comments[0][0] <= node_location.value:
self.print_comment(self.comments.pop(0)[1])

if isinstance(node, Scalar):
self._print_scalar(node, is_name, is_symbol)
elif is_name and isinstance(node, (List, list)):
Expand Down Expand Up @@ -526,7 +573,7 @@ def indent(self, amount=0, relative=True):

self.indentation_stack.append(self.current_indent)
base_indent = (self.current_column if relative else self.current_indent)
assert base_indent + amount >= 0
assert base_indent + amount >= 0, f'base_indent={base_indent} amount={amount}'
self.current_indent = base_indent + amount

@contextmanager
Expand Down Expand Up @@ -560,6 +607,11 @@ def space(self, count=1):

self.write(' '*count)

def print_comment(self, comment):
ci = self.current_indent
super().print_comment(comment)
self.current_indent = ci

def print_list(self, nodes, sep=',', relative_indent=None, standalone_items=None,
are_names=False, is_symbol=False):
"""Execute :meth:`print_node` on all the `nodes`, separating them with `sep`.
Expand Down
6 changes: 6 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,9 @@ def test_cli_workhorse():
SELECT 'abc'
'def'
"""

with StringIO("Select /* one */ 1") as input:
with UnclosableStream() as output:
with redirect_stdin(input), redirect_stdout(output):
main(['--preserve-comments'])
assert output.getvalue() == "SELECT /* one */ 1\n"
10 changes: 8 additions & 2 deletions tests/test_printers_prettification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
# :Created: dom 17 mar 2019 10:46:03 CET
# :Author: Lele Gaifax <lele@metapensiero.it>
# :License: GNU General Public License version 3 or later
# :Copyright: © 2019, 2020 Lele Gaifax
# :Copyright: © 2019, 2020, 2021 Lele Gaifax
#

from ast import literal_eval
from pathlib import Path

import pytest

from pglast import scan
from pglast.printer import IndentedStream
import pglast.printers

Expand Down Expand Up @@ -82,10 +83,15 @@ def test_prettification(src, lineno, case):
parts = case.split('\n=\n')
original = parts[0].strip()
parts = parts[1].split('\n:\n')
expected = parts[0].strip().replace('\\n\\\n', '\n')
expected = parts[0].strip().replace('\\n\\\n', '\n').replace('\\s', ' ')
if len(parts) == 2:
options = literal_eval(parts[1])
else:
options = {}
if options.pop('preserve_comments', False):
comments = options['comments'] = []
for token in scan(original):
if token.name in ('C_COMMENT', 'SQL_COMMENT'):
comments.append((token.start, original[token.start:token.end+1]))
prettified = IndentedStream(**options)(original)
assert expected == prettified, "%s:%d:%r != %r" % (src, lineno, expected, prettified)
10 changes: 10 additions & 0 deletions tests/test_printers_prettification/dml/select.sql
Original file line number Diff line number Diff line change
Expand Up @@ -369,3 +369,13 @@ INTERSECT
GROUP BY y
LIMIT 3)
LIMIT 2

/*
header
*/ select /*one*/ 1
/*footer*/
=
/* header */
SELECT /* one */ 1 /* footer */\s
:
{'preserve_comments': True}

0 comments on commit 649ecb5

Please sign in to comment.