Skip to content

Commit

Permalink
Merge branch 'topic/bbannier/snapshot-testing'
Browse files Browse the repository at this point in the history
  • Loading branch information
bbannier committed Jul 16, 2024
2 parents 3a43e64 + fbb835c commit f8b6570
Show file tree
Hide file tree
Showing 15 changed files with 158 additions and 92 deletions.
18 changes: 10 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,35 +1,37 @@
exclude: tests/data
exclude: tests/(data|__snapshots__|samples)

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files

- repo: https://github.com/PyCQA/pylint
rev: v3.0.1
rev: v3.2.5
hooks:
- id: pylint
additional_dependencies:
- "pytest==8.2.2"
- "syrupy==4.6.1"
- "setuptools"
- "tree-sitter>=0.21.3"
- "tree-sitter-zeek"
- "tree-sitter==0.22.3"
- "tree-sitter-zeek==0.1.1"

- repo: https://github.com/psf/black
rev: 23.10.1
rev: 24.4.2
hooks:
- id: black

- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
rev: v3.16.0
hooks:
- id: pyupgrade
args: ["--py37-plus"]

- repo: https://github.com/igorshubovych/markdownlint-cli
rev: v0.37.0
rev: v0.41.0
hooks:
- id: markdownlint-fix
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@ keywords = [
requires-python = ">=3.9"

dependencies = [
"tree-sitter == 0.22.3",
"tree-sitter-zeek == 0.1.1",
"tree-sitter==0.22.3",
"tree-sitter-zeek==0.1.1",
]

[project.optional-dependencies]
dev = [
"pytest>=8.1.1",
"pytest==8.2.2",
"syrupy==4.6.1",
]

[[project.maintainers]]
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Installation setup."""

from setuptools import setup


Expand Down
File renamed without changes.
File renamed without changes.
73 changes: 41 additions & 32 deletions tests/test_dir_recursion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,67 +22,74 @@ def setUp(self):
shutil.rmtree("a", ignore_errors=True)
os.makedirs(join("a", "b", "c"))

shutil.copy(join(tu.DATA, "test1.zeek"), join("a", "test1.zeek"))
shutil.copy(join(tu.DATA, "test1.zeek"), join("a", "test2.zeek"))
shutil.copy(join(tu.DATA, "test1.zeek"), join("a", "b", "test3.txt"))
shutil.copy(join(tu.DATA, "test1.zeek"), join("a", "b", "test4.zeek"))
shutil.copy(join(tu.DATA, "test1.zeek"), join("a", "b", "c", "test5.zeek"))
for f in (
join("a", "test1.zeek"),
join("a", "test2.zeek"),
join("a", "b", "test3.txt"),
join("a", "b", "test4.zeek"),
join("a", "b", "c", "test5.zeek"),
):
with open(f, "w", encoding="utf-8") as h:
h.write(tu.SAMPLE_UNFORMATTED)

def tearDown(self):
shutil.rmtree("a", ignore_errors=True)

# pylint: disable-next=invalid-name
def assertEqualContent(self, file1, file2):
with open(file1, encoding="utf-8") as hdl1, open(
file2, encoding="utf-8"
) as hdl2:
self.assertEqual(hdl1.read(), hdl2.read())
def assertEqualContent(self, file1, content_expected):
with open(file1, encoding="utf-8") as hdl1:
self.assertEqual(hdl1.read(), content_expected)

# pylint: disable-next=invalid-name
def assertNotEqualContent(self, file1, file2):
with open(file1, encoding="utf-8") as hdl1, open(
file2, encoding="utf-8"
) as hdl2:
self.assertNotEqual(hdl1.read(), hdl2.read())
def assertNotEqualContent(self, file1, content_expected):
with open(file1, encoding="utf-8") as hdl1:
self.assertNotEqual(hdl1.read(), content_expected)

def test_recursive_formatting(self):
parser = argparse.ArgumentParser()
zeekscript.add_format_cmd(parser)
args = parser.parse_args(["-i", "-r", "a"])

# Python < 3.10 does not yet support parenthesized context managers:
with unittest.mock.patch(
"sys.stdout", new=io.StringIO()
) as out, unittest.mock.patch("sys.stderr", new=io.StringIO()):
with (
unittest.mock.patch("sys.stdout", new=io.StringIO()) as out,
unittest.mock.patch("sys.stderr", new=io.StringIO()),
):
ret = args.run_cmd(args)
self.assertEqual(ret, 0)
self.assertEqual(out.getvalue(), "4 files processed, 0 errors\n")

self.assertEqualContent(
join(tu.DATA, "test1.zeek.out"), join("a", "test1.zeek")
join("a", "test1.zeek"),
tu.SAMPLE_FORMATTED,
)
self.assertEqualContent(
join(tu.DATA, "test1.zeek.out"), join("a", "test2.zeek")
join("a", "test2.zeek"),
tu.SAMPLE_FORMATTED,
)
self.assertEqualContent(
join(tu.DATA, "test1.zeek.out"), join("a", "b", "test4.zeek")
join("a", "b", "test4.zeek"),
tu.SAMPLE_FORMATTED,
)
self.assertEqualContent(
join(tu.DATA, "test1.zeek.out"), join("a", "b", "c", "test5.zeek")
join("a", "b", "c", "test5.zeek"),
tu.SAMPLE_FORMATTED,
)

self.assertNotEqualContent(
join(tu.DATA, "test1.zeek.out"), join("a", "b", "test3.txt")
join("a", "b", "test3.txt"),
tu.SAMPLE_FORMATTED,
)

def test_recurse_inplace(self):
parser = argparse.ArgumentParser()
zeekscript.add_format_cmd(parser)
args = parser.parse_args(["-ir"])

with unittest.mock.patch("sys.stdout", new=io.StringIO()), unittest.mock.patch(
"sys.stderr", new=io.StringIO()
) as err:
with (
unittest.mock.patch("sys.stdout", new=io.StringIO()),
unittest.mock.patch("sys.stderr", new=io.StringIO()) as err,
):
ret = args.run_cmd(args)
self.assertEqual(ret, 0)
self.assertEqual(
Expand All @@ -95,9 +102,10 @@ def test_dir_without_recurse(self):
zeekscript.add_format_cmd(parser)
args = parser.parse_args(["-i", "a"])

with unittest.mock.patch("sys.stdout", new=io.StringIO()), unittest.mock.patch(
"sys.stderr", new=io.StringIO()
) as err:
with (
unittest.mock.patch("sys.stdout", new=io.StringIO()),
unittest.mock.patch("sys.stderr", new=io.StringIO()) as err,
):
ret = args.run_cmd(args)
self.assertEqual(ret, 0)
self.assertEqual(
Expand All @@ -110,9 +118,10 @@ def test_recurse_without_inplace(self):
zeekscript.add_format_cmd(parser)
args = parser.parse_args(["-r", "a"])

with unittest.mock.patch("sys.stdout", new=io.StringIO()), unittest.mock.patch(
"sys.stderr", new=io.StringIO()
) as err:
with (
unittest.mock.patch("sys.stdout", new=io.StringIO()),
unittest.mock.patch("sys.stderr", new=io.StringIO()) as err,
):
ret = args.run_cmd(args)
self.assertEqual(ret, 1)
self.assertEqual(
Expand Down
56 changes: 15 additions & 41 deletions tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,6 @@


class TestFormatting(unittest.TestCase):
def _get_input_and_baseline(self, filename):
with open(os.path.join(tu.DATA, filename), "rb") as hdl:
input_ = hdl.read()

with open(os.path.join(tu.DATA, filename + ".out"), "rb") as hdl:
output = hdl.read()

return input_, output

def _format(self, content):
script = zeekscript.Script(io.BytesIO(content))

Expand All @@ -35,17 +26,6 @@ def _format(self, content):

return buf.getvalue()

def test_file_formatting(self):
input_, baseline = self._get_input_and_baseline("test1.zeek")

# Format the input data and compare to baseline:
result1 = self._format(input_)
self.assertEqual(baseline, result1)

# Format the result again. There should be no change.
result2 = self._format(result1)
self.assertEqual(baseline, result2)

def test_interval(self):
self.assertEqual(self._format(b"1 sec;").rstrip(), b"1sec;")
self.assertEqual(self._format(b"1min;").rstrip(), b"1min;")
Expand Down Expand Up @@ -94,7 +74,8 @@ def test_format_comment_separator(self):

# We split out lines here to work around different line endings on Windows.
self.assertEqual(
self._format(code.encode()).decode().splitlines(), expected.splitlines()
self._format(code.encode("UTF-8")).decode().splitlines(),
expected.splitlines(),
)


Expand Down Expand Up @@ -213,34 +194,27 @@ class TestNewlineFormatting(unittest.TestCase):
# This test verifies correct processing when line endings in the input
# differ from that normally used by the platform.

def _get_formatted_and_baseline(self, filename):
def test_file_formatting(self):
given = tu.SAMPLE_UNFORMATTED.encode("utf-8")
# Swap line endings for something not native to the platform:
with open(os.path.join(tu.DATA, filename), "rb") as hdl:
data = hdl.read()
if zeekscript.Formatter.NL == b"\n":
# Turn everything to \r\n, even if mixed
data = data.replace(b"\r\n", b"\n")
data = data.replace(b"\n", b"\r\n")
else:
data = data.replace(b"\r\n", b"\n")

buf = io.BytesIO(data)

if zeekscript.Formatter.NL == b"\n":
# Turn everything to \r\n, even if mixed
given = given.replace(b"\r\n", b"\n")
given = given.replace(b"\n", b"\r\n")
else:
given = given.replace(b"\r\n", b"\n")

buf = io.BytesIO(given)
script = zeekscript.Script(buf)
script.parse()

buf = io.BytesIO()
script.format(buf)

with open(os.path.join(tu.DATA, filename + ".out"), "rb") as hdl:
result_wanted = hdl.read()

result_is = buf.getvalue()
return result_wanted, result_is
given = buf.getvalue()

def test_file_formatting(self):
result_wanted, result_is = self._get_formatted_and_baseline("test1.zeek")
self.assertEqual(result_wanted, result_is)
expected = tu.SAMPLE_FORMATTED.encode("utf-8")
self.assertEqual(expected, given)


class TestScriptConstruction(unittest.TestCase):
Expand Down
65 changes: 65 additions & 0 deletions tests/test_samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Snapshot test of formatting of sample files.
Code samples are located in tests/samples. To add a new sample create a new
file and create a snapshot of the formatting result with
$ pytest --snapshot-update
Commit both the sample as well as the updated snapshot.
"""

import io
from pathlib import Path

import pytest
from syrupy.assertion import SnapshotAssertion
from syrupy.extensions.single_file import SingleFileSnapshotExtension

from testutils import zeekscript

SAMPLES_DIR = Path(__file__).parent / "samples"


# Use a custom snapshot fixture so we emit one file per generated test case
# instead of one per module.
@pytest.fixture
def snapshot(snapshot: SnapshotAssertion): # pylint: disable=redefined-outer-name
return snapshot.use_extension(SingleFileSnapshotExtension)


def _format(script: zeekscript.Script):
"""Formats a given `Script`"""
buf = io.BytesIO()
script.format(buf)
return buf.getvalue()


def _get_samples():
"""Helper to enumerate samples"""

# We exclude directories since we store snapshots along with the samples.
# This assumes that there are no tests in subdirectories of `SAMPLES_DIR`.
try:
return [sample for sample in SAMPLES_DIR.iterdir() if sample.is_file()]
except FileNotFoundError:
return []


# For each file in `SAMPLES_DIR` test formatting of the file.
@pytest.mark.parametrize("sample", _get_samples())
# pylint: disable-next=redefined-outer-name
def test_samples(sample: Path, snapshot: SnapshotAssertion):
input_ = zeekscript.Script(sample)

assert input_.parse(), f"failed to parse input {sample}"
assert not input_.has_error(), f"parse result for {sample} has parse errors"

name = str(sample.relative_to(SAMPLES_DIR.parent.parent))

output = _format(input_)
assert output == snapshot(
name=name
), f"formatted {sample} inconsistent with snapshot"

output2 = _format(input_)
assert output2 == output, f"idempotency violation for {sample}"
14 changes: 13 additions & 1 deletion tests/testutils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Helpers for the various test_*.py files."""

import os
import sys

TESTS = os.path.dirname(os.path.realpath(__file__))
ROOT = os.path.normpath(os.path.join(TESTS, ".."))
DATA = os.path.normpath(os.path.join(TESTS, "data"))

# Prepend the tree's root folder to the module searchpath so we find zeekscript
# via it. This allows tests to run without package installation. (We do need a
Expand Down Expand Up @@ -36,3 +36,15 @@ def normalize(content):
out = content

return fix_lineseps(out)


# A small unformatted source sample for general testing.
SAMPLE_UNFORMATTED = """\
global foo=1 +2 ;
"""


# A small formatted source sample for general testing.
SAMPLE_FORMATTED = """\
global foo = 1 + 2;
"""
1 change: 1 addition & 0 deletions zeekscript/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Wrapper around more low-level tests."""

__version__ = "1.2.9-15"
__all__ = ["cli", "error", "formatter", "node", "output", "script"]

Expand Down
1 change: 1 addition & 0 deletions zeekscript/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""This module provides reusable command line parsers and tooling."""

import argparse
import io
import os
Expand Down
Loading

0 comments on commit f8b6570

Please sign in to comment.