diff --git a/python/bridgestan/model.py b/python/bridgestan/model.py index d1a69bd5..fe53e71b 100644 --- a/python/bridgestan/model.py +++ b/python/bridgestan/model.py @@ -13,8 +13,36 @@ from .compile import compile_model, windows_dll_path_setup from .util import validate_readable -FloatArray = npt.NDArray[np.float64] -double_array = ndpointer(dtype=ctypes.c_double, flags=("C_CONTIGUOUS")) + +def array_ptr(*args, **kwargs): + """ + Return a new class which can be used in a ctypes signature + to accept either a numpy array or a compatible + ``ctypes.POINTER`` instance. + + All arguments are forwarded to :func:`np.ctypeslib.ndpointer`. + """ + np_type = ndpointer(*args, **kwargs) + base = np.ctypeslib.as_ctypes_type(np_type._dtype_) + ctypes_type = ctypes.POINTER(base) + + def from_param(cls, obj): + if isinstance(obj, (ctypes_type, ctypes.Array)): + return ctypes_type.from_param(obj) + return np_type.from_param(obj) + + return type(np_type.__name__, (np_type,), {"from_param": classmethod(from_param)}) + + +FloatArray = Union[ + npt.NDArray[np.float64], + ctypes.POINTER(ctypes.c_double), + ctypes.Array[ctypes.c_double], +] +double_array = array_ptr(dtype=ctypes.c_double, flags=("C_CONTIGUOUS")) +writeable_double_array = array_ptr( + dtype=ctypes.c_double, flags=("C_CONTIGUOUS", "WRITEABLE") +) star_star_char = ctypes.POINTER(ctypes.c_char_p) c_print_callback = ctypes.CFUNCTYPE(None, ctypes.POINTER(ctypes.c_char), ctypes.c_int) @@ -165,6 +193,19 @@ def __init__( self._param_unc_num.restype = ctypes.c_int self._param_unc_num.argtypes = [ctypes.c_void_p] + num_params = self._param_unc_num(self.model) + + param_sized_out_array = array_ptr( + dtype=ctypes.c_double, + flags=("C_CONTIGUOUS", "WRITEABLE"), + shape=(num_params,), + ) + param_sqrd_sized_out_array = array_ptr( + dtype=ctypes.c_double, + flags=("C_CONTIGUOUS", "WRITEABLE"), + shape=(num_params, num_params), + ) + self._param_names = self.stanlib.bs_param_names self._param_names.restype = ctypes.c_char_p self._param_names.argtypes = [ @@ -184,7 +225,7 @@ def __init__( ctypes.c_int, ctypes.c_int, double_array, - double_array, + writeable_double_array, ctypes.c_void_p, star_star_char, ] @@ -194,7 +235,7 @@ def __init__( self._param_unconstrain.argtypes = [ ctypes.c_void_p, double_array, - double_array, + param_sized_out_array, star_star_char, ] @@ -203,7 +244,7 @@ def __init__( self._param_unconstrain_json.argtypes = [ ctypes.c_void_p, ctypes.c_char_p, - double_array, + param_sized_out_array, star_star_char, ] @@ -226,7 +267,7 @@ def __init__( ctypes.c_int, double_array, ctypes.POINTER(ctypes.c_double), - double_array, + param_sized_out_array, star_star_char, ] @@ -238,8 +279,8 @@ def __init__( ctypes.c_int, double_array, ctypes.POINTER(ctypes.c_double), - double_array, - double_array, + param_sized_out_array, + param_sqrd_sized_out_array, star_star_char, ] @@ -252,7 +293,7 @@ def __init__( double_array, double_array, ctypes.POINTER(ctypes.c_double), - double_array, + param_sized_out_array, star_star_char, ] @@ -395,7 +436,7 @@ def param_constrain( dims = self.param_num(include_tp=include_tp, include_gq=include_gq) if out is None: out = np.zeros(dims) - elif out.size != dims: + elif hasattr(out, "shape") and out.shape != (dims,): raise ValueError( "Error: out must be same size as number of constrained parameters" ) @@ -445,12 +486,10 @@ def param_unconstrain( dims = self.param_unc_num() if out is None: out = np.zeros(shape=dims) - elif out.size != dims: - raise ValueError( - f"out size = {out.size} != unconstrained params size = {dims}" - ) + err = ctypes.c_char_p() rc = self._param_unconstrain(self.model, theta, out, ctypes.byref(err)) + if rc: raise self._handle_error(err, "param_unconstrain") return out @@ -476,10 +515,7 @@ def param_unconstrain_json( dims = self.param_unc_num() if out is None: out = np.zeros(shape=dims) - elif out.size != dims: - raise ValueError( - f"out size = {out.size} != unconstrained params size = {dims}" - ) + chars = theta_json.encode("UTF-8") err = ctypes.c_char_p() rc = self._param_unconstrain_json(self.model, chars, out, ctypes.byref(err)) @@ -552,10 +588,10 @@ def log_density_gradient( dims = self.param_unc_num() if out is None: out = np.zeros(shape=dims) - elif out.size != dims: - raise ValueError(f"out size = {out.size} != params size = {dims}") + lp = ctypes.c_double() err = ctypes.c_char_p() + rc = self._log_density_gradient( self.model, int(propto), @@ -606,17 +642,13 @@ def log_density_hessian( dims = self.param_unc_num() if out_grad is None: out_grad = np.zeros(shape=dims) - elif out_grad.shape != (dims,): - raise ValueError(f"out_grad size = {out_grad.size} != params size = {dims}") - hess_size = dims * dims + if out_hess is None: - out_hess = np.zeros(shape=hess_size) - elif out_hess.shape != (dims, dims): - raise ValueError( - f"out_hess size = {out_hess.size} != params size^2 = {hess_size}" - ) + out_hess = np.zeros(shape=(dims, dims)) + lp = ctypes.c_double() err = ctypes.c_char_p() + rc = self._log_density_hessian( self.model, int(propto), @@ -661,10 +693,9 @@ def log_density_hessian_vector_product( dims = self.param_unc_num() if out is None: out = np.zeros(shape=dims) - elif out.size != dims: - raise ValueError(f"out size = {out.size} != params size = {dims}") lp = ctypes.c_double() err = ctypes.c_char_p() + rc = self._log_density_hvp( self.model, int(propto), diff --git a/python/test/test_stanmodel.py b/python/test/test_stanmodel.py index 0f719223..ced5a4bb 100644 --- a/python/test/test_stanmodel.py +++ b/python/test/test_stanmodel.py @@ -1,3 +1,4 @@ +import ctypes from pathlib import Path import numpy as np @@ -274,7 +275,7 @@ def test_param_unconstrain(): c2 = bridge.param_unconstrain(b, out=scratch) np.testing.assert_allclose(a, c2) scratch_wrong = np.zeros(16) - with pytest.raises(ValueError): + with pytest.raises(ctypes.ArgumentError): bridge.param_unconstrain(b, out=scratch_wrong) @@ -294,7 +295,7 @@ def test_param_unconstrain_json(): np.testing.assert_allclose(theta_unc, theta_unc_j_test2) scratch_bad = np.zeros(10) - with pytest.raises(ValueError): + with pytest.raises(ctypes.ArgumentError): bridge.param_unconstrain_json(theta_json, out=scratch_bad) @@ -400,7 +401,7 @@ def _grad_jacobian_true(y_unc): np.testing.assert_allclose(_grad_logp(y_unc) + _grad_jacobian_true(y_unc), grad[0]) # scratch_bad = np.zeros(bridge.param_unc_num() + 10) - with pytest.raises(ValueError): + with pytest.raises(ctypes.ArgumentError): bridge.log_density_gradient(y_unc, out=scratch_bad) @@ -506,7 +507,7 @@ def _hess_jacobian_true(y_unc): np.testing.assert_allclose(_grad_logp(y_unc) + _grad_jacobian_true(y_unc), grad[0]) # scratch_bad = np.zeros(bridge.param_unc_num() + 10) - with pytest.raises(ValueError): + with pytest.raises(ctypes.ArgumentError): bridge.log_density_hessian(y_unc, out_grad=scratch_bad) # test with 5 x 5 Hessian @@ -551,9 +552,6 @@ def test_out_behavior(): np.testing.assert_allclose(grads[0], grads[1]) -# BONUS TESTS - - def test_bernoulli(): def _bernoulli(y, p): return np.sum(y * np.log(p) + (1 - y) * np.log(1 - p)) @@ -717,6 +715,39 @@ def test_reload_warning(): model2 = bs.StanModel(relative_lib, data) +def test_ctypes_pointers(): + lib = STAN_FOLDER / "simple" / "simple_model.so" + data = STAN_FOLDER / "simple" / "simple.data.json" + model = bs.StanModel(lib, data) + + N = 5 + ParamType = ctypes.c_double * N + params = ParamType(*list(range(N))) + + lp = model.log_density( + ctypes.cast(params, ctypes.POINTER(ctypes.c_double)), propto=False + ) # basic input + assert lp == -15.0 + lp = model.log_density(params, propto=False) + assert lp == -15.0 + + grad_out = ParamType() + lp, _ = model.log_density_gradient(params, out=grad_out) + for i in range(N): + assert grad_out[i] == -1.0 * i + grad_out2 = ParamType() + lp, _ = model.log_density_gradient( + params, out=ctypes.cast(grad_out2, ctypes.POINTER(ctypes.c_double)) + ) + for i in range(N): + assert grad_out[i] == -1.0 * i + + # test bad type + with pytest.raises(ctypes.ArgumentError): + params = (ctypes.c_int * N)(*list(range(N))) + model.log_density(params) + + @pytest.fixture(scope="module") def recompile_simple(): """Recompile simple_model with autodiff hessian enable, then clean-up/restore it after test"""