diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index e5288636c596e..18f7aebae7fe8 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -964,6 +964,21 @@ def _test(): except py4j.protocol.Py4JError: spark = SparkSession(sc) + hive_enabled = True + try: + sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + hive_enabled = False + except TypeError: + hive_enabled = False + + if not hive_enabled: + # if hive is not enabled, then skip doctests that need hive + # TODO: Need to communicate with outside world that this test + # has been skipped. + m = pyspark.sql.readwriter + m.__dict__["DataFrameReader"].__dict__["table"].__doc__ = "" + globs['tempfile'] = tempfile globs['os'] = os globs['sc'] = sc diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 967cc83166f3f..003bcd4827433 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2977,6 +2977,20 @@ def test_create_dateframe_from_pandas_with_dst(self): class HiveSparkSubmitTests(SparkSubmitTests): + @classmethod + def setUpClass(cls): + # get a SparkContext to check for availability of Hive + sc = SparkContext('local[4]', cls.__name__) + try: + sc._jvm.org.apache.hadoop.hive.conf.HiveConf() + except py4j.protocol.Py4JError: + raise unittest.SkipTest("Hive is not available") + except TypeError: + raise unittest.SkipTest("Hive is not available") + finally: + # we don't need SparkContext for the test + sc.stop() + def test_hivecontext(self): # This test checks that HiveContext is using Hive metastore (SPARK-16224). # It sets a metastore url and checks if there is a derby dir created by diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 24dd06c26089c..39eff97c5cd3a 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -385,8 +385,23 @@ def registerJavaUDAF(self, name, javaClassName): def _test(): import doctest + import os + import os.path + import glob from pyspark.sql import SparkSession import pyspark.sql.udf + + SPARK_HOME = os.environ["SPARK_HOME"] + filename_pattern = "sql/core/target/scala-*/test-classes/" + \ + "test/org/apache/spark/sql/JavaStringLength.class" + if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): + # if test udf files are not compiled, then skip the below doctests + # TODO: Need to communicate with outside world that these tests + # have been skipped. + m = pyspark.sql.udf + m.__dict__["UDFRegistration"].__dict__["registerJavaFunction"].__doc__ = "" + m.__dict__["UDFRegistration"].__dict__["registerJavaUDAF"].__doc__ = "" + globs = pyspark.sql.udf.__dict__.copy() spark = SparkSession.builder\ .master("local[4]")\