Skip to content

Commit

Permalink
pnnx fp16 option for ncnn and onnx weight type (#4350)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Nov 14, 2022
1 parent 6967baa commit ec1b07c
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 190 deletions.
3 changes: 3 additions & 0 deletions tools/pnnx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ Usage: pnnx [model.pt] [(key=value)...]
ncnnparam=model.ncnn.param
ncnnbin=model.ncnn.bin
ncnnpy=model_ncnn.py
fp16=1
optlevel=2
device=cpu/gpu
inputshape=[1,3,224,224],...
Expand All @@ -119,6 +120,8 @@ Parameters:

`ncnnpy` (default="*_ncnn.py"): pyncnn script for inference

`fp16` (default=1): save ncnn weight and onnx in fp16 data type

`optlevel` (default=2): graph optimization level

| Option | Optimization level |
Expand Down
1 change: 0 additions & 1 deletion tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,6 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/insert_split.cpp
pass_ncnn/chain_multi_output.cpp
pass_ncnn/solve_batch_index.cpp
pass_ncnn/convert_to_fp16_model.cpp

pass_ncnn/eliminate_noop.cpp
pass_ncnn/eliminate_tail_reshape_permute.cpp
Expand Down
9 changes: 7 additions & 2 deletions tools/pnnx/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ static void show_usage()
fprintf(stderr, " ncnnparam=model.ncnn.param\n");
fprintf(stderr, " ncnnbin=model.ncnn.bin\n");
fprintf(stderr, " ncnnpy=model_ncnn.py\n");
fprintf(stderr, " fp16=1\n");
fprintf(stderr, " optlevel=2\n");
fprintf(stderr, " device=cpu/gpu\n");
fprintf(stderr, " inputshape=[1,3,224,224],...\n");
Expand Down Expand Up @@ -210,6 +211,7 @@ int main(int argc, char** argv)
std::string ncnnparampath = ptbase + ".ncnn.param";
std::string ncnnbinpath = ptbase + ".ncnn.bin";
std::string ncnnpypath = ptbase + "_ncnn.py";
int fp16 = 1;
int optlevel = 2;
std::string device = "cpu";
std::vector<std::vector<int64_t> > input_shapes;
Expand Down Expand Up @@ -250,6 +252,8 @@ int main(int argc, char** argv)
ncnnbinpath = std::string(value);
if (strcmp(key, "ncnnpy") == 0)
ncnnpypath = std::string(value);
if (strcmp(key, "fp16") == 0)
fp16 = atoi(value);
if (strcmp(key, "optlevel") == 0)
optlevel = atoi(value);
if (strcmp(key, "device") == 0)
Expand All @@ -273,6 +277,7 @@ int main(int argc, char** argv)
fprintf(stderr, "ncnnparam = %s\n", ncnnparampath.c_str());
fprintf(stderr, "ncnnbin = %s\n", ncnnbinpath.c_str());
fprintf(stderr, "ncnnpy = %s\n", ncnnpypath.c_str());
fprintf(stderr, "fp16 = %d\n", fp16);
fprintf(stderr, "optlevel = %d\n", optlevel);
fprintf(stderr, "device = %s\n", device.c_str());
fprintf(stderr, "inputshape = ");
Expand Down Expand Up @@ -415,7 +420,7 @@ int main(int argc, char** argv)
pnnx_graph.python(pnnxpypath, pnnxbinpath);

#if BUILD_PNNX2ONNX
pnnx::save_onnx(pnnx_graph, pnnxonnxpath.c_str());
pnnx::save_onnx(pnnx_graph, pnnxonnxpath.c_str(), fp16);
#else
fprintf(stderr, "pnnx build without onnx-zero support, skip saving onnx\n");
#endif
Expand All @@ -426,7 +431,7 @@ int main(int argc, char** argv)

pnnx::pass_ncnn(pnnx_graph);

pnnx::save_ncnn(pnnx_graph, ncnnparampath, ncnnbinpath, ncnnpypath);
pnnx::save_ncnn(pnnx_graph, ncnnparampath, ncnnbinpath, ncnnpypath, fp16);
}

// pnnx::Graph pnnx_graph2;
Expand Down
3 changes: 0 additions & 3 deletions tools/pnnx/src/pass_ncnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include "pass_ncnn/insert_split.h"
#include "pass_ncnn/chain_multi_output.h"
#include "pass_ncnn/solve_batch_index.h"
#include "pass_ncnn/convert_to_fp16_model.h"

#include "pass_ncnn/eliminate_noop.h"
#include "pass_ncnn/eliminate_tail_reshape_permute.h"
Expand Down Expand Up @@ -134,8 +133,6 @@ void pass_ncnn(Graph& g)
ncnn::convert_input(g);

ncnn::eliminate_output(g);

ncnn::convert_to_fp16_model(g);
}

} // namespace pnnx
133 changes: 0 additions & 133 deletions tools/pnnx/src/pass_ncnn/convert_to_fp16_model.cpp

This file was deleted.

25 changes: 0 additions & 25 deletions tools/pnnx/src/pass_ncnn/convert_to_fp16_model.h

This file was deleted.

97 changes: 96 additions & 1 deletion tools/pnnx/src/save_ncnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,66 @@ static bool string_is_positive_integer(const std::string& t)
return true;
}

int save_ncnn(const Graph& g, const std::string& parampath, const std::string& binpath, const std::string& pypath)
static unsigned short float32_to_float16(float value)
{
// 1 : 8 : 23
union
{
unsigned int u;
float f;
} tmp;

tmp.f = value;

// 1 : 8 : 23
unsigned short sign = (tmp.u & 0x80000000) >> 31;
unsigned short exponent = (tmp.u & 0x7F800000) >> 23;
unsigned int significand = tmp.u & 0x7FFFFF;

// NCNN_LOGE("%d %d %d", sign, exponent, significand);

// 1 : 5 : 10
unsigned short fp16;
if (exponent == 0)
{
// zero or denormal, always underflow
fp16 = (sign << 15) | (0x00 << 10) | 0x00;
}
else if (exponent == 0xFF)
{
// infinity or NaN
fp16 = (sign << 15) | (0x1F << 10) | (significand ? 0x200 : 0x00);
}
else
{
// normalized
short newexp = exponent + (-127 + 15);
if (newexp >= 31)
{
// overflow, return infinity
fp16 = (sign << 15) | (0x1F << 10) | 0x00;
}
else if (newexp <= 0)
{
// Some normal fp32 cannot be expressed as normal fp16
fp16 = (sign << 15) | (0x00 << 10) | 0x00;
}
else
{
// normal fp16
fp16 = (sign << 15) | (newexp << 10) | (significand >> 13);
}
}

return fp16;
}

static size_t alignSize(size_t sz, int n)
{
return (sz + n - 1) & -n;
}

int save_ncnn(const Graph& g, const std::string& parampath, const std::string& binpath, const std::string& pypath, int fp16)
{
FILE* paramfp = fopen(parampath.c_str(), "wb");
if (!paramfp)
Expand Down Expand Up @@ -196,12 +255,48 @@ int save_ncnn(const Graph& g, const std::string& parampath, const std::string& b
}
}

bool is_type_flag_fp32 = false;
for (const auto& it : op->attrs)
{
// fprintf(paramfp, " @%s=", it.first.c_str());

const Attribute& attr = it.second;

if (fp16 && is_type_flag_fp32)
{
// fp32 -> fp16
const float* p = (const float*)attr.data.data();
int len = attr.data.size() / 4;
std::vector<char> data_fp16(alignSize(len * 2, 4));
unsigned short* p_fp16 = (unsigned short*)data_fp16.data();
for (int i = 0; i < len; i++)
{
p_fp16[i] = float32_to_float16(p[i]);
}

// pad size to 4bytes
if (len % 2 == 1)
{
// pad with fixed value for model hash consistency
p_fp16[len] = 0x2283;
}

fwrite(data_fp16.data(), data_fp16.size(), 1, binfp);

is_type_flag_fp32 = false;
continue;
}

if (fp16 && attr.type == 0 && attr.data == std::vector<char> {0, 0, 0, 0})
{
// write fp16 flag
unsigned int fp16_flag = 0x01306B47;
fwrite((const char*)&fp16_flag, sizeof(fp16_flag), 1, binfp);

is_type_flag_fp32 = true;
continue;
}

fwrite(attr.data.data(), attr.data.size(), 1, binfp);
}

Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/src/save_ncnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

namespace pnnx {

int save_ncnn(const Graph& g, const std::string& parampath, const std::string& binpath, const std::string& pypath);
int save_ncnn(const Graph& g, const std::string& parampath, const std::string& binpath, const std::string& pypath, int fp16);

} // namespace pnnx

Expand Down
Loading

0 comments on commit ec1b07c

Please sign in to comment.