Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove clustering pairwise output format #2264

Merged
merged 1 commit into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 8 additions & 42 deletions splink/internals/connected_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,60 +359,29 @@ def _cc_create_unique_id_cols(


def _exit_query(
pairwise_mode: bool,
df_predict: SplinkDataFrame,
representatives: SplinkDataFrame,
concat_with_tf: SplinkDataFrame,
uid_cols: list[InputColumn],
pairwise_filter: bool,
) -> str:
representatives_name = representatives.physical_name
concat_with_tf_name = concat_with_tf.physical_name

if pairwise_mode:
df_predict_name = df_predict.physical_name
uid_concat_l = _composite_unique_id_from_edges_sql(uid_cols, "l", "n")
uid_concat_r = _composite_unique_id_from_edges_sql(uid_cols, "r", "n")

filter_cond = "where cluster_id_l = cluster_id_r" if pairwise_filter else ""

return f"""
select
n.*,
repr_l.representative as cluster_id_l,
repr_r.representative as cluster_id_r
from {df_predict_name} as n
left join
{representatives_name} as repr_l
on {uid_concat_l} = repr_l.node_id
left join
{representatives_name} as repr_r
on {uid_concat_r} = repr_r.node_id
{filter_cond}
order by
cluster_id_l, cluster_id_r
"""

else:
uid_concat = _composite_unique_id_from_nodes_sql(uid_cols, "n")
uid_concat = _composite_unique_id_from_nodes_sql(uid_cols, "n")

return f"""
select
c.representative as cluster_id, n.*
from {representatives_name} as c
return f"""
select
c.representative as cluster_id, n.*
from {representatives_name} as c

left join {concat_with_tf_name} as n
on {uid_concat} = c.node_id
"""
left join {concat_with_tf_name} as n
on {uid_concat} = c.node_id
"""


def solve_connected_components(
linker: "Linker",
edges_table: SplinkDataFrame,
df_predict: SplinkDataFrame,
concat_with_tf: SplinkDataFrame,
pairwise_output: bool = False,
filter_pairwise_format_for_clusters: bool = False,
_generated_graph: bool = False,
) -> SplinkDataFrame:
"""Connected Components main algorithm.
Expand Down Expand Up @@ -531,12 +500,9 @@ def solve_connected_components(
uid_cols = linker._settings_obj.column_info_settings.unique_id_input_columns

exit_query = _exit_query(
pairwise_mode=pairwise_output,
df_predict=df_predict,
representatives=representatives,
concat_with_tf=concat_with_tf,
uid_cols=uid_cols,
pairwise_filter=filter_pairwise_format_for_clusters,
)
pipeline = CTEPipeline([representatives])
pipeline.enqueue_sql(exit_query, "__splink__df_representatives")
Expand Down
10 changes: 0 additions & 10 deletions splink/internals/linker_components/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ def cluster_pairwise_predictions_at_threshold(
self,
df_predict: SplinkDataFrame,
threshold_match_probability: Optional[float] = None,
pairwise_formatting: bool = False,
filter_pairwise_format_for_clusters: bool = True,
) -> SplinkDataFrame:
"""Clusters the pairwise match predictions that result from
`linker.inference.predict()` into groups of connected record using the connected
Expand All @@ -53,11 +51,6 @@ def cluster_pairwise_predictions_at_threshold(
df_predict (SplinkDataFrame): The results of `linker.predict()`
threshold_match_probability (float): Pairwise comparisons with a
`match_probability` at or above this threshold are matched
pairwise_formatting (bool): Whether to output the pairwise match predictions
from linker.predict() with cluster IDs.
filter_pairwise_format_for_clusters (bool): If pairwise formatting has been
selected, whether to output all pairs, or only those belonging to a
cluster of size 2 or greater.

Returns:
SplinkDataFrame: A SplinkDataFrame containing a list of all IDs, clustered
Expand All @@ -79,10 +72,7 @@ def cluster_pairwise_predictions_at_threshold(
cc = solve_connected_components(
self._linker,
edges_table,
df_predict,
nodes_with_tf,
pairwise_formatting,
filter_pairwise_format_for_clusters,
)
cc.metadata["threshold_match_probability"] = threshold_match_probability

Expand Down
1 change: 0 additions & 1 deletion tests/cc_testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def run_cc_implementation(linker, predict_df):
cc = solve_connected_components(
linker,
predict_df,
df_predict=None,
concat_with_tf=concat_with_tf,
_generated_graph=True,
).as_pandas_dataframe()
Expand Down
Loading