diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000..50e4653608 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,9 @@ +* + +!tests/ +!splink/ +!pyproject.toml +!poetry.lock +!README.md + +**/*.pyc diff --git a/.github/workflows/run_demos_examples.yml b/.github/workflows/run_demos_examples.yml index b93cbc74f4..2cd4795635 100644 --- a/.github/workflows/run_demos_examples.yml +++ b/.github/workflows/run_demos_examples.yml @@ -36,6 +36,7 @@ jobs: - name: Install environment and check notebooks run: | cd splink_demos + cp ../benchmarking/conftest.py conftest.py python3 -m venv venv source venv/bin/activate pip install --upgrade pip diff --git a/.github/workflows/run_demos_tutorials.yml b/.github/workflows/run_demos_tutorials.yml index 6310ce0926..e1e704d5a0 100644 --- a/.github/workflows/run_demos_tutorials.yml +++ b/.github/workflows/run_demos_tutorials.yml @@ -28,6 +28,7 @@ jobs: - name: Install environment and check notebooks run: | cd splink_demos + cp ../benchmarking/conftest.py conftest.py python3 -m venv venv source venv/bin/activate pip install --upgrade pip diff --git a/.gitignore b/.gitignore index f35ea3aa27..2bde4279c2 100644 --- a/.gitignore +++ b/.gitignore @@ -172,4 +172,7 @@ cython_debug/ *.parquet *.csv -.DS_Store \ No newline at end of file +.DS_Store + +# vscode local settings +.vscode diff --git a/benchmarking/conftest.py b/benchmarking/conftest.py new file mode 100644 index 0000000000..7e334ec7e6 --- /dev/null +++ b/benchmarking/conftest.py @@ -0,0 +1,5 @@ +# add default marker to all tests - this flag is on by default +# set in pyproject.toml to aid testing tests/ +def pytest_collection_modifyitems(items, config): + for item in items: + item.add_marker("default") diff --git a/docs/dev_guides/changing_splink/running_tests_locally.md b/docs/dev_guides/changing_splink/running_tests_locally.md deleted file mode 100644 index 74e9f97d65..0000000000 --- a/docs/dev_guides/changing_splink/running_tests_locally.md +++ /dev/null @@ -1,32 +0,0 @@ -## Running tests locally - -To run tests locally, simply run: -```sh -python3 -m pytest tests/ -s --disable-pytest-warnings -``` - -To run a single test, append the filename to the `tests/` folder call: -```sh -python3 -m pytest tests/my_test_file.py -s --disable-pytest-warnings -``` - -## Running tests with docker 🐳 - -If you want to test Splink against a specific version of python, the easiest method is to utilise docker 🐳. - -Docker allows you to more quickly and easily install a specific version of python and run the existing test library against it. - -This is particularly useful if you're using py > 3.9.10 (which is currently in use in our tests github action) and need to run a secondary set of tests. - -A pre-built Dockerfile for running tests against python version 3.9.10 can be located within [scripts/run_tests.Dockerfile](https://github.com/moj-analytical-services/splink/scripts/scripts/run_tests.Dockerfile). - -To run, simply use the following docker command from within a terminal and the root folder of a splink clone: -```sh -docker build -t run_tests:testing -f scripts/run_tests.Dockerfile . && docker run run_tests:testing -``` - -This will both build and run the tests library. - -Feel free to replace `run_tests:testing` with an image name and tag you're happy with. - -Reusing the same image and tag will overwrite your existing image. \ No newline at end of file diff --git a/docs/dev_guides/changing_splink/testing.md b/docs/dev_guides/changing_splink/testing.md new file mode 100644 index 0000000000..6768e0d28f --- /dev/null +++ b/docs/dev_guides/changing_splink/testing.md @@ -0,0 +1,366 @@ +--- +tags: + - Testing + - Pytest + - Backends +--- +# Testing in Splink + +Tests in Splink make use of the [pytest](https://docs.pytest.org) framework. You can find the tests themselves in [the tests folder](https://github.com/moj-analytical-services/splink/tree/master/tests). + +Splink tests can be broadly categorised into three sets: + +* **'Core' tests** - these are tests which test some specific bit of functionality which does not depend on any specific SQL dialect. They are usually unit tests - examples are testing [`InputColumn`](https://github.com/moj-analytical-services/splink/blob/master/tests/test_input_column.py) and testing the [latitude-longitude distance calculation](https://github.com/moj-analytical-services/splink/blob/master/tests/test_lat_long_distance.py). +* **Backend-agnostic tests** - these are tests which run against some SQL backend, but which are written in such a way that they can run against many backends by making use of the [backend-agnostic testing framework](#backend-agnostic-testing). The majority of tests are of this type. +* **Backend-specific tests** - these are tests which run against a specific SQL backend, and test some feature particular to this backend. There are not many of these, as Splink is designed to run very similarly independent of the backend used. + +!!! info + We currently do not have support for testing the `athena` backend, due to the complication of needing a connection to an AWS account. All other backends have testing available. + +## Running tests + +### Running tests locally + +To run tests locally, simply run: +```sh +python3 -m pytest tests/ +``` +or alternatively +```sh +pytest tests/ +``` + +To run a single test file, append the filename to the `tests/` folder call, for example: +```sh +pytest tests/test_u_train.py +``` +or for a single test, additionally append the test name after a pair of colons, as: +```sh +pytest tests/test_u_train.py::test_u_train_multilink +``` + +??? tip "Further useful pytest options" + There may be many warnings emitted, for instance by library dependencies, cluttering your output in which case you can use `--disable-pytest-warnings` or `-W ignore` so that these will not be displayed. Some additional command-line options that may be useful: + + * `-s` to disable output capture, so that test output is displayed in the terminal in all cases + * `-v` for verbose mode, where each test instance will be displayed on a separate line with status + * `-q` for quiet mode, where output is extremely minimal + * `-x` to fail on first error/failure rather than continuing to run all selected tests + * + * `-m some_mark` run only those tests marked with `some_mark` - see [below](#selecting-sets-of-tests) for useful options here + + For instance usage might be: + ```sh + # ignore warnings, display output + pytest -W ignore -s tests/ + ``` + + or + ```sh + # ignore warnings, verbose output, fail on first error/failure + pytest -W ignore -v -x tests/ + ``` + + You can find a host of other available options using pytest's in-built help: + ```sh + pytest -h + ``` + +#### Running tests for specific backends or backend groups + +You may wish to run tests relating to to specific backends, tests which are backend-independent, or any combinations of these. Splink allows for various combinations by making use of `pytest`'s [`mark` feature](https://docs.pytest.org/en/latest/example/markers.html). + +If when you invoke pytest you pass no marks explicitly, there will be an implicit mark of `default`, as per the [pyproject.toml pytest.ini configuration](https://github.com/moj-analytical-services/splink/blob/master/pyproject.toml). + +The available options are: + +##### Run core tests +Option for running only the backend-independent 'core' tests: + +* `pytest tests/ -m core` - run only the 'core' tests, meaning those without dialect-dependence. In practice this means any test that hasn't been decorated using `mark_with_dialects_excluding` or `mark_with_dialects_including`. + +##### Run tests on a specific backend +Options for running tests on one backend only - this includes tests written [specifically for that backend](#tests-for-specific-backends), as well as [backend-agnostic tests](#backend-agnostic-testing) supported for that backend. + +* `pytest tests/ -m duckdb` - run all `duckdb` tests, and all `core` tests + * & similarly for other dialects +* `pytest tests/ -m duckdb_only` - run all `duckdb` tests only, and _not_ the `core` tests + * & similarly for other dialects + +##### Run tests across multiple backends +Options for running tests on multiple backends (including all backends) - this includes tests written [specifically for those backends](#tests-for-specific-backends), as well as [backend-agnostic tests](#backend-agnostic-testing) supported for those backends. + + * `pytest tests/ -m default` or equivalently `pytest tests/` - run all tests in the `default` group. The `default` group consists of the `core` tests, and those dialects in the `default` group - currently `spark` and `duckdb`. + * Other groups of dialects can be added and will similarly run with `pytest tests/ -m new_dialect_group`. Dialects within the current scope of testing and the groups they belong to are defined in the `dialect_groups` dictionary in [tests/decorator.py](https://github.com/moj-analytical-services/splink/blob/master/tests/decorator.py) +* `pytest tests/ -m all` run all tests for all available dialects + +These all work alongside all the other pytest options, so for instance to run the tests for training `probability_two_random_records_match` for only `duckdb`, ignoring warnings, with quiet output, and exiting on the first failure/error: +```sh +pytest -W ignore -q -x -m duckdb tests/test_estimate_prob_two_rr_match.py +``` + +??? tip "Running tests with docker 🐳" + + If you want to test Splink against a specific version of python, the easiest method is to utilise docker 🐳. + + Docker allows you to more quickly and easily install a specific version of python and run the existing test library against it. + + This is particularly useful if you're using py > 3.9.10 (which is currently in use in our tests github action) and need to run a secondary set of tests. + + A pre-built Dockerfile for running tests against python version 3.9.10 can be located within [scripts/run_tests.Dockerfile](https://github.com/moj-analytical-services/splink/blob/master/scripts/run_tests.Dockerfile). + + To run, simply use the following docker command from within a terminal and the root folder of a splink clone: + ```sh + docker build -t run_tests:testing -f scripts/run_tests.Dockerfile . && docker run --rm --name splink-test run_tests:testing + ``` + + This will both build and run the tests library. + + Feel free to replace `run_tests:testing` with an image name and tag you're happy with. + + Reusing the same image and tag will overwrite your existing image. + + You can also overwrite the default `CMD` if you want a different set of `pytest` command-line options, for example + ```sh + docker run --rm --name splink-test run_tests:testing pytest -W ignore -m spark tests/test_u_train.py + ``` + +### Tests in CI + +Splink utilises [github actions](https://docs.github.com/en/actions) to run tests for each pull request. This consists of a few independent checks: + +* The full test suite is run separately against several different python versions +* The [example notebooks](./examples_index.html) are checked to ensure they run without error +* The [tutorial notebooks](./demos/00_Tutorial_Introduction.html) are checked to ensure they run without error + +## Writing tests + +### Core tests + +Core tests are treated the same way as ordinary pytest tests. Any test is marked as `core` by default, and will only be excluded from being a core test if it is decorated using either: + +* `@mark_with_dialects_excluding` for [backend-agnostic tests](#backend-agnostic-testing), or +* `@mark_with_dialects_including` for [backend-specific tests](#tests-for-specific-backends) + +from the [test decorator file](https://github.com/moj-analytical-services/splink/blob/master/tests/decorator.py). + +### Backend-agnostic testing + +The majority of tests should be written using the backend-agnostic testing framework. This just provides some small tools which allow tests to be written in a backend-independent way. This means the tests can then by run against _all_ available SQL backends (or a subset, if some lack _necessary_ features for the test). + +As an example, let's consider a test that will run on all dialects, and then break down the various parts to see what each is doing. + +```py linenums="1" +from tests.decorator import mark_with_dialects_excluding + +@mark_with_dialects_excluding() +def test_feature_that_works_for_all_backends(test_helpers, dialect, some_other_test_fixture): + helper = test_helpers[dialect] + + df = helper.load_frame_from_csv("./tests/datasets/fake_1000_from_splink_demos.csv") + + settings_dict = { + "link_type": "dedupe_only", + "blocking_rules_to_generate_predictions": ["l.city = r.city", "l.surname = r.surname", "l.dob = r.dob"], + "comparisons": [ + helper.cl.exact_match("city"), + helper.cl.levenshtein_at_thresholds("first_name", [1, 2]), + helper.cl.levenshtein_at_thresholds("surname"), + { + "output_column_name": "email", + "comparison_description": "Email", + "comparison_levels": [ + helper.cll.null_level("email"), + helper.cll.exact_match_level("email"), + helper.cll.levenshtein_level("email", 2), + { + "sql_condition": "substr(email_l, 1) = substr(email_r, 1)", + "label_for_charts": "email first character matches", + }, + helper.cll.else_level(), + ] + } + ] + } + + linker = helper.Linker(df, settings_dict, **helper.extra_linker_args()) + + # and then some actual testing logic +``` + +Firstly you should import the decorator-factory `mark_with_dialects_excluding`, which will decorate each test function: + +```py linenums="1" +from tests.decorator import mark_with_dialects_excluding +``` + +Then we define the function, and pass parameters: + +```py linenums="3" hl_lines="1" +@mark_with_dialects_excluding() +def test_feature_that_works_for_all_backends(test_helpers, dialect, some_other_test_fixture): +``` + +The decorator `@mark_with_dialects_excluding()` will do two things: + +* marks the test it decorates with the appropriate custom `pytest` marks. This ensures that it will be run with tests for each dialect, excluding any that are passed as arguments; in this case it will be run for all dialects, as we have passed no arguments. +* parameterises the test with a string parameter `dialect`, which will be used to configure the test for that dialect. The test will run for each value of `dialect` possible, excluding any passed to the decorator (none in this case). + +You should aim to exclude as _few_ dialects as possible - consider if you really need to exclude any. Dialects should only be excluded if the test doesn't make sense for them due to features they lack. The default choice should be the decorator with no arguments `@mark_with_dialects_excluding()`, meaning the test runs for _all_ dialects. + +```py linenums="3" hl_lines="2" +@mark_with_dialects_excluding() +def test_feature_that_works_for_all_backends(test_helpers, dialect, some_other_test_fixture): +``` + +As well as the parameter `dialect` (which is provided by the decorator), we must also pass the helper-factory fixture `test_helpers`. We can additionally pass further [fixtures](https://docs.pytest.org/en/latest/how-to/fixtures.html) if needed - in this case `some_other_test_fixture`. +We could similarly provide an _explicit_ parameterisation to the test, in which case we would also pass these parameters - see [the pytest docs on parameterisation](https://doc.pytest.org/en/latest/example/parametrize.html#set-marks-or-test-id-for-individual-parametrized-test) for more information. + + +```py linenums="5" + helper = test_helpers[dialect] +``` + +The fixture `test_helpers` is simply a dictionary of the specific-dialect test helpers - here we pick the appropriate one for our test. + +Each helper has the same set of methods and properties, which encapsulate _all_ of the dialect-dependencies. You can find the full set of properties and methods by examining the source for the [base class `TestHelper`](https://github.com/moj-analytical-services/splink/blob/master/tests/helpers.py). + +```py linenums="7" + df = helper.load_frame_from_csv("./tests/datasets/fake_1000_from_splink_demos.csv") +``` + +Here we are now actually using a method of the test helper - in this case we are loading a table from a csv to the database and returning it in a form suitable for passing to a Splink linker. + + +```py linenums="12" hl_lines="2 3 4" + "comparisons": [ + helper.cl.exact_match("city"), + helper.cl.levenshtein_at_thresholds("first_name", [1, 2]), + helper.cl.levenshtein_at_thresholds("surname"), + { + "output_column_name": "email", +``` +We reference the dialect-specific [comparison library](../../comparison_library.html) as `helper.cl`, + +```py linenums="16" hl_lines="5 6 7 12" + { + "output_column_name": "email", + "comparison_description": "Email", + "comparison_levels": [ + helper.cll.null_level("email"), + helper.cll.exact_match_level("email"), + helper.cll.levenshtein_level("email", 2), + { + "sql_condition": "substr(email_l, 1) = substr(email_r, 1)", + "label_for_charts": "email first character matches", + } + helper.cll.else_level(), + ] + } +``` +and the dialect-specific [comparison level library](../../comparison_level_library.html) as `helper.cll`. + +```py linenums="23" hl_lines="2" + { + "sql_condition": "substr(email_l, 1) = substr(email_r, 1)", + "label_for_charts": "email first character matches", + }, +``` +We can include raw SQL statements, but we must ensure they are valid for all dialects we are considering, so we should avoid any unusual functions that are not likely to be universal. + +```py linenums="33" + linker = helper.Linker(df, settings_dict, **helper.extra_linker_args()) +``` +Finally we instantiate the linker, passing any default set of extra arguments provided by the helper, which some dialects require. + +From this point onwards we will be working with the instantiated `linker`, and so will not need to refer to `helper` any more - the rest of the test can be written as usual. + +#### Excluding some backends + +Now let's have a small look at a similar example - only this time we are going to exclude the `sqlite` backend, as the test relies on features not directly available for that backend. In this example that will be the SQL function `split_part` which does not exist in the `sqlite` dialect. + +!!! warning Reminder + Tests should be made available to the widest range of backends possible. Only exclude backends if features not shared by all backends are crucial to the test-logic - otherwise consider rewriting things so that all backends are covered. + +```py linenums="1" +from tests.decorator import mark_with_dialects_excluding + +@mark_with_dialects_excluding("sqlite") +def test_feature_that_doesnt_work_with_sqlite(test_helpers, dialect, some_other_test_fixture): + helper = test_helpers[dialect] + + df = helper.load_frame_from_csv("./tests/datasets/fake_1000_from_splink_demos.csv") + + settings_dict = { + "link_type": "dedupe_only", + "blocking_rules_to_generate_predictions": ["l.city = r.city", "l.surname = r.surname", "l.dob = r.dob"], + "comparisons": [ + helper.cl.exact_match("city"), + helper.cl.levenshtein_at_thresholds("first_name", [1, 2]), + helper.cl.levenshtein_at_thresholds("surname"), + { + "output_column_name": "email", + "comparison_description": "Email", + "comparison_levels": [ + helper.cll.null_level("email"), + helper.cll.exact_match_level("email"), + helper.cll.levenshtein_level("email", 2), + { + "sql_condition": "split_part(email_l, '@', 1) = split_part(email_r, '@', 1)", + "label_for_charts": "email local-part matches", + }, + helper.cll.else_level(), + ] + } + ] + } + + linker = helper.Linker(df, settings_dict, **helper.extra_linker_args()) + + # and then some actual testing logic +``` + +The key difference is the argument we pass to the decorator: +```py linenums="3" hl_lines="1" +@mark_with_dialects_excluding("sqlite") +def test_feature_that_doesnt_work_with_sqlite(test_helpers, dialect, some_other_test_fixture): +``` +As above this marks the test it decorates with the appropriate custom `pytest` marks, but in this case it ensures that it will be run with tests for each dialect **excluding sqlite**. Again `dialect` is passed as a parameter, and the test will run in turn for each value of `dialect` **except for 'sqlite'**. + +```py linenums="23" hl_lines="2" + { + "sql_condition": "split_part(email_l, '@', 1) = split_part(email_r, '@', 1)", + "label_for_charts": "email local-part matches", + } +``` +This line is why we cannot allow `sqlite` for this test - we make use of the function `split_part` which is not available in the `sqlite` dialect, hence its exclusion. We suppose that this particular comparison level is crucial for the test to make sense, otherwise we would rewrite this line to make it run universally. When you come to [run the tests](#running-tests-locally), this test will not run on the `sqlite` backend. + +If you need to exclude _multiple_ dialects this is also possible - just pass each as an argument. For example, to decorate a test that is not supported on `spark` _or_ `sqlite`, use the decorator `@mark_with_dialects_excluding("sqlite", "spark")`. + +### Backend-specific tests + +If you intend to write a test for a specific backend, first consider whether it is definitely specific to that backend - if not then a [backend-agnostic test](#backend-agnostic-testing) would be preferable, as then your test will be run against _many_ backends. +If you really do need to test features peculiar to one backend, then you can write it simply as you would an ordinary `pytest` test. The only difference is that you should decorate it with `@mark_with_dialects_including` (from [tests/decorator.py](https://github.com/moj-analytical-services/splink/blob/master/tests/decorator.py)) - for example: + +=== "DuckDB" + ```py + @mark_with_dialects_including("duckdb") + def test_some_specific_duckdb_feature(): + ... + ``` +=== "Spark" + ```py + @mark_with_dialects_including("spark") + def test_some_specific_spark_feature(): + ... + ``` +=== "SQLite" + ```py + @mark_with_dialects_including("sqlite") + def test_some_specific_sqlite_feature(): + ... + ``` + +This ensures that the test gets marked appropriately for running when the `Spark` tests should be run, and excludes it from the set of `core` tests. + +Note that unlike the exclusive `mark_with_dialects_excluding`, this decorator will _not_ paramaterise the test with the `dialect` argument. This is because usage of the _inclusive_ form is largely designed for single-dialect tests. If you wish to override this behaviour and parameterise the test you can use the argument `pass_dialect`, for example `@mark_with_dialects_including("spark", "sqlite", pass_dialect=True)`, in which case you would need to write the test in a [backend-independent manner](#backend-agnostic-testing). diff --git a/mkdocs.yml b/mkdocs.yml index 3781ee9e2d..01abfa5008 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -127,7 +127,7 @@ nav: - Building a Virtual Environment: "dev_guides/changing_splink/building_env_locally.md" - Linting: "dev_guides/changing_splink/lint.md" - Building Docs: "dev_guides/changing_splink/build_docs_locally.md" - - Running Tests: "dev_guides/changing_splink/running_tests_locally.md" + - Testing: "dev_guides/changing_splink/testing.md" - Releasing a Package Version: "dev_guides/changing_splink/releases.md" - Caching and pipelining: "dev_guides/caching.md" - Understanding and debugging Splink: "dev_guides/debug_modes.md" diff --git a/pyproject.toml b/pyproject.toml index 84bd6eb9e9..3bea7a7c61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,4 +61,21 @@ select = [ ignore = [ "B905", # `zip()` without an explicit `strict=` parameter "B006", # Do not use mutable data structures for argument defaults" -] \ No newline at end of file +] + +[tool.pytest.ini_options] +addopts = ["-m default"] +markers = [ +# only tests where backend is irrelevant: + "core", +# see tests/decorator.py::dialect_groups for group details: + "default", + "all", +# backend-specific sets + "duckdb", + "duckdb_only", + "spark", + "spark_only", + "sqlite", + "sqlite_only", +] diff --git a/splink/sqlite/sqlite_comparison_level_library.py b/splink/sqlite/sqlite_comparison_level_library.py index 08f371444b..250ea46bce 100644 --- a/splink/sqlite/sqlite_comparison_level_library.py +++ b/splink/sqlite/sqlite_comparison_level_library.py @@ -4,6 +4,7 @@ DistanceFunctionLevelBase, ElseLevelBase, ExactMatchLevelBase, + LevenshteinLevelBase, NullLevelBase, PercentageDifferenceLevelBase, ) @@ -30,6 +31,9 @@ def _distance_function_level(self): return distance_function_level @property + def _levenshtein_level(self): + return levenshtein_level + def _columns_reversed_level(self): return columns_reversed_level @@ -46,6 +50,10 @@ class else_level(SqliteBase, ElseLevelBase): pass +class levenshtein_level(SqliteBase, LevenshteinLevelBase): + pass + + class columns_reversed_level(SqliteBase, ColumnsReversedLevelBase): pass diff --git a/splink/sqlite/sqlite_comparison_library.py b/splink/sqlite/sqlite_comparison_library.py index 5f6ea00382..8369df5661 100644 --- a/splink/sqlite/sqlite_comparison_library.py +++ b/splink/sqlite/sqlite_comparison_library.py @@ -1,6 +1,7 @@ from ..comparison_library import ( DistanceFunctionAtThresholdsComparisonBase, ExactMatchBase, + LevenshteinAtThresholdsComparisonBase, ) from .sqlite_comparison_level_library import SqliteComparisonProperties @@ -13,3 +14,11 @@ class distance_function_at_thresholds( SqliteComparisonProperties, DistanceFunctionAtThresholdsComparisonBase ): pass + + +class levenshtein_at_thresholds( + SqliteComparisonProperties, LevenshteinAtThresholdsComparisonBase +): + @property + def _distance_level(self): + return self._levenshtein_level diff --git a/tests/conftest.py b/tests/conftest.py index 3e29930e5f..53dc89af65 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,10 +3,25 @@ import pytest from splink.spark.jar_location import similarity_jar_location +from tests.decorator import dialect_groups +from tests.helpers import DuckDBTestHelper, SparkTestHelper, SQLiteTestHelper logger = logging.getLogger(__name__) +def pytest_collection_modifyitems(items, config): + # any tests without backend-group markers will always run + marks = {gp for groups in dialect_groups.values() for gp in groups} + # any mark we've added, but excluding e.g. parametrize + our_marks = {*marks, *dialect_groups.keys()} + + for item in items: + if not any(marker.name in our_marks for marker in item.iter_markers()): + item.add_marker("core") + for mark in our_marks: + item.add_marker(mark) + + @pytest.fixture(scope="module") def spark(): from pyspark import SparkConf, SparkContext @@ -35,3 +50,14 @@ def df_spark(spark): df = spark.read.csv("./tests/datasets/fake_1000_from_splink_demos.csv", header=True) df.persist() yield df + + +# workaround as you can't pass fixtures as param arguments in base pytest +# see e.g. https://stackoverflow.com/a/42400786/11811947 +@pytest.fixture +def test_helpers(spark): + return { + "duckdb": DuckDBTestHelper(), + "spark": SparkTestHelper(spark), + "sqlite": SQLiteTestHelper(), + } diff --git a/tests/decorator.py b/tests/decorator.py new file mode 100644 index 0000000000..336907dd05 --- /dev/null +++ b/tests/decorator.py @@ -0,0 +1,48 @@ +import pytest + +dialect_groups = { + "duckdb": ["default"], + "spark": ["default"], + "sqlite": [], +} +for groups in dialect_groups.values(): + groups.append("all") + + +def invert(sql_dialects_missing): + return ( + sql_d for sql_d in dialect_groups.keys() if sql_d not in sql_dialects_missing + ) + + +def mark_with_dialects_excluding(*sql_dialects_missing): + sql_dialects = invert(sql_dialects_missing) + return mark_with_dialects_including(*sql_dialects, pass_dialect=True) + + +def mark_with_dialects_including(*sql_dialects, pass_dialect=False): + def mark_decorator(test_fn): + params = [] + all_marks = [] + for sql_d in sql_dialects: + # marks for whatever groups the dialect is in + marks = [ + getattr(pytest.mark, dialect_group) + for dialect_group in dialect_groups[sql_d] + ] + # plus the basic dialect mark + dialect_mark = getattr(pytest.mark, sql_d) + dialect_only_mark = getattr(pytest.mark, f"{sql_d}_only") + marks += [dialect_mark, dialect_only_mark] + params.append(pytest.param(sql_d, marks=marks)) + # will end up with duplicates, but think that's okay. for now at least. + all_marks += marks + + if pass_dialect: + test_fn = pytest.mark.parametrize("dialect", params)(test_fn) + else: + for mark in all_marks: + test_fn = mark(test_fn) + return test_fn + + return mark_decorator diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 0000000000..02fac44d19 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,160 @@ +import sqlite3 +from abc import ABC, abstractmethod + +import pandas as pd + +import splink.duckdb.duckdb_comparison_level_library as cll_duckdb +import splink.duckdb.duckdb_comparison_library as cl_duckdb +import splink.duckdb.duckdb_comparison_template_library as ctl_duckdb +import splink.spark.spark_comparison_level_library as cll_spark +import splink.spark.spark_comparison_library as cl_spark +import splink.spark.spark_comparison_template_library as ctl_spark +import splink.sqlite.sqlite_comparison_level_library as cll_sqlite +import splink.sqlite.sqlite_comparison_library as cl_sqlite +import splink.sqlite.sqlite_comparison_template_library as ctl_sqlite +from splink.duckdb.duckdb_linker import DuckDBLinker +from splink.spark.spark_linker import SparkLinker +from splink.sqlite.sqlite_linker import SQLiteLinker + + +class TestHelper(ABC): + @property + @abstractmethod + def Linker(self): + pass + + def extra_linker_args(self): + return {} + + @abstractmethod + def convert_frame(self, df): + pass + + def load_frame_from_csv(self, path): + return pd.read_csv(path) + + def load_frame_from_parquet(self, path): + return pd.read_parquet(path) + + @property + @abstractmethod + def cll(self): + pass + + @property + @abstractmethod + def cl(self): + pass + + @property + @abstractmethod + def ctl(self): + pass + + +class DuckDBTestHelper(TestHelper): + @property + def Linker(self): + return DuckDBLinker + + def convert_frame(self, df): + return df + + @property + def cll(self): + return cll_duckdb + + @property + def cl(self): + return cl_duckdb + + @property + def ctl(self): + return ctl_duckdb + + +class SparkTestHelper(TestHelper): + def __init__(self, spark): + self.spark = spark + + @property + def Linker(self): + return SparkLinker + + def extra_linker_args(self): + return {"spark": self.spark} + + def convert_frame(self, df): + spark_frame = self.spark.createDataFrame(df) + spark_frame.persist() + return spark_frame + + def load_frame_from_csv(self, path): + df = self.spark.read.csv(path, header=True) + df.persist() + return df + + def load_frame_from_parquet(self, path): + df = self.spark.read.parquet(path) + df.persist() + return df + + @property + def cll(self): + return cll_spark + + @property + def cl(self): + return cl_spark + + @property + def ctl(self): + return ctl_spark + + +class SQLiteTestHelper(TestHelper): + def __init__(self): + from rapidfuzz.distance.Levenshtein import distance + + def lev_wrap(str_l, str_r): + return distance(str(str_l), str(str_r)) + + con = sqlite3.connect(":memory:") + con.create_function("levenshtein", 2, lev_wrap) + self.con = con + self._frame_counter = 0 + + @property + def Linker(self): + return SQLiteLinker + + def extra_linker_args(self): + return {"connection": self.con} + + def _get_input_name(self): + name = f"input_alias_{self._frame_counter}" + self._frame_counter += 1 + return name + + def convert_frame(self, df): + name = self._get_input_name() + df.to_sql(name, self.con, if_exists="replace") + return name + + def load_frame_from_csv(self, path): + return self.convert_frame(super().load_frame_from_csv(path)) + + def load_frame_from_parquet(self, path): + return self.convert_frame(super().load_frame_from_parquet(path)) + + @property + def cll(self): + return cll_sqlite + + @property + def cl(self): + return cl_sqlite + + @property + def ctl(self): + return ctl_sqlite diff --git a/tests/test_array_columns.py b/tests/test_array_columns.py index 4f9ca7e102..ffd438c088 100644 --- a/tests/test_array_columns.py +++ b/tests/test_array_columns.py @@ -1,15 +1,17 @@ import pandas as pd import pytest -import splink.duckdb.duckdb_comparison_library as cl -from splink.duckdb.duckdb_linker import DuckDBLinker +from tests.decorator import mark_with_dialects_excluding def postcode(num): return f"XX{num} {num}YZ" -def test_array_comparisons(): +# No SQLite - no array comparisons in library +@mark_with_dialects_excluding("sqlite") +def test_array_comparisons(test_helpers, dialect): + helper = test_helpers[dialect] df = pd.DataFrame( [ { @@ -49,15 +51,16 @@ def test_array_comparisons(): }, ] ) + df = helper.convert_frame(df) settings = { "link_type": "dedupe_only", "comparisons": [ - cl.array_intersect_at_sizes("postcode", [4, 3, 2, 1]), - cl.exact_match("first_name"), + helper.cl.array_intersect_at_sizes("postcode", [4, 3, 2, 1]), + helper.cl.exact_match("first_name"), ], } - linker = DuckDBLinker(df, settings) + linker = helper.Linker(df, settings, **helper.extra_linker_args()) df_e = linker.predict().as_pandas_dataframe() # ID pairs with various sizes of intersections @@ -92,11 +95,11 @@ def test_array_comparisons(): settings = { "link_type": "dedupe_only", "comparisons": [ - cl.array_intersect_at_sizes("postcode", [3, 1]), - cl.exact_match("first_name"), + helper.cl.array_intersect_at_sizes("postcode", [3, 1]), + helper.cl.exact_match("first_name"), ], } - linker = DuckDBLinker(df, settings) + linker = helper.Linker(df, settings, **helper.extra_linker_args()) df_e = linker.predict().as_pandas_dataframe() # now levels encompass multiple size intersections @@ -130,7 +133,7 @@ def test_array_comparisons(): settings = { "link_type": "dedupe_only", "comparisons": [ - cl.array_intersect_at_sizes("postcode", [-1, 2]), - cl.exact_match("first_name"), + helper.cl.array_intersect_at_sizes("postcode", [-1, 2]), + helper.cl.exact_match("first_name"), ], } diff --git a/tests/test_estimate_prob_two_rr_match.py b/tests/test_estimate_prob_two_rr_match.py index aad5abfdc0..37e0f919bb 100644 --- a/tests/test_estimate_prob_two_rr_match.py +++ b/tests/test_estimate_prob_two_rr_match.py @@ -3,10 +3,12 @@ import pandas as pd import pytest -from splink.duckdb.duckdb_linker import DuckDBLinker +from tests.decorator import mark_with_dialects_excluding -def test_prob_rr_match_dedupe(): +@mark_with_dialects_excluding() +def test_prob_rr_match_dedupe(test_helpers, dialect): + helper = test_helpers[dialect] df = pd.DataFrame( [ {"unique_id": 1, "first_name": "John", "surname": "Smith"}, @@ -17,6 +19,7 @@ def test_prob_rr_match_dedupe(): {"unique_id": 6, "first_name": "Jane", "surname": "Taylor"}, ] ) + df = helper.convert_frame(df) settings = { "link_type": "dedupe_only", @@ -30,7 +33,7 @@ def test_prob_rr_match_dedupe(): deterministic_rules = ["l.first_name = r.first_name", "l.surname = r.surname"] # Test dedupe only - linker = DuckDBLinker(df, settings) + linker = helper.Linker(df, settings, **helper.extra_linker_args()) linker.estimate_probability_two_random_records_match( deterministic_rules, recall=1.0 ) @@ -50,7 +53,9 @@ def test_prob_rr_match_dedupe(): assert pytest.approx(prob) == 4 / 15 * (1 / 0.9) -def test_prob_rr_match_link_only(): +@mark_with_dialects_excluding() +def test_prob_rr_match_link_only(test_helpers, dialect): + helper = test_helpers[dialect] df_1 = pd.DataFrame( [ {"unique_id": 1, "first_name": "John", "surname": "Smith"}, @@ -66,6 +71,8 @@ def test_prob_rr_match_link_only(): {"unique_id": 4, "first_name": "Alice", "surname": "Williams"}, ] ) + df_1 = helper.convert_frame(df_1) + df_2 = helper.convert_frame(df_2) settings = { "link_type": "link_only", @@ -79,7 +86,7 @@ def test_prob_rr_match_link_only(): deterministic_rules = ["l.first_name = r.first_name", "l.surname = r.surname"] # Test dedupe only - linker = DuckDBLinker([df_1, df_2], settings) + linker = helper.Linker([df_1, df_2], settings, **helper.extra_linker_args()) linker.estimate_probability_two_random_records_match( deterministic_rules, recall=1.0 ) @@ -89,7 +96,9 @@ def test_prob_rr_match_link_only(): assert pytest.approx(prob) == 2 / 8 -def test_prob_rr_match_link_and_dedupe(): +@mark_with_dialects_excluding() +def test_prob_rr_match_link_and_dedupe(test_helpers, dialect): + helper = test_helpers[dialect] df_1 = pd.DataFrame( [ {"unique_id": 1, "first_name": "John", "surname": "Smith"}, @@ -105,6 +114,8 @@ def test_prob_rr_match_link_and_dedupe(): {"unique_id": 3, "first_name": "Jane", "surname": "Taylor"}, ] ) + df_1 = helper.convert_frame(df_1) + df_2 = helper.convert_frame(df_2) settings = { "link_type": "link_and_dedupe", @@ -115,7 +126,7 @@ def test_prob_rr_match_link_and_dedupe(): deterministic_rules = ["l.first_name = r.first_name", "l.surname = r.surname"] # Test dedupe only - linker = DuckDBLinker([df_1, df_2], settings) + linker = helper.Linker([df_1, df_2], settings, **helper.extra_linker_args()) linker.estimate_probability_two_random_records_match( deterministic_rules, recall=1.0 ) @@ -125,7 +136,9 @@ def test_prob_rr_match_link_and_dedupe(): assert pytest.approx(prob) == 3 / 15 -def test_prob_rr_match_link_only_multitable(): +@mark_with_dialects_excluding() +def test_prob_rr_match_link_only_multitable(test_helpers, dialect): + helper = test_helpers[dialect] df_1 = pd.DataFrame( [ {"unique_id": 1, "first_name": "John", "surname": "Smith"}, @@ -164,6 +177,15 @@ def test_prob_rr_match_link_only_multitable(): {"unique_id": 7, "first_name": "Brian", "surname": "Johnson"}, ] ) + (df_1, df_2, df_3, df_4) = list( + map(lambda df: df.assign(city="Brighton"), (df_1, df_2, df_3, df_4)) + ) + + df_1 = helper.convert_frame(df_1) + df_2 = helper.convert_frame(df_2) + df_3 = helper.convert_frame(df_3) + df_4 = helper.convert_frame(df_4) + dfs = [df_1, df_2, df_3, df_4] settings = { "link_type": "link_only", @@ -173,8 +195,7 @@ def test_prob_rr_match_link_only_multitable(): deterministic_rules = ["l.first_name = r.first_name", "l.surname = r.surname"] - dfs = [df_1, df_2, df_3, df_4] - linker = DuckDBLinker(dfs, settings) + linker = helper.Linker(dfs, settings, **helper.extra_linker_args()) linker.estimate_probability_two_random_records_match( deterministic_rules, recall=1.0 ) @@ -185,8 +206,7 @@ def test_prob_rr_match_link_only_multitable(): assert pytest.approx(prob) == 6 / 131 # if we define all record pairs to be a match, then the probability should be 1 - dfs = list(map(lambda df: df.assign(city="Brighton"), dfs)) - linker = DuckDBLinker(dfs, settings) + linker = helper.Linker(dfs, settings, **helper.extra_linker_args()) linker.estimate_probability_two_random_records_match( ["l.city = r.city"], recall=1.0 ) @@ -194,7 +214,9 @@ def test_prob_rr_match_link_only_multitable(): assert prob == 1 -def test_prob_rr_match_link_and_dedupe_multitable(): +@mark_with_dialects_excluding() +def test_prob_rr_match_link_and_dedupe_multitable(test_helpers, dialect): + helper = test_helpers[dialect] df_1 = pd.DataFrame( [ {"unique_id": 1, "first_name": "John", "surname": "Smith"}, @@ -233,6 +255,15 @@ def test_prob_rr_match_link_and_dedupe_multitable(): {"unique_id": 7, "first_name": "Brian", "surname": "Johnson"}, ] ) + (df_1, df_2, df_3, df_4) = list( + map(lambda df: df.assign(city="Brighton"), (df_1, df_2, df_3, df_4)) + ) + + df_1 = helper.convert_frame(df_1) + df_2 = helper.convert_frame(df_2) + df_3 = helper.convert_frame(df_3) + df_4 = helper.convert_frame(df_4) + dfs = [df_1, df_2, df_3, df_4] settings = { "link_type": "link_and_dedupe", @@ -242,8 +273,7 @@ def test_prob_rr_match_link_and_dedupe_multitable(): deterministic_rules = ["l.first_name = r.first_name", "l.surname = r.surname"] - dfs = [df_1, df_2, df_3, df_4] - linker = DuckDBLinker(dfs, settings) + linker = helper.Linker(dfs, settings, **helper.extra_linker_args()) linker.estimate_probability_two_random_records_match( deterministic_rules, recall=1.0 ) @@ -254,8 +284,7 @@ def test_prob_rr_match_link_and_dedupe_multitable(): # (3 + 4 + 5 + 7)(3 + 4 + 5 + 7 - 1)/2 = 171 comparisons assert pytest.approx(prob) == 10 / 171 - dfs = list(map(lambda df: df.assign(city="Brighton"), dfs)) - linker = DuckDBLinker(dfs, settings) + linker = helper.Linker(dfs, settings, **helper.extra_linker_args()) linker.estimate_probability_two_random_records_match( ["l.city = r.city"], recall=1.0 ) @@ -263,7 +292,10 @@ def test_prob_rr_match_link_and_dedupe_multitable(): assert prob == 1 -def test_prob_rr_valid_range(caplog): +@mark_with_dialects_excluding() +def test_prob_rr_valid_range(test_helpers, dialect, caplog): + helper = test_helpers[dialect] + def check_range(p): assert p <= 1 assert p >= 0 @@ -308,6 +340,7 @@ def check_range(p): }, ] ) + df = helper.convert_frame(df) settings = { "link_type": "dedupe_only", @@ -315,7 +348,7 @@ def check_range(p): } # Test dedupe only - linker = DuckDBLinker(df, settings) + linker = helper.Linker(df, settings, **helper.extra_linker_args()) with pytest.raises(ValueError): # all comparisons matches using this rule, so we must have perfect recall # using recall = 80% is inconsistent, so should get an error diff --git a/tests/test_full_example_duckdb.py b/tests/test_full_example_duckdb.py index 76d6e56841..0285e77a84 100644 --- a/tests/test_full_example_duckdb.py +++ b/tests/test_full_example_duckdb.py @@ -11,6 +11,7 @@ from splink.duckdb.duckdb_linker import DuckDBLinker from .basic_settings import get_settings_dict, name_comparison +from .decorator import mark_with_dialects_including from .linker_utils import ( _test_table_registration, _test_write_functionality, @@ -18,6 +19,7 @@ ) +@mark_with_dialects_including("duckdb") def test_full_example_duckdb(tmp_path): df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") df = df.rename(columns={"surname": "SUR name"}) @@ -156,6 +158,7 @@ def test_full_example_duckdb(tmp_path): ), ], ) +@mark_with_dialects_including("duckdb") def test_link_only(input, source_l, source_r): settings = get_settings_dict() settings["link_type"] = "link_only" @@ -200,6 +203,7 @@ def test_link_only(input, source_l, source_r): ), ], ) +@mark_with_dialects_including("duckdb") def test_duckdb_load_from_file(df): settings = get_settings_dict() @@ -221,6 +225,7 @@ def test_duckdb_load_from_file(df): assert len(linker.predict().as_pandas_dataframe()) == 7257 +@mark_with_dialects_including("duckdb") def test_duckdb_arrow_array(): # Checking array fixes problem identified here: # https://github.com/moj-analytical-services/splink/issues/680 @@ -249,6 +254,7 @@ def test_duckdb_arrow_array(): assert len(df) == 2 +@mark_with_dialects_including("duckdb") def test_cast_error(): from duckdb import InvalidInputException @@ -264,6 +270,7 @@ def test_cast_error(): DuckDBLinker(df) +@mark_with_dialects_including("duckdb") def test_small_example_duckdb(tmp_path): df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") df["full_name"] = df["first_name"] + df["surname"] diff --git a/tests/test_train_vs_predict.py b/tests/test_train_vs_predict.py index 72fc197f3f..a24390830a 100644 --- a/tests/test_train_vs_predict.py +++ b/tests/test_train_vs_predict.py @@ -1,12 +1,11 @@ -import pandas as pd import pytest -from splink.duckdb.duckdb_linker import DuckDBLinker - from .basic_settings import get_settings_dict +from .decorator import mark_with_dialects_excluding -def test_train_vs_predict(): +@mark_with_dialects_excluding() +def test_train_vs_predict(test_helpers, dialect): """ If you train parameters blocking on a column (say first_name) and then predict() using blocking_rules_to_generate_predictions @@ -16,11 +15,12 @@ def test_train_vs_predict(): The global version has the param estimate of first_name 'reveresed out' """ + helper = test_helpers[dialect] - df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") + df = helper.load_frame_from_csv("./tests/datasets/fake_1000_from_splink_demos.csv") settings_dict = get_settings_dict() settings_dict["blocking_rules_to_generate_predictions"] = ["l.surname = r.surname"] - linker = DuckDBLinker(df, settings_dict) + linker = helper.Linker(df, settings_dict, **helper.extra_linker_args()) training_session = linker.estimate_parameters_using_expectation_maximisation( "l.surname = r.surname", fix_u_probabilities=False diff --git a/tests/test_u_train.py b/tests/test_u_train.py index b1239e2223..4bdb9e6263 100644 --- a/tests/test_u_train.py +++ b/tests/test_u_train.py @@ -2,13 +2,12 @@ import pandas as pd import pytest -import splink.duckdb.duckdb_comparison_library as cld -import splink.spark.spark_comparison_library as clsp -from splink.duckdb.duckdb_linker import DuckDBLinker -from splink.spark.spark_linker import SparkLinker +from tests.decorator import mark_with_dialects_excluding -def test_u_train(): +@mark_with_dialects_excluding() +def test_u_train(test_helpers, dialect): + helper = test_helpers[dialect] data = [ {"unique_id": 1, "name": "Amanda"}, {"unique_id": 2, "name": "Robin"}, @@ -21,11 +20,12 @@ def test_u_train(): settings = { "link_type": "dedupe_only", - "comparisons": [cld.levenshtein_at_thresholds("name", 2)], + "comparisons": [helper.cl.levenshtein_at_thresholds("name", 2)], "blocking_rules_to_generate_predictions": ["l.name = r.name"], } + df_linker = helper.convert_frame(df) - linker = DuckDBLinker(df, settings) + linker = helper.Linker(df_linker, settings, **helper.extra_linker_args()) linker.debug_mode = True linker.estimate_u_using_random_sampling(max_pairs=1e6) cc_name = linker._settings_obj.comparisons[0] @@ -42,7 +42,9 @@ def test_u_train(): assert br.blocking_rule == "l.name = r.name" -def test_u_train_link_only(): +@mark_with_dialects_excluding() +def test_u_train_link_only(test_helpers, dialect): + helper = test_helpers[dialect] data_l = [ {"unique_id": 1, "name": "Amanda"}, {"unique_id": 2, "name": "Robin"}, @@ -65,11 +67,14 @@ def test_u_train_link_only(): settings = { "link_type": "link_only", - "comparisons": [cld.levenshtein_at_thresholds("name", 2)], + "comparisons": [helper.cl.levenshtein_at_thresholds("name", 2)], "blocking_rules_to_generate_predictions": [], } - linker = DuckDBLinker([df_l, df_r], settings) + df_l = helper.convert_frame(df_l) + df_r = helper.convert_frame(df_r) + + linker = helper.Linker([df_l, df_r], settings, **helper.extra_linker_args()) linker.debug_mode = True linker.estimate_u_using_random_sampling(max_pairs=1e6) cc_name = linker._settings_obj.comparisons[0] @@ -97,7 +102,9 @@ def test_u_train_link_only(): assert cl_no.u_probability == (denom - 3) / denom -def test_u_train_link_only_sample(): +@mark_with_dialects_excluding() +def test_u_train_link_only_sample(test_helpers, dialect): + helper = test_helpers[dialect] df_l = ( pd.DataFrame(np.random.randint(0, 3000, size=(3000, 1)), columns=["name"]) .reset_index() @@ -113,11 +120,14 @@ def test_u_train_link_only_sample(): settings = { "link_type": "link_only", - "comparisons": [cld.levenshtein_at_thresholds("name", 2)], + "comparisons": [helper.cl.levenshtein_at_thresholds("name", 2)], "blocking_rules_to_generate_predictions": [], } - linker = DuckDBLinker([df_l, df_r], settings) + df_l = helper.convert_frame(df_l) + df_r = helper.convert_frame(df_r) + + linker = helper.Linker([df_l, df_r], settings, **helper.extra_linker_args()) linker.debug_mode = True linker.estimate_u_using_random_sampling(max_pairs=max_pairs) linker._settings_obj.comparisons[0] @@ -135,10 +145,12 @@ def test_u_train_link_only_sample(): max_pairs_proportion = result[0]["count"] / max_pairs # equality only holds probabilistically # chance of failure is approximately 1e-06 - assert pytest.approx(max_pairs_proportion, 0.15) == 1.0 + assert pytest.approx(max_pairs_proportion, rel=0.15) == 1.0 -def test_u_train_multilink(): +@mark_with_dialects_excluding() +def test_u_train_multilink(test_helpers, dialect): + helper = test_helpers[dialect] datas = [ [ {"unique_id": 1, "name": "John"}, @@ -165,18 +177,18 @@ def test_u_train_multilink(): {"unique_id": 7, "name": "Adil"}, ], ] - dfs = list(map(pd.DataFrame, datas)) + dfs = list(map(lambda x: helper.convert_frame(pd.DataFrame(x)), datas)) expected_total_links = 2 * 3 + 2 * 4 + 2 * 7 + 3 * 4 + 3 * 7 + 4 * 7 expected_total_links_with_dedupes = (2 + 3 + 4 + 7) * (2 + 3 + 4 + 7 - 1) / 2 settings = { "link_type": "link_only", - "comparisons": [cld.levenshtein_at_thresholds("name", 2)], + "comparisons": [helper.cl.levenshtein_at_thresholds("name", 2)], "blocking_rules_to_generate_predictions": [], } - linker = DuckDBLinker(dfs, settings) + linker = helper.Linker(dfs, settings, **helper.extra_linker_args()) linker.debug_mode = True linker.estimate_u_using_random_sampling(max_pairs=1e6) cc_name = linker._settings_obj.comparisons[0] @@ -208,7 +220,7 @@ def test_u_train_multilink(): # also check the numbers on a link + dedupe with same inputs settings["link_type"] = "link_and_dedupe" - linker = DuckDBLinker(dfs, settings) + linker = helper.Linker(dfs, settings, **helper.extra_linker_args()) linker.debug_mode = True linker.estimate_u_using_random_sampling(max_pairs=1e6) cc_name = linker._settings_obj.comparisons[0] @@ -239,27 +251,20 @@ def test_u_train_multilink(): assert cl_no.u_probability == (denom - 10) / denom -@pytest.mark.parametrize( - ("Linker", "cll"), - [ - pytest.param(DuckDBLinker, cld, id="Test DuckDB random seeds"), - pytest.param(SparkLinker, clsp, id="Test Spark random seeds"), - ], -) -def test_seed_u_outputs(df_spark, Linker, cll): - if Linker == SparkLinker: - df = df_spark - else: - df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") +# No SQLite - doesn't support random seed +@mark_with_dialects_excluding("sqlite") +def test_seed_u_outputs(test_helpers, dialect): + helper = test_helpers[dialect] + df = helper.load_frame_from_csv("./tests/datasets/fake_1000_from_splink_demos.csv") settings = { "link_type": "dedupe_only", - "comparisons": [cll.levenshtein_at_thresholds("first_name", 2)], + "comparisons": [helper.cl.levenshtein_at_thresholds("first_name", 2)], } - linker_1 = Linker(df, settings) - linker_2 = Linker(df, settings) - linker_3 = Linker(df, settings) + linker_1 = helper.Linker(df, settings, **helper.extra_linker_args()) + linker_2 = helper.Linker(df, settings, **helper.extra_linker_args()) + linker_3 = helper.Linker(df, settings, **helper.extra_linker_args()) linker_1.estimate_u_using_random_sampling(max_pairs=1e3, seed=1) linker_2.estimate_u_using_random_sampling(max_pairs=1e3, seed=1)