Skip to content

Commit

Permalink
delete tranpose_kernel optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
wyushun committed Dec 9, 2021
1 parent 360ec25 commit e026434
Showing 1 changed file with 2 additions and 19 deletions.
21 changes: 2 additions & 19 deletions oneflow/user/kernels/transpose_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,6 @@ namespace oneflow {

namespace user_op {

namespace {
bool IsOrderedPermute(const std::vector<int32_t>& perm) {
for (auto i = 0; i < perm.size(); i++) {
if (perm[i] != i) { return false; }
}
return true;
}
} // namespace

template<typename Context>
std::unique_ptr<ep::primitive::Permute> NewPermutePrimitive(Context* ctx) {
const int64_t num_dims = ctx->TensorDesc4ArgNameAndIndex("output", 0)->shape().NumAxes();
Expand Down Expand Up @@ -58,16 +49,8 @@ class TransposeKernel final : public OpKernel, public user_op::CudaGraphSupport
int64_t elem_cnt = tensor_out->shape().elem_cnt();

if (elem_cnt != 0) {
if (IsOrderedPermute(perm)) {
// if permute vector is 0,1,...,n, do data copy directly
AutoMemcpy(ctx->stream(), tensor_out->mut_dptr(), tensor_in->dptr(),
elem_cnt * GetSizeOfDataType(dtype), tensor_out->mem_case(),
tensor_in->mem_case());
} else {
primitive->Launch(ctx->stream(), dtype, num_dims, src_dims, tensor_in->dptr(), perm.data(),
tensor_out->mut_dptr());
}

primitive->Launch(ctx->stream(), dtype, num_dims, src_dims, tensor_in->dptr(), perm.data(),
tensor_out->mut_dptr());
} else {
// For 0-d Tensor
return;
Expand Down

0 comments on commit e026434

Please sign in to comment.