-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcyglobal.pyx
85 lines (78 loc) · 2.91 KB
/
cyglobal.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
## cython: profile=True
## cython: linetrace=True
## cython: binding=True
## cython: boundscheck=False
## cython: wraparound=False
## distutils: define_macros=CYTHON_TRACE_NOGIL=1
cimport numpy as np
import numpy as np
from scipy.special import logsumexp
from libc.math cimport log, exp
def calc_global_likelihood(
np.ndarray[ndim=1,dtype=np.double_t] params,
np.ndarray[ndim=2,dtype=np.double_t] X,
dict obs,
dict blims):
cdef:
int rowlen, obs_idx, j, obs_count, cm_idx
dict reflo, bamlo
int [:,:] regobs
double [:,:] logprobs
double ll
rowlen = X.shape[1]
ll = 0.0
for regkey in obs.keys():
major, readnum = regkey
if len(obs[regkey]) == 0: # may not have observations for all regressions, esp in small cases
continue
regobs = obs[regkey]
low, high = blims[regkey]
b = params[low:high].reshape((rowlen,-1), order = 'F')
Xb = np.column_stack((np.dot(X,b), np.zeros(X.shape[0])))
Xb -= logsumexp(Xb, axis = 1)[:,None]
logprobs = Xb
for obs_idx in range(regobs.shape[0]):
for j in range(4):
obs_count = regobs[obs_idx,j+1]
cm_idx = regobs[obs_idx,0]
if obs_count > 0:
ll += logprobs[cm_idx, j]*obs_count
return ll
def calc_global_gradient(
np.ndarray[ndim=1,dtype=np.double_t] params,
np.ndarray[ndim=2,dtype=np.double_t] X,
dict obs,
dict blims):
cdef:
int rowlen, obs_idx, j, obs_count, param_idx, low, high, param_outcome, param_row_idx, obs_outcome, cm_idx
dict reflo, bamlo
int [:,:] regobs
double [:,:] logprobs
double x, prob_term
rowlen = X.shape[1]
grad_np = np.zeros(params.shape[0])
cdef double [:] grad = grad_np
for regkey in obs.keys():
major, readnum = regkey
if len(obs[regkey]) == 0:
continue
regobs = obs[regkey]
low, high = blims[regkey]
b = params[low:high].reshape((rowlen,-1), order = 'F')
Xb = np.column_stack((np.dot(X,b), np.zeros(X.shape[0])))
Xb -= logsumexp(Xb, axis = 1)[:,None]
logprobs = Xb
for obs_idx in range(regobs.shape[0]):
for obs_outcome in range(4):
obs_count = regobs[obs_idx,obs_outcome+1]
cm_idx = regobs[obs_idx, 0]
if obs_count > 0:
for param_idx in range(low,high):
param_outcome = (param_idx-low) // rowlen
param_row_idx = (param_idx-low) % rowlen
x = X[cm_idx,param_row_idx]
prob_term = -1 * exp(logprobs[cm_idx,param_outcome])
if obs_outcome == param_outcome:
prob_term += 1
grad[param_idx] += x*prob_term*obs_count
return grad_np