diff --git a/benchmarks/runbenchmarks.jl b/benchmarks/runbenchmarks.jl index 5f402eec..4f84fd22 100644 --- a/benchmarks/runbenchmarks.jl +++ b/benchmarks/runbenchmarks.jl @@ -112,54 +112,63 @@ for cf in get_configs() @info "Running benchmark $( cf.name )..." c_h, a, b, c, d = generate_inputs(cf) - # warmup - run_gemm(cf, a, b, c, d) + try + # warmup + run_gemm(cf, a, b, c, d) - # benchmark - profile_results = CUDA.@profile begin - for sample in 1:NUM_SAMPLES - run_gemm(cf, a, b, c, d) + # benchmark + profile_results = CUDA.@profile begin + for sample in 1:NUM_SAMPLES + run_gemm(cf, a, b, c, d) + end end - end - # XXX: This works for now, since every GEMM is one kernel, but later on we may want to benchmark - # operations consisting of multiple kernel launches... - profile_results = profile_results.device + # XXX: This works for now, since every GEMM is one kernel, but later on we may want to benchmark + # operations consisting of multiple kernel launches... + profile_results = profile_results.device - # get info - details[cf.name] = Dict( - "registers" => profile_results[1, "registers"], - "dynamic_shared_mem" => profile_results[1, "shared_mem"].dynamic, - "static_shared_mem" => profile_results[1, "shared_mem"].static, - "local_mem" => profile_results[1, "local_mem"].thread - ) + # get info + details[cf.name] = Dict( + "registers" => profile_results[1, "registers"], + "dynamic_shared_mem" => profile_results[1, "shared_mem"].dynamic, + "static_shared_mem" => profile_results[1, "shared_mem"].static, + "local_mem" => profile_results[1, "local_mem"].thread + ) - times = 1e9 .* (profile_results[!, "stop"] - profile_results[!, "start"]) - @assert length(times) == NUM_SAMPLES + times = 1e9 .* (profile_results[!, "stop"] - profile_results[!, "start"]) + @assert length(times) == NUM_SAMPLES - @info "\tGemmKernels: $(prettytime(times)) $(prettyflops(times, cf.config.matmul_shape))" + @info "\tGemmKernels: $(prettytime(times)) $(prettyflops(times, cf.config.matmul_shape))" - if !isnothing(cf.baseline) - # benchmark baseline - baseline_profile_results = CUDA.@profile begin - for sample in 1:NUM_SAMPLES - run_baseline(cf, a, b, c, d) + if !isnothing(cf.baseline) + # benchmark baseline + baseline_profile_results = CUDA.@profile begin + for sample in 1:NUM_SAMPLES + run_baseline(cf, a, b, c, d) + end end - end - baseline_profile_results = baseline_profile_results.device - @assert size(baseline_profile_results, 1) % NUM_SAMPLES == 0 + baseline_profile_results = baseline_profile_results.device + @assert size(baseline_profile_results, 1) % NUM_SAMPLES == 0 - baseline_times = 1e9 .* sum.(Iterators.partition(baseline_profile_results[!, "stop"] - baseline_profile_results[!, "start"], size(baseline_profile_results, 1) ÷ NUM_SAMPLES)) - @assert length(baseline_times) == NUM_SAMPLES + baseline_times = 1e9 .* sum.(Iterators.partition(baseline_profile_results[!, "stop"] - baseline_profile_results[!, "start"], size(baseline_profile_results, 1) ÷ NUM_SAMPLES)) + @assert length(baseline_times) == NUM_SAMPLES - baseline_ratio = "$(round(100 * minimum(baseline_times) / minimum(times); sigdigits=3))" - @info "\tBaseline: $(prettytime(baseline_times)) $(prettyflops(baseline_times, cf.config.matmul_shape)) (GemmKernels: $(baseline_ratio)%)" + baseline_ratio = "$(round(100 * minimum(baseline_times) / minimum(times); sigdigits=3))" + @info "\tBaseline: $(prettytime(baseline_times)) $(prettyflops(baseline_times, cf.config.matmul_shape)) (GemmKernels: $(baseline_ratio)%)" - baseline_results[cf.name] = Dict("times" => baseline_times) - end + baseline_results[cf.name] = Dict("times" => baseline_times) + end - results[cf.name] = Dict("times" => times) + results[cf.name] = Dict("times" => times) + catch err + if isa(err, GemmKernels.ConfigError) + # Skip this benchmark. + @warn "Skipping benchmark $(cf.name): Invalid configuration: $(err)." + else + rethrow() + end + end end function save_results(results_file, details_file, results, details) diff --git a/configs/configs.jl b/configs/configs.jl index c88b2d8d..ca2f9a62 100644 --- a/configs/configs.jl +++ b/configs/configs.jl @@ -241,7 +241,7 @@ macro get_wmma_config() mul!, Epilogue.Default(), verify_default, - Kernel.matmul_pipelined, + kernel, wmma_baseline) end end) end @@ -520,7 +520,22 @@ function get_configs() [2, 2, 1], [1, 1, 2], [2, 2, 2]], [[2048, 2048, 2048]]), - zero_c in [false] + zero_c in [false], + kernel in [Kernel.matmul_pipelined] + + push!(rv, @get_wmma_config) + end + + # WMMA GEMM parameters + for (M, N, K) in [(256, 256, 256)], + (AB_type, CD_type) in [(Float16, Float32)], + transpose_a in [false, true], + transpose_b in [false, true], + (BLOCK_M, BLOCK_N, BLOCK_K) in filter(x -> prod(x[1:2]) <= 128*128, collect(Iterators.product([64, 128, 256], [64, 128, 256], [16, 32, 64]))[:]), + (WARPS_M, WARPS_N) in filter(x -> prod(x) >= 4, collect(Iterators.product([1, 2, 4], [1, 2, 4]))[:]), + zero_c in [false, true], + (OP_M, OP_N, OP_K) in [(16, 16, 16)], + kernel in [Kernel.matmul_singlestage, Kernel.matmul_pipelined] push!(rv, @get_wmma_config) end diff --git a/src/config.jl b/src/config.jl index bb6c2bee..4fd99afc 100644 --- a/src/config.jl +++ b/src/config.jl @@ -35,6 +35,39 @@ is_b_col_major end +function Base.show(io::IO, config::Config) + println(io, "matmul_shape: $(config.matmul_shape)") + println(io, "block_shape: $(config.block_shape)") + println(io, "warps_per_block: $(config.warps_per_block)") + + println(io, "mem_a_warp: $(config.mem_a_warp)") + println(io, "mem_a_thread: $(config.mem_a_thread)") + + println(io, "mem_b_warp: $(config.mem_b_warp)") + println(io, "mem_b_thread: $(config.mem_b_thread)") + + println(io, "mem_cd_warp: $(config.mem_cd_warp)") + println(io, "mem_cd_thread: $(config.mem_cd_thread)") + + println(io, "compute_warp: $(config.compute_warp)") + println(io, "compute_op_shape: $(config.compute_op_shape)") + + println(io, "global_a_layout: $(config.global_a_layout)") + println(io, "global_b_layout: $(config.global_b_layout)") + println(io, "global_c_layout: $(config.global_c_layout)") + println(io, "global_d_layout: $(config.global_d_layout)") + + println(io, "shared_a_layout: $(config.shared_a_layout)") + println(io, "shared_b_layout: $(config.shared_b_layout)") + println(io, "shared_c_layout: $(config.shared_c_layout)") + println(io, "shared_d_layout: $(config.shared_d_layout)") + + println(io, "operator: $(config.operator)") + + println(io, "is_a_col_major: $(config.is_a_col_major)") + println(io, "is_b_col_major: $(config.is_b_col_major)") +end + struct ConfigError <: Exception message::String end diff --git a/test/matmul.jl b/test/matmul.jl index fa625021..19f08de6 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -7,8 +7,17 @@ include("../configs/configs.jl") @testset "Matrix multiplication" begin @testcase "$( cf.name )" for cf in get_configs() - c_h, a, b, c, d = generate_inputs(cf) - run_gemm(cf, a, b, c, d) - @test verify(cf, c_h, d) + try + c_h, a, b, c, d = generate_inputs(cf) + run_gemm(cf, a, b, c, d) + @test verify(cf, c_h, d) + catch err + # Count tests with config errors as "broken". + if isa(err, GemmKernels.ConfigError) + @test true skip=true + else + rethrow() + end + end end end