Skip to content

Commit

Permalink
[SPARK-27387][PYTHON][TESTS] Replace sqlutils.assertPandasEqual with …
Browse files Browse the repository at this point in the history
…Pandas assert_frame_equals

## What changes were proposed in this pull request?

Running PySpark tests with Pandas 0.24.x causes a failure in `test_pandas_udf_grouped_map` test_supported_types:
`ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()`

This is because a column is an ArrayType and the method `sqlutils ReusedSQLTestCase.assertPandasEqual ` does not properly check this.

This PR removes `assertPandasEqual` and replaces it with the built-in `pandas.util.testing.assert_frame_equal` which can properly handle columns of ArrayType and also prints out better diff between the DataFrames when an error occurs.

Additionally, imports of pandas and pyarrow were moved to the top of related test files to avoid duplicating the same import many times.

## How was this patch tested?

Existing tests

Closes apache#24306 from BryanCutler/python-pandas-assert_frame_equal-SPARK-27387.

Authored-by: Bryan Cutler <cutlerb@gmail.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
  • Loading branch information
BryanCutler authored and rshkv committed May 1, 2020
1 parent da8a7ff commit 95c106c
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 137 deletions.
44 changes: 18 additions & 26 deletions python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@
from pyspark.testing.utils import QuietTest
from pyspark.util import _exception_message

if have_pandas:
import pandas as pd
from pandas.util.testing import assert_frame_equal

if have_pyarrow:
import pyarrow as pa


@unittest.skipIf(
not have_pandas or not have_pyarrow,
Expand All @@ -41,7 +48,6 @@ def setUpClass(cls):
from datetime import date, datetime
from decimal import Decimal
from distutils.version import LooseVersion
import pyarrow as pa
super(ArrowTests, cls).setUpClass()
cls.warnings_lock = threading.Lock()

Expand Down Expand Up @@ -90,7 +96,6 @@ def tearDownClass(cls):
super(ArrowTests, cls).tearDownClass()

def create_pandas_data_frame(self):
import pandas as pd
import numpy as np
data_dict = {}
for j, name in enumerate(self.schema.names):
Expand All @@ -101,8 +106,6 @@ def create_pandas_data_frame(self):
return pd.DataFrame(data=data_dict)

def test_toPandas_fallback_enabled(self):
import pandas as pd

with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}):
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([({u'a': 1},)], schema=schema)
Expand All @@ -118,11 +121,10 @@ def test_toPandas_fallback_enabled(self):
self.assertTrue(len(user_warns) > 0)
self.assertTrue(
"Attempting non-optimization" in _exception_message(user_warns[-1]))
self.assertPandasEqual(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))
assert_frame_equal(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))

def test_toPandas_fallback_disabled(self):
from distutils.version import LooseVersion
import pyarrow as pa

schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
Expand Down Expand Up @@ -158,8 +160,8 @@ def test_toPandas_arrow_toggle(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
expected = self.create_pandas_data_frame()
self.assertPandasEqual(expected, pdf)
self.assertPandasEqual(expected, pdf_arrow)
assert_frame_equal(expected, pdf)
assert_frame_equal(expected, pdf_arrow)

def test_toPandas_respect_session_timezone(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
Expand All @@ -169,13 +171,13 @@ def test_toPandas_respect_session_timezone(self):
"spark.sql.execution.pandas.respectSessionTimeZone": False,
"spark.sql.session.timeZone": timezone}):
pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
self.assertPandasEqual(pdf_arrow_la, pdf_la)
assert_frame_equal(pdf_arrow_la, pdf_la)

with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": True,
"spark.sql.session.timeZone": timezone}):
pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df)
self.assertPandasEqual(pdf_arrow_ny, pdf_ny)
assert_frame_equal(pdf_arrow_ny, pdf_ny)

self.assertFalse(pdf_ny.equals(pdf_la))

Expand All @@ -185,13 +187,13 @@ def test_toPandas_respect_session_timezone(self):
if isinstance(field.dataType, TimestampType):
pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz(
pdf_la_corrected[field.name], timezone)
self.assertPandasEqual(pdf_ny, pdf_la_corrected)
assert_frame_equal(pdf_ny, pdf_la_corrected)

def test_pandas_round_trip(self):
pdf = self.create_pandas_data_frame()
df = self.spark.createDataFrame(self.data, schema=self.schema)
pdf_arrow = df.toPandas()
self.assertPandasEqual(pdf_arrow, pdf)
assert_frame_equal(pdf_arrow, pdf)

def test_filtered_frame(self):
df = self.spark.range(3).toDF("i")
Expand Down Expand Up @@ -265,7 +267,7 @@ def test_createDataFrame_with_schema(self):
df = self.spark.createDataFrame(pdf, schema=self.schema)
self.assertEquals(self.schema, df.schema)
pdf_arrow = df.toPandas()
self.assertPandasEqual(pdf_arrow, pdf)
assert_frame_equal(pdf_arrow, pdf)

def test_createDataFrame_with_incorrect_schema(self):
pdf = self.create_pandas_data_frame()
Expand All @@ -287,7 +289,6 @@ def test_createDataFrame_with_names(self):
self.assertEquals(df.schema.fieldNames(), new_names)

def test_createDataFrame_column_name_encoding(self):
import pandas as pd
pdf = pd.DataFrame({u'a': [1]})
columns = self.spark.createDataFrame(pdf).columns
self.assertTrue(isinstance(columns[0], str))
Expand All @@ -297,13 +298,11 @@ def test_createDataFrame_column_name_encoding(self):
self.assertEquals(columns[0], 'b')

def test_createDataFrame_with_single_data_type(self):
import pandas as pd
with QuietTest(self.sc):
with self.assertRaisesRegexp(ValueError, ".*IntegerType.*not supported.*"):
self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")

def test_createDataFrame_does_not_modify_input(self):
import pandas as pd
# Some series get converted for Spark to consume, this makes sure input is unchanged
pdf = self.create_pandas_data_frame()
# Use a nanosecond value to make sure it is not truncated
Expand All @@ -321,7 +320,6 @@ def test_schema_conversion_roundtrip(self):
self.assertEquals(self.schema, schema_rt)

def test_createDataFrame_with_array_type(self):
import pandas as pd
pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]})
df, df_arrow = self._createDataFrame_toggle(pdf)
result = df.collect()
Expand All @@ -347,16 +345,13 @@ def test_toPandas_with_array_type(self):

def test_createDataFrame_with_int_col_names(self):
import numpy as np
import pandas as pd
pdf = pd.DataFrame(np.random.rand(4, 2))
df, df_arrow = self._createDataFrame_toggle(pdf)
pdf_col_names = [str(c) for c in pdf.columns]
self.assertEqual(pdf_col_names, df.columns)
self.assertEqual(pdf_col_names, df_arrow.columns)

def test_createDataFrame_fallback_enabled(self):
import pandas as pd

with QuietTest(self.sc):
with self.sql_conf({"spark.sql.execution.arrow.fallback.enabled": True}):
with warnings.catch_warnings(record=True) as warns:
Expand All @@ -374,8 +369,6 @@ def test_createDataFrame_fallback_enabled(self):

def test_createDataFrame_fallback_disabled(self):
from distutils.version import LooseVersion
import pandas as pd
import pyarrow as pa

with QuietTest(self.sc):
with self.assertRaisesRegexp(TypeError, 'Unsupported type'):
Expand All @@ -391,7 +384,6 @@ def test_createDataFrame_fallback_disabled(self):

# Regression test for SPARK-23314
def test_timestamp_dst(self):
import pandas as pd
# Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am
dt = [datetime.datetime(2015, 11, 1, 0, 30),
datetime.datetime(2015, 11, 1, 1, 30),
Expand All @@ -401,8 +393,8 @@ def test_timestamp_dst(self):
df_from_python = self.spark.createDataFrame(dt, 'timestamp').toDF('time')
df_from_pandas = self.spark.createDataFrame(pdf)

self.assertPandasEqual(pdf, df_from_python.toPandas())
self.assertPandasEqual(pdf, df_from_pandas.toPandas())
assert_frame_equal(pdf, df_from_python.toPandas())
assert_frame_equal(pdf, df_from_pandas.toPandas())

def test_toPandas_batch_order(self):

Expand All @@ -418,7 +410,7 @@ def run_test(num_records, num_parts, max_records, use_delay=False):
df = df.rdd.mapPartitionsWithIndex(delay_first_part).toDF()
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": max_records}):
pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
self.assertPandasEqual(pdf, pdf_arrow)
assert_frame_equal(pdf, pdf_arrow)

cases = [
(1024, 512, 2), # Use large num partitions for more likely collecting out of order
Expand Down
7 changes: 4 additions & 3 deletions python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,14 +581,15 @@ def test_create_dataframe_required_pandas_not_found(self):

# Regression test for SPARK-23360
@unittest.skipIf(not have_pandas, pandas_requirement_message)
def test_create_dateframe_from_pandas_with_dst(self):
def test_create_dataframe_from_pandas_with_dst(self):
import pandas as pd
from pandas.util.testing import assert_frame_equal
from datetime import datetime

pdf = pd.DataFrame({'time': [datetime(2015, 10, 31, 22, 30)]})

df = self.spark.createDataFrame(pdf)
self.assertPandasEqual(pdf, df.toPandas())
assert_frame_equal(pdf, df.toPandas())

orig_env_tz = os.environ.get('TZ', None)
try:
Expand All @@ -597,7 +598,7 @@ def test_create_dateframe_from_pandas_with_dst(self):
time.tzset()
with self.sql_conf({'spark.sql.session.timeZone': tz}):
df = self.spark.createDataFrame(pdf)
self.assertPandasEqual(pdf, df.toPandas())
assert_frame_equal(pdf, df.toPandas())
finally:
del os.environ['TZ']
if orig_env_tz is not None:
Expand Down
60 changes: 31 additions & 29 deletions python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
pandas_requirement_message, pyarrow_requirement_message
from pyspark.testing.utils import QuietTest

if have_pandas:
import pandas as pd
from pandas.util.testing import assert_frame_equal


@unittest.skipIf(
not have_pandas or not have_pyarrow,
Expand All @@ -50,8 +54,6 @@ def plus_one(v):

@property
def pandas_scalar_plus_two(self):
import pandas as pd

@pandas_udf('double', PandasUDFType.SCALAR)
def plus_two(v):
assert isinstance(v, pd.Series)
Expand Down Expand Up @@ -107,7 +109,7 @@ def test_manual(self):
[9, 335.0, 33.5, [33.5]]],
['id', 'sum(v)', 'avg(v)', 'avg(array(v))'])

self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
assert_frame_equal(expected1.toPandas(), result1.toPandas())

def test_basic(self):
df = self.data
Expand All @@ -116,19 +118,19 @@ def test_basic(self):
# Groupby one column and aggregate one UDF with literal
result1 = df.groupby('id').agg(weighted_mean_udf(df.v, lit(1.0))).sort('id')
expected1 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort('id')
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
assert_frame_equal(expected1.toPandas(), result1.toPandas())

# Groupby one expression and aggregate one UDF with literal
result2 = df.groupby((col('id') + 1)).agg(weighted_mean_udf(df.v, lit(1.0)))\
.sort(df.id + 1)
expected2 = df.groupby((col('id') + 1))\
.agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort(df.id + 1)
self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
assert_frame_equal(expected2.toPandas(), result2.toPandas())

# Groupby one column and aggregate one UDF without literal
result3 = df.groupby('id').agg(weighted_mean_udf(df.v, df.w)).sort('id')
expected3 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, w)')).sort('id')
self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
assert_frame_equal(expected3.toPandas(), result3.toPandas())

# Groupby one expression and aggregate one UDF without literal
result4 = df.groupby((col('id') + 1).alias('id'))\
Expand All @@ -137,7 +139,7 @@ def test_basic(self):
expected4 = df.groupby((col('id') + 1).alias('id'))\
.agg(mean(df.v).alias('weighted_mean(v, w)'))\
.sort('id')
self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
assert_frame_equal(expected4.toPandas(), result4.toPandas())

def test_unsupported_types(self):
with QuietTest(self.sc):
Expand Down Expand Up @@ -166,7 +168,7 @@ def test_alias(self):
result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias'))
expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias'))

self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
assert_frame_equal(expected1.toPandas(), result1.toPandas())

def test_mixed_sql(self):
"""
Expand Down Expand Up @@ -200,9 +202,9 @@ def test_mixed_sql(self):
.agg(sum(df.v + 1) + 2)
.sort('id'))

self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
assert_frame_equal(expected1.toPandas(), result1.toPandas())
assert_frame_equal(expected2.toPandas(), result2.toPandas())
assert_frame_equal(expected3.toPandas(), result3.toPandas())

def test_mixed_udfs(self):
"""
Expand Down Expand Up @@ -262,12 +264,12 @@ def test_mixed_udfs(self):
.agg(plus_two(sum(plus_two(df.v))))
.sort('plus_two(id)'))

self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
self.assertPandasEqual(expected5.toPandas(), result5.toPandas())
self.assertPandasEqual(expected6.toPandas(), result6.toPandas())
assert_frame_equal(expected1.toPandas(), result1.toPandas())
assert_frame_equal(expected2.toPandas(), result2.toPandas())
assert_frame_equal(expected3.toPandas(), result3.toPandas())
assert_frame_equal(expected4.toPandas(), result4.toPandas())
assert_frame_equal(expected5.toPandas(), result5.toPandas())
assert_frame_equal(expected6.toPandas(), result6.toPandas())

def test_multiple_udfs(self):
"""
Expand All @@ -291,7 +293,7 @@ def test_multiple_udfs(self):
.sort('id')
.toPandas())

self.assertPandasEqual(expected1, result1)
assert_frame_equal(expected1, result1)

def test_complex_groupby(self):
df = self.data
Expand Down Expand Up @@ -327,13 +329,13 @@ def test_complex_groupby(self):
result7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum_udf(df.v)).sort('sum(v)')
expected7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum(df.v)).sort('sum(v)')

self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
self.assertPandasEqual(expected5.toPandas(), result5.toPandas())
self.assertPandasEqual(expected6.toPandas(), result6.toPandas())
self.assertPandasEqual(expected7.toPandas(), result7.toPandas())
assert_frame_equal(expected1.toPandas(), result1.toPandas())
assert_frame_equal(expected2.toPandas(), result2.toPandas())
assert_frame_equal(expected3.toPandas(), result3.toPandas())
assert_frame_equal(expected4.toPandas(), result4.toPandas())
assert_frame_equal(expected5.toPandas(), result5.toPandas())
assert_frame_equal(expected6.toPandas(), result6.toPandas())
assert_frame_equal(expected7.toPandas(), result7.toPandas())

def test_complex_expressions(self):
df = self.data
Expand Down Expand Up @@ -404,9 +406,9 @@ def test_complex_expressions(self):
.sort('id')
.toPandas())

self.assertPandasEqual(expected1, result1)
self.assertPandasEqual(expected2, result2)
self.assertPandasEqual(expected3, result3)
assert_frame_equal(expected1, result1)
assert_frame_equal(expected2, result2)
assert_frame_equal(expected3, result3)

def test_retain_group_columns(self):
with self.sql_conf({"spark.sql.retainGroupColumns": False}):
Expand All @@ -415,7 +417,7 @@ def test_retain_group_columns(self):

result1 = df.groupby(df.id).agg(sum_udf(df.v))
expected1 = df.groupby(df.id).agg(sum(df.v))
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
assert_frame_equal(expected1.toPandas(), result1.toPandas())

def test_array_type(self):
df = self.data
Expand Down
Loading

0 comments on commit 95c106c

Please sign in to comment.