diff --git a/src/config.jl b/src/config.jl index fff1f7ad..bb6c2bee 100644 --- a/src/config.jl +++ b/src/config.jl @@ -128,6 +128,11 @@ check_operator_config(operator::Type{<:Operator.WMMAOp}) = check_wmma_shape(oper check_operator_config(operator::Type{<:Operator.WMMAComplexOp}) = check_wmma_shape(operator) check_operator_config(operator::Type{<:Operator.WMMADualOp}) = check_wmma_shape(operator) +require_tile_sized_global(layout) = true +require_tile_sized_global(::Type{<:Layout.Zero{T}}) where {T} = false +require_tile_sized_global(::Type{<:Layout.ColMajor{T}}) where {T} = false +require_tile_sized_global(::Type{<:Layout.RowMajor{T}}) where {T} = false + function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kwargs...) params = Dict(kwargs) @@ -215,6 +220,15 @@ function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kw prod(mem_b_warp) * warps_per_block ≤ block_shape.K * block_shape.N || throw(ConfigError("mem_b_warp is too big for the selected block shape: need at least one iteration in the memory copy loop!")) prod(mem_cd_warp) * warps_per_block ≤ block_shape.M * block_shape.N || throw(ConfigError("mem_cd_warp is too big for the selected block shape: need at least one iteration in the memory copy loop!")) + # Check sizes of tiles + check_tile_multiple(num, den, dims, msg) = all([num[dim] % den[dim] == 0 for dim in dims]) || throw(ConfigError(msg)) + + check_tile_multiple(block_shape, compute_warp, [:M, :N, :K], "block_shape must be a multiple of compute_warp!") + require_tile_sized_global(global_a_layout) && check_tile_multiple(gemm_shape, block_shape, [:M, :K], "gemm_shape.MK must be a multiple of block_shape.MK!") + require_tile_sized_global(global_b_layout) && check_tile_multiple(gemm_shape, block_shape, [:K, :N], "gemm_shape.KN must be a multiple of block_shape.KN!") + require_tile_sized_global(global_c_layout) && check_tile_multiple(gemm_shape, block_shape, [:M, :N], "gemm_shape.MN must be a multiple of block_shape.MN!") + require_tile_sized_global(global_d_layout) && check_tile_multiple(gemm_shape, block_shape, [:M, :N], "gemm_shape.MN must be a multiple of block_shape.MN!") + return Config( #= Params =# gemm_shape,