diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index cf14a5482660c..f6aead93f9eec 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -121,7 +121,6 @@ SPARK_INDEX_NAME_PATTERN, ) from pyspark.pandas.missing.frame import _MissingPandasLikeDataFrame -from pyspark.pandas.ml import corr from pyspark.pandas.typedef.typehints import ( as_spark_type, infer_return_type, @@ -1430,8 +1429,7 @@ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "D * spearman : Spearman rank correlation min_periods : int, optional Minimum number of observations required per pair of columns - to have a valid result. Currently only available for Pearson - correlation. + to have a valid result. .. versionadded:: 3.4.0 @@ -1462,8 +1460,6 @@ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "D There are behavior differences between pandas-on-Spark and pandas. * the `method` argument only accepts 'pearson', 'spearman' - * if the `method` is `spearman`, the data should not contain NaNs. - * if the `method` is `spearman`, `min_periods` argument is not supported. """ if method not in ["pearson", "spearman", "kendall"]: raise ValueError(f"Invalid method {method}") @@ -1471,194 +1467,251 @@ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "D raise NotImplementedError("method doesn't support kendall for now") if min_periods is not None and not isinstance(min_periods, int): raise TypeError(f"Invalid min_periods type {type(min_periods).__name__}") - if min_periods is not None and method == "spearman": - raise NotImplementedError("min_periods doesn't support spearman for now") - - if method == "pearson": - min_periods = 1 if min_periods is None else min_periods - internal = self._internal.resolved_copy - numeric_labels = [ - label - for label in internal.column_labels - if isinstance(internal.spark_type_for(label), (NumericType, BooleanType)) - ] - numeric_scols: List[Column] = [ - internal.spark_column_for(label).cast("double") for label in numeric_labels - ] - numeric_col_names: List[str] = [name_like_string(label) for label in numeric_labels] - num_scols = len(numeric_scols) - sdf = internal.spark_frame - tmp_index_1_col_name = verify_temp_column_name(sdf, "__tmp_index_1_col__") - tmp_index_2_col_name = verify_temp_column_name(sdf, "__tmp_index_2_col__") - tmp_value_1_col_name = verify_temp_column_name(sdf, "__tmp_value_1_col__") - tmp_value_2_col_name = verify_temp_column_name(sdf, "__tmp_value_2_col__") - - # simple dataset - # +---+---+----+ - # | A| B| C| - # +---+---+----+ - # | 1| 2| 3.0| - # | 4| 1|null| - # +---+---+----+ - - pair_scols: List[Column] = [] - for i in range(0, num_scols): - for j in range(i, num_scols): - pair_scols.append( - F.struct( - F.lit(i).alias(tmp_index_1_col_name), - F.lit(j).alias(tmp_index_2_col_name), - numeric_scols[i].alias(tmp_value_1_col_name), - numeric_scols[j].alias(tmp_value_2_col_name), - ) + min_periods = 1 if min_periods is None else min_periods + internal = self._internal.resolved_copy + numeric_labels = [ + label + for label in internal.column_labels + if isinstance(internal.spark_type_for(label), (NumericType, BooleanType)) + ] + numeric_scols: List[Column] = [ + internal.spark_column_for(label).cast("double") for label in numeric_labels + ] + numeric_col_names: List[str] = [name_like_string(label) for label in numeric_labels] + num_scols = len(numeric_scols) + + sdf = internal.spark_frame + tmp_index_1_col_name = verify_temp_column_name(sdf, "__tmp_index_1_col__") + tmp_index_2_col_name = verify_temp_column_name(sdf, "__tmp_index_2_col__") + tmp_value_1_col_name = verify_temp_column_name(sdf, "__tmp_value_1_col__") + tmp_value_2_col_name = verify_temp_column_name(sdf, "__tmp_value_2_col__") + + # simple dataset + # +---+---+----+ + # | A| B| C| + # +---+---+----+ + # | 1| 2| 3.0| + # | 4| 1|null| + # +---+---+----+ + + pair_scols: List[Column] = [] + for i in range(0, num_scols): + for j in range(i, num_scols): + pair_scols.append( + F.struct( + F.lit(i).alias(tmp_index_1_col_name), + F.lit(j).alias(tmp_index_2_col_name), + numeric_scols[i].alias(tmp_value_1_col_name), + numeric_scols[j].alias(tmp_value_2_col_name), ) + ) - # +-------------------+-------------------+-------------------+-------------------+ - # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_value_1_col__|__tmp_value_2_col__| - # +-------------------+-------------------+-------------------+-------------------+ - # | 0| 0| 1.0| 1.0| - # | 0| 1| 1.0| 2.0| - # | 0| 2| 1.0| 3.0| - # | 1| 1| 2.0| 2.0| - # | 1| 2| 2.0| 3.0| - # | 2| 2| 3.0| 3.0| - # | 0| 0| 4.0| 4.0| - # | 0| 1| 4.0| 1.0| - # | 0| 2| 4.0| null| - # | 1| 1| 1.0| 1.0| - # | 1| 2| 1.0| null| - # | 2| 2| null| null| - # +-------------------+-------------------+-------------------+-------------------+ - tmp_tuple_col_name = verify_temp_column_name(sdf, "__tmp_tuple_col__") - sdf = sdf.select(F.explode(F.array(*pair_scols)).alias(tmp_tuple_col_name)).select( - F.col(f"{tmp_tuple_col_name}.{tmp_index_1_col_name}").alias(tmp_index_1_col_name), - F.col(f"{tmp_tuple_col_name}.{tmp_index_2_col_name}").alias(tmp_index_2_col_name), - F.col(f"{tmp_tuple_col_name}.{tmp_value_1_col_name}").alias(tmp_value_1_col_name), - F.col(f"{tmp_tuple_col_name}.{tmp_value_2_col_name}").alias(tmp_value_2_col_name), - ) + # +-------------------+-------------------+-------------------+-------------------+ + # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_value_1_col__|__tmp_value_2_col__| + # +-------------------+-------------------+-------------------+-------------------+ + # | 0| 0| 1.0| 1.0| + # | 0| 1| 1.0| 2.0| + # | 0| 2| 1.0| 3.0| + # | 1| 1| 2.0| 2.0| + # | 1| 2| 2.0| 3.0| + # | 2| 2| 3.0| 3.0| + # | 0| 0| 4.0| 4.0| + # | 0| 1| 4.0| 1.0| + # | 0| 2| null| null| + # | 1| 1| 1.0| 1.0| + # | 1| 2| null| null| + # | 2| 2| null| null| + # +-------------------+-------------------+-------------------+-------------------+ + tmp_tuple_col_name = verify_temp_column_name(sdf, "__tmp_tuple_col__") + null_cond = F.isnull(F.col(f"{tmp_tuple_col_name}.{tmp_value_1_col_name}")) | F.isnull( + F.col(f"{tmp_tuple_col_name}.{tmp_value_2_col_name}") + ) + sdf = sdf.select(F.explode(F.array(*pair_scols)).alias(tmp_tuple_col_name)).select( + F.col(f"{tmp_tuple_col_name}.{tmp_index_1_col_name}").alias(tmp_index_1_col_name), + F.col(f"{tmp_tuple_col_name}.{tmp_index_2_col_name}").alias(tmp_index_2_col_name), + F.when(null_cond, F.lit(None)) + .otherwise(F.col(f"{tmp_tuple_col_name}.{tmp_value_1_col_name}")) + .alias(tmp_value_1_col_name), + F.when(null_cond, F.lit(None)) + .otherwise(F.col(f"{tmp_tuple_col_name}.{tmp_value_2_col_name}")) + .alias(tmp_value_2_col_name), + ) - # +-------------------+-------------------+------------------------+-----------------+ - # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_pearson_corr_col__|__tmp_count_col__| - # +-------------------+-------------------+------------------------+-----------------+ - # | 2| 2| null| 1| - # | 1| 2| null| 1| - # | 1| 1| 1.0| 2| - # | 0| 0| 1.0| 2| - # | 0| 1| -1.0| 2| - # | 0| 2| null| 1| - # +-------------------+-------------------+------------------------+-----------------+ - tmp_corr_col_name = verify_temp_column_name(sdf, "__tmp_pearson_corr_col__") - tmp_count_col_name = verify_temp_column_name(sdf, "__tmp_count_col__") - sdf = sdf.groupby(tmp_index_1_col_name, tmp_index_2_col_name).agg( - F.corr(tmp_value_1_col_name, tmp_value_2_col_name).alias(tmp_corr_col_name), - F.count( - F.when( - F.col(tmp_value_1_col_name).isNotNull() - & F.col(tmp_value_2_col_name).isNotNull(), - 1, - ) - ).alias(tmp_count_col_name), - ) + # convert values to avg ranks for spearman correlation + if method == "spearman": + tmp_row_number_col_name = verify_temp_column_name(sdf, "__tmp_row_number_col__") + tmp_dense_rank_col_name = verify_temp_column_name(sdf, "__tmp_dense_rank_col__") + window = Window.partitionBy(tmp_index_1_col_name, tmp_index_2_col_name) - # +-------------------+-------------------+------------------------+ - # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_pearson_corr_col__| - # +-------------------+-------------------+------------------------+ - # | 2| 2| null| - # | 1| 2| null| - # | 2| 1| null| - # | 1| 1| 1.0| - # | 0| 0| 1.0| - # | 0| 1| -1.0| - # | 1| 0| -1.0| - # | 0| 2| null| - # | 2| 0| null| - # +-------------------+-------------------+------------------------+ + # tmp_value_1_col_name: value -> avg rank + # for example: + # values: 3, 4, 5, 7, 7, 7, 9, 9, 10 + # avg ranks: 1.0, 2.0, 3.0, 5.0, 5.0, 5.0, 7.5, 7.5, 9.0 sdf = ( sdf.withColumn( - tmp_corr_col_name, - F.when( - F.col(tmp_count_col_name) >= min_periods, F.col(tmp_corr_col_name) - ).otherwise(F.lit(None)), + tmp_row_number_col_name, + F.row_number().over(window.orderBy(F.asc_nulls_last(tmp_value_1_col_name))), ) .withColumn( - tmp_tuple_col_name, - F.explode( - F.when( - F.col(tmp_index_1_col_name) == F.col(tmp_index_2_col_name), - F.lit([0]), - ).otherwise(F.lit([0, 1])) - ), + tmp_dense_rank_col_name, + F.dense_rank().over(window.orderBy(F.asc_nulls_last(tmp_value_1_col_name))), ) - .select( - F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_1_col_name)) - .otherwise(F.col(tmp_index_2_col_name)) - .alias(tmp_index_1_col_name), - F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_2_col_name)) - .otherwise(F.col(tmp_index_1_col_name)) - .alias(tmp_index_2_col_name), - F.col(tmp_corr_col_name), + .withColumn( + tmp_value_1_col_name, + F.when(F.isnull(F.col(tmp_value_1_col_name)), F.lit(None)).otherwise( + F.avg(tmp_row_number_col_name).over( + window.orderBy(F.asc(tmp_dense_rank_col_name)).rangeBetween(0, 0) + ) + ), ) ) - # +-------------------+--------------------+ - # |__tmp_index_1_col__| __tmp_array_col__| - # +-------------------+--------------------+ - # | 0|[{0, 1.0}, {1, -1...| - # | 1|[{0, -1.0}, {1, 1...| - # | 2|[{0, null}, {1, n...| - # +-------------------+--------------------+ - tmp_array_col_name = verify_temp_column_name(sdf, "__tmp_array_col__") + # tmp_value_2_col_name: value -> avg rank sdf = ( - sdf.groupby(tmp_index_1_col_name) - .agg( - F.array_sort( - F.collect_list( - F.struct(F.col(tmp_index_2_col_name), F.col(tmp_corr_col_name)) + sdf.withColumn( + tmp_row_number_col_name, + F.row_number().over(window.orderBy(F.asc_nulls_last(tmp_value_2_col_name))), + ) + .withColumn( + tmp_dense_rank_col_name, + F.dense_rank().over(window.orderBy(F.asc_nulls_last(tmp_value_2_col_name))), + ) + .withColumn( + tmp_value_2_col_name, + F.when(F.isnull(F.col(tmp_value_2_col_name)), F.lit(None)).otherwise( + F.avg(tmp_row_number_col_name).over( + window.orderBy(F.asc(tmp_dense_rank_col_name)).rangeBetween(0, 0) ) - ).alias(tmp_array_col_name) + ), ) - .orderBy(tmp_index_1_col_name) ) - for i in range(0, num_scols): - sdf = sdf.withColumn( - tmp_tuple_col_name, F.get(F.col(tmp_array_col_name), i) - ).withColumn( - numeric_col_names[i], - F.col(f"{tmp_tuple_col_name}.{tmp_corr_col_name}"), - ) + sdf = sdf.select( + tmp_index_1_col_name, + tmp_index_2_col_name, + tmp_value_1_col_name, + tmp_value_2_col_name, + ) - index_col_names: List[str] = [] - if internal.column_labels_level > 1: - for level in range(0, internal.column_labels_level): - index_col_name = SPARK_INDEX_NAME_FORMAT(level) - indices = [label[level] for label in numeric_labels] - sdf = sdf.withColumn( - index_col_name, F.get(F.lit(indices), F.col(tmp_index_1_col_name)) - ) - index_col_names.append(index_col_name) - else: - sdf = sdf.withColumn( - SPARK_DEFAULT_INDEX_NAME, - F.get(F.lit(numeric_col_names), F.col(tmp_index_1_col_name)), + # +-------------------+-------------------+----------------+-----------------+ + # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_corr_col__|__tmp_count_col__| + # +-------------------+-------------------+----------------+-----------------+ + # | 2| 2| null| 1| + # | 1| 2| null| 1| + # | 1| 1| 1.0| 2| + # | 0| 0| 1.0| 2| + # | 0| 1| -1.0| 2| + # | 0| 2| null| 1| + # +-------------------+-------------------+----------------+-----------------+ + tmp_corr_col_name = verify_temp_column_name(sdf, "__tmp_corr_col__") + tmp_count_col_name = verify_temp_column_name(sdf, "__tmp_count_col__") + + sdf = sdf.groupby(tmp_index_1_col_name, tmp_index_2_col_name).agg( + F.corr(tmp_value_1_col_name, tmp_value_2_col_name).alias(tmp_corr_col_name), + F.count( + F.when( + F.col(tmp_value_1_col_name).isNotNull() + & F.col(tmp_value_2_col_name).isNotNull(), + 1, ) - index_col_names = [SPARK_DEFAULT_INDEX_NAME] + ).alias(tmp_count_col_name), + ) - sdf = sdf.select(*index_col_names, *numeric_col_names) + # +-------------------+-------------------+----------------+ + # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_corr_col__| + # +-------------------+-------------------+----------------+ + # | 2| 2| null| + # | 1| 2| null| + # | 2| 1| null| + # | 1| 1| 1.0| + # | 0| 0| 1.0| + # | 0| 1| -1.0| + # | 1| 0| -1.0| + # | 0| 2| null| + # | 2| 0| null| + # +-------------------+-------------------+----------------+ + sdf = ( + sdf.withColumn( + tmp_corr_col_name, + F.when( + F.col(tmp_count_col_name) >= min_periods, F.col(tmp_corr_col_name) + ).otherwise(F.lit(None)), + ) + .withColumn( + tmp_tuple_col_name, + F.explode( + F.when( + F.col(tmp_index_1_col_name) == F.col(tmp_index_2_col_name), + F.lit([0]), + ).otherwise(F.lit([0, 1])) + ), + ) + .select( + F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_1_col_name)) + .otherwise(F.col(tmp_index_2_col_name)) + .alias(tmp_index_1_col_name), + F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_2_col_name)) + .otherwise(F.col(tmp_index_1_col_name)) + .alias(tmp_index_2_col_name), + F.col(tmp_corr_col_name), + ) + ) - return DataFrame( - InternalFrame( - spark_frame=sdf, - index_spark_columns=[ - scol_for(sdf, index_col_name) for index_col_name in index_col_names - ], - column_labels=numeric_labels, - column_label_names=internal.column_label_names, + # +-------------------+--------------------+ + # |__tmp_index_1_col__| __tmp_array_col__| + # +-------------------+--------------------+ + # | 0|[{0, 1.0}, {1, -1...| + # | 1|[{0, -1.0}, {1, 1...| + # | 2|[{0, null}, {1, n...| + # +-------------------+--------------------+ + tmp_array_col_name = verify_temp_column_name(sdf, "__tmp_array_col__") + sdf = ( + sdf.groupby(tmp_index_1_col_name) + .agg( + F.array_sort( + F.collect_list(F.struct(F.col(tmp_index_2_col_name), F.col(tmp_corr_col_name))) + ).alias(tmp_array_col_name) + ) + .orderBy(tmp_index_1_col_name) + ) + + for i in range(0, num_scols): + sdf = sdf.withColumn( + tmp_tuple_col_name, F.get(F.col(tmp_array_col_name), i) + ).withColumn( + numeric_col_names[i], + F.col(f"{tmp_tuple_col_name}.{tmp_corr_col_name}"), + ) + + index_col_names: List[str] = [] + if internal.column_labels_level > 1: + for level in range(0, internal.column_labels_level): + index_col_name = SPARK_INDEX_NAME_FORMAT(level) + indices = [label[level] for label in numeric_labels] + sdf = sdf.withColumn( + index_col_name, F.get(F.lit(indices), F.col(tmp_index_1_col_name)) ) + index_col_names.append(index_col_name) + else: + sdf = sdf.withColumn( + SPARK_DEFAULT_INDEX_NAME, + F.get(F.lit(numeric_col_names), F.col(tmp_index_1_col_name)), ) + index_col_names = [SPARK_DEFAULT_INDEX_NAME] - return cast(DataFrame, ps.from_pandas(corr(self, method))) + sdf = sdf.select(*index_col_names, *numeric_col_names) + + return DataFrame( + InternalFrame( + spark_frame=sdf, + index_spark_columns=[ + scol_for(sdf, index_col_name) for index_col_name in index_col_names + ], + column_labels=numeric_labels, + column_label_names=internal.column_label_names, + ) + ) # TODO: add axis parameter and support more methods def corrwith( diff --git a/python/pyspark/pandas/tests/test_stats.py b/python/pyspark/pandas/tests/test_stats.py index 7e2ca96e60ff1..fbe16146ff296 100644 --- a/python/pyspark/pandas/tests/test_stats.py +++ b/python/pyspark/pandas/tests/test_stats.py @@ -269,26 +269,68 @@ def test_dataframe_corr(self): psdf.corr("kendall") with self.assertRaisesRegex(TypeError, "Invalid min_periods type"): psdf.corr(min_periods="3") - with self.assertRaisesRegex(NotImplementedError, "spearman for now"): - psdf.corr(method="spearman", min_periods=3) - self.assert_eq(psdf.corr(), pdf.corr(), check_exact=False) - self.assert_eq(psdf.corr(min_periods=1), pdf.corr(min_periods=1), check_exact=False) - self.assert_eq(psdf.corr(min_periods=3), pdf.corr(min_periods=3), check_exact=False) - self.assert_eq( - (psdf + 1).corr(min_periods=2), (pdf + 1).corr(min_periods=2), check_exact=False - ) + for method in ["pearson", "spearman"]: + self.assert_eq(psdf.corr(method=method), pdf.corr(method=method), check_exact=False) + self.assert_eq( + psdf.corr(method=method, min_periods=1), + pdf.corr(method=method, min_periods=1), + check_exact=False, + ) + self.assert_eq( + psdf.corr(method=method, min_periods=3), + pdf.corr(method=method, min_periods=3), + check_exact=False, + ) + self.assert_eq( + (psdf + 1).corr(method=method, min_periods=2), + (pdf + 1).corr(method=method, min_periods=2), + check_exact=False, + ) # multi-index columns columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B"), ("Y", "C"), ("Z", "D")]) pdf.columns = columns psdf.columns = columns - self.assert_eq(psdf.corr(), pdf.corr(), check_exact=False) - self.assert_eq(psdf.corr(min_periods=1), pdf.corr(min_periods=1), check_exact=False) - self.assert_eq(psdf.corr(min_periods=3), pdf.corr(min_periods=3), check_exact=False) + for method in ["pearson", "spearman"]: + self.assert_eq(psdf.corr(method=method), pdf.corr(method=method), check_exact=False) + self.assert_eq( + psdf.corr(method=method, min_periods=1), + pdf.corr(method=method, min_periods=1), + check_exact=False, + ) + self.assert_eq( + psdf.corr(method=method, min_periods=3), + pdf.corr(method=method, min_periods=3), + check_exact=False, + ) + self.assert_eq( + (psdf + 1).corr(method=method, min_periods=2), + (pdf + 1).corr(method=method, min_periods=2), + check_exact=False, + ) + + # test spearman with identical values + pdf = pd.DataFrame( + { + "a": [0, 1, 1, 1, 0], + "b": [2, 2, -1, 1, np.nan], + "c": [3, 3, 3, 3, 3], + "d": [np.nan, np.nan, np.nan, np.nan, np.nan], + } + ) + psdf = ps.from_pandas(pdf) + self.assert_eq(psdf.corr(method="spearman"), pdf.corr(method="spearman"), check_exact=False) self.assert_eq( - (psdf + 1).corr(min_periods=2), (pdf + 1).corr(min_periods=2), check_exact=False + psdf.corr(method="spearman", min_periods=1), + pdf.corr(method="spearman", min_periods=1), + check_exact=False, + ) + self.assert_eq( + psdf.corr(method="spearman", min_periods=3), + pdf.corr(method="spearman", min_periods=3), + check_exact=False, ) def test_corr(self):