Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fix source dataset issue when running link jobs #1193

Merged
merged 3 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion splink/block_from_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def block_from_labels(

unique_id_col = linker._settings_obj._unique_id_column_name

source_dataset_col = linker._settings_obj._source_dataset_column_name
source_dataset_col = linker._settings_obj._source_dataset_input_column

sql = lower_id_to_left_hand_side(df, source_dataset_col, unique_id_col)

Expand Down
2 changes: 1 addition & 1 deletion splink/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def block_using_rules_sql(linker: Linker):
and not linker._find_new_matches_mode
and not linker._compare_two_records_mode
):
source_dataset_col = linker._settings_obj._source_dataset_column_name
source_dataset_col = linker._source_dataset_column_name
# Need df_l to be the one with the lowest id to preeserve the property
# that the left dataset is the one with the lowest concatenated id
keys = linker._input_tables_dict.keys()
Expand Down
38 changes: 34 additions & 4 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ def __init__(
homogenised_tables, homogenised_aliases
)

self._names_of_tables_created_by_splink: set = set()
self._intermediate_table_cache: dict = CacheDictWithLogging()

if not isinstance(settings_dict, (dict, type(None))):
self._setup_settings_objs(None) # feed it a blank settings dictionary
self.load_settings(settings_dict)
Expand All @@ -199,9 +202,6 @@ def __init__(
self._validate_input_dfs()
self._em_training_sessions = []

self._names_of_tables_created_by_splink: set = set()
self._intermediate_table_cache: dict = CacheDictWithLogging()

self._find_new_matches_mode = False
self._train_u_using_random_sample_mode = False
self._compare_two_records_mode = False
Expand Down Expand Up @@ -278,6 +278,22 @@ def _input_tablename_r(self):
return "__splink_df_concat_with_tf_right"
return "__splink__df_concat_with_tf"

@property
def _source_dataset_column_name(self):
if self._settings_obj_ is None:
return None

# Used throughout the scripts to feed our SQL
if self._settings_obj._source_dataset_column_name_is_required:
df_obj = next(iter(self._input_tables_dict.values()))
columns = df_obj.columns_escaped

input_column = self._settings_obj._source_dataset_input_column
src_ds_col = InputColumn(input_column, self).name()
return "__splink_source_dataset" if src_ds_col in columns else input_column
else:
return None

@property
def _two_dataset_link_only(self):
# Two dataset link only join is a special case where an inner join of the
Expand Down Expand Up @@ -904,8 +920,22 @@ def load_settings(self, settings_dict: dict | str | Path):
)
settings_dict = json.loads(p.read_text())

# Store the cache ID so it can be reloaded after cache invalidation
cache_id = self._cache_uid
# So we don't run into any issues with generated tables having
# invalid columns as settings have been tweaked, invalidate
# the cache and allow these tables to be recomputed.

# This is less efficient, but triggers infrequently and ensures we don't
# run into issues where the defaults used conflict with the actual values
# supplied in settings.

# This is particularly relevant with `source_dataset`, which appears within
# concat_with_tf.
self.invalidate_cache()

# If a uid already exists in your settings object, prioritise this
settings_dict["linker_uid"] = settings_dict.get("linker_uid", self._cache_uid)
settings_dict["linker_uid"] = settings_dict.get("linker_uid", cache_id)
settings_dict["sql_dialect"] = settings_dict.get(
"sql_dialect", self._sql_dialect
)
Expand Down
2 changes: 1 addition & 1 deletion splink/missingness.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def completeness_data(linker, input_tablename=None, cols=None):
columns = linker._settings_obj._columns_used_by_comparisons

if linker._settings_obj._source_dataset_column_name_is_required:
source_name = linker._settings_obj._source_dataset_column_name
source_name = linker._source_dataset_column_name
else:
# Set source dataset to a literal string if dedupe_only
source_name = "'_a'"
Expand Down
8 changes: 4 additions & 4 deletions splink/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _source_dataset_column_name_is_required(self):
]

@property
def _source_dataset_column_name(self):
def _source_dataset_input_column(self):
if self._source_dataset_column_name_is_required:
s_else_d = self._from_settings_dict_else_default
return s_else_d("source_dataset_column_name")
Expand All @@ -166,7 +166,7 @@ def _unique_id_input_columns(self) -> list[InputColumn]:

if self._source_dataset_column_name_is_required:
col = InputColumn(
self._source_dataset_column_name,
self._source_dataset_input_column,
settings_obj=self,
)
cols.append(col)
Expand Down Expand Up @@ -198,7 +198,7 @@ def _needs_matchkey_column(self) -> bool:
def _columns_used_by_comparisons(self):
cols_used = []
if self._source_dataset_column_name_is_required:
cols_used.append(self._source_dataset_column_name)
cols_used.append(self._source_dataset_input_column)
cols_used.append(self._unique_id_column_name)
for cc in self.comparisons:
cols = cc._input_columns_used_by_case_statement
Expand Down Expand Up @@ -428,7 +428,7 @@ def _as_completed_dict(self):
"comparisons": [cc._as_completed_dict() for cc in self.comparisons],
"probability_two_random_records_match": rr_match,
"unique_id_column_name": self._unique_id_column_name,
"source_dataset_column_name": self._source_dataset_column_name,
"source_dataset_column_name": self._source_dataset_input_column,
}
return {**self._settings_dict, **current_settings}

Expand Down
4 changes: 3 additions & 1 deletion splink/vertically_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ def vertically_concatenate_sql(linker: Linker) -> str:
if source_dataset_col_req:
sqls_to_union = []
for df_obj in linker._input_tables_dict.values():
source_ds_col = linker._source_dataset_column_name
sql = f"""
select '{df_obj.templated_name}' as source_dataset, {select_columns_sql}
select '{df_obj.templated_name}' as {source_ds_col},
{select_columns_sql}
{salt_sql}
from {df_obj.physical_name}
"""
Expand Down
70 changes: 52 additions & 18 deletions tests/test_full_example_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,56 @@ def test_full_example_duckdb(tmp_path):
DuckDBLinker(df, settings_dict=path)


# Create some dummy dataframes for the link only test
df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv")
df_l = df.copy()
df_r = df.copy()
df_l["source_dataset"] = "my_left_ds"
df_r["source_dataset"] = "my_right_ds"
df_final = df_l.append(df_r)
# Tests link only jobs under different inputs:
# * A single dataframe with a `source_dataset` column
# * Two input dataframes with no specified `source_dataset` column
# * Two input dataframes with a specified `source_dataset` column
@pytest.mark.parametrize(
("input", "source_l", "source_r"),
[
pytest.param(
[df, df], # no source_dataset col
{"__splink__input_table_0"},
{"__splink__input_table_1"},
id="No source dataset column",
),
pytest.param(
df_final, # source_dataset col
{"my_left_ds"},
{"my_right_ds"},
id="Source dataset column in a single df",
),
pytest.param(
[df_l, df_r], # source_dataset col
{"my_left_ds"},
{"my_right_ds"},
id="Source dataset column in two dfs",
),
],
)
def test_link_only(input, source_l, source_r):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice test! Really helps understand the change

settings = get_settings_dict()
settings["link_type"] = "link_only"
settings["source_dataset_column_name"] = "source_dataset"

linker = DuckDBLinker(
input,
settings,
)
df_predict = linker.predict().as_pandas_dataframe()

assert len(df_predict) == 7257
assert set(df_predict.source_dataset_l.values) == source_l
assert set(df_predict.source_dataset_r.values) == source_r


@pytest.mark.parametrize(
("df"),
[
Expand Down Expand Up @@ -238,29 +288,13 @@ def test_small_example_duckdb(tmp_path):
"retain_intermediate_calculation_columns": True,
}

linker = DuckDBLinker(
df,
connection=os.path.join(tmp_path, "duckdb.db"),
output_schema="splink_in_duckdb",
)
linker.load_settings(settings_dict)
linker = DuckDBLinker(df, settings_dict)

linker.estimate_u_using_random_sampling(max_pairs=1e6)
linker.estimate_probability_two_random_records_match(
["l.email = r.email"], recall=0.3
)
blocking_rule = "l.full_name = r.full_name"
linker.estimate_parameters_using_expectation_maximisation(blocking_rule)

ThomasHepworth marked this conversation as resolved.
Show resolved Hide resolved
blocking_rule = "l.dob = r.dob"
linker.estimate_parameters_using_expectation_maximisation(blocking_rule)

df_predict = linker.predict()
linker.cluster_pairwise_predictions_at_threshold(df_predict, 0.1)

path = os.path.join(tmp_path, "model.json")
linker.save_settings_to_json(path)

linker_2 = DuckDBLinker(df, connection=":memory:")
linker_2.load_settings(path)
DuckDBLinker(df, settings_dict=path)
linker.predict()
25 changes: 23 additions & 2 deletions tests/test_full_example_spark.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from pyspark.sql.functions import array
import pyspark.sql.functions as f
from pyspark.sql.types import StringType, StructField, StructType

import splink.spark.spark_comparison_level_library as cll
Expand All @@ -13,7 +13,7 @@

def test_full_example_spark(df_spark, tmp_path):
# Convert a column to an array to enable testing intersection
df_spark = df_spark.withColumn("email", array("email"))
df_spark = df_spark.withColumn("email", f.array("email"))
settings_dict = get_settings_dict()

# Only needed because the value can be overwritten by other tests
Expand Down Expand Up @@ -128,3 +128,24 @@ def test_full_example_spark(df_spark, tmp_path):
break_lineage_method="checkpoint",
num_partitions_on_repartition=2,
)


def test_link_only(df_spark):
settings = get_settings_dict()
settings["link_type"] = "link_only"
settings["source_dataset_column_name"] = "source_dataset"

df_spark_a = df_spark.withColumn("source_dataset", f.lit("my_left_ds"))
df_spark_b = df_spark.withColumn("source_dataset", f.lit("my_right_ds"))

linker = SparkLinker(
[df_spark_a, df_spark_b],
settings,
break_lineage_method="checkpoint",
num_partitions_on_repartition=2,
)
df_predict = linker.predict().as_pandas_dataframe()

assert len(df_predict) == 7257
assert set(df_predict.source_dataset_l.values) == {"my_left_ds"}
assert set(df_predict.source_dataset_r.values) == {"my_right_ds"}