Skip to content

Mark more methods as device methods #2336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: vc/precompile_tools
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
@@ -65,6 +66,7 @@ Libdl = "1"
LinearAlgebra = "1"
Logging = "1"
NVTX = "0.3.2"
PrecompileTools = "1.2.1"
Preferences = "1"
PrettyTables = "2"
Printf = "1"
6 changes: 3 additions & 3 deletions src/device/intrinsics/atomics.jl
Original file line number Diff line number Diff line change
@@ -151,7 +151,7 @@ for A in (AS.Generic, AS.Global, AS.Shared), T in (:Int16, :UInt16)
end

intr = "atom$scope.cas.b16 \$0, [\$1], \$2, \$3;"
@eval @inline atomic_cas!(ptr::LLVMPtr{$T,$A}, cmp::$T, val::$T) =
@eval @device_function @inline atomic_cas!(ptr::LLVMPtr{$T,$A}, cmp::$T, val::$T) =
@asmcall($intr, "=h,l,h,h", true, $T, Tuple{Core.LLVMPtr{$T,$A},$T,$T}, ptr, cmp, val)
end

@@ -172,7 +172,7 @@ for A in (AS.Generic, AS.Global, AS.Shared)
nb = sizeof(T)*8
fn = Symbol("atomic_$(op)!")
intr = "llvm.nvvm.atomic.load.$op.$nb.p$(convert(Int, A))i$nb"
@eval @inline $fn(ptr::LLVMPtr{$T,$A}, val::$T) =
@eval @device_function @inline $fn(ptr::LLVMPtr{$T,$A}, val::$T) =
@typed_ccall($intr, llvmcall, $T, (LLVMPtr{$T,$A}, $T), ptr, val)
end
end
@@ -192,7 +192,7 @@ for A in (AS.Generic, AS.Global, AS.Shared), T in (:Float16,)
end

intr = "atom$scope.add.noftz.f16 \$0, [\$1], \$2;"
@eval @inline atomic_add!(ptr::LLVMPtr{$T,$A}, val::$T) =
@eval @device_function @inline atomic_add!(ptr::LLVMPtr{$T,$A}, val::$T) =
@asmcall($intr, "=h,l,h", true, $T, Tuple{Core.LLVMPtr{$T,$A},$T}, ptr, val)
end

11 changes: 6 additions & 5 deletions src/device/intrinsics/cooperative_groups.jl
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@ Noteworthy missing functionality:
module CG

using ..CUDA
using ..CUDA: i32, Aligned, alignment
using ..CUDA: i32, Aligned, alignment, @device_function

using ..LLVM.Interop
using ..LLVMLoopInfo
@@ -70,7 +70,7 @@ const grid_workspace = Ptr{grid_workspace_st}
end
end

function get_grid_workspace()
@device_function function get_grid_workspace()
# interpret the address from envreg 1 and 2 as the driver's grid workspace
hi = ccall("llvm.nvvm.read.ptx.sreg.envreg1", llvmcall, UInt32, ())
lo = ccall("llvm.nvvm.read.ptx.sreg.envreg2", llvmcall, UInt32, ())
@@ -370,7 +370,7 @@ end
return oldArrive
end

@inline function barrier_wait(gg::grid_group, token)
@device_function @inline function barrier_wait(gg::grid_group, token)
arrived = gg.details.barrier

if is_cta_master()
@@ -548,11 +548,12 @@ end

## pipeline operations

pipeline_commit() = ccall("llvm.nvvm.cp.async.commit.group", llvmcall, Cvoid, ())
@device_function pipeline_commit() = ccall("llvm.nvvm.cp.async.commit.group", llvmcall, Cvoid, ())

pipeline_wait_prior(n) =
@device_function pipeline_wait_prior(n) =
ccall("llvm.nvvm.cp.async.wait.group", llvmcall, Cvoid, (Int32,), n)

# TODO device function?
@generated function pipeline_memcpy_async(dst::LLVMPtr{T}, src::LLVMPtr{T}) where T
size_and_align = sizeof(T)
size_and_align in (4, 8, 16) || :(return error($"Unsupported size $size_and_align"))
3 changes: 3 additions & 0 deletions src/device/intrinsics/misc.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
export clock, nanosleep

@device_functions begin
"""
exit()
@@ -34,3 +35,5 @@ Puts a thread for a given amount `t`(in nanoseconds).
@asmcall("nanosleep.u32 \$0;", "r", true,
Cvoid, Tuple{UInt32}, convert(UInt32, t))
end

end
5 changes: 4 additions & 1 deletion src/device/intrinsics/synchronization.jl
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
export sync_threads, sync_warp
export sync_threads_count, sync_threads_and, sync_threads_or

@device_functions begin
"""
sync_threads()
@@ -64,7 +65,7 @@ the warp.

export barrier_sync

barrier_sync(id=0) = ccall("llvm.nvvm.barrier.sync", llvmcall, Cvoid, (Int32,), id)
@inline barrier_sync(id=0) = ccall("llvm.nvvm.barrier.sync", llvmcall, Cvoid, (Int32,), id)


## memory barriers (membar)
@@ -107,3 +108,5 @@ host threads, and all threads in peer devices as occurring before all writes to
memory made by the calling thread after the call to `threadfence_system()`.
"""
@inline threadfence_system() = ccall("llvm.nvvm.membar.sys", llvmcall, Cvoid, ())

end
2 changes: 1 addition & 1 deletion src/device/intrinsics/version.jl
Original file line number Diff line number Diff line change
@@ -50,7 +50,7 @@ end
export compute_capability, ptx_isa_version

for var in ["sm_major", "sm_minor", "ptx_major", "ptx_minor"]
@eval @inline $(Symbol(var))() =
@eval @device_function @inline $(Symbol(var))() =
Base.llvmcall(
$("""@$var = external global i32
define i32 @entry() #0 {
6 changes: 3 additions & 3 deletions src/device/intrinsics/warp.jl
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@ for (name, mode, mask, offset) in (("_up", :up, UInt32(0x00), src->src),
for (T,typ) in ((Int32, "i32"), (UInt32, "i32"), (Float32, "f32"))
intrinsic = "llvm.nvvm.shfl.sync.$mode.$typ"
@eval begin
@inline $fname(mask, val::$T, src, width=$ws) =
@device_function @inline $fname(mask, val::$T, src, width=$ws) =
ccall($intrinsic, llvmcall, $T,
(UInt32, $T, UInt32, UInt32),
mask, val, $(offset(:src)), pack(width, $mask))
@@ -109,7 +109,7 @@ for mode in (:all, :any, :uni)
@eval export $fname

intrinsic = "llvm.nvvm.vote.$mode.sync"
@eval @inline $fname(mask, pred) =
@eval @device_function @inline $fname(mask, pred) =
@typed_ccall($intrinsic, llvmcall, Bool, (UInt32, Bool), mask, pred)
end

@@ -119,7 +119,7 @@ for mode in (:ballot, )
@eval export $fname

intrinsic = "llvm.nvvm.vote.$mode.sync"
@eval @inline $fname(mask, pred) =
@eval @device_function @inline $fname(mask, pred) =
@typed_ccall($intrinsic, llvmcall, UInt32, (UInt32, Bool), mask, pred)
end

12 changes: 6 additions & 6 deletions src/device/intrinsics/wmma.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export WMMA
module WMMA

using ..CUDA: AS
using ..CUDA: AS, @device_function
using Core: LLVMPtr

################################################################################
@@ -196,10 +196,10 @@ for ops in all_ldst_ops,
ptr_ty = :(LLVMPtr{$arr_ty, $addr_space_int})

if sz == 1
@eval $func_name(src_addr, stride) = tuple(ccall($ccall_name, llvmcall, $frag_ty, ($ptr_ty, Int32), src_addr, stride))
@eval @device_function $func_name(src_addr, stride) = tuple(ccall($ccall_name, llvmcall, $frag_ty, ($ptr_ty, Int32), src_addr, stride))
else
struct_ty = Symbol("LLVMStruct$sz")
@eval $func_name(src_addr, stride) = convert(NTuple{$sz, $frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$frag_ty}, ($ptr_ty, Int32), src_addr, stride))
@eval @device_function $func_name(src_addr, stride) = convert(NTuple{$sz, $frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$frag_ty}, ($ptr_ty, Int32), src_addr, stride))
end
@eval export $func_name
@eval @doc (@doc llvm_wmma_load) $func_name
@@ -263,7 +263,7 @@ export llvm_wmma_store

ptr_ty = :(LLVMPtr{$arr_ty, $addr_space_int})

@eval $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, ($ptr_ty, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride)
@eval @device_function $func_name(dst_addr, data, stride) = ccall($ccall_name, llvmcall, Nothing, ($ptr_ty, $(frag_types...), Int32), dst_addr, $(frag_vars...), stride)
@eval export $func_name
@eval @doc (@doc llvm_wmma_store) $func_name
end
@@ -340,10 +340,10 @@ for ops in all_wmma_ops,
c_vars = ntuple(i -> :(c[$i]), c_sz)

if d_sz == 1
@eval $func_name(a, b, c) = tuple(ccall($ccall_name, llvmcall, $d_frag_ty, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)))
@eval @device_function $func_name(a, b, c) = tuple(ccall($ccall_name, llvmcall, $d_frag_ty, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)))
else
struct_ty = Symbol("LLVMStruct$d_sz")
@eval $func_name(a, b, c) = convert(NTuple{$d_sz, $d_frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)))
@eval @device_function $func_name(a, b, c) = convert(NTuple{$d_sz, $d_frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)))
end
@eval export $func_name
@eval @doc (@doc llvm_wmma_mma) $func_name
6 changes: 3 additions & 3 deletions src/device/utils.jl
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@ end
macro device_override(ex)
ex = macroexpand(__module__, ex)
esc(quote
Base.Experimental.@overlay(CUDA.method_table, $ex)
Base.Experimental.@overlay($(CUDA).method_table, $ex)
end)
end

@@ -31,7 +31,7 @@ macro device_function(ex)

esc(quote
$(combinedef(def))
@device_override $ex
$(CUDA).@device_override $ex
end)
end

@@ -47,7 +47,7 @@ macro device_functions(ex)
push!(out.args, rewrite(arg))
elseif Meta.isexpr(arg, [:function, :(=)])
# rewrite function definitions
push!(out.args, :(@device_function $arg))
push!(out.args, :($(CUDA).@device_function $arg))
else
# preserve all the rest
push!(out.args, arg)
14 changes: 14 additions & 0 deletions src/precompile.jl
Original file line number Diff line number Diff line change
@@ -14,3 +14,17 @@ precompile(run_and_collect, (Cmd,))
precompile(cudaconvert, (Function,))
precompile(Core.kwfunc(cudacall), (NamedTuple{(:threads, :blocks), Tuple{Int64, Int64}},typeof(cudacall),CuFunction,Type{Tuple{}}))
precompile(Core.kwfunc(launch), (NamedTuple{(:threads, :blocks), Tuple{Int64, Int64}},typeof(launch),CuFunction))

using PrecompileTools: @setup_workload, @compile_workload
@static if VERSION >= v"1.11.0-DEV.1603"
@setup_workload let
@compile_workload begin
target = PTXCompilerTarget(; cap=v"7.5")
params = CUDACompilerParams(; cap=v"7.5", ptx=v"7.5")
config = CompilerConfig(target, params)
mi = GPUCompiler.methodinstance(typeof(identity), Tuple{Nothing})
job = CompilerJob(mi, config)
GPUCompiler.code_native(devnull, job)
end
end
end