Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New features #155

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions scikits/odes/sundials/cvode.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,19 @@ cdef class CV_RootFunction:

cdef class CV_WrapRootFunction(CV_RootFunction):
cdef object _rootfn
cdef int with_userdata
cdef public int with_userdata
cpdef set_rootfn(self, object rootfn)

cdef class CV_JacRhsFunction:
cpdef int evaluate(self, DTYPE_t t,
np.ndarray[DTYPE_t, ndim=1] y,
np.ndarray[DTYPE_t, ndim=1] fy,
np.ndarray[DTYPE_t, ndim=2] J) except? -1
np.ndarray[DTYPE_t, ndim=2] J,
object userdata = *) except? -1

cdef class CV_WrapJacRhsFunction(CV_JacRhsFunction):
cdef public object _jacfn
cdef int with_userdata
cdef public int with_userdata
cpdef set_jacfn(self, object jacfn)

cdef class CV_PrecSetupFunction:
Expand Down Expand Up @@ -128,7 +129,7 @@ cdef class CVODE:
cdef N_Vector atol
cdef void* _cv_mem
cdef SUNContext sunctx
cdef dict options
cdef public dict options
cdef bint parallel_implementation, initialized, _old_api, _step_compute, _validate_flags
cdef CV_data aux_data

Expand Down
43 changes: 32 additions & 11 deletions scikits/odes/sundials/cvode.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ cdef class CV_JacRhsFunction:
cpdef int evaluate(self, DTYPE_t t,
np.ndarray[DTYPE_t, ndim=1] y,
np.ndarray[DTYPE_t, ndim=1] fy,
np.ndarray[DTYPE_t, ndim=2] J) except? -1:
np.ndarray[DTYPE_t, ndim=2] J,
object userdata = None) except? -1:
"""
Returns the Jacobi matrix of the right hand side function, as
d(rhs)/d y
Expand All @@ -275,22 +276,30 @@ cdef class CV_WrapJacRhsFunction(CV_JacRhsFunction):
"""
Set some jacobian equations as a JacRhsFunction executable class.
"""
self.with_userdata = 0
self._jacfn = jacfn
nrarg = _get_num_args(jacfn)
if nrarg > 5:
#hopefully a class method, self gives 6 arg!
self.with_userdata = 1
elif nrarg == 5 and inspect.isfunction(jacfn):
self.with_userdata = 1
self._jacfn = jacfn

cpdef int evaluate(self, DTYPE_t t,
np.ndarray[DTYPE_t, ndim=1] y,
np.ndarray[DTYPE_t, ndim=1] fy,
np.ndarray J) except? -1:
np.ndarray J,
object userdata = None) except? -1:
"""
Returns the Jacobi matrix (for dense the full matrix, for band only
bands. Result has to be stored in the variable J, which is preallocated
to the corresponding size.
"""
## if self.with_userdata == 1:
## self._jacfn(t, y, ydot, cj, J, userdata)
## else:
## self._jacfn(t, y, ydot, cj, J)
user_flag = self._jacfn(t, y, fy, J)
if self.with_userdata == 1:
user_flag = self._jacfn(t, y, fy, J, userdata)
else:
user_flag = self._jacfn(t, y, fy, J)

if user_flag is None:
user_flag = 0
Expand Down Expand Up @@ -318,7 +327,7 @@ cdef int _jacdense(sunrealtype tt,
ff_tmp = aux_data.z_tmp
nv_s2ndarray(ff, ff_tmp)

user_flag = aux_data.jac.evaluate(tt, yy_tmp, ff_tmp, jac_tmp)
user_flag = aux_data.jac.evaluate(tt, yy_tmp, ff_tmp, jac_tmp, aux_data.user_data,)

if parallel_implementation:
raise NotImplemented
Expand Down Expand Up @@ -1068,7 +1077,7 @@ cdef class CVODE:
if not supress_supported_check:
for opt in options.keys():
if not opt in ['atol', 'rtol', 'tstop', 'rootfn', 'nr_rootfns',
'verbosity', 'one_step_compute']:
'verbosity', 'one_step_compute', 'max_step_size']:
raise ValueError("Option '%s' can''t be set runtime." % opt)

# Verbosity level
Expand Down Expand Up @@ -1172,7 +1181,7 @@ cdef class CVODE:
if ('tstop' in options) and (options['tstop'] is not None):
opts_tstop = options['tstop']
self.options['tstop'] = opts_tstop
if (not opts_tstop is None) and (opts_tstop > 0.):
if (not opts_tstop is None):
flag = CVodeSetStopTime(cv_mem, <sunrealtype> opts_tstop)
if flag == CV_ILL_INPUT:
raise ValueError('CVodeSetStopTime::Stop value is beyond '
Expand Down Expand Up @@ -1644,8 +1653,12 @@ cdef class CVODE:
else:
_test = np.empty((len(y0), len(y0)), DTYPE)
_fy_test = np.zeros(len(y0), DTYPE)
jac._jacfn(t0, y0, _fy_test, _test)
if jac.with_userdata:
jac._jacfn(t0, y0, _fy_test, _test, opts['user_data'])
else:
jac._jacfn(t0, y0, _fy_test, _test)
_test = None
_fy_test = None

#now we initialize storage which is persistent over steps
self.t_roots = []
Expand Down Expand Up @@ -1896,6 +1909,14 @@ cdef class CVODE:
self.t_tstop, self.y_tstop,
)

def rootinfo(self):
#cdef int[self.options['nr_rootfns']] rootsfound
N = self.options['nr_rootfns']
cdef np.ndarray[int, ndim=1, mode='c'] rootsfound = np.empty(N, dtype=np.int32)
#cdef int rootsfound[N]
CVodeGetRootInfo(self._cv_mem, &rootsfound[0])
return rootsfound


def step(self, DTYPE_t t, np.ndarray[DTYPE_t, ndim=1] y_retn = None):
"""
Expand Down