diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 8392d7f29af5..5066b3c4a225 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2303,6 +2303,10 @@ def test_py4j_exception_message(self): self.assertTrue('NullPointerException' in _exception_message(context.exception)) + def test_parsing_version_string(self): + from pyspark.util import VersionUtils + self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced")) + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 04df835bf671..59cc2a632935 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -62,24 +62,31 @@ def _get_argspec(f): return argspec -def majorMinorVersion(version): +class VersionUtils(object): """ - Get major and minor version numbers for given Spark version string. - - >>> version = "2.4.0" - >>> majorMinorVersion(version) - (2, 4) + Provides utility method to determine Spark versions with given input string. + """ + @staticmethod + def majorMinorVersion(sparkVersion): + """ + Given a Spark version string, return the (major version number, minor version number). + E.g., for 2.0.1-SNAPSHOT, return (2, 0). - >>> version = "abc" - >>> majorMinorVersion(version) is None - True + >>> sparkVersion = "2.4.0" + >>> VersionUtils.majorMinorVersion(sparkVersion) + (2, 4) + >>> sparkVersion = "2.3.0-SNAPSHOT" + >>> VersionUtils.majorMinorVersion(sparkVersion) + (2, 3) - """ - m = re.search('^(\d+)\.(\d+)(\..*)?$', version) - if m is None: - return None - else: - return (int(m.group(1)), int(m.group(2))) + """ + m = re.search('^(\d+)\.(\d+)(\..*)?$', sparkVersion) + if m is not None: + return (int(m.group(1)), int(m.group(2))) + else: + raise ValueError("Spark tried to parse '%s' as a Spark" % sparkVersion + + " version string, but it could not find the major and minor" + + " version numbers.") if __name__ == "__main__":