From 458e4d73578edab84c4a5004f4a68ee24c7b7152 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Fri, 10 Mar 2023 15:39:36 -0900 Subject: [PATCH] feat: Add ComparisonLevel compositions This is a spin-off of https://github.com/moj-analytical-services/splink/pull/1096 --- docs/comparison_level_composition.md | 24 ++ mkdocs.yml | 1 + .../athena/athena_comparison_level_library.py | 1 + splink/comparison_level_composition.py | 247 ++++++++++++++++++ .../duckdb/duckdb_comparison_level_library.py | 1 + .../spark/spark_comparison_level_library.py | 1 + .../sqlite/sqlite_comparison_level_library.py | 1 + tests/basic_settings.py | 28 ++ tests/test_comparison_level_composition.py | 135 ++++++++++ tests/test_full_example_duckdb.py | 5 +- tests/test_full_example_spark.py | 13 +- 11 files changed, 447 insertions(+), 10 deletions(-) create mode 100644 docs/comparison_level_composition.md create mode 100644 splink/comparison_level_composition.py create mode 100644 tests/test_comparison_level_composition.py diff --git a/docs/comparison_level_composition.md b/docs/comparison_level_composition.md new file mode 100644 index 0000000000..6a67acce7c --- /dev/null +++ b/docs/comparison_level_composition.md @@ -0,0 +1,24 @@ +--- +tags: + - API + - comparisons +--- +# Documentation for `comparison_level_composition` functions + +`comparison_composition` allows the merging of existing comparison levels by a logical SQL clause - `OR`, `AND` or `NOT`. + +This extends the functionality of our base comparison levels by allowing users to "join" existing comparisons by various SQL clauses. + +For example, `or_(null_level("first_name"), null_level("surname"))` creates a check for nulls in *either* `first_name` or `surname`, rather than restricting the user to a single column. + +The detailed API for each of these are outlined below. + +## Library comparison composition APIs + +::: splink.comparison_level_composition + handler: python + selection: + members: + - and_ + - or_ + - not_ diff --git a/mkdocs.yml b/mkdocs.yml index db349a61c6..ffc6f95831 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -84,6 +84,7 @@ nav: - Comparisons Library API: - Comparison Library: "comparison_library.md" - Comparison Level Library: "comparison_level_library.md" + - Comparison Composition: "comparison_level_composition.md" - Settings Editor: "settingseditor/editor.md" - Settings dictionary reference: "settings_dict_guide.md" - Tutorials: diff --git a/splink/athena/athena_comparison_level_library.py b/splink/athena/athena_comparison_level_library.py index ef0260c04d..67da184d79 100644 --- a/splink/athena/athena_comparison_level_library.py +++ b/splink/athena/athena_comparison_level_library.py @@ -1,3 +1,4 @@ +from ..comparison_level_composition import and_, not_, or_ # noqa: F401 from ..comparison_level_library import ( ArrayIntersectLevelBase, ColumnsReversedLevelBase, diff --git a/splink/comparison_level_composition.py b/splink/comparison_level_composition.py new file mode 100644 index 0000000000..1bf37241a5 --- /dev/null +++ b/splink/comparison_level_composition.py @@ -0,0 +1,247 @@ +from __future__ import annotations + +from typing import Iterable + +from .comparison_level import ComparisonLevel + + +def and_( + *clls: ComparisonLevel | dict, + label_for_charts=None, + m_probability=None, + is_null_level=None, +) -> ComparisonLevel: + """Merge ComparisonLevels using logical "AND". + + Merge multiple ComparisonLevels into a single ComparisonLevel by + merging their SQL conditions using a logical "AND". + + By default, we generate a new `label_for_charts` for the new ComparisonLevel. + You can override this, and any other ComparisonLevel attributes, by passing + them as keyword arguments. + + Args: + *clls (ComparisonLevel | dict): ComparisonLevels or comparison + level dictionaries to merge + label_for_charts (str, optional): A label for this comparson level, + which will appear on charts as a reminder of what the level represents. + Defaults to a composition of - `label_1 AND label_2` + m_probability (float, optional): Starting value for m probability. + Defaults to None. + is_null_level (bool, optional): If true, m and u values will not be + estimated and instead the match weight will be zero for this column. + Defaults to None. + + Examples: + >>> # Simple null level composition with an `AND` clause + >>> import splink.duckdb.duckdb_comparison_level_library as cll + >>> cll.and_(cll.null_level("first_name"), cll.null_level("surname")) + + >>> # Composing a levenshtein level with a custom `contains` level + >>> import splink.duckdb.duckdb_comparison_level_library as cll + >>> misspelling = cll.levenshtein_level("name", 1) + >>> contains = { + >>> "sql_condition": "(contains(name_l, name_r) OR " \ + >>> "contains(name_r, name_l))" + >>> } + >>> merged = cll.and_(misspelling, contains, label_for_charts="Spelling error") + >>> merged.as_dict() + >>> { + >>> 'sql_condition': '(levenshtein("name_l", "name_r") <= 1) ' \ + >>> 'AND ((contains(name_l, name_r) OR contains(name_r, name_l)))', + >>> 'label_for_charts': 'Spelling error' + >>> } + + Returns: + ComparisonLevel: A new ComparisonLevel with the merged + SQL condition + """ + return _cl_merge( + *clls, + clause="AND", + label_for_charts=label_for_charts, + m_probability=m_probability, + is_null_level=is_null_level, + ) + + +def or_( + *clls: ComparisonLevel | dict, + label_for_charts: str | None = None, + m_probability: float | None = None, + is_null_level: bool | None = None, +) -> ComparisonLevel: + """Merge ComparisonLevels using logical "OR". + + Merge multiple ComparisonLevels into a single ComparisonLevel by + merging their SQL conditions using a logical "OR". + + By default, we generate a new `label_for_charts` for the new ComparisonLevel. + You can override this, and any other ComparisonLevel attributes, by passing + them as keyword arguments. + + Args: + *clls (ComparisonLevel | dict): ComparisonLevels or comparison + level dictionaries to merge + label_for_charts (str, optional): A label for this comparson level, + which will appear on charts as a reminder of what the level represents. + Defaults to a composition of - `label_1 OR label_2` + m_probability (float, optional): Starting value for m probability. + Defaults to None. + is_null_level (bool, optional): If true, m and u values will not be + estimated and instead the match weight will be zero for this column. + Defaults to None. + + Examples: + >>> # Simple null level composition with an `OR` clause + >>> import splink.duckdb.duckdb_comparison_level_library as cll + >>> cll.or_(cll.null_level("first_name"), cll.null_level("surname")) + + >>> # Composing a levenshtein level with a custom `contains` level + >>> import splink.duckdb.duckdb_comparison_level_library as cll + >>> misspelling = cll.levenshtein_level("name", 1) + >>> contains = { + >>> "sql_condition": "(contains(name_l, name_r) OR " \ + >>> "contains(name_r, name_l))" + >>> } + >>> merged = cll.or_(misspelling, contains, label_for_charts="Spelling error") + >>> merged.as_dict() + >>> { + >>> 'sql_condition': '(levenshtein("name_l", "name_r") <= 1) ' \ + >>> 'OR ((contains(name_l, name_r) OR contains(name_r, name_l)))', + >>> 'label_for_charts': 'Spelling error' + >>> } + + Returns: + ComparisonLevel: A new ComparisonLevel with the merged + SQL condition + """ + + return _cl_merge( + *clls, + clause="OR", + label_for_charts=label_for_charts, + m_probability=m_probability, + is_null_level=is_null_level, + ) + + +def not_( + cll: ComparisonLevel | dict, + label_for_charts: str | None = None, + m_probability: float | None = None, +) -> ComparisonLevel: + """Negate a ComparisonLevel. + + Returns a ComparisonLevel with the same SQL condition as the input, + but prefixed with "NOT". + + By default, we generate a new `label_for_charts` for the new ComparisonLevel. + You can override this, and any other ComparisonLevel attributes, by passing + them as keyword arguments. + + Args: + cll (ComparisonLevel | dict): ComparisonLevel or comparison + level dictionary + label_for_charts (str, optional): A label for this comparson level, + which will appear on charts as a reminder of what the level represents. + m_probability (float, optional): Starting value for m probability. + Defaults to None. + + Examples: + >>> import splink.duckdb.duckdb_comparison_level_library as cll + >>> # *Not* a null on first name `first_name` + >>> cll.not_(cll.exact_match("first_name")) + + >>> import splink.duckdb.duckdb_comparison_level_library as cll + >>> # Find all exact matches *not* on the first of January + >>> dob_first_jan = { + >>> "sql_condition": "SUBSTR(dob_std_l, -5) = '01-01'", + >>> "label_for_charts": "Date is 1st Jan", + >>> } + >>> exact_match_not_first_jan = cll.and_( + >>> cll.exact_match_level("dob"), + >>> cll.not_(dob_first_jan), + >>> label_for_charts = "Exact match and not the 1st Jan" + >>> ) + + + Returns: + ComparisonLevel + A new ComparisonLevel with the negated SQL condition and label_for_charts + """ + cls, sql_dialect = _parse_comparison_levels(cll) + cl = cls[0] + result = {} + result["sql_condition"] = f"NOT ({cl.sql_condition})" + + # Invert if is_null_level. + # If NOT is_null_level, then we don't know if the inverted level is null or not + if not cl.is_null_level: + result["is_null_level"] = False + + result["label_for_charts"] = ( + label_for_charts if label_for_charts else f"NOT ({cl.label_for_charts})" + ) + + if m_probability: + result["m_probability"] = m_probability + + return ComparisonLevel(result, sql_dialect=sql_dialect) + + +def _cl_merge( + *clls: ComparisonLevel | dict, + clause: str, + label_for_charts: str | None = None, + m_probability: float | None = None, + is_null_level: bool | None = None, +) -> ComparisonLevel: + if len(clls) == 0: + raise ValueError("Must provide at least one ComparisonLevel") + + cls, sql_dialect = _parse_comparison_levels(*clls) + result = {} + conditions = ("(" + cl.sql_condition + ")" for cl in cls) + result["sql_condition"] = f" {clause} ".join(conditions) + + # Set to null level if all supplied levels are "null levels" + if is_null_level is None: + if all(d.is_null_level for d in cls): + result["is_null_level"] = True + + if label_for_charts: + result["label_for_charts"] = label_for_charts + else: + labels = ("(" + cl.label_for_charts + ")" for cl in cls) + result["label_for_charts"] = f" {clause} ".join(labels) + + if m_probability: + result["m_probability"] = m_probability + + return ComparisonLevel(result, sql_dialect=sql_dialect) + + +def _parse_comparison_levels( + *cls: ComparisonLevel | dict, +) -> tuple[list[ComparisonLevel], str | None]: + cls = [_to_comparison_level(cl) for cl in cls] + sql_dialect = _unify_sql_dialects(cls) + return cls, sql_dialect + + +def _to_comparison_level(cl: ComparisonLevel | dict) -> ComparisonLevel: + if isinstance(cl, ComparisonLevel): + return cl + else: + return ComparisonLevel(cl) + + +def _unify_sql_dialects(cls: Iterable[ComparisonLevel]) -> str | None: + sql_dialects = set(cl._sql_dialect for cl in cls) + sql_dialects.discard(None) + if len(sql_dialects) > 1: + raise ValueError("Cannot combine comparison levels with different SQL dialects") + elif len(sql_dialects) == 0: + return None + return sql_dialects.pop() diff --git a/splink/duckdb/duckdb_comparison_level_library.py b/splink/duckdb/duckdb_comparison_level_library.py index c369826548..7e70fa3909 100644 --- a/splink/duckdb/duckdb_comparison_level_library.py +++ b/splink/duckdb/duckdb_comparison_level_library.py @@ -1,3 +1,4 @@ +from ..comparison_level_composition import and_, not_, or_ # noqa: F401 from ..comparison_level_library import ( ArrayIntersectLevelBase, ColumnsReversedLevelBase, diff --git a/splink/spark/spark_comparison_level_library.py b/splink/spark/spark_comparison_level_library.py index ce259ed3f1..415cbae971 100644 --- a/splink/spark/spark_comparison_level_library.py +++ b/splink/spark/spark_comparison_level_library.py @@ -1,3 +1,4 @@ +from ..comparison_level_composition import and_, not_, or_ # noqa: F401 from ..comparison_level_library import ( ArrayIntersectLevelBase, ColumnsReversedLevelBase, diff --git a/splink/sqlite/sqlite_comparison_level_library.py b/splink/sqlite/sqlite_comparison_level_library.py index 46ec0670c2..feb7535a0d 100644 --- a/splink/sqlite/sqlite_comparison_level_library.py +++ b/splink/sqlite/sqlite_comparison_level_library.py @@ -1,3 +1,4 @@ +from ..comparison_level_composition import and_, not_, or_ # noqa: F401 from ..comparison_level_library import ( ColumnsReversedLevelBase, DistanceFunctionLevelBase, diff --git a/tests/basic_settings.py b/tests/basic_settings.py index 3aba99fac1..40d65b01b9 100644 --- a/tests/basic_settings.py +++ b/tests/basic_settings.py @@ -151,3 +151,31 @@ def get_settings_dict(): } return deepcopy(settings) + + +def name_comparison(cll, sn: str) -> dict: + """A comparison using first and surname levels composed together.""" + return { + "output_column_name": "first_name_and_surname", + "comparison_levels": [ + # Null level + cll.or_(cll.null_level("first_name"), cll.null_level(sn)), + # Exact match on fn and sn + cll.or_( + cll.exact_match_level("first_name"), + cll.exact_match_level(sn), + m_probability=0.8, + label_for_charts="Exact match on first name or surname", + ), + # (Levenshtein(fn) and jaro_winkler(fn)) or levenshtein(sur) + cll.and_( + cll.or_( + cll.levenshtein_level("first_name", 2), + cll.jaro_winkler_level("first_name", 0.8), + m_probability=0.8, + ), + cll.levenshtein_level(sn, 3), + ), + cll.else_level(0.1), + ], + } diff --git a/tests/test_comparison_level_composition.py b/tests/test_comparison_level_composition.py new file mode 100644 index 0000000000..ca6063e81f --- /dev/null +++ b/tests/test_comparison_level_composition.py @@ -0,0 +1,135 @@ +import pandas as pd +import pytest + +import splink.duckdb.duckdb_comparison_level_library as cll +import splink.duckdb.duckdb_comparison_level_library as scll +from splink.duckdb.duckdb_linker import DuckDBLinker + + +def test_not(): + level = cll.not_(cll.null_level("first_name")) + assert level.is_null_level is False + + # Integration test for a simple dictionary cl + dob_jan_first = {"sql_condition": "SUBSTR(dob_std_l, -5) = '01-01'"} + cll.not_(dob_jan_first) + + with pytest.raises(TypeError): + cll.not_() + + +@pytest.mark.parametrize( + ("clause", "c_fun"), + [ + pytest.param("OR", cll.or_, id="Test or_"), + pytest.param("AND", cll.and_, id="Test and_"), + pytest.param("OR", scll.or_, id="Test spark or_"), + ], +) +def test_binary_composition_internals(clause, c_fun): + # Test what happens when only one value is fed + # It should just report the regular outputs of our comparison level func + level = c_fun(cll.exact_match_level("tom", include_colname_in_charts_label=True)) + assert level.sql_condition == '("tom_l" = "tom_r")' + assert level.label_for_charts == "(Exact match tom)" + + # Two null levels composed + level = c_fun( + cll.null_level("first_name"), + cll.null_level("surname"), + label_for_charts="This is a test", + ) + + null_sql = ( + f'("first_name_l" IS NULL OR "first_name_r" IS NULL) {clause} ' + '("surname_l" IS NULL OR "surname_r" IS NULL)' + ) + assert level.sql_condition == null_sql + # Default label + assert level.label_for_charts == "This is a test" + # As both inputs are null, we're expecting this to return True + assert level.is_null_level is True + + # Exact match and null level composition + level = c_fun( + cll.exact_match_level("first_name", include_colname_in_charts_label=True), + cll.null_level("first_name"), + m_probability=0.5, + ) + assert ( + level.sql_condition == f'("first_name_l" = "first_name_r") {clause} ' + '("first_name_l" IS NULL OR "first_name_r" IS NULL)' + ) + # Default label + assert level.label_for_charts == f"(Exact match first_name) {clause} (Null)" + # should default to False + assert level.is_null_level is False + assert level._m_probability == 0.5 + + # cll.not_(or_(...)) composition + level = cll.not_( + c_fun(cll.exact_match_level("first_name"), cll.exact_match_level("surname")), + m_probability=0.5, + ) + + exact_match_sql = ( + f'("first_name_l" = "first_name_r") {clause} ("surname_l" = "surname_r")' + ) + assert level.sql_condition == f"NOT ({exact_match_sql})" + + with pytest.raises(ValueError): + c_fun() + + +def test_composition_outputs(): + # Check our compositions give expected outputs + df = pd.DataFrame( + [ + {"unique_id": 1, "forename": "Tom", "surname": "Tim"}, + {"unique_id": 2, "forename": "Tom", "surname": "Tim"}, + {"unique_id": 3, "forename": "Tom", "surname": "Timothee"}, + {"unique_id": 4, "forename": "Sam", "surname": "Tarly"}, + {"unique_id": 5, "forename": "Sam", "surname": "Tim"}, + ] + ) + + # For testing the cll version + dbl_null = cll.or_(cll.null_level("forename"), cll.null_level("surname")) + both = cll.and_(cll.exact_match_level("forename"), cll.exact_match_level("surname")) + either = cll.or_( + cll.exact_match_level("forename"), cll.exact_match_level("surname") + ) + + full_name = { + "output_column_name": "full_name", + "comparison_levels": [ + dbl_null, + both, + either, + cll.not_(both), # acts as an "else" level + cll.else_level(), + ], + } + + settings = { + "link_type": "dedupe_only", + "comparisons": [full_name], + } + + linker = DuckDBLinker(df, settings) + + pred = linker.predict() + out = pred.as_pandas_dataframe().sort_values(by=["unique_id_l", "unique_id_r"]) + + # Check individual IDs are assigned to the correct gamma values + # Dict key: {gamma_value: tuple of ID pairs} + size_gamma_lookup = { + 3: [(1, 2)], + 2: [(1, 3), (1, 5), (2, 3), (2, 5), (4, 5)], + 1: [(1, 4), (2, 4), (3, 4), (3, 5)], + } + + for gamma, id_pairs in size_gamma_lookup.items(): + for left, right in id_pairs: + row = out.loc[(out.unique_id_l == left) & (out.unique_id_r == right)] + assert row["gamma_full_name"].values[0] == gamma diff --git a/tests/test_full_example_duckdb.py b/tests/test_full_example_duckdb.py index b1a2183094..a785757acb 100644 --- a/tests/test_full_example_duckdb.py +++ b/tests/test_full_example_duckdb.py @@ -6,6 +6,7 @@ import pyarrow.parquet as pq import pytest +import splink.duckdb.duckdb_comparison_level_library as cll from splink.duckdb.duckdb_comparison_library import ( exact_match, jaccard_at_thresholds, @@ -13,12 +14,11 @@ ) from splink.duckdb.duckdb_linker import DuckDBLinker -from .basic_settings import get_settings_dict +from .basic_settings import get_settings_dict, name_comparison from .linker_utils import _test_table_registration, register_roc_data def test_full_example_duckdb(tmp_path): - df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") df = df.rename(columns={"surname": "SUR name"}) settings_dict = get_settings_dict() @@ -26,6 +26,7 @@ def test_full_example_duckdb(tmp_path): # Overwrite the surname comparison to include duck-db specific syntax settings_dict["comparisons"][0] = jaro_winkler_at_thresholds("first_name") settings_dict["comparisons"][1] = jaccard_at_thresholds("SUR name") + settings_dict["comparisons"].append(name_comparison(cll, "SUR name")) settings_dict["blocking_rules_to_generate_predictions"] = [ 'l."SUR name" = r."SUR name"', ] diff --git a/tests/test_full_example_spark.py b/tests/test_full_example_spark.py index 5f7aa5d7f4..dc15977980 100644 --- a/tests/test_full_example_spark.py +++ b/tests/test_full_example_spark.py @@ -3,25 +3,22 @@ from pyspark.sql.functions import array from pyspark.sql.types import StringType, StructField, StructType +import splink.spark.spark_comparison_level_library as cll import splink.spark.spark_comparison_library as cl -from splink.spark.spark_comparison_level_library import ( - array_intersect_level, - else_level, -) from splink.spark.spark_linker import SparkLinker -from .basic_settings import get_settings_dict +from .basic_settings import get_settings_dict, name_comparison from .linker_utils import _test_table_registration, register_roc_data def test_full_example_spark(df_spark, tmp_path): - # Convert a column to an array to enable testing intersection df_spark = df_spark.withColumn("email", array("email")) settings_dict = get_settings_dict() # Only needed because the value can be overwritten by other tests settings_dict["comparisons"][1] = cl.exact_match("surname") + settings_dict["comparisons"].append(name_comparison(cll, "surname")) settings = { "probability_two_random_records_match": 0.01, @@ -35,8 +32,8 @@ def test_full_example_spark(df_spark, tmp_path): cl.exact_match("dob"), { "comparison_levels": [ - array_intersect_level("email"), - else_level(), + cll.array_intersect_level("email"), + cll.else_level(), ] }, cl.exact_match("city"),