Skip to content

Commit b236e87

Browse files
committed
mean iou debug jolibrain#2
1 parent d839ab2 commit b236e87

File tree

2 files changed

+32
-9
lines changed

2 files changed

+32
-9
lines changed

src/supervisedoutputconnector.h

+14-5
Original file line numberDiff line numberDiff line change
@@ -673,14 +673,22 @@ namespace dd
673673
double accc = c_sum / c_total_targ;
674674
mean_acc[c] += accc;
675675
mean_acc_bs[c]++;
676-
}
677676

678-
// mean intersection over union
677+
}
678+
679+
// mean intersection over union
679680
double c_false_neg = static_cast<double>((ddiffc.array() == -2-c).count());
680-
double c_false_pos = static_cast<double>((ddiffc.array() == c+1).count());
681-
double iou = (c_sum==0) ? 0: c_sum / (c_false_pos + c_sum + c_false_neg);
681+
double c_false_pos = static_cast<double>((ddiffc.array() == c+1).count());
682+
// below corner case where nothing is to predict : put correct to zero
683+
// but do not devide by zero
684+
double iou = (c_sum == 0)? 0 : c_sum / (c_false_pos + c_sum + c_false_neg);
682685
mean_iou[c] += iou;
683-
mean_iou_bs[c]++;
686+
// ... and divide one time less when normalizing by batch size
687+
if (c_total_targ !=0)
688+
mean_iou_bs[c]++;
689+
// another possible waywould be to put artificially iou to one if nothing is to be
690+
// predicted for class c
691+
684692
}
685693
}
686694
int c_nclasses = 0;
@@ -696,6 +704,7 @@ namespace dd
696704
meaniou += mean_iou[c];
697705
}
698706
clacc = mean_acc;
707+
// corner case where prediction is wrong
699708
if (c_nclasses > 0) {
700709
meanacc /= static_cast<double>(c_nclasses);
701710
meaniou /= static_cast<double>(c_nclasses);

tests/ut-conn.cc

+18-4
Original file line numberDiff line numberDiff line change
@@ -56,21 +56,35 @@ TEST(outputconn,acc)
5656

5757
TEST(outputconn,acc_v)
5858
{
59-
std::vector<double> targets = {0.0, 0.0, 1.0, 1.0};
60-
std::vector<double> pred1 = {0.0,1.0,1.0,1.0};
61-
std::vector<std::vector<double>> preds = { pred1 };
6259
APIData res_ad;
63-
res_ad.add("batch_size",static_cast<int>(1));
60+
res_ad.add("batch_size",static_cast<int>(2));
61+
res_ad.add("nclasses",static_cast<int>(2));
62+
6463
APIData bad;
64+
std::vector<double> targets = {0.0, 1.0 };
65+
std::vector<double> pred1 = {0.0, 1.0};
6566
bad.add("pred",pred1);
6667
bad.add("target",targets);
6768
std::vector<APIData> vad = {bad};
6869
res_ad.add(std::to_string(0),vad);
70+
71+
72+
APIData bad2;
73+
std::vector<double> targets2 = {0.0, 0.0};
74+
std::vector<double> pred2 = {0.0, 1.0};
75+
bad2.add("pred",pred2);
76+
bad2.add("target",targets2);
77+
std::vector<APIData> vad2 = {bad2};
78+
res_ad.add(std::to_string(1),vad2);
79+
80+
81+
6982
SupervisedOutput so;
7083
double meanacc = 0.0, meaniou = 0.0;
7184
std::vector<double> clacc;
7285
double acc = so.acc_v(res_ad,meanacc,meaniou,clacc);
7386
ASSERT_EQ(0.75,acc);
87+
ASSERT_EQ(0.875,meaniou)
7488
}
7589

7690
TEST(outputconn,acck)

0 commit comments

Comments
 (0)