Skip to content

Commit

Permalink
feat: Add ComparisonLevel compositions
Browse files Browse the repository at this point in the history
This is a spin-off of
#1096
  • Loading branch information
NickCrews committed Mar 11, 2023
1 parent e0a2d5e commit 458e4d7
Show file tree
Hide file tree
Showing 11 changed files with 447 additions and 10 deletions.
24 changes: 24 additions & 0 deletions docs/comparison_level_composition.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
---
tags:
- API
- comparisons
---
# Documentation for `comparison_level_composition` functions

`comparison_composition` allows the merging of existing comparison levels by a logical SQL clause - `OR`, `AND` or `NOT`.

This extends the functionality of our base comparison levels by allowing users to "join" existing comparisons by various SQL clauses.

For example, `or_(null_level("first_name"), null_level("surname"))` creates a check for nulls in *either* `first_name` or `surname`, rather than restricting the user to a single column.

The detailed API for each of these are outlined below.

## Library comparison composition APIs

::: splink.comparison_level_composition
handler: python
selection:
members:
- and_
- or_
- not_
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ nav:
- Comparisons Library API:
- Comparison Library: "comparison_library.md"
- Comparison Level Library: "comparison_level_library.md"
- Comparison Composition: "comparison_level_composition.md"
- Settings Editor: "settingseditor/editor.md"
- Settings dictionary reference: "settings_dict_guide.md"
- Tutorials:
Expand Down
1 change: 1 addition & 0 deletions splink/athena/athena_comparison_level_library.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ..comparison_level_composition import and_, not_, or_ # noqa: F401
from ..comparison_level_library import (
ArrayIntersectLevelBase,
ColumnsReversedLevelBase,
Expand Down
247 changes: 247 additions & 0 deletions splink/comparison_level_composition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
from __future__ import annotations

from typing import Iterable

from .comparison_level import ComparisonLevel


def and_(
*clls: ComparisonLevel | dict,
label_for_charts=None,
m_probability=None,
is_null_level=None,
) -> ComparisonLevel:
"""Merge ComparisonLevels using logical "AND".
Merge multiple ComparisonLevels into a single ComparisonLevel by
merging their SQL conditions using a logical "AND".
By default, we generate a new `label_for_charts` for the new ComparisonLevel.
You can override this, and any other ComparisonLevel attributes, by passing
them as keyword arguments.
Args:
*clls (ComparisonLevel | dict): ComparisonLevels or comparison
level dictionaries to merge
label_for_charts (str, optional): A label for this comparson level,
which will appear on charts as a reminder of what the level represents.
Defaults to a composition of - `label_1 AND label_2`
m_probability (float, optional): Starting value for m probability.
Defaults to None.
is_null_level (bool, optional): If true, m and u values will not be
estimated and instead the match weight will be zero for this column.
Defaults to None.
Examples:
>>> # Simple null level composition with an `AND` clause
>>> import splink.duckdb.duckdb_comparison_level_library as cll
>>> cll.and_(cll.null_level("first_name"), cll.null_level("surname"))
>>> # Composing a levenshtein level with a custom `contains` level
>>> import splink.duckdb.duckdb_comparison_level_library as cll
>>> misspelling = cll.levenshtein_level("name", 1)
>>> contains = {
>>> "sql_condition": "(contains(name_l, name_r) OR " \
>>> "contains(name_r, name_l))"
>>> }
>>> merged = cll.and_(misspelling, contains, label_for_charts="Spelling error")
>>> merged.as_dict()
>>> {
>>> 'sql_condition': '(levenshtein("name_l", "name_r") <= 1) ' \
>>> 'AND ((contains(name_l, name_r) OR contains(name_r, name_l)))',
>>> 'label_for_charts': 'Spelling error'
>>> }
Returns:
ComparisonLevel: A new ComparisonLevel with the merged
SQL condition
"""
return _cl_merge(
*clls,
clause="AND",
label_for_charts=label_for_charts,
m_probability=m_probability,
is_null_level=is_null_level,
)


def or_(
*clls: ComparisonLevel | dict,
label_for_charts: str | None = None,
m_probability: float | None = None,
is_null_level: bool | None = None,
) -> ComparisonLevel:
"""Merge ComparisonLevels using logical "OR".
Merge multiple ComparisonLevels into a single ComparisonLevel by
merging their SQL conditions using a logical "OR".
By default, we generate a new `label_for_charts` for the new ComparisonLevel.
You can override this, and any other ComparisonLevel attributes, by passing
them as keyword arguments.
Args:
*clls (ComparisonLevel | dict): ComparisonLevels or comparison
level dictionaries to merge
label_for_charts (str, optional): A label for this comparson level,
which will appear on charts as a reminder of what the level represents.
Defaults to a composition of - `label_1 OR label_2`
m_probability (float, optional): Starting value for m probability.
Defaults to None.
is_null_level (bool, optional): If true, m and u values will not be
estimated and instead the match weight will be zero for this column.
Defaults to None.
Examples:
>>> # Simple null level composition with an `OR` clause
>>> import splink.duckdb.duckdb_comparison_level_library as cll
>>> cll.or_(cll.null_level("first_name"), cll.null_level("surname"))
>>> # Composing a levenshtein level with a custom `contains` level
>>> import splink.duckdb.duckdb_comparison_level_library as cll
>>> misspelling = cll.levenshtein_level("name", 1)
>>> contains = {
>>> "sql_condition": "(contains(name_l, name_r) OR " \
>>> "contains(name_r, name_l))"
>>> }
>>> merged = cll.or_(misspelling, contains, label_for_charts="Spelling error")
>>> merged.as_dict()
>>> {
>>> 'sql_condition': '(levenshtein("name_l", "name_r") <= 1) ' \
>>> 'OR ((contains(name_l, name_r) OR contains(name_r, name_l)))',
>>> 'label_for_charts': 'Spelling error'
>>> }
Returns:
ComparisonLevel: A new ComparisonLevel with the merged
SQL condition
"""

return _cl_merge(
*clls,
clause="OR",
label_for_charts=label_for_charts,
m_probability=m_probability,
is_null_level=is_null_level,
)


def not_(
cll: ComparisonLevel | dict,
label_for_charts: str | None = None,
m_probability: float | None = None,
) -> ComparisonLevel:
"""Negate a ComparisonLevel.
Returns a ComparisonLevel with the same SQL condition as the input,
but prefixed with "NOT".
By default, we generate a new `label_for_charts` for the new ComparisonLevel.
You can override this, and any other ComparisonLevel attributes, by passing
them as keyword arguments.
Args:
cll (ComparisonLevel | dict): ComparisonLevel or comparison
level dictionary
label_for_charts (str, optional): A label for this comparson level,
which will appear on charts as a reminder of what the level represents.
m_probability (float, optional): Starting value for m probability.
Defaults to None.
Examples:
>>> import splink.duckdb.duckdb_comparison_level_library as cll
>>> # *Not* a null on first name `first_name`
>>> cll.not_(cll.exact_match("first_name"))
>>> import splink.duckdb.duckdb_comparison_level_library as cll
>>> # Find all exact matches *not* on the first of January
>>> dob_first_jan = {
>>> "sql_condition": "SUBSTR(dob_std_l, -5) = '01-01'",
>>> "label_for_charts": "Date is 1st Jan",
>>> }
>>> exact_match_not_first_jan = cll.and_(
>>> cll.exact_match_level("dob"),
>>> cll.not_(dob_first_jan),
>>> label_for_charts = "Exact match and not the 1st Jan"
>>> )
Returns:
ComparisonLevel
A new ComparisonLevel with the negated SQL condition and label_for_charts
"""
cls, sql_dialect = _parse_comparison_levels(cll)
cl = cls[0]
result = {}
result["sql_condition"] = f"NOT ({cl.sql_condition})"

# Invert if is_null_level.
# If NOT is_null_level, then we don't know if the inverted level is null or not
if not cl.is_null_level:
result["is_null_level"] = False

result["label_for_charts"] = (
label_for_charts if label_for_charts else f"NOT ({cl.label_for_charts})"
)

if m_probability:
result["m_probability"] = m_probability

return ComparisonLevel(result, sql_dialect=sql_dialect)


def _cl_merge(
*clls: ComparisonLevel | dict,
clause: str,
label_for_charts: str | None = None,
m_probability: float | None = None,
is_null_level: bool | None = None,
) -> ComparisonLevel:
if len(clls) == 0:
raise ValueError("Must provide at least one ComparisonLevel")

cls, sql_dialect = _parse_comparison_levels(*clls)
result = {}
conditions = ("(" + cl.sql_condition + ")" for cl in cls)
result["sql_condition"] = f" {clause} ".join(conditions)

# Set to null level if all supplied levels are "null levels"
if is_null_level is None:
if all(d.is_null_level for d in cls):
result["is_null_level"] = True

if label_for_charts:
result["label_for_charts"] = label_for_charts
else:
labels = ("(" + cl.label_for_charts + ")" for cl in cls)
result["label_for_charts"] = f" {clause} ".join(labels)

if m_probability:
result["m_probability"] = m_probability

return ComparisonLevel(result, sql_dialect=sql_dialect)


def _parse_comparison_levels(
*cls: ComparisonLevel | dict,
) -> tuple[list[ComparisonLevel], str | None]:
cls = [_to_comparison_level(cl) for cl in cls]
sql_dialect = _unify_sql_dialects(cls)
return cls, sql_dialect


def _to_comparison_level(cl: ComparisonLevel | dict) -> ComparisonLevel:
if isinstance(cl, ComparisonLevel):
return cl
else:
return ComparisonLevel(cl)


def _unify_sql_dialects(cls: Iterable[ComparisonLevel]) -> str | None:
sql_dialects = set(cl._sql_dialect for cl in cls)
sql_dialects.discard(None)
if len(sql_dialects) > 1:
raise ValueError("Cannot combine comparison levels with different SQL dialects")
elif len(sql_dialects) == 0:
return None
return sql_dialects.pop()
1 change: 1 addition & 0 deletions splink/duckdb/duckdb_comparison_level_library.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ..comparison_level_composition import and_, not_, or_ # noqa: F401
from ..comparison_level_library import (
ArrayIntersectLevelBase,
ColumnsReversedLevelBase,
Expand Down
1 change: 1 addition & 0 deletions splink/spark/spark_comparison_level_library.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ..comparison_level_composition import and_, not_, or_ # noqa: F401
from ..comparison_level_library import (
ArrayIntersectLevelBase,
ColumnsReversedLevelBase,
Expand Down
1 change: 1 addition & 0 deletions splink/sqlite/sqlite_comparison_level_library.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ..comparison_level_composition import and_, not_, or_ # noqa: F401
from ..comparison_level_library import (
ColumnsReversedLevelBase,
DistanceFunctionLevelBase,
Expand Down
28 changes: 28 additions & 0 deletions tests/basic_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,31 @@ def get_settings_dict():
}

return deepcopy(settings)


def name_comparison(cll, sn: str) -> dict:
"""A comparison using first and surname levels composed together."""
return {
"output_column_name": "first_name_and_surname",
"comparison_levels": [
# Null level
cll.or_(cll.null_level("first_name"), cll.null_level(sn)),
# Exact match on fn and sn
cll.or_(
cll.exact_match_level("first_name"),
cll.exact_match_level(sn),
m_probability=0.8,
label_for_charts="Exact match on first name or surname",
),
# (Levenshtein(fn) and jaro_winkler(fn)) or levenshtein(sur)
cll.and_(
cll.or_(
cll.levenshtein_level("first_name", 2),
cll.jaro_winkler_level("first_name", 0.8),
m_probability=0.8,
),
cll.levenshtein_level(sn, 3),
),
cll.else_level(0.1),
],
}
Loading

0 comments on commit 458e4d7

Please sign in to comment.