Skip to content

Commit

Permalink
Merge pull request #2196 from Shaikh-Ubaid/fix_ccall_for_cpython
Browse files Browse the repository at this point in the history
Support ccall() for symengine and other libs
  • Loading branch information
certik authored Jul 20, 2023
2 parents d559090 + f6eab9c commit d41c25f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
9 changes: 5 additions & 4 deletions integration_tests/symbolics_07.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from lpython import ccall
from lpython import ccall, CPtr
import os

@ccall(header="symengine/cwrapper.h")
@ccall(header="symengine/cwrapper.h", c_shared_lib="symengine", c_shared_lib_path=f"{os.environ['CONDA_PREFIX']}/lib")
def basic_new_heap() -> CPtr:
pass

@ccall(header="symengine/cwrapper.h")
@ccall(header="symengine/cwrapper.h", c_shared_lib="symengine", c_shared_lib_path=f"{os.environ['CONDA_PREFIX']}/lib")
def basic_const_pi(x: CPtr) -> None:
pass

@ccall(header="symengine/cwrapper.h")
@ccall(header="symengine/cwrapper.h", c_shared_lib="symengine", c_shared_lib_path=f"{os.environ['CONDA_PREFIX']}/lib")
def basic_str(x: CPtr) -> str:
pass

Expand Down
28 changes: 19 additions & 9 deletions src/runtime/lpython/lpython.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ class CTypes:
A wrapper class for interfacing C via ctypes.
"""

def __init__(self, f):
def __init__(self, f, py_mod = None, py_mod_path = None):
def get_rtlib_dir():
current_dir = os.path.dirname(os.path.abspath(__file__))
return os.path.join(current_dir, "..")
Expand All @@ -349,17 +349,20 @@ def get_lib_name(name):
else:
raise NotImplementedError("Platform not implemented")
def get_crtlib_path():
py_mod = os.environ.get("LPYTHON_PY_MOD_NAME", "")
nonlocal py_mod, py_mod_path
if py_mod is None:
py_mod = os.environ.get("LPYTHON_PY_MOD_NAME", "")
if py_mod == "":
return os.path.join(get_rtlib_dir(),
get_lib_name("lpython_runtime"))
else:
py_mod_path = os.environ["LPYTHON_PY_MOD_PATH"]
if py_mod_path is None:
py_mod_path = os.environ["LPYTHON_PY_MOD_PATH"]
return os.path.join(py_mod_path, get_lib_name(py_mod))
self.name = f.__name__
self.args = f.__code__.co_varnames
self.annotations = f.__annotations__
if "LPYTHON_PY_MOD_NAME" in os.environ:
if ("LPYTHON_PY_MOD_NAME" in os.environ) or (py_mod is not None):
crtlib = get_crtlib_path()
self.library = ctypes.CDLL(crtlib)
self.cf = self.library[self.name]
Expand Down Expand Up @@ -388,7 +391,10 @@ def __call__(self, *args, **kwargs):
new_args.append(arg.ctypes.data_as(ctypes.POINTER(convert_numpy_dtype_to_ctype(arg.dtype))))
else:
new_args.append(arg)
return self.cf(*new_args)
res = self.cf(*new_args)
if self.cf.restype == ctypes.c_char_p:
res = res.decode("utf-8")
return res

def convert_to_ctypes_Union(f):
fields = []
Expand Down Expand Up @@ -465,10 +471,14 @@ def __init__(self, *args):

return ctypes_Structure

def ccall(f):
if isclass(f) and issubclass(f, Union):
return f
return CTypes(f)
def ccall(f=None, header=None, c_shared_lib=None, c_shared_lib_path=None):
def wrap(func):
if not isclass(func) or not issubclass(func, Union):
func = CTypes(func, c_shared_lib, c_shared_lib_path)
return func
if f:
return wrap(f)
return wrap

def pythoncall(*args, **kwargs):
def inner(fn):
Expand Down

0 comments on commit d41c25f

Please sign in to comment.