forked from VeriSilicon/tflite-vx-delegate
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.cc
116 lines (100 loc) · 3.8 KB
/
utils.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
/****************************************************************************
*
* Copyright (c) 2021 Vivante Corporation
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#include "utils.h"
#include "tensorflow/lite/minimal_logging.h"
using namespace tflite;
namespace vx {
namespace delegate {
namespace utils {
// transpose channel_dim while doing transpose operation.
int32_t TransposeChannelDim(const std::vector<uint32_t>& perm,
int32_t channel_dim) {
if (channel_dim < 0) {
TFLITE_LOG_PROD(TFLITE_LOG_ERROR, "invalid channel_dim");
return -1;
}
for (uint32_t i = 0; i < perm.size(); i++) {
if (channel_dim == perm.at(i)) {
return i;
}
}
TFLITE_LOG_PROD(TFLITE_LOG_ERROR, "Can't find channle_dim");
return -1;
}
// Convert the perm in TfLite to the perm in vx-delegate when transpose.
std::vector<uint32_t> GetOvxTransposePerm(const std::vector<uint32_t>& perm) {
std::vector<uint32_t> perm_out(perm.rbegin(), perm.rend());
std::vector<uint32_t> perm_in, ovx_perm;
for (int i = perm.size() - 1; i >= 0; i--) {
perm_in.push_back(i);
}
for (auto o : perm_out) {
for (int i = 0; i < perm_in.size(); i++) {
if (o == perm_in[i]) {
ovx_perm.push_back(i);
break;
}
}
}
return ovx_perm;
}
void GenerateWeightsDataForBilinear(float* data,
const std::vector<uint32_t>& weight_shape,
uint32_t scale_w,
uint32_t scale_h) {
int32_t width = weight_shape[0];
int32_t height = weight_shape[1];
int32_t channel_in = weight_shape[2];
int32_t channel_out = weight_shape[3];
for (int o = 0; o < channel_out; o++) {
for (int h = 0; h < height; h++) {
float center_w = width % 2 == 1 ? scale_w - 1.0 : scale_w - 0.5;
float center_h = height % 2 == 1 ? scale_h - 1.0 : scale_h - 0.5;
for (int w = 0; w < width; w++) {
data[o * (channel_in + 1) * width * height + h * width + w] =
(1 - std::abs(w - center_w) / scale_w) *
(1 - std::abs(h - center_h) / scale_h);
}
}
}
return;
}
void GenerateWeightDataForNearest(float* data,
const std::vector<uint32_t>& weight_shape) {
uint32_t width = weight_shape[0];
uint32_t height = weight_shape[1];
uint32_t channel_in = weight_shape[2];
uint32_t channel_out = weight_shape[3];
for (int o = 0; o < channel_out; o++) {
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
data[o * (channel_in + 1) * width * height + h * width + w] = 1;
}
}
}
return;
}
} // namespace utils
} // namespace delegate
} // namespace vx