Skip to content

Commit

Permalink
[Unity][Cutlass] Fix C source generation of dense operation (apache#1…
Browse files Browse the repository at this point in the history
…6476)

This commit fixes an issue that generates wrong c sources of dense operation using cutlass.

Co-authored-by: 진배 박 <jinbae@nexon.co.kr>
  • Loading branch information
creaitr and jinbaep authored Apr 30, 2024
1 parent 6252fa5 commit a320b63
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,10 @@ def get_flattened_batch_dim(arg_name, batch_rank):
transposed = "transposed" in func_name or "dense" in func_name
lhs_arg_idx = _get_optional_int_annotation(annotations, "lhs_arg_idx", 0)
rhs_arg_idx = _get_optional_int_annotation(annotations, "rhs_arg_idx", 1)
bias_arg_idx = _get_optional_int_annotation(annotations, "bias_arg_idx", None)
if "bias" in func_name:
bias_arg_idx = _get_optional_int_annotation(annotations, "bias_arg_idx", 2)
else:
bias_arg_idx = _get_optional_int_annotation(annotations, "bias_arg_idx", None)
residual_arg_idx = _get_optional_int_annotation(annotations, "residual_arg_idx", None)

lhs_arg = func_args[lhs_arg_idx]
Expand Down

0 comments on commit a320b63

Please sign in to comment.