diff --git a/CHANGELOG.md b/CHANGELOG.md index 6fb4ad3af4..7fded8b3f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- When using DuckDB, you can now pass `duckdb.DuckDBPyRelation`s as input tables to the `Linker` ([#2375](https://github.com/moj-analytical-services/splink/pull/2375)) + ### Fixed - Completeness chart now works correctly with indexed columns in spark ([#2309](https://github.com/moj-analytical-services/splink/pull/2309)) diff --git a/splink/internals/athena/database_api.py b/splink/internals/athena/database_api.py index 4de9822221..7102dff090 100644 --- a/splink/internals/athena/database_api.py +++ b/splink/internals/athena/database_api.py @@ -3,13 +3,13 @@ import json import logging import os -from typing import Any, Sequence +from typing import Any import awswrangler as wr import boto3 import pandas as pd -from ..database_api import AcceptableInputTableType, DatabaseAPI +from ..database_api import DatabaseAPI from ..dialects import AthenaDialect from ..sql_transform import sqlglot_transform_sql from .athena_helpers.athena_transforms import cast_concat_as_varchar @@ -250,17 +250,3 @@ def accepted_df_dtypes(self): except ImportError: pass return accepted_df_dtypes - - def load_from_file(self, file_path: str) -> str: - raise NotImplementedError( - "Loading from file is not supported for Athena. " - "Please use the `table` method to load data." - ) - - def process_input_tables( - self, input_tables: Sequence[AcceptableInputTableType] - ) -> Sequence[AcceptableInputTableType]: - input_tables = super().process_input_tables(input_tables) - return [ - self.load_from_file(t) if isinstance(t, str) else t for t in input_tables - ] diff --git a/splink/internals/duckdb/database_api.py b/splink/internals/duckdb/database_api.py index f16660df4c..010c452aed 100644 --- a/splink/internals/duckdb/database_api.py +++ b/splink/internals/duckdb/database_api.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Sequence, Union +from typing import Union import duckdb import pandas as pd @@ -14,7 +14,6 @@ from .dataframe import DuckDBDataFrame from .duckdb_helpers.duckdb_helpers import ( create_temporary_duckdb_connection, - duckdb_load_from_file, validate_duckdb_connection, ) @@ -93,9 +92,6 @@ def table_exists_in_database(self, table_name): return False return True - def load_from_file(self, file_path: str) -> str: - return duckdb_load_from_file(file_path) - def _execute_sql_against_backend(self, final_sql: str) -> duckdb.DuckDBPyRelation: return self._con.sql(final_sql) @@ -110,11 +106,3 @@ def accepted_df_dtypes(self): except ImportError: pass return accepted_df_dtypes - - def process_input_tables( - self, input_tables: Sequence[AcceptableInputTableType] - ) -> Sequence[AcceptableInputTableType]: - input_tables = super().process_input_tables(input_tables) - return [ - self.load_from_file(t) if isinstance(t, str) else t for t in input_tables - ] diff --git a/splink/internals/duckdb/dataframe.py b/splink/internals/duckdb/dataframe.py index 2c8fa0ef13..1cf5af14b8 100644 --- a/splink/internals/duckdb/dataframe.py +++ b/splink/internals/duckdb/dataframe.py @@ -4,6 +4,7 @@ import os from typing import TYPE_CHECKING +from duckdb import DuckDBPyRelation from pandas import DataFrame as pd_DataFrame from splink.internals.input_column import InputColumn @@ -50,6 +51,13 @@ def as_pandas_dataframe(self, limit: int = None) -> pd_DataFrame: return self.db_api._execute_sql_against_backend(sql).to_df() + def as_duckdbpyrelation(self, limit: int = None) -> DuckDBPyRelation: + sql = f"select * from {self.physical_name}" + if limit: + sql += f" limit {limit}" + + return self.db_api._execute_sql_against_backend(sql) + def to_parquet(self, filepath, overwrite=False): if not overwrite: self.check_file_exists(filepath) diff --git a/splink/internals/duckdb/duckdb_helpers/duckdb_helpers.py b/splink/internals/duckdb/duckdb_helpers/duckdb_helpers.py index f12adabe12..20b579f966 100644 --- a/splink/internals/duckdb/duckdb_helpers/duckdb_helpers.py +++ b/splink/internals/duckdb/duckdb_helpers/duckdb_helpers.py @@ -1,7 +1,6 @@ import os import tempfile import uuid -from pathlib import Path import duckdb @@ -26,7 +25,7 @@ def validate_duckdb_connection(connection, logger): connection = connection.lower() - if connection in [":memory:", ":temporary:"]: + if connection in [":memory:", ":temporary:", ":default:"]: return suffixes = (".duckdb", ".db") @@ -49,15 +48,3 @@ def create_temporary_duckdb_connection(self): path = os.path.join(self._temp_dir.name, f"{fname}.duckdb") con = duckdb.connect(database=path, read_only=False) return con - - -def duckdb_load_from_file(path: str) -> str: - file_functions = { - ".csv": f"read_csv_auto('{path}')", - ".parquet": f"read_parquet('{path}')", - } - file_ext = Path(path).suffix - if file_ext in file_functions.keys(): - return file_functions[file_ext] - else: - return path diff --git a/splink/internals/linker.py b/splink/internals/linker.py index b284708aae..24f9b0bed6 100644 --- a/splink/internals/linker.py +++ b/splink/internals/linker.py @@ -306,9 +306,11 @@ def _register_input_tables( input_tables: Sequence[AcceptableInputTableType], input_aliases: Optional[str | List[str]], ) -> Dict[str, SplinkDataFrame]: + input_tables_list = ensure_is_list(input_tables) + if input_aliases is None: input_table_aliases = [ - f"__splink__input_table_{i}" for i, _ in enumerate(input_tables) + f"__splink__input_table_{i}" for i, _ in enumerate(input_tables_list) ] overwrite = True else: diff --git a/tests/test_full_example_duckdb.py b/tests/test_full_example_duckdb.py index 955dddc3d3..7f6476c73e 100644 --- a/tests/test_full_example_duckdb.py +++ b/tests/test_full_example_duckdb.py @@ -1,5 +1,6 @@ import os +import duckdb import pandas as pd import pyarrow as pa import pyarrow.csv as pa_csv @@ -8,10 +9,9 @@ import splink.internals.comparison_level_library as cll import splink.internals.comparison_library as cl +from splink import DuckDBAPI, Linker, SettingsCreator, block_on from splink.blocking_analysis import count_comparisons_from_blocking_rule from splink.exploratory import completeness_chart, profile_columns -from splink.internals.duckdb.database_api import DuckDBAPI -from splink.internals.linker import Linker from .basic_settings import get_settings_dict, name_comparison from .decorator import mark_with_dialects_including @@ -197,10 +197,6 @@ def test_link_only(input, source_l, source_r): pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv"), id="DuckDB link from pandas df", ), - pytest.param( - "./tests/datasets/fake_1000_from_splink_demos.csv", - id="DuckDB load from file", - ), pytest.param( pa.Table.from_pandas( pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") @@ -219,7 +215,7 @@ def test_link_only(input, source_l, source_r): ], ) @mark_with_dialects_including("duckdb") -def test_duckdb_load_from_file(df): +def test_duckdb_load_different_tablish_types(df): settings = get_settings_dict() db_api = DuckDBAPI() @@ -324,3 +320,18 @@ def test_small_example_duckdb(tmp_path): linker.training.estimate_parameters_using_expectation_maximisation(blocking_rule) linker.inference.predict() + + +@mark_with_dialects_including("duckdb") +def test_duckdb_input_is_duckdbpyrelation(): + df1 = duckdb.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") + df2 = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv") + + settings = SettingsCreator( + link_type="link_and_dedupe", + comparisons=[cl.ExactMatch("first_name")], + blocking_rules_to_generate_predictions=[block_on("first_name", "surname")], + ) + db_api = DuckDBAPI(connection=":default:") + linker = Linker([df1, df2], settings, db_api) + linker.inference.predict()