Skip to content

Commit

Permalink
Merge pull request #2454 from moj-analytical-services/clustering_allo…
Browse files Browse the repository at this point in the history
…ws_mw

Clustering allows match weight args not just match probability
  • Loading branch information
RobinL authored Nov 26, 2024
2 parents 5b85deb + 3ceddde commit 2aa78da
Show file tree
Hide file tree
Showing 7 changed files with 282 additions and 33 deletions.
85 changes: 71 additions & 14 deletions splink/internals/clustering.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from __future__ import annotations

import logging
import math
from typing import Optional

from splink.internals.connected_components import solve_connected_components
from splink.internals.database_api import AcceptableInputTableType, DatabaseAPISubClass
from splink.internals.input_column import InputColumn
from splink.internals.misc import ascii_uid
from splink.internals.misc import (
ascii_uid,
prob_to_match_weight,
threshold_args_to_match_prob,
threshold_args_to_match_prob_list,
)
from splink.internals.pipeline import CTEPipeline
from splink.internals.splink_dataframe import SplinkDataFrame

Expand Down Expand Up @@ -42,14 +48,15 @@ def cluster_pairwise_predictions_at_threshold(
edge_id_column_name_left: Optional[str] = None,
edge_id_column_name_right: Optional[str] = None,
threshold_match_probability: Optional[float] = None,
threshold_match_weight: Optional[float] = None,
) -> SplinkDataFrame:
"""Clusters the pairwise match predictions into groups of connected records using
the connected components graph clustering algorithm.
Records with an estimated match probability at or above threshold_match_probability
are considered to be a match (i.e. they represent the same entity).
If no match probability column is provided, it is assumed that all edges
If no match probability or match weight is provided, it is assumed that all edges
(comparison) are a match.
If your node and edge column names follow Splink naming conventions, then you can
Expand All @@ -68,6 +75,8 @@ def cluster_pairwise_predictions_at_threshold(
right edge IDs. If not provided, assumed to be f"{node_id_column_name}_r"
threshold_match_probability (Optional[float]): Pairwise comparisons with a
match_probability at or above this threshold are matched
threshold_match_weight (Optional[float]): Pairwise comparisons with a
match_weight at or above this threshold are matched
Returns:
SplinkDataFrame: A SplinkDataFrame containing a list of all IDs, clustered
Expand Down Expand Up @@ -128,6 +137,10 @@ def cluster_pairwise_predictions_at_threshold(
edge_id_column_name_right,
)

threshold_match_probability = threshold_args_to_match_prob(
threshold_match_probability, threshold_match_weight
)

cc = solve_connected_components(
nodes_table=nodes_sdf,
edges_table=edges_sdf,
Expand Down Expand Up @@ -223,23 +236,38 @@ def _calculate_stable_clusters_at_new_threshold(
return sqls


def _threshold_to_str(x):
if x == 0.0:
return "0_0"
elif x == 1.0:
return "1_0"
def _threshold_to_str(match_probability: float, is_match_weight: bool = False) -> str:
if is_match_weight:
if match_probability == 0.0:
return "mw_minus_inf"
elif match_probability == 1.0:
return "mw_inf"
else:
weight = prob_to_match_weight(match_probability)
formatted = f"{weight:.6f}".rstrip("0")
if formatted.endswith("."):
formatted = formatted[:-1]
return f"mw_{formatted.replace('.', '_')}"
else:
return f"{x:.8f}".rstrip("0").replace(".", "_")
if match_probability == 0.0:
return "0_0"
elif match_probability == 1.0:
return "1_0"
formatted = f"{match_probability:.6f}".rstrip("0")
if formatted.endswith("."):
formatted = formatted[:-1]
return f"p_{formatted.replace('.', '_')}"


def _generate_detailed_cluster_comparison_sql(
all_results: dict[float, SplinkDataFrame],
unique_id_col: str = "unique_id",
is_match_weight: bool = False,
) -> str:
thresholds = sorted(all_results.keys())

select_columns = [f"t0.{unique_id_col}"] + [
f"t{i}.cluster_id AS cluster_{_threshold_to_str(threshold)}"
f"t{i}.cluster_id AS cluster_{_threshold_to_str(threshold, is_match_weight)}"
for i, threshold in enumerate(thresholds)
]

Expand Down Expand Up @@ -286,14 +314,25 @@ def _get_cluster_stats_sql(cc: SplinkDataFrame) -> list[dict[str, str]]:
return sqls


def _threshold_to_weight_for_table(p):
if p == 0 or p == 1:
return "NULL"
else:
return str(math.log2(p / (1 - p)))


def _generate_cluster_summary_stats_sql(
all_results: dict[float, SplinkDataFrame],
) -> str:
thresholds = sorted(all_results.keys())

select_statements = [
f"""
SELECT cast({threshold} as float) as threshold, *
SELECT
cast({threshold} as float) as threshold_match_probability,
cast({_threshold_to_weight_for_table(threshold)} as float)
as threshold_match_weight,
*
FROM {all_results[threshold].physical_name}
"""
for threshold in thresholds
Expand All @@ -309,7 +348,8 @@ def cluster_pairwise_predictions_at_multiple_thresholds(
edges: AcceptableInputTableType,
db_api: DatabaseAPISubClass,
node_id_column_name: str,
match_probability_thresholds: list[float],
match_probability_thresholds: list[float] | None = None,
match_weight_thresholds: list[float] | None = None,
edge_id_column_name_left: Optional[str] = None,
edge_id_column_name_right: Optional[str] = None,
output_cluster_summary_stats: bool = False,
Expand All @@ -330,8 +370,10 @@ def cluster_pairwise_predictions_at_multiple_thresholds(
edges (AcceptableInputTableType): The table containing edge information
db_api (DatabaseAPISubClass): The database API to use for querying
node_id_column_name (str): The name of the column containing node IDs
match_probability_thresholds (list[float]): List of thresholds to
compute clusters for
match_probability_thresholds (list[float] | None): List of match probability
thresholds to compute clusters for
match_weight_thresholds (list[float] | None): List of match weight thresholds
to compute clusters for
edge_id_column_name_left (Optional[str]): The name of the column containing
left edge IDs. If not provided, assumed to be f"{node_id_column_name}_l"
edge_id_column_name_right (Optional[str]): The name of the column containing
Expand Down Expand Up @@ -409,7 +451,19 @@ def cluster_pairwise_predictions_at_multiple_thresholds(
else:
edges_sdf = edges

match_probability_thresholds = sorted(match_probability_thresholds)
is_match_weight = (
match_weight_thresholds is not None and match_probability_thresholds is None
)

match_probability_thresholds = threshold_args_to_match_prob_list(
match_probability_thresholds, match_weight_thresholds
)

if match_probability_thresholds is None or len(match_probability_thresholds) == 0:
raise ValueError(
"Must provide either match_probability_thresholds "
"or match_weight_thresholds"
)

initial_threshold = match_probability_thresholds.pop(0)
all_results = {}
Expand Down Expand Up @@ -489,6 +543,8 @@ def cluster_pairwise_predictions_at_multiple_thresholds(
nodes_in_play,
edges_in_play,
node_id_column_name=node_id_column_name,
edge_id_column_name_left=edge_id_column_name_left,
edge_id_column_name_right=edge_id_column_name_right,
db_api=db_api,
threshold_match_probability=new_threshold,
)
Expand Down Expand Up @@ -530,6 +586,7 @@ def cluster_pairwise_predictions_at_multiple_thresholds(
sql = _generate_detailed_cluster_comparison_sql(
all_results,
unique_id_col=node_id_column_name,
is_match_weight=is_match_weight,
)

pipeline = CTEPipeline()
Expand Down
16 changes: 14 additions & 2 deletions splink/internals/linker_components/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
_node_degree_sql,
_size_density_centralisation_sql,
)
from splink.internals.misc import (
threshold_args_to_match_prob,
)
from splink.internals.pipeline import CTEPipeline
from splink.internals.splink_dataframe import SplinkDataFrame
from splink.internals.unique_id_concat import (
Expand Down Expand Up @@ -38,19 +41,24 @@ def cluster_pairwise_predictions_at_threshold(
self,
df_predict: SplinkDataFrame,
threshold_match_probability: Optional[float] = None,
threshold_match_weight: Optional[float] = None,
) -> SplinkDataFrame:
"""Clusters the pairwise match predictions that result from
`linker.inference.predict()` into groups of connected record using the connected
components graph clustering algorithm
Records with an estimated `match_probability` at or above
`threshold_match_probability` are considered to be a match (i.e. they represent
`threshold_match_probability` (or records with a `match_weight` at or above
`threshold_match_weight`) are considered to be a match (i.e. they represent
the same entity).
Args:
df_predict (SplinkDataFrame): The results of `linker.predict()`
threshold_match_probability (float): Pairwise comparisons with a
threshold_match_probability (float, optional): Pairwise comparisons with a
`match_probability` at or above this threshold are matched
threshold_match_weight (float, optional): Pairwise comparisons with a
`match_weight` at or above this threshold are matched. Only one of
threshold_match_probability or threshold_match_weight should be provided
Returns:
SplinkDataFrame: A SplinkDataFrame containing a list of all IDs, clustered
Expand Down Expand Up @@ -92,6 +100,10 @@ def cluster_pairwise_predictions_at_threshold(
c.unquote().name for c in df_predict.columns
]

threshold_match_probability = threshold_args_to_match_prob(
threshold_match_probability, threshold_match_weight
)

if not has_match_prob_col and threshold_match_probability is not None:
raise ValueError(
"df_predict must have a column called 'match_probability' if "
Expand Down
64 changes: 64 additions & 0 deletions splink/internals/misc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json
import pkgutil
import random
Expand Down Expand Up @@ -183,3 +185,65 @@ def read_resource(path: str) -> str:
if (resource_data := pkgutil.get_data("splink", path)) is None:
raise FileNotFoundError(f"Could not locate splink resource at: {path}")
return resource_data.decode("utf-8")


def threshold_args_to_match_weight(
threshold_match_probability: float | None, threshold_match_weight: float | None
) -> float | None:
if threshold_match_probability is not None and threshold_match_weight is not None:
raise ValueError(
"Cannot provide both threshold_match_probability and "
"threshold_match_weight. Please specify only one."
)

if threshold_match_probability is not None:
if threshold_match_probability == 0:
return None
return prob_to_match_weight(threshold_match_probability)

if threshold_match_weight is not None:
return threshold_match_weight

return None


def threshold_args_to_match_prob(
threshold_match_probability: float | None, threshold_match_weight: float | None
) -> float | None:
if threshold_match_probability is not None and threshold_match_weight is not None:
raise ValueError(
"Cannot provide both threshold_match_probability and "
"threshold_match_weight. Please specify only one."
)

if threshold_match_probability is not None:
return threshold_match_probability

if threshold_match_weight is not None:
return bayes_factor_to_prob(
match_weight_to_bayes_factor(threshold_match_weight)
)

return None


def threshold_args_to_match_prob_list(
match_probability_thresholds: list[float] | None,
match_weight_thresholds: list[float] | None,
) -> list[float] | None:
if match_probability_thresholds is not None and match_weight_thresholds is not None:
raise ValueError(
"Cannot provide both match_probability_thresholds and "
"match_weight_thresholds. Please specify only one."
)

if match_probability_thresholds is not None:
return sorted(match_probability_thresholds)

if match_weight_thresholds is not None:
return sorted(
bayes_factor_to_prob(match_weight_to_bayes_factor(w))
for w in match_weight_thresholds
)

return None
27 changes: 11 additions & 16 deletions splink/internals/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

from splink.internals.comparison import Comparison
from splink.internals.input_column import InputColumn
from splink.internals.misc import prob_to_bayes_factor, prob_to_match_weight
from splink.internals.misc import (
prob_to_bayes_factor,
threshold_args_to_match_weight,
)

from .settings import CoreModelSettings, Settings

Expand Down Expand Up @@ -97,21 +100,13 @@ def predict_from_comparison_vectors_sqls(
bf_terms,
sql_infinity_expression,
)
# Add condition to treat case of 0 as None
if threshold_match_probability == 0:
threshold_match_probability = None
# In case user provided both, take the minimum of the two thresholds
if threshold_match_probability is not None:
thres_prob_as_weight = prob_to_match_weight(threshold_match_probability)
else:
thres_prob_as_weight = None
if threshold_match_probability is not None or threshold_match_weight is not None:
thresholds = [
thres_prob_as_weight,
threshold_match_weight,
]
threshold = max([t for t in thresholds if t is not None])
threshold_expr = f" where log2({bayes_factor_expr}) >= {threshold} "

threshold_as_mw = threshold_args_to_match_weight(
threshold_match_probability, threshold_match_weight
)

if threshold_as_mw is not None:
threshold_expr = f" where log2({bayes_factor_expr}) >= {threshold_as_mw} "
else:
threshold_expr = ""

Expand Down
3 changes: 3 additions & 0 deletions splink/internals/spark/database_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def _repartition_if_needed(self, spark_df, templated_name):
r"__splink__clusters_at_threshold",
r"__splink__clusters_at_all_thresholds",
r"__splink__stable_nodes_at_new_threshold",
r"__splink__clustering_output_final",
]

num_partitions = self.num_partitions_on_repartition
Expand Down Expand Up @@ -263,6 +264,8 @@ def _repartition_if_needed(self, spark_df, templated_name):
num_partitions = math.ceil(num_partitions / 10)
elif templated_name == "__splink__stable_nodes_at_new_threshold":
num_partitions = math.ceil(num_partitions / 10)
elif templated_name == "__splink__clustering_output_final":
num_partitions = math.ceil(num_partitions / 10)

if re.fullmatch(r"|".join(names_to_repartition), templated_name):
spark_df = spark_df.repartition(num_partitions)
Expand Down
Loading

0 comments on commit 2aa78da

Please sign in to comment.