diff --git a/lib/src/ExponentialRBF.cxx b/lib/src/ExponentialRBF.cxx index af57f6c..66ecfb5 100644 --- a/lib/src/ExponentialRBF.cxx +++ b/lib/src/ExponentialRBF.cxx @@ -63,7 +63,7 @@ Scalar ExponentialRBF::getSigma() const return sigma_; } -void ExponentialRBF::setSigma(Scalar sigma) +void ExponentialRBF::setSigma(const Scalar sigma) { sigma_ = sigma; } @@ -77,6 +77,8 @@ Point ExponentialRBF::getParameter() const void ExponentialRBF::setParameter(const Point & parameter) { + if (parameter.getDimension() != 1) + throw InvalidArgumentException(HERE) << "ExponentialRBF expected a parameter of dimension 1"; sigma_ = parameter[0]; } diff --git a/lib/src/NormalRBF.cxx b/lib/src/NormalRBF.cxx index a852cfd..da60155 100644 --- a/lib/src/NormalRBF.cxx +++ b/lib/src/NormalRBF.cxx @@ -63,7 +63,7 @@ Scalar NormalRBF::getSigma() const return sigma_; } -void NormalRBF::setSigma(Scalar sigma) +void NormalRBF::setSigma(const Scalar sigma) { sigma_ = sigma; } @@ -77,6 +77,8 @@ Point NormalRBF::getParameter() const void NormalRBF::setParameter(const Point & parameter) { + if (parameter.getDimension() != 1) + throw InvalidArgumentException(HERE) << "NormalRBF expected a parameter of dimension 1"; sigma_ = parameter[0]; } @@ -91,8 +93,8 @@ Description NormalRBF::getParameterDescription() const /* Operator () */ Scalar NormalRBF::operator() (const Point & x1, const Point & x2) const { - Point difference(x1 - x2); - Scalar value = exp(- difference.normSquare() / (2.0 * sigma_ * sigma_)); + const Point difference(x1 - x2); + const Scalar value = exp(- difference.normSquare() / (2.0 * sigma_ * sigma_)); return value; } diff --git a/lib/src/PolynomialKernel.cxx b/lib/src/PolynomialKernel.cxx index c92a077..ffc1be9 100644 --- a/lib/src/PolynomialKernel.cxx +++ b/lib/src/PolynomialKernel.cxx @@ -101,13 +101,16 @@ void PolynomialKernel::setConstant(Scalar constant) /* Accessor to the parameter used for cross-validation */ Point PolynomialKernel::getParameter() const { - return {degree_}; + return {degree_, linear_, constant_}; } void PolynomialKernel::setParameter(const Point & parameter) { + if (parameter.getDimension() != 3) + throw InvalidArgumentException(HERE) << "PolynomialKernel expected a parameter of dimension 3"; degree_ = parameter[0]; - constant_ = parameter[1]; + linear_ = parameter[1]; + constant_ = parameter[2]; } diff --git a/lib/src/RationalKernel.cxx b/lib/src/RationalKernel.cxx index 604ff3f..67e38aa 100644 --- a/lib/src/RationalKernel.cxx +++ b/lib/src/RationalKernel.cxx @@ -77,6 +77,8 @@ Point RationalKernel::getParameter() const void RationalKernel::setParameter(const Point & parameter) { + if (parameter.getDimension() != 1) + throw InvalidArgumentException(HERE) << "RationalKernel expected a parameter of dimension 1"; constant_ = parameter[0]; } diff --git a/lib/src/SigmoidKernel.cxx b/lib/src/SigmoidKernel.cxx index a0f2fb7..b4ea363 100644 --- a/lib/src/SigmoidKernel.cxx +++ b/lib/src/SigmoidKernel.cxx @@ -55,8 +55,8 @@ SigmoidKernel * SigmoidKernel::clone() const String SigmoidKernel::__repr__() const { return OSS() << "class=" << getClassName() - << "constant=" << constant_ - << "linear=" << linear_; + << " constant=" << constant_ + << " linear=" << linear_; } @@ -66,7 +66,7 @@ Scalar SigmoidKernel::getLinear() const return linear_; } -void SigmoidKernel::setLinear( Scalar linear ) +void SigmoidKernel::setLinear(const Scalar linear) { linear_ = linear; } @@ -78,13 +78,15 @@ Scalar SigmoidKernel::getConstant() const return constant_; } -void SigmoidKernel::setConstant( Scalar constant ) +void SigmoidKernel::setConstant(const Scalar constant) { constant_ = constant; } void SigmoidKernel::setParameter(const Point & parameter) { + if (parameter.getDimension() != 2) + throw InvalidArgumentException(HERE) << "SigmoidKernel expected a parameter of dimension 2"; linear_ = parameter[0]; constant_ = parameter[1]; } @@ -101,10 +103,10 @@ Description SigmoidKernel::getParameterDescription() const } /* Operator () */ -Scalar SigmoidKernel::operator() ( const Point & x1, const Point & x2 ) const +Scalar SigmoidKernel::operator() (const Point & x1, const Point & x2) const { - Scalar dotProduct = x1.dot(x2); - Scalar value = tanh( linear_ * dotProduct + constant_ ); + const Scalar dotProduct = x1.dot(x2); + const Scalar value = tanh( linear_ * dotProduct + constant_ ); return value; } diff --git a/lib/src/otsvm/ExponentialRBF.hxx b/lib/src/otsvm/ExponentialRBF.hxx index 8aa158c..f2a27b6 100644 --- a/lib/src/otsvm/ExponentialRBF.hxx +++ b/lib/src/otsvm/ExponentialRBF.hxx @@ -27,22 +27,16 @@ namespace OTSVM { - /** * @class ExponentialRBF * * Implementation of the Exponential RBF kernel */ - - - class OTSVM_API ExponentialRBF : public SVMKernelImplementation { CLASSNAME public: - - /** Constructor with parameters */ explicit ExponentialRBF(const OT::Scalar sigma = 1.0); @@ -54,7 +48,7 @@ public: /** Sigma parameter accessor */ virtual OT::Scalar getSigma() const; - virtual void setSigma(OT::Scalar sigma); + virtual void setSigma(const OT::Scalar sigma); /** Accessor to the parameter used for cross-validation */ OT::Point getParameter() const override; diff --git a/lib/src/otsvm/NormalRBF.hxx b/lib/src/otsvm/NormalRBF.hxx index 48c44f5..76b8e1c 100644 --- a/lib/src/otsvm/NormalRBF.hxx +++ b/lib/src/otsvm/NormalRBF.hxx @@ -49,7 +49,7 @@ public: /** Sigma parameter accessor */ virtual OT::Scalar getSigma() const; - virtual void setSigma(OT::Scalar sigma); + virtual void setSigma(const OT::Scalar sigma); /** Accessor to the parameter used for cross-validation */ OT::Point getParameter() const override; diff --git a/lib/src/otsvm/SigmoidKernel.hxx b/lib/src/otsvm/SigmoidKernel.hxx index e6da218..d7c04fa 100644 --- a/lib/src/otsvm/SigmoidKernel.hxx +++ b/lib/src/otsvm/SigmoidKernel.hxx @@ -48,11 +48,11 @@ public: /* Linear term accessor */ virtual OT::Scalar getLinear() const; - virtual void setLinear(OT::Scalar linear); + virtual void setLinear(const OT::Scalar linear); /* Constant term accessor */ virtual OT::Scalar getConstant() const; - virtual void setConstant(OT::Scalar constant); + virtual void setConstant(const OT::Scalar constant); /* Parameters value and description accessor */ OT::Point getParameter() const override; diff --git a/python/test/t_SVMKernel_std.py b/python/test/t_SVMKernel_std.py index 423f643..4747cfd 100755 --- a/python/test/t_SVMKernel_std.py +++ b/python/test/t_SVMKernel_std.py @@ -1,7 +1,9 @@ #! /usr/bin/env python +import openturns as ot import openturns.testing as ott import otsvm +import os kernels = [ otsvm.NormalRBF(2.0), @@ -21,6 +23,7 @@ V.append([x, y]) for kernel in kernels: + print(kernel) for v in V: x, y = v print("x,y=", x, y) @@ -43,3 +46,22 @@ kxydx2 = kernel.partialHessian(x, y) print("d2kernel/(dx_i*dx_j)(x,y)=", repr(kxydx2)) + + # parameter accessor + param = kernel.getParameter() + kernel.setParameter(param) + + # serialization + if ot.PlatformInfo.HasFeature("libxml2"): + study = ot.Study() + fname = "study_kernel.xml" + study.setStorageManager(ot.XMLStorageManager(fname)) + study.add("kernel", kernel) + study.save() + study = ot.Study() + study.setStorageManager(ot.XMLStorageManager(fname)) + study.load() + kernel = otsvm.SVMKernel() + study.fillObject("kernel", kernel) + print(kernel) + os.remove(fname)