Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
SparkConnectService
.getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getClientId)
.session
v.getPlan.getOpTypeCase match {
case proto.Plan.OpTypeCase.COMMAND => handleCommand(session, v)
case proto.Plan.OpTypeCase.ROOT => handlePlan(session, v)
case _ =>
throw new UnsupportedOperationException(s"${v.getPlan.getOpTypeCase} not supported.")
session.withActive {
v.getPlan.getOpTypeCase match {
case proto.Plan.OpTypeCase.COMMAND => handleCommand(session, v)
case proto.Plan.OpTypeCase.ROOT => handlePlan(session, v)
case _ =>
throw new UnsupportedOperationException(s"${v.getPlan.getOpTypeCase} not supported.")
}
}
}

Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/sql/connect/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType:
elif isinstance(data_type, DoubleType):
ret.double.CopyFrom(pb2.DataType.Double())
elif isinstance(data_type, DecimalType):
ret.decimal.CopyFrom(pb2.DataType.Decimal())
ret.decimal.scale = data_type.scale
ret.decimal.precision = data_type.precision
elif isinstance(data_type, DateType):
ret.date.CopyFrom(pb2.DataType.Date())
elif isinstance(data_type, TimestampType):
Expand Down
5 changes: 0 additions & 5 deletions python/pyspark/sql/tests/connect/test_parity_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,6 @@ def test_infer_schema_upcast_int_to_string(self):
def test_infer_schema_with_udt(self):
super().test_infer_schema_with_udt()

# TODO(SPARK-41834): Implement SparkSession.conf
@unittest.skip("Fails in Spark Connect, should enable.")
def test_negative_decimal(self):
super().test_negative_decimal()

# TODO(SPARK-42020): createDataFrame with UDT
@unittest.skip("Fails in Spark Connect, should enable.")
def test_nested_udt_in_df(self):
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,13 +379,13 @@ def test_create_dataframe_from_dict_respects_schema(self):

def test_negative_decimal(self):
try:
self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=true")
self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=true").collect()
df = self.spark.createDataFrame([(1,), (11,)], ["value"])
ret = df.select(col("value").cast(DecimalType(1, -1))).collect()
actual = list(map(lambda r: int(r.value), ret))
self.assertEqual(actual, [0, 10])
finally:
self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=false")
self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=false").collect()

def test_create_dataframe_from_objects(self):
data = [MyObject(1, "1"), MyObject(2, "2")]
Expand Down