Skip to content

Commit

Permalink
SVMKernel: Fix setParameter
Browse files Browse the repository at this point in the history
  • Loading branch information
jschueller committed Feb 19, 2024
1 parent d761abf commit c6194c8
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 23 deletions.
4 changes: 3 additions & 1 deletion lib/src/ExponentialRBF.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Scalar ExponentialRBF::getSigma() const
return sigma_;
}

void ExponentialRBF::setSigma(Scalar sigma)
void ExponentialRBF::setSigma(const Scalar sigma)
{
sigma_ = sigma;
}
Expand All @@ -77,6 +77,8 @@ Point ExponentialRBF::getParameter() const

void ExponentialRBF::setParameter(const Point & parameter)
{
if (parameter.getDimension() != 1)
throw InvalidArgumentException(HERE) << "ExponentialRBF expected a parameter of dimension 1";
sigma_ = parameter[0];
}

Expand Down
8 changes: 5 additions & 3 deletions lib/src/NormalRBF.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Scalar NormalRBF::getSigma() const
return sigma_;
}

void NormalRBF::setSigma(Scalar sigma)
void NormalRBF::setSigma(const Scalar sigma)
{
sigma_ = sigma;
}
Expand All @@ -77,6 +77,8 @@ Point NormalRBF::getParameter() const

void NormalRBF::setParameter(const Point & parameter)
{
if (parameter.getDimension() != 1)
throw InvalidArgumentException(HERE) << "NormalRBF expected a parameter of dimension 1";
sigma_ = parameter[0];
}

Expand All @@ -91,8 +93,8 @@ Description NormalRBF::getParameterDescription() const
/* Operator () */
Scalar NormalRBF::operator() (const Point & x1, const Point & x2) const
{
Point difference(x1 - x2);
Scalar value = exp(- difference.normSquare() / (2.0 * sigma_ * sigma_));
const Point difference(x1 - x2);
const Scalar value = exp(- difference.normSquare() / (2.0 * sigma_ * sigma_));
return value;
}

Expand Down
7 changes: 5 additions & 2 deletions lib/src/PolynomialKernel.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,16 @@ void PolynomialKernel::setConstant(Scalar constant)
/* Accessor to the parameter used for cross-validation */
Point PolynomialKernel::getParameter() const
{
return {degree_};
return {degree_, linear_, constant_};
}

void PolynomialKernel::setParameter(const Point & parameter)
{
if (parameter.getDimension() != 3)
throw InvalidArgumentException(HERE) << "PolynomialKernel expected a parameter of dimension 3";
degree_ = parameter[0];
constant_ = parameter[1];
linear_ = parameter[1];
constant_ = parameter[2];
}


Expand Down
2 changes: 2 additions & 0 deletions lib/src/RationalKernel.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ Point RationalKernel::getParameter() const

void RationalKernel::setParameter(const Point & parameter)
{
if (parameter.getDimension() != 1)
throw InvalidArgumentException(HERE) << "RationalKernel expected a parameter of dimension 1";
constant_ = parameter[0];
}

Expand Down
16 changes: 9 additions & 7 deletions lib/src/SigmoidKernel.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ SigmoidKernel * SigmoidKernel::clone() const
String SigmoidKernel::__repr__() const
{
return OSS() << "class=" << getClassName()
<< "constant=" << constant_
<< "linear=" << linear_;
<< " constant=" << constant_
<< " linear=" << linear_;
}


Expand All @@ -66,7 +66,7 @@ Scalar SigmoidKernel::getLinear() const
return linear_;
}

void SigmoidKernel::setLinear( Scalar linear )
void SigmoidKernel::setLinear(const Scalar linear)
{
linear_ = linear;
}
Expand All @@ -78,13 +78,15 @@ Scalar SigmoidKernel::getConstant() const
return constant_;
}

void SigmoidKernel::setConstant( Scalar constant )
void SigmoidKernel::setConstant(const Scalar constant)
{
constant_ = constant;
}

void SigmoidKernel::setParameter(const Point & parameter)
{
if (parameter.getDimension() != 2)
throw InvalidArgumentException(HERE) << "SigmoidKernel expected a parameter of dimension 2";
linear_ = parameter[0];
constant_ = parameter[1];
}
Expand All @@ -101,10 +103,10 @@ Description SigmoidKernel::getParameterDescription() const
}

/* Operator () */
Scalar SigmoidKernel::operator() ( const Point & x1, const Point & x2 ) const
Scalar SigmoidKernel::operator() (const Point & x1, const Point & x2) const
{
Scalar dotProduct = x1.dot(x2);
Scalar value = tanh( linear_ * dotProduct + constant_ );
const Scalar dotProduct = x1.dot(x2);
const Scalar value = tanh( linear_ * dotProduct + constant_ );
return value;
}

Expand Down
8 changes: 1 addition & 7 deletions lib/src/otsvm/ExponentialRBF.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,16 @@ namespace OTSVM
{



/**
* @class ExponentialRBF
*
* Implementation of the Exponential RBF kernel
*/



class OTSVM_API ExponentialRBF
: public SVMKernelImplementation
{
CLASSNAME
public:


/** Constructor with parameters */
explicit ExponentialRBF(const OT::Scalar sigma = 1.0);

Expand All @@ -54,7 +48,7 @@ public:

/** Sigma parameter accessor */
virtual OT::Scalar getSigma() const;
virtual void setSigma(OT::Scalar sigma);
virtual void setSigma(const OT::Scalar sigma);

/** Accessor to the parameter used for cross-validation */
OT::Point getParameter() const override;
Expand Down
2 changes: 1 addition & 1 deletion lib/src/otsvm/NormalRBF.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public:

/** Sigma parameter accessor */
virtual OT::Scalar getSigma() const;
virtual void setSigma(OT::Scalar sigma);
virtual void setSigma(const OT::Scalar sigma);

/** Accessor to the parameter used for cross-validation */
OT::Point getParameter() const override;
Expand Down
4 changes: 2 additions & 2 deletions lib/src/otsvm/SigmoidKernel.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ public:

/* Linear term accessor */
virtual OT::Scalar getLinear() const;
virtual void setLinear(OT::Scalar linear);
virtual void setLinear(const OT::Scalar linear);

/* Constant term accessor */
virtual OT::Scalar getConstant() const;
virtual void setConstant(OT::Scalar constant);
virtual void setConstant(const OT::Scalar constant);

/* Parameters value and description accessor */
OT::Point getParameter() const override;
Expand Down
22 changes: 22 additions & 0 deletions python/test/t_SVMKernel_std.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#! /usr/bin/env python

import openturns as ot
import openturns.testing as ott
import otsvm
import os

kernels = [
otsvm.NormalRBF(2.0),
Expand All @@ -21,6 +23,7 @@
V.append([x, y])

for kernel in kernels:
print(kernel)
for v in V:
x, y = v
print("x,y=", x, y)
Expand All @@ -43,3 +46,22 @@

kxydx2 = kernel.partialHessian(x, y)
print("d2kernel/(dx_i*dx_j)(x,y)=", repr(kxydx2))

# parameter accessor
param = kernel.getParameter()
kernel.setParameter(param)

# serialization
if ot.PlatformInfo.HasFeature("libxml2"):
study = ot.Study()
fname = "study_kernel.xml"
study.setStorageManager(ot.XMLStorageManager(fname))
study.add("kernel", kernel)
study.save()
study = ot.Study()
study.setStorageManager(ot.XMLStorageManager(fname))
study.load()
kernel = otsvm.SVMKernel()
study.fillObject("kernel", kernel)
print(kernel)
os.remove(fname)

0 comments on commit c6194c8

Please sign in to comment.