@@ -225,7 +225,7 @@ struct LLVMFunc{F,tt}
225225 entry:: String
226226end
227227
228- function Base. getproperty (f:: LLVMFunc{F, tt} , sym:: Symbol ) where {F, tt}
228+ function Base. getproperty (f:: LLVMFunc{F,tt} , sym:: Symbol ) where {F,tt}
229229 if sym === :fun
230230 f
231231 else
235235
236236# TODO in the future we may want to avoid doing a second cufunction compilation
237237# for computing the thread/block count (or potentially do it ourselves).
238- @noinline function CUDA. launch_configuration (f:: LLVMFunc{F, tt} ; shmem:: Union{Integer, Base.Callable} = 0 , max_threads:: Integer = 0 ) where {F, tt}
239- CUDA. launch_configuration (Base. inferencebarrier (CUDA. cufunction)(f. f, Tuple{tt. parameters[2 : end ]. .. }). fun; shmem, max_threads)
238+ @noinline function CUDA. launch_configuration (
239+ f:: LLVMFunc{F,tt} ; shmem:: Union{Integer,Base.Callable} = 0 , max_threads:: Integer = 0
240+ ) where {F,tt}
241+ return CUDA. launch_configuration (
242+ Base. inferencebarrier (CUDA. cufunction)(f. f, Tuple{tt. parameters[2 : end ]. .. }). fun;
243+ shmem,
244+ max_threads,
245+ )
240246end
241247
242248const GPUCompiler = CUDA. GPUCompiler
@@ -282,7 +288,12 @@ function compile(job)
282288 entry = GPUCompiler. JuliaContext () do ctx
283289 mod, meta = GPUCompiler. compile (
284290 # :llvm, job; optimize=false, cleanup=false, validate=false, libraries=true
285- :llvm , job; optimize= false , cleanup= false , validate= false , libraries= false
291+ :llvm ,
292+ job;
293+ optimize= false ,
294+ cleanup= false ,
295+ validate= false ,
296+ libraries= false ,
286297 # :llvm, job; optimize=false, cleanup=false, validate=true, libraries=false
287298 # :llvm, job; optimize=false, cleanup=false, validate=false, libraries=false
288299 )
@@ -357,19 +368,21 @@ function link(job, compiled)
357368end
358369
359370function to_bytes (x)
360- sz = sizeof (x)
361- ref = Ref (x)
362- GC. @preserve ref begin
363- ptr = Base. reinterpret (Ptr{UInt8}, Base. unsafe_convert (Ptr{Cvoid}, ref))
364- vec = Vector {UInt8} (undef, sz)
365- for i in 1 : sz
366- @inbounds vec[i] = Base. unsafe_load (ptr, i)
367- end
368- vec
369- end
370- end
371-
372- function Reactant. make_tracer (seen, @nospecialize (prev:: CuTracedArray ), @nospecialize (path), mode; kwargs... )
371+ sz = sizeof (x)
372+ ref = Ref (x)
373+ GC. @preserve ref begin
374+ ptr = Base. reinterpret (Ptr{UInt8}, Base. unsafe_convert (Ptr{Cvoid}, ref))
375+ vec = Vector {UInt8} (undef, sz)
376+ for i in 1 : sz
377+ @inbounds vec[i] = Base. unsafe_load (ptr, i)
378+ end
379+ vec
380+ end
381+ end
382+
383+ function Reactant. make_tracer (
384+ seen, @nospecialize (prev:: CuTracedArray ), @nospecialize (path), mode; kwargs...
385+ )
373386 x = Base. unsafe_pointer_to_objref (Base. reinterpret (Ptr{Cvoid}, prev. ptr)):: TracedRArray
374387 Reactant. make_tracer (seen, x, path, mode; kwargs... )
375388 return prev
@@ -388,7 +401,9 @@ function get_field_offset(T::Type, path)
388401 findfirst (== (field), fieldnames (current_type))
389402 end
390403 if field_idx === nothing
391- error (" Field $field not found in type $current_type , fieldnames=$(fieldnames (current_type)) T=$T path=$path " )
404+ error (
405+ " Field $field not found in type $current_type , fieldnames=$(fieldnames (current_type)) T=$T path=$path " ,
406+ )
392407 end
393408
394409 # Add the offset of this field
@@ -419,7 +434,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
419434 rarrays = TracedRArray[]
420435
421436 fname = func. entry
422-
437+
423438 wrapper_tys = MLIR. IR. Type[]
424439 ctx = MLIR. IR. context ()
425440 cullvm_ty = MLIR. IR. Type (MLIR. API. mlirLLVMPointerTypeGet (ctx, 1 ))
@@ -436,19 +451,23 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
436451 end
437452 push! (wrapper_tys, cullvm_ty)
438453 end
439-
454+
440455 sym_name = String (gensym (" call_$fname " ))
441456 mod = MLIR. IR. mmodule ()
442- CConv= MLIR. IR. Attribute (MLIR. API. mlirLLVMCConvAttrGet (ctx, MLIR. API. MlirLLVMCConvPTX_Kernel))
457+ CConv = MLIR. IR. Attribute (
458+ MLIR. API. mlirLLVMCConvAttrGet (ctx, MLIR. API. MlirLLVMCConvPTX_Kernel)
459+ )
443460 voidty = MLIR. IR. Type (MLIR. API. mlirLLVMVoidTypeGet (ctx))
444- wrapftype = MLIR. IR. Type (MLIR. API. mlirLLVMFunctionTypeGet (voidty, length (wrapper_tys), wrapper_tys, false ))
461+ wrapftype = MLIR. IR. Type (
462+ MLIR. API. mlirLLVMFunctionTypeGet (voidty, length (wrapper_tys), wrapper_tys, false )
463+ )
445464 wrapfunc = MLIR. IR. block! (MLIR. IR. body (mod)) do
446465 return MLIR. Dialects. llvm. func (;
447466 sym_name,
448467 sym_visibility= MLIR. IR. Attribute (" private" ),
449468 function_type= wrapftype,
450469 body= MLIR. IR. Region (),
451- CConv
470+ CConv,
452471 )
453472 end
454473 wrapbody = MLIR. IR. Block (wrapper_tys, [MLIR. IR. Location () for _ in wrapper_tys])
@@ -459,11 +478,17 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
459478
460479 symtab = MLIR. IR. SymbolTable (MLIR. IR. Operation (mod))
461480 gpufunc = MLIR. IR. lookup (symtab, fname)
462- MLIR. IR. attr! (gpufunc, " CConv" , MLIR. IR. Attribute (MLIR. API. mlirLLVMCConvAttrGet (ctx, MLIR. API. MlirLLVMCConvC)))
463- gpu_function_type = MLIR. IR. Type (Reactant. TracedUtils. get_attribute_by_name (gpufunc, " function_type" ))
481+ MLIR. IR. attr! (
482+ gpufunc,
483+ " CConv" ,
484+ MLIR. IR. Attribute (MLIR. API. mlirLLVMCConvAttrGet (ctx, MLIR. API. MlirLLVMCConvC)),
485+ )
486+ gpu_function_type = MLIR. IR. Type (
487+ Reactant. TracedUtils. get_attribute_by_name (gpufunc, " function_type" )
488+ )
464489
465490 trueidx = 1
466- allocs = Union{Tuple{MLIR. IR. Value, MLIR. IR. Type}, Nothing}[]
491+ allocs = Union{Tuple{MLIR. IR. Value,MLIR. IR. Type},Nothing}[]
467492
468493 llvmptr = MLIR. IR. Type (MLIR. API. mlirLLVMPointerTypeGet (ctx, 0 ))
469494 i8 = MLIR. IR. Type (UInt8)
@@ -476,18 +501,34 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
476501
477502 # TODO check for only integer and explicitly non cutraced types
478503 MLIR. IR. block! (wrapbody) do
479- argty = MLIR. IR. Type (MLIR. API. mlirLLVMFunctionTypeGetInput (gpu_function_type, trueidx- 1 ))
504+ argty = MLIR. IR. Type (
505+ MLIR. API. mlirLLVMFunctionTypeGetInput (gpu_function_type, trueidx - 1 )
506+ )
480507 trueidx += 1
481- c1 = MLIR. IR. result (MLIR. Dialects. llvm. mlir_constant (; res= MLIR. IR. Type (Int64), value= MLIR. IR. Attribute (1 )), 1 )
482- alloc = MLIR. IR. result (MLIR. Dialects. llvm. alloca (c1; elem_type= MLIR. IR. Attribute (argty), res= llvmptr), 1 )
508+ c1 = MLIR. IR. result (
509+ MLIR. Dialects. llvm. mlir_constant (;
510+ res= MLIR. IR. Type (Int64), value= MLIR. IR. Attribute (1 )
511+ ),
512+ 1 ,
513+ )
514+ alloc = MLIR. IR. result (
515+ MLIR. Dialects. llvm. alloca (
516+ c1; elem_type= MLIR. IR. Attribute (argty), res= llvmptr
517+ ),
518+ 1 ,
519+ )
483520 push! (allocs, (alloc, argty))
484521
485522 sz = sizeof (a)
486523 array_ty = MLIR. IR. Type (MLIR. API. mlirLLVMArrayTypeGet (MLIR. IR. Type (Int8), sz))
487- cdata = MLIR. IR. result (MLIR. Dialects. llvm. mlir_constant (; res= array_ty, value= MLIR. IR. DenseElementsAttribute (to_bytes (a))), 1 )
524+ cdata = MLIR. IR. result (
525+ MLIR. Dialects. llvm. mlir_constant (;
526+ res= array_ty, value= MLIR. IR. DenseElementsAttribute (to_bytes (a))
527+ ),
528+ 1 ,
529+ )
488530 MLIR. Dialects. llvm. store (cdata, alloc)
489531 end
490-
491532 end
492533
493534 argidx = 1
@@ -499,21 +540,30 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
499540 if p[1 ] != = kernelargsym
500541 continue
501542 end
502-
543+
503544 arg = arg. mlir_data
504545 arg = Reactant. TracedUtils. transpose_val (arg)
505546 push! (restys, MLIR. IR. type (arg))
506547 push! (mlir_args, arg)
507-
548+
508549 # Get the allocation corresponding to which arg we're doing
509550 alloc = allocs[p[2 ]][1 ]
510551
511552 # we need to now compute the offset in bytes of the path
512553 julia_arg = allargs[p[2 ]]
513-
554+
514555 offset = get_field_offset (typeof (julia_arg), p[3 : end ])
515556 MLIR. IR. block! (wrapbody) do
516- ptr = MLIR. IR. result (MLIR. Dialects. llvm. getelementptr (alloc, MLIR. IR. Value[], res= llvmptr, elem_type= i8, rawConstantIndices= MLIR. IR. Attribute ([Int32 (offset)])), 1 )
557+ ptr = MLIR. IR. result (
558+ MLIR. Dialects. llvm. getelementptr (
559+ alloc,
560+ MLIR. IR. Value[];
561+ res= llvmptr,
562+ elem_type= i8,
563+ rawConstantIndices= MLIR. IR. Attribute ([Int32 (offset)]),
564+ ),
565+ 1 ,
566+ )
517567 MLIR. Dialects. llvm. store (MLIR. IR. argument (wrapbody, argidx), ptr)
518568 end
519569
@@ -530,11 +580,11 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
530580 ),
531581 ),
532582 )
533-
583+
534584 argidx += 1
535585 end
536586 end
537-
587+
538588 MLIR. IR. block! (wrapbody) do
539589 for arg in allocs
540590 if arg === nothing
@@ -544,7 +594,12 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
544594 argres = MLIR. IR. result (MLIR. Dialects. llvm. load (alloc; res= argty), 1 )
545595 push! (wrapargs, argres)
546596 end
547- MLIR. Dialects. llvm. call (wrapargs, MLIR. IR. Value[]; callee= MLIR. IR. FlatSymbolRefAttribute (Base. String (fname)), op_bundle_sizes= MLIR. IR. Attribute (Int32[]))
597+ MLIR. Dialects. llvm. call (
598+ wrapargs,
599+ MLIR. IR. Value[];
600+ callee= MLIR. IR. FlatSymbolRefAttribute (Base. String (fname)),
601+ op_bundle_sizes= MLIR. IR. Attribute (Int32[]),
602+ )
548603 MLIR. Dialects. llvm. return_ (nothing )
549604 end
550605
@@ -565,7 +620,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
565620 mlir_args;
566621 result_0= restys,
567622 fn= MLIR. IR. FlatSymbolRefAttribute (sym_name),
568- output_operand_aliases= MLIR. IR. Attribute (output_operand_aliases)
623+ output_operand_aliases= MLIR. IR. Attribute (output_operand_aliases),
569624 )
570625
571626 argidx = 1
@@ -574,7 +629,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
574629 continue
575630 end
576631 arg. mlir_data = Reactant. TracedUtils. transpose_val (MLIR. IR. result (call, argidx))
577- argidx+= 1
632+ argidx += 1
578633 end
579634end
580635
0 commit comments