Skip to content

Commit

Permalink
[BUG] Fix ValueError caused by different operand data types in `if_…
Browse files Browse the repository at this point in the history
…then_else` while initializing `Conv2dTransposeGemmImageTask` (#470)

Closes #469
  • Loading branch information
BolinSNLHM authored and vadiklyutiy committed Dec 19, 2024
1 parent 55e29e9 commit f9fc705
Showing 1 changed file with 3 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def fcompute(b, i, k):
return if_then_else(
cond=logical_and(xx >= 0, xx < p * sx, xx % sx == 0, yy >= 0, yy < q * sy, yy % sy == 0),
then_expr=data[ni, gi * og + ogi, xx // sx, yy // sy],
else_expr=0.0,
# If we simply pass `else_expr=0.0`, it will be interpreted as float32, which causes error when the
# input data tensor is of dtype e.g., float16.
else_expr=data.type.dtype.zero,
)

output = compute(name='gemm_x', shape=[groups, n * h * w, og * kx * ky], fcompute=fcompute)
Expand Down

0 comments on commit f9fc705

Please sign in to comment.