-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-40399][PS] Make pearson correlation in DataFrame.corr support missing values and min_periods
#37845
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-40399][PS] Make pearson correlation in DataFrame.corr support missing values and min_periods
#37845
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1417,15 +1417,23 @@ def aggregate(self, func: Union[List[str], Dict[Name, List[str]]]) -> "DataFrame | |
|
|
||
| agg = aggregate | ||
|
|
||
| def corr(self, method: str = "pearson") -> "DataFrame": | ||
| def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "DataFrame": | ||
| """ | ||
| Compute pairwise correlation of columns, excluding NA/null values. | ||
|
|
||
| .. versionadded:: 3.3.0 | ||
|
|
||
| Parameters | ||
| ---------- | ||
| method : {'pearson', 'spearman'} | ||
| * pearson : standard correlation coefficient | ||
| * spearman : Spearman rank correlation | ||
| min_periods : int, optional | ||
| Minimum number of observations required per pair of columns | ||
| to have a valid result. Currently only available for Pearson | ||
| correlation. | ||
|
|
||
| .. versionadded:: 3.4.0 | ||
|
|
||
| Returns | ||
| ------- | ||
|
|
@@ -1454,11 +1462,202 @@ def corr(self, method: str = "pearson") -> "DataFrame": | |
| There are behavior differences between pandas-on-Spark and pandas. | ||
|
|
||
| * the `method` argument only accepts 'pearson', 'spearman' | ||
| * the data should not contain NaNs. pandas-on-Spark will return an error. | ||
| * pandas-on-Spark doesn't support the following argument(s). | ||
| * if the `method` is `spearman`, the data should not contain NaNs. | ||
| * if the `method` is `spearman`, `min_periods` argument is not supported. | ||
| """ | ||
| if method not in ["pearson", "spearman", "kendall"]: | ||
| raise ValueError(f"Invalid method {method}") | ||
| if method == "kendall": | ||
| raise NotImplementedError("method doesn't support kendall for now") | ||
| if min_periods is not None and not isinstance(min_periods, int): | ||
| raise TypeError(f"Invalid min_periods type {type(min_periods).__name__}") | ||
|
Comment on lines
+1472
to
+1473
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like pandas allows float: >>> pdf.corr('pearson', min_periods=1.4)
dogs cats
dogs 1.000000 -0.851064
cats -0.851064 1.000000But I'm not sure if it's intended behavior or not, since they raises >>> pdf.corr('pearson', min_periods='a')
Traceback (most recent call last):
...
TypeError: an integer is required
>>>
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am also not sure, but the type of |
||
| if min_periods is not None and method == "spearman": | ||
| raise NotImplementedError("min_periods doesn't support spearman for now") | ||
|
|
||
| if method == "pearson": | ||
| min_periods = 1 if min_periods is None else min_periods | ||
| internal = self._internal.resolved_copy | ||
| numeric_labels = [ | ||
| label | ||
| for label in internal.column_labels | ||
| if isinstance(internal.spark_type_for(label), (NumericType, BooleanType)) | ||
| ] | ||
| numeric_scols: List[Column] = [ | ||
| internal.spark_column_for(label).cast("double") for label in numeric_labels | ||
| ] | ||
| numeric_col_names: List[str] = [name_like_string(label) for label in numeric_labels] | ||
| num_scols = len(numeric_scols) | ||
|
|
||
| sdf = internal.spark_frame | ||
| tmp_index_1_col_name = verify_temp_column_name(sdf, "__tmp_index_1_col__") | ||
| tmp_index_2_col_name = verify_temp_column_name(sdf, "__tmp_index_2_col__") | ||
| tmp_value_1_col_name = verify_temp_column_name(sdf, "__tmp_value_1_col__") | ||
| tmp_value_2_col_name = verify_temp_column_name(sdf, "__tmp_value_2_col__") | ||
|
|
||
| # simple dataset | ||
| # +---+---+----+ | ||
| # | A| B| C| | ||
| # +---+---+----+ | ||
| # | 1| 2| 3.0| | ||
| # | 4| 1|null| | ||
| # +---+---+----+ | ||
|
|
||
| pair_scols: List[Column] = [] | ||
| for i in range(0, num_scols): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: we can omit the - for i in range(0, num_scols):
+ for i in range(num_scols):Either looks okay, though.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just because that other places use |
||
| for j in range(i, num_scols): | ||
| pair_scols.append( | ||
| F.struct( | ||
| F.lit(i).alias(tmp_index_1_col_name), | ||
| F.lit(j).alias(tmp_index_2_col_name), | ||
| numeric_scols[i].alias(tmp_value_1_col_name), | ||
| numeric_scols[j].alias(tmp_value_2_col_name), | ||
| ) | ||
| ) | ||
|
|
||
| # +-------------------+-------------------+-------------------+-------------------+ | ||
| # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_value_1_col__|__tmp_value_2_col__| | ||
| # +-------------------+-------------------+-------------------+-------------------+ | ||
| # | 0| 0| 1.0| 1.0| | ||
| # | 0| 1| 1.0| 2.0| | ||
| # | 0| 2| 1.0| 3.0| | ||
| # | 1| 1| 2.0| 2.0| | ||
| # | 1| 2| 2.0| 3.0| | ||
| # | 2| 2| 3.0| 3.0| | ||
| # | 0| 0| 4.0| 4.0| | ||
| # | 0| 1| 4.0| 1.0| | ||
| # | 0| 2| 4.0| null| | ||
| # | 1| 1| 1.0| 1.0| | ||
| # | 1| 2| 1.0| null| | ||
| # | 2| 2| null| null| | ||
| # +-------------------+-------------------+-------------------+-------------------+ | ||
| tmp_tuple_col_name = verify_temp_column_name(sdf, "__tmp_tuple_col__") | ||
| sdf = sdf.select(F.explode(F.array(*pair_scols)).alias(tmp_tuple_col_name)).select( | ||
| F.col(f"{tmp_tuple_col_name}.{tmp_index_1_col_name}").alias(tmp_index_1_col_name), | ||
| F.col(f"{tmp_tuple_col_name}.{tmp_index_2_col_name}").alias(tmp_index_2_col_name), | ||
| F.col(f"{tmp_tuple_col_name}.{tmp_value_1_col_name}").alias(tmp_value_1_col_name), | ||
| F.col(f"{tmp_tuple_col_name}.{tmp_value_2_col_name}").alias(tmp_value_2_col_name), | ||
| ) | ||
|
|
||
| # +-------------------+-------------------+------------------------+-----------------+ | ||
| # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_pearson_corr_col__|__tmp_count_col__| | ||
| # +-------------------+-------------------+------------------------+-----------------+ | ||
| # | 2| 2| null| 1| | ||
| # | 1| 2| null| 1| | ||
| # | 1| 1| 1.0| 2| | ||
| # | 0| 0| 1.0| 2| | ||
| # | 0| 1| -1.0| 2| | ||
| # | 0| 2| null| 1| | ||
| # +-------------------+-------------------+------------------------+-----------------+ | ||
| tmp_corr_col_name = verify_temp_column_name(sdf, "__tmp_pearson_corr_col__") | ||
| tmp_count_col_name = verify_temp_column_name(sdf, "__tmp_count_col__") | ||
| sdf = sdf.groupby(tmp_index_1_col_name, tmp_index_2_col_name).agg( | ||
| F.corr(tmp_value_1_col_name, tmp_value_2_col_name).alias(tmp_corr_col_name), | ||
| F.count( | ||
| F.when( | ||
| F.col(tmp_value_1_col_name).isNotNull() | ||
| & F.col(tmp_value_2_col_name).isNotNull(), | ||
| 1, | ||
| ) | ||
| ).alias(tmp_count_col_name), | ||
| ) | ||
|
|
||
| # +-------------------+-------------------+------------------------+ | ||
| # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_pearson_corr_col__| | ||
| # +-------------------+-------------------+------------------------+ | ||
| # | 2| 2| null| | ||
| # | 1| 2| null| | ||
| # | 2| 1| null| | ||
| # | 1| 1| 1.0| | ||
| # | 0| 0| 1.0| | ||
| # | 0| 1| -1.0| | ||
| # | 1| 0| -1.0| | ||
| # | 0| 2| null| | ||
| # | 2| 0| null| | ||
| # +-------------------+-------------------+------------------------+ | ||
| sdf = ( | ||
| sdf.withColumn( | ||
| tmp_corr_col_name, | ||
| F.when( | ||
| F.col(tmp_count_col_name) >= min_periods, F.col(tmp_corr_col_name) | ||
| ).otherwise(F.lit(None)), | ||
| ) | ||
| .withColumn( | ||
| tmp_tuple_col_name, | ||
| F.explode( | ||
| F.when( | ||
| F.col(tmp_index_1_col_name) == F.col(tmp_index_2_col_name), | ||
| F.lit([0]), | ||
| ).otherwise(F.lit([0, 1])) | ||
| ), | ||
| ) | ||
| .select( | ||
| F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_1_col_name)) | ||
| .otherwise(F.col(tmp_index_2_col_name)) | ||
| .alias(tmp_index_1_col_name), | ||
| F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_2_col_name)) | ||
| .otherwise(F.col(tmp_index_1_col_name)) | ||
| .alias(tmp_index_2_col_name), | ||
| F.col(tmp_corr_col_name), | ||
| ) | ||
| ) | ||
|
|
||
| # +-------------------+--------------------+ | ||
| # |__tmp_index_1_col__| __tmp_array_col__| | ||
| # +-------------------+--------------------+ | ||
| # | 0|[{0, 1.0}, {1, -1...| | ||
| # | 1|[{0, -1.0}, {1, 1...| | ||
| # | 2|[{0, null}, {1, n...| | ||
| # +-------------------+--------------------+ | ||
| tmp_array_col_name = verify_temp_column_name(sdf, "__tmp_array_col__") | ||
| sdf = ( | ||
| sdf.groupby(tmp_index_1_col_name) | ||
| .agg( | ||
| F.array_sort( | ||
| F.collect_list( | ||
| F.struct(F.col(tmp_index_2_col_name), F.col(tmp_corr_col_name)) | ||
| ) | ||
| ).alias(tmp_array_col_name) | ||
| ) | ||
| .orderBy(tmp_index_1_col_name) | ||
| ) | ||
|
|
||
| for i in range(0, num_scols): | ||
| sdf = sdf.withColumn( | ||
| tmp_tuple_col_name, F.get(F.col(tmp_array_col_name), i) | ||
| ).withColumn( | ||
| numeric_col_names[i], | ||
| F.col(f"{tmp_tuple_col_name}.{tmp_corr_col_name}"), | ||
| ) | ||
|
|
||
| index_col_names: List[str] = [] | ||
| if internal.column_labels_level > 1: | ||
| for level in range(0, internal.column_labels_level): | ||
| index_col_name = SPARK_INDEX_NAME_FORMAT(level) | ||
| indices = [label[level] for label in numeric_labels] | ||
| sdf = sdf.withColumn( | ||
| index_col_name, F.get(F.lit(indices), F.col(tmp_index_1_col_name)) | ||
| ) | ||
| index_col_names.append(index_col_name) | ||
| else: | ||
| sdf = sdf.withColumn( | ||
| SPARK_DEFAULT_INDEX_NAME, | ||
| F.get(F.lit(numeric_col_names), F.col(tmp_index_1_col_name)), | ||
| ) | ||
| index_col_names = [SPARK_DEFAULT_INDEX_NAME] | ||
|
|
||
| sdf = sdf.select(*index_col_names, *numeric_col_names) | ||
|
|
||
| return DataFrame( | ||
| InternalFrame( | ||
| spark_frame=sdf, | ||
| index_spark_columns=[ | ||
| scol_for(sdf, index_col_name) for index_col_name in index_col_names | ||
| ], | ||
| column_labels=numeric_labels, | ||
| column_label_names=internal.column_label_names, | ||
| ) | ||
| ) | ||
|
|
||
| * `min_periods` argument is not supported | ||
| """ | ||
| return cast(DataFrame, ps.from_pandas(corr(self, method))) | ||
|
|
||
| # TODO: add axis parameter and support more methods | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -257,6 +257,40 @@ def test_skew_kurt_numerical_stability(self): | |
| self.assert_eq(psdf.skew(), pdf.skew(), almost=True) | ||
| self.assert_eq(psdf.kurt(), pdf.kurt(), almost=True) | ||
|
|
||
| def test_dataframe_corr(self): | ||
| # existing 'test_corr' is mixed by df.corr and ser.corr, will delete 'test_corr' | ||
| # when we have separate tests for df.corr and ser.corr | ||
| pdf = makeMissingDataframe(0.3, 42) | ||
| psdf = ps.from_pandas(pdf) | ||
|
|
||
| with self.assertRaisesRegex(ValueError, "Invalid method"): | ||
| psdf.corr("std") | ||
| with self.assertRaisesRegex(NotImplementedError, "kendall for now"): | ||
| psdf.corr("kendall") | ||
| with self.assertRaisesRegex(TypeError, "Invalid min_periods type"): | ||
| psdf.corr(min_periods="3") | ||
| with self.assertRaisesRegex(NotImplementedError, "spearman for now"): | ||
| psdf.corr(method="spearman", min_periods=3) | ||
|
|
||
| self.assert_eq(psdf.corr(), pdf.corr(), check_exact=False) | ||
| self.assert_eq(psdf.corr(min_periods=1), pdf.corr(min_periods=1), check_exact=False) | ||
| self.assert_eq(psdf.corr(min_periods=3), pdf.corr(min_periods=3), check_exact=False) | ||
| self.assert_eq( | ||
| (psdf + 1).corr(min_periods=2), (pdf + 1).corr(min_periods=2), check_exact=False | ||
| ) | ||
|
|
||
| # multi-index columns | ||
| columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B"), ("Y", "C"), ("Z", "D")]) | ||
| pdf.columns = columns | ||
| psdf.columns = columns | ||
|
|
||
| self.assert_eq(psdf.corr(), pdf.corr(), check_exact=False) | ||
| self.assert_eq(psdf.corr(min_periods=1), pdf.corr(min_periods=1), check_exact=False) | ||
| self.assert_eq(psdf.corr(min_periods=3), pdf.corr(min_periods=3), check_exact=False) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we also test for chained operations? self.assert_eq((psdf + 1).corr(), (pdf + 1).corr(), check_exact=False)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice, let me add it |
||
| self.assert_eq( | ||
| (psdf + 1).corr(min_periods=2), (pdf + 1).corr(min_periods=2), check_exact=False | ||
| ) | ||
|
|
||
| def test_corr(self): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. existing test is mixed by
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we comment this at the top of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure |
||
| # Disable arrow execution since corr() is using UDT internally which is not supported. | ||
| with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice