diff --git a/Project.toml b/Project.toml index a62a301c..ad3ab107 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3" CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde" +Cassette = "7057c7e9-c182-5462-911a-8362d720325c" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -20,3 +21,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Test", "BenchmarkTools", "SpecialFunctions"] + +[compat] +julia = ">= 1.1" diff --git a/REQUIRE b/REQUIRE index b900bcb2..fe245a10 100644 --- a/REQUIRE +++ b/REQUIRE @@ -1,4 +1,4 @@ -julia 1.0 +julia 1.1 CUDAdrv 1.0 LLVM 0.9.14 CUDAapi 0.4.0 diff --git a/src/CUDAnative.jl b/src/CUDAnative.jl index b676f59b..43b0f22d 100644 --- a/src/CUDAnative.jl +++ b/src/CUDAnative.jl @@ -23,6 +23,7 @@ end include("utils.jl") # needs to be loaded _before_ the compiler infrastructure, because of generated functions +isdevice() = false include(joinpath("device", "tools.jl")) include(joinpath("device", "pointer.jl")) include(joinpath("device", "array.jl")) @@ -31,6 +32,7 @@ include(joinpath("device", "cuda_intrinsics.jl")) include(joinpath("device", "runtime_intrinsics.jl")) include("compiler.jl") +include("context.jl") include("execution.jl") include("reflection.jl") diff --git a/src/context.jl b/src/context.jl new file mode 100644 index 00000000..5ac838a4 --- /dev/null +++ b/src/context.jl @@ -0,0 +1,50 @@ +## +# Implements contextual dispatch through Cassette.jl +# Goals: +# - Rewrite common CPU functions to appropriate GPU intrinsics +# +# TODO: +# - error (erf, ...) +# - pow +# - min, max +# - mod, rem +# - gamma +# - bessel +# - distributions +# - unsorted + +using Cassette + +function transform(ctx, ref) + ci = ref.code_info + ci.inlineable = true + return ci +end +const InlinePass = Cassette.@pass transform + +Cassette.@context CUDACtx +const cudactx = Cassette.disablehooks(CUDACtx(pass = InlinePass)) + +function Cassette.overdub(ctx::CUDACtx, ::typeof(isdevice)) + return true +end + +# libdevice.jl +for f in (:cos, :cospi, :sin, :sinpi, :tan, + :acos, :asin, :atan, + :cosh, :sinh, :tanh, + :acosh, :asinh, :atanh, + :log, :log10, :log1p, :log2, + :exp, :exp2, :exp10, :expm1, :ldexp, + :isfinite, :isinf, :isnan, + :signbit, :abs, + :sqrt, :cbrt, + :ceil, :floor,) + @eval function Cassette.overdub(ctx::CUDACtx, ::typeof(Base.$f), x::Union{Float32, Float64}) + @Base._inline_meta + return CUDAnative.$f(x) + end +end + +contextualize(f::F) where F = (args...) -> Cassette.overdub(cudactx, f, args...) + diff --git a/src/execution.jl b/src/execution.jl index 4977d245..96f9dc90 100644 --- a/src/execution.jl +++ b/src/execution.jl @@ -172,7 +172,8 @@ kernel to determine the launch configuration: GC.@preserve args begin kernel_args = cudaconvert.(args) kernel_tt = Tuple{Core.Typeof.(kernel_args)...} - kernel = cufunction(f, kernel_tt; compilation_kwargs) + kernel_f = contextualize(f) + kernel = cufunction(kernel_f, kernel_tt; compilation_kwargs) kernel(kernel_args...; launch_kwargs) end """ @@ -202,7 +203,8 @@ macro cuda(ex...) GC.@preserve $(vars...) begin local kernel_args = cudaconvert.(($(var_exprs...),)) local kernel_tt = Tuple{Core.Typeof.(kernel_args)...} - local kernel = cufunction($(esc(f)), kernel_tt; $(map(esc, compiler_kwargs)...)) + local kernel_f = contextualize($(esc(f))) + local kernel = cufunction(kernel_f, kernel_tt; $(map(esc, compiler_kwargs)...)) kernel(kernel_args...; $(map(esc, call_kwargs)...)) end end)