Skip to content

Commit

Permalink
ENH: pass thru correct column in sparkaggregation if present
Browse files Browse the repository at this point in the history
For spark, add all missing pass-thru column, i.e. correct match.
  • Loading branch information
mbaak committed Apr 20, 2024
1 parent d543d6b commit 86948b9
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
2 changes: 2 additions & 0 deletions emm/aggregation/base_entity_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(
preprocessed_col: str = "preprocessed",
gt_name_col: str = "gt_name",
gt_preprocessed_col: str = "gt_preprocessed",
correct_col: str = "correct",
aggregation_method: Literal["max_frequency_nm_score", "mean_score"] = "max_frequency_nm_score",
blacklist: list | None = None,
positive_set_col: str = "positive_set",
Expand All @@ -157,6 +158,7 @@ def __init__(
self.preprocessed_col = preprocessed_col
self.gt_name_col = gt_name_col
self.gt_preprocessed_col = gt_preprocessed_col
self.correct_col = correct_col
self.aggregation_method = aggregation_method
self.blacklist = blacklist or []
self.positive_set_col = positive_set_col
Expand Down
3 changes: 3 additions & 0 deletions emm/aggregation/pandas_entity_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
preprocessed_col: str = "preprocessed",
gt_name_col: str = "gt_name",
gt_preprocessed_col: str = "gt_preprocessed",
correct_col: str = "correct",
aggregation_method: Literal["max_frequency_nm_score", "mean_score"] = "max_frequency_nm_score",
blacklist: list[str] | None = None,
) -> None:
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(
preprocessed_col: Name of column of preprocessed input
gt_name_col: ground truth name column, default is "gt_name".
gt_preprocessed_col: column name of preprocessed ground truth names, default is "preprocessed".
correct_col: column indicating correct matches, if present. default is "correct". optional.
aggregation_method: default is "max_frequency_nm_score", alternative is "mean_score".
blacklist: blacklist of names to skip in clustering.
"""
Expand All @@ -97,6 +99,7 @@ def __init__(
preprocessed_col=preprocessed_col,
gt_name_col=gt_name_col,
gt_preprocessed_col=gt_preprocessed_col,
correct_col=correct_col,
aggregation_method=aggregation_method,
blacklist=blacklist or [],
)
Expand Down
14 changes: 11 additions & 3 deletions emm/aggregation/spark_entity_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark.sql.functions import col, lit
from pyspark.sql.pandas.functions import PandasUDFType, pandas_udf
from pyspark.sql.types import FloatType, IntegerType, StringType, StructField
from pyspark.sql.types import BooleanType, FloatType, IntegerType, StringType, StructField

from emm.aggregation.base_entity_aggregation import BaseEntityAggregation, matching_max_candidate
from emm.helper.spark_custom_reader_writer import SparkReadable, SparkWriteable
Expand All @@ -49,7 +49,10 @@ class SparkEntityAggregation(
"uid_col",
"freq_col",
"output_col",
"processed_col",
"preprocessed_col",
"gt_name_col",
"gt_preprocessed_col",
"correct_col",
"aggregation_method",
"blacklist",
)
Expand All @@ -66,6 +69,7 @@ def __init__(
preprocessed_col: str = "preprocessed",
gt_name_col: str = "gt_name",
gt_preprocessed_col: str = "gt_preprocessed",
correct_col: str = "correct",
aggregation_method: Literal["max_frequency_nm_score", "mean_score"] = "max_frequency_nm_score",
blacklist: list | None = None,
) -> None:
Expand Down Expand Up @@ -98,6 +102,7 @@ def __init__(
preprocessed_col: Name of column of preprocessed input, default is "preprocessed".
gt_name_col: ground truth name column, default is "gt_name".
gt_preprocessed_col: column name of preprocessed ground truth names. default is "gt_preprocessed".
correct_col: column indicating correct matches, pass-thru if present. default is "correct". optional.
aggregation_method: default is "max_frequency_nm_score", alternative is "mean_score".
blacklist: blacklist of names to skip in clustering.
"""
Expand All @@ -115,6 +120,7 @@ def __init__(
preprocessed_col=preprocessed_col,
gt_name_col=gt_name_col,
gt_preprocessed_col=gt_preprocessed_col,
correct_col=correct_col,
blacklist=blacklist or [],
)

Expand Down Expand Up @@ -145,6 +151,9 @@ def _transform(self, dataframe):
schema.add(StructField(self.freq_col, IntegerType(), True))
schema.add(StructField(self.gt_name_col, StringType(), True))
schema.add(StructField(self.gt_preprocessed_col, StringType(), True))
# pass through the correct-match column if present on training or test sets; useful for accuracy testing
if self.correct_col in dataframe.columns:
schema.add(StructField(self.correct_col, BooleanType(), True))

@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def matching_max_candidate_wrapper(_, df) -> pd.DataFrame:
Expand All @@ -158,7 +167,6 @@ def matching_max_candidate_wrapper(_, df) -> pd.DataFrame:
output_col=self.output_col,
aggregation_method=self.aggregation_method,
)

return df[[c.name for c in schema]]

# remove all irrelevant non-matches before applying account matching
Expand Down

0 comments on commit 86948b9

Please sign in to comment.