Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to pass seed into estimate_u_using_random_sampling #1161

Merged
merged 18 commits into from
Apr 12, 2023
Merged
2 changes: 1 addition & 1 deletion docs/comparison_template_library.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ tags:

# Documentation for `comparison_template_library`

The `comparison_template_library` contains pre-made comparisons with pre-defined parameters available for use directly [as described in this topic guide](../topic_guides/customising_comparisons.html#method-2-using-the-comparisontemplatelibrary).
The `comparison_template_library` contains pre-made comparisons with pre-defined parameters available for use directly [as described in this topic guide](./topic_guides/customising_comparisons.html#method-2-using-the-comparisontemplatelibrary).
However, not every comparison is available for every [Splink-compatible SQL backend](./topic_guides/backends.html). More detail on creating comparisons for specific data types is also [included in the topic guide.](./topic_guides/customising_comparisons.html#creating-comparisons-for-specific-data-types)

The detailed API for each of these are outlined below.
Expand Down
1 change: 0 additions & 1 deletion scripts/make_test_datasets_smaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def process_directory(directory):
print("Starting truncating files")
for root, _, files in os.walk(directory):
for file in files:

file_path = os.path.join(root, file)
file_ext = os.path.splitext(file)[-1].lower()

Expand Down
7 changes: 6 additions & 1 deletion splink/athena/athena_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,14 @@ def register_table(self, input, table_name, overwrite=False):
self.register_data_on_s3(input, table_name)
return self._table_to_splink_dataframe(table_name, table_name)

def _random_sample_sql(self, proportion, sample_size):
def _random_sample_sql(self, proportion, sample_size, seed=None):
if proportion == 1.0:
return ""
if seed:
raise NotImplementedError(
"Athena does not support seeds in random ",
"samples. Please remove the `seed` parameter.",
)
ThomasHepworth marked this conversation as resolved.
Show resolved Hide resolved
percent = proportion * 100
return f" TABLESAMPLE BERNOULLI ({percent})"

Expand Down
7 changes: 4 additions & 3 deletions splink/cluster_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def df_edges_as_records(


def _get_random_cluster_ids(
linker: "Linker", connected_components: SplinkDataFrame, sample_size: int
linker: "Linker", connected_components: SplinkDataFrame, sample_size: int, seed=None
):
sql = f"""
select count(distinct cluster_id) as count
Expand All @@ -137,7 +137,7 @@ def _get_random_cluster_ids(
from {connected_components.physical_name}
)
select cluster_id from distinct_clusters
{linker._random_sample_sql(proportion, sample_size)}
{linker._random_sample_sql(proportion, sample_size, seed)}
"""

df_sample = linker._sql_to_splink_dataframe_checking_cache(
Expand Down Expand Up @@ -188,6 +188,7 @@ def render_splink_cluster_studio_html(
out_path: str,
sampling_method="random",
sample_size=10,
sample_seed=None,
cluster_ids: list = None,
cluster_names: list = None,
overwrite: bool = False,
Expand All @@ -202,7 +203,7 @@ def render_splink_cluster_studio_html(
if cluster_ids is None:
if sampling_method == "random":
cluster_ids = _get_random_cluster_ids(
linker, df_clustered_nodes, sample_size
linker, df_clustered_nodes, sample_size, sample_seed
)
if sampling_method == "by_cluster_size":
cluster_ids = _get_cluster_id_of_each_size(linker, df_clustered_nodes, 1)
Expand Down
1 change: 0 additions & 1 deletion splink/comparison_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,6 @@ def __init__(
comparison_levels.append(null_level)

if include_exact_match_level:

label_suffix = f" {lat_col}, {long_col}"
level = {
"sql_condition": f"({lat_col}_l = {lat_col}_r) \n"
Expand Down
3 changes: 0 additions & 3 deletions splink/comparison_library_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
def comparison_at_thresholds_error_logger(comparison, thresholds):

error_logger = []

if len(thresholds) == 0:
Expand All @@ -12,7 +11,6 @@ def comparison_at_thresholds_error_logger(comparison, thresholds):
error_logger.append("All entries of `thresholds` must be postive")

if len(error_logger) > 0:

error_logger.insert(
0,
f"The following error(s) were identified while validating "
Expand All @@ -25,7 +23,6 @@ def comparison_at_thresholds_error_logger(comparison, thresholds):


def datediff_error_logger(thresholds, metrics):

# Extracted from the DateDiffAtThresholdsComparisonBase class as that was overly
# verbose and failing the lint.

Expand Down
7 changes: 5 additions & 2 deletions splink/duckdb/duckdb_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,14 @@ def register_table(self, input, table_name, overwrite=False):
self._con.register(table_name, input)
return self._table_to_splink_dataframe(table_name, table_name)

def _random_sample_sql(self, proportion, sample_size):
def _random_sample_sql(self, proportion, sample_size, seed=None):
if proportion == 1.0:
return ""
percent = proportion * 100
return f"USING SAMPLE {percent}% (bernoulli)"
if seed:
return f"USING SAMPLE bernoulli({percent}%) REPEATABLE({seed})"
else:
return f"USING SAMPLE {percent}% (bernoulli)"

@property
def _infinity_expression(self):
Expand Down
4 changes: 2 additions & 2 deletions splink/estimate_u.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _rows_needed_for_n_pairs(n_pairs):
return sample_rows


def estimate_u_values(linker: Linker, max_pairs):
def estimate_u_values(linker: Linker, max_pairs, seed=None):
logger.info("----- Estimating u probabilities using random sampling -----")

nodes_with_tf = linker._initialise_df_concat_with_tf()
Expand Down Expand Up @@ -93,7 +93,7 @@ def estimate_u_values(linker: Linker, max_pairs):
sql = f"""
select *
from __splink__df_concat_with_tf
{training_linker._random_sample_sql(proportion, sample_size)}
{training_linker._random_sample_sql(proportion, sample_size, seed)}
"""
training_linker._enqueue_sql(sql, "__splink__df_concat_with_tf_sample")
df_sample = training_linker._execute_sql_pipeline([nodes_with_tf])
Expand Down
7 changes: 5 additions & 2 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ def deterministic_link(self) -> SplinkDataFrame:
return self._execute_sql_pipeline([concat_with_tf])

def estimate_u_using_random_sampling(
self, max_pairs: int = None, *, target_rows=None
self, max_pairs: int = None, seed: int = None, *, target_rows=None
):
"""Estimate the u parameters of the linkage model using random sampling.

Expand All @@ -1020,6 +1020,9 @@ def estimate_u_using_random_sampling(
gives best results but can take a long time to compute. 1e7 (ten million)
is often adequate whilst testing different model specifications, before
the final model is estimated.
seed (int): Seed for random sampling. Assign to get reproducible u
probabilities. Note, seed for random sampling is only supported for
DuckDB and Spark, for Athena and SQLite set to None.

Examples:
>>> linker.estimate_u_using_random_sampling(1e8)
Expand Down Expand Up @@ -1047,7 +1050,7 @@ def estimate_u_using_random_sampling(
else:
raise TypeError("Missing argument max_pairs")

estimate_u_values(self, max_pairs)
estimate_u_values(self, max_pairs, seed)
self._populate_m_u_from_trained_values()

self._settings_obj._columns_without_estimated_parameters_message()
Expand Down
7 changes: 5 additions & 2 deletions splink/spark/spark_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,11 +463,14 @@ def register_table(self, input, table_name, overwrite=False):
input.createOrReplaceTempView(table_name)
return self._table_to_splink_dataframe(table_name, table_name)

def _random_sample_sql(self, proportion, sample_size):
def _random_sample_sql(self, proportion, sample_size, seed=None):
if proportion == 1.0:
return ""
percent = proportion * 100
return f" TABLESAMPLE ({percent} PERCENT) "
if seed:
return f" ORDER BY rand({seed}) LIMIT {round(sample_size)}"
else:
return f" TABLESAMPLE ({percent} PERCENT) "

def _table_exists_in_database(self, table_name):
query_result = self.spark.sql(
Expand Down
8 changes: 6 additions & 2 deletions splink/sqlite/sqlite_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,16 @@ def register_table(self, input, table_name, overwrite=False):
input.to_sql(table_name, self.con, index=False)
return self._table_to_splink_dataframe(table_name, table_name)

def _random_sample_sql(self, proportion, sample_size):
def _random_sample_sql(self, proportion, sample_size, seed=None):
if proportion == 1.0:
return ""
if seed:
raise NotImplementedError(
"SQLite does not support seeds in random ",
"samples. Please remove the `seed` parameter.",
)

sample_size = int(sample_size)

return (
"where unique_id IN (SELECT unique_id FROM __splink__df_concat_with_tf"
f" ORDER BY RANDOM() LIMIT {sample_size})"
Expand Down
2 changes: 0 additions & 2 deletions tests/test_comparison_template_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def test_date_comparison_jw_run(ctl):
],
)
def test_datediff_levels(spark, ctl, Linker):

df = pd.DataFrame(
[
{
Expand Down Expand Up @@ -187,7 +186,6 @@ def test_name_comparison_run(ctl):
],
)
def test_name_comparison_levels(spark, ctl, Linker):

df = pd.DataFrame(
[
{
Expand Down
3 changes: 0 additions & 3 deletions tests/test_datediff_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
],
)
def test_datediff_levels(spark, cl, cll, Linker):

# Capture differing comparison levels to allow unique settings generation
df = pd.DataFrame(
[
Expand Down Expand Up @@ -142,7 +141,6 @@ def test_datediff_levels(spark, cl, cll, Linker):
for gamma, id_pairs in gamma_lookup.items():
for left, right in id_pairs:
for linker_name, linker_pred in linker_outputs.items():

print(f"Checking IDs: {left}, {right} for {linker_name}")

assert (
Expand All @@ -162,7 +160,6 @@ def test_datediff_levels(spark, cl, cll, Linker):
],
)
def test_datediff_error_logger(cl):

# Differing lengths between thresholds and units
with pytest.raises(ValueError):
cl.datediff_at_thresholds("dob", [1], ["day", "month", "year", "year"])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_full_example_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_full_example_athena(tmp_path):
linker.compute_tf_table("city")
linker.compute_tf_table("first_name")

linker.estimate_u_using_random_sampling(max_pairs=1e6)
linker.estimate_u_using_random_sampling(max_pairs=1e6, seed=1)

blocking_rule = "l.first_name = r.first_name and l.surname = r.surname"
linker.estimate_parameters_using_expectation_maximisation(blocking_rule)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_full_example_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_full_example_duckdb(tmp_path):
linker.compute_tf_table("city")
linker.compute_tf_table("first_name")

linker.estimate_u_using_random_sampling(max_pairs=1e6)
linker.estimate_u_using_random_sampling(max_pairs=1e6, seed=1)
linker.estimate_probability_two_random_records_match(
["l.email = r.email"], recall=0.3
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_full_example_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_full_example_spark(df_spark, tmp_path):
linker.estimate_probability_two_random_records_match(
["l.email = r.email"], recall=0.3
)
linker.estimate_u_using_random_sampling(max_pairs=1e5)
linker.estimate_u_using_random_sampling(max_pairs=1e5, seed=1)

blocking_rule = "l.first_name = r.first_name and l.surname = r.surname"
linker.estimate_parameters_using_expectation_maximisation(blocking_rule)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_full_example_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def power(val, exp):
["l.email = r.email"], recall=0.3
)

linker.estimate_u_using_random_sampling(max_pairs=1e6)
linker.estimate_u_using_random_sampling(max_pairs=1e6, seed=1)

blocking_rule = "l.first_name = r.first_name and l.surname = r.surname"
linker.estimate_parameters_using_expectation_maximisation(blocking_rule)
Expand Down
2 changes: 0 additions & 2 deletions tests/test_km_distance_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
],
)
def test_simple_run(cl):

print(
cl.distance_in_km_at_thresholds(
lat_col="lat", long_col="long", km_thresholds=[1, 5, 10]
Expand Down Expand Up @@ -190,7 +189,6 @@ def test_km_distance_levels(spark, cl, cll, Linker):
for gamma, id_pairs in gamma_lookup.items():
for left, right in id_pairs:
for linker_name, linker_pred in linker_outputs.items():

print(f"Checking IDs: {left}, {right} for {linker_name}")

gamma_column_name_options = [
Expand Down
29 changes: 29 additions & 0 deletions tests/test_u_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,32 @@ def test_u_train_multilink():
assert cl_lev.u_probability == 1 / denom
cl_no = cc_name._get_comparison_level_by_comparison_vector_value(0)
assert cl_no.u_probability == (denom - 10) / denom


def test_seed_u_outputs():
df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv")

settings = {
"link_type": "dedupe_only",
"comparisons": [levenshtein_at_thresholds("first_name", 2)],
"blocking_rules_to_generate_predictions": [],
}

linker_1 = DuckDBLinker(df, settings)
linker_2 = DuckDBLinker(df, settings)
linker_3 = DuckDBLinker(df, settings)

linker_1.estimate_u_using_random_sampling(max_pairs=1e3, seed=1)
linker_2.estimate_u_using_random_sampling(max_pairs=1e3, seed=1)
linker_3.estimate_u_using_random_sampling(max_pairs=1e3, seed=2)
RossKen marked this conversation as resolved.
Show resolved Hide resolved

print(linker_1._settings_obj._parameter_estimates_as_records)
RossKen marked this conversation as resolved.
Show resolved Hide resolved

assert (
linker_1._settings_obj._parameter_estimates_as_records
== linker_2._settings_obj._parameter_estimates_as_records
)
assert (
linker_1._settings_obj._parameter_estimates_as_records
!= linker_3._settings_obj._parameter_estimates_as_records
)