Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add regex extract functionality to comparisons #1203

Merged
merged 37 commits into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a0c8c68
Added option to include a regex match to exact match level base
zslade Apr 24, 2023
3e27cff
added function to produce sql string with accepted duckdb regex extra…
zslade Apr 24, 2023
5e2410f
added function to produce sql string with accepted spark regex extrac…
zslade Apr 24, 2023
387c414
updated with error message to user if regex extract function doesn't …
zslade Apr 24, 2023
c819cf5
0 not required for duckdb backend
zslade Apr 25, 2023
86a871c
Updated regex function so that can be used with string comparitors an…
zslade Apr 25, 2023
03c6214
updated string comparitor levels and reverse match level to accept re…
zslade Apr 25, 2023
8582f11
Added option to perform a regex match
zslade Apr 26, 2023
bcc98aa
required regex parameter to be str
zslade Apr 26, 2023
acf9e95
Regex extract sql added to Athena backend
zslade Apr 27, 2023
ff2d7a6
Updated error message
zslade Apr 27, 2023
2738ab3
added error message for spark when user enters python escape characte…
zslade Apr 27, 2023
c7e985d
Updated doc strings with examples for using regex_extract option
zslade Apr 27, 2023
e57ab69
removed erroneous bracket
zslade May 2, 2023
2f7730d
updated scripts to latest versions to work with Athena linker
zslade May 2, 2023
205c6b8
Added tests for regex levels for duckdb and spark linker
zslade May 4, 2023
930f221
Added test for invalid regex
zslade May 4, 2023
3105f51
Linted and formatted script test script
zslade May 4, 2023
08c622e
linted and formatted
zslade May 5, 2023
2348bc1
Added option `valid_string_regex` to treat col as null if no match on…
zslade May 7, 2023
93f1489
Updated comparisons with valid_string_regex option
zslade May 7, 2023
18d6041
removed comments
zslade May 7, 2023
c8eb64b
linted
zslade May 7, 2023
e584e2b
linted
zslade May 7, 2023
3d0faca
add `valid_string_regex` option to date comparison
zslade May 7, 2023
a14137d
Added a check for regex_extract and valid_string_regegx parameters
zslade May 7, 2023
ad293b5
linted scripts
zslade May 7, 2023
622acd1
Merge branch 'master' into regex_extract
RossKen May 9, 2023
8c1471a
Update docs and add null level for dates
RossKen May 9, 2023
efc2f27
Merge branch 'master' into regex_extract
RossKen May 9, 2023
787b72c
revert athena linker rebase error
RossKen May 9, 2023
9cc5ddd
fix up template library and null level
RossKen May 9, 2023
5afe4e6
add tests for date format parsing
RossKen May 10, 2023
dd6803b
linting
RossKen May 10, 2023
9f664b3
lint with black
RossKen May 10, 2023
e44e321
Merge branch 'regex_extract' of github.com:moj-analytical-services/sp…
RossKen May 10, 2023
9501ede
reinstate cll testing
RossKen May 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions splink/athena/athena_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ def size_array_intersect_sql(col_name_l, col_name_r):
return f"cardinality(array_intersect({col_name_l}, {col_name_r}))"


def regex_extract_sql(col_name, regex):
return f"""
regexp_extract({col_name}, '{regex}')
"""


class AthenaBase(DialectBase):
@property
def _sql_dialect(self):
Expand All @@ -19,3 +25,7 @@ def _levenshtein_name(self):
@property
def _size_array_intersect_function(self):
return size_array_intersect_sql

@property
def _regex_extract_function(self):
return regex_extract_sql
207 changes: 197 additions & 10 deletions splink/comparison_level_library.py

Large diffs are not rendered by default.

224 changes: 206 additions & 18 deletions splink/comparison_library.py

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion splink/comparison_library_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def distance_threshold_comparison_levels(
col_name: str,
distance_function_name: str,
distance_threshold_or_thresholds,
regex_extract: str = None,
higher_is_more_similar: bool = True,
m_probability_or_probabilities_thres: list = None,
):
Expand Down Expand Up @@ -94,7 +95,12 @@ def distance_threshold_comparison_levels(
higher_is_more_similar = True

# these function arguments hold for all cases.
kwargs = dict(col_name=col_name, distance_threshold=thres, m_probability=m_prob)
kwargs = dict(
col_name=col_name,
distance_threshold=thres,
regex_extract=regex_extract,
m_probability=m_prob,
)
# separate out the two that are only used
# when we have a user-supplied function, rather than a predefined subclass
# feels a bit hacky, but will do at least for time being
Expand Down
103 changes: 68 additions & 35 deletions splink/comparison_template_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class DateComparisonBase(Comparison):
def __init__(
self,
col_name: str,
cast_strings_to_date: bool = False,
date_format: str = None,
invalid_dates_as_null: bool = False,
include_exact_match_level: bool = True,
term_frequency_adjustments: bool = False,
separate_1st_january: bool = False,
Expand All @@ -36,8 +39,6 @@ def __init__(
m_probability_or_probabilities_jw: float | list = None,
m_probability_or_probabilities_datediff: float | list = None,
m_probability_else: float = None,
cast_strings_to_date: bool = False,
date_format: str = None,
) -> Comparison:
"""A wrapper to generate a comparison for a date column the data in
`col_name` with preselected defaults.
Expand All @@ -51,7 +52,19 @@ def __init__(
- Anything else

Args:
col_name (str): The name of the column to compare
col_name (str): The name of the column to compare.
cast_strings_to_date (bool, optional): Set to True to
enable date-casting when input dates are strings. Also adjust
date_format if date-strings are not in (yyyy-mm-dd) format.
Defaults to False.
date_format (str, optional): Format of input dates if date-strings
are given. Must be consistent across record pairs. If None
(the default), downstream functions for each backend assign
date_format to ISO 8601 format (yyyy-mm-dd).
Set to "yyyy-MM-dd" for Spark and "%Y-%m-%d" for DuckDB
when invalid_dates_as_null=True
invalid_dates_as_null (bool, optional): assign any dates that do not adhere
to date_format to the null level. Defaults to False.
include_exact_match_level (bool, optional): If True, include an exact match
level. Defaults to True.
term_frequency_adjustments (bool, optional): If True, apply term frequency
Expand Down Expand Up @@ -95,14 +108,6 @@ def __init__(
for the datediff thresholds specified. Defaults to None.
m_probability_else (_type_, optional): If provided, overrides the
default m probability for the 'anything else' level. Defaults to None.
cast_strings_to_date (bool, optional): Set to True to
enable date-casting when input dates are strings. Also adjust
date_format if date-strings are not in (yyyy-mm-dd) format.
Defaults to False.
date_format(str, optional): Format of input dates if date-strings
are given. Must be consistent across record pairs. If None
(the default), downstream functions for each backend assign
date_format to ISO 8601 format (yyyy-mm-dd).

Examples:
=== "DuckDB"
Expand All @@ -120,6 +125,15 @@ def __init__(
datediff_thresholds=[1, 1],
datediff_metrics=["month", "year"])
```
Date Comparison casting columns date and assigning values that do not
match the date_format to the null level
``` python
import splink.duckdb.duckdb_comparison_template_library as ctl
ctl.date_comparison("date_of_birth",
cast_strings_to_date=True,
date_format='%d/%m/%Y',
invalid_dates_as_null=True)
```
=== "Spark"
Basic Date Comparison
``` python
Expand All @@ -135,14 +149,25 @@ def __init__(
datediff_thresholds=[1, 1],
datediff_metrics=["month", "year"])
```

Date Comparison casting columns date and assigning values that do not
match the date_format to the null level
``` python
import splink.spark.spark_comparison_template_library as ctl
ctl.date_comparison("date_of_birth",
cast_strings_to_date=True,
date_format='dd/mm/yyyy',
invalid_dates_as_null=True)
```
Returns:
Comparison: A comparison that can be inclued in the Splink settings
dictionary.
"""
# Construct Comparison
comparison_levels = []
comparison_levels.append(self._null_level(col_name))
if invalid_dates_as_null:
comparison_levels.append(self._null_level(col_name, date_format))
else:
comparison_levels.append(self._null_level(col_name))

# Validate user inputs
datediff_error_logger(thresholds=datediff_thresholds, metrics=datediff_metrics)
Expand Down Expand Up @@ -172,9 +197,9 @@ def __init__(
threshold_comparison_levels = distance_threshold_comparison_levels(
self,
col_name,
"levenshtein",
levenshtein_thresholds,
m_probability_or_probabilities_lev,
distance_function_name="levenshtein",
distance_threshold_or_thresholds=levenshtein_thresholds,
m_probability_or_probabilities_thres=m_probability_or_probabilities_lev,
)
comparison_levels = comparison_levels + threshold_comparison_levels

Expand All @@ -183,9 +208,9 @@ def __init__(
threshold_comparison_levels = distance_threshold_comparison_levels(
self,
col_name,
"jaro",
jaro_thresholds,
m_probability_or_probabilities_jar,
distance_function_name="jaro",
distance_threshold_or_thresholds=jaro_thresholds,
m_probability_or_probabilities_thres=m_probability_or_probabilities_jar,
)
comparison_levels = comparison_levels + threshold_comparison_levels

Expand All @@ -194,9 +219,9 @@ def __init__(
threshold_comparison_levels = distance_threshold_comparison_levels(
self,
col_name,
"jaro-winkler",
jaro_winkler_thresholds,
m_probability_or_probabilities_jw,
distance_function_name="jaro-winkler",
distance_threshold_or_thresholds=jaro_winkler_thresholds,
m_probability_or_probabilities_thres=m_probability_or_probabilities_jw,
)
comparison_levels = comparison_levels + threshold_comparison_levels

Expand Down Expand Up @@ -292,6 +317,7 @@ class NameComparisonBase(Comparison):
def __init__(
self,
col_name: str,
regex_extract: str = None,
include_exact_match_level: bool = True,
phonetic_col_name: str = None,
term_frequency_adjustments_name: bool = False,
Expand All @@ -318,7 +344,8 @@ def __init__(
- Anything else

Args:
col_name (str): The name of the column to compare
col_name (str): The name of the column to compare.
regex_extract (str): Regular expression pattern to evaluate a match on.
include_exact_match_level (bool, optional): If True, include an exact match
level for col_name. Defaults to True.
phonetic_col_name (str): The name of the column with phonetic reduction
Expand Down Expand Up @@ -419,6 +446,7 @@ def __init__(
term_frequency_adjustments=term_frequency_adjustments_name,
m_probability=m_probability_exact_match_name,
include_colname_in_charts_label=True,
regex_extract=regex_extract,
)
comparison_levels.append(comparison_level)

Expand All @@ -428,6 +456,7 @@ def __init__(
term_frequency_adjustments=term_frequency_adjustments_phonetic_name,
m_probability=m_probability_exact_match_phonetic_name,
include_colname_in_charts_label=True,
regex_extract=regex_extract,
)
comparison_levels.append(comparison_level)

Expand All @@ -436,9 +465,10 @@ def __init__(
threshold_comparison_levels = distance_threshold_comparison_levels(
self,
col_name,
"levenshtein",
levenshtein_thresholds,
m_probability_or_probabilities_lev,
distance_function_name="levenshtein",
distance_threshold_or_thresholds=levenshtein_thresholds,
regex_extract=regex_extract,
m_probability_or_probabilities_thres=m_probability_or_probabilities_lev,
)
comparison_levels = comparison_levels + threshold_comparison_levels

Expand All @@ -447,9 +477,10 @@ def __init__(
threshold_comparison_levels = distance_threshold_comparison_levels(
self,
col_name,
"jaro",
jaro_thresholds,
m_probability_or_probabilities_jar,
distance_function_name="jaro",
distance_threshold_or_thresholds=jaro_thresholds,
regex_extract=regex_extract,
m_probability_or_probabilities_thres=m_probability_or_probabilities_jar,
)
comparison_levels = comparison_levels + threshold_comparison_levels

Expand All @@ -458,9 +489,10 @@ def __init__(
threshold_comparison_levels = distance_threshold_comparison_levels(
self,
col_name,
"jaro-winkler",
jaro_winkler_thresholds,
m_probability_or_probabilities_jw,
distance_function_name="jaro-winkler",
distance_threshold_or_thresholds=jaro_winkler_thresholds,
regex_extract=regex_extract,
m_probability_or_probabilities_thres=m_probability_or_probabilities_jw,
)
comparison_levels = comparison_levels + threshold_comparison_levels

Expand All @@ -469,9 +501,10 @@ def __init__(
threshold_comparison_levels = distance_threshold_comparison_levels(
self,
col_name,
"jaccard",
jaccard_thresholds,
m_probability_or_probabilities_jar,
distance_function_name="jaccard",
distance_threshold_or_thresholds=jaccard_thresholds,
regex_extract=regex_extract,
m_probability_or_probabilities_thres=m_probability_or_probabilities_jar,
)
comparison_levels = comparison_levels + threshold_comparison_levels

Expand Down
6 changes: 6 additions & 0 deletions splink/dialect_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def _datediff_function(self):
f"comparisons/comparison levels?"
)

@property
def _regex_extract_function(self):
raise NotImplementedError(
"Regex extract option not defined for " "the SQL backend being used. "
)

@property
def _levenshtein_name(self):
return "levenshtein"
Expand Down
10 changes: 10 additions & 0 deletions splink/duckdb/duckdb_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ def datediff_sql(
"""


def regex_extract_sql(col_name, regex):
return f"""
regexp_extract({col_name}, '{regex}')
"""


class DuckDBBase(DialectBase):
@property
def _sql_dialect(self):
Expand All @@ -48,6 +54,10 @@ def _size_array_intersect_function(self):
def _datediff_function(self):
return datediff_sql

@property
def _regex_extract_function(self):
return regex_extract_sql

@property
def _jaro_name(self):
return "jaro_similarity"
Expand Down
18 changes: 18 additions & 0 deletions splink/spark/spark_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,20 @@ def datediff_sql(
"""


def regex_extract_sql(col_name, regex):
if "\\" in regex:
raise SyntaxError(
"Regular expressions containing “\\” (the python escape character) "
"are not compatible with Splink’s Spark linker. "
"Please consider using alternative syntax, "
"for example replacing “\\d” with “[0-9]”."
)
else:
return f"""
regexp_extract({col_name}, '{regex}', 0)
"""


class SparkBase(DialectBase):
@property
def _sql_dialect(self):
Expand All @@ -57,6 +71,10 @@ def _datediff_function(self):
def _size_array_intersect_function(self):
return size_array_intersect_sql

@property
def _regex_extract_function(self):
return regex_extract_sql

@property
def _jaro_name(self):
return "jaro_sim"
Loading