Skip to content

Commit

Permalink
Test more WMMA configurations (#171)
Browse files Browse the repository at this point in the history
* Test more WMMA configurations

* Skip benchmarks with invalid config

* Retrigger CI
  • Loading branch information
thomasfaingnaert authored Jan 2, 2024
1 parent 3052b52 commit 3c328d1
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 40 deletions.
79 changes: 44 additions & 35 deletions benchmarks/runbenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,54 +112,63 @@ for cf in get_configs()
@info "Running benchmark $( cf.name )..."
c_h, a, b, c, d = generate_inputs(cf)

# warmup
run_gemm(cf, a, b, c, d)
try
# warmup
run_gemm(cf, a, b, c, d)

# benchmark
profile_results = CUDA.@profile begin
for sample in 1:NUM_SAMPLES
run_gemm(cf, a, b, c, d)
# benchmark
profile_results = CUDA.@profile begin
for sample in 1:NUM_SAMPLES
run_gemm(cf, a, b, c, d)
end
end
end

# XXX: This works for now, since every GEMM is one kernel, but later on we may want to benchmark
# operations consisting of multiple kernel launches...
profile_results = profile_results.device
# XXX: This works for now, since every GEMM is one kernel, but later on we may want to benchmark
# operations consisting of multiple kernel launches...
profile_results = profile_results.device

# get info
details[cf.name] = Dict(
"registers" => profile_results[1, "registers"],
"dynamic_shared_mem" => profile_results[1, "shared_mem"].dynamic,
"static_shared_mem" => profile_results[1, "shared_mem"].static,
"local_mem" => profile_results[1, "local_mem"].thread
)
# get info
details[cf.name] = Dict(
"registers" => profile_results[1, "registers"],
"dynamic_shared_mem" => profile_results[1, "shared_mem"].dynamic,
"static_shared_mem" => profile_results[1, "shared_mem"].static,
"local_mem" => profile_results[1, "local_mem"].thread
)

times = 1e9 .* (profile_results[!, "stop"] - profile_results[!, "start"])
@assert length(times) == NUM_SAMPLES
times = 1e9 .* (profile_results[!, "stop"] - profile_results[!, "start"])
@assert length(times) == NUM_SAMPLES

@info "\tGemmKernels: $(prettytime(times)) $(prettyflops(times, cf.config.matmul_shape))"
@info "\tGemmKernels: $(prettytime(times)) $(prettyflops(times, cf.config.matmul_shape))"

if !isnothing(cf.baseline)
# benchmark baseline
baseline_profile_results = CUDA.@profile begin
for sample in 1:NUM_SAMPLES
run_baseline(cf, a, b, c, d)
if !isnothing(cf.baseline)
# benchmark baseline
baseline_profile_results = CUDA.@profile begin
for sample in 1:NUM_SAMPLES
run_baseline(cf, a, b, c, d)
end
end
end

baseline_profile_results = baseline_profile_results.device
@assert size(baseline_profile_results, 1) % NUM_SAMPLES == 0
baseline_profile_results = baseline_profile_results.device
@assert size(baseline_profile_results, 1) % NUM_SAMPLES == 0

baseline_times = 1e9 .* sum.(Iterators.partition(baseline_profile_results[!, "stop"] - baseline_profile_results[!, "start"], size(baseline_profile_results, 1) ÷ NUM_SAMPLES))
@assert length(baseline_times) == NUM_SAMPLES
baseline_times = 1e9 .* sum.(Iterators.partition(baseline_profile_results[!, "stop"] - baseline_profile_results[!, "start"], size(baseline_profile_results, 1) ÷ NUM_SAMPLES))
@assert length(baseline_times) == NUM_SAMPLES

baseline_ratio = "$(round(100 * minimum(baseline_times) / minimum(times); sigdigits=3))"
@info "\tBaseline: $(prettytime(baseline_times)) $(prettyflops(baseline_times, cf.config.matmul_shape)) (GemmKernels: $(baseline_ratio)%)"
baseline_ratio = "$(round(100 * minimum(baseline_times) / minimum(times); sigdigits=3))"
@info "\tBaseline: $(prettytime(baseline_times)) $(prettyflops(baseline_times, cf.config.matmul_shape)) (GemmKernels: $(baseline_ratio)%)"

baseline_results[cf.name] = Dict("times" => baseline_times)
end
baseline_results[cf.name] = Dict("times" => baseline_times)
end

results[cf.name] = Dict("times" => times)
results[cf.name] = Dict("times" => times)
catch err
if isa(err, GemmKernels.ConfigError)
# Skip this benchmark.
@warn "Skipping benchmark $(cf.name): Invalid configuration: $(err)."
else
rethrow()
end
end
end

function save_results(results_file, details_file, results, details)
Expand Down
19 changes: 17 additions & 2 deletions configs/configs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ macro get_wmma_config()
mul!,
Epilogue.Default(),
verify_default,
Kernel.matmul_pipelined,
kernel,
wmma_baseline)
end end)
end
Expand Down Expand Up @@ -520,7 +520,22 @@ function get_configs()
[2, 2, 1],
[1, 1, 2],
[2, 2, 2]], [[2048, 2048, 2048]]),
zero_c in [false]
zero_c in [false],
kernel in [Kernel.matmul_pipelined]

push!(rv, @get_wmma_config)
end

# WMMA GEMM parameters
for (M, N, K) in [(256, 256, 256)],
(AB_type, CD_type) in [(Float16, Float32)],
transpose_a in [false, true],
transpose_b in [false, true],
(BLOCK_M, BLOCK_N, BLOCK_K) in filter(x -> prod(x[1:2]) <= 128*128, collect(Iterators.product([64, 128, 256], [64, 128, 256], [16, 32, 64]))[:]),
(WARPS_M, WARPS_N) in filter(x -> prod(x) >= 4, collect(Iterators.product([1, 2, 4], [1, 2, 4]))[:]),
zero_c in [false, true],
(OP_M, OP_N, OP_K) in [(16, 16, 16)],
kernel in [Kernel.matmul_singlestage, Kernel.matmul_pipelined]

push!(rv, @get_wmma_config)
end
Expand Down
33 changes: 33 additions & 0 deletions src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,39 @@
is_b_col_major
end

function Base.show(io::IO, config::Config)
println(io, "matmul_shape: $(config.matmul_shape)")
println(io, "block_shape: $(config.block_shape)")
println(io, "warps_per_block: $(config.warps_per_block)")

println(io, "mem_a_warp: $(config.mem_a_warp)")
println(io, "mem_a_thread: $(config.mem_a_thread)")

println(io, "mem_b_warp: $(config.mem_b_warp)")
println(io, "mem_b_thread: $(config.mem_b_thread)")

println(io, "mem_cd_warp: $(config.mem_cd_warp)")
println(io, "mem_cd_thread: $(config.mem_cd_thread)")

println(io, "compute_warp: $(config.compute_warp)")
println(io, "compute_op_shape: $(config.compute_op_shape)")

println(io, "global_a_layout: $(config.global_a_layout)")
println(io, "global_b_layout: $(config.global_b_layout)")
println(io, "global_c_layout: $(config.global_c_layout)")
println(io, "global_d_layout: $(config.global_d_layout)")

println(io, "shared_a_layout: $(config.shared_a_layout)")
println(io, "shared_b_layout: $(config.shared_b_layout)")
println(io, "shared_c_layout: $(config.shared_c_layout)")
println(io, "shared_d_layout: $(config.shared_d_layout)")

println(io, "operator: $(config.operator)")

println(io, "is_a_col_major: $(config.is_a_col_major)")
println(io, "is_b_col_major: $(config.is_b_col_major)")
end

struct ConfigError <: Exception
message::String
end
Expand Down
15 changes: 12 additions & 3 deletions test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,17 @@ include("../configs/configs.jl")

@testset "Matrix multiplication" begin
@testcase "$( cf.name )" for cf in get_configs()
c_h, a, b, c, d = generate_inputs(cf)
run_gemm(cf, a, b, c, d)
@test verify(cf, c_h, d)
try
c_h, a, b, c, d = generate_inputs(cf)
run_gemm(cf, a, b, c, d)
@test verify(cf, c_h, d)
catch err
# Count tests with config errors as "broken".
if isa(err, GemmKernels.ConfigError)
@test true skip=true
else
rethrow()
end
end
end
end

0 comments on commit 3c328d1

Please sign in to comment.