Skip to content

Commit

Permalink
SVMClassification: Fix serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
jschueller committed Feb 19, 2024
1 parent e6a3d66 commit d761abf
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 36 deletions.
2 changes: 1 addition & 1 deletion lib/src/LibSVM.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ Scalar LibSVM::runCrossValidation()
totalError += (p_implementation_->problem_.y[i] - target[i]) * (p_implementation_->problem_.y[i] - target[i]) / size;
}

LOGTRACE(OSS() << "LibSVM::runCrossValidation gamma=" << p_implementation_->parameter_.gamma << " C=" << p_implementation_->parameter_.C << " err=" << totalError);
LOGDEBUG(OSS() << "LibSVM::runCrossValidation gamma=" << p_implementation_->parameter_.gamma << " C=" << p_implementation_->parameter_.C << " err=" << totalError);

return totalError;
}
Expand Down
6 changes: 3 additions & 3 deletions lib/src/SVMClassification.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,9 @@ void SVMClassification::save(Advocate & adv) const
void SVMClassification::load(Advocate & adv)
{
ClassifierImplementation::load(adv);
adv.saveAttribute( "tradeoffFactor_", tradeoffFactor_ );
adv.saveAttribute( "kernelParameter_", kernelParameter_ );
adv.saveAttribute( "accuracy_", accuracy_ );
adv.loadAttribute( "tradeoffFactor_", tradeoffFactor_ );
adv.loadAttribute( "kernelParameter_", kernelParameter_ );
adv.loadAttribute( "accuracy_", accuracy_ );
}


Expand Down
4 changes: 1 addition & 3 deletions python/doc/examples/plot_example2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@
# we create dataIn and dataOut
size = dataInOut.getSize()
dataIn = dataInOut.getMarginal([1, 2])
dataOut = ot.Indices(size)
for i in range(size):
dataOut[i] = int(dataInOut[i, 0])
dataOut = [int(dataInOut[i, 0]) for i in range(size)]

# list of C parameter
cp = [0.000001, 0.00001, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100]
Expand Down
18 changes: 3 additions & 15 deletions python/test/t_SVMClassification_multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,9 @@
path = os.path.abspath(os.path.dirname(__file__))
dataInOut = ot.Sample.ImportFromCSVFile(os.path.join(path, "multiclass.csv"), ",")

dataIn = ot.Sample(148, 4)
dataOut = ot.Indices(148, 0)

# we build the input Sample and the output Sample because we must separate
# dataInOut
for i in range(148):
a = dataInOut[i]
b = ot.Point(4)
b[0] = a[1]
b[1] = a[2]
b[2] = a[3]
b[3] = a[4]
dataIn[i] = b
dataOut[i] = int(a[0])

size = len(dataInOut)
dataIn = dataInOut.getMarginal([1, 2, 3, 4])
dataOut = [int(dataInOut[i, 0]) for i in range(size)]

# list of C parameter
cp = [0.000001, 0.00001, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100]
Expand Down
37 changes: 23 additions & 14 deletions python/test/t_SVMClassification_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,10 @@

# we retrieve the sample from the file sample.csv
path = os.path.abspath(os.path.dirname(__file__))
dataInOut = ot.Sample().ImportFromCSVFile(os.path.join(path, "sample.csv"), ",")

dataIn = ot.Sample(861, 2)
dataOut = ot.Indices(861, 0)

# we build the input Sample and the output Sample because we must separate
# dataInOut
for i in range(861):
a = dataInOut[i]
b = ot.Point(2)
b[0] = a[1]
b[1] = a[2]
dataIn[i] = b
dataOut[i] = int(a[0])
dataInOut = ot.Sample.ImportFromCSVFile(os.path.join(path, "sample.csv"), ",")
size = len(dataInOut)
dataIn = dataInOut.getMarginal([1, 2])
dataOut = [int(dataInOut[i, 0]) for i in range(size)]

# list of C parameter
cp = [0.000001, 0.00001, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100]
Expand All @@ -40,3 +30,22 @@
accuracy = algo.getAccuracy()
print(accuracy)
ott.assert_almost_equal(accuracy, 100.0)

for i in range(size):
x = dataIn[i]
c = dataOut[i]
print(f"x={x} c={c} classify={algo.classify(x)} grade={algo.grade(x, c)} predict={algo.predict(x)}")

if ot.PlatformInfo.HasFeature("libxml2"):
study = ot.Study()
fname = "study_classif.xml"
study.setStorageManager(ot.XMLStorageManager(fname))
study.add("algo", algo)
study.save()
study = ot.Study()
study.setStorageManager(ot.XMLStorageManager(fname))
study.load()
algo = otsvm.SVMClassification()
study.fillObject("algo", algo)
accuracy = algo.getAccuracy()
os.remove(fname)

0 comments on commit d761abf

Please sign in to comment.