Skip to content

Commit

Permalink
Merge cb77088 into 7ddb64c
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Mar 22, 2023
2 parents 7ddb64c + cb77088 commit 5a590cf
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "splink"
version = "3.7.1"
version = "3.7.2"
description = "Fast probabilistic data linkage at scale"
authors = ["Robin Linacre <robinlinacre@hotmail.com>", "Sam Lindsay", "Theodore Manassis", "Tom Hepworth", "Andy Bond"]
license = "MIT"
Expand Down
2 changes: 1 addition & 1 deletion splink/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.7.1"
__version__ = "3.7.2"
26 changes: 24 additions & 2 deletions splink/spark/spark_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
catalog=None,
database=None,
repartition_after_blocking=False,
num_partitions_on_repartition=100,
num_partitions_on_repartition=None,
):
"""Initialise the linker object, which manages the data linkage process and
holds the data linkage model.
Expand Down Expand Up @@ -117,7 +117,6 @@ def __init__(
self.break_lineage_method = break_lineage_method

self.repartition_after_blocking = repartition_after_blocking
self.num_partitions_on_repartition = num_partitions_on_repartition

input_tables = ensure_is_list(input_table_or_tables)

Expand All @@ -127,6 +126,25 @@ def __init__(

self._get_spark_from_input_tables_if_not_provided(spark, input_tables)

if num_partitions_on_repartition is None:
parallelism_value = 200
try:
parallelism_value = self.spark.conf.get("spark.default.parallelism")
parallelism_value = int(parallelism_value)
except Exception:
pass

# Prefer spark.sql.shuffle.partitions if set
try:
parallelism_value = self.spark.conf.get("spark.sql.shuffle.partitions")
parallelism_value = int(parallelism_value)
except Exception:
pass

self.num_partitions_on_repartition = math.ceil(parallelism_value / 2)
else:
self.num_partitions_on_repartition = num_partitions_on_repartition

self._set_catalog_and_database_if_not_provided(catalog, database)

self._drop_splink_cached_tables()
Expand Down Expand Up @@ -292,10 +310,14 @@ def _repartition_if_needed(self, spark_df, templated_name):
r"__splink__df_representatives",
r"__splink__df_concat_with_tf_sample",
r"__splink__df_concat_with_tf",
r"__splink__df_predict",
]

num_partitions = self.num_partitions_on_repartition

if re.fullmatch(r"__splink__df_predict", templated_name):
num_partitions = math.ceil(self.num_partitions_on_repartition)

if re.fullmatch(r"__splink__df_representatives", templated_name):
num_partitions = math.ceil(self.num_partitions_on_repartition / 6)

Expand Down

0 comments on commit 5a590cf

Please sign in to comment.