diff --git a/lib/src/LibSVM.cxx b/lib/src/LibSVM.cxx index 4609600..bb304f7 100644 --- a/lib/src/LibSVM.cxx +++ b/lib/src/LibSVM.cxx @@ -198,22 +198,56 @@ void LibSVM::setP(const OT::Scalar p) Point LibSVM::getSupportVectorCoef( ) { Point res(getNumberSupportVector()); - for( UnsignedInteger j = 0 ; j < getNumberSupportVector() ; j++ ) - { + for(UnsignedInteger j = 0 ; j < getNumberSupportVector(); ++ j) res[j] = p_implementation_->p_model_->sv_coef[0][j]; - } return res; } /* KernelType accessor */ LibSVM::KernelType LibSVM::getKernelType() const { - return (LibSVM::KernelType)p_implementation_->parameter_.kernel_type; + switch (p_implementation_->parameter_.kernel_type) + { + case LINEAR: + return Linear; + case POLY: + return Polynomial; + case RBF: + return NormalRbf; + case SIGMOID: + return Sigmoid; + default: + throw InvalidArgumentException(HERE) << "LibSVM: kernel type not available."; + } } void LibSVM::setKernelType(const UnsignedInteger kernelType) { - p_implementation_->parameter_.kernel_type = kernelType; + switch (kernelType) + { + case Linear: + { + p_implementation_->parameter_.kernel_type = LINEAR; + break; + } + case Polynomial: + { + p_implementation_->parameter_.kernel_type = POLY; + break; + } + case NormalRbf: + { + p_implementation_->parameter_.kernel_type = RBF; + break; + } + case Sigmoid: + { + p_implementation_->parameter_.kernel_type = SIGMOID; + break; + } + default: + throw InvalidArgumentException(HERE) << "LibSVM: kernel type not available."; + } } /* Node accessor */ @@ -222,10 +256,24 @@ svm_node* LibSVM::getNode(const UnsignedInteger index) return p_implementation_->p_model_ -> SV[index]; } -/*SvmType accessor */ +/* SvmType accessor */ void LibSVM::setSvmType(const UnsignedInteger svmType) { - p_implementation_->parameter_.svm_type = svmType; + switch (svmType) + { + case CSupportClassification: + { + p_implementation_->parameter_.svm_type = C_SVC; + break; + } + case EpsilonSupportRegression: + { + p_implementation_->parameter_.svm_type = EPSILON_SVR; + break; + } + default: + throw InvalidArgumentException(HERE) << "LibSVM: svmType not available."; + } } /*kernelParameter accessor */ @@ -312,20 +360,15 @@ Scalar LibSVM::computeError() Scalar LibSVM::computeAccuracy() { UnsignedInteger totalerror = 0; - for ( UnsignedInteger k = 0 ; k < (UnsignedInteger)p_implementation_->problem_.l ; k++ ) - { - if(p_implementation_->problem_.y[k] != svm_predict(p_implementation_->p_model_, p_implementation_->problem_.x[k])) - { - totalerror++; - } - } - + for (UnsignedInteger k = 0 ; k < (UnsignedInteger)p_implementation_->problem_.l ; ++ k) + if (p_implementation_->problem_.y[k] != svm_predict(p_implementation_->p_model_, p_implementation_->problem_.x[k])) + ++ totalerror; return totalerror; } template T * LibSVM::Allocation(const UnsignedInteger size) const { - return ( T * )malloc( size * sizeof(T) ); + return (T *)malloc( size * sizeof(T) ); } @@ -383,7 +426,6 @@ void LibSVM::convertData(const Sample & inputSample, const Sample & outputSample p_implementation_->p_node_[j * (inputDimension + 1) + i].index = i + 1; p_implementation_->p_node_[j * (inputDimension + 1) + i].value = inputTransformation_(inputSample[j])[i]; } - p_implementation_->p_node_[j * (inputDimension + 1) + inputDimension].index = - 1; } } @@ -466,30 +508,21 @@ UnsignedInteger LibSVM::getLabelValues(const Point & vector, const SignedInteger svm_predict_values(p_implementation_->p_model_, prob.x[0], dec_values); - for(UnsignedInteger i = 0 ; i < numberclass ; i ++ ) + for (UnsignedInteger i = 0; i < numberclass; ++ i) { - for(UnsignedInteger j = i + 1 ; j < numberclass ; j ++) + for(UnsignedInteger j = i + 1; j < numberclass; ++ j) { if(dec_values[pos++] > 0) - { ++ vote[i]; - } else - { ++ vote[j]; - } } } UnsignedInteger res = 0; - - for( UnsignedInteger i = 0 ; i < numberclass ; i ++ ) - { - if( (SignedInteger)p_implementation_->p_model_->label[i] == outC ) - { + for (UnsignedInteger i = 0; i < numberclass; ++ i) + if((SignedInteger)p_implementation_->p_model_->label[i] == outC) res = i; - } - } free(dec_values); return vote[res]; @@ -512,7 +545,6 @@ Scalar LibSVM::predict(const Point & inP) const svm_get_svm_type(p_implementation_->p_model_) == EPSILON_SVR || svm_get_svm_type(p_implementation_->p_model_) == NU_SVR) { - svm_predict_values(p_implementation_->p_model_, x, &res); if (svm_get_svm_type(p_implementation_->p_model_) == ONE_CLASS) @@ -520,12 +552,13 @@ Scalar LibSVM::predict(const Point & inP) const } else { + // multiclass int i; int nr_class = svm_get_nr_class(p_implementation_->p_model_); - double *dec_values = new double[nr_class * (nr_class - 1) / 2]; - svm_predict_values(p_implementation_->p_model_, x, dec_values); + std::vectordec_values(nr_class * (nr_class - 1) / 2); + svm_predict_values(p_implementation_->p_model_, x, dec_values.data()); - int *vote = new int[nr_class]; + std::vector vote(nr_class); for (i = 0; i < nr_class; i++) vote[i] = 0; int pos = 0; @@ -543,15 +576,12 @@ Scalar LibSVM::predict(const Point & inP) const if (vote[i] > vote[vote_max_idx]) vote_max_idx = i; - int *labels = new int[nr_class]; - svm_get_labels(p_implementation_->p_model_, labels); + std::vector labels(nr_class); + svm_get_labels(p_implementation_->p_model_, labels.data()); int label = labels[vote_max_idx]; res = (double)label; res = dec_values[0] * labels[0]; - delete[] labels; - delete[] vote; - delete[] dec_values; } free(x); return res;