diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index cf2e2e0c7344d..baf8dc82fd84a 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -3684,6 +3684,14 @@ def sha1(col: "ColumnOrName") -> Column: def sha2(col: "ColumnOrName", numBits: int) -> Column: + if numBits not in [0, 224, 256, 384, 512]: + raise PySparkValueError( + error_class="VALUE_NOT_ALLOWED", + message_parameters={ + "arg_name": "numBits", + "allowed_values": "[0, 224, 256, 384, 512]", + }, + ) return _invoke_function("sha2", _to_col(col), lit(numBits)) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 3b579f20333e9..69f082a6f998f 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -9112,6 +9112,14 @@ def sha2(col: "ColumnOrName", numBits: int) -> Column: |Bob |cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961| +-----+----------------------------------------------------------------+ """ + if numBits not in [0, 224, 256, 384, 512]: + raise PySparkValueError( + error_class="VALUE_NOT_ALLOWED", + message_parameters={ + "arg_name": "numBits", + "allowed_values": "[0, 224, 256, 384, 512]", + }, + ) return _invoke_function("sha2", _to_java_column(col), numBits) diff --git a/python/pyspark/sql/tests/connect/test_utils.py b/python/pyspark/sql/tests/connect/test_utils.py index 5f5f401cc6261..917cb58057f7f 100644 --- a/python/pyspark/sql/tests/connect/test_utils.py +++ b/python/pyspark/sql/tests/connect/test_utils.py @@ -14,16 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import unittest from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.sql.tests.test_utils import UtilsTestsMixin class ConnectUtilsTests(ReusedConnectTestCase, UtilsTestsMixin): - @unittest.skip("SPARK-46397: Different exception thrown") - def test_capture_illegalargument_exception(self): - super().test_capture_illegalargument_exception() + pass if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py index d54db78d4b651..66b5c19fc975a 100644 --- a/python/pyspark/sql/tests/test_utils.py +++ b/python/pyspark/sql/tests/test_utils.py @@ -20,11 +20,11 @@ from itertools import zip_longest from pyspark.errors import QueryContextType -from pyspark.sql.functions import sha2, to_timestamp from pyspark.errors import ( AnalysisException, ParseException, PySparkAssertionError, + PySparkValueError, IllegalArgumentException, SparkUpgradeException, ) @@ -590,8 +590,8 @@ def test_assert_equal_timestamp(self): data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"] ) - df1 = df1.withColumn("timestamp", to_timestamp("timestamp")) - df2 = df2.withColumn("timestamp", to_timestamp("timestamp")) + df1 = df1.withColumn("timestamp", F.to_timestamp("timestamp")) + df2 = df2.withColumn("timestamp", F.to_timestamp("timestamp")) assertDataFrameEqual(df1, df2, checkRowOrder=False) assertDataFrameEqual(df1, df2, checkRowOrder=True) @@ -1729,17 +1729,14 @@ def test_capture_illegalargument_exception(self): "Setting negative mapred.reduce.tasks", lambda: self.spark.sql("SET mapred.reduce.tasks=-1"), ) + + def test_capture_pyspark_value_exception(self): df = self.spark.createDataFrame([(1, 2)], ["a", "b"]) self.assertRaisesRegex( - IllegalArgumentException, - "1024 is not in the permitted values", - lambda: df.select(sha2(df.a, 1024)).collect(), + PySparkValueError, + "Value for `numBits` has to be amongst the following values", + lambda: df.select(F.sha2(df.a, 1024)).collect(), ) - try: - df.select(sha2(df.a, 1024)).collect() - except IllegalArgumentException as e: - self.assertRegex(e._desc, "1024 is not in the permitted values") - self.assertRegex(e._stackTrace, "org.apache.spark.sql.functions") def test_get_error_class_state(self): # SPARK-36953: test CapturedException.getErrorClass and getSqlState (from SparkThrowable)