-
Notifications
You must be signed in to change notification settings - Fork 3
/
Multrnd_Matrix_mex_fast_v1.c
118 lines (99 loc) · 3.46 KB
/
Multrnd_Matrix_mex_fast_v1.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
/*==========================================================
* Multrnd_Matrix_mex.c -
*
*
* The calling syntax is:
*
* [ZSDS,WSZS] = Multrnd_Matrix_mex_fast(Xtrain,Phi,Theta);
*
* This is a MEX-file for MATLAB.
* Copyright 2012 Mingyuan Zhou
*
* v1: replace (double) randomMT()/RAND_MAX_32 with (double) rand() / RAND_MAX
* Oct, 2015
*========================================================*/
/* $Revision: v1 $ */
#include "mex.h"
#include "string.h"
#include <math.h>
#include <stdlib.h>
/*
* #include "cokus.c"
* #define RAND_MAX_32 4294967295.0*/
mwIndex BinarySearch(double probrnd, double *prob_cumsum, mwSize Ksize) {
mwIndex k, kstart, kend;
if (probrnd <=prob_cumsum[0])
return(0);
else {
for (kstart=1, kend=Ksize-1; ; ) {
if (kstart >= kend) {
return(kend);
}
else {
k = kstart+ (kend-kstart)/2;
if (prob_cumsum[k-1]>probrnd && prob_cumsum[k]>probrnd)
kend = k-1;
else if (prob_cumsum[k-1]<probrnd && prob_cumsum[k]<probrnd)
kstart = k+1;
else
return(k);
}
}
}
return(k);
}
void Multrnd_Matrix(double *ZSDS, double *WSZS, double *Phi, double *Theta, mwIndex *ir, mwIndex *jc, double *pr, mwSize Vsize, mwSize Nsize, mwSize Ksize, double *prob_cumsum)
/*//, mxArray **lhsPtr, mxArray **rhsPtr)*/
{
double cum_sum, probrnd;
mwIndex k, j, v, token, ji=0, total=0;
/*//, ksave;*/
mwIndex starting_row_index, stopping_row_index, current_row_index;
for (j=0;j<Nsize;j++) {
starting_row_index = jc[j];
stopping_row_index = jc[j+1];
if (starting_row_index == stopping_row_index)
continue;
else {
for (current_row_index = starting_row_index; current_row_index<stopping_row_index; current_row_index++) {
v = ir[current_row_index];
for (cum_sum=0,k=0; k<Ksize; k++) {
cum_sum += Phi[v+ k*Vsize]*Theta[k + Ksize*j];
prob_cumsum[k] = cum_sum;
}
for (token=0;token< pr[total];token++) {
/*probrnd = (double) randomMT()/RAND_MAX_32*cum_sum;*/
probrnd = (double) rand() / RAND_MAX*cum_sum;
ji++;
k = BinarySearch(probrnd, prob_cumsum, Ksize);
ZSDS[k+Ksize*j]++;
WSZS[v+k*Vsize]++;
}
total++;
}
}
}
}
/* The gateway function */
void mexFunction( int nlhs, mxArray *plhs[],
int nrhs, const mxArray *prhs[])
{
double *ZSDS, *WSZS, *Phi, *Theta, *RND;
double *pr, *prob_cumsum;
mwIndex *ir, *jc;
mwIndex Vsize, Nsize, Ksize;
pr = mxGetPr(prhs[0]);
ir = mxGetIr(prhs[0]);
jc = mxGetJc(prhs[0]);
Vsize = mxGetM(prhs[0]);
Nsize = mxGetN(prhs[0]);
Ksize = mxGetN(prhs[1]);
Phi = mxGetPr(prhs[1]);
Theta = mxGetPr(prhs[2]);
plhs[0] = mxCreateDoubleMatrix(Ksize,Nsize,mxREAL);
plhs[1] = mxCreateDoubleMatrix(Vsize,Ksize,mxREAL);
ZSDS = mxGetPr(plhs[0]);
WSZS = mxGetPr(plhs[1]);
prob_cumsum = (double *) mxCalloc(Ksize,sizeof(double));
Multrnd_Matrix(ZSDS, WSZS, Phi, Theta, ir, jc, pr, Vsize, Nsize, Ksize, prob_cumsum);
}