Skip to content
This repository has been archived by the owner on May 27, 2021. It is now read-only.

Commit

Permalink
Try #334:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] committed Feb 11, 2019
2 parents e687697 + 94c4b7c commit 01d968e
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 3 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Expand All @@ -20,3 +21,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "BenchmarkTools", "SpecialFunctions"]

[compat]
julia = ">= 1.1"
2 changes: 1 addition & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
julia 1.0
julia 1.1
CUDAdrv 1.1
LLVM 0.9.14
CUDAapi 0.4.0
Expand Down
2 changes: 2 additions & 0 deletions src/CUDAnative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ end
include("utils.jl")

# needs to be loaded _before_ the compiler infrastructure, because of generated functions
isdevice() = false
include(joinpath("device", "tools.jl"))
include(joinpath("device", "pointer.jl"))
include(joinpath("device", "array.jl"))
Expand All @@ -31,6 +32,7 @@ include(joinpath("device", "cuda_intrinsics.jl"))
include(joinpath("device", "runtime_intrinsics.jl"))

include("compiler.jl")
include("context.jl")
include("execution.jl")
include("reflection.jl")

Expand Down
52 changes: 52 additions & 0 deletions src/context.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
##
# Implements contextual dispatch through Cassette.jl
# Goals:
# - Rewrite common CPU functions to appropriate GPU intrinsics
#
# TODO:
# - error (erf, ...)
# - pow
# - min, max
# - mod, rem
# - gamma
# - bessel
# - distributions
# - unsorted

using Cassette

function transform(ctx, ref)
ci = ref.code_info
noinline = any(@nospecialize(x) -> Core.Compiler.isexpr(x, :meta) && x.args[1] == :noinline, ci.code)
if !noinline
ci.inlineable = true
end
return ci
end
const InlinePass = Cassette.@pass transform

Cassette.@context CUDACtx
const cudactx = Cassette.disablehooks(CUDACtx(pass = InlinePass))

Cassette.overdub(::CUDACtx, ::typeof(datatype_align), ::Type{T}) where {T} = datatype_align(T)
Cassette.overdub(ctx::CUDACtx, ::typeof(isdevice)) = true

# libdevice.jl
for f in (:cos, :cospi, :sin, :sinpi, :tan,
:acos, :asin, :atan,
:cosh, :sinh, :tanh,
:acosh, :asinh, :atanh,
:log, :log10, :log1p, :log2,
:exp, :exp2, :exp10, :expm1, :ldexp,
:isfinite, :isinf, :isnan,
:signbit, :abs,
:sqrt, :cbrt,
:ceil, :floor,)
@eval function Cassette.overdub(ctx::CUDACtx, ::typeof(Base.$f), x::Union{Float32, Float64})
@Base._inline_meta
return CUDAnative.$f(x)
end
end

contextualize(f::F) where F = (args...) -> Cassette.overdub(cudactx, f, args...)

6 changes: 4 additions & 2 deletions src/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ kernel to determine the launch configuration:
GC.@preserve args begin
kernel_args = cudaconvert.(args)
kernel_tt = Tuple{Core.Typeof.(kernel_args)...}
kernel = cufunction(f, kernel_tt; compilation_kwargs)
kernel_f = contextualize(f)
kernel = cufunction(kernel_f, kernel_tt; compilation_kwargs)
kernel(kernel_args...; launch_kwargs)
end
"""
Expand Down Expand Up @@ -205,7 +206,8 @@ macro cuda(ex...)
GC.@preserve $(vars...) begin
local kernel_args = cudaconvert.(($(var_exprs...),))
local kernel_tt = Tuple{Core.Typeof.(kernel_args)...}
local kernel = cufunction($(esc(f)), kernel_tt; $(map(esc, compiler_kwargs)...))
local kernel_f = contextualize($(esc(f)))
local kernel = cufunction(kernel_f, kernel_tt; $(map(esc, compiler_kwargs)...))
kernel(kernel_args...; $(map(esc, call_kwargs)...))
end
end)
Expand Down

0 comments on commit 01d968e

Please sign in to comment.