Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class AttributeSuite extends SparkFunSuite {
// Attribute.fromStructField should accept any NumericType, not just DoubleType
val longFldWithMeta = new StructField("x", LongType, false, metadata)
assert(Attribute.fromStructField(longFldWithMeta).isNumeric)
val decimalFldWithMeta = new StructField("x", DecimalType(None), false, metadata)
val decimalFldWithMeta = new StructField("x", DecimalType(38, 18), false, metadata)
assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric)
}
}
36 changes: 21 additions & 15 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,30 +194,33 @@ def fromInternal(self, ts):

class DecimalType(FractionalType):
"""Decimal (decimal.Decimal) data type.

The DecimalType must have fixed precision (the maximum total number of digits)
and scale (the number of digits on the right of dot). For example, (5, 2) can
support the value from [-999.99 to 999.99].

The precision can be up to 38, the scale must less or equal to precision.

When create a DecimalType, the default precision and scale is (10, 0). When infer
schema from decimal.Decimal objects, it will be DecimalType(38, 18).

:param precision: the maximum total number of digits (default: 10)
:param scale: the number of digits on right side of dot. (default: 0)
"""

def __init__(self, precision=None, scale=None):
def __init__(self, precision=10, scale=0):
self.precision = precision
self.scale = scale
self.hasPrecisionInfo = precision is not None
self.hasPrecisionInfo = True # this is public API

def simpleString(self):
if self.hasPrecisionInfo:
return "decimal(%d,%d)" % (self.precision, self.scale)
else:
return "decimal(10,0)"
return "decimal(%d,%d)" % (self.precision, self.scale)

def jsonValue(self):
if self.hasPrecisionInfo:
return "decimal(%d,%d)" % (self.precision, self.scale)
else:
return "decimal"
return "decimal(%d,%d)" % (self.precision, self.scale)

def __repr__(self):
if self.hasPrecisionInfo:
return "DecimalType(%d,%d)" % (self.precision, self.scale)
else:
return "DecimalType()"
return "DecimalType(%d,%d)" % (self.precision, self.scale)


class DoubleType(FractionalType):
Expand Down Expand Up @@ -761,7 +764,10 @@ def _infer_type(obj):
return obj.__UDT__

dataType = _type_mappings.get(type(obj))
if dataType is not None:
if dataType is DecimalType:
# the precision and scale of `obj` may be different from row to row.
return DecimalType(38, 18)
elif dataType is not None:
return dataType()

if isinstance(obj, dict):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,18 @@ public static ArrayType createArrayType(DataType elementType, boolean containsNu
return new ArrayType(elementType, containsNull);
}

/**
* Creates a DecimalType by specifying the precision and scale.
*/
public static DecimalType createDecimalType(int precision, int scale) {
return DecimalType$.MODULE$.apply(precision, scale);
}

/**
* Creates a DecimalType with default precision and scale, which are 10 and 0.
*/
public static DecimalType createDecimalType() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add javadoc here; can you also add javadoc for the one above?

return DecimalType$.MODULE$.Unlimited();
return DecimalType$.MODULE$.USER_DEFAULT();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ private [sql] object JavaTypeInference {
case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)

case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType.SYSTEM_DEFAULT, true)
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ trait ScalaReflection {
case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true)
case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true)
case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if t <:< localTypeOf[java.math.BigDecimal] =>
Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< localTypeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
Expand Down Expand Up @@ -167,8 +167,8 @@ trait ScalaReflection {
case obj: Float => FloatType
case obj: Double => DoubleType
case obj: java.sql.Date => DateType
case obj: java.math.BigDecimal => DecimalType.Unlimited
case obj: Decimal => DecimalType.Unlimited
case obj: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT
case obj: Decimal => DecimalType.SYSTEM_DEFAULT
case obj: java.sql.Timestamp => TimestampType
case null => NullType
// For other cases, there is no obvious mapping from the type of the given object to a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,10 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {

protected lazy val numericLiteral: Parser[Literal] =
( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) }
| sign.? ~ unsignedFloat ^^ { case s ~ f => Literal((s.getOrElse("") + f).toDouble) }
| sign.? ~ unsignedFloat ^^ {
// TODO(davies): some precisions may loss, we should create decimal literal
case s ~ f => Literal(BigDecimal(s.getOrElse("") + f).doubleValue())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is different from our offline discussion, isn't it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will break some tests, so I'd like to do it separately, https://issues.apache.org/jira/browse/SPARK-9281?filter=-2.

}
)

protected lazy val unsignedFloat: Parser[String] =
Expand Down
Loading