Skip to content

Commit

Permalink
[tuner]: use translation_info binding (#669)
Browse files Browse the repository at this point in the history
This PR is relevant to the task in
#453 : use IREE bindings for
compilation info (incl., lowering_config and translation_info).

Use translation_info from IREE  python binding.

---------

Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
  • Loading branch information
bangtianliu authored Dec 11, 2024
1 parent da6bb12 commit 1e26b20
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 173 deletions.
106 changes: 19 additions & 87 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def apply_configuration(
) = lowering_config.subgroup_count_mn
workgroup_sizes = lowering_config.workgroup_tile_sizes
reduction_sizes = lowering_config.reduction_tile_sizes
gpu_pipeline_options = configuration.translation_info.configuration[
GPU_PIPELINE_OPTIONS_KEY
]
waves_per_eu = configuration.translation_info.configuration[LLVM_FUNC_ATTRS_KEY][
WAVES_PER_EU_KEY
]
tune_logger.info(f"Applying: {configuration}")
expr0 = re.compile(
r"<intrinsic = #iree_gpu\.mma_layout<(.+)>, subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>"
Expand All @@ -63,11 +69,11 @@ def apply_configuration(
expr4 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>")
expr5 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"")
repl0 = f"<intrinsic = {intrinsic}, subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>"
repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},'
repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.translation_info.workgroup_size))}] subgroup_size = {configuration.translation_info.subgroup_size},'
repl2 = f"workgroup = {workgroup_sizes}"
repl3 = f"reduction = {reduction_sizes}"
repl4 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}"
repl5 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"'
repl4 = f"gpu_pipeline_options = {gpu_pipeline_options}"
repl5 = f'"amdgpu-waves-per-eu" = {waves_per_eu}'

new_mlir = ""
for line in template:
Expand Down Expand Up @@ -128,15 +134,6 @@ class MmtTuner(DispatchTuner, MmtParser):
def get_transform_function_mmt(
self, problem_size: ProblemSize, functionName: str, configuration: Configuration
) -> str:
lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
return f"""
transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{
%mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op
Expand All @@ -145,13 +142,8 @@ def get_transform_function_mmt(
transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value
transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = {configuration.lowering_config}>,
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = {intrinsic},
subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>
{extra_config}}}>
lowering_config = {configuration.lowering_config},
translation_info = {configuration.translation_info}
> -> !transform.any_param
transform.yield %matmul, %config : !transform.any_op, !transform.any_param
}}
Expand Down Expand Up @@ -197,16 +189,6 @@ def get_transform_function_conv(
filter = f"tensor<{problem_size.rhs_type}>"
output = f"tensor<{dynamic_batch_output_ty}>"

lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)

return f"""
transform.named_sequence @{functionName}(%conv: !transform.any_op {{transform.readonly}})
-> (!transform.any_op, !transform.any_param) {{
Expand All @@ -217,13 +199,8 @@ def get_transform_function_conv(
outs(%out : {output}) -> {output}
}} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = {configuration.lowering_config}>,
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = {intrinsic},
subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>
{extra_config}}}>
lowering_config = {configuration.lowering_config},
translation_info = {configuration.translation_info}
> -> !transform.any_param
transform.yield %conv, %config : !transform.any_op, !transform.any_param
}}
Expand Down Expand Up @@ -262,16 +239,6 @@ def get_transform_function_broadcast_rhs_mmt(
functionName: str,
configuration: Configuration,
) -> str:
lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)

lhs_dynamic_batch = problem_size.lhs_type
lhs_dynamic_batch.shape = lhs_dynamic_batch.shape.copy()
lhs_dynamic_batch.shape[0] = -1
Expand All @@ -284,13 +251,8 @@ def get_transform_function_broadcast_rhs_mmt(
transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value
transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = {configuration.lowering_config}>,
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = {intrinsic},
subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>
{extra_config}}}>
lowering_config = {configuration.lowering_config},
translation_info = {configuration.translation_info}
> -> !transform.any_param
transform.yield %generic, %config : !transform.any_op, !transform.any_param
}}
Expand Down Expand Up @@ -351,16 +313,6 @@ def get_transform_function_batch_mmt(
functionName: str,
configuration: Configuration,
) -> str:
lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)

return f"""
transform.named_sequence @{functionName}(%generic: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{
%mmt = transform.include @match_batch_mmt_i8_i8_i32 failures(propagate) (%generic) : (!transform.any_op) -> !transform.any_op
Expand All @@ -369,13 +321,8 @@ def get_transform_function_batch_mmt(
transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value
transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = {configuration.lowering_config}>,
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = {intrinsic},
subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>
{extra_config}}}>
lowering_config = {configuration.lowering_config},
translation_info = {configuration.translation_info}
> -> !transform.any_param
transform.yield %generic, %config : !transform.any_op, !transform.any_param
}}
Expand Down Expand Up @@ -421,16 +368,6 @@ def get_transform_function_batch_matmul(
input1 = f"tensor<{problem_size.rhs_type}>"
output = f"tensor<{problem_size.res_type}>"

lowering_config = configuration.lowering_config
intrinsic = lowering_config.mma_kind
(
subgroup_m_count,
subgroup_n_count,
) = lowering_config.subgroup_count_mn

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)

return f"""
transform.named_sequence @{functionName}(%batch_matmul: !transform.any_op {{transform.readonly}})
-> (!transform.any_op, !transform.any_param) {{
Expand All @@ -441,13 +378,8 @@ def get_transform_function_batch_matmul(
outs(%out : {output}) -> {output}
}} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = {configuration.lowering_config}>,
translation_info = #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = {intrinsic},
subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>
{extra_config}}}>
lowering_config = {configuration.lowering_config},
translation_info = {configuration.translation_info}
> -> !transform.any_param
transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param
}}
Expand Down
102 changes: 68 additions & 34 deletions tuner/tuner/candidate_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from iree.compiler import ir # type: ignore
from iree.compiler.dialects import iree_gpu # type: ignore
from iree.compiler.dialects import iree_codegen # type: ignore

from . import candidate_gen
from . import common
Expand Down Expand Up @@ -56,14 +57,17 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None:
subgroup_m_count=16,
subgroup_n_count=16,
)
pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get(
iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute
)
pipeline_options = iree_gpu.PipelineOptionsAttr.get(prefetch_shared_memory=True)
config_dict = common.get_translation_info_config(pipeline_options, 8)
translation_info = iree_codegen.TranslationInfoAttr.get(
pipeline_attr, None, [16, 16, 1], 16, config_dict
)
config = common.Configuration(
subgroup_size=16,
workgroup_size=[16, 16, 1],
translation_info=translation_info,
lowering_config=lowering_config,
gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(
prefetch_shared_memory=True
),
waves_per_eu=8,
)

problem_size = common.ProblemSize(
Expand Down Expand Up @@ -118,16 +122,21 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None:
subgroup_m_count=1,
subgroup_n_count=4,
)
pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get(
iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute
)
pipeline_options = iree_gpu.PipelineOptionsAttr.get(
reorder_workgroups_strategy=iree_gpu.ReorderWorkgroupsStrategyAttr.get(
iree_gpu.ReorderWorkgroupsStrategy.Transpose
)
)
config_dict = common.get_translation_info_config(pipeline_options, 2)
translation_info = iree_codegen.TranslationInfoAttr.get(
pipeline_attr, None, [256, 1, 1], 64, config_dict
)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[256, 1, 1],
translation_info=translation_info,
lowering_config=lowering_config,
gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(
reorder_workgroups_strategy=iree_gpu.ReorderWorkgroupsStrategyAttr.get(
iree_gpu.ReorderWorkgroupsStrategy.Transpose
)
),
waves_per_eu=2,
)

problem_size = common.ProblemSize(
Expand Down Expand Up @@ -191,12 +200,17 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None:
subgroup_m_count=1,
subgroup_n_count=4,
)
pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get(
iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute
)
pipeline_options = iree_gpu.PipelineOptionsAttr.get()
config_dict = common.get_translation_info_config(pipeline_options, 2)
translation_info = iree_codegen.TranslationInfoAttr.get(
pipeline_attr, None, [256, 1, 1], 64, config_dict
)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[256, 1, 1],
translation_info=translation_info,
lowering_config=lowering_config,
gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(),
waves_per_eu=2,
)

tf_mlir = candidate_gen.ContractionTuner("mk", "nk", tile_dims).apply_params(
Expand Down Expand Up @@ -246,12 +260,17 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None:
subgroup_m_count=2,
subgroup_n_count=2,
)
pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get(
iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute
)
pipeline_options = iree_gpu.PipelineOptionsAttr.get()
config_dict = common.get_translation_info_config(pipeline_options, 2)
translation_info = iree_codegen.TranslationInfoAttr.get(
pipeline_attr, None, [128, 2, 1], 64, config_dict
)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
translation_info=translation_info,
lowering_config=lowering_config,
gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(),
waves_per_eu=2,
)

tf_mlir = candidate_gen.BatchMatmulTuner("mk", "nk", tile_dims).apply_params(
Expand Down Expand Up @@ -304,12 +323,17 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None:
subgroup_m_count=2,
subgroup_n_count=2,
)
pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get(
iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute
)
pipeline_options = iree_gpu.PipelineOptionsAttr.get()
config_dict = common.get_translation_info_config(pipeline_options, 2)
translation_info = iree_codegen.TranslationInfoAttr.get(
pipeline_attr, None, [128, 2, 1], 64, config_dict
)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
translation_info=translation_info,
lowering_config=lowering_config,
gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(),
waves_per_eu=2,
)

tf_mlir = candidate_gen.BatchMmtTuner().apply_params(
Expand Down Expand Up @@ -360,12 +384,17 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None:
subgroup_m_count=2,
subgroup_n_count=2,
)
pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get(
iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute
)
pipeline_options = iree_gpu.PipelineOptionsAttr.get()
config_dict = common.get_translation_info_config(pipeline_options, 4)
translation_info = iree_codegen.TranslationInfoAttr.get(
pipeline_attr, None, [128, 2, 1], 64, config_dict
)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
translation_info=translation_info,
lowering_config=lowering_config,
gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(),
waves_per_eu=4,
)

tf_mlir = candidate_gen.BatchMmtTuner().apply_params(
Expand Down Expand Up @@ -440,12 +469,17 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None:
subgroup_m_count=2,
subgroup_n_count=2,
)
pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get(
iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute
)
pipeline_options = iree_gpu.PipelineOptionsAttr.get()
config_dict = common.get_translation_info_config(pipeline_options, 4)
translation_info = iree_codegen.TranslationInfoAttr.get(
pipeline_attr, None, [128, 2, 1], 64, config_dict
)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
translation_info=translation_info,
lowering_config=lowering_config,
gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(),
waves_per_eu=4,
)

tf_mlir = candidate_gen.ContractionTuner(
Expand Down
Loading

0 comments on commit 1e26b20

Please sign in to comment.