From 5d61ca1386fd4b176059ae832797a5cd16ddbef6 Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Fri, 31 Mar 2023 09:06:45 +0100 Subject: [PATCH 01/17] fix docs path link break --- docs/comparison_template_library.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/comparison_template_library.md b/docs/comparison_template_library.md index a5cadfc86f..5d0f4c3e88 100644 --- a/docs/comparison_template_library.md +++ b/docs/comparison_template_library.md @@ -7,7 +7,7 @@ tags: # Documentation for `comparison_template_library` -The `comparison_template_library` contains pre-made comparisons with pre-defined parameters available for use directly [as described in this topic guide](../topic_guides/customising_comparisons.html#method-2-using-the-comparisontemplatelibrary). +The `comparison_template_library` contains pre-made comparisons with pre-defined parameters available for use directly [as described in this topic guide](./topic_guides/customising_comparisons.html#method-2-using-the-comparisontemplatelibrary). However, not every comparison is available for every [Splink-compatible SQL backend](./topic_guides/backends.html). More detail on creating comparisons for specific data types is also [included in the topic guide.](./topic_guides/customising_comparisons.html#creating-comparisons-for-specific-data-types) The detailed API for each of these are outlined below. From 68be1b6f2c3a6a815e4c0f10cd40d6db7b038522 Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Fri, 31 Mar 2023 10:25:08 +0100 Subject: [PATCH 02/17] Working for duckdb --- splink/duckdb/duckdb_linker.py | 7 +++++-- splink/estimate_u.py | 4 ++-- splink/linker.py | 4 ++-- tests/test_u_train.py | 25 +++++++++++++++++++++++++ 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/splink/duckdb/duckdb_linker.py b/splink/duckdb/duckdb_linker.py index 8c425f87af..08671e2933 100644 --- a/splink/duckdb/duckdb_linker.py +++ b/splink/duckdb/duckdb_linker.py @@ -209,11 +209,14 @@ def register_table(self, input, table_name, overwrite=False): self._con.register(table_name, input) return self._table_to_splink_dataframe(table_name, table_name) - def _random_sample_sql(self, proportion, sample_size): + def _random_sample_sql(self, proportion, sample_size, seed=None): if proportion == 1.0: return "" percent = proportion * 100 - return f"USING SAMPLE {percent}% (bernoulli)" + if seed: + return f"USING SAMPLE bernoulli({percent}%) REPEATABLE({seed})" + else: + return f"USING SAMPLE {percent}% (bernoulli)" @property def _infinity_expression(self): diff --git a/splink/estimate_u.py b/splink/estimate_u.py index 34771f4f2d..c62edba028 100644 --- a/splink/estimate_u.py +++ b/splink/estimate_u.py @@ -31,7 +31,7 @@ def _rows_needed_for_n_pairs(n_pairs): return sample_rows -def estimate_u_values(linker: Linker, max_pairs): +def estimate_u_values(linker: Linker, max_pairs, seed=None): logger.info("----- Estimating u probabilities using random sampling -----") nodes_with_tf = linker._initialise_df_concat_with_tf() @@ -93,7 +93,7 @@ def estimate_u_values(linker: Linker, max_pairs): sql = f""" select * from __splink__df_concat_with_tf - {training_linker._random_sample_sql(proportion, sample_size)} + {training_linker._random_sample_sql(proportion, sample_size, seed)} """ training_linker._enqueue_sql(sql, "__splink__df_concat_with_tf_sample") df_sample = training_linker._execute_sql_pipeline([nodes_with_tf]) diff --git a/splink/linker.py b/splink/linker.py index 603c932fe0..6b50270219 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -1000,7 +1000,7 @@ def deterministic_link(self) -> SplinkDataFrame: return self._execute_sql_pipeline([concat_with_tf]) def estimate_u_using_random_sampling( - self, max_pairs: int = None, *, target_rows=None + self, max_pairs: int = None, seed = None, *, target_rows=None ): """Estimate the u parameters of the linkage model using random sampling. @@ -1047,7 +1047,7 @@ def estimate_u_using_random_sampling( else: raise TypeError("Missing argument max_pairs") - estimate_u_values(self, max_pairs) + estimate_u_values(self, max_pairs, seed) self._populate_m_u_from_trained_values() self._settings_obj._columns_without_estimated_parameters_message() diff --git a/tests/test_u_train.py b/tests/test_u_train.py index 57c83ae688..aec4775b08 100644 --- a/tests/test_u_train.py +++ b/tests/test_u_train.py @@ -4,6 +4,7 @@ from splink.duckdb.duckdb_comparison_library import levenshtein_at_thresholds from splink.duckdb.duckdb_linker import DuckDBLinker +from splink.estimate_u import _rows_needed_for_n_pairs def test_u_train(): @@ -235,3 +236,27 @@ def test_u_train_multilink(): assert cl_lev.u_probability == 1 / denom cl_no = cc_name._get_comparison_level_by_comparison_vector_value(0) assert cl_no.u_probability == (denom - 10) / denom + + +def test_seed_u_outputs(): + + df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") + + settings = { + "link_type": "dedupe_only", + "comparisons": [levenshtein_at_thresholds("first_name", 2)], + "blocking_rules_to_generate_predictions": [], + } + + linker_1 = DuckDBLinker(df, settings) + linker_2 = DuckDBLinker(df, settings) + linker_3 = DuckDBLinker(df, settings) + + linker_1.estimate_u_using_random_sampling(max_pairs=1e3, seed=1) + linker_2.estimate_u_using_random_sampling(max_pairs=1e3, seed=1) + linker_3.estimate_u_using_random_sampling(max_pairs=1e3, seed=2) + + print(linker_1._settings_obj._parameter_estimates_as_records) + + assert linker_1._settings_obj._parameter_estimates_as_records == linker_2._settings_obj._parameter_estimates_as_records + assert linker_1._settings_obj._parameter_estimates_as_records != linker_3._settings_obj._parameter_estimates_as_records From 1d6b1c06edbfda58354aceb7894b34623b40466a Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Fri, 31 Mar 2023 10:47:40 +0100 Subject: [PATCH 03/17] linting --- splink/linker.py | 2 +- tests/test_u_train.py | 15 ++++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/splink/linker.py b/splink/linker.py index 6b50270219..ff9e44832e 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -1000,7 +1000,7 @@ def deterministic_link(self) -> SplinkDataFrame: return self._execute_sql_pipeline([concat_with_tf]) def estimate_u_using_random_sampling( - self, max_pairs: int = None, seed = None, *, target_rows=None + self, max_pairs: int = None, seed=None, *, target_rows=None ): """Estimate the u parameters of the linkage model using random sampling. diff --git a/tests/test_u_train.py b/tests/test_u_train.py index aec4775b08..d17a08d9d5 100644 --- a/tests/test_u_train.py +++ b/tests/test_u_train.py @@ -4,7 +4,6 @@ from splink.duckdb.duckdb_comparison_library import levenshtein_at_thresholds from splink.duckdb.duckdb_linker import DuckDBLinker -from splink.estimate_u import _rows_needed_for_n_pairs def test_u_train(): @@ -239,7 +238,7 @@ def test_u_train_multilink(): def test_seed_u_outputs(): - + df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") settings = { @@ -247,7 +246,7 @@ def test_seed_u_outputs(): "comparisons": [levenshtein_at_thresholds("first_name", 2)], "blocking_rules_to_generate_predictions": [], } - + linker_1 = DuckDBLinker(df, settings) linker_2 = DuckDBLinker(df, settings) linker_3 = DuckDBLinker(df, settings) @@ -258,5 +257,11 @@ def test_seed_u_outputs(): print(linker_1._settings_obj._parameter_estimates_as_records) - assert linker_1._settings_obj._parameter_estimates_as_records == linker_2._settings_obj._parameter_estimates_as_records - assert linker_1._settings_obj._parameter_estimates_as_records != linker_3._settings_obj._parameter_estimates_as_records + assert ( + linker_1._settings_obj._parameter_estimates_as_records + == linker_2._settings_obj._parameter_estimates_as_records + ) + assert ( + linker_1._settings_obj._parameter_estimates_as_records + != linker_3._settings_obj._parameter_estimates_as_records + ) From 61286b01c79fba8c618a5b7566f43d3c3a7fc03a Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Fri, 31 Mar 2023 10:52:20 +0100 Subject: [PATCH 04/17] add seed parameter to all backend integration test --- tests/test_full_example_athena.py | 2 +- tests/test_full_example_duckdb.py | 2 +- tests/test_full_example_spark.py | 2 +- tests/test_full_example_sqlite.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_full_example_athena.py b/tests/test_full_example_athena.py index 0f7d28fe9d..11ff1aa2b2 100644 --- a/tests/test_full_example_athena.py +++ b/tests/test_full_example_athena.py @@ -134,7 +134,7 @@ def test_full_example_athena(tmp_path): linker.compute_tf_table("city") linker.compute_tf_table("first_name") - linker.estimate_u_using_random_sampling(max_pairs=1e6) + linker.estimate_u_using_random_sampling(max_pairs=1e6, seed=1) blocking_rule = "l.first_name = r.first_name and l.surname = r.surname" linker.estimate_parameters_using_expectation_maximisation(blocking_rule) diff --git a/tests/test_full_example_duckdb.py b/tests/test_full_example_duckdb.py index a35b440879..d7fa250845 100644 --- a/tests/test_full_example_duckdb.py +++ b/tests/test_full_example_duckdb.py @@ -50,7 +50,7 @@ def test_full_example_duckdb(tmp_path): linker.compute_tf_table("city") linker.compute_tf_table("first_name") - linker.estimate_u_using_random_sampling(max_pairs=1e6) + linker.estimate_u_using_random_sampling(max_pairs=1e6, seed=1) linker.estimate_probability_two_random_records_match( ["l.email = r.email"], recall=0.3 ) diff --git a/tests/test_full_example_spark.py b/tests/test_full_example_spark.py index dafa00c83a..163140e3c6 100644 --- a/tests/test_full_example_spark.py +++ b/tests/test_full_example_spark.py @@ -61,7 +61,7 @@ def test_full_example_spark(df_spark, tmp_path): linker.estimate_probability_two_random_records_match( ["l.email = r.email"], recall=0.3 ) - linker.estimate_u_using_random_sampling(max_pairs=1e5) + linker.estimate_u_using_random_sampling(max_pairs=1e5, seed=1) blocking_rule = "l.first_name = r.first_name and l.surname = r.surname" linker.estimate_parameters_using_expectation_maximisation(blocking_rule) diff --git a/tests/test_full_example_sqlite.py b/tests/test_full_example_sqlite.py index d9546741e1..f8098cf4fc 100644 --- a/tests/test_full_example_sqlite.py +++ b/tests/test_full_example_sqlite.py @@ -40,7 +40,7 @@ def power(val, exp): ["l.email = r.email"], recall=0.3 ) - linker.estimate_u_using_random_sampling(max_pairs=1e6) + linker.estimate_u_using_random_sampling(max_pairs=1e6, seed=1) blocking_rule = "l.first_name = r.first_name and l.surname = r.surname" linker.estimate_parameters_using_expectation_maximisation(blocking_rule) From f149d5467d7fea0ee5a69d52a80a08118b75f0bb Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Fri, 31 Mar 2023 15:27:47 +0100 Subject: [PATCH 05/17] add exception handling to athena and sqlite --- splink/athena/athena_linker.py | 15 ++++++++++++++- splink/cluster_studio.py | 7 ++++--- splink/duckdb/duckdb_linker.py | 8 ++++++++ splink/estimate_u.py | 6 +----- splink/linker.py | 4 +++- splink/sqlite/sqlite_linker.py | 16 ++++++++++++++-- 6 files changed, 44 insertions(+), 12 deletions(-) diff --git a/splink/athena/athena_linker.py b/splink/athena/athena_linker.py index b5463e3930..0ad0d6b973 100644 --- a/splink/athena/athena_linker.py +++ b/splink/athena/athena_linker.py @@ -344,12 +344,25 @@ def register_table(self, input, table_name, overwrite=False): self.register_data_on_s3(input, table_name) return self._table_to_splink_dataframe(table_name, table_name) - def _random_sample_sql(self, proportion, sample_size): + def _random_sample_sql(self, proportion, sample_size, seed=None): if proportion == 1.0: return "" + if seed: + raise NotImplementedError( + "Athena does not support seeds in random ", + "samples. Please remove the `seed` parameter.", + ) percent = proportion * 100 return f" TABLESAMPLE BERNOULLI ({percent})" + def _u_random_sample_sql(self, proportion, sample_size, seed=None): + sql = f""" + select * + from __splink__df_concat_with_tf + {self._random_sample_sql(proportion, sample_size, seed)} + """ + return sql + @property def _infinity_expression(self): return "infinity()" diff --git a/splink/cluster_studio.py b/splink/cluster_studio.py index 21112b4727..b9199adc11 100644 --- a/splink/cluster_studio.py +++ b/splink/cluster_studio.py @@ -117,7 +117,7 @@ def df_edges_as_records( def _get_random_cluster_ids( - linker: "Linker", connected_components: SplinkDataFrame, sample_size: int + linker: "Linker", connected_components: SplinkDataFrame, sample_size: int, seed=None ): sql = f""" select count(distinct cluster_id) as count @@ -137,7 +137,7 @@ def _get_random_cluster_ids( from {connected_components.physical_name} ) select cluster_id from distinct_clusters - {linker._random_sample_sql(proportion, sample_size)} + {linker._random_sample_sql(proportion, sample_size, seed)} """ df_sample = linker._sql_to_splink_dataframe_checking_cache( @@ -188,6 +188,7 @@ def render_splink_cluster_studio_html( out_path: str, sampling_method="random", sample_size=10, + sample_seed=None, cluster_ids: list = None, cluster_names: list = None, overwrite: bool = False, @@ -202,7 +203,7 @@ def render_splink_cluster_studio_html( if cluster_ids is None: if sampling_method == "random": cluster_ids = _get_random_cluster_ids( - linker, df_clustered_nodes, sample_size + linker, df_clustered_nodes, sample_size, sample_seed ) if sampling_method == "by_cluster_size": cluster_ids = _get_cluster_id_of_each_size(linker, df_clustered_nodes, 1) diff --git a/splink/duckdb/duckdb_linker.py b/splink/duckdb/duckdb_linker.py index 08671e2933..9fc8a5a7fe 100644 --- a/splink/duckdb/duckdb_linker.py +++ b/splink/duckdb/duckdb_linker.py @@ -218,6 +218,14 @@ def _random_sample_sql(self, proportion, sample_size, seed=None): else: return f"USING SAMPLE {percent}% (bernoulli)" + def _u_random_sample_sql(self, proportion, sample_size, seed=None): + sql = f""" + select * + from __splink__df_concat_with_tf + {self._random_sample_sql(proportion, sample_size, seed)} + """ + return sql + @property def _infinity_expression(self): return "cast('infinity' as double)" diff --git a/splink/estimate_u.py b/splink/estimate_u.py index c62edba028..b51e7e4bed 100644 --- a/splink/estimate_u.py +++ b/splink/estimate_u.py @@ -90,11 +90,7 @@ def estimate_u_values(linker: Linker, max_pairs, seed=None): if sample_size > count_rows: sample_size = count_rows - sql = f""" - select * - from __splink__df_concat_with_tf - {training_linker._random_sample_sql(proportion, sample_size, seed)} - """ + sql = training_linker._u_random_sample_sql(proportion, sample_size, seed) training_linker._enqueue_sql(sql, "__splink__df_concat_with_tf_sample") df_sample = training_linker._execute_sql_pipeline([nodes_with_tf]) diff --git a/splink/linker.py b/splink/linker.py index ff9e44832e..f1a830acdd 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -1000,7 +1000,7 @@ def deterministic_link(self) -> SplinkDataFrame: return self._execute_sql_pipeline([concat_with_tf]) def estimate_u_using_random_sampling( - self, max_pairs: int = None, seed=None, *, target_rows=None + self, max_pairs: int = None, seed: int = None, *, target_rows=None ): """Estimate the u parameters of the linkage model using random sampling. @@ -1020,6 +1020,8 @@ def estimate_u_using_random_sampling( gives best results but can take a long time to compute. 1e7 (ten million) is often adequate whilst testing different model specifications, before the final model is estimated. + seed (int): Seed for random sampling. Assign to get reproducible u + probabilities Examples: >>> linker.estimate_u_using_random_sampling(1e8) diff --git a/splink/sqlite/sqlite_linker.py b/splink/sqlite/sqlite_linker.py index f5f696e493..173e9cab78 100644 --- a/splink/sqlite/sqlite_linker.py +++ b/splink/sqlite/sqlite_linker.py @@ -147,17 +147,29 @@ def register_table(self, input, table_name, overwrite=False): input.to_sql(table_name, self.con, index=False) return self._table_to_splink_dataframe(table_name, table_name) - def _random_sample_sql(self, proportion, sample_size): + def _random_sample_sql(self, proportion, sample_size, seed=None): if proportion == 1.0: return "" + if seed: + raise NotImplementedError( + "SQLite does not support seeds in random ", + "samples. Please remove the `seed` parameter.", + ) sample_size = int(sample_size) - return ( "where unique_id IN (SELECT unique_id FROM __splink__df_concat_with_tf" f" ORDER BY RANDOM() LIMIT {sample_size})" ) + def _u_random_sample_sql(self, proportion, sample_size, seed=None): + sql = f""" + select * + from __splink__df_concat_with_tf + {self._random_sample_sql(proportion, sample_size, seed)} + """ + return sql + @property def _infinity_expression(self): return "'infinity'" From 7cf0e12ef42a67f4d00bda7eb4641e78ffebf026 Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Fri, 31 Mar 2023 15:37:08 +0100 Subject: [PATCH 06/17] spark not working --- splink/spark/spark_linker.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/splink/spark/spark_linker.py b/splink/spark/spark_linker.py index 1d27fc7b55..7e89542e9d 100644 --- a/splink/spark/spark_linker.py +++ b/splink/spark/spark_linker.py @@ -463,12 +463,32 @@ def register_table(self, input, table_name, overwrite=False): input.createOrReplaceTempView(table_name) return self._table_to_splink_dataframe(table_name, table_name) - def _random_sample_sql(self, proportion, sample_size): + def _random_sample_sql(self, proportion, sample_size, seed=None): if proportion == 1.0: return "" percent = proportion * 100 return f" TABLESAMPLE ({percent} PERCENT) " + def _u_random_sample_sql(self, proportion, sample_size, seed=None): + """from pyspark.sql.functions import udf, struct, col, map_values + from pyspark.sql.types import MapType, StringType + from pyspark.sql.functions import udf + + # Define the PySpark function to sample a map of columns + def sample_columns(cols, with_replacement=True, fraction=0.5, seed=None): + return {k: list(set(v).sample(with_replacement, fraction, seed)) for k, v in cols.items()} + + # Register the PySpark function as a Spark SQL UDF + sample_udf = udf(sample_columns, MapType(StringType(), ArrayType(StringType()))) + """ + + sql = f""" + select * + from __splink__df_concat_with_tf + {self._random_sample_sql(proportion, sample_size, seed)} + """ + return sql + def _table_exists_in_database(self, table_name): query_result = self.spark.sql( f"show tables from {self.splink_data_store} like '{table_name}'" From bfacc5903663ade95fe435d211c94ef7235df2d6 Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Fri, 31 Mar 2023 16:31:25 +0100 Subject: [PATCH 07/17] docs --- splink/linker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/splink/linker.py b/splink/linker.py index f1a830acdd..4d87b67238 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -1021,7 +1021,8 @@ def estimate_u_using_random_sampling( is often adequate whilst testing different model specifications, before the final model is estimated. seed (int): Seed for random sampling. Assign to get reproducible u - probabilities + probabilities. Note, seed for random sampling is only supported for + DuckDB and Spark, for Athena and SQLite set to None. Examples: >>> linker.estimate_u_using_random_sampling(1e8) From 49cd395f4d8e750c1d68e35ad07a74dc6ca3c12c Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Mon, 10 Apr 2023 00:42:13 +0100 Subject: [PATCH 08/17] test repeatable for spark --- splink/spark/spark_linker.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/splink/spark/spark_linker.py b/splink/spark/spark_linker.py index 7e89542e9d..4db8d0fc88 100644 --- a/splink/spark/spark_linker.py +++ b/splink/spark/spark_linker.py @@ -467,7 +467,10 @@ def _random_sample_sql(self, proportion, sample_size, seed=None): if proportion == 1.0: return "" percent = proportion * 100 - return f" TABLESAMPLE ({percent} PERCENT) " + if seed: + return f" TABLESAMPLE ({percent} PERCENT) REPEATABLE({seed})" + else: + return f" TABLESAMPLE ({percent} PERCENT) " def _u_random_sample_sql(self, proportion, sample_size, seed=None): """from pyspark.sql.functions import udf, struct, col, map_values From 9f5f1640d08e3f3065159ed8103f4c8e7da6c7ec Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Mon, 10 Apr 2023 08:23:33 +0100 Subject: [PATCH 09/17] test order by rand and limit --- splink/spark/spark_linker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/splink/spark/spark_linker.py b/splink/spark/spark_linker.py index 4db8d0fc88..24c5c92194 100644 --- a/splink/spark/spark_linker.py +++ b/splink/spark/spark_linker.py @@ -468,7 +468,7 @@ def _random_sample_sql(self, proportion, sample_size, seed=None): return "" percent = proportion * 100 if seed: - return f" TABLESAMPLE ({percent} PERCENT) REPEATABLE({seed})" + return f" ORDER BY rand({seed}) LIMIT {round(sample_size)})" else: return f" TABLESAMPLE ({percent} PERCENT) " From 6d1ba01cc05c9c12472e3565b71c7646072cd220 Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Mon, 10 Apr 2023 08:28:35 +0100 Subject: [PATCH 10/17] lint --- scripts/make_test_datasets_smaller.py | 1 - splink/comparison_library.py | 1 - splink/comparison_library_utils.py | 3 --- splink/linker.py | 2 +- splink/spark/spark_linker.py | 5 +++-- tests/test_comparison_template_lib.py | 2 -- tests/test_datediff_level.py | 3 --- tests/test_km_distance_level.py | 2 -- tests/test_u_train.py | 1 - 9 files changed, 4 insertions(+), 16 deletions(-) diff --git a/scripts/make_test_datasets_smaller.py b/scripts/make_test_datasets_smaller.py index 633bddd135..e7fd608828 100644 --- a/scripts/make_test_datasets_smaller.py +++ b/scripts/make_test_datasets_smaller.py @@ -21,7 +21,6 @@ def process_directory(directory): print("Starting truncating files") for root, _, files in os.walk(directory): for file in files: - file_path = os.path.join(root, file) file_ext = os.path.splitext(file)[-1].lower() diff --git a/splink/comparison_library.py b/splink/comparison_library.py index cf932325cb..4832ca2cac 100644 --- a/splink/comparison_library.py +++ b/splink/comparison_library.py @@ -595,7 +595,6 @@ def __init__( comparison_levels.append(null_level) if include_exact_match_level: - label_suffix = f" {lat_col}, {long_col}" level = { "sql_condition": f"({lat_col}_l = {lat_col}_r) \n" diff --git a/splink/comparison_library_utils.py b/splink/comparison_library_utils.py index 465b41e46c..46dfcbce48 100644 --- a/splink/comparison_library_utils.py +++ b/splink/comparison_library_utils.py @@ -1,5 +1,4 @@ def comparison_at_thresholds_error_logger(comparison, thresholds): - error_logger = [] if len(thresholds) == 0: @@ -12,7 +11,6 @@ def comparison_at_thresholds_error_logger(comparison, thresholds): error_logger.append("All entries of `thresholds` must be postive") if len(error_logger) > 0: - error_logger.insert( 0, f"The following error(s) were identified while validating " @@ -25,7 +23,6 @@ def comparison_at_thresholds_error_logger(comparison, thresholds): def datediff_error_logger(thresholds, metrics): - # Extracted from the DateDiffAtThresholdsComparisonBase class as that was overly # verbose and failing the lint. diff --git a/splink/linker.py b/splink/linker.py index 4d87b67238..5fb6d68dd6 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -1021,7 +1021,7 @@ def estimate_u_using_random_sampling( is often adequate whilst testing different model specifications, before the final model is estimated. seed (int): Seed for random sampling. Assign to get reproducible u - probabilities. Note, seed for random sampling is only supported for + probabilities. Note, seed for random sampling is only supported for DuckDB and Spark, for Athena and SQLite set to None. Examples: diff --git a/splink/spark/spark_linker.py b/splink/spark/spark_linker.py index 24c5c92194..a6a1e65481 100644 --- a/splink/spark/spark_linker.py +++ b/splink/spark/spark_linker.py @@ -479,12 +479,13 @@ def _u_random_sample_sql(self, proportion, sample_size, seed=None): # Define the PySpark function to sample a map of columns def sample_columns(cols, with_replacement=True, fraction=0.5, seed=None): - return {k: list(set(v).sample(with_replacement, fraction, seed)) for k, v in cols.items()} + return {k: list(set(v).sample(with_replacement, fraction, seed)) for + k, v in cols.items()} # Register the PySpark function as a Spark SQL UDF sample_udf = udf(sample_columns, MapType(StringType(), ArrayType(StringType()))) """ - + sql = f""" select * from __splink__df_concat_with_tf diff --git a/tests/test_comparison_template_lib.py b/tests/test_comparison_template_lib.py index 031d5d244f..60762ded80 100644 --- a/tests/test_comparison_template_lib.py +++ b/tests/test_comparison_template_lib.py @@ -38,7 +38,6 @@ def test_date_comparison_jw_run(ctl): ], ) def test_datediff_levels(spark, ctl, Linker): - df = pd.DataFrame( [ { @@ -187,7 +186,6 @@ def test_name_comparison_run(ctl): ], ) def test_name_comparison_levels(spark, ctl, Linker): - df = pd.DataFrame( [ { diff --git a/tests/test_datediff_level.py b/tests/test_datediff_level.py index 1af4ea6774..b0c38a1436 100644 --- a/tests/test_datediff_level.py +++ b/tests/test_datediff_level.py @@ -17,7 +17,6 @@ ], ) def test_datediff_levels(spark, cl, cll, Linker): - # Capture differing comparison levels to allow unique settings generation df = pd.DataFrame( [ @@ -142,7 +141,6 @@ def test_datediff_levels(spark, cl, cll, Linker): for gamma, id_pairs in gamma_lookup.items(): for left, right in id_pairs: for linker_name, linker_pred in linker_outputs.items(): - print(f"Checking IDs: {left}, {right} for {linker_name}") assert ( @@ -162,7 +160,6 @@ def test_datediff_levels(spark, cl, cll, Linker): ], ) def test_datediff_error_logger(cl): - # Differing lengths between thresholds and units with pytest.raises(ValueError): cl.datediff_at_thresholds("dob", [1], ["day", "month", "year", "year"]) diff --git a/tests/test_km_distance_level.py b/tests/test_km_distance_level.py index 0cce70744b..2aaffa2ac9 100644 --- a/tests/test_km_distance_level.py +++ b/tests/test_km_distance_level.py @@ -17,7 +17,6 @@ ], ) def test_simple_run(cl): - print( cl.distance_in_km_at_thresholds( lat_col="lat", long_col="long", km_thresholds=[1, 5, 10] @@ -190,7 +189,6 @@ def test_km_distance_levels(spark, cl, cll, Linker): for gamma, id_pairs in gamma_lookup.items(): for left, right in id_pairs: for linker_name, linker_pred in linker_outputs.items(): - print(f"Checking IDs: {left}, {right} for {linker_name}") gamma_column_name_options = [ diff --git a/tests/test_u_train.py b/tests/test_u_train.py index d17a08d9d5..634fb57935 100644 --- a/tests/test_u_train.py +++ b/tests/test_u_train.py @@ -238,7 +238,6 @@ def test_u_train_multilink(): def test_seed_u_outputs(): - df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") settings = { From d785f5629d091305244f88951d199ea875df4d8d Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Mon, 10 Apr 2023 08:40:57 +0100 Subject: [PATCH 11/17] fix typo --- splink/spark/spark_linker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/splink/spark/spark_linker.py b/splink/spark/spark_linker.py index a6a1e65481..9196e40d2e 100644 --- a/splink/spark/spark_linker.py +++ b/splink/spark/spark_linker.py @@ -468,7 +468,7 @@ def _random_sample_sql(self, proportion, sample_size, seed=None): return "" percent = proportion * 100 if seed: - return f" ORDER BY rand({seed}) LIMIT {round(sample_size)})" + return f" ORDER BY rand({seed}) LIMIT {round(sample_size)}" else: return f" TABLESAMPLE ({percent} PERCENT) " From 50de05f5cc9f9ae425b6937d623823da33e8c9ba Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Mon, 10 Apr 2023 08:51:49 +0100 Subject: [PATCH 12/17] remove unused import --- splink/misc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/splink/misc.py b/splink/misc.py index 9cd62610c6..3bd4f7a559 100644 --- a/splink/misc.py +++ b/splink/misc.py @@ -1,4 +1,3 @@ -import itertools import json import random import string From 6bce59a05e63ac373bf0a2c771fb4ae03ef3d436 Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Mon, 10 Apr 2023 09:02:19 +0100 Subject: [PATCH 13/17] re-simplify functions --- splink/athena/athena_linker.py | 8 -------- splink/duckdb/duckdb_linker.py | 8 -------- splink/estimate_u.py | 6 +++++- splink/misc.py | 1 + splink/spark/spark_linker.py | 21 --------------------- splink/sqlite/sqlite_linker.py | 8 -------- 6 files changed, 6 insertions(+), 46 deletions(-) diff --git a/splink/athena/athena_linker.py b/splink/athena/athena_linker.py index 0ad0d6b973..0d57b2646a 100644 --- a/splink/athena/athena_linker.py +++ b/splink/athena/athena_linker.py @@ -355,14 +355,6 @@ def _random_sample_sql(self, proportion, sample_size, seed=None): percent = proportion * 100 return f" TABLESAMPLE BERNOULLI ({percent})" - def _u_random_sample_sql(self, proportion, sample_size, seed=None): - sql = f""" - select * - from __splink__df_concat_with_tf - {self._random_sample_sql(proportion, sample_size, seed)} - """ - return sql - @property def _infinity_expression(self): return "infinity()" diff --git a/splink/duckdb/duckdb_linker.py b/splink/duckdb/duckdb_linker.py index 9fc8a5a7fe..08671e2933 100644 --- a/splink/duckdb/duckdb_linker.py +++ b/splink/duckdb/duckdb_linker.py @@ -218,14 +218,6 @@ def _random_sample_sql(self, proportion, sample_size, seed=None): else: return f"USING SAMPLE {percent}% (bernoulli)" - def _u_random_sample_sql(self, proportion, sample_size, seed=None): - sql = f""" - select * - from __splink__df_concat_with_tf - {self._random_sample_sql(proportion, sample_size, seed)} - """ - return sql - @property def _infinity_expression(self): return "cast('infinity' as double)" diff --git a/splink/estimate_u.py b/splink/estimate_u.py index b51e7e4bed..4aa7b92200 100644 --- a/splink/estimate_u.py +++ b/splink/estimate_u.py @@ -90,7 +90,11 @@ def estimate_u_values(linker: Linker, max_pairs, seed=None): if sample_size > count_rows: sample_size = count_rows - sql = training_linker._u_random_sample_sql(proportion, sample_size, seed) + sql = f""" + select * + from __splink__df_concat_with_tf + {training_linker._random_sample_sql(proportion, sample_size)} + """ training_linker._enqueue_sql(sql, "__splink__df_concat_with_tf_sample") df_sample = training_linker._execute_sql_pipeline([nodes_with_tf]) diff --git a/splink/misc.py b/splink/misc.py index 3bd4f7a559..9cd62610c6 100644 --- a/splink/misc.py +++ b/splink/misc.py @@ -1,3 +1,4 @@ +import itertools import json import random import string diff --git a/splink/spark/spark_linker.py b/splink/spark/spark_linker.py index 9196e40d2e..ccc3f77e08 100644 --- a/splink/spark/spark_linker.py +++ b/splink/spark/spark_linker.py @@ -472,27 +472,6 @@ def _random_sample_sql(self, proportion, sample_size, seed=None): else: return f" TABLESAMPLE ({percent} PERCENT) " - def _u_random_sample_sql(self, proportion, sample_size, seed=None): - """from pyspark.sql.functions import udf, struct, col, map_values - from pyspark.sql.types import MapType, StringType - from pyspark.sql.functions import udf - - # Define the PySpark function to sample a map of columns - def sample_columns(cols, with_replacement=True, fraction=0.5, seed=None): - return {k: list(set(v).sample(with_replacement, fraction, seed)) for - k, v in cols.items()} - - # Register the PySpark function as a Spark SQL UDF - sample_udf = udf(sample_columns, MapType(StringType(), ArrayType(StringType()))) - """ - - sql = f""" - select * - from __splink__df_concat_with_tf - {self._random_sample_sql(proportion, sample_size, seed)} - """ - return sql - def _table_exists_in_database(self, table_name): query_result = self.spark.sql( f"show tables from {self.splink_data_store} like '{table_name}'" diff --git a/splink/sqlite/sqlite_linker.py b/splink/sqlite/sqlite_linker.py index 173e9cab78..b9da998fb8 100644 --- a/splink/sqlite/sqlite_linker.py +++ b/splink/sqlite/sqlite_linker.py @@ -162,14 +162,6 @@ def _random_sample_sql(self, proportion, sample_size, seed=None): f" ORDER BY RANDOM() LIMIT {sample_size})" ) - def _u_random_sample_sql(self, proportion, sample_size, seed=None): - sql = f""" - select * - from __splink__df_concat_with_tf - {self._random_sample_sql(proportion, sample_size, seed)} - """ - return sql - @property def _infinity_expression(self): return "'infinity'" From e3532c88f08513dd72c55d2ad7008d0418114408 Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Mon, 10 Apr 2023 10:24:53 +0100 Subject: [PATCH 14/17] readd omitted seed parameter --- splink/estimate_u.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/splink/estimate_u.py b/splink/estimate_u.py index 4aa7b92200..c62edba028 100644 --- a/splink/estimate_u.py +++ b/splink/estimate_u.py @@ -93,7 +93,7 @@ def estimate_u_values(linker: Linker, max_pairs, seed=None): sql = f""" select * from __splink__df_concat_with_tf - {training_linker._random_sample_sql(proportion, sample_size)} + {training_linker._random_sample_sql(proportion, sample_size, seed)} """ training_linker._enqueue_sql(sql, "__splink__df_concat_with_tf_sample") df_sample = training_linker._execute_sql_pipeline([nodes_with_tf]) From 9ca96d3f328343b8a38d71c35c72ba9e31eed9d4 Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Tue, 11 Apr 2023 18:06:04 +0100 Subject: [PATCH 15/17] remove print statement and parametrize for spark --- tests/test_u_train.py | 41 +++++++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/tests/test_u_train.py b/tests/test_u_train.py index 634fb57935..b1239e2223 100644 --- a/tests/test_u_train.py +++ b/tests/test_u_train.py @@ -1,9 +1,11 @@ import numpy as np import pandas as pd -from pytest import approx +import pytest -from splink.duckdb.duckdb_comparison_library import levenshtein_at_thresholds +import splink.duckdb.duckdb_comparison_library as cld +import splink.spark.spark_comparison_library as clsp from splink.duckdb.duckdb_linker import DuckDBLinker +from splink.spark.spark_linker import SparkLinker def test_u_train(): @@ -19,7 +21,7 @@ def test_u_train(): settings = { "link_type": "dedupe_only", - "comparisons": [levenshtein_at_thresholds("name", 2)], + "comparisons": [cld.levenshtein_at_thresholds("name", 2)], "blocking_rules_to_generate_predictions": ["l.name = r.name"], } @@ -63,7 +65,7 @@ def test_u_train_link_only(): settings = { "link_type": "link_only", - "comparisons": [levenshtein_at_thresholds("name", 2)], + "comparisons": [cld.levenshtein_at_thresholds("name", 2)], "blocking_rules_to_generate_predictions": [], } @@ -111,7 +113,7 @@ def test_u_train_link_only_sample(): settings = { "link_type": "link_only", - "comparisons": [levenshtein_at_thresholds("name", 2)], + "comparisons": [cld.levenshtein_at_thresholds("name", 2)], "blocking_rules_to_generate_predictions": [], } @@ -133,7 +135,7 @@ def test_u_train_link_only_sample(): max_pairs_proportion = result[0]["count"] / max_pairs # equality only holds probabilistically # chance of failure is approximately 1e-06 - assert approx(max_pairs_proportion, 0.15) == 1.0 + assert pytest.approx(max_pairs_proportion, 0.15) == 1.0 def test_u_train_multilink(): @@ -170,7 +172,7 @@ def test_u_train_multilink(): settings = { "link_type": "link_only", - "comparisons": [levenshtein_at_thresholds("name", 2)], + "comparisons": [cld.levenshtein_at_thresholds("name", 2)], "blocking_rules_to_generate_predictions": [], } @@ -237,25 +239,32 @@ def test_u_train_multilink(): assert cl_no.u_probability == (denom - 10) / denom -def test_seed_u_outputs(): - df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") +@pytest.mark.parametrize( + ("Linker", "cll"), + [ + pytest.param(DuckDBLinker, cld, id="Test DuckDB random seeds"), + pytest.param(SparkLinker, clsp, id="Test Spark random seeds"), + ], +) +def test_seed_u_outputs(df_spark, Linker, cll): + if Linker == SparkLinker: + df = df_spark + else: + df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") settings = { "link_type": "dedupe_only", - "comparisons": [levenshtein_at_thresholds("first_name", 2)], - "blocking_rules_to_generate_predictions": [], + "comparisons": [cll.levenshtein_at_thresholds("first_name", 2)], } - linker_1 = DuckDBLinker(df, settings) - linker_2 = DuckDBLinker(df, settings) - linker_3 = DuckDBLinker(df, settings) + linker_1 = Linker(df, settings) + linker_2 = Linker(df, settings) + linker_3 = Linker(df, settings) linker_1.estimate_u_using_random_sampling(max_pairs=1e3, seed=1) linker_2.estimate_u_using_random_sampling(max_pairs=1e3, seed=1) linker_3.estimate_u_using_random_sampling(max_pairs=1e3, seed=2) - print(linker_1._settings_obj._parameter_estimates_as_records) - assert ( linker_1._settings_obj._parameter_estimates_as_records == linker_2._settings_obj._parameter_estimates_as_records From e4b0afd9c27e5874b185b035dede9c29c321ec00 Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Tue, 11 Apr 2023 18:18:53 +0100 Subject: [PATCH 16/17] add performance caveat and remove itertools import --- splink/linker.py | 4 ++++ splink/misc.py | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/splink/linker.py b/splink/linker.py index 990597af16..0a04e8d10a 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -1045,6 +1045,10 @@ def estimate_u_using_random_sampling( pairwise comparisons are non-matches (or at least, they are very unlikely to be matches). For large datasets, this is typically true. + The results of estimate_u_using_random_sampling, and therefore an entire splink + model, can be made reproducible by setting the seed parameter. Setting the seed + will have performance implications as additional processing is required. + Args: max_pairs (int): The maximum number of pairwise record comparisons to sample. Larger will give more accurate estimates diff --git a/splink/misc.py b/splink/misc.py index ccef200451..b9f8d8d86a 100644 --- a/splink/misc.py +++ b/splink/misc.py @@ -1,4 +1,3 @@ -import itertools import json import random import string From 8502d8d9bbc724bf896a7e837801c92a12c84fa7 Mon Sep 17 00:00:00 2001 From: Ross Kennedy Date: Tue, 11 Apr 2023 17:19:35 +0000 Subject: [PATCH 17/17] lint with black --- splink/linker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/splink/linker.py b/splink/linker.py index 0a04e8d10a..764249c18c 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -1045,8 +1045,8 @@ def estimate_u_using_random_sampling( pairwise comparisons are non-matches (or at least, they are very unlikely to be matches). For large datasets, this is typically true. - The results of estimate_u_using_random_sampling, and therefore an entire splink - model, can be made reproducible by setting the seed parameter. Setting the seed + The results of estimate_u_using_random_sampling, and therefore an entire splink + model, can be made reproducible by setting the seed parameter. Setting the seed will have performance implications as additional processing is required. Args: