Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3684,6 +3684,14 @@ def sha1(col: "ColumnOrName") -> Column:


def sha2(col: "ColumnOrName", numBits: int) -> Column:
if numBits not in [0, 224, 256, 384, 512]:
raise PySparkValueError(
error_class="VALUE_NOT_ALLOWED",
message_parameters={
"arg_name": "numBits",
"allowed_values": "[0, 224, 256, 384, 512]",
},
)
return _invoke_function("sha2", _to_col(col), lit(numBits))


Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9112,6 +9112,14 @@ def sha2(col: "ColumnOrName", numBits: int) -> Column:
|Bob |cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961|
+-----+----------------------------------------------------------------+
"""
if numBits not in [0, 224, 256, 384, 512]:
raise PySparkValueError(
error_class="VALUE_NOT_ALLOWED",
message_parameters={
"arg_name": "numBits",
"allowed_values": "[0, 224, 256, 384, 512]",
},
)
return _invoke_function("sha2", _to_java_column(col), numBits)


Expand Down
5 changes: 1 addition & 4 deletions python/pyspark/sql/tests/connect/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest

from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.sql.tests.test_utils import UtilsTestsMixin


class ConnectUtilsTests(ReusedConnectTestCase, UtilsTestsMixin):
@unittest.skip("SPARK-46397: Different exception thrown")
def test_capture_illegalargument_exception(self):
super().test_capture_illegalargument_exception()
pass


if __name__ == "__main__":
Expand Down
19 changes: 8 additions & 11 deletions python/pyspark/sql/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
from itertools import zip_longest

from pyspark.errors import QueryContextType
from pyspark.sql.functions import sha2, to_timestamp
from pyspark.errors import (
AnalysisException,
ParseException,
PySparkAssertionError,
PySparkValueError,
IllegalArgumentException,
SparkUpgradeException,
)
Expand Down Expand Up @@ -590,8 +590,8 @@ def test_assert_equal_timestamp(self):
data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"]
)

df1 = df1.withColumn("timestamp", to_timestamp("timestamp"))
df2 = df2.withColumn("timestamp", to_timestamp("timestamp"))
df1 = df1.withColumn("timestamp", F.to_timestamp("timestamp"))
df2 = df2.withColumn("timestamp", F.to_timestamp("timestamp"))

assertDataFrameEqual(df1, df2, checkRowOrder=False)
assertDataFrameEqual(df1, df2, checkRowOrder=True)
Expand Down Expand Up @@ -1729,17 +1729,14 @@ def test_capture_illegalargument_exception(self):
"Setting negative mapred.reduce.tasks",
lambda: self.spark.sql("SET mapred.reduce.tasks=-1"),
)

def test_capture_pyspark_value_exception(self):
df = self.spark.createDataFrame([(1, 2)], ["a", "b"])
self.assertRaisesRegex(
IllegalArgumentException,
"1024 is not in the permitted values",
lambda: df.select(sha2(df.a, 1024)).collect(),
PySparkValueError,
"Value for `numBits` has to be amongst the following values",
lambda: df.select(F.sha2(df.a, 1024)).collect(),
)
try:
df.select(sha2(df.a, 1024)).collect()
except IllegalArgumentException as e:
self.assertRegex(e._desc, "1024 is not in the permitted values")
self.assertRegex(e._stackTrace, "org.apache.spark.sql.functions")

def test_get_error_class_state(self):
# SPARK-36953: test CapturedException.getErrorClass and getSqlState (from SparkThrowable)
Expand Down