Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 204 additions & 5 deletions python/pyspark/pandas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,15 +1417,23 @@ def aggregate(self, func: Union[List[str], Dict[Name, List[str]]]) -> "DataFrame

agg = aggregate

def corr(self, method: str = "pearson") -> "DataFrame":
def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "DataFrame":
"""
Compute pairwise correlation of columns, excluding NA/null values.

.. versionadded:: 3.3.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice


Parameters
----------
method : {'pearson', 'spearman'}
* pearson : standard correlation coefficient
* spearman : Spearman rank correlation
min_periods : int, optional
Minimum number of observations required per pair of columns
to have a valid result. Currently only available for Pearson
correlation.

.. versionadded:: 3.4.0

Returns
-------
Expand Down Expand Up @@ -1454,11 +1462,202 @@ def corr(self, method: str = "pearson") -> "DataFrame":
There are behavior differences between pandas-on-Spark and pandas.

* the `method` argument only accepts 'pearson', 'spearman'
* the data should not contain NaNs. pandas-on-Spark will return an error.
* pandas-on-Spark doesn't support the following argument(s).
* if the `method` is `spearman`, the data should not contain NaNs.
* if the `method` is `spearman`, `min_periods` argument is not supported.
"""
if method not in ["pearson", "spearman", "kendall"]:
raise ValueError(f"Invalid method {method}")
if method == "kendall":
raise NotImplementedError("method doesn't support kendall for now")
if min_periods is not None and not isinstance(min_periods, int):
raise TypeError(f"Invalid min_periods type {type(min_periods).__name__}")
Comment on lines +1472 to +1473
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like pandas allows float:

>>> pdf.corr('pearson', min_periods=1.4)
          dogs      cats
dogs  1.000000 -0.851064
cats -0.851064  1.000000

But I'm not sure if it's intended behavior or not, since they raises TypeError: an integer is required when the type is str as below:

>>> pdf.corr('pearson', min_periods='a')
Traceback (most recent call last):
...
TypeError: an integer is required
>>>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also not sure, but the type of min_periods is also expected to be int in Pandas.
I think that pdf.corr('pearson', min_periods=1.4) can work in Pandas just because a validation is missing in Pandas

if min_periods is not None and method == "spearman":
raise NotImplementedError("min_periods doesn't support spearman for now")

if method == "pearson":
min_periods = 1 if min_periods is None else min_periods
internal = self._internal.resolved_copy
numeric_labels = [
label
for label in internal.column_labels
if isinstance(internal.spark_type_for(label), (NumericType, BooleanType))
]
numeric_scols: List[Column] = [
internal.spark_column_for(label).cast("double") for label in numeric_labels
]
numeric_col_names: List[str] = [name_like_string(label) for label in numeric_labels]
num_scols = len(numeric_scols)

sdf = internal.spark_frame
tmp_index_1_col_name = verify_temp_column_name(sdf, "__tmp_index_1_col__")
tmp_index_2_col_name = verify_temp_column_name(sdf, "__tmp_index_2_col__")
tmp_value_1_col_name = verify_temp_column_name(sdf, "__tmp_value_1_col__")
tmp_value_2_col_name = verify_temp_column_name(sdf, "__tmp_value_2_col__")

# simple dataset
# +---+---+----+
# | A| B| C|
# +---+---+----+
# | 1| 2| 3.0|
# | 4| 1|null|
# +---+---+----+

pair_scols: List[Column] = []
for i in range(0, num_scols):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can omit the 0 since it's default ?

-  for i in range(0, num_scols):
+  for i in range(num_scols):

Either looks okay, though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just because that other places use range(0, x) instead of range(x)

for j in range(i, num_scols):
pair_scols.append(
F.struct(
F.lit(i).alias(tmp_index_1_col_name),
F.lit(j).alias(tmp_index_2_col_name),
numeric_scols[i].alias(tmp_value_1_col_name),
numeric_scols[j].alias(tmp_value_2_col_name),
)
)

# +-------------------+-------------------+-------------------+-------------------+
# |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_value_1_col__|__tmp_value_2_col__|
# +-------------------+-------------------+-------------------+-------------------+
# | 0| 0| 1.0| 1.0|
# | 0| 1| 1.0| 2.0|
# | 0| 2| 1.0| 3.0|
# | 1| 1| 2.0| 2.0|
# | 1| 2| 2.0| 3.0|
# | 2| 2| 3.0| 3.0|
# | 0| 0| 4.0| 4.0|
# | 0| 1| 4.0| 1.0|
# | 0| 2| 4.0| null|
# | 1| 1| 1.0| 1.0|
# | 1| 2| 1.0| null|
# | 2| 2| null| null|
# +-------------------+-------------------+-------------------+-------------------+
tmp_tuple_col_name = verify_temp_column_name(sdf, "__tmp_tuple_col__")
sdf = sdf.select(F.explode(F.array(*pair_scols)).alias(tmp_tuple_col_name)).select(
F.col(f"{tmp_tuple_col_name}.{tmp_index_1_col_name}").alias(tmp_index_1_col_name),
F.col(f"{tmp_tuple_col_name}.{tmp_index_2_col_name}").alias(tmp_index_2_col_name),
F.col(f"{tmp_tuple_col_name}.{tmp_value_1_col_name}").alias(tmp_value_1_col_name),
F.col(f"{tmp_tuple_col_name}.{tmp_value_2_col_name}").alias(tmp_value_2_col_name),
)

# +-------------------+-------------------+------------------------+-----------------+
# |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_pearson_corr_col__|__tmp_count_col__|
# +-------------------+-------------------+------------------------+-----------------+
# | 2| 2| null| 1|
# | 1| 2| null| 1|
# | 1| 1| 1.0| 2|
# | 0| 0| 1.0| 2|
# | 0| 1| -1.0| 2|
# | 0| 2| null| 1|
# +-------------------+-------------------+------------------------+-----------------+
tmp_corr_col_name = verify_temp_column_name(sdf, "__tmp_pearson_corr_col__")
tmp_count_col_name = verify_temp_column_name(sdf, "__tmp_count_col__")
sdf = sdf.groupby(tmp_index_1_col_name, tmp_index_2_col_name).agg(
F.corr(tmp_value_1_col_name, tmp_value_2_col_name).alias(tmp_corr_col_name),
F.count(
F.when(
F.col(tmp_value_1_col_name).isNotNull()
& F.col(tmp_value_2_col_name).isNotNull(),
1,
)
).alias(tmp_count_col_name),
)

# +-------------------+-------------------+------------------------+
# |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_pearson_corr_col__|
# +-------------------+-------------------+------------------------+
# | 2| 2| null|
# | 1| 2| null|
# | 2| 1| null|
# | 1| 1| 1.0|
# | 0| 0| 1.0|
# | 0| 1| -1.0|
# | 1| 0| -1.0|
# | 0| 2| null|
# | 2| 0| null|
# +-------------------+-------------------+------------------------+
sdf = (
sdf.withColumn(
tmp_corr_col_name,
F.when(
F.col(tmp_count_col_name) >= min_periods, F.col(tmp_corr_col_name)
).otherwise(F.lit(None)),
)
.withColumn(
tmp_tuple_col_name,
F.explode(
F.when(
F.col(tmp_index_1_col_name) == F.col(tmp_index_2_col_name),
F.lit([0]),
).otherwise(F.lit([0, 1]))
),
)
.select(
F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_1_col_name))
.otherwise(F.col(tmp_index_2_col_name))
.alias(tmp_index_1_col_name),
F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_2_col_name))
.otherwise(F.col(tmp_index_1_col_name))
.alias(tmp_index_2_col_name),
F.col(tmp_corr_col_name),
)
)

# +-------------------+--------------------+
# |__tmp_index_1_col__| __tmp_array_col__|
# +-------------------+--------------------+
# | 0|[{0, 1.0}, {1, -1...|
# | 1|[{0, -1.0}, {1, 1...|
# | 2|[{0, null}, {1, n...|
# +-------------------+--------------------+
tmp_array_col_name = verify_temp_column_name(sdf, "__tmp_array_col__")
sdf = (
sdf.groupby(tmp_index_1_col_name)
.agg(
F.array_sort(
F.collect_list(
F.struct(F.col(tmp_index_2_col_name), F.col(tmp_corr_col_name))
)
).alias(tmp_array_col_name)
)
.orderBy(tmp_index_1_col_name)
)

for i in range(0, num_scols):
sdf = sdf.withColumn(
tmp_tuple_col_name, F.get(F.col(tmp_array_col_name), i)
).withColumn(
numeric_col_names[i],
F.col(f"{tmp_tuple_col_name}.{tmp_corr_col_name}"),
)

index_col_names: List[str] = []
if internal.column_labels_level > 1:
for level in range(0, internal.column_labels_level):
index_col_name = SPARK_INDEX_NAME_FORMAT(level)
indices = [label[level] for label in numeric_labels]
sdf = sdf.withColumn(
index_col_name, F.get(F.lit(indices), F.col(tmp_index_1_col_name))
)
index_col_names.append(index_col_name)
else:
sdf = sdf.withColumn(
SPARK_DEFAULT_INDEX_NAME,
F.get(F.lit(numeric_col_names), F.col(tmp_index_1_col_name)),
)
index_col_names = [SPARK_DEFAULT_INDEX_NAME]

sdf = sdf.select(*index_col_names, *numeric_col_names)

return DataFrame(
InternalFrame(
spark_frame=sdf,
index_spark_columns=[
scol_for(sdf, index_col_name) for index_col_name in index_col_names
],
column_labels=numeric_labels,
column_label_names=internal.column_label_names,
)
)

* `min_periods` argument is not supported
"""
return cast(DataFrame, ps.from_pandas(corr(self, method)))

# TODO: add axis parameter and support more methods
Expand Down
34 changes: 34 additions & 0 deletions python/pyspark/pandas/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,40 @@ def test_skew_kurt_numerical_stability(self):
self.assert_eq(psdf.skew(), pdf.skew(), almost=True)
self.assert_eq(psdf.kurt(), pdf.kurt(), almost=True)

def test_dataframe_corr(self):
# existing 'test_corr' is mixed by df.corr and ser.corr, will delete 'test_corr'
# when we have separate tests for df.corr and ser.corr
pdf = makeMissingDataframe(0.3, 42)
psdf = ps.from_pandas(pdf)

with self.assertRaisesRegex(ValueError, "Invalid method"):
psdf.corr("std")
with self.assertRaisesRegex(NotImplementedError, "kendall for now"):
psdf.corr("kendall")
with self.assertRaisesRegex(TypeError, "Invalid min_periods type"):
psdf.corr(min_periods="3")
with self.assertRaisesRegex(NotImplementedError, "spearman for now"):
psdf.corr(method="spearman", min_periods=3)

self.assert_eq(psdf.corr(), pdf.corr(), check_exact=False)
self.assert_eq(psdf.corr(min_periods=1), pdf.corr(min_periods=1), check_exact=False)
self.assert_eq(psdf.corr(min_periods=3), pdf.corr(min_periods=3), check_exact=False)
self.assert_eq(
(psdf + 1).corr(min_periods=2), (pdf + 1).corr(min_periods=2), check_exact=False
)

# multi-index columns
columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B"), ("Y", "C"), ("Z", "D")])
pdf.columns = columns
psdf.columns = columns

self.assert_eq(psdf.corr(), pdf.corr(), check_exact=False)
self.assert_eq(psdf.corr(min_periods=1), pdf.corr(min_periods=1), check_exact=False)
self.assert_eq(psdf.corr(min_periods=3), pdf.corr(min_periods=3), check_exact=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also test for chained operations?
e.g.

self.assert_eq((psdf + 1).corr(), (pdf + 1).corr(), check_exact=False)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, let me add it

self.assert_eq(
(psdf + 1).corr(min_periods=2), (pdf + 1).corr(min_periods=2), check_exact=False
)

def test_corr(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

existing test is mixed by df.corr and ser.corr, not easy to reuse, so add a new one.
will delete it after both df.corr and ser.corr are refactored

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we comment this at the top of test_dataframe_corr so as not to forget ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

# Disable arrow execution since corr() is using UDT internally which is not supported.
with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):
Expand Down