Skip to content

Commit

Permalink
enums
Browse files Browse the repository at this point in the history
  • Loading branch information
jschueller committed Nov 2, 2023
1 parent c6e62ef commit 110ab50
Showing 1 changed file with 69 additions and 39 deletions.
108 changes: 69 additions & 39 deletions lib/src/LibSVM.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -198,22 +198,56 @@ void LibSVM::setP(const OT::Scalar p)
Point LibSVM::getSupportVectorCoef( )
{
Point res(getNumberSupportVector());
for( UnsignedInteger j = 0 ; j < getNumberSupportVector() ; j++ )
{
for(UnsignedInteger j = 0 ; j < getNumberSupportVector(); ++ j)
res[j] = p_implementation_->p_model_->sv_coef[0][j];
}
return res;
}

/* KernelType accessor */
LibSVM::KernelType LibSVM::getKernelType() const
{
return (LibSVM::KernelType)p_implementation_->parameter_.kernel_type;
switch (p_implementation_->parameter_.kernel_type)
{
case LINEAR:
return Linear;
case POLY:
return Polynomial;
case RBF:
return NormalRbf;
case SIGMOID:
return Sigmoid;
default:
throw InvalidArgumentException(HERE) << "LibSVM: kernel type not available.";
}
}

void LibSVM::setKernelType(const UnsignedInteger kernelType)
{
p_implementation_->parameter_.kernel_type = kernelType;
switch (kernelType)
{
case Linear:
{
p_implementation_->parameter_.kernel_type = LINEAR;
break;
}
case Polynomial:
{
p_implementation_->parameter_.kernel_type = POLY;
break;
}
case NormalRbf:
{
p_implementation_->parameter_.kernel_type = RBF;
break;
}
case Sigmoid:
{
p_implementation_->parameter_.kernel_type = SIGMOID;
break;
}
default:
throw InvalidArgumentException(HERE) << "LibSVM: kernel type not available.";
}
}

/* Node accessor */
Expand All @@ -222,10 +256,24 @@ svm_node* LibSVM::getNode(const UnsignedInteger index)
return p_implementation_->p_model_ -> SV[index];
}

/*SvmType accessor */
/* SvmType accessor */
void LibSVM::setSvmType(const UnsignedInteger svmType)
{
p_implementation_->parameter_.svm_type = svmType;
switch (svmType)
{
case CSupportClassification:
{
p_implementation_->parameter_.svm_type = C_SVC;
break;
}
case EpsilonSupportRegression:
{
p_implementation_->parameter_.svm_type = EPSILON_SVR;
break;
}
default:
throw InvalidArgumentException(HERE) << "LibSVM: svmType not available.";
}
}

/*kernelParameter accessor */
Expand Down Expand Up @@ -312,20 +360,15 @@ Scalar LibSVM::computeError()
Scalar LibSVM::computeAccuracy()
{
UnsignedInteger totalerror = 0;
for ( UnsignedInteger k = 0 ; k < (UnsignedInteger)p_implementation_->problem_.l ; k++ )
{
if(p_implementation_->problem_.y[k] != svm_predict(p_implementation_->p_model_, p_implementation_->problem_.x[k]))
{
totalerror++;
}
}

for (UnsignedInteger k = 0 ; k < (UnsignedInteger)p_implementation_->problem_.l ; ++ k)
if (p_implementation_->problem_.y[k] != svm_predict(p_implementation_->p_model_, p_implementation_->problem_.x[k]))
++ totalerror;
return totalerror;
}

template <typename T> T * LibSVM::Allocation(const UnsignedInteger size) const
{
return ( T * )malloc( size * sizeof(T) );
return (T *)malloc( size * sizeof(T) );
}


Expand Down Expand Up @@ -383,7 +426,6 @@ void LibSVM::convertData(const Sample & inputSample, const Sample & outputSample
p_implementation_->p_node_[j * (inputDimension + 1) + i].index = i + 1;
p_implementation_->p_node_[j * (inputDimension + 1) + i].value = inputTransformation_(inputSample[j])[i];
}

p_implementation_->p_node_[j * (inputDimension + 1) + inputDimension].index = - 1;
}
}
Expand Down Expand Up @@ -466,30 +508,21 @@ UnsignedInteger LibSVM::getLabelValues(const Point & vector, const SignedInteger

svm_predict_values(p_implementation_->p_model_, prob.x[0], dec_values);

for(UnsignedInteger i = 0 ; i < numberclass ; i ++ )
for (UnsignedInteger i = 0; i < numberclass; ++ i)
{
for(UnsignedInteger j = i + 1 ; j < numberclass ; j ++)
for(UnsignedInteger j = i + 1; j < numberclass; ++ j)
{
if(dec_values[pos++] > 0)
{
++ vote[i];
}
else
{
++ vote[j];
}
}
}

UnsignedInteger res = 0;

for( UnsignedInteger i = 0 ; i < numberclass ; i ++ )
{
if( (SignedInteger)p_implementation_->p_model_->label[i] == outC )
{
for (UnsignedInteger i = 0; i < numberclass; ++ i)
if((SignedInteger)p_implementation_->p_model_->label[i] == outC)
res = i;
}
}

free(dec_values);
return vote[res];
Expand All @@ -512,20 +545,20 @@ Scalar LibSVM::predict(const Point & inP) const
svm_get_svm_type(p_implementation_->p_model_) == EPSILON_SVR ||
svm_get_svm_type(p_implementation_->p_model_) == NU_SVR)
{

svm_predict_values(p_implementation_->p_model_, x, &res);

if (svm_get_svm_type(p_implementation_->p_model_) == ONE_CLASS)
res = (res > 0) ? 1 : -1;
}
else
{
// multiclass
int i;
int nr_class = svm_get_nr_class(p_implementation_->p_model_);
double *dec_values = new double[nr_class * (nr_class - 1) / 2];
svm_predict_values(p_implementation_->p_model_, x, dec_values);
std::vector<double>dec_values(nr_class * (nr_class - 1) / 2);
svm_predict_values(p_implementation_->p_model_, x, dec_values.data());

int *vote = new int[nr_class];
std::vector<int> vote(nr_class);
for (i = 0; i < nr_class; i++)
vote[i] = 0;
int pos = 0;
Expand All @@ -543,15 +576,12 @@ Scalar LibSVM::predict(const Point & inP) const
if (vote[i] > vote[vote_max_idx])
vote_max_idx = i;

int *labels = new int[nr_class];
svm_get_labels(p_implementation_->p_model_, labels);
std::vector<int> labels(nr_class);
svm_get_labels(p_implementation_->p_model_, labels.data());
int label = labels[vote_max_idx];

res = (double)label;
res = dec_values[0] * labels[0];
delete[] labels;
delete[] vote;
delete[] dec_values;
}
free(x);
return res;
Expand Down

0 comments on commit 110ab50

Please sign in to comment.