-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathllvm_w.py
358 lines (312 loc) · 12.5 KB
/
llvm_w.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
# Copyright (C) 2009 Corrado Zoccolo
# This file is part of py_llvm_compile.
# py_llvm_compile is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# py_llvm_compile is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with py_llvm_compile. If not, see <http://www.gnu.org/licenses/>.
import ctypes
import llvm._core as _core # C wrappers
from llvm import *
from llvm.core import *
from llvm.ee import *
from llvm.passes import *
if not globals().has_key('__loaded'):
__llvm_module=Module.new('llvm_module')
__fopt = FunctionPassManager.new(__llvm_module)
__ee = ExecutionEngine.new(__llvm_module)
__fopt.add(__ee.target_data)
__constant_cache={}
__function_cache={}
__ctypes_string_cache={}
for i in (
PASS_SIMPLIFY_LIBCALLS,
PASS_REASSOCIATE,
PASS_GVN,
PASS_SIMPLIFYCFG,
PASS_TAILCALLELIM,
PASS_INSTCOMBINE,
):
__fopt.add(i)
__fopt.initialize()
__loaded=1
def llvm_intrinsic(which,*types): return Function.intrinsic(__llvm_module, which, types)
def llvm_create_global_constant(llvm_v):
n=str(llvm_v)
if(__constant_cache.has_key(n)):
return __constant_cache[n]
v=__llvm_module.add_global_variable(llvm_v.type,'const_$')
v.initializer=llvm_v
v.global_constant=True
__constant_cache[n]=v
return v
def llvm_add_function(sig,name): return __llvm_module.add_function(sig,name)
def llvm_get_function(sig,name):
try: fun=__llvm_module.get_function_named(name)
except: return None
if sig!=llvm_type_fun(fun.type): raise TypeError('Wrong signature for function '+name)
return fun
def llvm_get_or_insert_function(sig,name):
fun=__llvm_module.get_or_insert_function(sig,name)
if sig!=llvm_type_fun(fun.type): raise TypeError('Wrong signature for function '+name)
return fun
def llvm_function(sig, name, key, internal):
if(__function_cache.has_key(key)):
return __function_cache[key]
if internal:
f=llvm_add_function(sig,name)
f.linkage=LINKAGE_INTERNAL
else:
f=llvm_get_function(sig,name)
if not f:
f=llvm_add_function(sig,name)
__function_cache[key]=f
return f
def llvm_opt_function(f): return __fopt.run(f)
def llvm_run_function(f,a): return __ee.run_function(f,a)
def llvm_dump_module(): print __llvm_module
def llvm_type(py_t, py_v=None):
if issubclass(py_t,int):
return Type.int(32)
if issubclass(py_t,long):
return Type.int(64)
if issubclass(py_t,float):
return Type.double()
if issubclass(py_t,str):
return Type.array(Type.int(8),len(py_v)+1)
return Type.void()
def llvm_is_int(llvm_t):
return llvm_t.kind in (TYPE_INTEGER,)
def llvm_is_number(llvm_t):
return llvm_t.kind in (TYPE_INTEGER,TYPE_FLOAT,TYPE_DOUBLE)
def llvm_is_array(llvm_t):
return llvm_t.kind in (TYPE_ARRAY,)
def llvm_is_pointer(llvm_t):
return llvm_t.kind in (TYPE_ARRAY,TYPE_POINTER)
def llvm_type_fun(llvm_t):
if(llvm_t.kind==TYPE_POINTER):
llvm_t=llvm_t.pointee
if(llvm_t.kind!=TYPE_FUNCTION):
raise TypeError('Not a function')
return llvm_t
def llvm_type_elem(llvm_st):
if(llvm_st.kind==TYPE_ARRAY):
return llvm_st.element
if(llvm_st.kind==TYPE_POINTER):
return llvm_st.pointee
return Type.void()
def llvm_type_promote(llvm_t1,llvm_t2):
if(llvm_t1==llvm_t2):
return llvm_t1
if(llvm_is_int(llvm_t1) and llvm_is_int(llvm_t2)):
return llvm_t1 if llvm_t1.width>=llvm_t2.width else llvm_t2
if(llvm_is_number(llvm_t1) and llvm_is_number(llvm_t2)):
return Type.double()
if(llvm_is_array(llvm_t1) and llvm_is_array(llvm_t2) and llvm_t1.count==llvm_t2.count):
return Type.array(llvm_type_promote(llvm_t1.element,llvm_t2.element),llvm_t1.count)
if(llvm_is_pointer(llvm_t1) and llvm_is_pointer(llvm_t2)):
return Type.pointer(llvm_type_promote(llvm_type_elem(llvm_t1),llvm_type_elem(llvm_t2)))
return Type.void()
def llvm_type_apply(llvm_t, *llvm_ts):
llvm_t=llvm_type_fun(llvm_t)
if len(llvm_ts)!=llvm_t.arg_count:
raise TypeError('Wrong number of argument in function call')
return llvm_t.return_type
def ctypes_type(llvm_t):
if(llvm_is_int(llvm_t)):
if llvm_t.width <=32:
return ctypes.c_int32
else:
return ctypes.c_int64
if(llvm_is_number(llvm_t)):
return ctypes.c_double
if(llvm_is_pointer(llvm_t)):
if(llvm_type_elem(llvm_t)==Type.int(8)):
return ctypes.c_char_p
return ctypes.c_void_p
if(llvm_t==Type.void()):
return None
raise TypeError('Cannot convert llvm type '+str(llvm_t)+' to ctypes')
def ctypes_ptr_to_int(ctypes_ptr): return ctypes.cast(ctypes_ptr,ctypes.c_voidp).value
def llvm_string_argument(llvm_t,py_v):
n=str(py_v)
if(not __ctypes_string_cache.has_key(n)):
__ctypes_string_cache[n]=ctypes.c_char_p(n)
return GenericValue.pointer(llvm_t,ctypes_ptr_to_int(__ctypes_string_cache[n]))
def llvm_arg_value(llvm_t,py_v):
if(llvm_is_int(llvm_t)):
return GenericValue.int_signed(llvm_t,py_v)
if(llvm_is_number(llvm_t)):
return GenericValue.real(llvm_t,py_v)
if(llvm_t==Type.pointer(Type.int(8))):
return llvm_string_argument(llvm_t,py_v)
if(llvm_is_pointer(llvm_t)):
return GenericValue.pointer(llvm_t,py_v) #TODO: properly handle python iterables and ctypes here
raise NotImplementedError()
def llvm_rt_value(llvm_t,py_v):
if(llvm_is_int(llvm_t)):
return Constant.int_signextend(llvm_t,py_v)
if(llvm_is_number(llvm_t)):
return Constant.real(llvm_t,py_v)
if(llvm_t.kind==TYPE_ARRAY and llvm_t.element==Type.int(8)):
return Constant.stringz(py_v)
raise NotImplementedError()
def llvm_funptr(llvm_t,ptr):
cproto=ctypes.CFUNCTYPE(ctypes_type(llvm_t.return_type),*(ctypes_type(t) for t in llvm_t.args))
retval=cproto(ptr)
if not llvm_is_pointer(llvm_t.return_type): return retval
try:
rt=llvm_type_fun(llvm_t.return_type)
def make_cast_retval(rt,retval):
def cast_retval(*args):
return llvm_funptr(rt,retval(*args))
return cast_retval
return make_cast_retval(rt,retval)
except:
return retval
def python_value(llvm_t,llvm_v):
if(llvm_is_int(llvm_t)):
return llvm_v.as_int_signed()
if(llvm_is_number(llvm_t)):
return llvm_v.as_real(llvm_t)
if(llvm_t==Type.pointer(Type.int(8))):
p=llvm_v.as_pointer()
return str(ctypes.string_at(llvm_v.as_pointer())) if p else None
try:
return llvm_funptr(llvm_type_fun(llvm_t),llvm_v.as_pointer())
except: pass
raise NotImplementedError()
class enhanced_builder(Builder):
def __init__(self, ptr):
return Builder.__init__(self,ptr)
@staticmethod
def new(basic_block):
import llvm._core
check_is_basic_block(basic_block)
b = enhanced_builder(llvm._core.LLVMCreateBuilder())
b.position_at_end(basic_block)
return b
def coerce(self,llvm_t,llvm_v,nm):
if(llvm_t==llvm_v.type):
return llvm_v
if(llvm_is_int(llvm_t) and llvm_t.width==1):
return self.cmp('ne',llvm_v,llvm_rt_value(llvm_v.type,0),nm)
if(llvm_is_int(llvm_t)):
if(llvm_is_int(llvm_v.type)):
if(llvm_v.type.width<=llvm_t.width):
return self.sext(llvm_v,llvm_t,nm)
else:
return self.trunc(llvm_v,llvm_t,nm)
if(llvm_is_number(llvm_v.type)):
return self.fptosi(llvm_v,llvm_t,nm)
return llvm_v
if(llvm_is_number(llvm_t)):
if(llvm_is_int(llvm_v.type)):
return self.sitofp(llvm_v,llvm_t,nm)
if(llvm_is_number(llvm_v.type)):
return self.fpext(llvm_v,llvm_t,nm)
return llvm_v
if(llvm_t.kind==TYPE_POINTER and llvm_v.type.kind==TYPE_ARRAY):
if isinstance(llvm_v,Constant):
llvm_v=llvm_create_global_constant(llvm_v)
zero=llvm_rt_value(Type.int(32),0)
return self.gep(llvm_v,(zero,zero),nm)
raise TypeError('Cannot convert '+str(llvm_v.type)+' to '+str(llvm_t))
def abs(self,llvm_v,nm):
if(llvm_is_int(llvm_v.type)):
neg_v = self.neg(llvm_v)
pos = self.icmp(llvm_v,Constant.int(llvm_v.type,0))
return self.select(pos,llvm_t,neg_v)
if(llvm_is_number(llvm_v.type)):
neg_v = self.neg(llvm_v)
pos = self.icmp(llvm_v,Constant.real(llvm_v.type,0))
return self.select(pos,llvm_t,neg_v,nm)
raise NotImplementedError()
def add(self,llvm_v1,llvm_v2,nm):
ty=llvm_v1.type
if(llvm_is_int(ty)):
return super(enhanced_builder,self).add(llvm_v1,llvm_v2, nm)
if(llvm_is_number(ty)):
return self.fadd(llvm_v1,llvm_v2, nm)
raise NotImplementedError()
def sub(self,llvm_v1,llvm_v2,nm):
ty=llvm_v1.type
if(llvm_is_int(ty)):
return super(enhanced_builder,self).sub(llvm_v1,llvm_v2,nm)
if(llvm_is_number(ty)):
return self.fsub(llvm_v1,llvm_v2,nm)
raise NotImplementedError()
def mul(self,llvm_v1,llvm_v2,nm):
ty=llvm_v1.type
if(llvm_is_int(ty)):
return super(enhanced_builder,self).mul(llvm_v1,llvm_v2,nm)
if(llvm_is_number(ty)):
return self.fmul(llvm_v1,llvm_v2,nm)
raise NotImplementedError()
def div(self,llvm_v1,llvm_v2,nm):
ty=llvm_v1.type
if(llvm_is_int(ty)):
return self.sdiv(llvm_v1,llvm_v2,nm)
if(llvm_is_number(ty)):
return self.fdiv(llvm_v1,llvm_v2,nm)
raise NotImplementedError()
def mod(self,llvm_v1,llvm_v2,nm):
ty=llvm_v1.type
if(llvm_is_int(ty)):
return self.smod(llvm_v1,llvm_v2,nm)
if(llvm_is_number(ty)):
return self.fmod(llvm_v1,llvm_v2,nm)
raise NotImplementedError()
def divmod(self,llvm_v1,llvm_v2,nm):
ty=llvm_v1.type
if(llvm_is_int(ty)):
q=self.sdiv(llvm_v1,llvm_v2,nm+'.quot')
r=self.smod(llvm_v1,llvm_v2,nm+'.rem')
return None
if(llvm_is_number(ty)):
q=self.fdiv(llvm_v1,llvm_v2,nm+'.quot')
r=self.fmod(llvm_v1,llvm_v2,nm+'.rem')
return None
raise NotImplementedError()
def cmp(self,op,llvm_v1,llvm_v2,nm):
if(llvm_is_int(llvm_v1.type)):
return self.icmp({'eq':IPRED_EQ,
'ge':IPRED_SGE,
'gt':IPRED_SGT,
'le':IPRED_SLE,
'lt':IPRED_SLT,
'ne':IPRED_NE}[op],llvm_v1,llvm_v2,nm)
if(llvm_is_number(llvm_v1.type)):
return self.fcmp({'eq':RPRED_OEQ,
'ge':RPRED_OGE,
'gt':RPRED_OGT,
'le':RPRED_OLE,
'lt':RPRED_OLT,
'ne':RPRED_ONE}[op],llvm_v1,llvm_v2,nm)
raise NotImplementedError()
def pow(self,llvm_v1,llvm_v2,nm):
pow_intr = llvm_intrinsic(INTR_POW, llvm_v1.type, llvm_v2.type)
return self.call(pow_intr, (llvm_v1,llvm_v2),nm)
class llvm_compiled_fun(object):
def __init__(self, llvm_fun, arg_t, return_t):
self.llvm_fun=llvm_fun
self.arg_t=arg_t
self.return_t=return_t
def __call__(self,**args):
arg=[llvm_arg_value(t,args[v]) for (t,v) in self.arg_t]
return python_value(self.return_t,llvm_run_function(self.llvm_fun,arg))
class llvm_compiled_ordered_fun(object):
def __init__(self, llvm_fun, arg_t, return_t):
self.llvm_fun=llvm_fun
self.arg_t=arg_t
self.return_t=return_t
def __call__(self,*args):
arg=[llvm_arg_value(*t_v) for t_v in zip(self.arg_t,args)]
return python_value(self.return_t,llvm_run_function(self.llvm_fun,arg))
def llvm_compiled_funptr(llvm_fun):
return llvm_funptr(llvm_type_fun(llvm_fun.type),__ee.get_pointer_to_function(llvm_fun))