Skip to content

Commit

Permalink
implement 4d memorydata (#4074)
Browse files Browse the repository at this point in the history
* implement 4d memorydata

* fix ncnnoptimize memorydata 4d
  • Loading branch information
nihui committed Jul 25, 2022
1 parent 13a9533 commit 4f414c1
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 6 deletions.
7 changes: 6 additions & 1 deletion src/layer/memorydata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,19 @@ int MemoryData::load_param(const ParamDict& pd)
{
w = pd.get(0, 0);
h = pd.get(1, 0);
d = pd.get(11, 0);
c = pd.get(2, 0);

return 0;
}

int MemoryData::load_model(const ModelBin& mb)
{
if (c != 0)
if (d != 0)
{
data = mb.load(w, h, d, c, 1);
}
else if (c != 0)
{
data = mb.load(w, h, c, 1);
}
Expand Down
1 change: 1 addition & 0 deletions src/layer/memorydata.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class MemoryData : public Layer
public:
int w;
int h;
int d;
int c;

Mat data;
Expand Down
6 changes: 3 additions & 3 deletions src/layer/vulkan/memorydata_vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ int MemoryData_vulkan::create_pipeline(const Option& opt)
int out_elempack = 1;
if (out_shape.dims == 1) out_elempack = opt.use_shader_pack8 && out_shape.w % 8 == 0 ? 8 : out_shape.w % 4 == 0 ? 4 : 1;
if (out_shape.dims == 2) out_elempack = opt.use_shader_pack8 && out_shape.h % 8 == 0 ? 8 : out_shape.h % 4 == 0 ? 4 : 1;
if (out_shape.dims == 3) out_elempack = opt.use_shader_pack8 && out_shape.c % 8 == 0 ? 8 : out_shape.c % 4 == 0 ? 4 : 1;
if (out_shape.dims == 3 || out_shape.dims == 4) out_elempack = opt.use_shader_pack8 && out_shape.c % 8 == 0 ? 8 : out_shape.c % 4 == 0 ? 4 : 1;

size_t out_elemsize;
if (opt.use_fp16_storage)
Expand All @@ -50,7 +50,7 @@ int MemoryData_vulkan::create_pipeline(const Option& opt)
Mat out_shape_packed;
if (out_shape.dims == 1) out_shape_packed = Mat(out_shape.w / out_elempack, (void*)0, out_elemsize, out_elempack);
if (out_shape.dims == 2) out_shape_packed = Mat(out_shape.w, out_shape.h / out_elempack, (void*)0, out_elemsize, out_elempack);
if (out_shape.dims == 3) out_shape_packed = Mat(out_shape.w, out_shape.h, out_shape.c / out_elempack, (void*)0, out_elemsize, out_elempack);
if (out_shape.dims == 3 || out_shape.dims == 4) out_shape_packed = Mat(out_shape.w, out_shape.h, out_shape.c / out_elempack, (void*)0, out_elemsize, out_elempack);

// check blob shape
if (!vkdev->shape_support_image_storage(out_shape_packed))
Expand All @@ -68,7 +68,7 @@ int MemoryData_vulkan::upload_model(VkTransfer& cmd, const Option& opt)
int elempack = 1;
if (shape.dims == 1) elempack = opt.use_shader_pack8 && shape.w % 8 == 0 ? 8 : shape.w % 4 == 0 ? 4 : 1;
if (shape.dims == 2) elempack = opt.use_shader_pack8 && shape.h % 8 == 0 ? 8 : shape.h % 4 == 0 ? 4 : 1;
if (shape.dims == 3) elempack = opt.use_shader_pack8 && shape.c % 8 == 0 ? 8 : shape.c % 4 == 0 ? 4 : 1;
if (shape.dims == 3 || shape.dims == 4) elempack = opt.use_shader_pack8 && shape.c % 8 == 0 ? 8 : shape.c % 4 == 0 ? 4 : 1;

Mat data_packed;
convert_packing(data, data_packed, elempack, opt);
Expand Down
9 changes: 9 additions & 0 deletions src/modelbin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ Mat ModelBin::load(int w, int h, int c, int type) const
return m.reshape(w, h, c);
}

Mat ModelBin::load(int w, int h, int d, int c, int type) const
{
Mat m = load(w * h * d * c, type);
if (m.empty())
return m;

return m.reshape(w, h, d, c);
}

class ModelBinFromDataReaderPrivate
{
public:
Expand Down
2 changes: 2 additions & 0 deletions src/modelbin.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class NCNN_EXPORT ModelBin
virtual Mat load(int w, int h, int type) const;
// load dim
virtual Mat load(int w, int h, int c, int type) const;
// load cube
virtual Mat load(int w, int h, int d, int c, int type) const;
};

class ModelBinFromDataReaderPrivate;
Expand Down
5 changes: 3 additions & 2 deletions tools/modelwriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ int ModelWriter::fwrite_weight_tag_data(const ncnn::Mat& data, FILE* bp, float a
{
int p0 = ftell(bp);

ncnn::Mat data_flattened = data.reshape(data.w * data.h * data.c);
ncnn::Mat data_flattened = data.reshape(data.w * data.h * data.d * data.c);
if (gen_random_weight)
Randomize(data_flattened, a, b);

Expand Down Expand Up @@ -660,7 +660,7 @@ int ModelWriter::fwrite_weight_data(const ncnn::Mat& data, FILE* bp, float a, fl
{
int p0 = ftell(bp);

ncnn::Mat data_flattened = data.reshape(data.w * data.h * data.c);
ncnn::Mat data_flattened = data.reshape(data.w * data.h * data.d * data.c);
if (gen_random_weight)
Randomize(data_flattened, a, b);

Expand Down Expand Up @@ -1761,6 +1761,7 @@ int ModelWriter::save(const char* parampath, const char* binpath)
fprintf_param_value(" 0=%d", w)
fprintf_param_value(" 1=%d", h)
fprintf_param_value(" 2=%d", c)
fprintf_param_value(" 11=%d", d)
fwrite_weight_data(op->data, bp);
}
else if (layer->type == "MultiHeadAttention")
Expand Down

0 comments on commit 4f414c1

Please sign in to comment.