-
Notifications
You must be signed in to change notification settings - Fork 0
/
viterbi_cpp.cpp
177 lines (167 loc) · 6.47 KB
/
viterbi_cpp.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
// MEX function to compute viterbi loop over states
// [map_state_sequence] = viterbi(state_ids_i, state_ids_j, trans_prob_ij, ...
// initial_prob, obs_lik, gmm_from_state)
//
// WARNING: assumes that state_ids_j of the transition model are sorted!
// TODO: Check if that's true
//
// INPUT:
// state_ids_i=[1:4]';
// state_ids_j=[2;1;4;3];
// trans_prob_ij=[1;1;1;1];
// initial_prob=[0.5;0;0.5;0];
// obs_lik=zeros(2, 2, 4);
// obs_lik=[0.7, 0.2, 0.7, 0.9; 0.3, 0.8, 0.3, 0.1];
// gmm_from_state=[1; 1; 2; 2];
//
// 07.05.2015 by Florian Krebs
// 17.06.2015 integrated uint16/uint32 switch by Harald Frostel
// 31.08.2015 changed input parameters to make the code independent of the beat
// tracking HMM
// ---------------------------------------------------------------------
#include <math.h>
#include <matrix.h>
#include <mex.h>
#include <algorithm>
#include <vector>
#include <stdint.h>
/* Definitions to keep compatibility with earlier versions of ML */
#ifndef MWSIZE_MAX
typedef int mwSize;
typedef int mwIndex;
typedef int mwSignedIndex;
#if (defined(_LP64) || defined(_WIN64)) && !defined(MX_COMPAT_32)
/* Currently 2^48 based on hardware limitations */
# define MWSIZE_MAX 281474976710655UL
# define MWINDEX_MAX 281474976710655UL
# define MWSINDEX_MAX 281474976710655L
# define MWSINDEX_MIN -281474976710655L
#else
# define MWSIZE_MAX 2147483647UL
# define MWINDEX_MAX 2147483647UL
# define MWSINDEX_MAX 2147483647L
# define MWSINDEX_MIN -2147483647L
#endif
#define MWSIZE_MIN 0UL
#define MWINDEX_MIN 0UL
#endif
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
if (nrhs < 6) {
mexErrMsgTxt("To few arguments.");
}
if (nrhs > 6) {
mexErrMsgTxt("To many arguments.");
}
//declare variables
mxArray *map_state_sequence, *psi;
mxArray *debug_data;
const mwSize *dims;
double sum_k, temp, debug_temp, check, new_prediction;
int num_frames, obs_idx, num_trans, i_state;
int i, j, i_trans, prev_end_state, current_start_state, current_end_state;
bool store_new_precursor = false;
//associate inputs and associate pointers
double *state_ids_i_ptr = mxGetPr(prhs[0]);
double *state_ids_j_ptr = mxGetPr(prhs[1]);
double *trans_prob_ij_ptr = mxGetPr(prhs[2]);
double *initial_prob_ptr = mxGetPr(prhs[3]);
double *obs_lik_ptr = mxGetPr(prhs[4]);
double *gmm_from_state_ptr = mxGetPr(prhs[5]);
//figure out dimensions
// num_frames
dims = mxGetDimensions(prhs[4]);
num_frames = (int)dims[1];
// num_states
dims = mxGetDimensions(prhs[3]);
const mwSize num_states = dims[0];
// number of possible transitions
dims = mxGetDimensions(prhs[1]);
num_trans = (int)dims[0];
//associate outputs
map_state_sequence = plhs[0] = mxCreateDoubleMatrix(num_frames,1,mxREAL);
// debug_data = plhs[1] = mxCreateDoubleMatrix(num_states, 1, mxREAL);
//internal variables
// if less than 65535 states are used we can store the state ids as uint16
const bool elems32 = (num_states) > 65535;
uint16_t *psi_ptr_16 = NULL;
uint32_t *psi_ptr_32 = NULL;
if (elems32) {
psi = mxCreateNumericMatrix(num_states, num_frames, mxUINT32_CLASS, mxREAL);
psi_ptr_32 = static_cast<uint32_t *>(mxGetData(psi));
} else {
psi = mxCreateNumericMatrix(num_states, num_frames, mxUINT16_CLASS, mxREAL);
psi_ptr_16 = static_cast<uint16_t *>(mxGetData(psi));
}
std::vector<double> delta(initial_prob_ptr, initial_prob_ptr+num_states);
std::vector<double> prediction(initial_prob_ptr, initial_prob_ptr+num_states);
//associate pointers (get pointer to the first element of the real data)
double *map_state_sequence_ptr = mxGetPr(map_state_sequence);
//start computing
debug_temp = 0;
for(i=0;i<num_frames;i++)
{
prev_end_state = -7;
// loop over possible transitions
for(i_trans=0;i_trans<num_trans;i_trans++)
{
// get start and end state of transition i_trans
current_start_state = (int)state_ids_i_ptr[i_trans]-1;
current_end_state = (int)state_ids_j_ptr[i_trans]-1;
if (current_end_state == prev_end_state) {
// the transition (i_trans-1) has the same end_state as transition
// (i_trans): Find the best start_state among these two
new_prediction = delta[current_start_state] * trans_prob_ij_ptr[i_trans];
if ( new_prediction > prediction[prev_end_state]) {
// found more probable precursor state -> save it
prediction[current_end_state] = new_prediction;
store_new_precursor = true;
}
} else {
// the transition i_trans has a different end-state from transition
// i_trans-1
prev_end_state = current_end_state;
prediction[current_end_state] = delta[current_start_state] * trans_prob_ij_ptr[i_trans];
store_new_precursor = true;
}
if (store_new_precursor) {
const size_t idx = current_end_state * num_frames + i;
if (elems32) {
psi_ptr_32[idx] = current_start_state;
}
else {
psi_ptr_16[idx] = current_start_state;
}
store_new_precursor = false;
}
}
sum_k = 0;
for (i_state=0; i_state<num_states; i_state++) {
delta[i_state] = prediction[i_state] * obs_lik_ptr[(int)gmm_from_state_ptr[i_state], i];
sum_k += delta[i_state];
}
// normalise
for (i_state=0; i_state<num_states; i_state++) {
delta[i_state] /= sum_k;
}
}
// Back tracing
// Find best_end_state
int best_end_state = 0;
double best_delta = -7;
for (i_state=0; i_state<num_states; i_state++) {
if (delta[i_state] > best_delta) {
best_delta = delta[i_state];
best_end_state = i_state;
}
}
// store and convert to MATLAB index
map_state_sequence_ptr[num_frames-1] = (double)best_end_state + 1;
for (i=num_frames-1; i>0; i--) {
const size_t idx = (best_end_state * num_frames + i); // idx for psi
const uint32_t value = elems32?psi_ptr_32[idx]:uint32_t(psi_ptr_16[idx]);
map_state_sequence_ptr[i-1] = value + 1;
best_end_state = value;
}
mxDestroyArray(psi);
}