Skip to content

Commit

Permalink
fix pnnx softmax/normalize/slice negative axis conversion to ncnn (#4284
Browse files Browse the repository at this point in the history
)
  • Loading branch information
nihui authored Oct 19, 2022
1 parent 549152c commit 777e4ef
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 24 deletions.
13 changes: 8 additions & 5 deletions tools/pnnx/src/pass_ncnn/F_normalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,6 @@ pnnx.Output output 1 0 out
{
const int batch_index = op->inputs[0]->params["__batch_index"].i;

int input_rank = op->inputs[0]->shape.size();

if (batch_index >= 0 && batch_index < input_rank)
input_rank -= 1;

int axis = captured_params.at("dim").i;
if (axis == batch_index)
{
Expand All @@ -58,7 +53,10 @@ pnnx.Output output 1 0 out
}

if (axis < 0)
{
int input_rank = op->inputs[0]->shape.size();
axis = input_rank + axis;
}

if (axis > batch_index)
axis -= 1;
Expand All @@ -75,6 +73,11 @@ pnnx.Output output 1 0 out
return;
}

int input_rank = op->inputs[0]->shape.size();

if (batch_index >= 0 && batch_index < input_rank)
input_rank -= 1;

if (input_rank == 2 || axis != 0)
{
fprintf(stderr, "unsupported normalize for %d-rank tensor with axis %d\n", input_rank, axis);
Expand Down
8 changes: 3 additions & 5 deletions tools/pnnx/src/pass_ncnn/F_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,6 @@ pnnx.Output output 1 0 out
{
const int batch_index = op->inputs[0]->params["__batch_index"].i;

int input_rank = op->inputs[0]->shape.size();

if (batch_index >= 0 && batch_index < input_rank)
input_rank -= 1;

int axis = captured_params.at("dim").i;
if (axis == batch_index)
{
Expand All @@ -58,7 +53,10 @@ pnnx.Output output 1 0 out
}

if (axis < 0)
{
int input_rank = op->inputs[0]->shape.size();
axis = input_rank + axis;
}

if (axis > batch_index)
axis -= 1;
Expand Down
19 changes: 12 additions & 7 deletions tools/pnnx/src/pass_ncnn/Tensor_slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,17 @@ pnnx.Output output 1 0 out

const int batch_index = op->inputs[0]->params["__batch_index"].i;

int input_rank = op->inputs[0]->shape.size();
{
int input_rank = op->inputs[0]->shape.size();

if (batch_index >= 0 && batch_index < input_rank)
input_rank -= 1;
if (batch_index >= 0 && batch_index < input_rank)
input_rank -= 1;

if (input_rank > 4)
{
fprintf(stderr, "slice %d-rank tensor with %d-rank axes is not possible!\n", input_rank, axes_rank);
return;
if (input_rank > 4)
{
fprintf(stderr, "slice %d-rank tensor with %d-rank axes is not possible!\n", input_rank, axes_rank);
return;
}
}

for (int i = 0; i < axes_rank; i++)
Expand All @@ -80,7 +82,10 @@ pnnx.Output output 1 0 out
}

if (axes[i] < 0)
{
int input_rank = op->inputs[0]->shape.size();
axes[i] = input_rank + axes[i];
}

if (axes[i] > batch_index)
axes[i] -= 1;
Expand Down
8 changes: 3 additions & 5 deletions tools/pnnx/src/pass_ncnn/nn_Softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,6 @@ pnnx.Output output 1 0 out
{
const int batch_index = op->inputs[0]->params["__batch_index"].i;

int input_rank = op->inputs[0]->shape.size();

if (batch_index >= 0 && batch_index < input_rank)
input_rank -= 1;

int axis = captured_params.at("dim").i;
if (axis == batch_index)
{
Expand All @@ -58,7 +53,10 @@ pnnx.Output output 1 0 out
}

if (axis < 0)
{
int input_rank = op->inputs[0]->shape.size();
axis = input_rank + axis;
}

if (axis > batch_index)
axis -= 1;
Expand Down
3 changes: 2 additions & 1 deletion tools/pnnx/tests/ncnn/test_F_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def forward(self, x, y, z):
x = F.softmax(x, 0)
y = F.softmax(y, 1)
z = F.softmax(z, 2)
return x, y, z
z2 = F.softmax(z, -1)
return x, y, z, z2

def test():
net = Model()
Expand Down
4 changes: 3 additions & 1 deletion tools/pnnx/tests/ncnn/test_nn_Softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ def __init__(self):
self.act_0 = nn.Softmax(dim=0)
self.act_1 = nn.Softmax(dim=1)
self.act_2 = nn.Softmax(dim=2)
self.act_3 = nn.Softmax(dim=-1)

def forward(self, x, y, z):
x = self.act_0(x)
y = self.act_1(y)
z = self.act_2(z)
return x, y, z
z2 = self.act_3(z)
return x, y, z, z2

def test():
net = Model()
Expand Down

0 comments on commit 777e4ef

Please sign in to comment.