Skip to content

Commit

Permalink
Merge a347153 into 28c88a3
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasHepworth authored Apr 27, 2023
2 parents 28c88a3 + a347153 commit ace12fd
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 3 deletions.
3 changes: 2 additions & 1 deletion splink/analyse_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def cumulative_comparisons_generated_by_blocking_rules(
# Calculate the Cartesian Product
if output_chart:
# We only need the cartesian product if we want to output the chart view
if len(linker._input_tables_dict) == 1:

if settings_obj._link_type == "dedupe_only":
group_by_statement = ""
else:
group_by_statement = "group by source_dataset"
Expand Down
37 changes: 35 additions & 2 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
bayes_factor_to_prob,
ensure_is_list,
ensure_is_tuple,
find_unique_source_dataset,
prob_to_bayes_factor,
)
from .missingness import completeness_data, missingness_data
Expand Down Expand Up @@ -291,8 +292,7 @@ def _source_dataset_column_name(self):
df_obj = next(iter(self._input_tables_dict.values()))
columns = df_obj.columns_escaped

input_column = self._settings_obj._source_dataset_input_column
src_ds_col = InputColumn(input_column, self).name()
input_column, src_ds_col = self._settings_obj_._source_dataset_col
return "__splink_source_dataset" if src_ds_col in columns else input_column
else:
return None
Expand Down Expand Up @@ -340,6 +340,35 @@ def _infinity_expression(self):
f"infinity sql expression not available for {type(self)}"
)

@property
def _verify_link_only_job(self):

cache = self._intermediate_table_cache
if "__splink__df_concat_with_tf" not in cache:
return

if self._settings_obj._link_type == "link_only":
# if input datasets > 1 then skip
if len(self._input_tables_dict) > 1:
return

# else, check if source dataset column is populated...
src_ds = self._source_dataset_column_name
if src_ds == "__splink_source_dataset":
_, src_ds = self._settings_obj_._source_dataset_col

sql = find_unique_source_dataset(src_ds)
self._enqueue_sql(sql, "source_ds_distinct")
src_ds_distinct = self._execute_sql_pipeline(
[cache["__splink__df_concat_with_tf"]]
)
if len(src_ds_distinct.as_record_dict()) == 1:
raise SplinkException(
"if `link_type` is `link_only`, it should have at least two "
"input dataframes, or one dataframe with a `source_dataset` "
"column outlining which dataset each record belongs to."
)

def _register_input_tables(self, input_tables, input_aliases, accepted_df_dtypes):
# 'homogenised' means all entries are strings representing tables
homogenised_tables = []
Expand Down Expand Up @@ -427,6 +456,10 @@ def _initialise_df_concat_with_tf(self, materialise=True):
nodes_with_tf = self._execute_sql_pipeline()
cache["__splink__df_concat_with_tf"] = nodes_with_tf

# verify the link job
if self._settings_obj_ is not None:
self._verify_link_only_job

return nodes_with_tf

def _table_to_splink_dataframe(
Expand Down
9 changes: 9 additions & 0 deletions splink/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,12 @@ def major_minor_version_greater_equal_than(this_version, base_comparison_version

def ascii_uid(len):
return "".join(random.choices(string.ascii_letters + string.digits, k=len))


def find_unique_source_dataset(src_ds):
sql = f"""
select distinct {src_ds} as src
from __splink__df_concat_with_tf
"""

return sql
5 changes: 5 additions & 0 deletions splink/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ def _source_dataset_input_column(self):
else:
return None

@property
def _source_dataset_col(self):
input_column = self._source_dataset_input_column
return (input_column, InputColumn(input_column, self).name())

@property
def _unique_id_input_columns(self) -> list[InputColumn]:
cols = []
Expand Down
48 changes: 48 additions & 0 deletions tests/test_link_only_verification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pandas as pd
import pytest

from splink.duckdb.duckdb_linker import DuckDBLinker
from splink.exceptions import SplinkException
from tests.basic_settings import get_settings_dict

df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv")
df_l = df.copy()
df_r = df.copy()
df_l["source_dataset"] = "my_left_ds"
df_r["source_dataset"] = "my_right_ds"
df_final = df_l.append(df_r)

settings = get_settings_dict()
settings["link_type"] = "link_only"


def test_link_only_verification():
# As `_initialise_df_concat_with_tf()` cannot be run without
# a setting object, we don't need to test that.

# Two input dataframes + link only settings
linker = DuckDBLinker(
[df_l, df_r],
settings,
)
linker._initialise_df_concat_with_tf()

# A single dataframe with a source_dataset col
linker = DuckDBLinker(
df_final,
settings,
)
linker._initialise_df_concat_with_tf()

# A single df with no source_dataset col, despite
# calling link_only. Should fail w/ SplinkException
linker = DuckDBLinker(
df,
settings,
)
# This should pass as concat_with_tf doesn't yet exist
linker._verify_link_only_job
with pytest.raises(SplinkException):
# Fails as only one df w/ no source_dataset col has
# been passed
linker._initialise_df_concat_with_tf()

0 comments on commit ace12fd

Please sign in to comment.