diff --git a/pom.xml b/pom.xml
index 666d5d7169a15..d18831df1db6d 100644
--- a/pom.xml
+++ b/pom.xml
@@ -185,6 +185,10 @@
2.8
1.8
1.0.0
+
0.8.0
${java.home}
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 59a417015b949..8ec24db8717b2 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1913,6 +1913,9 @@ def toPandas(self):
0 2 Alice
1 5 Bob
"""
+ from pyspark.sql.utils import require_minimum_pandas_version
+ require_minimum_pandas_version()
+
import pandas as pd
if self.sql_ctx.getConf("spark.sql.execution.pandas.respectSessionTimeZone").lower() \
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 1ed04298bc899..b3af9b82953f3 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -646,6 +646,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
except Exception:
has_pandas = False
if has_pandas and isinstance(data, pandas.DataFrame):
+ from pyspark.sql.utils import require_minimum_pandas_version
+ require_minimum_pandas_version()
+
if self.conf.get("spark.sql.execution.pandas.respectSessionTimeZone").lower() \
== "true":
timezone = self.conf.get("spark.sql.session.timeZone")
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 53da7dd45c2f2..58359b61dc83a 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -48,19 +48,26 @@
else:
import unittest
-_have_pandas = False
-_have_old_pandas = False
+_pandas_requirement_message = None
try:
- import pandas
- try:
- from pyspark.sql.utils import require_minimum_pandas_version
- require_minimum_pandas_version()
- _have_pandas = True
- except:
- _have_old_pandas = True
-except:
- # No Pandas, but that's okay, we'll skip those tests
- pass
+ from pyspark.sql.utils import require_minimum_pandas_version
+ require_minimum_pandas_version()
+except ImportError as e:
+ from pyspark.util import _exception_message
+ # If Pandas version requirement is not satisfied, skip related tests.
+ _pandas_requirement_message = _exception_message(e)
+
+_pyarrow_requirement_message = None
+try:
+ from pyspark.sql.utils import require_minimum_pyarrow_version
+ require_minimum_pyarrow_version()
+except ImportError as e:
+ from pyspark.util import _exception_message
+ # If Arrow version requirement is not satisfied, skip related tests.
+ _pyarrow_requirement_message = _exception_message(e)
+
+_have_pandas = _pandas_requirement_message is None
+_have_pyarrow = _pyarrow_requirement_message is None
from pyspark import SparkContext
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
@@ -75,15 +82,6 @@
from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException
-_have_arrow = False
-try:
- import pyarrow
- _have_arrow = True
-except:
- # No Arrow, but that's okay, we'll skip those tests
- pass
-
-
class UTCOffsetTimezone(datetime.tzinfo):
"""
Specifies timezone in UTC offset
@@ -2794,7 +2792,6 @@ def count_bucketed_cols(names, table="pyspark_bucket"):
def _to_pandas(self):
from datetime import datetime, date
- import numpy as np
schema = StructType().add("a", IntegerType()).add("b", StringType())\
.add("c", BooleanType()).add("d", FloatType())\
.add("dt", DateType()).add("ts", TimestampType())
@@ -2807,7 +2804,7 @@ def _to_pandas(self):
df = self.spark.createDataFrame(data, schema)
return df.toPandas()
- @unittest.skipIf(not _have_pandas, "Pandas not installed")
+ @unittest.skipIf(not _have_pandas, _pandas_requirement_message)
def test_to_pandas(self):
import numpy as np
pdf = self._to_pandas()
@@ -2819,13 +2816,13 @@ def test_to_pandas(self):
self.assertEquals(types[4], np.object) # datetime.date
self.assertEquals(types[5], 'datetime64[ns]')
- @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed")
- def test_to_pandas_old(self):
+ @unittest.skipIf(_have_pandas, "Required Pandas was found.")
+ def test_to_pandas_required_pandas_not_found(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
self._to_pandas()
- @unittest.skipIf(not _have_pandas, "Pandas not installed")
+ @unittest.skipIf(not _have_pandas, _pandas_requirement_message)
def test_to_pandas_avoid_astype(self):
import numpy as np
schema = StructType().add("a", IntegerType()).add("b", StringType())\
@@ -2843,7 +2840,7 @@ def test_create_dataframe_from_array_of_long(self):
df = self.spark.createDataFrame(data)
self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807]))
- @unittest.skipIf(not _have_pandas, "Pandas not installed")
+ @unittest.skipIf(not _have_pandas, _pandas_requirement_message)
def test_create_dataframe_from_pandas_with_timestamp(self):
import pandas as pd
from datetime import datetime
@@ -2858,14 +2855,16 @@ def test_create_dataframe_from_pandas_with_timestamp(self):
self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType))
self.assertTrue(isinstance(df.schema['d'].dataType, DateType))
- @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed")
- def test_create_dataframe_from_old_pandas(self):
- import pandas as pd
- from datetime import datetime
- pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
- "d": [pd.Timestamp.now().date()]})
+ @unittest.skipIf(_have_pandas, "Required Pandas was found.")
+ def test_create_dataframe_required_pandas_not_found(self):
with QuietTest(self.sc):
- with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
+ with self.assertRaisesRegexp(
+ ImportError,
+ '(Pandas >= .* must be installed|No module named pandas)'):
+ import pandas as pd
+ from datetime import datetime
+ pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
+ "d": [pd.Timestamp.now().date()]})
self.spark.createDataFrame(pdf)
@@ -3383,7 +3382,9 @@ def __init__(self, **kwargs):
_make_type_verifier(data_type, nullable=False)(obj)
-@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
+@unittest.skipIf(
+ not _have_pandas or not _have_pyarrow,
+ _pandas_requirement_message or _pyarrow_requirement_message)
class ArrowTests(ReusedSQLTestCase):
@classmethod
@@ -3641,7 +3642,9 @@ def test_createDataFrame_with_int_col_names(self):
self.assertEqual(pdf_col_names, df_arrow.columns)
-@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
+@unittest.skipIf(
+ not _have_pandas or not _have_pyarrow,
+ _pandas_requirement_message or _pyarrow_requirement_message)
class PandasUDFTests(ReusedSQLTestCase):
def test_pandas_udf_basic(self):
from pyspark.rdd import PythonEvalType
@@ -3765,7 +3768,9 @@ def foo(k, v):
return k
-@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
+@unittest.skipIf(
+ not _have_pandas or not _have_pyarrow,
+ _pandas_requirement_message or _pyarrow_requirement_message)
class ScalarPandasUDFTests(ReusedSQLTestCase):
@classmethod
@@ -4278,7 +4283,9 @@ def test_register_vectorized_udf_basic(self):
self.assertEquals(expected.collect(), res2.collect())
-@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
+@unittest.skipIf(
+ not _have_pandas or not _have_pyarrow,
+ _pandas_requirement_message or _pyarrow_requirement_message)
class GroupedMapPandasUDFTests(ReusedSQLTestCase):
@property
@@ -4447,7 +4454,9 @@ def test_unsupported_types(self):
df.groupby('id').apply(f).collect()
-@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
+@unittest.skipIf(
+ not _have_pandas or not _have_pyarrow,
+ _pandas_requirement_message or _pyarrow_requirement_message)
class GroupedAggPandasUDFTests(ReusedSQLTestCase):
@property
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 08c34c6dccc5e..578298632dd4c 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -115,18 +115,32 @@ def toJArray(gateway, jtype, arr):
def require_minimum_pandas_version():
""" Raise ImportError if minimum version of Pandas is not installed
"""
+ # TODO(HyukjinKwon): Relocate and deduplicate the version specification.
+ minimum_pandas_version = "0.19.2"
+
from distutils.version import LooseVersion
- import pandas
- if LooseVersion(pandas.__version__) < LooseVersion('0.19.2'):
- raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process; "
- "however, your version was %s." % pandas.__version__)
+ try:
+ import pandas
+ except ImportError:
+ raise ImportError("Pandas >= %s must be installed; however, "
+ "it was not found." % minimum_pandas_version)
+ if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
+ raise ImportError("Pandas >= %s must be installed; however, "
+ "your version was %s." % (minimum_pandas_version, pandas.__version__))
def require_minimum_pyarrow_version():
""" Raise ImportError if minimum version of pyarrow is not installed
"""
+ # TODO(HyukjinKwon): Relocate and deduplicate the version specification.
+ minimum_pyarrow_version = "0.8.0"
+
from distutils.version import LooseVersion
- import pyarrow
- if LooseVersion(pyarrow.__version__) < LooseVersion('0.8.0'):
- raise ImportError("pyarrow >= 0.8.0 must be installed on calling Python process; "
- "however, your version was %s." % pyarrow.__version__)
+ try:
+ import pyarrow
+ except ImportError:
+ raise ImportError("PyArrow >= %s must be installed; however, "
+ "it was not found." % minimum_pyarrow_version)
+ if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version):
+ raise ImportError("PyArrow >= %s must be installed; however, "
+ "your version was %s." % (minimum_pyarrow_version, pyarrow.__version__))
diff --git a/python/setup.py b/python/setup.py
index 251d4526d4dd0..6a98401941d8d 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -100,6 +100,11 @@ def _supports_symlinks():
file=sys.stderr)
exit(-1)
+# If you are changing the versions here, please also change ./python/pyspark/sql/utils.py and
+# ./python/run-tests.py. In case of Arrow, you should also check ./pom.xml.
+_minimum_pandas_version = "0.19.2"
+_minimum_pyarrow_version = "0.8.0"
+
try:
# We copy the shell script to be under pyspark/python/pyspark so that the launcher scripts
# find it where expected. The rest of the files aren't copied because they are accessed
@@ -201,7 +206,10 @@ def _supports_symlinks():
extras_require={
'ml': ['numpy>=1.7'],
'mllib': ['numpy>=1.7'],
- 'sql': ['pandas>=0.19.2', 'pyarrow>=0.8.0']
+ 'sql': [
+ 'pandas>=%s' % _minimum_pandas_version,
+ 'pyarrow>=%s' % _minimum_pyarrow_version,
+ ]
},
classifiers=[
'Development Status :: 5 - Production/Stable',