Skip to content

Commit 0a41c60

Browse files
authored
No Transpose emission for 0 rank tensor (#375)
* `stablehlo.sort` Ops * do not transpose rank 0 tensor * move check * format
1 parent 3afed78 commit 0a41c60

File tree

2 files changed

+5
-11
lines changed

2 files changed

+5
-11
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -335,12 +335,6 @@ function link(job, compiled)
335335
return compiled
336336
end
337337

338-
function transpose_val(val)
339-
attr = MLIR.IR.DenseArrayAttribute(
340-
Int64[reverse(0:(length(size(MLIR.IR.type(val))) - 1))...]
341-
)
342-
return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1)
343-
end
344338

345339
Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
346340
args...;
@@ -366,7 +360,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
366360
Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray
367361
push!(rarrays, ta)
368362
arg = ta.mlir_data
369-
arg = transpose_val(arg)
363+
arg = Reactant.TracedUtils.transpose_val(arg)
370364
push!(restys, MLIR.IR.type(arg))
371365
push!(mlir_args, arg)
372366
push!(
@@ -399,7 +393,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
399393
)
400394
# call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(func.mod))
401395
for (i, res) in enumerate(rarrays)
402-
res.mlir_data = transpose_val(MLIR.IR.result(call, i))
396+
res.mlir_data = Reactant.TracedUtils.transpose_val(MLIR.IR.result(call, i))
403397
end
404398

405399
@show blockdim

src/TracedUtils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ function transpose_ty(mlirty)
108108
return MLIR.IR.TensorType([reverse(size(mlirty))...], eltype(mlirty))
109109
end
110110
function transpose_val(val)
111-
attr = MLIR.IR.DenseArrayAttribute(
112-
Int64[reverse(0:(length(size(MLIR.IR.type(val))) - 1))...]
113-
)
111+
val_size = size(MLIR.IR.type(val))
112+
val_size == () && return val
113+
attr = MLIR.IR.DenseArrayAttribute(Int64[reverse(0:(length(val_size) - 1))...])
114114
return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1)
115115
end
116116

0 commit comments

Comments
 (0)