Skip to content

Commit bb3737f

Browse files
committed
Merged PR 1144: Implement transpose kernel
Implement transpose kernel Related work items: #147
1 parent 119fdd7 commit bb3737f

File tree

3 files changed

+235
-0
lines changed

3 files changed

+235
-0
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#include "core/providers/cpu/tensor/transpose.h"
2+
3+
namespace Lotus {
4+
5+
/* A permutation [a,b,c,...] indicates that
6+
- The 0-th dimension of the output corresponds to the a-th dimension of input
7+
- The 1-st dimension of the output corresponds to the b-th dimension of input
8+
- The 2-nd dimension of the output corresponds to the c-th dimension of input
9+
etc.
10+
*/
11+
12+
// The following is a reference (unoptimized) implementation of Transpose.
13+
// TODO: Optimize the implementation to use memcpy for sub-blocks that can be so copied.
14+
15+
template <>
16+
Status Transpose<float>::compute(OpKernelContext* ctx) const {
17+
const Tensor& X = *ctx->input<Tensor>(0);
18+
const TensorShape& input_shape = X.shape();
19+
const std::vector<int64_t>& input_dims = input_shape.GetDims();
20+
size_t rank = input_dims.size();
21+
22+
// Determine permutation to use:
23+
// If no permutation was specified in the attributes, the default is [rank-1, ..., 0]
24+
const std::vector<int64_t>* p_perm;
25+
std::vector<int64_t> default_perm(rank);
26+
27+
if (perm_specified_)
28+
p_perm = &perm_;
29+
else {
30+
for (int i = 0; i < rank; ++i)
31+
default_perm[i] = rank - i - 1;
32+
p_perm = &default_perm;
33+
}
34+
35+
// Determine shape of output, as well as stride to be used:
36+
// stride[i] indicates the stride for the input-tensor dimension corresponding
37+
// to the i-th dimension of the output
38+
39+
std::vector<int64_t> output_dims(rank);
40+
std::vector<size_t> stride(rank);
41+
for (int i = 0; i < rank; i++) {
42+
size_t inpdim = (*p_perm)[i];
43+
output_dims[i] = input_dims[inpdim];
44+
if (inpdim + 1 < rank)
45+
stride[i] = input_shape.SizeFromDimension(inpdim + 1);
46+
else
47+
stride[i] = 1;
48+
}
49+
50+
TensorShape output_shape{output_dims};
51+
Tensor* Y = ctx->output(0, output_shape);
52+
const float* Xdata = X.data<float>();
53+
float* Ydata = Y->mutable_data<float>();
54+
auto size = output_shape.Size();
55+
std::vector<int64_t> y_index(rank, 0); // index used to iterate over Y's iteration-space
56+
for (size_t i = 0; i < size; ++i) {
57+
// convert y_index into offset in X's data
58+
size_t x_offset = 0;
59+
for (int j = 0; j < rank; ++j) {
60+
x_offset += y_index[j] * stride[j];
61+
}
62+
// copy
63+
LOTUS_ENFORCE((0 <= x_offset) && (x_offset < size));
64+
*(Ydata + i) = *(Xdata + x_offset);
65+
// increment y_index:
66+
for (int64_t k = rank - 1; k >= 0; --k) {
67+
y_index[k]++;
68+
if (y_index[k] < output_dims[k]) break;
69+
y_index[k] = 0;
70+
}
71+
}
72+
73+
return Status::OK();
74+
}
75+
76+
REGISTER_KERNEL(KernelDef("Transpose")
77+
.Domain(LotusIR::kOnnxDomain)
78+
.SinceVersion(1, 2)
79+
.Provider(LotusIR::kCpuExecutionProvider)
80+
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
81+
Transpose<float>);
82+
83+
} // namespace Lotus
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#pragma once
2+
3+
#include "gsl/gsl_util"
4+
#include "core/common/common.h"
5+
#include "core/framework/op_kernel.h"
6+
7+
namespace Lotus {
8+
template <typename T>
9+
class Transpose final : public OpKernel {
10+
public:
11+
Transpose(const OpKernelInfo& info) : OpKernel{info}, perm_specified_(false) {
12+
Status status = info.GetAttrs<int64_t>("perm", perm_);
13+
14+
if (status.IsOK()) {
15+
perm_specified_ = true;
16+
size_t rank = perm_.size();
17+
std::vector<bool> seen(rank, false);
18+
// Check that perm_ is a valid permutation of [0,rank-1]
19+
for (auto i : perm_) {
20+
if ((i < 0) || (i >= gsl::narrow<int64_t>(rank)))
21+
LOTUS_THROW("Attribute perm of Transpose has an invalid value. Value ", i, " is outside range.");
22+
if (seen[i])
23+
LOTUS_THROW("Attribute perm of Transpose has an invalid value. Value ", i, " is repeated.");
24+
seen[i] = true;
25+
}
26+
}
27+
}
28+
29+
Status compute(OpKernelContext* context) const override;
30+
31+
private:
32+
bool perm_specified_;
33+
std::vector<int64_t> perm_;
34+
};
35+
} // namespace Lotus
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
#include "core/providers/cpu/tensor/transpose.h"
2+
#include "gtest/gtest.h"
3+
#include "test/test_utils.h"
4+
5+
namespace Lotus {
6+
namespace Test {
7+
8+
template <size_t count>
9+
void TransposeTest(std::vector<int64_t>& input_shape,
10+
std::vector<float>& input_vals,
11+
std::vector<int64_t>* p_perm,
12+
std::vector<int64_t> expected_shape,
13+
const float (&expected_vals)[count]) {
14+
TypeProto tensor_float;
15+
tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
16+
LotusIR::NodeArg input_def("X", &tensor_float), output_def("Y", &tensor_float);
17+
std::vector<LotusIR::NodeArg*> input_defs{&input_def};
18+
std::vector<LotusIR::NodeArg*> output_defs{&output_def};
19+
20+
TestModel model("TransposeTest", input_defs, output_defs);
21+
22+
if (nullptr != p_perm)
23+
model.Node().AddAttribute("perm", *p_perm);
24+
25+
SimpleFloatTest<Transpose> test(model);
26+
test.AddInput(input_shape, input_vals);
27+
test.AddOutput(expected_shape);
28+
test.Run(expected_shape, expected_vals);
29+
}
30+
31+
// Test 2 dimensional transpose, with no permutation attribute specified
32+
TEST(TransposeOpTest, TwoDimNoAttr) {
33+
std::vector<int64_t> input_shape({2, 3});
34+
std::vector<float> input_vals = {
35+
1.0f, 2.0f, 3.0f,
36+
4.0f, 5.0f, 6.0f};
37+
38+
std::vector<int64_t> expected_shape({3, 2});
39+
float expected_vals[] = {
40+
1.0f, 4.0f,
41+
2.0f, 5.0f,
42+
3.0f, 6.0f};
43+
44+
TransposeTest(input_shape, input_vals, nullptr, expected_shape, expected_vals);
45+
}
46+
47+
// Test 2 dimensional transpose, with permutation attribute specified
48+
TEST(TransposeOpTest, TwoDim) {
49+
std::vector<int64_t> input_shape({2, 3});
50+
std::vector<float> input_vals = {
51+
1.0f, 2.0f, 3.0f,
52+
4.0f, 5.0f, 6.0f};
53+
54+
std::vector<int64_t> perm = {1, 0};
55+
std::vector<int64_t> expected_shape({3, 2});
56+
float expected_vals[] = {
57+
1.0f, 4.0f,
58+
2.0f, 5.0f,
59+
3.0f, 6.0f};
60+
61+
TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals);
62+
}
63+
64+
// Test 3 dimensional transpose, with permutation attribute specified
65+
TEST(TransposeOpTest, ThreeDim) {
66+
std::vector<int64_t> input_shape({4, 2, 3});
67+
std::vector<float> input_vals = {
68+
1.0f, 2.0f, 3.0f,
69+
4.0f, 5.0f, 6.0f,
70+
71+
1.1f, 2.1f, 3.1f,
72+
4.1f, 5.1f, 6.1f,
73+
74+
1.2f, 2.2f, 3.2f,
75+
4.2f, 5.2f, 6.2f,
76+
77+
1.3f, 2.3f, 3.3f,
78+
4.3f, 5.3f, 6.3f};
79+
80+
std::vector<int64_t> perm = {0, 2, 1};
81+
std::vector<int64_t> expected_shape({4, 3, 2});
82+
float expected_vals[] = {
83+
1.0f,
84+
4.0f,
85+
2.0f,
86+
5.0f,
87+
3.0f,
88+
6.0f,
89+
90+
1.1f,
91+
4.1f,
92+
2.1f,
93+
5.1f,
94+
3.1f,
95+
6.1f,
96+
97+
1.2f,
98+
4.2f,
99+
2.2f,
100+
5.2f,
101+
3.2f,
102+
6.2f,
103+
104+
1.3f,
105+
4.3f,
106+
2.3f,
107+
5.3f,
108+
3.3f,
109+
6.3f,
110+
111+
};
112+
113+
TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals);
114+
}
115+
116+
} // namespace Test
117+
} // namespace Lotus

0 commit comments

Comments
 (0)