-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Investigate the performance issues and consider moving to GemmKernels.jl #2
Comments
Thanks! If possible, could you please also print the generated PTX code and show the nsight-sys screen shot here? These information could help analysing the performance difference between C implementation and GemmKernels implementation. For GemmKernel implementation, just go throught the steps here: |
Sure! Since both the ptx code and the benchmark result are both quite long, I will upload the file directly here Here are the benchmark result by Nvidia Nsight Compute of Figures of the ncu benchmark results: Result of Result of Here are the PTX code of |
Can you share your benchmarking code? I did a test myself, and with two performance fixes to GemmKernels.jl (JuliaGPU/GemmKernels.jl#182, and a tuned configuration) I'm getting very similar performance. For example, on a RTX6000 Ada using 4096x4096 Float32 inputs:
GemmKernels.jl seems consistently a little faster than CuTropicalGEMM; even re-using this block/operator configuration for different input sizes (i.e. where additional tuning might result in even better performance). I'm benchmarking using the following code: using CUDA, GemmKernels, LinearAlgebra
using TropicalNumbers, CuTropicalGEMM
using BenchmarkTools
function main()
M = K = N = 1024
A = CUDA.rand(Float32, M, K)
B = CUDA.rand(Float32, K, N)
C = CUDA.zeros(Float32, M, N)
print("CuTropicalGEMM: ")
let
tA = Tropical.(A)
tB = Tropical.(B)
tC = Tropical.(C)
@btime begin
mul!($tC, $tA, $tB)
# XXX: not sure why `CUDA.@sync` doesn't work here;
# is CuTropicalGEMM doing its own stream management?
device_synchronize()
end
end
print("GemmKernels: ")
let
# result of tuning
BLOCK_M = 128
BLOCK_N = 64
BLOCK_K = 32
OP_M = 16
OP_N = 4
OP_K = 4
OP_MB = 8
OP_NB = 4
OP_KB = 1
kernel = Kernel.matmul_pipelined
# pow2-sized, 128-bit aligned inputs, so we can use aligned layouts.
# we don't have transposed inputs, so everything is column major.
@assert stride(A, 2) % 16 == 0
global_a_layout = Layout.UnsafeAlignedColMajor{eltype(A)}
@assert stride(B, 2) % 16 == 0
global_b_layout = Layout.UnsafeAlignedColMajor{eltype(B)}
# we want to do a simple C = A * B, so no need to load C first.
global_c_layout = Layout.Zero{eltype(C)}
@assert stride(C, 2) % 16 == 0
global_d_layout = Layout.UnsafeAlignedColMajor{eltype(C)}
# shared layouts are similar.
# the frequently-accessed a/b shmems are padded to avoid bank conflicts.
shared_a_layout = Layout.Padded{Layout.UnsafeAlignedColMajor{eltype(A)}, 8}
shared_b_layout = Layout.Padded{Layout.UnsafeAlignedColMajor{eltype(B)}, 8}
shared_c_layout = shared_d_layout = Layout.UnsafeAlignedColMajor{eltype(C)}
# we use the tropical FPU operator
compute_type = promote_type(eltype(A), eltype(B))
operator = Operator.TropicalFPUOp{OP_M, OP_N, OP_K, OP_MB, OP_NB, OP_KB,
compute_type, eltype(C)}
# the block shape is the result of tuning
block_shape = (M = BLOCK_M, N = BLOCK_N, K = BLOCK_K)
@assert M % block_shape.M == 0
@assert N % block_shape.N == 0
@assert K % block_shape.K == 0
conf = GemmKernels.get_config(;
gemm_shape = (M = M, N = N, K = K),
block_shape,
operator,
global_a_layout, global_b_layout, global_c_layout, global_d_layout,
shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout,
is_a_col_major = true,
is_b_col_major = true
)
@btime CUDA.@sync GemmKernels.matmul($conf, $A, $B, $C, $C; kernel=$kernel)
end
CUDA.unsafe_free!(A)
CUDA.unsafe_free!(B)
CUDA.unsafe_free!(C)
end
isinteractive() || display() Now, GemmKernels.jl likely needs some improvements to be better across the board (e.g. more generalization to handle arbitrary input sizes, a better API, etc), but it nonetheless seems like a good starting point with all advantages that native Julia implementations have (arbitrary type support, ease of development, etc). |
These results seems really great. Actually I also have an implementation using |
The following changes have been made:
The new benchmark result is show here:
Originally posted by @ArrogantGao in #1 (comment)
The text was updated successfully, but these errors were encountered: