Skip to content

Commit

Permalink
Merge pull request #2375 from moj-analytical-services/duckdb_py_relation
Browse files Browse the repository at this point in the history
Support duckdbpyrelation as input type
  • Loading branch information
RobinL authored Sep 3, 2024
2 parents 69278a6 + fababd9 commit 8e41c4e
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 51 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- When using DuckDB, you can now pass `duckdb.DuckDBPyRelation`s as input tables to the `Linker` ([#2375](https://github.com/moj-analytical-services/splink/pull/2375))

### Fixed

- Completeness chart now works correctly with indexed columns in spark ([#2309](https://github.com/moj-analytical-services/splink/pull/2309))
Expand Down
18 changes: 2 additions & 16 deletions splink/internals/athena/database_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import json
import logging
import os
from typing import Any, Sequence
from typing import Any

import awswrangler as wr
import boto3
import pandas as pd

from ..database_api import AcceptableInputTableType, DatabaseAPI
from ..database_api import DatabaseAPI
from ..dialects import AthenaDialect
from ..sql_transform import sqlglot_transform_sql
from .athena_helpers.athena_transforms import cast_concat_as_varchar
Expand Down Expand Up @@ -250,17 +250,3 @@ def accepted_df_dtypes(self):
except ImportError:
pass
return accepted_df_dtypes

def load_from_file(self, file_path: str) -> str:
raise NotImplementedError(
"Loading from file is not supported for Athena. "
"Please use the `table` method to load data."
)

def process_input_tables(
self, input_tables: Sequence[AcceptableInputTableType]
) -> Sequence[AcceptableInputTableType]:
input_tables = super().process_input_tables(input_tables)
return [
self.load_from_file(t) if isinstance(t, str) else t for t in input_tables
]
14 changes: 1 addition & 13 deletions splink/internals/duckdb/database_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import Sequence, Union
from typing import Union

import duckdb
import pandas as pd
Expand All @@ -14,7 +14,6 @@
from .dataframe import DuckDBDataFrame
from .duckdb_helpers.duckdb_helpers import (
create_temporary_duckdb_connection,
duckdb_load_from_file,
validate_duckdb_connection,
)

Expand Down Expand Up @@ -93,9 +92,6 @@ def table_exists_in_database(self, table_name):
return False
return True

def load_from_file(self, file_path: str) -> str:
return duckdb_load_from_file(file_path)

def _execute_sql_against_backend(self, final_sql: str) -> duckdb.DuckDBPyRelation:
return self._con.sql(final_sql)

Expand All @@ -110,11 +106,3 @@ def accepted_df_dtypes(self):
except ImportError:
pass
return accepted_df_dtypes

def process_input_tables(
self, input_tables: Sequence[AcceptableInputTableType]
) -> Sequence[AcceptableInputTableType]:
input_tables = super().process_input_tables(input_tables)
return [
self.load_from_file(t) if isinstance(t, str) else t for t in input_tables
]
8 changes: 8 additions & 0 deletions splink/internals/duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
from typing import TYPE_CHECKING

from duckdb import DuckDBPyRelation
from pandas import DataFrame as pd_DataFrame

from splink.internals.input_column import InputColumn
Expand Down Expand Up @@ -50,6 +51,13 @@ def as_pandas_dataframe(self, limit: int = None) -> pd_DataFrame:

return self.db_api._execute_sql_against_backend(sql).to_df()

def as_duckdbpyrelation(self, limit: int = None) -> DuckDBPyRelation:
sql = f"select * from {self.physical_name}"
if limit:
sql += f" limit {limit}"

return self.db_api._execute_sql_against_backend(sql)

def to_parquet(self, filepath, overwrite=False):
if not overwrite:
self.check_file_exists(filepath)
Expand Down
15 changes: 1 addition & 14 deletions splink/internals/duckdb/duckdb_helpers/duckdb_helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import tempfile
import uuid
from pathlib import Path

import duckdb

Expand All @@ -26,7 +25,7 @@ def validate_duckdb_connection(connection, logger):

connection = connection.lower()

if connection in [":memory:", ":temporary:"]:
if connection in [":memory:", ":temporary:", ":default:"]:
return

suffixes = (".duckdb", ".db")
Expand All @@ -49,15 +48,3 @@ def create_temporary_duckdb_connection(self):
path = os.path.join(self._temp_dir.name, f"{fname}.duckdb")
con = duckdb.connect(database=path, read_only=False)
return con


def duckdb_load_from_file(path: str) -> str:
file_functions = {
".csv": f"read_csv_auto('{path}')",
".parquet": f"read_parquet('{path}')",
}
file_ext = Path(path).suffix
if file_ext in file_functions.keys():
return file_functions[file_ext]
else:
return path
4 changes: 3 additions & 1 deletion splink/internals/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,11 @@ def _register_input_tables(
input_tables: Sequence[AcceptableInputTableType],
input_aliases: Optional[str | List[str]],
) -> Dict[str, SplinkDataFrame]:
input_tables_list = ensure_is_list(input_tables)

if input_aliases is None:
input_table_aliases = [
f"__splink__input_table_{i}" for i, _ in enumerate(input_tables)
f"__splink__input_table_{i}" for i, _ in enumerate(input_tables_list)
]
overwrite = True
else:
Expand Down
25 changes: 18 additions & 7 deletions tests/test_full_example_duckdb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

import duckdb
import pandas as pd
import pyarrow as pa
import pyarrow.csv as pa_csv
Expand All @@ -8,10 +9,9 @@

import splink.internals.comparison_level_library as cll
import splink.internals.comparison_library as cl
from splink import DuckDBAPI, Linker, SettingsCreator, block_on
from splink.blocking_analysis import count_comparisons_from_blocking_rule
from splink.exploratory import completeness_chart, profile_columns
from splink.internals.duckdb.database_api import DuckDBAPI
from splink.internals.linker import Linker

from .basic_settings import get_settings_dict, name_comparison
from .decorator import mark_with_dialects_including
Expand Down Expand Up @@ -197,10 +197,6 @@ def test_link_only(input, source_l, source_r):
pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv"),
id="DuckDB link from pandas df",
),
pytest.param(
"./tests/datasets/fake_1000_from_splink_demos.csv",
id="DuckDB load from file",
),
pytest.param(
pa.Table.from_pandas(
pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv")
Expand All @@ -219,7 +215,7 @@ def test_link_only(input, source_l, source_r):
],
)
@mark_with_dialects_including("duckdb")
def test_duckdb_load_from_file(df):
def test_duckdb_load_different_tablish_types(df):
settings = get_settings_dict()

db_api = DuckDBAPI()
Expand Down Expand Up @@ -324,3 +320,18 @@ def test_small_example_duckdb(tmp_path):
linker.training.estimate_parameters_using_expectation_maximisation(blocking_rule)

linker.inference.predict()


@mark_with_dialects_including("duckdb")
def test_duckdb_input_is_duckdbpyrelation():
df1 = duckdb.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv")
df2 = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv")

settings = SettingsCreator(
link_type="link_and_dedupe",
comparisons=[cl.ExactMatch("first_name")],
blocking_rules_to_generate_predictions=[block_on("first_name", "surname")],
)
db_api = DuckDBAPI(connection=":default:")
linker = Linker([df1, df2], settings, db_api)
linker.inference.predict()

0 comments on commit 8e41c4e

Please sign in to comment.