diff --git a/.github/workflows/deploy_development_docs.yml b/.github/workflows/deploy_development_docs.yml new file mode 100644 index 00000000..308c29d7 --- /dev/null +++ b/.github/workflows/deploy_development_docs.yml @@ -0,0 +1,22 @@ +name: deploy development docs +on: + push: + branches: [main] + +jobs: + build: + name: Deploy development docs + runs-on: ubuntu-latest + steps: + - name: Checkout main + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + + - run: pip install hatch + - name: publish docs + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + git config user.name "Frank Hoffmann" + git config user.email "15r10nk@users.noreply.github.com" + git fetch origin gh-pages --depth=1 + hatch run docs:mike deploy --push development diff --git a/.github/workflows/test_docs.yml b/.github/workflows/test_docs.yml index 3cb7623b..1ffa1911 100644 --- a/.github/workflows/test_docs.yml +++ b/.github/workflows/test_docs.yml @@ -1,8 +1,6 @@ name: test docs on: pull_request: - push: - branches: [main] jobs: build: @@ -13,11 +11,6 @@ jobs: uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - run: pip install hatch - - name: publish docs - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: test docs run: | - git config user.name "Frank Hoffmann" - git config user.email "15r10nk@users.noreply.github.com" - git fetch origin gh-pages --depth=1 - hatch run docs:mike deploy --push development + hatch run docs:build diff --git a/README.md b/README.md index 720d9bcc..81c9a264 100644 --- a/README.md +++ b/README.md @@ -27,12 +27,19 @@ pip install inline-snapshot ## Key Features - **Intuitive Semantics:** `snapshot(x)` mirrors `x` for easy understanding. -- **Versatile Comparison Support:** Equipped with `x == snapshot(...)`, `x <= snapshot(...)`, `x in snapshot(...)`, and `snapshot(...)[key]`. +- **Versatile Comparison Support:** Equipped with + [`x == snapshot(...)`](https://15r10nk.github.io/inline-snapshot/latest/eq_snapshot/), + [`x <= snapshot(...)`](https://15r10nk.github.io/inline-snapshot/latest/cmp_snapshot/), + [`x in snapshot(...)`](https://15r10nk.github.io/inline-snapshot/latest/in_snapshot/), and + [`snapshot(...)[key]`](https://15r10nk.github.io/inline-snapshot/latest/getitem_snapshot/). - **Enhanced Control Flags:** Utilize various [flags](https://15r10nk.github.io/inline-snapshot/latest/pytest/) for precise control of which snapshots you want to change. - **Preserved Black Formatting:** Retains formatting consistency with Black formatting. - **External File Storage:** Store snapshots externally using `outsource(data)`. - **Seamless Pytest Integration:** Integrated seamlessly with pytest for effortless testing. - **Customizable:** code generation can be customized with [@customize_repr](https://15r10nk.github.io/inline-snapshot/latest/customize_repr) +- **Nested Snapshot Support:** snapshots can contain [other snapshots](https://15r10nk.github.io/inline-snapshot/eq_snapshot/#inner-snapshots) +- **Fuzzy Matching:** Incorporate [dirty-equals](https://15r10nk.github.io/inline-snapshot/eq_snapshot/#dirty-equals) for flexible comparisons within snapshots. +- **Dynamic Snapshot Content:** snashots can contain [non-constant values](https://15r10nk.github.io/inline-snapshot/eq_snapshot/#is) - **Comprehensive Documentation:** Access detailed [documentation](https://15r10nk.github.io/inline-snapshot/latest) for complete guidance. diff --git a/changelog.d/20241113_170718_15r10nk-git_fixes.md b/changelog.d/20241113_170718_15r10nk-git_fixes.md new file mode 100644 index 00000000..010d9246 --- /dev/null +++ b/changelog.d/20241113_170718_15r10nk-git_fixes.md @@ -0,0 +1,38 @@ + + + +### Fixed + +- inline-snapshot checks now if the given command line flags (`--inline-snapshot=...`) are valid + + + + + diff --git a/changelog.d/20241113_170944_15r10nk-git_fixes.md b/changelog.d/20241113_170944_15r10nk-git_fixes.md new file mode 100644 index 00000000..cb0218e6 --- /dev/null +++ b/changelog.d/20241113_170944_15r10nk-git_fixes.md @@ -0,0 +1,39 @@ + + + + + + +### Fixed +- `Example(...).run_pytest(raise=snapshot(...))` uses now the flags from the current run and not the flags from the Example. + + diff --git a/changelog.d/20241210_063450_15r10nk-git_container.md b/changelog.d/20241210_063450_15r10nk-git_container.md new file mode 100644 index 00000000..d14ed163 --- /dev/null +++ b/changelog.d/20241210_063450_15r10nk-git_container.md @@ -0,0 +1,62 @@ + + + +### Added + +- snapshots [inside snapshots](https://15r10nk.github.io/inline-snapshot/latest/eq_snapshot/#inner-snapshots) are now supported. + + ``` python + assert get_schema() == snapshot( + [ + { + "name": "var_1", + "type": snapshot("int") if version < 2 else snapshot("string"), + } + ] + ) + ``` + +- [runtime values](https://15r10nk.github.io/inline-snapshot/latest/eq_snapshot/#is) can now be part of snapshots. + + ``` python + from inline_snapshot import snapshot, Is + + current_version = "1.5" + assert request() == snapshot( + {"data": "page data", "version": Is(current_version)} + ) + ``` + +- [f-strings](https://15r10nk.github.io/inline-snapshot/latest/eq_snapshot/#f-strings) can now also be used within snapshots, but are currently not *fixed* by inline-snapshot. + +### Changed + +- *dirty-equals* expressions are now treated like *runtime values* or *snapshots* within snapshots and are not modified by inline-snapshot. + + + + diff --git a/docs/customize_repr.md b/docs/customize_repr.md index 177548f6..ae8cc103 100644 --- a/docs/customize_repr.md +++ b/docs/customize_repr.md @@ -3,6 +3,7 @@ `repr()` can be used to convert a python object into a source code representation of the object, but this does not work for every type. Here are some examples: + ```pycon >>> repr(int) "" @@ -42,7 +43,7 @@ def test_enum(): inline-snapshot comes with a special implementation for the following types: -```python exec="1" +``` python exec="1" from inline_snapshot._code_repr import code_repr_dispatch, code_repr for name, obj in sorted( @@ -58,16 +59,14 @@ for name, obj in sorted( print(f"- `{name}`") ``` -Container types like `dict` or `dataclass` need a special implementation because it is necessary that the implementation uses `repr()` for the child elements. +!!! note + Container types like `dict`, `list`, `tuple` or `dataclass` are handled in a different way, because inline-snapshot also needs to inspect these types to implement [unmanaged](/eq_snapshot.md#unmanaged-snapshot-values) snapshot values. + + -```python exec="1" result="python" -print('--8<-- "src/inline_snapshot/_code_repr.py:list"') -``` -!!! note - using `#!python f"{obj!r}"` or `#!c PyObject_Repr()` will not work, because inline-snapshot replaces `#!python builtins.repr` during the code generation. -## customize +## customize recursive repr You can also use `repr()` inside `__repr__()`, if you want to make your own type compatible with inline-snapshot. @@ -102,6 +101,13 @@ def test_enum(): assert Pair(E.a, [E.b]) == snapshot(Pair(E.a, [E.b])) ``` +!!! note + using `#!python f"{obj!r}"` or `#!c PyObject_Repr()` will not work, because inline-snapshot replaces `#!python builtins.repr` during the code generation. The only way to use the custom repr implementation is to use the `repr()` function. + +!!! note + This implementation allows inline-snapshot to use the custom `repr()` recursively, but it does not allow you to use [unmanaged](/eq_snapshot.md#unmanaged-snapshot-parts) snapshot values like `#!python Pair(Is(some_var),5)` + + you can also customize the representation of datatypes in other libraries: ``` python diff --git a/docs/eq_snapshot.md b/docs/eq_snapshot.md index 94b02f5f..e3ff7673 100644 --- a/docs/eq_snapshot.md +++ b/docs/eq_snapshot.md @@ -33,9 +33,34 @@ Example: def test_something(): assert 2 + 40 == snapshot(42) ``` +## unmanaged snapshot values +inline-snapshots manages everything inside `snapshot(...)`, which means that the developer should not change these parts, but there are cases where it is useful to give the developer the control over the snapshot content back. -## dirty-equals +Therefor some types will be ignored by inline-snapshot and will **not be updated or fixed**, even if they cause tests to fail. + +These types are: + +* [dirty-equals](#dirty-equals) expressions, +* [dynamic code](#is) inside `Is(...)`, +* [snapshots](#inner-snapshots) inside snapshots and +* [f-strings](#f-strings). + +inline-snapshot is able to handle these types within the following containers: + +* list +* tuple +* dict +* namedtuple +* dataclass + + +Other types are converted with a [customizable](customize_repr.md) `repr()` into code. It is not possible to use unmanaged snapshot values within these objects. + +### dirty-equals It might be, that larger snapshots with many lists and dictionaries contain some values which change frequently and are not relevant for the test. They might be part of larger data structures and be difficult to normalize. @@ -80,9 +105,7 @@ Example: ) ``` -inline-snapshot tries to change only the values that it needs to change in order to pass the equality comparison. -This allows to replace parts of the snapshot with [dirty-equals](https://dirty-equals.helpmanual.io/latest/) expressions. -This expressions are preserved as long as the `==` comparison with them is `True`. +The date can be replaced with the [dirty-equals](https://dirty-equals.helpmanual.io/latest/) expression `IsDatetime()`. Example: @@ -159,8 +182,203 @@ Example: ) ``` -!!! note - The current implementation looks only into lists, dictionaries and tuples and not into the representation of other data structures. +### Is(...) + +`Is()` can be used to put runtime values inside snapshots. +It tells inline-snapshot that the developer wants control over some part of the snapshot. + + +``` python +from inline_snapshot import snapshot, Is + +current_version = "1.5" + + +def request(): + return {"data": "page data", "version": current_version} + + +def test_function(): + assert request() == snapshot( + {"data": "page data", "version": Is(current_version)} + ) +``` + +The snapshot does not need to be fixed when `current_version` changes in the future, but `"page data"` will still be fixed if it changes. + +`Is()` can also be used when the snapshot is evaluated multiple times, which is useful in loops or parametrized tests. + +=== "original code" + + ``` python + from inline_snapshot import snapshot, Is + + + def test_function(): + for c in "abc": + assert [c, "correct"] == snapshot([Is(c), "wrong"]) + ``` + +=== "--inline-snapshot=fix" + + ``` python hl_lines="6" + from inline_snapshot import snapshot, Is + + + def test_function(): + for c in "abc": + assert [c, "correct"] == snapshot([Is(c), "correct"]) + ``` + +### inner snapshots + +Snapshots can be used inside other snapshots in different use cases. + +#### conditional snapshots +It is also possible to use snapshots inside snapshots. + +This is useful to describe version specific parts of snapshots by replacing the specific part with `#!python snapshot() if some_condition else snapshot()`. +The test has to be executed in each specific condition to fill the snapshots. + +The following example shows how this can be used to run a tests with two different library versions: + +=== "my_lib.py v1" + + + ``` python + version = 1 + + + def get_schema(): + return [{"name": "var_1", "type": "int"}] + ``` + +=== "my_lib.py v2" + + + ``` python + version = 2 + + + def get_schema(): + return [{"name": "var_1", "type": "string"}] + ``` + + + +``` python +from inline_snapshot import snapshot +from my_lib import version, get_schema + + +def test_function(): + assert get_schema() == snapshot( + [ + { + "name": "var_1", + "type": snapshot("int") if version < 2 else snapshot("string"), + } + ] + ) +``` + +The advantage of this approach is that the test uses always the correct values for each library version. + +You can also extract the version logic into its own function. + +``` python +from inline_snapshot import snapshot, Snapshot +from my_lib import version, get_schema + + +def version_snapshot(v1: Snapshot, v2: Snapshot): + return v1 if version < 2 else v2 + + +def test_function(): + assert get_schema() == snapshot( + [ + { + "name": "var_1", + "type": version_snapshot( + v1=snapshot("int"), v2=snapshot("string") + ), + } + ] + ) +``` + +#### common snapshot parts + +Another use case is the extraction of common snapshot parts into an extra snapshot: + + +``` python +from inline_snapshot import snapshot + + +def some_data(name): + return {"header": "really long header\n" * 5, "your name": name} + + +def test_function(): + + header = snapshot( + """\ +really long header +really long header +really long header +really long header +really long header +""" + ) + + assert some_data("Tom") == snapshot( + { + "header": header, + "your name": "Tom", + } + ) + + assert some_data("Bob") == snapshot( + { + "header": header, + "your name": "Bob", + } + ) +``` + +This simplifies test data and allows inline-snapshot to update your values if required. +It makes also sure that the header is the same in both cases. + + +### f-strings + +*f-strings* are not generated by inline-snapshot, but they can be used in snapshots if you want to replace some dynamic part of a string value. + + +``` python +from inline_snapshot import snapshot + + +def get_error(): + # example code which generates an error message + return __file__ + ": error at line 5" + + +def test_get_error(): + assert get_error() == snapshot(f"{__file__}: error at line 5") +``` + +It is not required to wrap the changed value in `Is(f"...")`, because inline-snapshot knows that *f-strings* are only generated by the developer. + +!!! Warning "Limitation" + inline-snapshot is currently not able to fix the string constants within *f-strings*. + + `#!python f"...{var}..."` works **currently** like `#!python Is(f"...{var}...")` and issues a warning if the value changes, giving you the opportunity to fix your f-string. + + `#!python f"...{var}..."` will in the **future** work like `#!python f"...{Is(var)}"`. inline-snapshot will then be able to *fix* the string parts within the f-string. + ## pytest options diff --git a/docs/pytest.md b/docs/pytest.md index 5d825618..a7d34aca 100644 --- a/docs/pytest.md +++ b/docs/pytest.md @@ -11,7 +11,7 @@ inline-snapshot provides one pytest option with different flags (*create*, Snapshot comparisons return always `True` if you use one of the flags *create*, *fix* or *review*. This is necessary because the whole test needs to be run to fix all snapshots like in this case: -```python +``` python from inline_snapshot import snapshot @@ -30,7 +30,7 @@ def test_something(): Approve the changes of the given [category](categories.md). These flags can be combined with *report* and *review*. -```python title="test_something.py" +``` python title="test_something.py" from inline_snapshot import snapshot diff --git a/mkdocs.yml b/mkdocs.yml index 82efbc38..04fe6417 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -22,6 +22,11 @@ theme: media: '(prefers-color-scheme: dark)' primary: teal +validation: + links: + absolute_links: relative_to_docs + + watch: - CONTRIBUTING.md - CHANGELOG.md diff --git a/pyproject.toml b/pyproject.toml index 003b1f70..5e62305b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,13 +75,6 @@ omit = [ parallel = true source_pkgs = ["inline_snapshot", "tests"] -[tool.hatch.envs.coverage] -dependencies = [ - "coverage" -] -env-vars.TOP = "{root}" -scripts.report = "coverage html" - [tool.hatch.envs.docs] dependencies = [ "markdown-exec[ansi]>=1.8.0", @@ -128,6 +121,12 @@ extra-dependencies = [ ] env-vars.TOP = "{root}" +[tool.hatch.envs.hatch-test.scripts] +run = "pytest{env:HATCH_TEST_ARGS:} {args}" +run-cov = "coverage run -m pytest{env:HATCH_TEST_ARGS:} {args}" +cov-combine = "coverage combine" +cov-report=["coverage report","coverage html"] + [tool.hatch.envs.types] extra-dependencies = [ "mypy>=1.0.0", @@ -173,9 +172,5 @@ venvPath = ".nox" format = "md" version = "command: cz bump --get-next" -[tool.inline-snapshot.shortcuts] -fix=["create","fix"] -review=["create","review"] - [tool.pytest.ini_options] markers=["no_rewriting: marks tests which need no code rewriting and can be used with pypy"] diff --git a/src/inline_snapshot/__init__.py b/src/inline_snapshot/__init__.py index 813fcdbf..3766bafe 100644 --- a/src/inline_snapshot/__init__.py +++ b/src/inline_snapshot/__init__.py @@ -3,9 +3,19 @@ from ._external import external from ._external import outsource from ._inline_snapshot import snapshot +from ._is import Is from ._types import Category from ._types import Snapshot -__all__ = ["snapshot", "external", "outsource", "customize_repr", "HasRepr"] +__all__ = [ + "snapshot", + "external", + "outsource", + "customize_repr", + "HasRepr", + "Is", + "Category", + "Snapshot", +] __version__ = "0.14.2" diff --git a/src/inline_snapshot/_adapter/__init__.py b/src/inline_snapshot/_adapter/__init__.py new file mode 100644 index 00000000..2f699011 --- /dev/null +++ b/src/inline_snapshot/_adapter/__init__.py @@ -0,0 +1,3 @@ +from .adapter import get_adapter_type + +__all__ = ("get_adapter_type",) diff --git a/src/inline_snapshot/_adapter/adapter.py b/src/inline_snapshot/_adapter/adapter.py new file mode 100644 index 00000000..b915ec62 --- /dev/null +++ b/src/inline_snapshot/_adapter/adapter.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import ast +import typing + +from inline_snapshot._source_file import SourceFile + + +def get_adapter_type(value): + from inline_snapshot._adapter.generic_call_adapter import get_adapter_for_type + + adapter = get_adapter_for_type(type(value)) + if adapter is not None: + return adapter + + if isinstance(value, list): + from .sequence_adapter import ListAdapter + + return ListAdapter + + if type(value) is tuple: + from .sequence_adapter import TupleAdapter + + return TupleAdapter + + if isinstance(value, dict): + from .dict_adapter import DictAdapter + + return DictAdapter + + from .value_adapter import ValueAdapter + + return ValueAdapter + + +class Item(typing.NamedTuple): + value: typing.Any + node: ast.expr + + +class Adapter: + context: SourceFile + + def __init__(self, context): + self.context = context + + def get_adapter(self, old_value, new_value) -> Adapter: + if type(old_value) is not type(new_value): + from .value_adapter import ValueAdapter + + return ValueAdapter(self.context) + + adapter_type = get_adapter_type(old_value) + if adapter_type is not None: + return adapter_type(self.context) + assert False + + def assign(self, old_value, old_node, new_value): + raise NotImplementedError(cls) + + @classmethod + def map(cls, value, map_function): + raise NotImplementedError(cls) + + @classmethod + def repr(cls, value): + raise NotImplementedError(cls) + + +def adapter_map(value, map_function): + return get_adapter_type(value).map(value, map_function) diff --git a/src/inline_snapshot/_adapter/dict_adapter.py b/src/inline_snapshot/_adapter/dict_adapter.py new file mode 100644 index 00000000..4e0cf940 --- /dev/null +++ b/src/inline_snapshot/_adapter/dict_adapter.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import ast +import warnings + +from .._change import Delete +from .._change import DictInsert +from ..syntax_warnings import InlineSnapshotSyntaxWarning +from .adapter import Adapter +from .adapter import adapter_map +from .adapter import Item + + +class DictAdapter(Adapter): + + @classmethod + def repr(cls, value): + result = ( + "{" + + ", ".join(f"{repr(k)}: {repr(value)}" for k, value in value.items()) + + "}" + ) + + if type(value) is not dict: + result = f"{repr(type(value))}({result})" + + return result + + @classmethod + def map(cls, value, map_function): + return {k: adapter_map(v, map_function) for k, v in value.items()} + + def items(self, value, node): + if node is None: + return [Item(value=value, node=None) for value in value.values()] + + assert isinstance(node, ast.Dict) + + result = [] + + for value_key, node_key, node_value in zip( + value.keys(), node.keys, node.values + ): + try: + # this is just a sanity check, dicts should be ordered + node_key = ast.literal_eval(node_key) + except Exception: + pass + else: + assert node_key == value_key + + result.append(Item(value=value[value_key], node=node_value)) + + return result + + def assign(self, old_value, old_node, new_value): + if old_node is not None: + assert isinstance(old_node, ast.Dict) + assert len(old_value) == len(old_node.keys) + + for key, value in zip(old_node.keys, old_node.values): + if key is None: + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=self.context._source.filename, + lineno=value.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return old_value + + for value, node in zip(old_value.keys(), old_node.keys): + + try: + # this is just a sanity check, dicts should be ordered + node_value = ast.literal_eval(node) + except: + continue + assert node_value == value + + result = {} + for key, node in zip( + old_value.keys(), + (old_node.values if old_node is not None else [None] * len(old_value)), + ): + if not key in new_value: + # delete entries + yield Delete("fix", self.context._source, node, old_value[key]) + + to_insert = [] + insert_pos = 0 + for key, new_value_element in new_value.items(): + if key not in old_value: + # add new values + to_insert.append((key, new_value_element)) + result[key] = new_value_element + else: + if isinstance(old_node, ast.Dict): + node = old_node.values[list(old_value.keys()).index(key)] + else: + node = None + # check values with same keys + result[key] = yield from self.get_adapter( + old_value[key], new_value[key] + ).assign(old_value[key], node, new_value[key]) + + if to_insert: + new_code = [ + (self.context._value_to_code(k), self.context._value_to_code(v)) + for k, v in to_insert + ] + yield DictInsert( + "fix", + self.context._source, + old_node, + insert_pos, + new_code, + to_insert, + ) + to_insert = [] + + insert_pos += 1 + + if to_insert: + new_code = [ + (self.context._value_to_code(k), self.context._value_to_code(v)) + for k, v in to_insert + ] + yield DictInsert( + "fix", + self.context._source, + old_node, + len(old_value), + new_code, + to_insert, + ) + + return result diff --git a/src/inline_snapshot/_adapter/generic_call_adapter.py b/src/inline_snapshot/_adapter/generic_call_adapter.py new file mode 100644 index 00000000..95142945 --- /dev/null +++ b/src/inline_snapshot/_adapter/generic_call_adapter.py @@ -0,0 +1,360 @@ +from __future__ import annotations + +import ast +import warnings +from abc import ABC +from collections import defaultdict +from dataclasses import fields +from dataclasses import is_dataclass +from dataclasses import MISSING +from typing import Any + +from inline_snapshot._adapter.value_adapter import ValueAdapter + +from .._change import CallArg +from .._change import Delete +from ..syntax_warnings import InlineSnapshotSyntaxWarning +from .adapter import Adapter +from .adapter import adapter_map +from .adapter import Item + + +def get_adapter_for_type(typ): + subclasses = GenericCallAdapter.__subclasses__() + options = [cls for cls in subclasses if cls.check_type(typ)] + + if not options: + return + + assert len(options) == 1 + return options[0] + + +class Argument: + value: Any + is_default: bool = False + + def __init__(self, value, is_default=False): + self.value = value + self.is_default = is_default + + +class GenericCallAdapter(Adapter): + + @classmethod + def check_type(cls, typ) -> bool: + raise NotImplementedError(cls) + + @classmethod + def arguments(cls, value) -> tuple[list[Argument], dict[str, Argument]]: + raise NotImplementedError(cls) + + @classmethod + def argument(cls, value, pos_or_name) -> Any: + raise NotImplementedError(cls) + + @classmethod + def repr(cls, value): + + args, kwargs = cls.arguments(value) + + arguments = [repr(value.value) for value in args] + [ + f"{key}={repr(value.value)}" + for key, value in kwargs.items() + if not value.is_default + ] + + return f"{repr(type(value))}({', '.join(arguments)})" + + @classmethod + def map(cls, value, map_function): + new_args, new_kwargs = cls.arguments(value) + return type(value)( + *[adapter_map(arg.value, map_function) for arg in new_args], + **{ + k: adapter_map(kwarg.value, map_function) + for k, kwarg in new_kwargs.items() + }, + ) + + def items(self, value, node): + assert isinstance(node, ast.Call) + assert not node.args + assert all(kw.arg for kw in node.keywords) + + return [ + Item(value=self.argument(value, kw.arg), node=kw.value) + for kw in node.keywords + if kw.arg + ] + + def assign(self, old_value, old_node, new_value): + if old_node is None: + value = yield from ValueAdapter(self.context).assign( + old_value, old_node, new_value + ) + return value + + assert isinstance(old_node, ast.Call) + + # positional arguments + for pos_arg in old_node.args: + if isinstance(pos_arg, ast.Starred): + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=self.context._source.filename, + lineno=pos_arg.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return old_value + + # keyword arguments + for kw in old_node.keywords: + if kw.arg is None: + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=self.context._source.filename, + lineno=kw.value.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return old_value + + new_args, new_kwargs = self.arguments(new_value) + + # positional arguments + + result_args = [] + + for i, (new_value_element, node) in enumerate(zip(new_args, old_node.args)): + old_value_element = self.argument(old_value, i) + result = yield from self.get_adapter( + old_value_element, new_value_element.value + ).assign(old_value_element, node, new_value_element.value) + result_args.append(result) + + if len(old_node.args) > len(new_args): + for arg_pos, node in list(enumerate(old_node.args))[len(new_args) :]: + yield Delete( + "fix", + self.context._source, + node, + self.argument(old_value, arg_pos), + ) + + if len(old_node.args) < len(new_args): + for insert_pos, value in list(enumerate(new_args))[len(old_node.args) :]: + yield CallArg( + flag="fix", + file=self.context._source, + node=old_node, + arg_pos=insert_pos, + arg_name=None, + new_code=self.context._value_to_code(value.value), + new_value=value.value, + ) + + # keyword arguments + result_kwargs = {} + for kw in old_node.keywords: + if (missing := not kw.arg in new_kwargs) or new_kwargs[kw.arg].is_default: + # delete entries + yield Delete( + "fix" if missing else "update", + self.context._source, + kw.value, + self.argument(old_value, kw.arg), + ) + + old_node_kwargs = {kw.arg: kw.value for kw in old_node.keywords} + + to_insert = [] + insert_pos = 0 + for key, new_value_element in new_kwargs.items(): + if new_value_element.is_default: + continue + if key not in old_node_kwargs: + # add new values + to_insert.append((key, new_value_element.value)) + result_kwargs[key] = new_value_element.value + else: + node = old_node_kwargs[key] + + # check values with same keys + old_value_element = self.argument(old_value, key) + result_kwargs[key] = yield from self.get_adapter( + old_value_element, new_value_element.value + ).assign(old_value_element, node, new_value_element.value) + + if to_insert: + for key, value in to_insert: + + yield CallArg( + flag="fix", + file=self.context._source, + node=old_node, + arg_pos=insert_pos, + arg_name=key, + new_code=self.context._value_to_code(value), + new_value=value, + ) + to_insert = [] + + insert_pos += 1 + + if to_insert: + + for key, value in to_insert: + + yield CallArg( + flag="fix", + file=self.context._source, + node=old_node, + arg_pos=insert_pos, + arg_name=key, + new_code=self.context._value_to_code(value), + new_value=value, + ) + + return type(old_value)(*result_args, **result_kwargs) + + +class DataclassAdapter(GenericCallAdapter): + + @classmethod + def check_type(cls, value): + return is_dataclass(value) + + @classmethod + def arguments(cls, value): + + kwargs = {} + + for field in fields(value): # type: ignore + if field.repr: + field_value = getattr(value, field.name) + is_default = False + + if field.default != MISSING and field.default == field_value: + is_default = True + + if ( + field.default_factory != MISSING + and field.default_factory() == field_value + ): + is_default = True + + kwargs[field.name] = Argument(value=field_value, is_default=is_default) + + return ([], kwargs) + + def argument(self, value, pos_or_name): + assert isinstance(pos_or_name, str) + return getattr(value, pos_or_name) + + +try: + from pydantic import BaseModel +except ImportError: # pragma: no cover + pass +else: + from pydantic_core import PydanticUndefined + + class PydanticContainer(GenericCallAdapter): + + @classmethod + def check_type(cls, value): + return issubclass(value, BaseModel) + + @classmethod + def arguments(cls, value): + + kwargs = {} + + for name, field in value.model_fields.items(): # type: ignore + if field.repr: + field_value = getattr(value, name) + is_default = False + + if ( + field.default is not PydanticUndefined + and field.default == field_value + ): + is_default = True + + if ( + field.default_factory is not None + and field.default_factory() == field_value + ): + is_default = True + + kwargs[name] = Argument(value=field_value, is_default=is_default) + + return ([], kwargs) + + @classmethod + def argument(cls, value, pos_or_name): + assert isinstance(pos_or_name, str) + return getattr(value, pos_or_name) + + +class IsNamedTuple(ABC): + _inline_snapshot_name = "namedtuple" + + _fields: tuple + _field_defaults: dict + + @classmethod + def __subclasshook__(cls, t): + b = t.__bases__ + if len(b) != 1 or b[0] != tuple: + return False + f = getattr(t, "_fields", None) + if not isinstance(f, tuple): + return False + return all(type(n) == str for n in f) + + +class NamedTupleAdapter(GenericCallAdapter): + + @classmethod + def check_type(cls, value): + return issubclass(value, IsNamedTuple) + + @classmethod + def arguments(cls, value: IsNamedTuple): + + return ( + [], + { + field: Argument(value=getattr(value, field)) + for field in value._fields + if field not in value._field_defaults + or getattr(value, field) != value._field_defaults[field] + }, + ) + + def argument(self, value, pos_or_name): + assert isinstance(pos_or_name, str) + return getattr(value, pos_or_name) + + +class DefaultDictAdapter(GenericCallAdapter): + @classmethod + def check_type(cls, value): + return issubclass(value, defaultdict) + + @classmethod + def arguments(cls, value: defaultdict): + + return ( + [Argument(value=value.default_factory), Argument(value=dict(value))], + {}, + ) + + def argument(self, value, pos_or_name): + assert isinstance(pos_or_name, int) + if pos_or_name == 0: + return value.default_factory + elif pos_or_name == 1: + return dict(value) + assert False diff --git a/src/inline_snapshot/_adapter/sequence_adapter.py b/src/inline_snapshot/_adapter/sequence_adapter.py new file mode 100644 index 00000000..b48a452c --- /dev/null +++ b/src/inline_snapshot/_adapter/sequence_adapter.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import ast +import warnings +from collections import defaultdict + +from .._align import add_x +from .._align import align +from .._change import Delete +from .._change import ListInsert +from .._compare_context import compare_context +from ..syntax_warnings import InlineSnapshotSyntaxWarning +from .adapter import Adapter +from .adapter import adapter_map +from .adapter import Item + + +class SequenceAdapter(Adapter): + node_type: type + value_type: type + braces: str + trailing_comma: bool + + @classmethod + def repr(cls, value): + if len(value) == 1 and cls.trailing_comma: + seq = repr(value[0]) + "," + else: + seq = ", ".join(map(repr, value)) + return cls.braces[0] + seq + cls.braces[1] + + @classmethod + def map(cls, value, map_function): + result = [adapter_map(v, map_function) for v in value] + return cls.value_type(result) + + def items(self, value, node): + if node is None: + return [Item(value=v, node=None) for v in value] + + assert isinstance(node, self.node_type), (node, self) + assert len(value) == len(node.elts) + + return [Item(value=v, node=n) for v, n in zip(value, node.elts)] + + def assign(self, old_value, old_node, new_value): + if old_node is not None: + assert isinstance( + old_node, ast.List if isinstance(old_value, list) else ast.Tuple + ) + + for e in old_node.elts: + if isinstance(e, ast.Starred): + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=self.context.filename, + lineno=e.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return old_value + + with compare_context(): + diff = add_x(align(old_value, new_value)) + old = zip( + old_value, + old_node.elts if old_node is not None else [None] * len(old_value), + ) + new = iter(new_value) + old_position = 0 + to_insert = defaultdict(list) + result = [] + for c in diff: + if c in "mx": + old_value_element, old_node_element = next(old) + new_value_element = next(new) + v = yield from self.get_adapter( + old_value_element, new_value_element + ).assign(old_value_element, old_node_element, new_value_element) + result.append(v) + old_position += 1 + elif c == "i": + new_value_element = next(new) + new_code = self.context._value_to_code(new_value_element) + result.append(new_value_element) + to_insert[old_position].append((new_code, new_value_element)) + elif c == "d": + old_value_element, old_node_element = next(old) + yield Delete( + "fix", self.context._source, old_node_element, old_value_element + ) + old_position += 1 + else: + assert False + + for position, code_values in to_insert.items(): + yield ListInsert( + "fix", self.context._source, old_node, position, *zip(*code_values) + ) + + return self.value_type(result) + + +class ListAdapter(SequenceAdapter): + node_type = ast.List + value_type = list + braces = "[]" + trailing_comma = False + + +class TupleAdapter(SequenceAdapter): + node_type = ast.Tuple + value_type = tuple + braces = "()" + trailing_comma = True diff --git a/src/inline_snapshot/_adapter/value_adapter.py b/src/inline_snapshot/_adapter/value_adapter.py new file mode 100644 index 00000000..fc3ad351 --- /dev/null +++ b/src/inline_snapshot/_adapter/value_adapter.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import ast +import warnings + +from inline_snapshot._code_repr import value_code_repr +from inline_snapshot._unmanaged import Unmanaged +from inline_snapshot._unmanaged import update_allowed +from inline_snapshot._utils import value_to_token +from inline_snapshot.syntax_warnings import InlineSnapshotInfo + +from .._change import Replace +from .adapter import Adapter + + +class ValueAdapter(Adapter): + + @classmethod + def repr(cls, value): + return value_code_repr(value) + + @classmethod + def map(cls, value, map_function): + return map_function(value) + + def assign(self, old_value, old_node, new_value): + # generic fallback + + # because IsStr() != IsStr() + if isinstance(old_value, Unmanaged): + return old_value + + if old_node is None: + new_token = [] + else: + new_token = value_to_token(new_value) + + if isinstance(old_node, ast.JoinedStr) and isinstance(new_value, str): + if not old_value == new_value: + warnings.warn_explicit( + f"inline-snapshot will be able to fix f-strings in the future.\nThe current string value is:\n {new_value!r}", + filename=self.context._source.filename, + lineno=old_node.lineno, + category=InlineSnapshotInfo, + ) + return old_value + + if not old_value == new_value: + flag = "fix" + elif ( + old_node is not None + and update_allowed(old_value) + and self.context._token_of_node(old_node) != new_token + ): + flag = "update" + else: + # equal and equal repr + return old_value + + new_code = self.context._token_to_code(new_token) + + yield Replace( + node=old_node, + file=self.context._source, + new_code=new_code, + flag=flag, + old_value=old_value, + new_value=new_value, + ) + + return new_value diff --git a/src/inline_snapshot/_change.py b/src/inline_snapshot/_change.py index 691d4f7a..05c888f7 100644 --- a/src/inline_snapshot/_change.py +++ b/src/inline_snapshot/_change.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from typing import Any from typing import cast +from typing import DefaultDict from typing import Dict from typing import List from typing import Optional @@ -11,7 +12,7 @@ from asttokens.util import Token from executing.executing import EnhancedAST -from executing.executing import Source +from inline_snapshot._source_file import SourceFile from ._rewrite_code import ChangeRecorder from ._rewrite_code import end_of @@ -21,11 +22,11 @@ @dataclass() class Change: flag: str - source: Source + file: SourceFile @property def filename(self): - return self.source.filename + return self.file.filename def apply(self): raise NotImplementedError() @@ -76,7 +77,7 @@ class Replace(Change): def apply(self): change = ChangeRecorder.current.new_change() - range = self.source.asttokens().get_text_positions(self.node, False) + range = self.file.asttokens().get_text_positions(self.node, False) change.replace(range, self.new_code, filename=self.filename) @@ -87,40 +88,21 @@ class CallArg(Change): arg_name: Optional[str] new_code: str - old_value: Any new_value: Any - def apply(self): - change = ChangeRecorder.current.new_change() - tokens = list(self.source.asttokens().get_tokens(self.node)) - - call = self.node - tokens = list(self.source.asttokens().get_tokens(call)) - assert isinstance(call, ast.Call) - assert len(call.args) == 0 - assert len(call.keywords) == 0 - assert tokens[-2].string == "(" - assert tokens[-1].string == ")" - - assert self.arg_pos == 0 - assert self.arg_name == None - - change = ChangeRecorder.current.new_change() - change.set_tags("inline_snapshot") - change.replace( - (end_of(tokens[-2]), start_of(tokens[-1])), - self.new_code, - filename=self.filename, - ) +TokenRange = Tuple[Token, Token] -TokenRange = Tuple[Token, Token] +def brace_tokens(source, node) -> TokenRange: + first_token, *_, end_token = source.asttokens().get_tokens(node) + return first_token, end_token def generic_sequence_update( - source: Source, - parent: Union[ast.List, ast.Tuple, ast.Dict], + source: SourceFile, + parent: Union[ast.List, ast.Tuple, ast.Dict, ast.Call], + brace_tokens: TokenRange, parent_elements: List[Union[TokenRange, None]], to_insert: Dict[int, List[str]], ): @@ -128,7 +110,7 @@ def generic_sequence_update( new_code = [] deleted = False - last_token, *_, end_token = source.asttokens().get_tokens(parent) + last_token, end_token = brace_tokens is_start = True elements = 0 @@ -169,7 +151,7 @@ def generic_sequence_update( code = ", " + code if elements == 1 and isinstance(parent, ast.Tuple): - # trailing comma for tuples (1,)i + # trailing comma for tuples (1,) code += "," rec.replace( @@ -180,21 +162,23 @@ def generic_sequence_update( def apply_all(all_changes: List[Change]): - by_parent: Dict[EnhancedAST, List[Union[Delete, DictInsert, ListInsert]]] = ( - defaultdict(list) - ) - sources: Dict[EnhancedAST, Source] = {} + by_parent: Dict[ + EnhancedAST, List[Union[Delete, DictInsert, ListInsert, CallArg]] + ] = defaultdict(list) + sources: Dict[EnhancedAST, SourceFile] = {} for change in all_changes: if isinstance(change, Delete): node = cast(EnhancedAST, change.node).parent + if isinstance(node, ast.keyword): + node = node.parent by_parent[node].append(change) - sources[node] = change.source + sources[node] = change.file - elif isinstance(change, (DictInsert, ListInsert)): + elif isinstance(change, (DictInsert, ListInsert, CallArg)): node = cast(EnhancedAST, change.node) by_parent[node].append(change) - sources[node] = change.source + sources[node] = change.file else: change.apply() @@ -218,11 +202,57 @@ def list_token_range(entry): generic_sequence_update( source, parent, + brace_tokens(source, parent), [None if e in to_delete else list_token_range(e) for e in parent.elts], to_insert, ) - elif isinstance(parent, (ast.Dict)): + elif isinstance(parent, ast.Call): + to_delete = { + change.node for change in changes if isinstance(change, Delete) + } + atok = source.asttokens() + + def arg_token_range(node): + if isinstance(node.parent, ast.keyword): + node = node.parent + r = list(atok.get_tokens(node)) + return r[0], r[-1] + + braces_left = atok.next_token(list(atok.get_tokens(parent.func))[-1]) + assert braces_left.string == "(" + braces_right = list(atok.get_tokens(parent))[-1] + assert braces_right.string == ")" + + to_insert = DefaultDict(list) + + for change in changes: + if isinstance(change, CallArg): + if change.arg_name is not None: + position = ( + change.arg_pos + if change.arg_pos is not None + else len(parent.args) + len(parent.keywords) + ) + to_insert[position].append( + f"{change.arg_name} = {change.new_code}" + ) + else: + assert change.arg_pos is not None + to_insert[change.arg_pos].append(change.new_code) + + generic_sequence_update( + source, + parent, + (braces_left, braces_right), + [ + None if e in to_delete else arg_token_range(e) + for e in parent.args + [kw.value for kw in parent.keywords] + ], + to_insert, + ) + + elif isinstance(parent, ast.Dict): to_delete = { change.node for change in changes if isinstance(change, Delete) } @@ -241,6 +271,7 @@ def dict_token_range(key, value): generic_sequence_update( source, parent, + brace_tokens(source, parent), [ None if value in to_delete else dict_token_range(key, value) for key, value in zip(parent.keys, parent.values) diff --git a/src/inline_snapshot/_code_repr.py b/src/inline_snapshot/_code_repr.py index 9a5dcd3a..0f878d77 100644 --- a/src/inline_snapshot/_code_repr.py +++ b/src/inline_snapshot/_code_repr.py @@ -1,14 +1,10 @@ import ast -from abc import ABC -from collections import defaultdict -from dataclasses import fields -from dataclasses import is_dataclass -from dataclasses import MISSING from enum import Enum from enum import Flag from functools import singledispatch from unittest import mock + real_repr = repr @@ -62,7 +58,7 @@ def customize_repr(f): """Register a funtion which should be used to get the code representation of a object. - ```python + ``` python @customize_repr def _(obj: MyCustomClass): return f"MyCustomClass(attr={repr(obj.attr)})" @@ -78,8 +74,27 @@ def _(obj: MyCustomClass): def code_repr(obj): - with mock.patch("builtins.repr", code_repr): - result = code_repr_dispatch(obj) + + with mock.patch("builtins.repr", mocked_code_repr): + return mocked_code_repr(obj) + + +def mocked_code_repr(obj): + from inline_snapshot._adapter.adapter import get_adapter_type + + adapter = get_adapter_type(obj) + assert adapter is not None + return adapter.repr(obj) + + +def value_code_repr(obj): + if not type(obj) == type(obj): + # dispatch will not work in cases like this + return ( + f"HasRepr({repr(type(obj))}, '< type(obj) can not be compared with == >')" + ) + + result = code_repr_dispatch(obj) try: ast.parse(result) @@ -104,59 +119,6 @@ def _(value: Flag): return " | ".join(f"{name}.{flag.name}" for flag in type(value) if flag in value) -# -8<- [start:list] -@customize_repr -def _(value: list): - return "[" + ", ".join(map(repr, value)) + "]" - - -# -8<- [end:list] - - -class OnlyTuple(ABC): - _inline_snapshot_name = "builtins.tuple" - - @classmethod - def __subclasshook__(cls, t): - return t is tuple - - -@customize_repr -def _(value: OnlyTuple): - assert isinstance(value, tuple) - if len(value) == 1: - return f"({repr(value[0])},)" - return "(" + ", ".join(map(repr, value)) + ")" - - -class IsNamedTuple(ABC): - _inline_snapshot_name = "namedtuple" - - _fields: tuple - _field_defaults: dict - - @classmethod - def __subclasshook__(cls, t): - b = t.__bases__ - if len(b) != 1 or b[0] != tuple: - return False - f = getattr(t, "_fields", None) - if not isinstance(f, tuple): - return False - return all(type(n) == str for n in f) - - -@customize_repr -def _(value: IsNamedTuple): - params = ", ".join( - f"{field}={repr(getattr(value,field))}" - for field in value._fields - if field not in value._field_defaults - or getattr(value, field) != value._field_defaults[field] - ) - return f"{repr(type(value))}({params})" - - @customize_repr def _(value: set): if len(value) == 0: @@ -173,71 +135,6 @@ def _(value: frozenset): return "frozenset({" + ", ".join(map(repr, value)) + "})" -@customize_repr -def _(value: dict): - result = ( - "{" + ", ".join(f"{repr(k)}: {repr(value)}" for k, value in value.items()) + "}" - ) - - if type(value) is not dict: - result = f"{repr(type(value))}({result})" - - return result - - -@customize_repr -def _(value: defaultdict): - return f"defaultdict({repr(value.default_factory)}, {repr(dict(value))})" - - @customize_repr def _(value: type): return value.__qualname__ - - -class IsDataclass(ABC): - _inline_snapshot_name = "dataclass" - - @classmethod - def __subclasshook__(cls, subclass): - return is_dataclass(subclass) - - -@customize_repr -def _(value: IsDataclass): - attrs = [] - for field in fields(value): # type: ignore - if field.repr: - field_value = getattr(value, field.name) - - if field.default != MISSING and field.default == field_value: - continue - - if ( - field.default_factory != MISSING - and field.default_factory() == field_value - ): - continue - - attrs.append(f"{field.name}={repr(field_value)}") - - return f"{repr(type(value))}({', '.join(attrs)})" - - -try: - from pydantic import BaseModel -except ImportError: # pragma: no cover - pass -else: - - @customize_repr - def _(model: BaseModel): - return ( - type(model).__qualname__ - + "(" - + ", ".join( - e + "=" + repr(getattr(model, e)) - for e in sorted(model.__pydantic_fields_set__) - ) - + ")" - ) diff --git a/src/inline_snapshot/_compare_context.py b/src/inline_snapshot/_compare_context.py new file mode 100644 index 00000000..104a235c --- /dev/null +++ b/src/inline_snapshot/_compare_context.py @@ -0,0 +1,17 @@ +from contextlib import contextmanager + + +def compare_only(): + return _eq_check_only + + +_eq_check_only = False + + +@contextmanager +def compare_context(): + global _eq_check_only + old_eq_only = _eq_check_only + _eq_check_only = True + yield + _eq_check_only = old_eq_only diff --git a/src/inline_snapshot/_inline_snapshot.py b/src/inline_snapshot/_inline_snapshot.py index 26c98478..a8741fdf 100644 --- a/src/inline_snapshot/_inline_snapshot.py +++ b/src/inline_snapshot/_inline_snapshot.py @@ -1,21 +1,20 @@ import ast import copy import inspect -import tokenize -import warnings -from collections import defaultdict -from pathlib import Path from typing import Any from typing import Dict # noqa from typing import Iterator +from typing import List from typing import Set from typing import Tuple # noqa from typing import TypeVar from executing import Source +from inline_snapshot._adapter.adapter import Adapter +from inline_snapshot._adapter.adapter import adapter_map +from inline_snapshot._source_file import SourceFile -from ._align import add_x -from ._align import align +from ._adapter import get_adapter_type from ._change import CallArg from ._change import Change from ._change import Delete @@ -23,21 +22,18 @@ from ._change import ListInsert from ._change import Replace from ._code_repr import code_repr +from ._compare_context import compare_only from ._exceptions import UsageError -from ._format import format_code from ._sentinels import undefined from ._types import Category -from ._utils import ignore_tokens -from ._utils import normalize -from ._utils import simple_token +from ._types import Snapshot +from ._unmanaged import map_unmanaged +from ._unmanaged import Unmanaged +from ._unmanaged import update_allowed from ._utils import value_to_token -class NotImplementedYet(Exception): - pass - - -snapshots = {} # type: Dict[Tuple[int, int], Snapshot] +snapshots = {} # type: Dict[Tuple[int, int], SnapshotReference] _active = False @@ -54,10 +50,6 @@ def _return(result): return result -class InlineSnapshotSyntaxWarning(Warning): - pass - - class Flags: """ fix: the value needs to be changed to pass the tests @@ -86,36 +78,44 @@ def ignore_old_value(): return _update_flags.fix or _update_flags.update -class GenericValue: +class GenericValue(Snapshot): _new_value: Any _old_value: Any _current_op = "undefined" _ast_node: ast.Expr - _source: Source + _file: SourceFile - def _token_of_node(self, node): + def get_adapter(self, value): + return get_adapter_type(value)(self._file) - return list( - normalize( - [ - simple_token(t.type, t.string) - for t in self._source.asttokens().get_tokens(node) - if t.type not in ignore_tokens - ] - ) - ) + def _re_eval(self, value): - def _format(self, text): - if self._source is None: - return text - else: - return format_code(text, Path(self._source.filename)) + def re_eval(old_value, node, value): + if isinstance(old_value, Unmanaged): + old_value.value = value + return + + assert type(old_value) is type(value) - def _token_to_code(self, tokens): - return self._format(tokenize.untokenize(tokens)).strip() + adapter = self.get_adapter(old_value) + if adapter is not None and hasattr(adapter, "items"): + old_items = adapter.items(old_value, node) + new_items = adapter.items(value, node) + assert len(old_items) == len(new_items) - def _value_to_code(self, value): - return self._token_to_code(value_to_token(value)) + for old_item, new_item in zip(old_items, new_items): + re_eval(old_item.value, old_item.node, new_item.value) + + else: + if update_allowed(old_value): + if not old_value == value: + raise UsageError( + "snapshot value should not change. Use Is(...) for dynamic snapshot parts." + ) + else: + assert False, "old_value should be converted to Unmanaged" + + re_eval(self._old_value, self._ast_node, value) def _ignore_old(self): return ( @@ -132,10 +132,10 @@ def _visible_value(self): return self._old_value def _get_changes(self) -> Iterator[Change]: - raise NotImplementedYet() + raise NotImplementedError() def _new_code(self): - raise NotImplementedYet() + raise NotImplementedError() def __repr__(self): return repr(self._visible_value()) @@ -169,10 +169,12 @@ def __getitem__(self, _item): class UndecidedValue(GenericValue): def __init__(self, old_value, ast_node, source): + + old_value = adapter_map(old_value, map_unmanaged) self._old_value = old_value self._new_value = undefined self._ast_node = ast_node - self._source = source + self._file = SourceFile(source) def _change(self, cls): self.__class__ = cls @@ -183,48 +185,28 @@ def _new_code(self): def _get_changes(self) -> Iterator[Change]: def handle(node, obj): - if isinstance(obj, list): - if not isinstance(node, ast.List): - return - for node_value, value in zip(node.elts, obj): - yield from handle(node_value, value) - elif isinstance(obj, tuple): - if not isinstance(node, ast.Tuple): - return - for node_value, value in zip(node.elts, obj): - yield from handle(node_value, value) - - elif isinstance(obj, dict): - if not isinstance(node, ast.Dict): - return - for value_key, node_key, node_value in zip( - obj.keys(), node.keys, node.values - ): - try: - # this is just a sanity check, dicts should be ordered - node_key = ast.literal_eval(node_key) - except Exception: - pass - else: - assert node_key == value_key - - yield from handle(node_value, obj[value_key]) - else: - if update_allowed(obj): - new_token = value_to_token(obj) - if self._token_of_node(node) != new_token: - new_code = self._token_to_code(new_token) - - yield Replace( - node=self._ast_node, - source=self._source, - new_code=new_code, - flag="update", - old_value=self._old_value, - new_value=self._old_value, - ) - if self._source is not None: + adapter = self.get_adapter(obj) + if adapter is not None and hasattr(adapter, "items"): + for item in adapter.items(obj, node): + yield from handle(item.node, item.value) + return + + if not isinstance(obj, Unmanaged): + new_token = value_to_token(obj) + if self._file._token_of_node(node) != new_token: + new_code = self._file._token_to_code(new_token) + + yield Replace( + node=self._ast_node, + file=self._file, + new_code=new_code, + flag="update", + old_value=self._old_value, + new_value=self._old_value, + ) + + if self._file._source is not None: yield from handle(self._ast_node, self._old_value) # functions which determine the type @@ -250,19 +232,6 @@ def __getitem__(self, item): return self[item] -try: - import dirty_equals # type: ignore -except ImportError: # pragma: no cover - - def update_allowed(value): - return True - -else: - - def update_allowed(value): - return not isinstance(value, dirty_equals.DirtyEquals) - - def clone(obj): new = copy.deepcopy(obj) if not obj == new: @@ -282,233 +251,39 @@ def clone(obj): class EqValue(GenericValue): _current_op = "x == snapshot" + _changes: List[Change] def __eq__(self, other): global _missing_values if self._old_value is undefined: _missing_values += 1 - def use_valid_old_values(old_value, new_value): - - if ( - isinstance(new_value, list) - and isinstance(old_value, list) - or isinstance(new_value, tuple) - and isinstance(old_value, tuple) - ): - diff = add_x(align(old_value, new_value)) - old = iter(old_value) - new = iter(new_value) - result = [] - for c in diff: - if c in "mx": - old_value_element = next(old) - new_value_element = next(new) - result.append( - use_valid_old_values(old_value_element, new_value_element) - ) - elif c == "i": - result.append(next(new)) - elif c == "d": - pass - else: - assert False - - return type(new_value)(result) - - elif isinstance(new_value, dict) and isinstance(old_value, dict): - result = {} - - for key, new_value_element in new_value.items(): - if key in old_value: - result[key] = use_valid_old_values( - old_value[key], new_value_element - ) - else: - result[key] = new_value_element - - return result - - if new_value == old_value: - return old_value - else: - return new_value - - if self._new_value is undefined: - self._new_value = use_valid_old_values(self._old_value, clone(other)) - if self._old_value is undefined or ignore_old_value(): - return True - return _return(self._old_value == other) - else: - return _return(self._new_value == other) + if not compare_only() and self._new_value is undefined: + adapter = Adapter(self._file).get_adapter(self._old_value, other) + it = iter(adapter.assign(self._old_value, self._ast_node, clone(other))) + self._changes = [] + while True: + try: + self._changes.append(next(it)) + except StopIteration as ex: + self._new_value = ex.value + break + + return _return(self._visible_value() == other) + + # if self._new_value is undefined: + # self._new_value = use_valid_old_values(self._old_value, clone(other)) + # if self._old_value is undefined or ignore_old_value(): + # return True + # return _return(self._old_value == other) + # else: + # return _return(self._new_value == other) def _new_code(self): - return self._value_to_code(self._new_value) + return self._file._value_to_code(self._new_value) def _get_changes(self) -> Iterator[Change]: - - assert self._old_value is not undefined - - def check(old_value, old_node, new_value): - - if ( - isinstance(old_node, ast.List) - and isinstance(new_value, list) - and isinstance(old_value, list) - or isinstance(old_node, ast.Tuple) - and isinstance(new_value, tuple) - and isinstance(old_value, tuple) - ): - for e in old_node.elts: - if isinstance(e, ast.Starred): - warnings.warn_explicit( - "star-expressions are not supported inside snapshots", - filename=self._source.filename, - lineno=e.lineno, - category=InlineSnapshotSyntaxWarning, - ) - return - diff = add_x(align(old_value, new_value)) - old = zip(old_value, old_node.elts) - new = iter(new_value) - old_position = 0 - to_insert = defaultdict(list) - for c in diff: - if c in "mx": - old_value_element, old_node_element = next(old) - new_value_element = next(new) - yield from check( - old_value_element, old_node_element, new_value_element - ) - old_position += 1 - elif c == "i": - new_value_element = next(new) - new_code = self._value_to_code(new_value_element) - to_insert[old_position].append((new_code, new_value_element)) - elif c == "d": - old_value_element, old_node_element = next(old) - yield Delete( - "fix", self._source, old_node_element, old_value_element - ) - old_position += 1 - else: - assert False - - for position, code_values in to_insert.items(): - yield ListInsert( - "fix", self._source, old_node, position, *zip(*code_values) - ) - - return - - elif ( - isinstance(old_node, ast.Dict) - and isinstance(new_value, dict) - and isinstance(old_value, dict) - and len(old_value) == len(old_node.keys) - ): - - for key, value in zip(old_node.keys, old_node.values): - if key is None: - warnings.warn_explicit( - "star-expressions are not supported inside snapshots", - filename=self._source.filename, - lineno=value.lineno, - category=InlineSnapshotSyntaxWarning, - ) - return - - for value, node in zip(old_value.keys(), old_node.keys): - assert node is not None - - try: - # this is just a sanity check, dicts should be ordered - node_value = ast.literal_eval(node) - except: - continue - assert node_value == value - - for key, node in zip(old_value.keys(), old_node.values): - if key in new_value: - # check values with same keys - yield from check(old_value[key], node, new_value[key]) - else: - # delete entries - yield Delete("fix", self._source, node, old_value[key]) - - to_insert = [] - insert_pos = 0 - for key, new_value_element in new_value.items(): - if key not in old_value: - # add new values - to_insert.append((key, new_value_element)) - else: - if to_insert: - new_code = [ - (self._value_to_code(k), self._value_to_code(v)) - for k, v in to_insert - ] - yield DictInsert( - "fix", - self._source, - old_node, - insert_pos, - new_code, - to_insert, - ) - to_insert = [] - insert_pos += 1 - - if to_insert: - new_code = [ - (self._value_to_code(k), self._value_to_code(v)) - for k, v in to_insert - ] - yield DictInsert( - "fix", - self._source, - old_node, - len(old_node.values), - new_code, - to_insert, - ) - - return - - # generic fallback - - # because IsStr() != IsStr() - if type(old_value) is type(new_value) and not update_allowed(new_value): - return - - if old_node is None: - new_token = [] - else: - new_token = value_to_token(new_value) - - if not old_value == new_value: - flag = "fix" - elif ( - self._ast_node is not None - and update_allowed(old_value) - and self._token_of_node(old_node) != new_token - ): - flag = "update" - else: - return - - new_code = self._token_to_code(new_token) - - yield Replace( - node=old_node, - source=self._source, - new_code=new_code, - flag=flag, - old_value=old_value, - new_value=new_value, - ) - - yield from check(self._old_value, self._ast_node, self._new_value) + return iter(self._changes) class MinMaxValue(GenericValue): @@ -535,7 +310,7 @@ def _generic_cmp(self, other): return _return(self.cmp(self._visible_value(), other)) def _new_code(self): - return self._value_to_code(self._new_value) + return self._file._value_to_code(self._new_value) def _get_changes(self) -> Iterator[Change]: new_token = value_to_token(self._new_value) @@ -545,17 +320,17 @@ def _get_changes(self) -> Iterator[Change]: flag = "trim" elif ( self._ast_node is not None - and self._token_of_node(self._ast_node) != new_token + and self._file._token_of_node(self._ast_node) != new_token ): flag = "update" else: return - new_code = self._token_to_code(new_token) + new_code = self._file._token_to_code(new_token) yield Replace( node=self._ast_node, - source=self._source, + file=self._file, new_code=new_code, flag=flag, old_value=self._old_value, @@ -625,7 +400,7 @@ def __contains__(self, item): return _return(item in self._old_value) def _new_code(self): - return self._value_to_code(self._new_value) + return self._file._value_to_code(self._new_value) def _get_changes(self) -> Iterator[Change]: @@ -638,19 +413,25 @@ def _get_changes(self) -> Iterator[Change]: for old_value, old_node in zip(self._old_value, elements): if old_value not in self._new_value: yield Delete( - flag="trim", source=self._source, node=old_node, old_value=old_value + flag="trim", + file=self._file, + node=old_node, + old_value=old_value, ) continue # check for update new_token = value_to_token(old_value) - if old_node is not None and self._token_of_node(old_node) != new_token: - new_code = self._token_to_code(new_token) + if ( + old_node is not None + and self._file._token_of_node(old_node) != new_token + ): + new_code = self._file._token_to_code(new_token) yield Replace( node=old_node, - source=self._source, + file=self._file, new_code=new_code, flag="update", old_value=old_value, @@ -661,10 +442,10 @@ def _get_changes(self) -> Iterator[Change]: if new_values: yield ListInsert( flag="fix", - source=self._source, + file=self._file, node=self._ast_node, position=len(self._old_value), - new_code=[self._value_to_code(v) for v in new_values], + new_code=[self._file._value_to_code(v) for v in new_values], new_values=new_values, ) @@ -678,31 +459,39 @@ def __getitem__(self, index): if self._new_value is undefined: self._new_value = {} - old_value = self._old_value - if old_value is undefined: - _missing_values += 1 - old_value = {} - - child_node = None - if self._ast_node is not None: - assert isinstance(self._ast_node, ast.Dict) - if index in old_value: - pos = list(old_value.keys()).index(index) - child_node = self._ast_node.values[pos] - if index not in self._new_value: + old_value = self._old_value + if old_value is undefined: + _missing_values += 1 + old_value = {} + + child_node = None + if self._ast_node is not None: + assert isinstance(self._ast_node, ast.Dict) + if index in old_value: + pos = list(old_value.keys()).index(index) + child_node = self._ast_node.values[pos] + self._new_value[index] = UndecidedValue( - old_value.get(index, undefined), child_node, self._source + old_value.get(index, undefined), child_node, self._file ) return self._new_value[index] + def _re_eval(self, value): + super()._re_eval(value) + + if self._new_value is not undefined and self._old_value is not undefined: + for key, s in self._new_value.items(): + if key in self._old_value: + s._re_eval(self._old_value[key]) + def _new_code(self): return ( "{" + ", ".join( [ - f"{self._value_to_code(k)}: {v._new_code()}" + f"{self._file._value_to_code(k)}: {v._new_code()}" for k, v in self._new_value.items() if not isinstance(v, UndecidedValue) ] @@ -726,7 +515,7 @@ def _get_changes(self) -> Iterator[Change]: yield from self._new_value[key]._get_changes() else: # delete entries - yield Delete("trim", self._source, node, self._old_value[key]) + yield Delete("trim", self._file, node, self._old_value[key]) to_insert = [] for key, new_value_element in self._new_value.items(): @@ -737,10 +526,10 @@ def _get_changes(self) -> Iterator[Change]: to_insert.append((key, new_value_element._new_code())) if to_insert: - new_code = [(self._value_to_code(k), v) for k, v in to_insert] + new_code = [(self._file._value_to_code(k), v) for k, v in to_insert] yield DictInsert( "create", - self._source, + self._file, self._ast_node, len(self._old_value), new_code, @@ -815,10 +604,12 @@ def snapshot(obj: Any = undefined) -> Any: node = expr.node if node is None: # we can run without knowing of the calling expression but we will not be able to fix code - snapshots[key] = Snapshot(obj, None) + snapshots[key] = SnapshotReference(obj, None) else: assert isinstance(node, ast.Call) - snapshots[key] = Snapshot(obj, expr) + snapshots[key] = SnapshotReference(obj, expr) + else: + snapshots[key]._re_eval(obj) return snapshots[key]._value @@ -835,7 +626,7 @@ def used_externals(tree): ] -class Snapshot: +class SnapshotReference: def __init__(self, value, expr): self._expr = expr node = expr.node.args[0] if expr is not None and expr.node.args else None @@ -853,16 +644,18 @@ def _changes(self): new_code = self._value._new_code() yield CallArg( - "create", - self._value._source, - self._expr.node if self._expr is not None else None, - 0, - None, - new_code, - self._value._old_value, - self._value._new_value, + flag="create", + file=self._value._file, + node=self._expr.node if self._expr is not None else None, + arg_pos=0, + arg_name=None, + new_code=new_code, + new_value=self._value._new_value, ) else: yield from self._value._get_changes() + + def _re_eval(self, obj): + self._value._re_eval(obj) diff --git a/src/inline_snapshot/_is.py b/src/inline_snapshot/_is.py new file mode 100644 index 00000000..1f695397 --- /dev/null +++ b/src/inline_snapshot/_is.py @@ -0,0 +1,6 @@ +class Is: + def __init__(self, value): + self.value = value + + def __eq__(self, other): + return self.value == other diff --git a/src/inline_snapshot/_rewrite_code.py b/src/inline_snapshot/_rewrite_code.py index 70eb9b6e..0cab4c56 100644 --- a/src/inline_snapshot/_rewrite_code.py +++ b/src/inline_snapshot/_rewrite_code.py @@ -98,12 +98,8 @@ def __init__(self, change_recorder): self.change_recorder._changes.append(self) self.change_id = self._next_change_id - self._tags = [] type(self)._next_change_id += 1 - def set_tags(self, *tags): - self._tags = tags - def replace(self, node, new_contend, *, filename): assert isinstance(new_contend, str) @@ -128,7 +124,7 @@ def _replace(self, filename, range, new_contend): class SourceFile: - def __init__(self, filename): + def __init__(self, filename: pathlib.Path): self.replacements: list[Replacement] = [] self.filename = filename self.source = self.filename.read_text("utf-8") diff --git a/src/inline_snapshot/_source_file.py b/src/inline_snapshot/_source_file.py new file mode 100644 index 00000000..ba8a94bc --- /dev/null +++ b/src/inline_snapshot/_source_file.py @@ -0,0 +1,51 @@ +import tokenize +from pathlib import Path + +from executing import Source +from inline_snapshot._format import format_code +from inline_snapshot._utils import normalize +from inline_snapshot._utils import simple_token +from inline_snapshot._utils import value_to_token + +from ._utils import ignore_tokens + + +class SourceFile: + _source = Source + + def __init__(self, source): + if isinstance(source, SourceFile): + self._source = source._source + else: + self._source = source + + @property + def filename(self): + return self._source.filename + + def _format(self, text): + if self._source is None: + return text + else: + return format_code(text, Path(self._source.filename)) + + def asttokens(self): + return self._source.asttokens() + + def _token_to_code(self, tokens): + return self._format(tokenize.untokenize(tokens)).strip() + + def _value_to_code(self, value): + return self._token_to_code(value_to_token(value)) + + def _token_of_node(self, node): + + return list( + normalize( + [ + simple_token(t.type, t.string) + for t in self._source.asttokens().get_tokens(node) + if t.type not in ignore_tokens + ] + ) + ) diff --git a/src/inline_snapshot/_unmanaged.py b/src/inline_snapshot/_unmanaged.py new file mode 100644 index 00000000..5e46b9b5 --- /dev/null +++ b/src/inline_snapshot/_unmanaged.py @@ -0,0 +1,41 @@ +from ._is import Is +from ._types import Snapshot + +try: + import dirty_equals # type: ignore +except ImportError: # pragma: no cover + + def is_dirty_equal(value): + return False + +else: + + def is_dirty_equal(value): + return isinstance(value, dirty_equals.DirtyEquals) or ( + isinstance(value, type) and issubclass(value, dirty_equals.DirtyEquals) + ) + + +def update_allowed(value): + return not (is_dirty_equal(value) or isinstance(value, (Is, Snapshot))) # type: ignore + + +def is_unmanaged(value): + return not update_allowed(value) + + +class Unmanaged: + def __init__(self, value): + self.value = value + + def __eq__(self, other): + assert not isinstance(other, Unmanaged) + + return self.value == other + + +def map_unmanaged(value): + if is_unmanaged(value): + return Unmanaged(value) + else: + return value diff --git a/src/inline_snapshot/pytest_plugin.py b/src/inline_snapshot/pytest_plugin.py index ed059217..4ac6809f 100644 --- a/src/inline_snapshot/pytest_plugin.py +++ b/src/inline_snapshot/pytest_plugin.py @@ -84,6 +84,12 @@ def pytest_configure(config): f"--inline-snapshot={','.join(flags)} can not be combined with xdist" ) + unknown_flags = flags - categories - {"disable", "review", "report", "short-report"} + if unknown_flags: + raise pytest.UsageError( + f"--inline-snapshot={','.join(sorted(unknown_flags))} is a unknown flag" + ) + if "disable" in flags and flags != {"disable"}: raise pytest.UsageError( f"--inline-snapshot=disable can not be combined with other flags ({', '.join(flags-{'disable'})})" diff --git a/src/inline_snapshot/syntax_warnings.py b/src/inline_snapshot/syntax_warnings.py new file mode 100644 index 00000000..a403ec16 --- /dev/null +++ b/src/inline_snapshot/syntax_warnings.py @@ -0,0 +1,6 @@ +class InlineSnapshotSyntaxWarning(Warning): + pass + + +class InlineSnapshotInfo(Warning): + pass diff --git a/src/inline_snapshot/testing/_example.py b/src/inline_snapshot/testing/_example.py index aab78947..d4a6f3c6 100644 --- a/src/inline_snapshot/testing/_example.py +++ b/src/inline_snapshot/testing/_example.py @@ -5,6 +5,7 @@ import platform import re import subprocess as sp +import traceback from argparse import ArgumentParser from io import StringIO from pathlib import Path @@ -88,6 +89,14 @@ def __init__(self, files: str | dict[str, str]): self.files = files + self.dump_files() + + def dump_files(self): + for name, content in self.files.items(): + print(f"file: {name}") + print(content) + print() + def _write_files(self, dir: Path): for name, content in self.files.items(): (dir / name).write_text(content) @@ -144,6 +153,7 @@ def run_inline( self._write_files(tmp_path) + raised_exception = None with snapshot_env(): with ChangeRecorder().activate() as recorder: _inline_snapshot._update_flags = Flags({*flags}) @@ -154,6 +164,7 @@ def run_inline( try: for filename in tmp_path.glob("*.py"): globals: dict[str, Any] = {} + print("run> pytest", filename) exec( compile(filename.read_text("utf-8"), filename, "exec"), globals, @@ -164,7 +175,8 @@ def run_inline( if k.startswith("test_") and callable(v): v() except Exception as e: - assert raises == f"{type(e).__name__}:\n" + str(e) + traceback.print_exc() + raised_exception = e finally: _inline_snapshot._active = False @@ -193,6 +205,11 @@ def run_inline( if reported_categories is not None: assert sorted(snapshot_flags) == reported_categories + if raised_exception is not None: + assert raises == f"{type(raised_exception).__name__}:\n" + str( + raised_exception + ) + if changed_files is not None: current_files = {} @@ -213,6 +230,7 @@ def run_pytest( env: dict[str, str] = {}, changed_files: Snapshot[dict[str, str]] | None = None, report: Snapshot[str] | None = None, + stderr: Snapshot[str] | None = None, returncode: Snapshot[int] | None = None, ) -> Example: """Run pytest with the given args and env variables in an seperate @@ -225,6 +243,7 @@ def run_pytest( env: dict of environment variables changed_files: snapshot of files which are changed by this run. report: snapshot of the report at the end of the pytest run. + stderr: pytest stderr output returncode: snapshot of the pytest returncode. Returns: @@ -259,6 +278,9 @@ def run_pytest( if returncode is not None: assert result.returncode == returncode + if stderr is not None: + assert result.stderr.decode() == stderr + if report is not None: report_list = [] diff --git a/tests/adapter/test_dataclass.py b/tests/adapter/test_dataclass.py new file mode 100644 index 00000000..4f521bda --- /dev/null +++ b/tests/adapter/test_dataclass.py @@ -0,0 +1,492 @@ +from inline_snapshot import snapshot +from inline_snapshot.testing._example import Example + +from tests.warns import warns + + +def test_unmanaged(): + + Example( + """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass + +@dataclass +class A: + a:int + b:int + +def test_something(): + assert A(a=2,b=4) == snapshot(A(a=1,b=Is(1))), "not equal" +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass + +@dataclass +class A: + a:int + b:int + +def test_something(): + assert A(a=2,b=4) == snapshot(A(a=2,b=Is(1))), "not equal" +""" + } + ), + raises=snapshot( + """\ +AssertionError: +not equal\ +""" + ), + ) + + +def test_reeval(): + Example( + """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass + +@dataclass +class A: + a:int + b:int + +def test_something(): + for c in "ab": + assert A(a=1,b=c) == snapshot(A(a=2,b=Is(c))) +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass + +@dataclass +class A: + a:int + b:int + +def test_something(): + for c in "ab": + assert A(a=1,b=c) == snapshot(A(a=1,b=Is(c))) +""" + } + ), + ) + + +def test_pydantic_default_value(): + Example( + """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass,field +from pydantic import BaseModel,Field + +class A(BaseModel): + a:int + b:int=2 + c:list=Field(default_factory=list) + +def test_something(): + assert A(a=1) == snapshot(A(a=1,b=2,c=[])) +""" + ).run_inline( + ["--inline-snapshot=update"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass,field +from pydantic import BaseModel,Field + +class A(BaseModel): + a:int + b:int=2 + c:list=Field(default_factory=list) + +def test_something(): + assert A(a=1) == snapshot(A(a=1)) +""" + } + ), + ) + + +def test_dataclass_default_value(): + Example( + """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass,field + +@dataclass +class A: + a:int + b:int=2 + c:list=field(default_factory=list) + +def test_something(): + assert A(a=1) == snapshot(A(a=1,b=2,c=[])) +""" + ).run_inline( + ["--inline-snapshot=update"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass,field + +@dataclass +class A: + a:int + b:int=2 + c:list=field(default_factory=list) + +def test_something(): + assert A(a=1) == snapshot(A(a=1)) +""" + } + ), + ) + + +def test_disabled(executing_used): + Example( + """\ +from inline_snapshot import snapshot +from dataclasses import dataclass + +@dataclass +class A: + a:int + +def test_something(): + assert A(a=3) == snapshot(A(a=5)),"not equal" +""" + ).run_inline( + changed_files=snapshot({}), + raises=snapshot( + """\ +AssertionError: +not equal\ +""" + ), + ) + + +def test_starred_warns(): + with warns( + snapshot( + [ + ( + 10, + "InlineSnapshotSyntaxWarning: star-expressions are not supported inside snapshots", + ) + ] + ), + include_line=True, + ): + Example( + """ +from inline_snapshot import snapshot +from dataclasses import dataclass + +@dataclass +class A: + a:int + +def test_something(): + assert A(a=3) == snapshot(A(**{"a":5})),"not equal" +""" + ).run_inline( + ["--inline-snapshot=fix"], + raises=snapshot( + """\ +AssertionError: +not equal\ +""" + ), + ) + + +def test_add_argument(): + Example( + """\ +from inline_snapshot import snapshot +from dataclasses import dataclass + +@dataclass +class A: + a:int=0 + b:int=0 + c:int=0 + +def test_something(): + assert A(a=3,b=3,c=3) == snapshot(A(b=3)),"not equal" +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot +from dataclasses import dataclass + +@dataclass +class A: + a:int=0 + b:int=0 + c:int=0 + +def test_something(): + assert A(a=3,b=3,c=3) == snapshot(A(a = 3, b=3, c = 3)),"not equal" +""" + } + ), + raises=snapshot( + """\ +AssertionError: +not equal\ +""" + ), + ) + + +def test_positional_star_args(): + + with warns( + snapshot( + [ + "InlineSnapshotSyntaxWarning: star-expressions are not supported inside snapshots" + ] + ) + ): + Example( + """\ +from inline_snapshot import snapshot +from dataclasses import dataclass + +@dataclass +class A: + a:int + +def test_something(): + assert A(a=3) == snapshot(A(*[],a=3)),"not equal" +""" + ).run_inline( + ["--inline-snapshot=report"], + ) + + +def test_remove_positional_argument(): + Example( + """\ +from inline_snapshot import snapshot + +from inline_snapshot._adapter.generic_call_adapter import GenericCallAdapter,Argument + + +class L: + def __init__(self,*l): + self.l=l + + def __eq__(self,other): + if not isinstance(other,L): + return NotImplemented + return other.l==self.l + +class LAdapter(GenericCallAdapter): + @classmethod + def check_type(cls, typ): + return issubclass(typ,L) + + @classmethod + def arguments(cls, value): + return ([Argument(x) for x in value.l],{}) + + @classmethod + def argument(cls, value, pos_or_name): + assert isinstance(pos_or_name,int) + return value.l[pos_or_name] + +def test_L1(): + assert L(1,2) == snapshot(L(1)), "not equal" + +def test_L2(): + assert L(1,2) == snapshot(L(1, 2, 3)), "not equal" +""" + ).run_pytest( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot + +from inline_snapshot._adapter.generic_call_adapter import GenericCallAdapter,Argument + + +class L: + def __init__(self,*l): + self.l=l + + def __eq__(self,other): + if not isinstance(other,L): + return NotImplemented + return other.l==self.l + +class LAdapter(GenericCallAdapter): + @classmethod + def check_type(cls, typ): + return issubclass(typ,L) + + @classmethod + def arguments(cls, value): + return ([Argument(x) for x in value.l],{}) + + @classmethod + def argument(cls, value, pos_or_name): + assert isinstance(pos_or_name,int) + return value.l[pos_or_name] + +def test_L1(): + assert L(1,2) == snapshot(L(1, 2)), "not equal" + +def test_L2(): + assert L(1,2) == snapshot(L(1, 2)), "not equal" +""" + } + ), + ) + + +def test_namedtuple(): + Example( + """\ +from inline_snapshot import snapshot +from collections import namedtuple + +T=namedtuple("T","a,b") + +def test_tuple(): + assert T(a=1,b=2) == snapshot(T(a=1, b=3)), "not equal" +""" + ).run_pytest( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot +from collections import namedtuple + +T=namedtuple("T","a,b") + +def test_tuple(): + assert T(a=1,b=2) == snapshot(T(a=1, b=2)), "not equal" +""" + } + ), + ) + + +def test_defaultdict(): + Example( + """\ +from inline_snapshot import snapshot +from collections import defaultdict + + +def test_tuple(): + d=defaultdict(list) + d[1].append(2) + assert d == snapshot(defaultdict(list, {1: [3]})), "not equal" +""" + ).run_pytest( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot +from collections import defaultdict + + +def test_tuple(): + d=defaultdict(list) + d[1].append(2) + assert d == snapshot(defaultdict(list, {1: [2]})), "not equal" +""" + } + ), + ) + + +def test_dataclass_field_repr(): + + Example( + """\ +from inline_snapshot import snapshot +from dataclasses import dataclass,field + +@dataclass +class container: + a: int + b: int = field(default=5,repr=False) + +assert container(a=1,b=5) == snapshot() +""" + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot +from dataclasses import dataclass,field + +@dataclass +class container: + a: int + b: int = field(default=5,repr=False) + +assert container(a=1,b=5) == snapshot(container(a=1)) +""" + } + ), + ).run_inline() + + +def test_pydantic_field_repr(): + + Example( + """\ +from inline_snapshot import snapshot +from pydantic import BaseModel,Field + +class container(BaseModel): + a: int + b: int = Field(default=5,repr=False) + +assert container(a=1,b=5) == snapshot() +""" + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot +from pydantic import BaseModel,Field + +class container(BaseModel): + a: int + b: int = Field(default=5,repr=False) + +assert container(a=1,b=5) == snapshot(container(a=1)) +""" + } + ), + ).run_inline() diff --git a/tests/adapter/test_dict.py b/tests/adapter/test_dict.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/adapter/test_general.py b/tests/adapter/test_general.py new file mode 100644 index 00000000..52d1daf9 --- /dev/null +++ b/tests/adapter/test_general.py @@ -0,0 +1,47 @@ +from inline_snapshot import snapshot +from inline_snapshot.testing import Example + + +def test_adapter_mismatch(): + + Example( + """\ +from inline_snapshot import snapshot + + +def test_thing(): + assert [1,2] == snapshot({1:2}) + + """ + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot + + +def test_thing(): + assert [1,2] == snapshot([1, 2]) + + \ +""" + } + ), + ) + + +def test_reeval(): + + Example( + """\ +from inline_snapshot import snapshot,Is + + +def test_thing(): + for i in (1,2): + assert {1:i} == snapshot({1:Is(i)}) + assert [i] == [Is(i)] + assert (i,) == (Is(i),) +""" + ).run_pytest(["--inline-snapshot=short-report"], report=snapshot("")) diff --git a/tests/adapter/test_sequence.py b/tests/adapter/test_sequence.py new file mode 100644 index 00000000..77ea2853 --- /dev/null +++ b/tests/adapter/test_sequence.py @@ -0,0 +1,94 @@ +import pytest +from inline_snapshot._inline_snapshot import snapshot +from inline_snapshot.testing._example import Example + + +def test_list_adapter_create_inner_snapshot(): + + Example( + """\ +from inline_snapshot import snapshot +from dirty_equals import IsInt + +def test_list(): + + assert [1,2,3,4] == snapshot([1,IsInt(),snapshot(),4]),"not equal" +""" + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot +from dirty_equals import IsInt + +def test_list(): + + assert [1,2,3,4] == snapshot([1,IsInt(),snapshot(3),4]),"not equal" +""" + } + ), + raises=snapshot( + """\ +AssertionError: +not equal\ +""" + ), + ) + + +def test_list_adapter_fix_inner_snapshot(): + + Example( + """\ +from inline_snapshot import snapshot +from dirty_equals import IsInt + +def test_list(): + + assert [1,2,3,4] == snapshot([1,IsInt(),snapshot(8),4]),"not equal" +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot +from dirty_equals import IsInt + +def test_list(): + + assert [1,2,3,4] == snapshot([1,IsInt(),snapshot(3),4]),"not equal" +""" + } + ), + raises=snapshot( + """\ +AssertionError: +not equal\ +""" + ), + ) + + +@pytest.mark.no_rewriting +def test_list_adapter_reeval(executing_used): + + Example( + """\ +from inline_snapshot import snapshot,Is + +def test_list(): + + for i in (1,2,3): + assert [1,i] == snapshot([1,Is(i)]),"not equal" +""" + ).run_inline( + changed_files=snapshot({}), + raises=snapshot( + """\ +AssertionError: +not equal\ +""" + ), + ) diff --git a/tests/conftest.py b/tests/conftest.py index 4aeb5730..dbef0a28 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,6 +21,7 @@ from inline_snapshot._format import format_code from inline_snapshot._inline_snapshot import Flags from inline_snapshot._rewrite_code import ChangeRecorder +from inline_snapshot._types import Category from inline_snapshot.testing._example import snapshot_env pytest_plugins = "pytester" @@ -65,7 +66,7 @@ def w(source_code, *, flags="", reported_flags=None, number=1): @pytest.fixture() -def source(tmp_path): +def source(tmp_path: Path): filecount = 1 @dataclass @@ -76,8 +77,8 @@ class Source: number_snapshots: int = 0 number_changes: int = 0 - def run(self, *flags): - flags = Flags({*flags}) + def run(self, *flags_arg: Category): + flags = Flags({*flags_arg}) nonlocal filecount filename: Path = tmp_path / f"test_{filecount}.py" @@ -311,7 +312,10 @@ def format(self): ) def pyproject(self, source): - (pytester.path / "pyproject.toml").write_text(source, "utf-8") + self.write_file("pyproject.toml", source) + + def write_file(self, filename, content): + (pytester.path / filename).write_text(content, "utf-8") def storage(self): dir = pytester.path / ".inline-snapshot" / "external" diff --git a/tests/test_change.py b/tests/test_change.py new file mode 100644 index 00000000..cbe82589 --- /dev/null +++ b/tests/test_change.py @@ -0,0 +1,90 @@ +import ast + +import pytest +from executing import Source +from inline_snapshot._change import apply_all +from inline_snapshot._change import CallArg +from inline_snapshot._change import Delete +from inline_snapshot._change import Replace +from inline_snapshot._inline_snapshot import snapshot +from inline_snapshot._rewrite_code import ChangeRecorder +from inline_snapshot._source_file import SourceFile + + +@pytest.fixture +def check_change(tmp_path): + i = 0 + + def w(source, changes, new_code): + nonlocal i + + filename = tmp_path / f"test_{i}.py" + i += 1 + + filename.write_text(source) + print(f"\ntest: {source}") + + source = Source.for_filename(filename) + module = source.tree + context = SourceFile(source) + + call = module.body[0].value + assert isinstance(call, ast.Call) + + with ChangeRecorder().activate() as cr: + apply_all(changes(context, call)) + + cr.virtual_write() + + cr.dump() + + assert list(cr.files())[0].source == new_code + + return w + + +def test_change_function_args(check_change): + + check_change( + "f(a,b=2)", + lambda source, call: [ + Replace( + flag="fix", + file=source, + node=call.args[0], + new_code="22", + old_value=0, + new_value=0, + ) + ], + snapshot("f(22,b=2)"), + ) + + check_change( + "f(a,b=2)", + lambda source, call: [ + Delete( + flag="fix", + file=source, + node=call.args[0], + old_value=0, + ) + ], + snapshot("f(b=2)"), + ) + + check_change( + "f(a,b=2)", + lambda source, call: [ + CallArg( + flag="fix", + file=source, + node=call, + arg_pos=0, + arg_name=None, + new_code="22", + new_value=22, + ) + ], + snapshot("f(22, a,b=2)"), + ) diff --git a/tests/test_code_repr.py b/tests/test_code_repr.py index 1ee4f9ea..64c80ebe 100644 --- a/tests/test_code_repr.py +++ b/tests/test_code_repr.py @@ -144,40 +144,6 @@ class container: ) -def test_dataclass_field_repr(check_update): - - Example( - """\ -from inline_snapshot import snapshot -from dataclasses import dataclass,field - -@dataclass -class container: - a: int - b: int = field(default=5,repr=False) - -assert container(a=1,b=5) == snapshot() -""" - ).run_inline( - ["--inline-snapshot=create"], - changed_files=snapshot( - { - "test_something.py": """\ -from inline_snapshot import snapshot -from dataclasses import dataclass,field - -@dataclass -class container: - a: int - b: int = field(default=5,repr=False) - -assert container(a=1,b=5) == snapshot(container(a=1)) -""" - } - ), - ).run_inline() - - def test_flag(check_update): assert ( @@ -368,3 +334,37 @@ def __repr__(self): return "FakeTuple()" assert code_repr(FakeTuple()) == snapshot("FakeTuple()") + + +def test_invalid_repr(check_update): + assert ( + check_update( + """\ +class Thing: + def __repr__(self): + return "+++" + + def __eq__(self,other): + if not isinstance(other,Thing): + return NotImplemented + return True + +assert Thing() == snapshot() +""", + flags="create", + ) + == snapshot( + """\ +class Thing: + def __repr__(self): + return "+++" + + def __eq__(self,other): + if not isinstance(other,Thing): + return NotImplemented + return True + +assert Thing() == snapshot(HasRepr(Thing, "+++")) +""" + ) + ) diff --git a/tests/test_dirty_equals.py b/tests/test_dirty_equals.py new file mode 100644 index 00000000..1bc897a3 --- /dev/null +++ b/tests/test_dirty_equals.py @@ -0,0 +1,152 @@ +from inline_snapshot._inline_snapshot import snapshot +from inline_snapshot.testing._example import Example + + +def test_dirty_equals_repr(): + Example( + """\ +from inline_snapshot import snapshot +from dirty_equals import IsStr + +def test_something(): + assert [IsStr()] == snapshot() + """ + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot({}), + raises=snapshot( + """\ +UsageError: +inline-snapshot uses `copy.deepcopy` to copy objects, +but the copied object is not equal to the original one: + +original: [HasRepr(IsStr, '< type(obj) can not be compared with == >')] +copied: [HasRepr(IsStr, '< type(obj) can not be compared with == >')] + +Please fix the way your object is copied or your __eq__ implementation. +""" + ), + ) + + +def test_compare_dirty_equals_twice() -> None: + + Example( + """ +from dirty_equals import IsStr +from inline_snapshot import snapshot + +for x in 'ab': + assert x == snapshot(IsStr()) + assert [x,5] == snapshot([IsStr(),3]) + assert {'a':x,'b':5} == snapshot({'a':IsStr(),'b':3}) + +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ + +from dirty_equals import IsStr +from inline_snapshot import snapshot + +for x in 'ab': + assert x == snapshot(IsStr()) + assert [x,5] == snapshot([IsStr(),5]) + assert {'a':x,'b':5} == snapshot({'a':IsStr(),'b':5}) + +""" + } + ), + ) + + +def test_dirty_equals_in_unused_snapshot() -> None: + + Example( + """ +from dirty_equals import IsStr +from inline_snapshot import snapshot,Is + +snapshot([IsStr(),3]) +snapshot((IsStr(),3)) +snapshot({1:IsStr(),2:3}) +snapshot({1+1:2}) + +t=(1,2) +d={1:2} +l=[1,2] +snapshot([Is(t),Is(d),Is(l)]) +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot({}), + ) + + +def test_now_like_dirty_equals(): + # test for cases like https://github.com/15r10nk/inline-snapshot/issues/116 + + Example( + """ +from dirty_equals import DirtyEquals +from inline_snapshot import snapshot + + +def test_time(): + + now = 5 + + class Now(DirtyEquals): + def equals(self, other): + return other == now + + assert 5 == snapshot(Now()) + + now = 6 + + assert 5 == snapshot(Now()), "different time" +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot({}), + raises=snapshot( + """\ +AssertionError: +different time\ +""" + ), + ) + + +def test_dirty_equals_with_changing_args() -> None: + + Example( + """\ +from dirty_equals import IsInt +from inline_snapshot import snapshot + +def test_number(): + + for i in range(5): + assert ["a",i] == snapshot(["e",IsInt(gt=i-1,lt=i+1)]) + +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from dirty_equals import IsInt +from inline_snapshot import snapshot + +def test_number(): + + for i in range(5): + assert ["a",i] == snapshot(["a",IsInt(gt=i-1,lt=i+1)]) + +""" + } + ), + ) diff --git a/tests/test_docs.py b/tests/test_docs.py index 14a4046f..d1543afd 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -1,11 +1,233 @@ +import itertools import platform import re import sys import textwrap +from collections import defaultdict +from dataclasses import dataclass from pathlib import Path +from typing import Optional import inline_snapshot._inline_snapshot import pytest +from inline_snapshot import snapshot +from inline_snapshot.extra import raises + + +@dataclass +class Block: + code: str + code_header: Optional[str] + block_options: str + line: int + + +def map_code_blocks(file, func, fix=False): + + block_start = re.compile("( *)``` *python(.*)") + block_end = re.compile("```.*") + + header = re.compile("") + + current_code = file.read_text("utf-8") + new_lines = [] + block_lines = [] + options = set() + is_block = False + code = None + indent = "" + block_start_linenum = None + block_options = None + code_header = None + header_line = "" + + for linenumber, line in enumerate(current_code.splitlines(), start=1): + m = block_start.fullmatch(line) + if m and not is_block: + # ``` python + block_start_linenum = linenumber + indent = m[1] + block_options = m[2] + block_lines = [] + is_block = True + continue + + if block_end.fullmatch(line.strip()) and is_block: + # ``` + is_block = False + + code = "\n".join(block_lines) + "\n" + code = textwrap.dedent(code) + if file.suffix == ".py": + code = code.replace("\\\\", "\\") + + try: + new_block = func( + Block( + code=code, + code_header=code_header, + block_options=block_options, + line=block_start_linenum, + ) + ) + except Exception: + print(f"error at block at line {block_start_linenum}") + print(f"{code_header=}") + print(f"{block_options=}") + print(code) + raise + + if new_block.code_header is not None: + new_lines.append(f"{indent}") + + new_lines.append( + f"{indent}``` {('python '+new_block.block_options.strip()).strip()}" + ) + + new_code = new_block.code.rstrip() + if file.suffix == ".py": + new_code = new_code.replace("\\", "\\\\") + new_code = textwrap.indent(new_code, indent) + + new_lines.append(new_code) + + new_lines.append(f"{indent}```") + + header_line = "" + code_header = None + + continue + + if is_block: + block_lines.append(line) + continue + + m = header.fullmatch(line.strip()) + if m: + # comment + header_line = line + code_header = m[1].strip() + continue + else: + if header_line: + new_lines.append(header_line) + code_header = None + header_line = "" + + new_lines.append(line) + + new_code = "\n".join(new_lines) + "\n" + + if fix: + file.write_text(new_code) + else: + assert current_code.splitlines() == new_code.splitlines() + assert current_code == new_code + + +def test_map_code_blocks(tmp_path): + + file = tmp_path / "example.md" + + def test_doc( + markdown_code, + handle_block=lambda block: exec(block.code), + blocks=[], + exception="", + new_markdown_code=None, + ): + + file.write_text(markdown_code) + + recorded_blocks = [] + + with raises(exception): + + def test_block(block): + handle_block(block) + recorded_blocks.append(block) + return block + + map_code_blocks(file, test_block, True) + assert recorded_blocks == blocks + map_code_blocks(file, test_block, False) + + recorded_markdown_code = file.read_text() + if recorded_markdown_code != markdown_code: + assert new_markdown_code == recorded_markdown_code + else: + assert new_markdown_code == None + + test_doc( + """ +``` python +1 / 0 +``` +""", + exception=snapshot("ZeroDivisionError: division by zero"), + ) + + test_doc( + """\ +text +``` python +print(1 + 1) +``` +text + +``` python hl_lines="1 2 3" +print(1 - 1) +``` +text +""", + blocks=snapshot( + [ + Block( + code="print(1 + 1)\n", code_header=None, block_options="", line=2 + ), + Block( + code="print(1 - 1)\n", + code_header="inline-snapshot: create test", + block_options=' hl_lines="1 2 3"', + line=7, + ), + ] + ), + ) + + def change_block(block): + block.code = "# removed" + block.code_header = "header" + block.block_options = "option a b c" + + test_doc( + """\ +text +``` python +print(1 + 1) +``` +""", + handle_block=change_block, + blocks=snapshot( + [ + Block( + code="# removed", + code_header="header", + block_options="option a b c", + line=2, + ) + ] + ), + new_markdown_code=snapshot( + """\ +text + +``` python option a b c +# removed +``` +""" + ), + ) @pytest.mark.skipif( @@ -14,7 +236,7 @@ ) @pytest.mark.skipif( sys.version_info[:2] != (3, 12), - reason="\\r in stdout can cause problems in snapshot strings", + reason="there is no reason to test the doc with different python versions", ) @pytest.mark.parametrize( "file", @@ -36,19 +258,7 @@ def test_docs(project, file, subtests): * `outcome-passed=2` to check for the pytest test outcome """ - block_start = re.compile("( *)``` *python.*") - block_end = re.compile("```.*") - - header = re.compile("") - - text = file.read_text("utf-8") - new_lines = [] - block_lines = [] - options = set() - is_block = False - code = None - indent = "" - first_block = True + last_code = None project.pyproject( """ @@ -57,132 +267,104 @@ def test_docs(project, file, subtests): """ ) - for linenumber, line in enumerate(text.splitlines(), start=1): - m = block_start.fullmatch(line) - if m and is_block == True: - block_start_line = line - indent = m[1] - block_lines = [] - continue + extra_files = defaultdict(list) - if block_end.fullmatch(line.strip()) and is_block: - with subtests.test(line=linenumber): - is_block = False + def test_block(block: Block): + if block.code_header is None: + return block - last_code = code - code = "\n".join(block_lines) + "\n" - code = textwrap.dedent(code) - if file.suffix == ".py": - code = code.replace("\\\\", "\\") + if block.code_header.startswith("inline-snapshot-lib:"): + extra_files[block.code_header.split()[1]].append(block.code) + return block - flags = options & {"fix", "update", "create", "trim"} + if block.code_header.startswith("todo-inline-snapshot:"): + return block + assert False - args = ["--inline-snapshot", ",".join(flags)] if flags else [] + nonlocal last_code + with subtests.test(line=block.line): - if flags and "first_block" not in options: - project.setup(last_code) - else: - project.setup(code) + code = block.code - result = project.run(*args) + options = set(block.code_header.split()) - print("flags:", flags) + flags = options & {"fix", "update", "create", "trim"} - new_code = code - if flags: - new_code = project.source + args = ["--inline-snapshot", ",".join(flags)] if flags else [] - if "show_error" in options: - new_code = new_code.split("# Error:")[0] - new_code += "# Error:\n" + textwrap.indent( - result.errorLines(), "# " - ) + if flags and "first_block" not in options: + project.setup(last_code) + else: + project.setup(code) - print("new code:") - print(new_code) - print("expected code:") - print(code) + if extra_files: + all_files = [ + [(key, file) for file in files] + for key, files in extra_files.items() + ] + for files in itertools.product(*all_files): + for filename, content in files: + project.write_file(filename, content) + result = project.run(*args) - if ( - inline_snapshot._inline_snapshot._update_flags.fix - ): # pragma: no cover - flags_str = " ".join( - sorted(flags) - + sorted(options & {"first_block", "show_error"}) - + [ - f"outcome-{k}={v}" - for k, v in result.parseoutcomes().items() - if k in ("failed", "errors", "passed") - ] - ) - header_line = f"{indent}" + else: - new_lines.append(header_line) + result = project.run(*args) - from inline_snapshot._align import align - - linenum = 1 - hl_lines = "" - if last_code is not None and "first_block" not in options: - changed_lines = [] - alignment = align(last_code.split("\n"), new_code.split("\n")) - for c in alignment: - if c == "d": - continue - elif c == "m": - linenum += 1 - else: - changed_lines.append(str(linenum)) - linenum += 1 - if changed_lines: - hl_lines = f' hl_lines="{" ".join(changed_lines)}"' + print("flags:", flags, repr(block.block_options)) + + new_code = code + if flags: + new_code = project.source + + if "show_error" in options: + new_code = new_code.split("# Error:")[0] + new_code += "# Error:\n" + textwrap.indent(result.errorLines(), "# ") + + print("new code:") + print(new_code) + print("expected code:") + print(code) + + block.code_header = "inline-snapshot: " + " ".join( + sorted(flags) + + sorted(options & {"first_block", "show_error"}) + + [ + f"outcome-{k}={v}" + for k, v in result.parseoutcomes().items() + if k in ("failed", "errors", "passed") + ] + ) + + from inline_snapshot._align import align + + linenum = 1 + hl_lines = "" + if last_code is not None and "first_block" not in options: + changed_lines = [] + alignment = align(last_code.split("\n"), new_code.split("\n")) + for c in alignment: + if c == "d": + continue + elif c == "m": + linenum += 1 else: - assert False, "no lines changed" - - new_lines.append(f"{indent}``` python{hl_lines}") - - if ( - inline_snapshot._inline_snapshot._update_flags.fix - ): # pragma: no cover - new_code = new_code.rstrip("\n") - if file.suffix == ".py": - new_code = new_code.replace("\\", "\\\\") - new_code = textwrap.indent(new_code, indent) - - new_lines.append(new_code) + changed_lines.append(str(linenum)) + linenum += 1 + if changed_lines: + hl_lines = f'hl_lines="{" ".join(changed_lines)}"' else: - new_lines += block_lines + assert False, "no lines changed" + block.block_options = hl_lines - new_lines.append(line) + block.code = new_code - if not inline_snapshot._inline_snapshot._update_flags.fix: - if flags: - assert result.ret == 0 - else: - assert { - f"outcome-{k}={v}" - for k, v in result.parseoutcomes().items() - if k in ("failed", "errors", "passed") - } == {flag for flag in options if flag.startswith("outcome-")} - assert code == new_code - else: # pragma: no cover - pass - - continue - - m = header.fullmatch(line.strip()) - if m: - options = set(m.group(1).split()) - if first_block: - options.add("first_block") - first_block = False - header_line = line - is_block = True + if flags: + assert result.ret == 0 - if is_block: - block_lines.append(line) - else: - new_lines.append(line) + last_code = code + return block - if inline_snapshot._inline_snapshot._update_flags.fix: # pragma: no cover - file.write_text("\n".join(new_lines) + "\n", "utf-8") + map_code_blocks( + file, test_block, inline_snapshot._inline_snapshot._update_flags.fix + ) diff --git a/tests/test_fstring.py b/tests/test_fstring.py new file mode 100644 index 00000000..565ce941 --- /dev/null +++ b/tests/test_fstring.py @@ -0,0 +1,46 @@ +from inline_snapshot import snapshot +from inline_snapshot.testing import Example + +from .warns import warns + + +def test_fstring(): + Example( + """ +from inline_snapshot import snapshot + +def test_a(): + assert "a 1" == snapshot(f"a {1}") + """ + ).run_inline(reported_categories=snapshot([])) + + +def test_fstring_fix(): + + with warns( + snapshot( + [ + """\ +InlineSnapshotInfo: inline-snapshot will be able to fix f-strings in the future. +The current string value is: + 'a 1'\ +""" + ] + ) + ): + Example( + """ +from inline_snapshot import snapshot + +def test_a(): + assert "a 1" == snapshot(f"b {1}"), "not equal" + """ + ).run_inline( + ["--inline-snapshot=fix"], + raises=snapshot( + """\ +AssertionError: +not equal\ +""" + ), + ) diff --git a/tests/test_inline_snapshot.py b/tests/test_inline_snapshot.py index 82947e59..7b5c57ef 100644 --- a/tests/test_inline_snapshot.py +++ b/tests/test_inline_snapshot.py @@ -1,4 +1,3 @@ -import ast import contextlib import itertools import warnings @@ -8,12 +7,9 @@ from typing import Union import pytest -from hypothesis import given -from hypothesis.strategies import text from inline_snapshot import _inline_snapshot from inline_snapshot import snapshot from inline_snapshot._inline_snapshot import Flags -from inline_snapshot._utils import triple_quote from inline_snapshot.testing import Example from inline_snapshot.testing._example import snapshot_env @@ -584,194 +580,6 @@ def test_plain(check_update, executing_used): assert check_update("s = snapshot()", flags="") == snapshot("s = snapshot()") -def test_string_update(check_update): - # black --preview wraps strings to keep the line length. - # string concatenation should produce updates. - assert ( - check_update( - 'assert "ab" == snapshot("a" "b")', reported_flags="", flags="update" - ) - == 'assert "ab" == snapshot("a" "b")' - ) - - assert ( - check_update( - 'assert "ab" == snapshot("a"\n "b")', reported_flags="", flags="update" - ) - == 'assert "ab" == snapshot("a"\n "b")' - ) - - assert check_update( - 'assert "ab\\nc" == snapshot("a"\n "b\\nc")', flags="update" - ) == snapshot( - '''\ -assert "ab\\nc" == snapshot("""\\ -ab -c\\ -""")\ -''' - ) - - assert ( - check_update( - 'assert b"ab" == snapshot(b"a"\n b"b")', reported_flags="", flags="update" - ) - == 'assert b"ab" == snapshot(b"a"\n b"b")' - ) - - -def test_string_newline(check_update): - assert check_update('s = snapshot("a\\nb")', flags="update") == snapshot( - '''\ -s = snapshot("""\\ -a -b\\ -""")\ -''' - ) - - assert check_update('s = snapshot("a\\"\\"\\"\\nb")', flags="update") == snapshot( - """\ -s = snapshot('''\\ -a\"\"\" -b\\ -''')\ -""" - ) - - assert check_update( - 's = snapshot("a\\"\\"\\"\\n\\\'\\\'\\\'b")', flags="update" - ) == snapshot( - '''\ -s = snapshot("""\\ -a\\"\\"\\" -\'\'\'b\\ -""")\ -''' - ) - - assert check_update('s = snapshot(b"a\\nb")') == snapshot('s = snapshot(b"a\\nb")') - - assert check_update('s = snapshot("\\n\\\'")', flags="update") == snapshot( - '''\ -s = snapshot("""\\ - -'\\ -""")\ -''' - ) - - assert check_update('s = snapshot("\\n\\"")', flags="update") == snapshot( - '''\ -s = snapshot("""\\ - -"\\ -""")\ -''' - ) - - assert check_update("s = snapshot(\"'''\\n\\\"\")", flags="update") == snapshot( - '''\ -s = snapshot("""\\ -\'\'\' -\\"\\ -""")\ -''' - ) - - assert check_update('s = snapshot("\\n\b")', flags="update") == snapshot( - '''\ -s = snapshot("""\\ - -\\x08\\ -""")\ -''' - ) - - -def test_string_quote_choice(check_update): - assert check_update( - "s = snapshot(\" \\'\\'\\' \\'\\'\\' \\\"\\\"\\\"\\nother_line\")", - flags="update", - ) == snapshot( - '''\ -s = snapshot("""\\ - \'\'\' \'\'\' \\"\\"\\" -other_line\\ -""")\ -''' - ) - - assert check_update( - 's = snapshot(" \\\'\\\'\\\' \\"\\"\\" \\"\\"\\"\\nother_line")', flags="update" - ) == snapshot( - """\ -s = snapshot('''\\ - \\'\\'\\' \"\"\" \"\"\" -other_line\\ -''')\ -""" - ) - - assert check_update('s = snapshot("\\n\\"")', flags="update") == snapshot( - '''\ -s = snapshot("""\\ - -"\\ -""")\ -''' - ) - - assert check_update( - "s=snapshot('\\n')", flags="update", reported_flags="" - ) == snapshot("s=snapshot('\\n')") - assert check_update( - "s=snapshot('abc\\n')", flags="update", reported_flags="" - ) == snapshot("s=snapshot('abc\\n')") - assert check_update("s=snapshot('abc\\nabc')", flags="update") == snapshot( - '''\ -s=snapshot("""\\ -abc -abc\\ -""")\ -''' - ) - assert check_update("s=snapshot('\\nabc')", flags="update") == snapshot( - '''\ -s=snapshot("""\\ - -abc\\ -""")\ -''' - ) - assert check_update("s=snapshot('a\\na\\n')", flags="update") == snapshot( - '''\ -s=snapshot("""\\ -a -a -""")\ -''' - ) - - assert ( - check_update( - '''\ -s=snapshot("""\\ -a -""")\ -''', - flags="update", - ) - == snapshot('s=snapshot("a\\n")') - ) - - -@given(s=text()) -def test_string_convert(s): - print(s) - assert ast.literal_eval(triple_quote(s)) == s - - def test_flags_repr(): assert repr(Flags({"update"})) == "Flags({'update'})" @@ -834,40 +642,6 @@ def test_type_error(check_update): assert test1 == test2 -def test_invalid_repr(check_update): - assert ( - check_update( - """\ -class Thing: - def __repr__(self): - return "+++" - - def __eq__(self,other): - if not isinstance(other,Thing): - return NotImplemented - return True - -assert Thing() == snapshot() -""", - flags="create", - ) - == snapshot( - """\ -class Thing: - def __repr__(self): - return "+++" - - def __eq__(self,other): - if not isinstance(other,Thing): - return NotImplemented - return True - -assert Thing() == snapshot(HasRepr(Thing, "+++")) -""" - ) - ) - - def test_sub_snapshot_create(check_update): assert ( @@ -1058,62 +832,6 @@ def test_thing(): assert result.report == snapshot("") -def test_compare_dirty_equals_twice() -> None: - - Example( - """ -from dirty_equals import IsStr -from inline_snapshot import snapshot - -for x in 'ab': - assert x == snapshot(IsStr()) - assert [x,5] == snapshot([IsStr(),3]) - assert {'a':x,'b':5} == snapshot({'a':IsStr(),'b':3}) - -""" - ).run_inline( - ["--inline-snapshot=fix"], - changed_files=snapshot( - { - "test_something.py": """\ - -from dirty_equals import IsStr -from inline_snapshot import snapshot - -for x in 'ab': - assert x == snapshot(IsStr()) - assert [x,5] == snapshot([IsStr(),5]) - assert {'a':x,'b':5} == snapshot({'a':IsStr(),'b':5}) - -""" - } - ), - ) - - -def test_dirty_equals_in_unused_snapshot() -> None: - - Example( - """ -from dirty_equals import IsStr -from inline_snapshot import snapshot - -snapshot([IsStr(),3]) -snapshot((IsStr(),3)) -snapshot({1:IsStr(),2:3}) -snapshot({1+1:2}) - -t=(1,2) -d={1:2} -l=[1,2] -snapshot([t,d,l]) -""" - ).run_inline( - ["--inline-snapshot=fix"], - changed_files=snapshot({}), - ) - - @dataclass class Warning: message: str @@ -1154,7 +872,7 @@ def test_starred_warns_list(): """ from inline_snapshot import snapshot -assert [5] == snapshot([*[4]]) +assert [5] == snapshot([*[5]]) """ ).run_inline(["--inline-snapshot=fix"]) @@ -1175,57 +893,49 @@ def test_starred_warns_dict(): """ from inline_snapshot import snapshot -assert {1:3} == snapshot({**{1:2}}) +assert {1:3} == snapshot({**{1:3}}) """ ).run_inline(["--inline-snapshot=fix"]) -def test_now_like_dirty_equals(): - # test for cases like https://github.com/15r10nk/inline-snapshot/issues/116 +def test_is(): Example( """ -from dirty_equals import DirtyEquals -from inline_snapshot import snapshot +from inline_snapshot import snapshot,Is +def test_Is(): + for i in range(3): + assert ["hello",i] == snapshot(["hi",Is(i)]) + assert ["hello",i] == snapshot({1:["hi",Is(i)]})[i] +""" + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "test_something.py": """\ -def test_time(): - - now = 5 - - class Now(DirtyEquals): - def equals(self, other): - return other == now - - assert now == snapshot(Now()) - - now = 6 +from inline_snapshot import snapshot,Is - assert 5 == snapshot(Now()) +def test_Is(): + for i in range(3): + assert ["hello",i] == snapshot(["hi",Is(i)]) + assert ["hello",i] == snapshot({1:["hi",Is(i)], 0: ["hello", 0], 2: ["hello", 2]})[i] """ + } + ), ).run_inline( ["--inline-snapshot=fix"], changed_files=snapshot( { "test_something.py": """\ -from dirty_equals import DirtyEquals -from inline_snapshot import snapshot - - -def test_time(): - - now = 5 - - class Now(DirtyEquals): - def equals(self, other): - return other == now - - assert now == snapshot(Now()) - - now = 6 +from inline_snapshot import snapshot,Is - assert 5 == snapshot(5) +def test_Is(): + for i in range(3): + assert ["hello",i] == snapshot(["hello",Is(i)]) + assert ["hello",i] == snapshot({1:["hello",Is(i)], 0: ["hello", 0], 2: ["hello", 2]})[i] """ } ), diff --git a/tests/test_is.py b/tests/test_is.py new file mode 100644 index 00000000..cbbfa850 --- /dev/null +++ b/tests/test_is.py @@ -0,0 +1,22 @@ +from inline_snapshot._inline_snapshot import snapshot +from inline_snapshot.testing._example import Example + + +def test_missing_is(): + + Example( + """\ +from inline_snapshot import snapshot + +def test_is(): + for i in (1,2): + assert i == snapshot(i) + """ + ).run_inline( + raises=snapshot( + """\ +UsageError: +snapshot value should not change. Use Is(...) for dynamic snapshot parts.\ +""" + ) + ) diff --git a/tests/test_pydantic.py b/tests/test_pydantic.py index 0165be50..c4a16e63 100644 --- a/tests/test_pydantic.py +++ b/tests/test_pydantic.py @@ -33,7 +33,7 @@ class M(BaseModel): age:int=4 def test_pydantic(): - assert M(size=5,name="Tom")==snapshot(M(name="Tom", size=5)) + assert M(size=5,name="Tom")==snapshot(M(size=5, name="Tom")) \ """ diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest_plugin.py index 28444900..39ed1ef7 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest_plugin.py @@ -705,3 +705,23 @@ def test_a(): """ ), ) + + +def test_unknown_flag(): + + Example( + """\ +def test_a(): + assert 1==1 +""" + ).run_pytest( + ["--inline-snapshot=creaigflen"], + report=snapshot(""), + returncode=snapshot(4), + stderr=snapshot( + """\ +ERROR: --inline-snapshot=creaigflen is a unknown flag + +""" + ), + ) diff --git a/tests/test_string.py b/tests/test_string.py new file mode 100644 index 00000000..301b6230 --- /dev/null +++ b/tests/test_string.py @@ -0,0 +1,194 @@ +import ast + +from hypothesis import given +from hypothesis.strategies import text +from inline_snapshot import snapshot +from inline_snapshot._utils import triple_quote + + +def test_string_update(check_update): + # black --preview wraps strings to keep the line length. + # string concatenation should produce updates. + assert ( + check_update( + 'assert "ab" == snapshot("a" "b")', reported_flags="", flags="update" + ) + == 'assert "ab" == snapshot("a" "b")' + ) + + assert ( + check_update( + 'assert "ab" == snapshot("a"\n "b")', reported_flags="", flags="update" + ) + == 'assert "ab" == snapshot("a"\n "b")' + ) + + assert check_update( + 'assert "ab\\nc" == snapshot("a"\n "b\\nc")', flags="update" + ) == snapshot( + '''\ +assert "ab\\nc" == snapshot("""\\ +ab +c\\ +""")\ +''' + ) + + assert ( + check_update( + 'assert b"ab" == snapshot(b"a"\n b"b")', reported_flags="", flags="update" + ) + == 'assert b"ab" == snapshot(b"a"\n b"b")' + ) + + +def test_string_newline(check_update): + assert check_update('s = snapshot("a\\nb")', flags="update") == snapshot( + '''\ +s = snapshot("""\\ +a +b\\ +""")\ +''' + ) + + assert check_update('s = snapshot("a\\"\\"\\"\\nb")', flags="update") == snapshot( + """\ +s = snapshot('''\\ +a\"\"\" +b\\ +''')\ +""" + ) + + assert check_update( + 's = snapshot("a\\"\\"\\"\\n\\\'\\\'\\\'b")', flags="update" + ) == snapshot( + '''\ +s = snapshot("""\\ +a\\"\\"\\" +\'\'\'b\\ +""")\ +''' + ) + + assert check_update('s = snapshot(b"a\\nb")') == snapshot('s = snapshot(b"a\\nb")') + + assert check_update('s = snapshot("\\n\\\'")', flags="update") == snapshot( + '''\ +s = snapshot("""\\ + +'\\ +""")\ +''' + ) + + assert check_update('s = snapshot("\\n\\"")', flags="update") == snapshot( + '''\ +s = snapshot("""\\ + +"\\ +""")\ +''' + ) + + assert check_update("s = snapshot(\"'''\\n\\\"\")", flags="update") == snapshot( + '''\ +s = snapshot("""\\ +\'\'\' +\\"\\ +""")\ +''' + ) + + assert check_update('s = snapshot("\\n\b")', flags="update") == snapshot( + '''\ +s = snapshot("""\\ + +\\x08\\ +""")\ +''' + ) + + +def test_string_quote_choice(check_update): + assert check_update( + "s = snapshot(\" \\'\\'\\' \\'\\'\\' \\\"\\\"\\\"\\nother_line\")", + flags="update", + ) == snapshot( + '''\ +s = snapshot("""\\ + \'\'\' \'\'\' \\"\\"\\" +other_line\\ +""")\ +''' + ) + + assert check_update( + 's = snapshot(" \\\'\\\'\\\' \\"\\"\\" \\"\\"\\"\\nother_line")', flags="update" + ) == snapshot( + """\ +s = snapshot('''\\ + \\'\\'\\' \"\"\" \"\"\" +other_line\\ +''')\ +""" + ) + + assert check_update('s = snapshot("\\n\\"")', flags="update") == snapshot( + '''\ +s = snapshot("""\\ + +"\\ +""")\ +''' + ) + + assert check_update( + "s=snapshot('\\n')", flags="update", reported_flags="" + ) == snapshot("s=snapshot('\\n')") + assert check_update( + "s=snapshot('abc\\n')", flags="update", reported_flags="" + ) == snapshot("s=snapshot('abc\\n')") + assert check_update("s=snapshot('abc\\nabc')", flags="update") == snapshot( + '''\ +s=snapshot("""\\ +abc +abc\\ +""")\ +''' + ) + assert check_update("s=snapshot('\\nabc')", flags="update") == snapshot( + '''\ +s=snapshot("""\\ + +abc\\ +""")\ +''' + ) + assert check_update("s=snapshot('a\\na\\n')", flags="update") == snapshot( + '''\ +s=snapshot("""\\ +a +a +""")\ +''' + ) + + assert ( + check_update( + '''\ +s=snapshot("""\\ +a +""")\ +''', + flags="update", + ) + == snapshot('s=snapshot("a\\n")') + ) + + +@given(s=text()) +def test_string_convert(s): + print(s) + assert ast.literal_eval(triple_quote(s)) == s diff --git a/tests/test_warns.py b/tests/test_warns.py new file mode 100644 index 00000000..327971de --- /dev/null +++ b/tests/test_warns.py @@ -0,0 +1,34 @@ +import warnings + +from inline_snapshot import snapshot + +from tests.warns import warns + + +def test_warns(): + + def warning(): + warnings.warn_explicit( + message="bad things happen", + category=SyntaxWarning, + filename="file.py", + lineno=5, + ) + + with warns( + snapshot([("file.py", 5, "SyntaxWarning: bad things happen")]), + include_line=True, + include_file=True, + ): + warning() + + with warns( + snapshot([("file.py", "SyntaxWarning: bad things happen")]), + include_file=True, + ): + warning() + + with warns( + snapshot(["SyntaxWarning: bad things happen"]), + ): + warning() diff --git a/tests/warns.py b/tests/warns.py new file mode 100644 index 00000000..6cf62557 --- /dev/null +++ b/tests/warns.py @@ -0,0 +1,24 @@ +import contextlib +import warnings + + +@contextlib.contextmanager +def warns(expected_warnings=[], include_line=False, include_file=False): + with warnings.catch_warnings(record=True) as result: + warnings.simplefilter("always") + yield + + def make_warning(w): + message = f"{w.category.__name__}: {w.message}" + if not include_line and not include_file: + return message + message = (message,) + + if include_line: + message = (w.lineno, *message) + if include_file: + message = (w.filename, *message) + + return message + + assert [make_warning(w) for w in result] == expected_warnings