Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clustering allows match weight args not just match probability #2454

Merged
merged 18 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 71 additions & 14 deletions splink/internals/clustering.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from __future__ import annotations

import logging
import math
from typing import Optional

from splink.internals.connected_components import solve_connected_components
from splink.internals.database_api import AcceptableInputTableType, DatabaseAPISubClass
from splink.internals.input_column import InputColumn
from splink.internals.misc import ascii_uid
from splink.internals.misc import (
ascii_uid,
prob_to_match_weight,
threshold_args_to_match_prob,
threshold_args_to_match_prob_list,
)
from splink.internals.pipeline import CTEPipeline
from splink.internals.splink_dataframe import SplinkDataFrame

Expand Down Expand Up @@ -42,14 +48,15 @@ def cluster_pairwise_predictions_at_threshold(
edge_id_column_name_left: Optional[str] = None,
edge_id_column_name_right: Optional[str] = None,
threshold_match_probability: Optional[float] = None,
threshold_match_weight: Optional[float] = None,
) -> SplinkDataFrame:
"""Clusters the pairwise match predictions into groups of connected records using
the connected components graph clustering algorithm.

Records with an estimated match probability at or above threshold_match_probability
are considered to be a match (i.e. they represent the same entity).

If no match probability column is provided, it is assumed that all edges
If no match probability or match weight is provided, it is assumed that all edges
(comparison) are a match.

If your node and edge column names follow Splink naming conventions, then you can
Expand All @@ -68,6 +75,8 @@ def cluster_pairwise_predictions_at_threshold(
right edge IDs. If not provided, assumed to be f"{node_id_column_name}_r"
threshold_match_probability (Optional[float]): Pairwise comparisons with a
match_probability at or above this threshold are matched
threshold_match_weight (Optional[float]): Pairwise comparisons with a
match_weight at or above this threshold are matched

Returns:
SplinkDataFrame: A SplinkDataFrame containing a list of all IDs, clustered
Expand Down Expand Up @@ -128,6 +137,10 @@ def cluster_pairwise_predictions_at_threshold(
edge_id_column_name_right,
)

threshold_match_probability = threshold_args_to_match_prob(
threshold_match_probability, threshold_match_weight
)

cc = solve_connected_components(
nodes_table=nodes_sdf,
edges_table=edges_sdf,
Expand Down Expand Up @@ -223,23 +236,38 @@ def _calculate_stable_clusters_at_new_threshold(
return sqls


def _threshold_to_str(x):
if x == 0.0:
return "0_0"
elif x == 1.0:
return "1_0"
def _threshold_to_str(match_probability: float, is_match_weight: bool = False) -> str:
if is_match_weight:
if match_probability == 0.0:
return "mw_minus_inf"
elif match_probability == 1.0:
return "mw_inf"
else:
weight = prob_to_match_weight(match_probability)
formatted = f"{weight:.6f}".rstrip("0")
if formatted.endswith("."):
formatted = formatted[:-1]
return f"mw_{formatted.replace('.', '_')}"
else:
return f"{x:.8f}".rstrip("0").replace(".", "_")
if match_probability == 0.0:
return "0_0"
elif match_probability == 1.0:
return "1_0"
formatted = f"{match_probability:.6f}".rstrip("0")
if formatted.endswith("."):
formatted = formatted[:-1]
return f"p_{formatted.replace('.', '_')}"


def _generate_detailed_cluster_comparison_sql(
all_results: dict[float, SplinkDataFrame],
unique_id_col: str = "unique_id",
is_match_weight: bool = False,
) -> str:
thresholds = sorted(all_results.keys())

select_columns = [f"t0.{unique_id_col}"] + [
f"t{i}.cluster_id AS cluster_{_threshold_to_str(threshold)}"
f"t{i}.cluster_id AS cluster_{_threshold_to_str(threshold, is_match_weight)}"
for i, threshold in enumerate(thresholds)
]

Expand Down Expand Up @@ -286,14 +314,25 @@ def _get_cluster_stats_sql(cc: SplinkDataFrame) -> list[dict[str, str]]:
return sqls


def _threshold_to_weight_for_table(p):
if p == 0 or p == 1:
return "NULL"
else:
return str(math.log2(p / (1 - p)))


def _generate_cluster_summary_stats_sql(
all_results: dict[float, SplinkDataFrame],
) -> str:
thresholds = sorted(all_results.keys())

select_statements = [
f"""
SELECT cast({threshold} as float) as threshold, *
SELECT
cast({threshold} as float) as threshold_match_probability,
cast({_threshold_to_weight_for_table(threshold)} as float)
as threshold_match_weight,
*
FROM {all_results[threshold].physical_name}
"""
for threshold in thresholds
Expand All @@ -309,7 +348,8 @@ def cluster_pairwise_predictions_at_multiple_thresholds(
edges: AcceptableInputTableType,
db_api: DatabaseAPISubClass,
node_id_column_name: str,
match_probability_thresholds: list[float],
match_probability_thresholds: list[float] | None = None,
match_weight_thresholds: list[float] | None = None,
edge_id_column_name_left: Optional[str] = None,
edge_id_column_name_right: Optional[str] = None,
output_cluster_summary_stats: bool = False,
Expand All @@ -330,8 +370,10 @@ def cluster_pairwise_predictions_at_multiple_thresholds(
edges (AcceptableInputTableType): The table containing edge information
db_api (DatabaseAPISubClass): The database API to use for querying
node_id_column_name (str): The name of the column containing node IDs
match_probability_thresholds (list[float]): List of thresholds to
compute clusters for
match_probability_thresholds (list[float] | None): List of match probability
thresholds to compute clusters for
match_weight_thresholds (list[float] | None): List of match weight thresholds
to compute clusters for
edge_id_column_name_left (Optional[str]): The name of the column containing
left edge IDs. If not provided, assumed to be f"{node_id_column_name}_l"
edge_id_column_name_right (Optional[str]): The name of the column containing
Expand Down Expand Up @@ -409,7 +451,19 @@ def cluster_pairwise_predictions_at_multiple_thresholds(
else:
edges_sdf = edges

match_probability_thresholds = sorted(match_probability_thresholds)
is_match_weight = (
match_weight_thresholds is not None and match_probability_thresholds is None
)

match_probability_thresholds = threshold_args_to_match_prob_list(
match_probability_thresholds, match_weight_thresholds
)

if match_probability_thresholds is None or len(match_probability_thresholds) == 0:
raise ValueError(
"Must provide either match_probability_thresholds "
"or match_weight_thresholds"
)

initial_threshold = match_probability_thresholds.pop(0)
all_results = {}
Expand Down Expand Up @@ -489,6 +543,8 @@ def cluster_pairwise_predictions_at_multiple_thresholds(
nodes_in_play,
edges_in_play,
node_id_column_name=node_id_column_name,
edge_id_column_name_left=edge_id_column_name_left,
edge_id_column_name_right=edge_id_column_name_right,
db_api=db_api,
threshold_match_probability=new_threshold,
)
Expand Down Expand Up @@ -530,6 +586,7 @@ def cluster_pairwise_predictions_at_multiple_thresholds(
sql = _generate_detailed_cluster_comparison_sql(
all_results,
unique_id_col=node_id_column_name,
is_match_weight=is_match_weight,
)

pipeline = CTEPipeline()
Expand Down
2 changes: 1 addition & 1 deletion splink/internals/connected_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def solve_connected_components(
SELECT representative FROM __splink__representatives_stable_{iteration}
)
"""
pipeline.enqueue_sql(sql, "__splink__representatives_unstable")
pipeline.enqueue_sql(sql, f"__splink__representatives_unstable_{iteration}")
prev_representatives_thinned = db_api.sql_pipeline_to_splink_dataframe(pipeline)

# 1a. Thin neighbours table - we can drop all rows that refer to
Expand Down
16 changes: 14 additions & 2 deletions splink/internals/linker_components/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
_node_degree_sql,
_size_density_centralisation_sql,
)
from splink.internals.misc import (
threshold_args_to_match_prob,
)
from splink.internals.pipeline import CTEPipeline
from splink.internals.splink_dataframe import SplinkDataFrame
from splink.internals.unique_id_concat import (
Expand Down Expand Up @@ -38,19 +41,24 @@ def cluster_pairwise_predictions_at_threshold(
self,
df_predict: SplinkDataFrame,
threshold_match_probability: Optional[float] = None,
threshold_match_weight: Optional[float] = None,
) -> SplinkDataFrame:
"""Clusters the pairwise match predictions that result from
`linker.inference.predict()` into groups of connected record using the connected
components graph clustering algorithm

Records with an estimated `match_probability` at or above
`threshold_match_probability` are considered to be a match (i.e. they represent
`threshold_match_probability` (or records with a `match_weight` at or above
`threshold_match_weight`) are considered to be a match (i.e. they represent
the same entity).

Args:
df_predict (SplinkDataFrame): The results of `linker.predict()`
threshold_match_probability (float): Pairwise comparisons with a
threshold_match_probability (float, optional): Pairwise comparisons with a
`match_probability` at or above this threshold are matched
threshold_match_weight (float, optional): Pairwise comparisons with a
`match_weight` at or above this threshold are matched. Only one of
threshold_match_probability or threshold_match_weight should be provided

Returns:
SplinkDataFrame: A SplinkDataFrame containing a list of all IDs, clustered
Expand Down Expand Up @@ -92,6 +100,10 @@ def cluster_pairwise_predictions_at_threshold(
c.unquote().name for c in df_predict.columns
]

threshold_match_probability = threshold_args_to_match_prob(
threshold_match_probability, threshold_match_weight
)

if not has_match_prob_col and threshold_match_probability is not None:
raise ValueError(
"df_predict must have a column called 'match_probability' if "
Expand Down
64 changes: 64 additions & 0 deletions splink/internals/misc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json
import pkgutil
import random
Expand Down Expand Up @@ -183,3 +185,65 @@ def read_resource(path: str) -> str:
if (resource_data := pkgutil.get_data("splink", path)) is None:
raise FileNotFoundError(f"Could not locate splink resource at: {path}")
return resource_data.decode("utf-8")


def threshold_args_to_match_weight(
threshold_match_probability: float | None, threshold_match_weight: float | None
) -> float | None:
if threshold_match_probability is not None and threshold_match_weight is not None:
raise ValueError(
"Cannot provide both threshold_match_probability and "
"threshold_match_weight. Please specify only one."
)

if threshold_match_probability is not None:
if threshold_match_probability == 0:
return None
return prob_to_match_weight(threshold_match_probability)

if threshold_match_weight is not None:
return threshold_match_weight

return None


def threshold_args_to_match_prob(
threshold_match_probability: float | None, threshold_match_weight: float | None
) -> float | None:
if threshold_match_probability is not None and threshold_match_weight is not None:
raise ValueError(
"Cannot provide both threshold_match_probability and "
"threshold_match_weight. Please specify only one."
)

if threshold_match_probability is not None:
return threshold_match_probability

if threshold_match_weight is not None:
return bayes_factor_to_prob(
match_weight_to_bayes_factor(threshold_match_weight)
)

return None


def threshold_args_to_match_prob_list(
match_probability_thresholds: list[float] | None,
match_weight_thresholds: list[float] | None,
) -> list[float] | None:
if match_probability_thresholds is not None and match_weight_thresholds is not None:
raise ValueError(
"Cannot provide both match_probability_thresholds and "
"match_weight_thresholds. Please specify only one."
)

if match_probability_thresholds is not None:
return sorted(match_probability_thresholds)

if match_weight_thresholds is not None:
return sorted(
bayes_factor_to_prob(match_weight_to_bayes_factor(w))
for w in match_weight_thresholds
)

return None
27 changes: 11 additions & 16 deletions splink/internals/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

from splink.internals.comparison import Comparison
from splink.internals.input_column import InputColumn
from splink.internals.misc import prob_to_bayes_factor, prob_to_match_weight
from splink.internals.misc import (
prob_to_bayes_factor,
threshold_args_to_match_weight,
)

from .settings import CoreModelSettings, Settings

Expand Down Expand Up @@ -97,21 +100,13 @@ def predict_from_comparison_vectors_sqls(
bf_terms,
sql_infinity_expression,
)
# Add condition to treat case of 0 as None
if threshold_match_probability == 0:
threshold_match_probability = None
# In case user provided both, take the minimum of the two thresholds
if threshold_match_probability is not None:
thres_prob_as_weight = prob_to_match_weight(threshold_match_probability)
else:
thres_prob_as_weight = None
if threshold_match_probability is not None or threshold_match_weight is not None:
thresholds = [
thres_prob_as_weight,
threshold_match_weight,
]
threshold = max([t for t in thresholds if t is not None])
threshold_expr = f" where log2({bayes_factor_expr}) >= {threshold} "

threshold_as_mw = threshold_args_to_match_weight(
threshold_match_probability, threshold_match_weight
)

if threshold_as_mw is not None:
threshold_expr = f" where log2({bayes_factor_expr}) >= {threshold_as_mw} "
else:
threshold_expr = ""

Expand Down
3 changes: 3 additions & 0 deletions splink/internals/spark/database_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def _repartition_if_needed(self, spark_df, templated_name):
r"__splink__clusters_at_threshold",
r"__splink__clusters_at_all_thresholds",
r"__splink__stable_nodes_at_new_threshold",
r"__splink__clustering_output_final",
]

num_partitions = self.num_partitions_on_repartition
Expand Down Expand Up @@ -263,6 +264,8 @@ def _repartition_if_needed(self, spark_df, templated_name):
num_partitions = math.ceil(num_partitions / 10)
elif templated_name == "__splink__stable_nodes_at_new_threshold":
num_partitions = math.ceil(num_partitions / 10)
elif templated_name == "__splink__clustering_output_final":
num_partitions = math.ceil(num_partitions / 10)

if re.fullmatch(r"|".join(names_to_repartition), templated_name):
spark_df = spark_df.repartition(num_partitions)
Expand Down
Loading
Loading