Skip to content

Commit

Permalink
squeeze and expanddims 4d (#4346)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Nov 13, 2022
1 parent 6a47f8d commit 498ca73
Show file tree
Hide file tree
Showing 10 changed files with 310 additions and 91 deletions.
40 changes: 40 additions & 0 deletions src/layer/expanddims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ int ExpandDims::load_param(const ParamDict& pd)
{
expand_w = pd.get(0, 0);
expand_h = pd.get(1, 0);
expand_d = pd.get(11, 0);
expand_c = pd.get(2, 0);
axes = pd.get(3, Mat());

Expand All @@ -36,16 +37,19 @@ int ExpandDims::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt
{
int w = bottom_blob.w;
int h = bottom_blob.h;
int channels = bottom_blob.c;
int dims = bottom_blob.dims;

bool _expand_w = false;
bool _expand_h = false;
bool _expand_d = false;
bool _expand_c = false;

if (axes.empty())
{
_expand_w = expand_w;
_expand_h = expand_h;
_expand_d = expand_d;
_expand_c = expand_c;
}
else
Expand Down Expand Up @@ -77,6 +81,22 @@ int ExpandDims::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt
{
_expand_w = true;
}
if (dims == 3 && axis == 0)
{
_expand_c = true;
}
if (dims == 3 && axis == 1)
{
_expand_d = true;
}
if (dims == 3 && axis == 2)
{
_expand_h = true;
}
if (dims == 3 && axis == 3)
{
_expand_w = true;
}
}
}

Expand Down Expand Up @@ -114,6 +134,26 @@ int ExpandDims::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt
}
}

if (dims == 3)
{
if (_expand_w)
{
top_blob = bottom_blob.reshape(1, w, h, channels, opt.blob_allocator);
}
else if (_expand_h)
{
top_blob = bottom_blob.reshape(w, 1, h, channels, opt.blob_allocator);
}
else if (_expand_d)
{
top_blob = bottom_blob.reshape(w, h, 1, channels, opt.blob_allocator);
}
else if (_expand_c)
{
top_blob = bottom_blob.reshape(w, h, channels, 1, opt.blob_allocator);
}
}

if (top_blob.empty())
return -100;

Expand Down
1 change: 1 addition & 0 deletions src/layer/expanddims.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class ExpandDims : public Layer
public:
int expand_w;
int expand_h;
int expand_d;
int expand_c;
Mat axes;
};
Expand Down
84 changes: 84 additions & 0 deletions src/layer/squeeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ int Squeeze::load_param(const ParamDict& pd)
{
squeeze_w = pd.get(0, 0);
squeeze_h = pd.get(1, 0);
squeeze_d = pd.get(11, 0);
squeeze_c = pd.get(2, 0);
axes = pd.get(3, Mat());

Expand All @@ -36,17 +37,20 @@ int Squeeze::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) c
{
int w = bottom_blob.w;
int h = bottom_blob.h;
int d = bottom_blob.d;
int channels = bottom_blob.c;
int dims = bottom_blob.dims;

bool _squeeze_w = false;
bool _squeeze_h = false;
bool _squeeze_d = false;
bool _squeeze_c = false;

if (axes.empty())
{
_squeeze_w = w == 1 && squeeze_w;
_squeeze_h = h == 1 && squeeze_h;
_squeeze_d = d == 1 && squeeze_d;
_squeeze_c = channels == 1 && squeeze_c;
}
else
Expand Down Expand Up @@ -82,6 +86,22 @@ int Squeeze::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) c
{
_squeeze_w = w == 1;
}
if (dims == 4 && axis == 0)
{
_squeeze_c = channels == 1;
}
if (dims == 4 && axis == 1)
{
_squeeze_d = d == 1;
}
if (dims == 4 && axis == 2)
{
_squeeze_h = h == 1;
}
if (dims == 4 && axis == 3)
{
_squeeze_w = w == 1;
}
}
}

Expand Down Expand Up @@ -143,6 +163,70 @@ int Squeeze::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) c
}
}

if (dims == 4)
{
if (_squeeze_w && _squeeze_h && _squeeze_d && _squeeze_c)
{
top_blob = bottom_blob.reshape(1, opt.blob_allocator);
}
else if (_squeeze_w && _squeeze_h && _squeeze_d)
{
top_blob = bottom_blob.reshape(channels, opt.blob_allocator);
}
else if (_squeeze_h && _squeeze_d && _squeeze_c)
{
top_blob = bottom_blob.reshape(w, opt.blob_allocator);
}
else if (_squeeze_w && _squeeze_d && _squeeze_c)
{
top_blob = bottom_blob.reshape(h, opt.blob_allocator);
}
else if (_squeeze_w && _squeeze_h && _squeeze_c)
{
top_blob = bottom_blob.reshape(d, opt.blob_allocator);
}
else if (_squeeze_w && _squeeze_h)
{
top_blob = bottom_blob.reshape(d, channels, opt.blob_allocator);
}
else if (_squeeze_w && _squeeze_d)
{
top_blob = bottom_blob.reshape(h, channels, opt.blob_allocator);
}
else if (_squeeze_h && _squeeze_d)
{
top_blob = bottom_blob.reshape(w, channels, opt.blob_allocator);
}
else if (_squeeze_h && _squeeze_c)
{
top_blob = bottom_blob.reshape(w, d, opt.blob_allocator);
}
else if (_squeeze_w && _squeeze_c)
{
top_blob = bottom_blob.reshape(h, d, opt.blob_allocator);
}
else if (_squeeze_d && _squeeze_c)
{
top_blob = bottom_blob.reshape(w, h, opt.blob_allocator);
}
else if (_squeeze_w)
{
top_blob = bottom_blob.reshape(h, d, channels, opt.blob_allocator);
}
else if (_squeeze_h)
{
top_blob = bottom_blob.reshape(w, d, channels, opt.blob_allocator);
}
else if (_squeeze_d)
{
top_blob = bottom_blob.reshape(w, h, channels, opt.blob_allocator);
}
else if (_squeeze_c)
{
top_blob = bottom_blob.reshape(w, h, d, opt.blob_allocator);
}
}

if (top_blob.empty())
return -100;

Expand Down
1 change: 1 addition & 0 deletions src/layer/squeeze.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Squeeze : public Layer
public:
int squeeze_w;
int squeeze_h;
int squeeze_d;
int squeeze_c;
Mat axes;
};
Expand Down
115 changes: 76 additions & 39 deletions tests/test_expanddims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,20 @@
#include "layer/expanddims.h"
#include "testutil.h"

static int test_expanddims(const ncnn::Mat& a, int expand_w, int expand_h, int expand_c)
static int test_expanddims(const ncnn::Mat& a, int expand_w, int expand_h, int expand_d, int expand_c)
{
ncnn::ParamDict pd;
pd.set(0, expand_w);
pd.set(1, expand_h);
pd.set(11, expand_d);
pd.set(2, expand_c);

std::vector<ncnn::Mat> weights(0);

int ret = test_layer<ncnn::ExpandDims>("ExpandDims", pd, weights, a);
if (ret != 0)
{
fprintf(stderr, "test_expanddims failed a.dims=%d a=(%d %d %d) expand_w=%d expand_h=%d expand_c=%d\n", a.dims, a.w, a.h, a.c, expand_w, expand_h, expand_c);
fprintf(stderr, "test_expanddims failed a.dims=%d a=(%d %d %d %d) expand_w=%d expand_h=%d expand_d=%d expand_c=%d\n", a.dims, a.w, a.h, a.d, a.c, expand_w, expand_h, expand_d, expand_c);
}

return ret;
Expand Down Expand Up @@ -60,6 +61,17 @@ static ncnn::Mat IntArrayMat(int a0, int a1, int a2)
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1, int a2, int a3)
{
ncnn::Mat m(4);
int* p = m;
p[0] = a0;
p[1] = a1;
p[2] = a2;
p[3] = a3;
return m;
}

static void print_int_array(const ncnn::Mat& a)
{
const int* pa = a;
Expand All @@ -82,7 +94,7 @@ static int test_expanddims_axes(const ncnn::Mat& a, const ncnn::Mat& axes)
int ret = test_layer<ncnn::ExpandDims>("ExpandDims", pd, weights, a);
if (ret != 0)
{
fprintf(stderr, "test_expanddims_axes failed a.dims=%d a=(%d %d %d)\n", a.dims, a.w, a.h, a.c);
fprintf(stderr, "test_expanddims_axes failed a.dims=%d a=(%d %d %d %d)\n", a.dims, a.w, a.h, a.d, a.c);
fprintf(stderr, " axes=");
print_int_array(axes);
fprintf(stderr, "\n");
Expand All @@ -91,48 +103,73 @@ static int test_expanddims_axes(const ncnn::Mat& a, const ncnn::Mat& axes)
return ret;
}

static int test_expand_0()
static int test_expanddims_all_params(const ncnn::Mat& a)
{
ncnn::Mat as[7];
as[0] = RandomMat(1, 1, 1);
as[1] = RandomMat(14, 16);
as[2] = RandomMat(1, 14);
as[3] = RandomMat(11, 1);
as[4] = RandomMat(1, 1);
as[5] = RandomMat(120);
as[6] = RandomMat(1);

for (int i = 0; i < 7; i++)
{
const ncnn::Mat& a = as[i];
int ret = 0
|| test_expanddims(a, 0, 0, 0)
|| test_expanddims(a, 0, 0, 1)
|| test_expanddims(a, 0, 1, 0)
|| test_expanddims(a, 0, 1, 1)
|| test_expanddims(a, 1, 0, 0)
|| test_expanddims(a, 1, 0, 1)
|| test_expanddims(a, 1, 1, 0)
|| test_expanddims(a, 1, 1, 1)

|| test_expanddims_axes(a, IntArrayMat(0))
|| test_expanddims_axes(a, IntArrayMat(1))
|| test_expanddims_axes(a, IntArrayMat(2))
|| test_expanddims_axes(a, IntArrayMat(0, 1))
|| test_expanddims_axes(a, IntArrayMat(0, 2))
|| test_expanddims_axes(a, IntArrayMat(1, 2))
|| test_expanddims_axes(a, IntArrayMat(0, 1, 2));

if (ret != 0)
return ret;
}
return 0
|| test_expanddims(a, 0, 0, 0, 0)
|| test_expanddims(a, 0, 0, 0, 1)
|| test_expanddims(a, 0, 0, 1, 0)
|| test_expanddims(a, 0, 0, 1, 1)
|| test_expanddims(a, 0, 1, 0, 0)
|| test_expanddims(a, 0, 1, 0, 1)
|| test_expanddims(a, 0, 1, 1, 0)
|| test_expanddims(a, 0, 1, 1, 1)
|| test_expanddims(a, 1, 0, 0, 0)
|| test_expanddims(a, 1, 0, 0, 1)
|| test_expanddims(a, 1, 0, 1, 0)
|| test_expanddims(a, 1, 0, 1, 1)
|| test_expanddims(a, 1, 1, 0, 0)
|| test_expanddims(a, 1, 1, 0, 1)
|| test_expanddims(a, 1, 1, 1, 0)
|| test_expanddims(a, 1, 1, 1, 1)

|| test_expanddims_axes(a, IntArrayMat(0))
|| test_expanddims_axes(a, IntArrayMat(1))
|| test_expanddims_axes(a, IntArrayMat(2))
|| test_expanddims_axes(a, IntArrayMat(3))
|| test_expanddims_axes(a, IntArrayMat(0, 1))
|| test_expanddims_axes(a, IntArrayMat(0, 2))
|| test_expanddims_axes(a, IntArrayMat(0, 3))
|| test_expanddims_axes(a, IntArrayMat(1, 2))
|| test_expanddims_axes(a, IntArrayMat(1, 3))
|| test_expanddims_axes(a, IntArrayMat(2, 3))
|| test_expanddims_axes(a, IntArrayMat(0, 1, 2))
|| test_expanddims_axes(a, IntArrayMat(0, 1, 3))
|| test_expanddims_axes(a, IntArrayMat(0, 2, 3))
|| test_expanddims_axes(a, IntArrayMat(1, 2, 3))
|| test_expanddims_axes(a, IntArrayMat(0, 1, 2, 3));
}

static int test_expanddims_0()
{
return 0
|| test_expanddims_all_params(RandomMat(3, 12, 16))
|| test_expanddims_all_params(RandomMat(3, 1, 16))
|| test_expanddims_all_params(RandomMat(1, 33, 15))
|| test_expanddims_all_params(RandomMat(1, 14, 1))
|| test_expanddims_all_params(RandomMat(12, 13, 1))
|| test_expanddims_all_params(RandomMat(1, 1, 1));
}

return 0;
static int test_expanddims_1()
{
return 0
|| test_expanddims_all_params(RandomMat(14, 16))
|| test_expanddims_all_params(RandomMat(1, 14))
|| test_expanddims_all_params(RandomMat(11, 1))
|| test_expanddims_all_params(RandomMat(1, 1));
}

static int test_expanddims_2()
{
return 0
|| test_expanddims_all_params(RandomMat(120))
|| test_expanddims_all_params(RandomMat(1));
}

int main()
{
SRAND(7767517);

return test_expand_0();
return test_expanddims_0() || test_expanddims_1() || test_expanddims_2();
}
Loading

0 comments on commit 498ca73

Please sign in to comment.