-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrainML3D.cc
148 lines (128 loc) · 4.29 KB
/
trainML3D.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
/*
* Open source implementation of the ML3 classifier.
*
* If you find this software useful, please cite:
*
* "Multiclass Latent Locally Linear Support Vector Machines"
* Marco Fornoni, Barbara Caputo and Francesco Orabona
* JMLR Workshop and Conference Proceedings Volume 29 (ACML 2013 Proceedings)
*
* Copyright (c) 2013 Idiap Research Institute, http://www.idiap.ch/
* Written by Marco Fornoni <marco.fornoni@alumni.epfl.ch>
*
* This file is part of the ML3 Software.
*
* ML3 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 3 as
* published by the Free Software Foundation.
*
* ML3 is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ML3. If not, see <http://www.gnu.org/licenses/>.
*
* trainML3D.cc
*
* MATLAB interface to train a ML3 model on double-precision data
*
* Created on: Apr 20, 2013
* Author: Marco Fornoni
*/
#include "MexUtils.h"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
if (nlhs!=1)
mexErrMsgTxt("The number of output variables must be: 1."
"\nExample syntax: model=trainML3(model,X_tr,y_tr)");
if (nrhs!=3 && nrhs!=5)
mexErrMsgTxt("The number of input variables must be either 3:\n"
"correct syntax: model=trainML3(model,X_tr,y_tr)\n"
"or 5:\n"
"\ncorrect syntax: model=trainML3(model,X_tr,y_tr,X_te,y_te)\n");
// Eigen::initParallel();
// omp_set_num_threads(4);
// mexPrintf("Num threads %d.\n", omp_get_num_threads());
// mexPrintf("test");
// std::cout<<"test"<<std::endl;
bool testAllEpochs=false;
const mxArray *mxX = prhs[1];
const uint xdim = mxGetM(mxX);
const uint xlen = mxGetN(mxX);
const mxArray *mxY = prhs[2];
const uint ylen = mxGetM(mxY);
const uint ylen2 = mxGetN(mxY);
int *dy = (int *) mxGetPr(mxY);
if (ylen != xlen){
mexErrMsgTxt("Number of samples and labels should agree.\n");
}
const mxClassID category = mxGetClassID(mxX);
const mxArray *mxXte;
const mxArray *mxYte;
int *dyte ;
uint xtedim;
uint xtelen;
uint ytelen;
ArrayXi yte;
if (nrhs==5){
mxXte = prhs[3];
xtedim = mxGetM(mxXte);
xtelen = mxGetN(mxXte);
mxYte = prhs[4];
ytelen = mxGetM(mxYte);
const uint ytelen2 = mxGetN(mxYte);
dyte = (int *) mxGetPr(mxYte);
if (xtelen>0 && ytelen>0){
if (xdim != xtedim){
mexErrMsgTxt("Training and testing samples should have the same dimensionality.\n");
}
if (ytelen != xtelen){
mexErrMsgTxt("Number of test samples and test labels should agree.\n");
}
testAllEpochs=true;
}
// MAPS THE MATLAB ARRAY INTO AN EIGEN ARRAY OF INTEGERS
yte=Map<ArrayXi>(dyte,ytelen);
}
// MAPS THE MATLAB ARRAY INTO AN EIGEN ARRAY OF INTEGERS
const Map<ArrayXi> y(dy,ylen);
if (mxIsSparse(mxX)) {
std::cout<<"Sparse input data is NOT supported."<<std::endl;
}else if(category==mxDOUBLE_CLASS) {
double *val;
//INSTANTIATE A ML3 MODEL
Model<double> model=Model<double>();
//INSTANTIATE A ML3 ALGORITHM
ML3<double> ml3=ML3<double>();
//INSTANTIATE A MEXUTILS OBJECT
MexUtils<double> mu=MexUtils<double>();
//LOADS THE MODEL PASSED AS ARGUMENT BY MATLAB
mu.load_mex_model(prhs[0], model);
double *dX = (double *) mxGetData(mxX);
// MAPS THE MATLAB MATRIX INTO AN EIGEN MATRIX OF DOUBLES
const Map<MatrixXd> X(dX,xdim,xlen);
if (model.maxCCCPIter < 0){
mexErrMsgTxt("The number of CCCP iterations must be >= 0 \n");
}else if (model.initStep==0 && model.maxCCCPIter==0){
mexErrMsgTxt("If no initialization is performed the number of CCCP iterations must be >= 1 \n");
}
mu.timer_reset();
if (testAllEpochs){
double *dXte = (double *) mxGetData(mxXte);
MatrixXd Xte;
// MAPS THE MATLAB MATRIX INTO AN EIGEN MATRIX OF DOUBLES
Xte=Map<MatrixXd>(dXte,xtedim,xtelen);
ml3.trainML3(model,X,y,Xte,yte,true);
}else{
// ml3.trainML3(model,X,y,X,y,true);
ml3.trainML3(model,X,y);
}
model.trTime=mu.timer_query();
// SETS THE model AS AN OUTPUT FOR MATLAB
mu.setOutput(plhs, model, category);
}else{
mexErrMsgTxt("The only supported datatype is double precision floating point.\n");
}
}