Skip to content

Commit

Permalink
added ccompile class
Browse files Browse the repository at this point in the history
  • Loading branch information
aerorohit committed Jun 14, 2022
1 parent 9b85b9b commit 9d24c78
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 13 deletions.
76 changes: 65 additions & 11 deletions compyle/c_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,27 @@
from compyle.profile import profile
from .translator import ocl_detect_type, KnownType
from .cython_generator import CythonGenerator, get_func_definition
from .cython_generator import getsourcelines
from mako.template import Template
from .ext_module import get_md5
from .cimport import Cmodule
from .transpiler import Transpiler
from . import array

import pybind11
import numpy as np


elwise_c_pybind = '''
PYBIND11_MODULE(${modname}, m) {
m.def("${modname}", [](${pyb11_args}){
return ${name}(${pyb11_call});
});
}
'''


class CBackend(CythonGenerator):
Expand Down Expand Up @@ -41,17 +62,50 @@ def ctype_to_pyb11(self, c_type):
def _get_self_type(self):
return KnownType('GLOBAL_MEM %s*' % self._class_name)


elwise_c_pybind = '''
PYBIND11_MODULE(${modname}, m) {
m.def("${modname}", [](${pyb11_args}){
return ${name}(${pyb11_call});
});
}
'''
class CCompile(CBackend):
def __init__(self, func):
super(CCompile, self).__init__()
self.func = func
self.src = "not yet generated"
self.tp = Transpiler(backend='c')
self.c_func = self._compile()

def _compile(self):
self.tp.add(self.func)
self.src = self.tp.get_code()

py_data, c_data = self.get_func_signature_pyb11(self.func)

pyb11_args = ', '.join(py_data[0][:])
pyb11_call = ', '.join(py_data[1][:])
hash_fn = get_md5(self.src)
modname = f'm_{hash_fn}'
template = Template(elwise_c_pybind)
src_bind = template.render(
name=self.func.__name__,
modname=modname,
pyb11_args=pyb11_args,
pyb11_call=pyb11_call
)
self.src += src_bind

mod = Cmodule(self.src, hash_fn, openmp=False,
extra_inc_dir=[pybind11.get_include()])
module = mod.load()
return getattr(module, modname)

def _massage_arg(self, x):
if isinstance(x, array.Array):
return x.dev
elif isinstance(x, np.ndarray):
return x
else:
return np.asarray(x)

@profile
def __call__(self, *args, **kwargs):
c_args = [self._massage_arg(x) for x in args]
self.c_func(*c_args)

elwise_c_template = '''
Expand Down
18 changes: 16 additions & 2 deletions compyle/tests/test_c_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
from unittest import TestCase
from ..c_backend import CBackend
from ..c_backend import CBackend, CCompile
from ..types import annotate
import numpy as np

Expand All @@ -24,6 +24,20 @@ def test_fn(x, y, z=2, w=3.0):
self.assertListEqual(c_args, exp_c_args)
self.assertListEqual(c_call, exp_c_call)


class TestCCompile(TestCase):
def test_compile(self):
@annotate(int='n, p', intp='x, y')
def get_pow(n, p, x, y):
for i in range(n):
y[i] = x[i]**p
c_get_pow = CCompile(get_pow)
n = 5
p = 5
x = np.ones(n, dtype=np.int32) * 2
y = np.zeros(n, dtype=np.int32)
y_exp = np.ones(n, dtype=np.int32) * 32
c_get_pow(n, p, x, y)
assert(np.all(y == y_exp))

if __name__ == '__main__':
unittest.main()

0 comments on commit 9d24c78

Please sign in to comment.