Skip to content

Commit

Permalink
[Python] Enhance Func1 API
Browse files Browse the repository at this point in the history
  • Loading branch information
ischoegl committed Jun 30, 2023
1 parent 74d5cdb commit e11a6e2
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 38 deletions.
18 changes: 12 additions & 6 deletions interfaces/cython/cantera/func1.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@ from .ctcxx cimport *
cdef extern from "cantera/numerics/Func1.h":
cdef cppclass CxxFunc1 "Cantera::Func1":
double eval(double) except +translate_exception
string type()

cdef cppclass CxxTabulated1 "Cantera::Tabulated1" (CxxFunc1):
CxxTabulated1(int, double*, double*, string) except +translate_exception
double eval(double) except +translate_exception

cdef extern from "cantera/cython/funcWrapper.h":
ctypedef double (*callback_wrapper)(double, void*, void**) except? 0.0
Expand All @@ -31,12 +29,20 @@ cdef extern from "cantera/cython/funcWrapper.h":
void setExceptionValue(PyObject*)


cdef extern from "cantera/numerics/Func1Factory.h":
cdef shared_ptr[CxxFunc1] CxxNewFunc1 "Cantera::newFunc1" (
string, double) except +translate_exception
cdef shared_ptr[CxxFunc1] CxxNewFunc1 "Cantera::newFunc1" (
string, vector[double]&) except +translate_exception
cdef shared_ptr[CxxFunc1] CxxNewFunc1 "Cantera::newFunc1" (
string, shared_ptr[CxxFunc1], shared_ptr[CxxFunc1]) except +translate_exception
cdef shared_ptr[CxxFunc1] CxxNewFunc1 "Cantera::newFunc1" (
string, shared_ptr[CxxFunc1], double) except +translate_exception


cdef class Func1:
cdef shared_ptr[CxxFunc1] _func
cdef CxxFunc1* func
cdef object callable
cdef object exception
cpdef void _set_callback(self, object) except *

cdef class TabulatedFunction(Func1):
cpdef void _set_tables(self, object, object, string) except *
112 changes: 80 additions & 32 deletions interfaces/cython/cantera/func1.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -65,34 +65,92 @@ cdef class Func1:
self.exception = None
self.callable = None

def __init__(self, c):
def __init__(self, c, init=True):
if init is False:
# used by 'create' classmethod
return
if hasattr(c, '__call__'):
# callback function
self._set_callback(c)
else:
return

cdef Func1 func
try:
arr = np.array(c)
try:
if arr.ndim == 0:
# handle constants or unsized numpy arrays
k = float(c)
self._set_callback(lambda t: k)
elif arr.size == 1:
# handle lists, tuples or numpy arrays with a single element
k = float(c[0])
self._set_callback(lambda t: k)
else:
raise TypeError

except TypeError:
raise TypeError(
"'Func1' objects must be constructed from a number or "
"a callable object") from None
if arr.ndim == 0:
# handle constants or unsized numpy arrays
k = float(c)
elif arr.size == 1:
# handle lists, tuples or numpy arrays with a single element
k = float(c[0])
else:
raise TypeError
func = Func1.create("constant", k)
self._func = func._func
self.func = self._func.get()

except TypeError:
raise TypeError(
"'Func1' objects must be constructed from a number or "
"a callable object") from None

cpdef void _set_callback(self, c) except *:
self.callable = c
self._func.reset(new CxxFunc1Py(func_callback, <void*>self))
self.func = self._func.get()

@property
def type(self):
"""
Return the type of the underlying C++ functor object.
.. versionadded:: 3.0
"""
return pystr(self.func.type())

@classmethod
def create(cls, functor_type, *args):
"""
Create new C++ `Func1` functor (advanced feature).
For supported functor types, see the Cantera C++ documentation.
.. versionadded:: 3.0
"""
cdef Func1 out = cls(None, False)
cdef Func1 f0
cdef Func1 f1
cdef string cxx_string = stringify(functor_type)
cdef vector[double] arr
if len(args) == 0:
# simple functor with no parameter
out._func = CxxNewFunc1(cxx_string, 1.)
elif len(args) == 1:
if hasattr(args[0], "__len__"):
# advanced functor with array and no parameter
for v in args[0]:
arr.push_back(v)
out._func = CxxNewFunc1(cxx_string, arr)
else:
# simple functor with scalar parameter
out._func = CxxNewFunc1(cxx_string, float(args[0]))
elif len(args) == 2:
if isinstance(args[0], Func1) and isinstance(args[1], Func1):
# compound functor
f0 = args[0]
f1 = args[1]
out._func = CxxNewFunc1(cxx_string, f0._func, f1._func)
elif isinstance(args[0], Func1):
# modified functor
f0 = args[0]
out._func = CxxNewFunc1(cxx_string, f0._func, float(args[1]))
else:
raise ValueError("Invalid arguments")
else:
raise ValueError("Invalid arguments")
out.func = out._func.get()
return out

def __call__(self, t):
return self.func.eval(t)

Expand Down Expand Up @@ -134,17 +192,7 @@ cdef class TabulatedFunction(Func1):
"""

def __init__(self, time, fval, method='linear'):
self._set_tables(time, fval, stringify(method))

cpdef void _set_tables(self, time, fval, string method) except *:
tt = np.asarray(time, dtype=np.double)
ff = np.asarray(fval, dtype=np.double)
if tt.size != ff.size:
raise ValueError("Sizes of arrays do not match "
"({} vs {})".format(tt.size, ff.size))
elif tt.size == 0:
raise ValueError("Arrays must not be empty.")
cdef np.ndarray[np.double_t, ndim=1] tvec = tt
cdef np.ndarray[np.double_t, ndim=1] fvec = ff
self.func = <CxxFunc1*>(new CxxTabulated1(tt.size, &tvec[0], &fvec[0],
method))
arr = np.hstack([np.array(time), np.array(fval)])
cdef Func1 func = Func1.create(f"tabulated-{method}", arr)
self._func = func._func
self.func = self._func.get()

0 comments on commit e11a6e2

Please sign in to comment.