@@ -34,21 +34,68 @@ References
34
34
35
35
"""
36
36
from cpython.pycapsule cimport PyCapsule_GetPointer
37
+ from libc.stdint cimport uint64_t
37
38
38
- import numpy as np
39
39
cimport numpy as np
40
40
from numpy.random cimport BitGenerator, bitgen_t
41
41
42
- from pyhtnorm.c_htnorm cimport *
42
+ from numpy.random import default_rng
43
43
44
44
45
45
np.import_array()
46
46
47
- cdef extern from " ../src/dist .h" :
47
+ cdef extern from " htnorm_distributions .h" :
48
48
int HTNORM_ALLOC_ERROR
49
49
50
50
51
- cdef inline void validate_return_info(info):
51
+ cdef extern from " htnorm_rng.h" nogil:
52
+ ctypedef struct rng_t:
53
+ void * base
54
+ uint64_t (* next_uint64)(void * state)
55
+ double (* next_double)(void * state)
56
+
57
+
58
+ cdef extern from " htnorm.h" nogil:
59
+ ctypedef enum mat_type " type_t" :
60
+ NORMAL
61
+ DIAGONAL
62
+ IDENTITY
63
+
64
+ ctypedef struct ht_config_t:
65
+ size_t gnrow
66
+ size_t gncol
67
+ const double * mean
68
+ const double * cov
69
+ const double * g
70
+ const double * r
71
+ bint diag
72
+
73
+ ctypedef struct sp_config_t:
74
+ mat_type a_id
75
+ mat_type o_id
76
+ size_t pnrow
77
+ size_t pncol
78
+ const double * mean
79
+ const double * a
80
+ const double * phi
81
+ const double * omega
82
+ bint struct_mean
83
+
84
+ void init_ht_config(ht_config_t* conf, size_t gnrow, size_t gncol,
85
+ const double * mean, const double * cov, const double * g,
86
+ const double * r, bint diag)
87
+
88
+ void init_sp_config(sp_config_t* conf, size_t pnrow, size_t pncol,
89
+ const double * mean, const double * a, const double * phi,
90
+ const double * omega, bint struct_mean, mat_type a_id,
91
+ mat_type o_id)
92
+
93
+ int htn_hyperplane_truncated_mvn(rng_t* rng, const ht_config_t* conf, double * out)
94
+
95
+ int htn_structured_precision_mvn(rng_t* rng, const sp_config_t* conf, double * out)
96
+
97
+
98
+ cdef inline void validate_return_info(int info):
52
99
if info == HTNORM_ALLOC_ERROR:
53
100
raise MemoryError (" Not enough memory to allocate resources." )
54
101
elif info < 0 :
@@ -62,13 +109,14 @@ cdef inline void validate_return_info(info):
62
109
)
63
110
64
111
65
- cdef set _VALID_MATRIX_TYPES = {NORMAL, DIAGONAL, IDENTITY}
112
+ cdef dict MAT_TYPE = {" regular" : NORMAL, " diagonal" : DIAGONAL, " identity" : IDENTITY}
113
+ cdef const char * BITGEN_NAME = " BitGenerator"
66
114
67
115
68
- cdef inline void initialize_rng(BitGenerator bitgenerator, rng_t* htnorm_rng):
116
+ cdef inline void initialize_rng(object bitgenerator, rng_t* htnorm_rng):
69
117
cdef bitgen_t* bitgen
70
118
71
- bitgen = < bitgen_t* > PyCapsule_GetPointer(bitgenerator.capsule, " BitGenerator " )
119
+ bitgen = < bitgen_t* > PyCapsule_GetPointer(bitgenerator.capsule, BITGEN_NAME )
72
120
htnorm_rng.base = bitgen.state
73
121
htnorm_rng.next_uint64 = bitgen.next_uint64
74
122
htnorm_rng.next_double = bitgen.next_double
@@ -135,7 +183,6 @@ def hyperplane_truncated_mvnorm(
135
183
the algorithm could not successfully generate the samples.
136
184
137
185
"""
138
- cdef BitGenerator bitgenerator
139
186
cdef rng_t rng
140
187
cdef ht_config_t config
141
188
cdef int info
@@ -152,7 +199,7 @@ def hyperplane_truncated_mvnorm(
152
199
init_ht_config(& config, g.shape[0 ], g.shape[1 ], & mean[0 ],
153
200
& cov[0 , 0 ], & g[0 , 0 ], & r[0 ], diag)
154
201
155
- bitgenerator = np.random. default_rng(random_state)._bit_generator
202
+ bitgenerator = default_rng(random_state)._bit_generator
156
203
initialize_rng(bitgenerator, & rng)
157
204
158
205
with bitgenerator.lock, nogil:
@@ -169,14 +216,15 @@ def structured_precision_mvnorm(
169
216
double[:,::1] phi ,
170
217
double[:,::1] omega ,
171
218
bint mean_structured = False ,
172
- int a_type = 0 ,
173
- int o_type = 0 ,
219
+ str a_type = " regular " ,
220
+ str o_type = " regular " ,
174
221
double[:] out = None ,
175
222
random_state = None
176
223
):
177
224
"""
178
225
structured_precision_mvnorm(mean, a, phi, omega, mean_structured=False,
179
- a_type=0, o_type=0, out=None, random_state=None)
226
+ a_type="regular", o_type="regular",
227
+ out=None, random_state=None)
180
228
181
229
Sample from a MVN with a structured precision matrix :math:`\Lambda`
182
230
.. math::
@@ -197,9 +245,9 @@ def structured_precision_mvnorm(
197
245
such than ``mean = (precision)^-1 * phi^T * omega * t``. If this
198
246
is set to True, then the `mean` parameter is assumed to contain the
199
247
array ``t``.
200
- a_type : {0, 1, 2 }, optional, default=0
248
+ a_type : {"regular", "diagonal", "identity" }, optional, default="regular"
201
249
Whether `a` ia a normal, diagonal or identity matrix.
202
- o_type : {0, 1, 2 }, optional, default=0
250
+ o_type : {"regular", "diagonal", "identity" }, optional, default="regular"
203
251
Whether `omega` ia a normal, diagonal or identity matrix.
204
252
out : 1d array, optional, default=None
205
253
An array of the same shape as `mean` to store the samples. If not
@@ -228,7 +276,6 @@ def structured_precision_mvnorm(
228
276
the algorithm could not successfully generate the samples.
229
277
230
278
"""
231
- cdef BitGenerator bitgenerator
232
279
cdef rng_t rng
233
280
cdef sp_config_t config
234
281
cdef int info
@@ -238,8 +285,8 @@ def structured_precision_mvnorm(
238
285
raise ValueError (' `omega` and `a` both need to be square matrices' )
239
286
elif (phi.shape[0 ] != omega.shape[0 ]) or (phi.shape[1 ] != a.shape[0 ]):
240
287
raise ValueError (' Shapes of `phi`, `omega` and `a` are not consistent' )
241
- elif not {a_type, o_type}.issubset(_VALID_MATRIX_TYPES ):
242
- raise ValueError (f" `a_type` & `o_type` must be one of {_VALID_MATRIX_TYPES }" )
288
+ elif not {a_type, o_type}.issubset(MAT_TYPE ):
289
+ raise ValueError (f" `a_type` & `o_type` must be one of {set(MAT_TYPE) }" )
243
290
elif has_out and out.shape[0 ] != mean.shape[0 ]:
244
291
raise ValueError (" `out` must have the same size as the mean array." )
245
292
elif not has_out:
@@ -248,9 +295,9 @@ def structured_precision_mvnorm(
248
295
249
296
init_sp_config(& config, phi.shape[0 ], phi.shape[1 ], & mean[0 ], & a[0 , 0 ],
250
297
& phi[0 , 0 ], & omega[0 , 0 ], mean_structured,
251
- < mat_type > a_type, < mat_type > o_type)
298
+ MAT_TYPE[ a_type], MAT_TYPE[ o_type] )
252
299
253
- bitgenerator = np.random. default_rng(random_state)._bit_generator
300
+ bitgenerator = default_rng(random_state)._bit_generator
254
301
initialize_rng(bitgenerator, & rng)
255
302
256
303
with bitgenerator.lock, nogil:
0 commit comments