Skip to content

Commit

Permalink
Fix bug; Update test case
Browse files Browse the repository at this point in the history
  • Loading branch information
Chong Gao committed Dec 18, 2024
1 parent 89f996b commit 27bf6f7
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 11 deletions.
51 changes: 41 additions & 10 deletions integration_tests/src/main/python/hashing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect
from data_gen import *
from marks import allow_non_gpu, ignore_order
from spark_session import is_before_spark_320

_struct_of_xxhash_gens = StructGen([(f"c{i}", g) for i, g in enumerate(_xxhash_gens)])

_xxhash_gens = [
null_gen,
Expand All @@ -35,21 +32,55 @@
decimal_gen_128bit,
float_gen,
double_gen
] + single_level_array_gens + nested_array_gens_sample + [
all_basic_struct_gen,
struct_array_gen,
_struct_of_xxhash_gens
] + map_gens_sample
]

_struct_of_xxhash_gens = StructGen([(f"c{i}", g) for i, g in enumerate(_xxhash_gens)])

_xxhash_gens = (_xxhash_gens + [_struct_of_xxhash_gens] + single_level_array_gens
+ nested_array_gens_sample + [
all_basic_struct_gen,
struct_array_gen,
_struct_of_xxhash_gens
] + map_gens_sample)

@ignore_order(local=True)
@pytest.mark.parametrize("gen", _xxhash_gens, ids=idfn)
def test_xxhash64_single_column(gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, gen).selectExpr("a", "xxhash64(a)"))
lambda spark : unary_op_df(spark, gen).selectExpr("a", "xxhash64(a)"),
{"spark.sql.legacy.allowHashOnMapType" : True})

@ignore_order(local=True)
def test_xxhash64_multi_column():
gen = StructGen(_struct_of_xxhash_gens.children, nullable=False)
col_list = ",".join(gen.data_type.fieldNames())
assert_gpu_and_cpu_are_equal_collect(
lambda spark : gen_df(spark, gen).selectExpr("c0", f"xxhash64({col_list})"))
lambda spark : gen_df(spark, gen).selectExpr("c0", f"xxhash64({col_list})"),
{"spark.sql.legacy.allowHashOnMapType" : True})

def test_xxhash64_8_depth():
gen_8_depth = StructGen([('l1', # level 1
StructGen([('l2',
StructGen([('l3',
StructGen([('l4',
StructGen([('l5',
StructGen([('l6',
StructGen([('l7',
int_gen)]))]))]))]))]))]))]) # level 8
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, gen_8_depth).selectExpr("a", "xxhash64(a)"))

@allow_non_gpu("ProjectExec")
def test_xxhash64_fallback_exceeds_stack_size():
gen_9_depth = StructGen([('l1', # level 1
StructGen([('l2',
StructGen([('l3',
StructGen([('l4',
StructGen([('l5',
StructGen([('l6',
StructGen([('l7',
StructGen([('l8',
int_gen)]))]))]))]))]))]))]))]) # level 9
assert_gpu_fallback_collect(
lambda spark : unary_op_df(spark, gen_9_depth).selectExpr("a", "xxhash64(a)"),
"ProjectExec")
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ object XxHash64Utils {
case ArrayType(c: DataType, _) => computeMaxStackSizeForFlatten(c)
case st: StructType =>
1 + st.map(f => computeMaxStackSizeForFlatten(f.dataType)).max
case _ => 0 // primitive types
case _ => 1 // primitive types
}
}

Expand Down

0 comments on commit 27bf6f7

Please sign in to comment.