Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2026,15 +2026,25 @@ def __init__(self, func, returnType, name=None):
"{0}".format(type(func)))

self.func = func
self.returnType = (
returnType if isinstance(returnType, DataType)
else _parse_datatype_string(returnType))
self._returnType = returnType
# Stores UserDefinedPythonFunctions jobj, once initialized
self._returnType_placeholder = None
self._judf_placeholder = None
self._name = name or (
func.__name__ if hasattr(func, '__name__')
else func.__class__.__name__)

@property
def returnType(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have pretty similar logic bellow, would it make sense to think about if there is a nicer more general way to handle these delayed iniatilization classes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm.. I tried several ways I could think at my best but I could not figure out ...

# This makes sure this is called after SparkContext is initialized.
# ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string.
if self._returnType_placeholder is None:
if isinstance(self._returnType, DataType):
self._returnType_placeholder = self._returnType
else:
self._returnType_placeholder = _parse_datatype_string(self._returnType)
return self._returnType_placeholder

@property
def _judf(self):
# It is possible that concurrent access, to newly created UDF,
Expand Down
25 changes: 25 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,6 +1246,31 @@ def test_struct_type(self):
with self.assertRaises(TypeError):
not_a_field = struct1[9.9]

def test_parse_datatype_string(self):
from pyspark.sql.types import _all_atomic_types, _parse_datatype_string
for k, t in _all_atomic_types.items():
if t != NullType:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, if I haven't missed anything, this PR drops the support the type parsing null. I guess it is almost seldom that we explicitly set the type with null. Also, IIRC, we will support NullType via void (SPARK-20680) soon as a workaround.

self.assertEqual(t(), _parse_datatype_string(k))
self.assertEqual(IntegerType(), _parse_datatype_string("int"))
self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1 ,1)"))
self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 10,1 )"))
self.assertEqual(DecimalType(11, 1), _parse_datatype_string("decimal(11,1)"))
self.assertEqual(
ArrayType(IntegerType()),
_parse_datatype_string("array<int >"))
self.assertEqual(
MapType(IntegerType(), DoubleType()),
_parse_datatype_string("map< int, double >"))
self.assertEqual(
StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
_parse_datatype_string("struct<a:int, c:double >"))
self.assertEqual(
StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
_parse_datatype_string("a:int, c:double"))
self.assertEqual(
StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
_parse_datatype_string("a INT, c DOUBLE"))

def test_metadata_null(self):
from pyspark.sql.types import StructType, StringType, StructField
schema = StructType([StructField("f1", StringType(), True, None),
Expand Down
88 changes: 34 additions & 54 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from py4j.protocol import register_input_converter
from py4j.java_gateway import JavaClass

from pyspark import SparkContext
from pyspark.serializers import CloudPickleSerializer

__all__ = [
Expand Down Expand Up @@ -727,18 +728,6 @@ def __eq__(self, other):
_BRACKETS = {'(': ')', '[': ']', '{': '}'}


def _parse_basic_datatype_string(s):
if s in _all_atomic_types.keys():
return _all_atomic_types[s]()
elif s == "int":
return IntegerType()
elif _FIXED_DECIMAL.match(s):
m = _FIXED_DECIMAL.match(s)
return DecimalType(int(m.group(1)), int(m.group(2)))
else:
raise ValueError("Could not parse datatype: %s" % s)


def _ignore_brackets_split(s, separator):
"""
Splits the given string by given separator, but ignore separators inside brackets pairs, e.g.
Expand Down Expand Up @@ -771,32 +760,23 @@ def _ignore_brackets_split(s, separator):
return parts


def _parse_struct_fields_string(s):
parts = _ignore_brackets_split(s, ",")
fields = []
for part in parts:
name_and_type = _ignore_brackets_split(part, ":")
if len(name_and_type) != 2:
raise ValueError("The strcut field string format is: 'field_name:field_type', " +
"but got: %s" % part)
field_name = name_and_type[0].strip()
field_type = _parse_datatype_string(name_and_type[1])
fields.append(StructField(field_name, field_type))
return StructType(fields)


def _parse_datatype_string(s):
"""
Parses the given data type string to a :class:`DataType`. The data type string format equals
to :class:`DataType.simpleString`, except that top level struct type can omit
the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use ``byte`` instead
of ``tinyint`` for :class:`ByteType`. We can also use ``int`` as a short name
for :class:`IntegerType`.
for :class:`IntegerType`. Since Spark 2.3, this also supports a schema in a DDL-formatted
string and case-insensitive strings.

>>> _parse_datatype_string("int ")
IntegerType
>>> _parse_datatype_string("INT ")
IntegerType
>>> _parse_datatype_string("a: byte, b: decimal( 16 , 8 ) ")
StructType(List(StructField(a,ByteType,true),StructField(b,DecimalType(16,8),true)))
>>> _parse_datatype_string("a DOUBLE, b STRING")
StructType(List(StructField(a,DoubleType,true),StructField(b,StringType,true)))
>>> _parse_datatype_string("a: array< short>")
StructType(List(StructField(a,ArrayType(ShortType,true),true)))
>>> _parse_datatype_string(" map<string , string > ")
Expand All @@ -806,43 +786,43 @@ def _parse_datatype_string(s):
>>> _parse_datatype_string("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
ParseException:...
>>> _parse_datatype_string("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
ParseException:...
>>> _parse_datatype_string("array<int") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
ParseException:...
>>> _parse_datatype_string("map<int, boolean>>") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
ParseException:...
"""
s = s.strip()
if s.startswith("array<"):
if s[-1] != ">":
raise ValueError("'>' should be the last char, but got: %s" % s)
return ArrayType(_parse_datatype_string(s[6:-1]))
elif s.startswith("map<"):
if s[-1] != ">":
raise ValueError("'>' should be the last char, but got: %s" % s)
parts = _ignore_brackets_split(s[4:-1], ",")
if len(parts) != 2:
raise ValueError("The map type string format is: 'map<key_type,value_type>', " +
"but got: %s" % s)
kt = _parse_datatype_string(parts[0])
vt = _parse_datatype_string(parts[1])
return MapType(kt, vt)
elif s.startswith("struct<"):
if s[-1] != ">":
raise ValueError("'>' should be the last char, but got: %s" % s)
return _parse_struct_fields_string(s[7:-1])
elif ":" in s:
return _parse_struct_fields_string(s)
else:
return _parse_basic_datatype_string(s)
sc = SparkContext._active_spark_context

def from_ddl_schema(type_str):
return _parse_datatype_json_string(
sc._jvm.org.apache.spark.sql.types.StructType.fromDDL(type_str).json())

def from_ddl_datatype(type_str):
return _parse_datatype_json_string(
sc._jvm.org.apache.spark.sql.api.python.PythonSQLUtils.parseDataType(type_str).json())

try:
# DDL format, "fieldname datatype, fieldname datatype".
return from_ddl_schema(s)
except Exception as e:
try:
# For backwards compatibility, "integer", "struct<fieldname: datatype>" and etc.
return from_ddl_datatype(s)
except:
try:
# For backwards compatibility, "fieldname: datatype, fieldname: datatype" case.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't fieldname: datatype, fieldname: datatype be parsed as DDL schema?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested few cases but it looks not:

scala> StructType.fromDDL("a struct<a: INT, b: STRING>")
res5: org.apache.spark.sql.types.StructType = StructType(StructField(a,StructType(StructField(a,IntegerType,true), StructField(b,StringType,true)),true))

scala> StructType.fromDDL("a INT, b STRING")
res6: org.apache.spark.sql.types.StructType = StructType(StructField(a,IntegerType,true), StructField(b,StringType,true))

scala> StructType.fromDDL("a: INT, b: STRING")
org.apache.spark.sql.catalyst.parser.ParseException:
extraneous input ':' expecting ...

return from_ddl_datatype("struct<%s>" % s.strip())
except:
raise e


def _parse_datatype_json_string(json_string):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.api.python

import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.types.DataType

private[sql] object PythonSQLUtils {
def parseDataType(typeText: String): DataType = CatalystSqlParser.parseDataType(typeText)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this, I should do something like ...

getattr(getattr(sc._jvm.org.apache.spark.sql.catalyst.parser, "CatalystSqlParser$"), "MODULE$").parseDataType("a")

}