From c840438b62e3071b8e658de7343c8e461387de97 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 30 Jul 2021 17:31:00 -0500 Subject: [PATCH 01/57] Squashed 'src/composable_kernel/' content from commit f6edda611 git-subtree-dir: src/composable_kernel git-subtree-split: f6edda6119ebbb237dfa6270797b34f960d7b190 --- .clang-format | 90 + CMakeLists.txt | 42 + README.md | 177 + cmake/AddKernels.cmake | 40 + cmake/TargetFlags.cmake | 50 + .../include/gridwise_operation_wrapper.hpp | 14 + ...volution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp | 272 + ...lution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp | 275 + ...volution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp | 263 + ...volution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp | 179 + ...lution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp | 129 + ...lution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp | 129 + ...lution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp | 132 + ...volution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp | 132 + .../tensor_description/cluster_descriptor.hpp | 33 + .../dynamic_multi_index_transform.hpp | 1737 +++++ .../dynamic_multi_index_transform_helper.hpp | 104 + .../dynamic_tensor_descriptor.hpp | 596 ++ .../dynamic_tensor_descriptor_helper.hpp | 150 + .../tensor_description/tensor_adaptor.hpp | 466 ++ ...lockwise_dynamic_tensor_slice_transfer.hpp | 171 + ...kwise_dynamic_tensor_slice_transfer_v2.hpp | 158 + .../blockwise_gemm_dlops_v2r2.hpp | 396 ++ .../blockwise_gemm_dlops_v2r3.hpp | 410 ++ .../blockwise_gemm_dlops_v3.hpp | 190 + .../blockwise_gemm_xdlops.hpp | 528 ++ ...ridwise_dynamic_contraction_dlops_v1r2.hpp | 664 ++ .../gridwise_dynamic_gemm_dlops_v1r2.hpp | 679 ++ .../gridwise_dynamic_gemm_dlops_v1r3.hpp | 671 ++ .../gridwise_dynamic_gemm_dlops_v2.hpp | 463 ++ .../gridwise_dynamic_gemm_xdlops_v2r3.hpp | 823 +++ .../threadwise_contraction_dlops.hpp | 230 + .../threadwise_dynamic_tensor_slice_set.hpp | 59 + ...readwise_dynamic_tensor_slice_transfer.hpp | 1449 +++++ ...dwise_dynamic_tensor_slice_transfer_v2.hpp | 789 +++ .../threadwise_gemm_dlops_v3.hpp | 162 + .../include/tensor_operation/xdlops_gemm.hpp | 801 +++ .../utility/amd_buffer_addressing_v2.hpp | 654 ++ .../include/utility/amd_dlop.hpp | 188 + .../include/utility/amd_inline_asm.hpp | 353 + .../include/utility/amd_llvm_intrinsic.hpp | 11 + .../include/utility/amd_xdlops.hpp | 499 ++ composable_kernel/include/utility/array.hpp | 63 + .../include/utility/array_multi_index.hpp | 77 + .../include/utility/common_header.hpp | 45 + composable_kernel/include/utility/config.hpp | 142 + .../utility/container_element_picker.hpp | 155 + .../include/utility/container_helper.hpp | 403 ++ .../include/utility/data_type.hpp | 1017 +++ .../include/utility/data_type_enum.hpp | 20 + .../include/utility/data_type_helper.hpp | 76 + .../include/utility/dynamic_buffer.hpp | 208 + .../include/utility/functional.hpp | 116 + .../include/utility/functional2.hpp | 48 + .../include/utility/functional3.hpp | 142 + .../include/utility/functional4.hpp | 62 + .../include/utility/integral_constant.hpp | 17 + .../include/utility/magic_division.hpp | 155 + composable_kernel/include/utility/math.hpp | 225 + .../include/utility/multi_index.hpp | 12 + composable_kernel/include/utility/number.hpp | 44 + composable_kernel/include/utility/print.hpp | 70 + .../include/utility/sequence.hpp | 882 +++ .../include/utility/sequence_helper.hpp | 36 + .../include/utility/static_buffer.hpp | 35 + .../utility/statically_indexed_array.hpp | 40 + .../statically_indexed_array_multi_index.hpp | 108 + .../include/utility/synchronization.hpp | 21 + composable_kernel/include/utility/tuple.hpp | 167 + .../include/utility/tuple_helper.hpp | 80 + composable_kernel/include/utility/type.hpp | 60 + composable_kernel/include/utility/utility.hpp | 14 + ...mplicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp | 374 ++ ...plicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp | 362 ++ ...plicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp | 362 ++ ...mplicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp | 392 ++ external/half/include/half.hpp | 5671 +++++++++++++++++ external/rocm/include/bfloat16_dev.hpp | 125 + host/CMakeLists.txt | 4 + host/driver_offline/CMakeLists.txt | 21 + .../conv_bwd_driver_offline.cpp | 357 ++ .../conv_fwd_driver_offline.cpp | 480 ++ ...plicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp | 341 + ...icit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp | 317 + ...mplicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp | 210 + ...plicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp | 283 + ...licit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp | 284 + ...icit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp | 206 + ...icit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp | 240 + ...icit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp | 305 + ...icit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp | 365 ++ ...mplicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp | 192 + ...mplicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp | 244 + .../driver_dynamic_contraction_dlops_v1r2.hpp | 290 + ...mplicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp | 352 + ..._gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp | 367 ++ .../driver_dynamic_gemm_dlops_v1r2.hpp | 415 ++ .../driver_dynamic_gemm_dlops_v1r3.hpp | 411 ++ .../driver_dynamic_gemm_xdlops_v2r3.hpp | 196 + host/driver_online/CMakeLists.txt | 21 + host/driver_online/conv_fwd_driver_online.cpp | 453 ++ ...nv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp | 673 ++ ..._tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp | 51 + ...tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp | 73 + ...tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp | 73 + .../convolution_problem_descriptor.hpp | 79 + ...mplicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp | 395 ++ ...plicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp | 386 ++ ...plicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp | 389 ++ ...mplicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp | 182 + .../include/online_driver_common.hpp | 44 + host/host_tensor/CMakeLists.txt | 19 + host/host_tensor/include/conv_common.hpp | 86 + host/host_tensor/include/device.hpp | 86 + host/host_tensor/include/device_tensor.hpp | 9 + host/host_tensor/include/host_conv.hpp | 326 + .../include/host_conv_bwd_data.hpp | 143 + host/host_tensor/include/host_tensor.hpp | 322 + .../include/host_tensor_generator.hpp | 60 + host/host_tensor/src/device.cpp | 67 + host/host_tensor/src/host_tensor.cpp | 48 + host/online_compilation/CMakeLists.txt | 168 + .../addkernels/CMakeLists.txt | 30 + .../addkernels/addkernels.cpp | 264 + .../addkernels/include_inliner.cpp | 213 + .../addkernels/include_inliner.hpp | 142 + .../addkernels/source_file_desc.hpp | 45 + .../hip_utility/binary_cache.cpp | 112 + .../hip_utility/exec_utils.cpp | 93 + .../hip_utility/handlehip.cpp | 285 + .../hip_utility/hip_build_utils.cpp | 346 + .../hip_utility/hipoc_kernel.cpp | 84 + .../hip_utility/hipoc_program.cpp | 139 + .../hip_utility/kernel_build_params.cpp | 66 + .../hip_utility/kernel_cache.cpp | 154 + .../online_compilation/hip_utility/logger.cpp | 43 + host/online_compilation/hip_utility/md5.cpp | 319 + .../hip_utility/target_properties.cpp | 119 + .../hip_utility/tmp_dir.cpp | 66 + .../include/binary_cache.hpp | 52 + host/online_compilation/include/config.h.in | 47 + host/online_compilation/include/env.hpp | 123 + .../online_compilation/include/exec_utils.hpp | 42 + host/online_compilation/include/handle.hpp | 145 + host/online_compilation/include/hipCheck.hpp | 22 + .../include/hip_build_utils.hpp | 97 + .../include/hipoc_kernel.hpp | 174 + .../include/hipoc_program.hpp | 64 + .../include/hipoc_program_impl.hpp | 61 + host/online_compilation/include/kernel.hpp | 45 + .../include/kernel_build_params.hpp | 137 + .../include/kernel_cache.hpp | 97 + host/online_compilation/include/logger.hpp | 23 + .../online_compilation/include/manage_ptr.hpp | 76 + host/online_compilation/include/md5.hpp | 12 + .../include/op_kernel_args.hpp | 35 + .../include/simple_hash.hpp | 44 + .../include/stringutils.hpp | 133 + .../include/target_properties.hpp | 56 + host/online_compilation/include/tmp_dir.hpp | 26 + .../online_compilation/include/write_file.hpp | 30 + host/online_compilation/kernel.cpp.in | 70 + .../online_compilation/kernel_includes.cpp.in | 80 + host/online_compilation/kernels_batch.cpp.in | 1 + script/cmake-rocm.sh | 42 + script/count_vgpr.sh | 259 + script/docker-rocm4.1.sh | 14 + script/hipclang_opt.sh | 25 + script/run.sh | 47 + 169 files changed, 41816 insertions(+) create mode 100644 .clang-format create mode 100644 CMakeLists.txt create mode 100644 README.md create mode 100644 cmake/AddKernels.cmake create mode 100644 cmake/TargetFlags.cmake create mode 100644 composable_kernel/include/gridwise_operation_wrapper.hpp create mode 100644 composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp create mode 100644 composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp create mode 100644 composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp create mode 100644 composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp create mode 100644 composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp create mode 100644 composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp create mode 100644 composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp create mode 100644 composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp create mode 100644 composable_kernel/include/tensor_description/cluster_descriptor.hpp create mode 100644 composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp create mode 100644 composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp create mode 100644 composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp create mode 100644 composable_kernel/include/tensor_description/dynamic_tensor_descriptor_helper.hpp create mode 100644 composable_kernel/include/tensor_description/tensor_adaptor.hpp create mode 100644 composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp create mode 100644 composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer_v2.hpp create mode 100644 composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp create mode 100644 composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp create mode 100644 composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp create mode 100644 composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp create mode 100644 composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_dlops_v1r2.hpp create mode 100644 composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v1r2.hpp create mode 100644 composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v1r3.hpp create mode 100644 composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v2.hpp create mode 100644 composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r3.hpp create mode 100644 composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp create mode 100644 composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_set.hpp create mode 100644 composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp create mode 100644 composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer_v2.hpp create mode 100644 composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp create mode 100644 composable_kernel/include/tensor_operation/xdlops_gemm.hpp create mode 100644 composable_kernel/include/utility/amd_buffer_addressing_v2.hpp create mode 100644 composable_kernel/include/utility/amd_dlop.hpp create mode 100644 composable_kernel/include/utility/amd_inline_asm.hpp create mode 100644 composable_kernel/include/utility/amd_llvm_intrinsic.hpp create mode 100644 composable_kernel/include/utility/amd_xdlops.hpp create mode 100644 composable_kernel/include/utility/array.hpp create mode 100644 composable_kernel/include/utility/array_multi_index.hpp create mode 100644 composable_kernel/include/utility/common_header.hpp create mode 100644 composable_kernel/include/utility/config.hpp create mode 100644 composable_kernel/include/utility/container_element_picker.hpp create mode 100644 composable_kernel/include/utility/container_helper.hpp create mode 100644 composable_kernel/include/utility/data_type.hpp create mode 100644 composable_kernel/include/utility/data_type_enum.hpp create mode 100644 composable_kernel/include/utility/data_type_helper.hpp create mode 100644 composable_kernel/include/utility/dynamic_buffer.hpp create mode 100644 composable_kernel/include/utility/functional.hpp create mode 100644 composable_kernel/include/utility/functional2.hpp create mode 100644 composable_kernel/include/utility/functional3.hpp create mode 100644 composable_kernel/include/utility/functional4.hpp create mode 100644 composable_kernel/include/utility/integral_constant.hpp create mode 100644 composable_kernel/include/utility/magic_division.hpp create mode 100644 composable_kernel/include/utility/math.hpp create mode 100644 composable_kernel/include/utility/multi_index.hpp create mode 100644 composable_kernel/include/utility/number.hpp create mode 100644 composable_kernel/include/utility/print.hpp create mode 100644 composable_kernel/include/utility/sequence.hpp create mode 100644 composable_kernel/include/utility/sequence_helper.hpp create mode 100644 composable_kernel/include/utility/static_buffer.hpp create mode 100644 composable_kernel/include/utility/statically_indexed_array.hpp create mode 100644 composable_kernel/include/utility/statically_indexed_array_multi_index.hpp create mode 100644 composable_kernel/include/utility/synchronization.hpp create mode 100644 composable_kernel/include/utility/tuple.hpp create mode 100644 composable_kernel/include/utility/tuple_helper.hpp create mode 100644 composable_kernel/include/utility/type.hpp create mode 100644 composable_kernel/include/utility/utility.hpp create mode 100644 composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp create mode 100644 composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp create mode 100644 composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp create mode 100644 composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp create mode 100644 external/half/include/half.hpp create mode 100644 external/rocm/include/bfloat16_dev.hpp create mode 100644 host/CMakeLists.txt create mode 100644 host/driver_offline/CMakeLists.txt create mode 100644 host/driver_offline/conv_bwd_driver_offline.cpp create mode 100644 host/driver_offline/conv_fwd_driver_offline.cpp create mode 100644 host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp create mode 100644 host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp create mode 100644 host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp create mode 100644 host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp create mode 100644 host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp create mode 100644 host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp create mode 100644 host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp create mode 100644 host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp create mode 100644 host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp create mode 100644 host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp create mode 100644 host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp create mode 100644 host/driver_offline/include/driver_dynamic_contraction_dlops_v1r2.hpp create mode 100644 host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp create mode 100644 host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp create mode 100644 host/driver_offline/include/driver_dynamic_gemm_dlops_v1r2.hpp create mode 100644 host/driver_offline/include/driver_dynamic_gemm_dlops_v1r3.hpp create mode 100644 host/driver_offline/include/driver_dynamic_gemm_xdlops_v2r3.hpp create mode 100644 host/driver_online/CMakeLists.txt create mode 100644 host/driver_online/conv_fwd_driver_online.cpp create mode 100644 host/driver_online/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp create mode 100644 host/driver_online/include/conv_tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp create mode 100644 host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp create mode 100644 host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp create mode 100644 host/driver_online/include/convolution_problem_descriptor.hpp create mode 100644 host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp create mode 100644 host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp create mode 100644 host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp create mode 100644 host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp create mode 100644 host/driver_online/include/online_driver_common.hpp create mode 100644 host/host_tensor/CMakeLists.txt create mode 100644 host/host_tensor/include/conv_common.hpp create mode 100644 host/host_tensor/include/device.hpp create mode 100644 host/host_tensor/include/device_tensor.hpp create mode 100644 host/host_tensor/include/host_conv.hpp create mode 100644 host/host_tensor/include/host_conv_bwd_data.hpp create mode 100644 host/host_tensor/include/host_tensor.hpp create mode 100644 host/host_tensor/include/host_tensor_generator.hpp create mode 100644 host/host_tensor/src/device.cpp create mode 100644 host/host_tensor/src/host_tensor.cpp create mode 100644 host/online_compilation/CMakeLists.txt create mode 100644 host/online_compilation/addkernels/CMakeLists.txt create mode 100644 host/online_compilation/addkernels/addkernels.cpp create mode 100644 host/online_compilation/addkernels/include_inliner.cpp create mode 100644 host/online_compilation/addkernels/include_inliner.hpp create mode 100644 host/online_compilation/addkernels/source_file_desc.hpp create mode 100644 host/online_compilation/hip_utility/binary_cache.cpp create mode 100644 host/online_compilation/hip_utility/exec_utils.cpp create mode 100644 host/online_compilation/hip_utility/handlehip.cpp create mode 100644 host/online_compilation/hip_utility/hip_build_utils.cpp create mode 100644 host/online_compilation/hip_utility/hipoc_kernel.cpp create mode 100644 host/online_compilation/hip_utility/hipoc_program.cpp create mode 100644 host/online_compilation/hip_utility/kernel_build_params.cpp create mode 100644 host/online_compilation/hip_utility/kernel_cache.cpp create mode 100644 host/online_compilation/hip_utility/logger.cpp create mode 100644 host/online_compilation/hip_utility/md5.cpp create mode 100644 host/online_compilation/hip_utility/target_properties.cpp create mode 100644 host/online_compilation/hip_utility/tmp_dir.cpp create mode 100644 host/online_compilation/include/binary_cache.hpp create mode 100644 host/online_compilation/include/config.h.in create mode 100644 host/online_compilation/include/env.hpp create mode 100644 host/online_compilation/include/exec_utils.hpp create mode 100644 host/online_compilation/include/handle.hpp create mode 100644 host/online_compilation/include/hipCheck.hpp create mode 100644 host/online_compilation/include/hip_build_utils.hpp create mode 100644 host/online_compilation/include/hipoc_kernel.hpp create mode 100644 host/online_compilation/include/hipoc_program.hpp create mode 100644 host/online_compilation/include/hipoc_program_impl.hpp create mode 100644 host/online_compilation/include/kernel.hpp create mode 100644 host/online_compilation/include/kernel_build_params.hpp create mode 100644 host/online_compilation/include/kernel_cache.hpp create mode 100644 host/online_compilation/include/logger.hpp create mode 100644 host/online_compilation/include/manage_ptr.hpp create mode 100644 host/online_compilation/include/md5.hpp create mode 100644 host/online_compilation/include/op_kernel_args.hpp create mode 100644 host/online_compilation/include/simple_hash.hpp create mode 100644 host/online_compilation/include/stringutils.hpp create mode 100644 host/online_compilation/include/target_properties.hpp create mode 100644 host/online_compilation/include/tmp_dir.hpp create mode 100644 host/online_compilation/include/write_file.hpp create mode 100644 host/online_compilation/kernel.cpp.in create mode 100644 host/online_compilation/kernel_includes.cpp.in create mode 100644 host/online_compilation/kernels_batch.cpp.in create mode 100755 script/cmake-rocm.sh create mode 100755 script/count_vgpr.sh create mode 100755 script/docker-rocm4.1.sh create mode 100755 script/hipclang_opt.sh create mode 100755 script/run.sh diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000..22f2674966 --- /dev/null +++ b/.clang-format @@ -0,0 +1,90 @@ +--- +Language: Cpp +AccessModifierOffset: 0 +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: true +AlignConsecutiveDeclarations: false +AlignEscapedNewlinesLeft: true +AlignOperands: true +AlignTrailingComments: true +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortBlocksOnASingleLine: true +AllowShortCaseLabelsOnASingleLine: true +AllowShortFunctionsOnASingleLine: All +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: true +BinPackArguments: false +BinPackParameters: false +BraceWrapping: + AfterClass: true + AfterControlStatement: true + AfterEnum: true + AfterFunction: true + AfterNamespace: false + AfterObjCDeclaration: true + AfterStruct: true + AfterUnion: true + BeforeCatch: true + BeforeElse: true + IndentBraces: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Custom +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +ColumnLimit: 100 +CommentPragmas: '^ IWYU pragma:' +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ] +IncludeCategories: + - Regex: '^"(llvm|llvm-c|clang|clang-c)/' + Priority: 2 + - Regex: '^(<|"(gtest|isl|json)/)' + Priority: 3 + - Regex: '.*' + Priority: 1 +IndentCaseLabels: false +IndentWidth: 4 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: true +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: true +PenaltyBreakBeforeFirstCallParameter: 19 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 60 +PointerAlignment: Left +ReflowComments: true +SortIncludes: false +SpaceAfterCStyleCast: false +# SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeParens: Never +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +TabWidth: 8 +UseTab: Never +... + diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000000..0cf342bb45 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,42 @@ +cmake_minimum_required(VERSION 2.8.3) +project(modular_convolution) + +list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") + +include(TargetFlags) +include(AddKernels) + +## C++ +enable_language(CXX) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) +message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") + +## OpenMP +if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + # workaround issue hipcc in rocm3.5 cannot find openmp + set(OpenMP_CXX "${CMAKE_CXX_COMPILER}") + set(OpenMP_CXX_FLAGS "-fopenmp=libomp -Wno-unused-command-line-argument") + set(OpenMP_CXX_LIB_NAMES "libomp" "libgomp" "libiomp5") + set(OpenMP_libomp_LIBRARY ${OpenMP_CXX_LIB_NAMES}) + set(OpenMP_libgomp_LIBRARY ${OpenMP_CXX_LIB_NAMES}) + set(OpenMP_libiomp5_LIBRARY ${OpenMP_CXX_LIB_NAMES}) +else() + find_package(OpenMP REQUIRED) +endif() + +message("OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}") +message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}") +message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}") +message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}") + +set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") +link_libraries(${OpenMP_gomp_LIBRARY}) +link_libraries(${OpenMP_pthread_LIBRARY}) + +## HIP +find_package(HIP REQUIRED) +message(STATUS "Build with HIP ${hip_VERSION}") + +add_subdirectory(host) diff --git a/README.md b/README.md new file mode 100644 index 0000000000..6e6019601a --- /dev/null +++ b/README.md @@ -0,0 +1,177 @@ +# How to build and run + +# Docker +``` +docker run \ +-it \ +--rm \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ +rocm/tensorflow:rocm4.2-tf2.4-dev \ +/bin/bash +``` + +# Install Boost for online compilation +https://www.boost.org/doc/libs/1_66_0/more/getting_started/unix-variants.html#easy-build-and-install + + +# Build +Add path of Boost +``` + export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH +``` + +``` +mkdir build && cd build +``` + +cmake cmd. Need to Specify target ID, example below is gfx908 +``` +cmake \ +-D CMAKE_BUILD_TYPE=Release \ +-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 -O3 --amdgpu-target=gfx908 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \ +-D HIP_ONLINE_COMPILER_FLAGS="-DCK_AMD_GPU_GFX908" \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ +-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ +.. +``` + +Build drivers: \ +``conv_fwd_driver_offline`` is (offline compilation) driver for forward convolution, \ +``conv_bwd_driver_offline`` is (offline compilation) driver for backward-data convolution \ +``conv_fwd_driver_online`` is (online compilation) driver for forward convolution +``` + make -j conv_fwd_driver_offline + make -j conv_bwd_driver_offline + make -j conv_fwd_driver_online +``` + +# Run +* layout: 0 = NCHW; 1 = NHWC +* algo: algorithm +* verify: 0 = no verification; 1 = do verification +* init: 0 ~ 5. initialization method +* log: 0 = no log; 1 = do log +* repeat: number of time kernel being launched +``` +######################################################## layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads + ./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 + ./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1 + ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 + ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1 + ./host/driver_offline/conv_bwd_driver_offline 1 5 0 0 0 1 256 256 1024 3 3 14 14 1 1 1 1 1 1 1 1 +``` + +# Result +Forward convoltuion, FP16, NCHW +``` +./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 + +layout: 0 +in: dim 4, lengths {128, 192, 71, 71}, strides {967872, 5041, 71, 1} +wei: dim 4, lengths {256, 192, 3, 3}, strides {1728, 9, 3, 1} +out: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1296, 36, 1} +InLeftPads size 2, {1, 1, } +InRightPads size 2, {1, 1, } +ConvStrides size 2, {2, 2, } +ConvDilations size 2, {1, 1, } +device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw +a_k0_m_k1_grid_desc{216, 256, 8} +b_k0_n_k1_grid_desc{216, 165888, 8} +c_m_n_grid_desc{ 256, 165888} +launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +Average time : 1.4155 ms, 103.686 TFlop/s +``` + +Forward convoltuion, FP16, NCHW +``` + ./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1 + + layout: 0 +in: dim 4, lengths {256, 256, 14, 14}, strides {50176, 196, 14, 1} +wei: dim 4, lengths {1024, 256, 3, 3}, strides {2304, 9, 3, 1} +out: dim 4, lengths {256, 1024, 14, 14}, strides {200704, 196, 14, 1} +InLeftPads size 2, {1, 1, } +InRightPads size 2, {1, 1, } +ConvStrides size 2, {1, 1, } +ConvDilations size 2, {1, 1, } +device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw +a_k0_m_k1_grid_desc{288, 1024, 8} +b_k0_n_k1_grid_desc{288, 50176, 8} +c_m_n_grid_desc{ 1024, 50176} +launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +Average time : 2.21357 ms, 106.959 TFlop/s + ``` + + Forward convolution, FP16, NHWC + ``` + ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 + + layout: 1 +in: dim 4, lengths {128, 71, 71, 192}, strides {967872, 13632, 192, 1} +wei: dim 4, lengths {256, 3, 3, 192}, strides {1728, 576, 192, 1} +out: dim 4, lengths {128, 36, 36, 256}, strides {331776, 9216, 256, 1} +InLeftPads size 2, {1, 1, } +InRightPads size 2, {1, 1, } +ConvStrides size 2, {2, 2, } +ConvDilations size 2, {1, 1, } +device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk +a_k0_m_k1_grid_desc{216, 165888, 8} +b_k0_n_k1_grid_desc{216, 256, 8} +c_m_n_grid_desc{ 165888, 256} +launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +Average time : 1.12014 ms, 131.025 TFlop/s + ``` + + Forward convolution, FP16, NHWC + ``` + ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1 + + layout: 1 +in: dim 4, lengths {256, 14, 14, 256}, strides {50176, 3584, 256, 1} +wei: dim 4, lengths {1024, 3, 3, 256}, strides {2304, 768, 256, 1} +out: dim 4, lengths {256, 14, 14, 1024}, strides {200704, 14336, 1024, 1} +InLeftPads size 2, {1, 1, } +InRightPads size 2, {1, 1, } +ConvStrides size 2, {1, 1, } +ConvDilations size 2, {1, 1, } +device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk +a_k0_m_k1_grid_desc{288, 50176, 8} +b_k0_n_k1_grid_desc{288, 1024, 8} +c_m_n_grid_desc{ 50176, 1024} +launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +Average time : 1.86877 ms, 126.693 TFlop/s + ``` + + Backward data convolution, FP16, NHWC + ``` + ./host/driver_offline/conv_bwd_driver_offline 1 1 0 3 0 1 256 256 1024 3 3 14 14 1 1 1 1 1 1 1 1 + + layout: 1 +in: dim 4, lengths {256, 14, 14, 1024}, strides {200704, 14336, 1024, 1} +wei: dim 4, lengths {256, 3, 3, 1024}, strides {9216, 3072, 1024, 1} +out: dim 4, lengths {256, 14, 14, 256}, strides {50176, 3584, 256, 1} +InLeftPads size 2, {1, 1, } +InRightPads size 2, {1, 1, } +ConvStrides size 2, {1, 1, } +ConvDilations size 2, {1, 1, } +device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk +a_k0_m_k1_grid_desc{288, 50176, 8} +b_k0_n_k1_grid_desc{288, 1024, 8} +c_m_n_grid_desc{ 50176, 1024} +launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +Average time : 2.22461 ms, 106.428 TFlop/s +``` diff --git a/cmake/AddKernels.cmake b/cmake/AddKernels.cmake new file mode 100644 index 0000000000..429ecc47a9 --- /dev/null +++ b/cmake/AddKernels.cmake @@ -0,0 +1,40 @@ + +function(add_kernels SRC_DIR KERNEL_FILES) + set(INIT_KERNELS_LIST) + set(KERNELS_DECLS) + foreach(KERNEL_FILE ${KERNEL_FILES}) + if("${CMAKE_VERSION}" VERSION_LESS 3.0) + configure_file(${KERNEL_FILE} ${KERNEL_FILE}.delete) + else() + set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${KERNEL_FILE}) + endif() + get_filename_component(BASE_NAME ${KERNEL_FILE} NAME_WE) + string(TOUPPER "${BASE_NAME}" KEY_NAME) + string(MAKE_C_IDENTIFIER "${KEY_NAME}" VAR_NAME) + string(APPEND KERNELS_DECLS "extern const size_t APP_KERNEL_${VAR_NAME}_SIZE;\n") + string(APPEND KERNELS_DECLS "extern const unsigned char APP_KERNEL_${VAR_NAME}[];\n") + list(APPEND INIT_KERNELS_LIST " { \"${KEY_NAME}\", std::string(reinterpret_cast(APP_KERNEL_${VAR_NAME}), APP_KERNEL_${VAR_NAME}_SIZE) }") + endforeach() + string(REPLACE ";" ",\n" INIT_KERNELS "${INIT_KERNELS_LIST}") + configure_file(${SRC_DIR}/kernel.cpp.in ${PROJECT_BINARY_DIR}/kernel.cpp) +endfunction() + +function(add_kernel_includes SRC_DIR KERNEL_FILES) + set(INIT_KERNELS_LIST) + foreach(KERNEL_FILE ${KERNEL_FILES}) + if("${CMAKE_VERSION}" VERSION_LESS 3.0) + configure_file(${KERNEL_FILE} ${KERNEL_FILE}.delete) + else() + set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${KERNEL_FILE}) + endif() + get_filename_component(BASE_NAME ${KERNEL_FILE} NAME_WE) + get_filename_component(FILE_NAME ${KERNEL_FILE} NAME) + string(TOUPPER "${BASE_NAME}" KEY_NAME) + string(MAKE_C_IDENTIFIER "${KEY_NAME}" VAR_NAME) + list(APPEND INIT_KERNELS_LIST " { \"${FILE_NAME}\", std::string(reinterpret_cast(${VAR_NAME}), ${VAR_NAME}_SIZE) }") + endforeach() + string(REPLACE ";" ",\n" INIT_KERNELS "${INIT_KERNELS_LIST}") + configure_file(${SRC_DIR}/kernel_includes.cpp.in ${PROJECT_BINARY_DIR}/kernel_includes.cpp) +endfunction() + + diff --git a/cmake/TargetFlags.cmake b/cmake/TargetFlags.cmake new file mode 100644 index 0000000000..4f83fb5d39 --- /dev/null +++ b/cmake/TargetFlags.cmake @@ -0,0 +1,50 @@ + +function(get_target_property2 VAR TARGET PROPERTY) + get_target_property(_pflags ${TARGET} ${PROPERTY}) + if(_pflags) + set(${VAR} ${_pflags} PARENT_SCOPE) + else() + set(${VAR} "" PARENT_SCOPE) + endif() +endfunction() + + +macro(append_flags FLAGS TARGET PROPERTY PREFIX) + get_target_property2(_pflags ${TARGET} ${PROPERTY}) + foreach(FLAG ${_pflags}) + if(TARGET ${FLAG}) + target_flags(_pflags2 ${FLAG}) + string(APPEND ${FLAGS} " ${_pflags2}") + else() + string(APPEND ${FLAGS} " ${PREFIX}${FLAG}") + endif() + endforeach() +endmacro() + +macro(append_link_flags FLAGS TARGET PROPERTY) + get_target_property2(_pflags ${TARGET} ${PROPERTY}) + foreach(FLAG ${_pflags}) + if(TARGET ${FLAG}) + target_flags(_pflags2 ${FLAG}) + string(APPEND ${FLAGS} " ${_pflags2}") + elseif(FLAG MATCHES "^-.*") + string(APPEND ${FLAGS} " ${FLAG}") + elseif(EXISTS ${FLAG}) + string(APPEND ${FLAGS} " ${FLAG}") + else() + string(APPEND ${FLAGS} " -l${FLAG}") + endif() + endforeach() +endmacro() + +function(target_flags FLAGS TARGET) + set(_flags) + append_flags(_flags ${TARGET} "INTERFACE_COMPILE_OPTIONS" "") + append_flags(_flags ${TARGET} "INTERFACE_COMPILE_DEFINITIONS" "-D") + append_flags(_flags ${TARGET} "INTERFACE_INCLUDE_DIRECTORIES" "-isystem ") + append_flags(_flags ${TARGET} "INTERFACE_LINK_DIRECTORIES" "-L ") + append_flags(_flags ${TARGET} "INTERFACE_LINK_OPTIONS" "") + append_link_flags(_flags ${TARGET} "INTERFACE_LINK_LIBRARIES" "") + # message("_flags: ${_flags}") + set(${FLAGS} ${_flags} PARENT_SCOPE) +endfunction() diff --git a/composable_kernel/include/gridwise_operation_wrapper.hpp b/composable_kernel/include/gridwise_operation_wrapper.hpp new file mode 100644 index 0000000000..0a1e07ec57 --- /dev/null +++ b/composable_kernel/include/gridwise_operation_wrapper.hpp @@ -0,0 +1,14 @@ +#ifndef CK_GRIDWISE_OPERATION_KERNEL_WRAPPER +#define CK_GRIDWISE_OPERATION_KERNEL_WRAPPER + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + run_gridwise_operation(Xs... xs) +{ + GridwiseOp{}.Run(xs...); +} + +#endif diff --git a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..5c582dea46 --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,272 @@ +#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// Number of GEMMs = YTilda * XTilda +// GemmM = C +// GemmN = N * HTildaSlice * WTildaSlice +// GemmK = K * YDotSlice * XDotSlice +template +__host__ __device__ constexpr auto +transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( + const DynamicTensorDescriptor& wei_k_y_x_c_grid_desc, + const DynamicTensorDescriptor& out_n_ho_wo_k_grid_desc, + const DynamicTensorDescriptor& in_n_hi_wi_c_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number, + Number, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + constexpr auto IYTilda = Number{}; + constexpr auto IXTilda = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilda = ConvStrideH / GcdStrideDilationH; + const auto XTilda = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilda); + const auto XDot = math::integer_divide_ceil(X, XTilda); + + const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilda and WTilda that contribute to non-padding area of input tensor + const auto IHTildaSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH); + const auto IWTildaSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW); + + const auto IHTildaSliceEnd = + math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildaSliceEnd = + math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin; + const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda); + const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda); + + const auto K1 = GemmK1; + const auto K0 = K / K1; + + // weight tensor + const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_dynamic_tensor_descriptor( + wei_k_y_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilda), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilda), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = + transform_dynamic_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(IYTilda), + make_freeze_transform(IXTilda), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + +#if 1 + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#else + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#endif + + // output tensor + // this add padding check + const auto out_n_hop_wop_k_grid_desc = transform_dynamic_tensor_descriptor( + out_n_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_dynamic_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilda), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilda), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc = + transform_dynamic_tensor_descriptor( + out_n_ydot_htilda_xdot_wtilda_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + +#if 1 + const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#else + const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), + make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#endif + + // input tensor + const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilda, HTilda), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilda, WTilda), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(IYTilda), + make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), + make_freeze_transform(IXTilda), + make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( + in_n_htildaslice_wtildaslice_c_grid_desc, + make_tuple(make_pass_through_transform(C), + make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..377a1ac29b --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,275 @@ +#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// A: out +// B: wei +// C: in +// Number of GEMMs = YTilda * XTilda +// GemmM = N * HTildaSlice * WTildaSlice +// GemmN = C +// GemmK = K * YDotSlice * XDotSlice +template +__host__ __device__ constexpr auto +transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( + const DynamicTensorDescriptor& out_n_ho_wo_k_grid_desc, + const DynamicTensorDescriptor& wei_k_y_x_c_grid_desc, + const DynamicTensorDescriptor& in_n_hi_wi_c_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number, + Number, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + constexpr auto IYTilda = Number{}; + constexpr auto IXTilda = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilda = ConvStrideH / GcdStrideDilationH; + const auto XTilda = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilda); + const auto XDot = math::integer_divide_ceil(X, XTilda); + + const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilda and WTilda that contribute to non-padding area of input tensor + const auto IHTildaSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH); + const auto IWTildaSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW); + + const auto IHTildaSliceEnd = + math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildaSliceEnd = + math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin; + const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda); + const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda); + + const auto K1 = GemmK1; + const auto K0 = K / K1; + + // A: output tensor + // this add padding check + const auto out_n_hop_wop_k_grid_desc = transform_dynamic_tensor_descriptor( + out_n_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_dynamic_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilda), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilda), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc = + transform_dynamic_tensor_descriptor( + out_n_ydot_htilda_xdot_wtilda_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + +#if 1 + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#else + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), + make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#endif + + // B: weight tensor + const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_dynamic_tensor_descriptor( + wei_k_y_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilda), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilda), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = + transform_dynamic_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(IYTilda), + make_freeze_transform(IXTilda), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + +#if 1 + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#else + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#endif + + // C: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilda, HTilda), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilda, WTilda), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(IYTilda), + make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), + make_freeze_transform(IXTilda), + make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( + in_n_htildaslice_wtildaslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..404129365f --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -0,0 +1,263 @@ +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad( + const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, + const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, + const DynamicTensorDescriptor& out_n_k_ho_wo_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hip_wip_global_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmk_gemmn_global_desc = + transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple( + wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); +} + +template +__host__ __device__ constexpr auto +transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad( + const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, + const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, + const DynamicTensorDescriptor& out_n_k_ho_wo_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + assert(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0); + + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmk_gemmn_global_desc = + transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple( + wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); +} + +template +__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1( + const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, + const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, + const DynamicTensorDescriptor& out_n_k_ho_wo_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && + ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && + InRightPadW == 0); + + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple( + wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..79051d9512 --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,179 @@ +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad( + const DynamicTensorDescriptor& wei_k_y_x_c_grid_desc, + const DynamicTensorDescriptor& in_n_hi_wi_c_grid_desc, + const DynamicTensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + // weight tensor + const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmn_grid_desc = + transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + return make_tuple( + wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc); +} + +template +__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1( + const DynamicTensorDescriptor& wei_k_y_x_c_grid_desc, + const DynamicTensorDescriptor& in_n_hi_wi_c_grid_desc, + const DynamicTensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && + ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && + InRightPadW == 0); + + // weight tensor + const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_gemmk_gemmn_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, C)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // output tensor + const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + return make_tuple( + wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..49ae26518e --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp @@ -0,0 +1,129 @@ +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +__host__ __device__ constexpr auto +transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( + const DynamicTensorDescriptor& wei_k_c_y_x_grid_desc, + const DynamicTensorDescriptor& in_n_c_hi_wi_grid_desc, + const DynamicTensorDescriptor& out_n_k_ho_wo_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); + const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = N * Ho * Wo; + const auto GemmK = C * Y * X; + const auto GemmK0 = GemmK / GemmK1; + + // weight tensor + const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + wei_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // input tensor + const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor( + in_n_c_hi_wi_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor( + in_n_c_hip_wip_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmk_gemmn_grid_desc = + transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + in_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..5814e66766 --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,129 @@ +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +__host__ __device__ constexpr auto +transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad( + const DynamicTensorDescriptor& wei_k_y_x_c_grid_desc, + const DynamicTensorDescriptor& in_n_hi_wi_c_grid_desc, + const DynamicTensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = N * Ho * Wo; + const auto GemmK = C * Y * X; + const auto GemmK0 = GemmK / GemmK1; + + // weight tensor + const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + wei_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // input tensor + const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmn_grid_desc = + transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + in_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..ad9d99f4e7 --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,132 @@ +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// A: in +// B: wei +// C: out +// GemmM = N * Ho * Wo +// GemmN = K +// GemmK = Y * X * C +template +__host__ __device__ constexpr auto +transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( + const DynamicTensorDescriptor& in_n_hi_wi_c_grid_desc, + const DynamicTensorDescriptor& wei_k_y_x_c_grid_desc, + const DynamicTensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = N * Ho * Wo; + const auto GemmN = K; + const auto GemmK = Y * X * C; + const auto GemmK0 = GemmK / GemmK1; + + // A: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmm_grid_desc = + transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmk_gemmn_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..e709f768cb --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp @@ -0,0 +1,132 @@ +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM0 = 1 +// GemmM1 = K +// GemmN0 = N0 +// GemmN1 = (N / N0) * Ho * Wo +// GemmK0 = (C / C0) * Y * X +// GemmK1 = C0 +template +__host__ __device__ constexpr auto +transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( + const DynamicTensorDescriptor& wei_k_c_y_x_grid_desc, + const DynamicTensorDescriptor& in_n_c_hi_wi_grid_desc, + const DynamicTensorDescriptor& out_n_k_ho_wo_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const N0Type& N0, + const C0Type& C0) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); + const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto N1 = N / N0; + const auto C1 = C / C0; + + // weight tensor + const auto wei_gk0_gm0_gm1_gk1_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), + make_tuple(make_unmerge_transform(make_tuple(I1, K)), + make_unmerge_transform(make_tuple(C0, C1 * Y * X))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1, 2>{}, Sequence<3, 0>{})); + + // input tensor + const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor( + in_n_c_hi_wi_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n0_n1_c0_c1_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor( + in_n_c_hip_wip_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(N0, N1)), + make_unmerge_transform(make_tuple(C0, C1)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{})); + + const auto in_gk0_gn0_gn1_gk1_grid_desc = transform_dynamic_tensor_descriptor( + in_n0_n1_c0_c1_y_ho_x_wo_grid_desc, + make_tuple(make_merge_transform(make_tuple(C1, Y, X)), + make_pass_through_transform(N0), + make_merge_transform(make_tuple(N1, Ho, Wo)), + make_pass_through_transform(C0)), + make_tuple(Sequence<3, 4, 6>{}, Sequence<0>{}, Sequence<1, 5, 7>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + // output tensor + const auto out_n_k_howo_grid_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)); + + const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor( + out_n_k_howo_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(N0, N1)), + make_unmerge_transform(make_tuple(I1, K)), + make_pass_through_transform(Ho * Wo)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{})); + + const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor( + out_n0_n1_1_k_howo_grid_desc, + make_tuple(make_pass_through_transform(I1), + make_pass_through_transform(K), + make_pass_through_transform(N0), + make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))), + make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return make_tuple( + wei_gk0_gm0_gm1_gk1_grid_desc, in_gk0_gn0_gn1_gk1_grid_desc, out_gm0_gm1_gn0_gn1_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/cluster_descriptor.hpp b/composable_kernel/include/tensor_description/cluster_descriptor.hpp new file mode 100644 index 0000000000..c3523623d9 --- /dev/null +++ b/composable_kernel/include/tensor_description/cluster_descriptor.hpp @@ -0,0 +1,33 @@ +#ifndef CK_CLUSTER_DESCRIPTOR_HPP +#define CK_CLUSTER_DESCRIPTOR_HPP + +#include "common_header.hpp" +#include "tensor_adaptor.hpp" + +namespace ck { + +template ::type> +__host__ __device__ constexpr auto make_cluster_descriptor_v2( + const Lengths& lengths, + ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{}) +{ + constexpr index_t ndim_low = Lengths::Size(); + + const auto reordered_lengths = container_reorder_given_new2old(lengths, order); + + const auto low_lengths = generate_tuple( + [&](auto idim_low) { return reordered_lengths[idim_low]; }, Number{}); + + const auto transform = make_merge_transform(low_lengths); + + constexpr auto low_dim_old_top_ids = ArrangeOrder{}; + + constexpr auto up_dim_new_top_ids = Sequence<0>{}; + + return make_single_stage_tensor_adaptor( + make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids)); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp b/composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp new file mode 100644 index 0000000000..967517bef7 --- /dev/null +++ b/composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp @@ -0,0 +1,1737 @@ +#ifndef CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HPP +#define CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HPP + +#include "common_header.hpp" +#include "multi_index.hpp" + +namespace ck { + +template +struct DynamicPassThrough +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{})); + + UpLengths up_lengths_; + + __host__ __device__ constexpr DynamicPassThrough() = default; + + __host__ __device__ constexpr DynamicPassThrough(const LowLength& low_length) + : up_lengths_{make_tuple(low_length)} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ static void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up) + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}]; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicPassThrough, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +struct DynamicPad +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{} + LeftPad{} + RightPad{})); + + UpLengths up_lengths_; + LeftPad left_pad_; + RightPad right_pad_; + + __host__ __device__ constexpr DynamicPad() = default; + + __host__ __device__ constexpr DynamicPad(const LowLength& low_length, + const LeftPad& left_pad, + const RightPad& right_pad) + : up_lengths_{make_tuple(low_length + left_pad + right_pad)}, + left_pad_{left_pad}, + right_pad_{right_pad} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return SkipIsValidCheck; + } + + template + __host__ __device__ constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const + { + return SkipIsValidCheck || ((idx_up[Number<0>{}] >= left_pad_) && + (idx_up[Number<0>{}] < up_lengths_[Number<0>{}] - right_pad_)); + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicPad, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("left_pad_ %d", index_t{left_pad_}); + printf("right_pad_ %d", index_t{right_pad_}); + printf("}"); + } +}; + +template +struct DynamicLeftPad +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{} + LeftPad{})); + + UpLengths up_lengths_; + LeftPad left_pad_; + + __host__ __device__ constexpr DynamicLeftPad() = default; + + __host__ __device__ constexpr DynamicLeftPad(const LowLength& low_length, + const LeftPad& left_pad) + : up_lengths_{make_tuple(low_length + left_pad)}, left_pad_{left_pad} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return SkipIsValidCheck; + } + + template + __host__ __device__ constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const + { + return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_); + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicLeftPad, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("left_pad_ %d", index_t{left_pad_}); + printf("}"); + } +}; + +template +struct DynamicRightPad +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{} + RightPad{})); + + UpLengths up_lengths_; + LowLength low_length_; + RightPad right_pad_; + + __host__ __device__ constexpr DynamicRightPad() = default; + + __host__ __device__ constexpr DynamicRightPad(const LowLength& low_length, + const RightPad& right_pad) + : up_lengths_{make_tuple(low_length + right_pad)}, + low_length_{low_length}, + right_pad_{right_pad} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}]; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return SkipIsValidCheck; + } + + template + __host__ __device__ constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const + { + return SkipIsValidCheck || (idx_up[Number<0>{}] < low_length_); + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicRightPad, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("low_length_ %d", index_t{low_length_}); + printf("left_pad_ %d", index_t{right_pad_}); + printf("}"); + } +}; + +// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] +// UpLengths and Coefficients can be either of the followings: +// 1) Tuple of index_t, which is known at run-time, or +// 2) Tuple of Number, which is known at compile-time, or +// 3) Tuple of mixture of index_t and Number, which is known partially at run-time and partially +// at compile-time +template ::type = false> +struct DynamicEmbed +{ + static constexpr index_t NDimUp = UpLengths::Size(); + + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex; + + UpLengths up_lengths_; + Coefficients coefficients_; + + __host__ __device__ constexpr DynamicEmbed() = default; + + __host__ __device__ constexpr DynamicEmbed(const UpLengths& up_lengths, + const Coefficients& coefficients) + : up_lengths_{up_lengths}, coefficients_{coefficients} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == NDimUp, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = 0; + + static_for<0, NDimUp, 1>{}([&idx_low, &idx_up, this](auto i) { + idx_low(Number<0>{}) += idx_up[i] * this->coefficients_[i]; + }); + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) const + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == NDimUp && + LowIdx::Size() == 1 && UpIdx::Size() == NDimUp, + "wrong! inconsistent # of dimension"); + + idx_diff_low(Number<0>{}) = 0; + + static_for<0, NDimUp, 1>{}( + [&](auto i) { idx_diff_low(Number<0>{}) += idx_diff_up[i] * coefficients_[i]; }); + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicEmbed, "); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("coefficients_ "); + print_multi_index(coefficients_); + printf("}"); + } +}; + +// Implementation of "Merge" transformation primitive that uses regular to do lowering of +// multi-index and use carry-and-borrow check to do lowering of multi-index delta +template +struct DynamicMerge_v1_carry_check +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using LowLengthsScan = decltype( + container_reverse_exclusive_scan(LowLengths{}, math::multiplies_v2{}, Number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies_v2{}, Number<1>{}))); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + UpLengths up_lengths_; + + __host__ __device__ constexpr DynamicMerge_v1_carry_check() = default; + + __host__ __device__ constexpr DynamicMerge_v1_carry_check(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies_v2{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + // normal division + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_low(i) = tmp / this->low_lengths_scan_[i]; + tmp -= idx_low[i] * this->low_lengths_scan_[i]; + }); + + idx_low(Number{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex_1a(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& /* idx_up_new */, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions. + // However, + // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const + // can be calculated at compile-time. + // 2) If idx_diff_up is not known at compile-time, but its value + // doesn't change during the whole kernel execution, then + // idx_diff_low_const also + // doesn't change during the whole kernel execution. Compiler generated + // ISA should + // only caclculate idx_diff_low_const once and save it durinng the whole + // kernel execution + // If neither 1) nor 2) is satisfied, then the calculation will also be + // computed at + // run-time each time this function is called, and can be very expensive. + LowerIndex idx_diff_low_const; + LowerIndex idx_low_length_minus_idx_diff_low_const; + LowerIndex idx_low_length_plus_idx_diff_low_const; + +#if !CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = tmp / low_lengths_scan_[i]; + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = tmp; + + static_for<0, NDimLow, 1>{}([&](auto i) { + idx_low_length_minus_idx_diff_low_const(i) = low_lengths_[i] - idx_diff_low_const[i]; + + idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i]; + }); +#else + // Hack: this force result into SGPR. Need to make sure the result is thread invariant + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]); + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = __builtin_amdgcn_readfirstlane(tmp); + + static_for<0, NDimLow, 1>{}([&](auto i) { + idx_low_length_minus_idx_diff_low_const(i) = + __builtin_amdgcn_readfirstlane(low_lengths_[i] - idx_diff_low_const[i]); + + idx_low_length_plus_idx_diff_low_const(i) = + __builtin_amdgcn_readfirstlane(low_lengths_[i] + idx_diff_low_const[i]); + }); +#endif + + if constexpr(Hack == 1) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t carry = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] + carry; + + bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i]; + + idx_diff_low(i) = + do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i]; + + idx_diff_low(i) += carry; + + carry = do_carry ? 1 : 0; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry; + + idx_low += idx_diff_low; + } + else if constexpr(Hack == 2) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t borrow = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] - borrow; + + bool do_borrow = idx_low_tmp < -idx_diff_low_const[i]; + + idx_diff_low(i) = + do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low_const[i]; + + idx_diff_low(i) -= borrow; + + borrow = do_borrow ? 1 : 0; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] - borrow; + + idx_low += idx_diff_low; + } + else + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t carry = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] + carry; + + bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i]; + bool do_borrow = idx_low_tmp < -idx_diff_low_const[i]; + + idx_diff_low(i) = + do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i]; + idx_diff_low(i) = + do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low[i]; + + idx_diff_low(i) += carry; + + carry = do_carry ? 1 : 0; + carry = do_borrow ? -1 : carry; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry; + + idx_low += idx_diff_low; + } + } + + template + __host__ __device__ void UpdateLowerIndex_1b(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& /* idx_up_new */, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions. + // However, + // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const + // can be calculated at compile-time. + // 2) If idx_diff_up is not known at compile-time, but its value + // doesn't change during the whole kernel execution, then + // idx_diff_low_const also + // doesn't change during the whole kernel execution. Compiler generated + // ISA should + // only caclculate idx_diff_low_const once and save it durinng the whole + // kernel execution + // If neither 1) nor 2) is satisfied, then the calculation will also be + // computed at + // run-time each time this function is called, and can be very expensive. + LowerIndex idx_diff_low_const; + LowerIndex idx_low_length_minus_idx_diff_low_const; + LowerIndex idx_low_length_plus_idx_diff_low_const; + +#if !CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = tmp / low_lengths_scan_[i]; + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = tmp; + + static_for<0, NDimLow, 1>{}([&](auto i) { + idx_low_length_minus_idx_diff_low_const(i) = low_lengths_[i] - idx_diff_low_const[i]; + + idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i]; + }); +#else + // Hack: this force result into SGPR. Need to make sure the result is thread invariant + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]); + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = __builtin_amdgcn_readfirstlane(tmp); + + static_for<0, NDimLow, 1>{}([&](auto i) { + idx_low_length_minus_idx_diff_low_const(i) = + __builtin_amdgcn_readfirstlane(low_lengths_[i] - idx_diff_low_const[i]); + + idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i]; + }); +#endif + + if constexpr(Hack == 1) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t carry = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] + carry; + + bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i]; + + idx_diff_low(i) = + do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i]; + + idx_diff_low(i) += carry; + + carry = do_carry ? 1 : 0; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry; + + idx_low += idx_diff_low; + } + else if constexpr(Hack == 2) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t borrow = 0; + + static_for{}([&](auto i) { + index_t negative_idx_low_tmp = borrow - idx_low[i]; + + bool do_borrow = negative_idx_low_tmp > idx_diff_low_const[i]; + + idx_diff_low(i) = + do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low_const[i]; + + idx_diff_low(i) -= borrow; + + borrow = do_borrow ? 1 : 0; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] - borrow; + + idx_low += idx_diff_low; + } + else + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t carry = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] + carry; + + bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i]; + bool do_borrow = idx_low_tmp < -idx_diff_low_const[i]; + + idx_diff_low(i) = + do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i]; + idx_diff_low(i) = + do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low[i]; + + idx_diff_low(i) += carry; + + carry = do_carry ? 1 : 0; + carry = do_borrow ? -1 : carry; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry; + + idx_low += idx_diff_low; + } + } + + template + __host__ __device__ void UpdateLowerIndex_2(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& /* idx_up_new */, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions. + // However, + // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const + // can be calculated at compile-time. + // 2) If idx_diff_up is not known at compile-time, but its value + // doesn't change during the whole kernel execution, then + // idx_diff_low_const also + // doesn't change during the whole kernel execution. Compiler generated + // ISA should + // only caclculate idx_diff_low_const once and save it durinng the whole + // kernel execution + // If neither 1) nor 2) is satisfied, then the calculation will also be + // computed at run-time each time this function is called, and can be + // very expensive. + LowerIndex idx_diff_low_const; + +#if !CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = tmp / low_lengths_scan_[i]; + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = tmp; +#else + // Hack: this force result into SGPR. Need to make sure the result is thread invariant + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]); + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = __builtin_amdgcn_readfirstlane(tmp); +#endif + + if constexpr(Hack == 1) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + bool do_carry = 0; + + static_for{}([&](auto i) { + idx_diff_low(i) = idx_diff_low_const[i] + do_carry; + + index_t idx_low_tmp = idx_low[i] + idx_diff_low[i]; + + do_carry = idx_low_tmp >= low_lengths_[i]; + +#if 0 + // TODO: use exec-mask inline asm, which use 1 VALU + if(do_carry) + { + idx_diff_low(i) -= low_lengths_[i]; + } +#elif 1 + // this use 2 VALU + idx_diff_low(i) = do_carry ? idx_diff_low[i] - low_lengths_[i] : idx_diff_low[i]; +#elif 1 + // this use 2 VALU + index_t idx_diff_low_tmp = idx_diff_low[i] - low_lengths_[i]; + idx_diff_low(i) = do_carry ? idx_diff_low_tmp : idx_diff_low[i]; +#endif + + idx_low(i) += idx_diff_low[i]; + }); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_low_const[I0] + do_carry; + + idx_low(I0) += idx_diff_low[I0]; + } + else if constexpr(Hack == 2) + { + // do borrow check on each low dimension in reversed order + // do not need to check the first dimension + bool do_borrow = 0; + + static_for{}([&](auto i) { + idx_diff_low(i) = idx_diff_low_const[i] - do_borrow; + + index_t idx_low_tmp = idx_low[i] + idx_diff_low[i]; + + do_borrow = idx_low_tmp < 0; + +#if 0 + // TODO: use exec-mask inline asm + if(do_borrow) + { + idx_diff_low(i) += low_lengths_[i]; + } +#elif 1 + idx_diff_low(i) = do_borrow ? idx_diff_low[i] + low_lengths_[i] : idx_diff_low[i]; +#elif 1 + index_t idx_diff_low_tmp = idx_diff_low[i] + low_lengths_[i]; + idx_diff_low(i) = do_borrow ? idx_diff_low_tmp : idx_diff_low[i]; +#endif + + idx_low(i) += idx_diff_low[i]; + }); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_low_const[I0] - do_borrow; + + idx_low(I0) += idx_diff_low[I0]; + } + else + { + // not implemented + } + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { +#if 1 + UpdateLowerIndex_1a(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); +#elif 0 + UpdateLowerIndex_1b(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); +#else + UpdateLowerIndex_2(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); +#endif + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicMerge_v1_carry_check, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_scan_ "); + print_multi_index(low_lengths_scan_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +struct lambda_merge_generate_MagicDivision_calculate_magic_multiplier +{ + template + __host__ __device__ constexpr auto operator()(Number i) const + { + return MagicDivision::CalculateMagicMultiplier(LowLengths{}[i]); + } +}; + +template +struct lambda_merge_generate_MagicDivision_calculate_magic_shift +{ + template + __host__ __device__ constexpr auto operator()(Number i) const + { + return MagicDivision::CalculateMagicShift(LowLengths{}[i]); + } +}; + +// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering +// of both multi-index and delta of multi-index +// Caution: +// 1. The magic number division implementation being used would produce correct result if the +// dividended is uint32_t and its value is with in 31-bit value range of uint32_t. +// 2. The magic number division for int32_t dividened has not been implemented, the int32_t +// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for +// uint32_t is then used. +// 3. For Merge primitive, upper-index is the dividend. +// 4. When upper-index is uint32_t, its value need to be within 31-bit range. +// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be +// non-negative. +template +struct DynamicMerge_v2_magic_division +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies_v2{}, Number<1>{}))); + + using LowLengthsMagicDivisorMultipiler = decltype( + generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_multiplier{}, + Number{})); + + using LowLengthsMagicDivisorShift = decltype( + generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift{}, + Number{})); + + LowLengths low_lengths_; + LowLengthsMagicDivisorMultipiler low_lengths_magic_divisor_multiplier_; + LowLengthsMagicDivisorShift low_lengths_magic_divisor_shift_; + UpLengths up_lengths_; + + __host__ __device__ constexpr DynamicMerge_v2_magic_division() = default; + + __host__ __device__ constexpr DynamicMerge_v2_magic_division(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_magic_divisor_multiplier_{generate_tuple( + [&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths[i]); }, + Number{})}, + low_lengths_magic_divisor_shift_{generate_tuple( + [&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths[i]); }, + Number{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies_v2{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + static_for{}([&, this](auto i) { + index_t tmp2 = + MagicDivision::DoMagicDivision(tmp, + this->low_lengths_magic_divisor_multiplier_[i], + this->low_lengths_magic_divisor_shift_[i]); + idx_low(i) = tmp - tmp2 * this->low_lengths_[i]; + tmp = tmp2; + }); + + idx_low(Number<0>{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up_new[Number<0>{}]; + + static_for{}([&, this](auto i) { + index_t tmp2 = + MagicDivision::DoMagicDivision(tmp, + this->low_lengths_magic_divisor_multiplier_[i], + this->low_lengths_magic_divisor_shift_[i]); + + index_t idx_low_old = idx_low[i]; + + idx_low(i) = tmp - tmp2 * this->low_lengths_[i]; + tmp = tmp2; + + idx_diff_low(i) = idx_low[i] - idx_low_old; + }); + + idx_diff_low(Number<0>{}) = tmp - idx_low(Number<0>{}); + + idx_low(Number<0>{}) = tmp; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicMerge_v2_magic_division, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_magic_divisor_multiplier_ "); + print_multi_index(low_lengths_magic_divisor_multiplier_); + printf("low_lengths_magic_divisor_shift_ "); + print_multi_index(low_lengths_magic_divisor_shift_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering +// of both multi-index and delta of multi-index +// Caution: +// 1. The magic number division implementation being used would produce correct result if the +// dividended is uint32_t and its value is with in 31-bit value range of uint32_t. +// 2. The magic number division for int32_t dividened has not been implemented, the int32_t +// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for +// uint32_t is then used. +// 3. For Merge primitive, upper-index is the dividend. +// 4. When upper-index is uint32_t, its value need to be within 31-bit range. +// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be +// non-negative. +template +struct DynamicMerge_v2r2_magic_division +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using LowLengthsScan = decltype( + container_reverse_exclusive_scan(LowLengths{}, math::multiplies_v2{}, Number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies_v2{}, Number<1>{}))); + + using LowLengthsScanMagicDivisorMultipiler = decltype(generate_tuple( + lambda_merge_generate_MagicDivision_calculate_magic_multiplier{}, + Number{})); + + using LowLengthsScanMagicDivisorShift = decltype( + generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift{}, + Number{})); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + LowLengthsScanMagicDivisorMultipiler low_lengths_scan_magic_divisor_multiplier_; + LowLengthsScanMagicDivisorShift low_lengths_scan_magic_divisor_shift_; + UpLengths up_lengths_; + + __host__ __device__ constexpr DynamicMerge_v2r2_magic_division() = default; + + __host__ __device__ constexpr DynamicMerge_v2r2_magic_division(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})}, + low_lengths_scan_magic_divisor_multiplier_{generate_tuple( + [&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths_scan_[i]); }, + Number{})}, + low_lengths_scan_magic_divisor_shift_{generate_tuple( + [&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths_scan_[i]); }, + Number{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies_v2{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&, this](auto i) { + idx_low(i) = + MagicDivision::DoMagicDivision(tmp, + this->low_lengths_scan_magic_divisor_multiplier_[i], + this->low_lengths_scan_magic_divisor_shift_[i]); + + tmp -= idx_low[i] * this->low_lengths_scan_[i]; + }); + + idx_low(Number{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up_new[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&, this](auto i) { + index_t idx_low_old = idx_low[i]; + + idx_low(i) = + MagicDivision::DoMagicDivision(tmp, + this->low_lengths_scan_magic_divisor_multiplier_[i], + this->low_lengths_scan_magic_divisor_shift_[i]); + + idx_diff_low(i) = idx_low[i] - idx_low_old; + + tmp -= idx_low[i] * this->low_lengths_scan_[i]; + }); + + idx_diff_low(Number{}) = tmp - idx_low[Number{}]; + + idx_low(Number{}) = tmp; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicMerge_v2r2_magic_division, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_scan "); + print_multi_index(low_lengths_scan_); + printf("low_lengths_scan_magic_divisor_multiplier_ "); + print_multi_index(low_lengths_scan_magic_divisor_multiplier_); + printf("low_lengths_scan_magic_divisor_shift_ "); + print_multi_index(low_lengths_scan_magic_divisor_shift_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +struct DynamicUnMerge +{ + static constexpr index_t NDimUp = UpLengths::Size(); + + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex; + + using UpLengthsScan = + decltype(container_reverse_exclusive_scan(UpLengths{}, math::multiplies_v2{}, Number<1>{})); + + UpLengths up_lengths_; + UpLengthsScan up_lengths_scan_; + + __host__ __device__ constexpr DynamicUnMerge() = default; + + __host__ __device__ constexpr DynamicUnMerge(const UpLengths& up_lengths) + : up_lengths_{up_lengths}, + up_lengths_scan_{ + container_reverse_exclusive_scan(up_lengths, math::multiplies_v2{}, Number<1>{})} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + if constexpr(!Use24BitIntegerCalculation) + { + idx_low(Number<0>{}) = idx_up[Number{}]; + + static_for<0, NDimUp - 1, 1>{}( + [&](auto i) { idx_low(Number<0>{}) += idx_up[i] * up_lengths_scan_[i]; }); + } + else + { + idx_low(Number<0>{}) = idx_up[Number{}]; + + static_for<0, NDimUp - 1, 1>{}([&](auto i) { + idx_low(Number<0>{}) = + (0x00ffffff & idx_low[Number<0>{}]) + + (0x00ffffff & idx_up[i]) * (0x00ffffff & up_lengths_scan_[i]); + }); + } + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) const + { + CalculateLowerIndex(idx_diff_low, idx_diff_up); + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicUnMerge, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("up_lengths_scan_"); + print_multi_index(up_lengths_scan_); + printf("}"); + } +}; + +template +struct DynamicFreeze +{ + LowerIndex low_idx_; + + __host__ __device__ constexpr DynamicFreeze() = default; + + __host__ __device__ constexpr DynamicFreeze(const LowerIndex& low_idx) : low_idx_{low_idx} {} + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 0; } + + __host__ __device__ static constexpr auto GetUpperLengths() { return Tuple<>{}; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& /* idx_up */) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 0, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = low_idx_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& /* idx_diff_up */, + LowIdx& /* idx_low */, + const UpIdx& /* idx_up_new */, + Number) + { + idx_diff_low(Number<0>{}) = 0; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("DynamicFreeze"); + printf("low_idx_ %d", index_t{low_idx_}); + } +}; + +// Insert a dangling upper dimension without lower dimension +template +struct DynamicInsert +{ + using UpLengths = decltype(make_tuple(UpperLength{})); + + UpLengths up_lengths_; + + __host__ __device__ constexpr DynamicInsert() = default; + + __host__ __device__ constexpr DynamicInsert(const UpperLength& up_length) + : up_lengths_{make_tuple(up_length)} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 0; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr auto GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx&, const UpIdx&) const + { + static_assert(LowIdx::Size() == 0 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + } + + template + __host__ __device__ static void + UpdateLowerIndex(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&, Number) + { + static_assert(LowIdxDiff::Size() == 0 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 0 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("DynamicInsert"); + print_multi_index(up_lengths_); + } +}; + +template +struct DynamicVectorize +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(UpLength{})); + + UpLengths up_lengths_; + VectorSize vector_size_; + + __host__ __device__ constexpr DynamicVectorize() = default; + + __host__ __device__ constexpr DynamicVectorize(const VectorSize& vector_size, + const UpLength& up_length) + : vector_size_{vector_size}, up_lengths_{make_tuple(up_length)} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = vector_size_ * idx_up[Number<0>{}]; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) const + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = vector_size_ * idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicVectorize, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +struct DynamicSlice +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(SliceEnd{} - SliceBegin{})); + + UpLengths up_lengths_; + SliceBegin slice_begin_; + SliceEnd slice_end_; + + __host__ __device__ constexpr DynamicSlice() = default; + + __host__ __device__ constexpr DynamicSlice(const LowLength&, + const SliceBegin& slice_begin, + const SliceEnd& slice_end) + : up_lengths_{make_tuple(slice_end - slice_begin)}, + slice_begin_{slice_begin}, + slice_end_{slice_end} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}] + slice_begin_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx&) const + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicSlice, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("slice_begin_ %d", index_t{slice_begin_}); + printf("slice_end %d", index_t{slice_end_}); + printf("}"); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp b/composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp new file mode 100644 index 0000000000..b3e1c60485 --- /dev/null +++ b/composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp @@ -0,0 +1,104 @@ +#ifndef CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HELPER_HPP +#define CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HELPER_HPP + +#include "common_header.hpp" +#include "dynamic_multi_index_transform.hpp" + +namespace ck { + +template +__host__ __device__ constexpr auto make_pass_through_transform(const LowLength& low_length) +{ + return DynamicPassThrough{low_length}; +} + +template +__host__ __device__ constexpr auto +make_pad_transform(const LowLength& low_length, + const LeftPad& left_pad, + const RightPad& right_pad, + integral_constant = integral_constant{}) +{ + return DynamicPad{ + low_length, left_pad, right_pad}; +} + +template +__host__ __device__ constexpr auto make_left_pad_transform( + const LowLength& low_length, + const LeftPad& left_pad, + integral_constant = integral_constant{}) +{ + return DynamicLeftPad{low_length, left_pad}; +} + +template +__host__ __device__ constexpr auto make_right_pad_transform( + const LowLength& low_length, + const RightPad& right_pad, + integral_constant = integral_constant{}) +{ + return DynamicRightPad{low_length, right_pad}; +} + +template ::type = false> +__host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths, + const Coefficients& coefficients) +{ + return DynamicEmbed{up_lengths, coefficients}; +} + +template +__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths) +{ +#if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION + return DynamicMerge_v1_carry_check{low_lengths}; +#else +#if 1 + return DynamicMerge_v2_magic_division{low_lengths}; +#else + return DynamicMerge_v2r2_magic_division{low_lengths}; +#endif +#endif +} + +template +__host__ __device__ constexpr auto +make_merge_transform_v2_magic_division(const LowLengths& low_lengths) +{ + return DynamicMerge_v2_magic_division{low_lengths}; +} + +template +__host__ __device__ constexpr auto make_unmerge_transform( + const UpLengths& up_lengths, + integral_constant = integral_constant{}) +{ + return DynamicUnMerge{up_lengths}; +} + +template +__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx) +{ + return DynamicFreeze{low_idx}; +} + +template +__host__ __device__ constexpr auto make_slice_transform(const LowLength& low_length, + const SliceBegin& slice_begin, + const SliceEnd& slice_end) +{ + return DynamicSlice{low_length, slice_begin, slice_end}; +} + +template +__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size, + const UpLength& up_length) +{ + return DynamicVectorize{vector_size, up_length}; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp b/composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp new file mode 100644 index 0000000000..b9ca26c879 --- /dev/null +++ b/composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp @@ -0,0 +1,596 @@ +#ifndef CK_DYNAMIC_TENSOR_DESCRIPTOR_HPP +#define CK_DYNAMIC_TENSOR_DESCRIPTOR_HPP + +#include "common_header.hpp" +#include "dynamic_multi_index_transform.hpp" + +namespace ck { + +template +struct DynamicTensorCoordinate; + +template +struct DynamicTensorCoordinateIterator; + +// Transforms: Tuple +// LowerDimensionIdss : Tuple, ...> +// UpperDimensionIdss : Tuple, ...> +// VisibleDimensionIds> : Sequence<...> +template +struct DynamicTensorDescriptor +{ + // TODO make these private + __host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); } + + __host__ __device__ static constexpr index_t GetNumOfVisibleDimension() + { + return VisibleDimensionIds::Size(); + } + + __host__ __device__ static constexpr index_t GetNumOfHiddenDimension() + { + constexpr auto all_low_dim_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{}); + + constexpr auto all_up_dim_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{}); + + constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); + + using unique_sort_all_dim_ids = typename sequence_unique_sort, + math::equal>::type; + + return unique_sort_all_dim_ids::Size(); + } + + __host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms) + { + const auto lengths = generate_tuple( + [&](auto idim_visible) { + constexpr auto tmp = GetTransformAndItsUpperDimension(idim_visible); + + constexpr index_t itran = tmp[Number<0>{}]; + constexpr index_t idim_up = tmp[Number<1>{}]; + constexpr bool found = tmp[Number<2>{}]; + + static_assert(found == true, + "wrong! not found matching transformation and upper-dimension"); + + const auto length = + transforms[Number{}].GetUpperLengths()[Number{}]; + + return length; + }, + Number{}); + + // TODO: make container_reduce support tuple of Number and index_t + return container_reduce(lengths, math::multiplies_v2{}, Number<1>{}); + } + + template + __host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number) + { + constexpr auto idim_visible = Number{}; + + constexpr index_t idim_hidden = VisibleDimensionIds::At(idim_visible); + + index_t itran_found = 0; + index_t idim_up_found = 0; + bool found = false; + + static_for<0, ntransform_, 1>{}([&](auto itran) { + constexpr auto up_dim_ids = UpperDimensionIdss{}[itran]; + + static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) { + if constexpr(up_dim_ids[idim_up] == idim_hidden) + { + itran_found = itran; + idim_up_found = idim_up; + found = true; + } + }); + }); + + return make_tuple(itran_found, idim_up_found, found); + } + + constexpr static index_t ntransform_ = GetNumOfTransform(); + constexpr static index_t ndim_visible_ = GetNumOfVisibleDimension(); + constexpr static index_t ndim_hidden_ = GetNumOfHiddenDimension(); + + using VisibleIndex = MultiIndex; + using HiddenIndex = MultiIndex; + using Coordinate = DynamicTensorCoordinate; + + // may be index_t or Number<> + using ElementSize = remove_cv_t; + + public: + __host__ __device__ constexpr DynamicTensorDescriptor() = default; + + __host__ __device__ constexpr DynamicTensorDescriptor(const Transforms& transforms, + ElementSpaceSize element_space_size) + : transforms_{transforms}, + element_size_{InitializeElementSize(transforms)}, + element_space_size_{element_space_size} + + { + static_assert(Transforms::Size() == ntransform_ && + LowerDimensionIdss::Size() == ntransform_ && + UpperDimensionIdss::Size() == ntransform_, + "wrong! inconsistent # of transformations"); + + // TODO check dependency of dimensions is valid + } + + __host__ __device__ static constexpr index_t GetNumOfDimension() + { + return GetNumOfVisibleDimension(); + } + + template + __host__ __device__ constexpr auto GetLength(Number) const + { + static_assert(IDim >= 0 && IDim < ndim_visible_, "wrong! out of range"); + + constexpr auto tmp = GetTransformAndItsUpperDimension(Number{}); + + constexpr index_t itran = tmp[Number<0>{}]; + constexpr index_t idim_up = tmp[Number<1>{}]; + constexpr bool found = tmp[Number<2>{}]; + + static_assert(found == true, + "wrong! not found matching transformation and upper-dimension"); + + return transforms_[Number{}].GetUpperLengths()[Number{}]; + } + + __host__ __device__ constexpr auto GetElementSize() const { return element_size_; } + + __host__ __device__ constexpr auto GetElementSpaceSize() const { return element_space_size_; } + + template + __host__ __device__ constexpr index_t CalculateOffset(const Idx& idx) const + { + static_assert(Idx::Size() == GetNumOfDimension(), "wrong! inconsistent # of dimension"); + + return make_dynamic_tensor_coordinate(*this, idx).GetOffset(); + } + + // TODO make these private + __host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; } + + __host__ __device__ static constexpr auto GetLowerDimensionIdss() + { + return LowerDimensionIdss{}; + } + + __host__ __device__ static constexpr auto GetUpperDimensionIdss() + { + return UpperDimensionIdss{}; + } + + __host__ __device__ static constexpr auto GetVisibleDimensionIds() + { + return VisibleDimensionIds{}; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + bool is_known = true; + + static_for<0, Transforms::Size(), 1>{}([&](auto i) { + is_known &= + remove_cv_t>::IsKnownAtCompileTime(); + }); + + return is_known && is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicTensorDescriptor, "); + static_for<0, ntransform_, 1>{}([&](auto i) { + printf("transforms: "); + transforms_[i].Print(); + printf("LowerDimensionIds:"); + LowerDimensionIdss{}.At(i).Print(); + printf("UpperDimensionIds:"); + UpperDimensionIdss{}.At(i).Print(); + }); + printf("}"); + + VisibleDimensionIds::Print(); + } + + // TODO make these private + Transforms transforms_; + ElementSize element_size_; + ElementSpaceSize element_space_size_; +}; + +template +struct DynamicTensorCoordinate +{ + // TODO make these private + static constexpr index_t ndim_visible_ = VisibleDimensionIds::Size(); + + using HiddenIndex = MultiIndex; + using VisibleIndex = MultiIndex; + + public: + __host__ __device__ constexpr DynamicTensorCoordinate() = default; + + __host__ __device__ constexpr DynamicTensorCoordinate(const HiddenIndex& idx_hidden) + : idx_hidden_{idx_hidden} + { + } + + __host__ __device__ constexpr auto GetIndex() const { return GetVisibleIndex(); } + + __host__ __device__ constexpr index_t GetOffset() const { return idx_hidden_[Number<0>{}]; } + + // TODO make these private + __host__ __device__ constexpr const auto& GetHiddenIndex() const { return idx_hidden_; } + + __host__ __device__ auto& GetHiddenIndex() { return idx_hidden_; } + + __host__ __device__ constexpr auto GetVisibleIndex() const + { + return get_container_subset(idx_hidden_, VisibleDimensionIds{}); + } + + // TODO make these private + HiddenIndex idx_hidden_; +}; + +template +struct DynamicTensorCoordinateIterator +{ + // TODO make these private + using VisibleIndex = MultiIndex; + + public: + __host__ __device__ constexpr DynamicTensorCoordinateIterator() = default; + + __host__ __device__ constexpr DynamicTensorCoordinateIterator( + const VisibleIndex& idx_diff_visible, const MultiIndex& do_transforms) + : idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms} + { + } + + __host__ __device__ constexpr const auto& GetIndexDiff() const { return GetVisibleIndexDiff(); } + + // TODO make these private + __host__ __device__ constexpr const auto& GetVisibleIndexDiff() const + { + return idx_diff_visible_; + } + + VisibleIndex idx_diff_visible_; + MultiIndex do_transforms_; + + // HACK: control UpdateLowerIndex() + static constexpr UpdateLowerIndexHack update_lower_index_hack_; +}; + +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor, and to put it outside the scope where it is used +// (transform_dynamic_tensor_descriptor) because template cannot be defined inside a function +// template +template +struct lambda_get_up_dim_num +{ + template + __host__ __device__ constexpr auto operator()(I) const + { + using Tran = remove_reference_t; + return Number{}; + } +}; + +template +__host__ __device__ constexpr auto +transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, + const NewTransforms& new_transforms, + NewLowerDimensionOldVisibleIdss, + NewUpperDimensionNewVisibleIdss) +{ + // sanity check + { + constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); }, + NewLowerDimensionOldVisibleIdss{}); + + constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); }, + NewUpperDimensionNewVisibleIdss{}); + + static_assert(is_valid_sequence_map::value && + is_valid_sequence_map::value, + "wrong!"); + } + + // lower dimension's hidden idss + // convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of + // sequences) + constexpr auto low_dim_hidden_idss = transform_tuples( + // convert lower dimension visible ids (a sequence) to hidden ids (a sequence) + [](auto low_dim_visible_ids) constexpr { + return transform_sequences( + // convert lower dimension visible id to hidden id + [](auto low_dim_visible_id) constexpr { + return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id]; + }, + low_dim_visible_ids); + }, + NewLowerDimensionOldVisibleIdss{}); + + constexpr index_t num_new_transform = NewTransforms::Size(); + + // upper dimension's hidden idss + constexpr index_t old_hidden_dim_number = OldTensorDescriptor::GetNumOfHiddenDimension(); + + constexpr auto up_dim_numbers = + generate_sequence(lambda_get_up_dim_num{}, Number{}); + + constexpr auto up_dim_numbers_scan = merge_sequences( + Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus{}, Number<0>{})); + + constexpr auto up_dim_hidden_idss = generate_tuple( + [ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr { + return + typename arithmetic_sequence_gen::type{}; + }, + Number{}); + + // new visible dimension's hidden ids + constexpr auto unordered_new_visible_dim_hidden_ids = unpack( + [](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); + + constexpr auto new_visible_dim_unordered2ordered = unpack( + [](auto... xs) constexpr { return merge_sequences(xs...); }, + NewUpperDimensionNewVisibleIdss{}); + + constexpr auto new_visible_dim_hidden_ids = + unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered); + + // put everything together + const auto all_transforms = container_concat(old_tensor_desc.GetTransforms(), new_transforms); + + constexpr auto all_low_dim_hidden_idss = + container_concat(OldTensorDescriptor::GetLowerDimensionIdss(), low_dim_hidden_idss); + + constexpr auto all_up_dim_hidden_idss = + container_concat(OldTensorDescriptor::GetUpperDimensionIdss(), up_dim_hidden_idss); + + const auto element_space_size = old_tensor_desc.GetElementSpaceSize(); + + return DynamicTensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{all_transforms, + element_space_size}; +} + +template +__host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDesc& tensor_desc, + const VisibleIndex& idx_visible) +{ + static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), + "wrong! # of dimension inconsistent"); + + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); + constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds(); + + MultiIndex idx_hidden; + + // initialize visible index + set_container_subset(idx_hidden, visible_dim_ids, idx_visible); + + // calculate hidden index + static_for{}([&tensor_desc, &idx_hidden](auto itran_p1) { + auto itran = itran_p1 - Number<1>{}; + const auto& tran = tensor_desc.GetTransforms().At(itran); + constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); + constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); + + const auto idx_up = get_container_subset(idx_hidden, dims_up); + + MultiIndex idx_low; + + tran.CalculateLowerIndex(idx_low, idx_up); + + set_container_subset(idx_hidden, dims_low, idx_low); + }); + + return DynamicTensorCoordinate{idx_hidden}; +} + +// UpdateLowerIndexHack: Sequence<...> +// HACK: control UpdateLowerIndex +template +__host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator( + const TensorDesc&, const VisibleIndex& idx_diff_visible, UpdateLowerIndexHack) +{ + static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), + "wrong! # of dimension inconsistent"); + + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); + constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension(); + constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds(); + + static_assert(UpdateLowerIndexHack::Size() == ntransform, "wrong!"); + + // use index_t for boolean type + auto do_transforms = make_zero_multi_index(); + auto is_non_zero_diff = make_zero_multi_index(); + + // decide do_transform by checkout non-zero index diff components + MultiIndex non_zero_diff_pick_visible; + + static_for<0, ndim_visible, 1>{}( + [&](auto i) { non_zero_diff_pick_visible(i) = (idx_diff_visible[i] != 0); }); + + set_container_subset(is_non_zero_diff, visible_dim_ids, non_zero_diff_pick_visible); + + static_for{}([&](auto itran) { + constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); + constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); + + const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up); + + MultiIndex non_zero_diff_pick_low; + + // if any of upper index diff components is non-zero, then + // 1) Need to do this transform + // 2) all components of lower index diff will assume to be non-zero and need to be + // computed + const bool idx_diff_up_has_non_zero = container_reduce( + non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false); + + do_transforms(itran) = idx_diff_up_has_non_zero; + + static_for<0, dims_low.Size(), 1>{}( + [&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; }); + + set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low); + }); + + return DynamicTensorCoordinateIterator{ + idx_diff_visible, do_transforms}; +} + +template +__host__ __device__ constexpr auto +make_dynamic_tensor_coordinate_iterator(const TensorDesc&, const VisibleIndex& idx_diff_visible) +{ + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + + return make_dynamic_tensor_coordinate_iterator( + TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen::type{}); +} + +template +__host__ __device__ constexpr void move_dynamic_tensor_coordinate( + const TensorDesc& tensor_desc, TensorCoord& coord, const TensorCoordIterator& coord_iterator) +{ + constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + + // this is what needs to be calculated + auto idx_diff_hidden = make_zero_multi_index(); + + // initialize visible index diff + set_container_subset(idx_diff_hidden, + TensorDesc::GetVisibleDimensionIds(), + coord_iterator.GetVisibleIndexDiff()); + + // this is what needs to be updated + auto& idx_hidden = coord.GetHiddenIndex(); + + // update visible index + auto idx_hidden_pick_visible = + get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds()); + + idx_hidden_pick_visible += coord_iterator.GetIndexDiff(); + + set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible); + + // update rest of hidden index + static_for{}([&](auto itran) { + if(coord_iterator.do_transforms_[itran]) + { + const auto& tran = tensor_desc.GetTransforms().At(itran); + constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); + constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); + + const auto idx_up_new = get_container_subset(idx_hidden, dims_up); + auto idx_low = get_container_subset(idx_hidden, dims_low); + const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up); + + MultiIndex idx_diff_low; + + // HACK: control UpdateLowerIndex for DynamicMerge using hack + constexpr index_t Hack = decltype(coord_iterator.update_lower_index_hack_)::At(itran); + + tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); + + set_container_subset(idx_diff_hidden, dims_low, idx_diff_low); + set_container_subset(idx_hidden, dims_low, idx_low); + } + }); +} + +template +__host__ __device__ constexpr bool +coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& tensor_desc, + const TensorCoord& coord) +{ + bool valid = true; + + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + + const auto& idx_hidden = coord.GetHiddenIndex(); + + static_for{}([&tensor_desc, &idx_hidden, &valid](auto itran) { + const auto tran = tensor_desc.GetTransforms().At(itran); + + // check validity, only if current transformation does not always has a valid mapping + if constexpr(!decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex()) + { + const auto idx_up = + get_container_subset(idx_hidden, TensorDesc::GetUpperDimensionIdss().At(itran)); + + // Comment: using valid = valid && .. will result in weird control flow in ISA + valid &= tran.IsValidUpperIndexMappedToValidLowerIndex(idx_up); + } + }); + + return valid; +} + +template +__host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc, + const TensorCoord& coord) +{ + // check visible index + const auto& idx_visible = coord.GetVisibleIndex(); + + bool is_visible_index_valid = true; + + static_for<0, TensorDesc::GetNumOfDimension(), 1>{}( + [&is_visible_index_valid, &idx_visible, &tensor_desc](auto i) { + is_visible_index_valid = + is_visible_index_valid && + (idx_visible[i] >= 0 && idx_visible[i] < tensor_desc.GetLength(i)); + }); + + // check other hidden index + return is_visible_index_valid && + coordinate_has_valid_offset_assuming_visible_index_is_valid(tensor_desc, coord); +} + +template +using DynamicTensorCoordinate_t = decltype(make_dynamic_tensor_coordinate( + TensorDesc{}, MultiIndex>::GetNumOfDimension()>{})); + +template +using DynamicTensorCoordinateIterator_t = decltype(make_dynamic_tensor_coordinate_iterator( + TensorDesc{}, MultiIndex>::GetNumOfDimension()>{})); + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/dynamic_tensor_descriptor_helper.hpp b/composable_kernel/include/tensor_description/dynamic_tensor_descriptor_helper.hpp new file mode 100644 index 0000000000..2e36451a66 --- /dev/null +++ b/composable_kernel/include/tensor_description/dynamic_tensor_descriptor_helper.hpp @@ -0,0 +1,150 @@ +#ifndef CK_DYNAMIC_TENSOR_DESCRIPTOR_HELPER_HPP +#define CK_DYNAMIC_TENSOR_DESCRIPTOR_HELPER_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_multi_index_transform_helper.hpp" + +namespace ck { + +/* + * These functions create tensor descriptor at runtime. If they are not constexpr, you will + * likely see usage of scratch memory during construction of these tensor descriptors. So + * it's better to call these functions on host and then pass the constructed tensor descritpors + * to GPU. If the tensor descritpors being constructed are constexpr, then you can call these + * functions on GPU without worrying about scratch memory usage. + */ + +#if CK_WORKAROUND_SWDEV_275126 +template +__host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengths& lengths, + const Strides& strides, + Number i, + AccOld acc_old) +{ + auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i]; + + if constexpr(i.value < Lengths::Size() - 1) + { + return calculate_element_space_size_impl(lengths, strides, i + Number<1>{}, acc_new); + } + else + { + return acc_new; + } +} +#endif + +template ::type = false> +__host__ __device__ constexpr auto +make_dynamic_naive_tensor_descriptor_v2(const Tuple& lengths, + const Tuple& strides) +{ + constexpr index_t N = sizeof...(Lengths); + + const auto transforms = make_tuple(make_embed_transform(lengths, strides)); + + constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{}); + + constexpr auto up_dim_hidden_idss = + make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{}); + + constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; + +#if !CK_WORKAROUND_SWDEV_275126 + // rocm-4.1 compiler would crash for recursive labmda + // recursive function for reduction + auto f = [&](auto fs, auto i, auto acc_old) { + auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i]; + + if constexpr(i.value < N - 1) + { + return fs(fs, i + Number<1>{}, acc_new); + } + else + { + return acc_new; + } + }; + + const auto element_space_size = f(f, Number<0>{}, Number<1>{}); +#else + const auto element_space_size = + calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{}); +#endif + + return DynamicTensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{transforms, + element_space_size}; +} + +// Lengths... can be: +// 1) index_t, which is known at run-time +// 2) Number<>, which is known at compile-time +template +__host__ __device__ constexpr auto +make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple& lengths) +{ + constexpr index_t N = sizeof...(Lengths); + + const auto transforms = make_tuple(make_unmerge_transform(lengths)); + + constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{}); + + constexpr auto up_dim_hidden_idss = + make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{}); + + constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; + + const auto element_space_size = container_reduce(lengths, math::multiplies_v2{}, Number<1>{}); + + return DynamicTensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{transforms, + element_space_size}; +} + +template +__host__ __device__ constexpr auto +make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple& lengths, Align align) +{ + constexpr auto I1 = Number<1>{}; + + constexpr index_t N = sizeof...(Lengths); + + const auto stride_n_minus_2 = math::integer_least_multiple(lengths[Number{}], align); + + auto strides = generate_tuple( + [&](auto i) { + if constexpr(i.value == N - 1) + { + return I1; + } + else if constexpr(i.value == N - 2) + { + return Number{}; + } + else + { + return container_reduce(lengths, + math::multiplies_v2{}, + Number{}, + i + I1, + Number{}, + I1); + } + }, + Number{}); + + return make_dynamic_naive_tensor_descriptor_v2(lengths, strides); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/tensor_adaptor.hpp b/composable_kernel/include/tensor_description/tensor_adaptor.hpp new file mode 100644 index 0000000000..6affe6141f --- /dev/null +++ b/composable_kernel/include/tensor_description/tensor_adaptor.hpp @@ -0,0 +1,466 @@ +#ifndef CK_TENSOR_ADAPTOR_HPP +#define CK_TENSOR_ADAPTOR_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// Transforms: Tuple +// LowerDimensionHiddenIdss : Tuple, ...> +// UpperDimensionHiddenIdss : Tuple, ...> +// BottomDimensionHiddenIds : Sequence<...> +// TopDimensionHiddenIds : Sequence<...> +template +struct TensorAdaptor +{ + __host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); } + + __host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; } + + __host__ __device__ static constexpr auto GetLowerDimensionHiddenIdss() + { + return LowerDimensionHiddenIdss{}; + } + + __host__ __device__ static constexpr auto GetUpperDimensionHiddenIdss() + { + return UpperDimensionHiddenIdss{}; + } + + __host__ __device__ static constexpr auto GetTopDimensionHiddenIds() + { + return TopDimensionHiddenIds{}; + } + + __host__ __device__ static constexpr auto GetBottomDimensionHiddenIds() + { + return BottomDimensionHiddenIds{}; + } + + __host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms) + { + const auto lengths = generate_tuple( + [&](auto idim_top) { + constexpr auto tmp = GetTransformAndItsUpperDimension(idim_top); + + constexpr index_t itran = tmp[Number<0>{}]; + constexpr index_t idim_up = tmp[Number<1>{}]; + constexpr bool found = tmp[Number<2>{}]; + + static_assert(found == true, + "wrong! not found matching transformation and upper-dimension"); + + const auto length = + transforms[Number{}].GetUpperLengths()[Number{}]; + + return length; + }, + Number{}); + + // TODO: make container_reduce support tuple of Number and index_t + return container_reduce(lengths, math::multiplies_v2{}, Number<1>{}); + } + + template + __host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number) + { + constexpr auto idim_top = Number{}; + + constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top); + + index_t itran_found = 0; + index_t idim_up_found = 0; + bool found = false; + + static_for<0, ntransform_, 1>{}([&](auto itran) { + constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran]; + + static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) { + if constexpr(up_dim_ids[idim_up] == idim_hidden) + { + itran_found = itran; + idim_up_found = idim_up; + found = true; + } + }); + }); + + return make_tuple(itran_found, idim_up_found, found); + } + + __host__ __device__ static constexpr index_t GetNumOfBottomDimension() + { + return BottomDimensionHiddenIds::Size(); + } + + __host__ __device__ static constexpr index_t GetNumOfTopDimension() + { + return TopDimensionHiddenIds::Size(); + } + + __host__ __device__ static constexpr index_t GetNumOfHiddenDimension() + { + constexpr auto all_low_dim_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, + LowerDimensionHiddenIdss{}); + + constexpr auto all_up_dim_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, + UpperDimensionHiddenIdss{}); + + constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); + + using unique_sort_all_dim_ids = typename sequence_unique_sort, + math::equal>::type; + + return unique_sort_all_dim_ids::Size(); + } + + constexpr static index_t ntransform_ = GetNumOfTransform(); + constexpr static index_t ndim_hidden_ = GetNumOfHiddenDimension(); + constexpr static index_t ndim_bottom_ = GetNumOfBottomDimension(); + constexpr static index_t ndim_top_ = GetNumOfTopDimension(); + + using HiddenIndex = MultiIndex; + using BottomIndex = MultiIndex; + using TopIndex = MultiIndex; + + // may be index_t or Number<> + using ElementSize = remove_cv_t; + + public: + __host__ __device__ constexpr TensorAdaptor() = default; + + __host__ __device__ constexpr TensorAdaptor(const Transforms& transforms) + : transforms_{transforms}, element_size_{InitializeElementSize(transforms)} + { + static_assert(Transforms::Size() == ntransform_ && + LowerDimensionHiddenIdss::Size() == ntransform_ && + UpperDimensionHiddenIdss::Size() == ntransform_, + "wrong! inconsistent # of transformations"); + + // TODO check dependency of dimensions is valid + } + + __host__ __device__ constexpr auto GetElementSize() const { return element_size_; } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + static_assert(TopIdx::Size() == TopDimensionHiddenIds::Size(), + "wrong! # of dimension inconsistent"); + + constexpr index_t ntransform = GetNumOfTransform(); + constexpr index_t ndim_hidden = GetNumOfHiddenDimension(); + + MultiIndex idx_hidden; + + // initialize uppest index + set_container_subset(idx_hidden, GetTopDimensionHiddenIds(), idx_top); + + // calculate hidden index + static_for{}([&](auto itran_p1) { + auto itran = itran_p1 - Number<1>{}; + const auto& tran = GetTransforms().At(itran); + constexpr auto dims_low = GetLowerDimensionHiddenIdss().At(itran); + constexpr auto dims_up = GetUpperDimensionHiddenIdss().At(itran); + + const auto idx_up = get_container_subset(idx_hidden, dims_up); + + MultiIndex idx_low; + + tran.CalculateLowerIndex(idx_low, idx_up); + + set_container_subset(idx_hidden, dims_low, idx_low); + }); + + return get_container_subset(idx_hidden, BottomDimensionHiddenIds{}); + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + bool is_known = true; + + static_for<0, Transforms::Size(), 1>{}([&](auto i) { + is_known &= + remove_cv_t>::IsKnownAtCompileTime(); + }); + + return is_known && is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("TensorAdaptor, "); + static_for<0, ntransform_, 1>{}([&](auto i) { + printf("transforms: "); + transforms_[i].Print(); + printf("LowerDimensionHiddenIds:"); + LowerDimensionHiddenIdss{}.At(i).Print(); + printf("UpperDimensionHiddenIds:"); + UpperDimensionHiddenIdss{}.At(i).Print(); + }); + + printf("BottomDimensionHiddenIds:"); + BottomDimensionHiddenIds::Print(); + printf("TopDimensionHiddenIds:"); + TopDimensionHiddenIds::Print(); + + printf("}"); + } + + private: + Transforms transforms_; + ElementSize element_size_; +}; + +template +__host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0, + const TensorAdaptor1& adaptor1) +{ + static_assert(TensorAdaptor0::GetNumOfTopDimension() == + TensorAdaptor1::GetNumOfBottomDimension(), + "wrong!"); + + // all_transforms = transform0 + transform1 + const auto all_transforms = + container_concat(adaptor0.GetTransforms(), adaptor1.GetTransforms()); + + // shift + constexpr index_t adaptor0_max_hidden_id = [&]() { + index_t adaptor0_max_hidden_id_ = NumericLimits::Min(); + + static_for<0, TensorAdaptor0::GetNumOfTransform(), 1>{}([&](auto itran) { + constexpr index_t ndim_low = + TensorAdaptor0{}.GetTransforms()[itran].GetNumOfLowerDimension(); + + static_for<0, ndim_low, 1>{}([&](auto idim_low) { + adaptor0_max_hidden_id_ = + math::max(adaptor0_max_hidden_id_, + TensorAdaptor0::GetLowerDimensionHiddenIdss()[itran][idim_low].value); + }); + + constexpr index_t ndim_up = + TensorAdaptor0{}.GetTransforms()[itran].GetNumOfUpperDimension(); + + static_for<0, ndim_up, 1>{}([&](auto idim_up) { + adaptor0_max_hidden_id_ = + math::max(adaptor0_max_hidden_id_, + TensorAdaptor0::GetUpperDimensionHiddenIdss()[itran][idim_up].value); + }); + }); + + return adaptor0_max_hidden_id_; + }(); + + constexpr index_t adaptor1_min_hidden_id = [&]() { + index_t adaptor1_min_hidden_id_ = NumericLimits::Max(); + + static_for<0, TensorAdaptor1::GetNumOfTransform(), 1>{}([&](auto itran) { + constexpr index_t ndim_low = + TensorAdaptor1{}.GetTransforms()[itran].GetNumOfLowerDimension(); + + // get the min of all lower dimenions, but not bottom dimension (because their id will + // be matched with top id from adaptor0) + static_for<0, ndim_low, 1>{}([&](auto idim_low) { + constexpr index_t low_dim_hidden_id = + TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran][idim_low].value; + + bool is_bottom_dim = false; + static_for<0, TensorAdaptor1::GetNumOfBottomDimension(), 1>{}([&](auto i) { + if constexpr(low_dim_hidden_id == + TensorAdaptor1::GetBottomDimensionHiddenIds()[i]) + { + is_bottom_dim = true; + } + }); + + if(!is_bottom_dim) + { + adaptor1_min_hidden_id_ = math::min(adaptor1_min_hidden_id_, low_dim_hidden_id); + } + }); + + constexpr index_t ndim_up = + TensorAdaptor1{}.GetTransforms()[itran].GetNumOfUpperDimension(); + + // get the min of all upper dimensions + static_for<0, ndim_up, 1>{}([&](auto idim_up) { + adaptor1_min_hidden_id_ = + math::min(adaptor1_min_hidden_id_, + TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran][idim_up].value); + }); + }); + + return adaptor1_min_hidden_id_; + }(); + + constexpr index_t adaptor1_hidden_id_shift = + adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id; + + constexpr index_t ndim_bottom_1 = TensorAdaptor1::GetNumOfBottomDimension(); + + // all_low_dim_hidden_idss = + // low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1)) + constexpr auto low_dim_hidden_idss_1 = generate_tuple( + // generate sequence of ids for a transform + [&](auto itran) { + constexpr auto ndim_low_1 = TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran].Size(); + + constexpr auto low_dim_hidden_ids_1 = + TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran]; + + // sequence in, sequence out + constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr + { + auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1); + + // shift hidden id so every dim id is unique + static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) { + low_dim_hidden_ids_1_mod_(idim_low_1) += adaptor1_hidden_id_shift; + }); + + // match hidden id + static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) { + static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) { + // if this low dim is bottom dim, then do id matching + if constexpr(low_dim_hidden_ids_1[idim_low_1] == + TensorAdaptor1::GetBottomDimensionHiddenIds()[idim_bottom_1]) + { + low_dim_hidden_ids_1_mod_(idim_low_1) = + TensorAdaptor0::GetTopDimensionHiddenIds()[idim_bottom_1]; + } + }); + }); + + return low_dim_hidden_ids_1_mod_; + } + (); + + return generate_sequence_v2( + [&](auto i) constexpr { return Number{}; }, + Number{}); + }, + Number{}); + + constexpr auto all_low_dim_hidden_idss = + container_concat(TensorAdaptor0::GetLowerDimensionHiddenIdss(), low_dim_hidden_idss_1); + + // all_up_dim_hidden_idss = + // up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1) + constexpr auto up_dim_hidden_idss_1 = generate_tuple( + // generate sequence of ids for a transform + [&](auto itran) { + constexpr auto ndim_up_1 = TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran].Size(); + + constexpr auto up_dim_hidden_ids_1 = + TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran]; + + // sequence in, constexpr tuple out + constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr + { + auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1); + + // shift hidden id + static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) { + up_dim_hidden_ids_1_mod_(idim_up_1) += adaptor1_hidden_id_shift; + }); + + return up_dim_hidden_ids_1_mod_; + } + (); + + // constexpr tuple to sequence + return generate_sequence_v2( + [&](auto i) constexpr { return Number{}; }, + Number{}); + }, + Number{}); + + constexpr auto all_up_dim_hidden_idss = + container_concat(TensorAdaptor0::GetUpperDimensionHiddenIdss(), up_dim_hidden_idss_1); + + // bottom_dim_hidden_ids = bottom_dim_hidden_ids_0 + constexpr auto bottom_dim_hidden_ids = TensorAdaptor0::GetBottomDimensionHiddenIds(); + + // top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1) + constexpr auto top_dim_hidden_ids = + TensorAdaptor1::GetTopDimensionHiddenIds() + Number{}; + + // put everything together + return TensorAdaptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{all_transforms}; +} + +// Transforms: Tuple +// LowerDimensionOldTopIdss: Tuple, ...> +// UpperDimensionNewTopIdss: Tuple, ...> +template +__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms, + LowerDimensionOldTopIdss, + UpperDimensionNewTopIdss) +{ + constexpr index_t ntransform = Transforms::Size(); + + static_assert(LowerDimensionOldTopIdss::Size() == ntransform && + UpperDimensionNewTopIdss::Size() == ntransform, + "wrong!"); + + // sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss + constexpr auto all_low_dim_old_top_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionOldTopIdss{}); + + constexpr auto all_up_dim_new_top_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{}); + + static_assert(is_valid_sequence_map::value && + is_valid_sequence_map::value, + "wrong!"); + + constexpr index_t ndim_old_top = all_low_dim_old_top_ids.Size(); + constexpr index_t ndim_new_top = all_up_dim_new_top_ids.Size(); + + // low_dim_hidden_idss + constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{}; + + // up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom + constexpr auto up_dim_hidden_idss = generate_tuple( + [](auto itran) { return UpperDimensionNewTopIdss{}[itran] + Number{}; }, + Number{}); + + // bottom_dim_hidden_ids + constexpr auto bottom_dim_hidden_ids = + typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{}; + + // top_dim_hidden_ids + constexpr auto top_dim_hidden_ids = + typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + Number{}; + + return TensorAdaptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{transforms}; +} + +template = 2, bool>::type = false> +__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs) +{ + return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...)); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp new file mode 100644 index 0000000000..694b2fd2cc --- /dev/null +++ b/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp @@ -0,0 +1,171 @@ +#ifndef CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP +#define CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "cluster_descriptor.hpp" +#include "threadwise_dynamic_tensor_slice_transfer.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct BlockwiseDynamicTensorSliceTransfer_v4 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + using Index = MultiIndex; + + __device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4(const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin) + : threadwise_transfer_( + src_desc, make_zero_multi_index(), dst_desc, make_zero_multi_index()) + + { + static_assert(nDim == remove_reference_t>::GetNumOfDimension() && + nDim == remove_reference_t>::GetNumOfDimension() && + nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), + "wrong! BlockSize too small"); + + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{}; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const SrcIteratorHacks& src_iterator_hacks) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_desc, src_buf); + } + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunWrite(dst_desc, dst_buf); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + // SrcMoveSliceWindowIteratorHack to control index calculation move slice window + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& step, + const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow( + src_desc, step, src_move_slice_window_iterator_hack); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseDynamicTensorSliceTransfer_v3; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer_v2.hpp b/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer_v2.hpp new file mode 100644 index 0000000000..20f3225f82 --- /dev/null +++ b/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer_v2.hpp @@ -0,0 +1,158 @@ +#ifndef CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP +#define CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "cluster_descriptor.hpp" +#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct BlockwiseDynamicTensorSliceTransfer_v4r1 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + using Index = MultiIndex; + + __device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4r1( + const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin) + : threadwise_transfer_( + src_desc, make_zero_multi_index(), dst_desc, make_zero_multi_index()) + + { + static_assert(nDim == remove_reference_t>::GetNumOfDimension() && + nDim == remove_reference_t>::GetNumOfDimension() && + nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), + "wrong! BlockSize too small"); + + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{}; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const SrcIteratorHacks& src_iterator_hacks) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks); + } + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunWrite(dst_desc, dst_buf); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + // SrcMoveSliceWindowIteratorHack to control index calculation move slice window + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& step, + const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow( + src_desc, step, src_move_slice_window_iterator_hack); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseDynamicTensorSliceTransfer_v3r1; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp new file mode 100644 index 0000000000..694cf9c6a3 --- /dev/null +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp @@ -0,0 +1,396 @@ +#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP +#define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP + +#include "common_header.hpp" +#include "tensor_adaptor.hpp" +#include "threadwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_contraction_dlops.hpp" + +namespace ck { + +// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1] +// A and B are visable to the whole block, C is distributed among each thread +// Assume: +// 1. A: +// 1. AKMBlockDesc is known at compile-time +// 2. ABlockBuffer is DynamicBuffer +// 2. B: +// 1. BKNBlockDesc is known at compile-time +// 2. BBlockBuffer is DynamicBuffer +// 3. C: +// 1. CM0M1N0N1ThreadDesc is known at compile-time +// 2. CThreadBuffer is StaticBuffer +// Also assume: +// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) +template ::type = false> +struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 +{ + using AIndex = MultiIndex<3>; + using BIndex = MultiIndex<3>; + using CIndex = MultiIndex<4>; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr index_t K = AKMBlockDesc{}.GetLength(I0); + static constexpr index_t M = AKMBlockDesc{}.GetLength(I1); + static constexpr index_t N = BKNBlockDesc{}.GetLength(I1); + + static constexpr index_t M100 = M1N1ThreadClusterM100; + static constexpr index_t N100 = M1N1ThreadClusterN100; + + static constexpr index_t M101 = M1N1ThreadClusterM101; + static constexpr index_t N101 = M1N1ThreadClusterN101; + + static constexpr index_t M11 = M1PerThreadM11; + static constexpr index_t N11 = N1PerThreadN11; + + static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11; + static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11; + + static constexpr index_t M0 = M / M1; + static constexpr index_t N0 = N / N1; + + __host__ __device__ static constexpr auto + MakeAKM0M1BlockDescriptor(const AKMBlockDesc& a_k_m_block_desc) + { + const auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor( + AKMBlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return a_k_m0_m1_block_desc; + } + + __host__ __device__ static constexpr auto + MakeBKN0N1BlockDescriptor(const BKNBlockDesc& b_k_n_block_desc) + { + const auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor( + BKNBlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return b_k_n0_n1_block_desc; + } + + __host__ __device__ static constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor() + { + // upper: [M0, M100, M101, M11, N0, N100, N101, N11] + // lower: [M, N] + constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{})); + + return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor; + } + + __host__ __device__ static constexpr auto + MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor() + { + // upper: [M0, M100, M101, M11, N0, N100, N101, N11] + // lower: [M0, M1, N0, N1] + constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{})); + + return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor; + } + + __host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths() + { + return Sequence{}; + } + + static constexpr auto a_k_m0_m1_block_desc_ = MakeAKM0M1BlockDescriptor(AKMBlockDesc{}); + static constexpr auto b_k_n0_n1_block_desc_ = MakeBKN0N1BlockDescriptor(BKNBlockDesc{}); + + public: + __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2() + : c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock( + get_thread_local_1d_id())}, + a_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])}, + b_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])} + { + static_assert(AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(BlockSize == M101 * M100 * N101 * N100, + "wrong! blocksize and cluster size not consistent"); + + static_assert(M % M1 == 0 && N % N1 == 0, "wrong!"); + + static_assert(AKMBlockDesc{}.GetLength(I0) == BKNBlockDesc{}.GetLength(I0), + "wrong! K dimension not consistent"); + + // TODO: remove this restriction + static_assert(M0 == 2 && N0 == 2, "wrong"); + } + + __device__ static CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id) + { + // lower: [M0, M1, N0, N1] + // upper: [M0, M100, M101, M11, N0, N100, N101, N11] + constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor(); + + // lower: [M0, M100, M101, M11, N0, N100, N101, N11] + // upper: [Tid, M0, M11, N0, N11] + constexpr auto adaptor1 = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)), + make_pass_through_transform(M0), + make_pass_through_transform(M11), + make_pass_through_transform(N0), + make_pass_through_transform(N11)), + make_tuple( + Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1); + + return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0)); + } + + __host__ __device__ static constexpr index_t GetABlockAlignment() { return M1PerThreadM11; } + + __host__ __device__ static constexpr auto GetBBlockAlignment() { return N1PerThreadN11; } + + template + __device__ void Run(const CM0M1N0N1ThreadDesc& c_m0_m1_n0_n1_thread_desc, + const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + static_assert(CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + // TODO: remove this restriction + static_assert(M0 == 2 && N0 == 2 && CM0M1N0N1ThreadDesc{}.GetLength(I0) == M0 && + CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0, + "wrong"); + + auto a_thread_buf = make_static_buffer( + a_k_m0_m1_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_k_n0_n1_thread_desc_.GetElementSpaceSize()); + + constexpr auto threadwise_gemm = + ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1, + Sequence<1, M1PerThreadM11>, + Sequence<1, N1PerThreadN11>>{}; + + // read A_sub_0 + a_thread_copy_.Run(a_k_m0_m1_block_desc_, + make_tuple(I0, I0, I0), + a_block_buf, + a_k_m0_m1_thread_desc_, + make_tuple(I0, I0, I0), + a_thread_buf); + + // read B_sub_0 + b_thread_copy_.Run(b_k_n0_n1_block_desc_, + make_tuple(I0, I0, I0), + b_block_buf, + b_k_n0_n1_thread_desc_, + make_tuple(I0, I0, I0), + b_thread_buf); + + // read B_sub_1 + b_thread_copy_.Run(b_k_n0_n1_block_desc_, + make_tuple(I0, I1, I0), + b_block_buf, + b_k_n0_n1_thread_desc_, + make_tuple(I0, I1, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(a_k_m0_m1_block_desc_, + make_tuple(I0, I1, I0), + a_block_buf, + a_k_m0_m1_thread_desc_, + make_tuple(I0, I1, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + + // loop over rest of k + static_for{}([&](auto k) { + // read A_sub_0 + a_thread_copy_.Run(a_k_m0_m1_block_desc_, + make_tuple(k, I0, I0), + a_block_buf, + a_k_m0_m1_thread_desc_, + make_tuple(I0, I0, I0), + a_thread_buf); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // read B_sub_0 + b_thread_copy_.Run(b_k_n0_n1_block_desc_, + make_tuple(k, I0, I0), + b_block_buf, + b_k_n0_n1_thread_desc_, + make_tuple(I0, I0, I0), + b_thread_buf); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + + // read B_sub_1 + b_thread_copy_.Run(b_k_n0_n1_block_desc_, + make_tuple(k, I1, I0), + b_block_buf, + b_k_n0_n1_thread_desc_, + make_tuple(I0, I1, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(a_k_m0_m1_block_desc_, + make_tuple(k, I1, I0), + a_block_buf, + a_k_m0_m1_thread_desc_, + make_tuple(I0, I1, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + }); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + } + + private: + // A[K, M0, M1] + static constexpr auto a_k_m0_m1_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{}, Number{})); + + // B[K, N0, N1] + static constexpr auto b_k_n0_n1_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{}, Number{})); + + using AThreadCopy = + ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<0, 1, 2>, + 2, + AThreadCopyScalarPerVector_M11, + 1>; + + using BThreadCopy = + ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<0, 1, 2>, + 2, + BThreadCopyScalarPerVector_N11, + 1>; + + CIndex c_thread_origin_data_idx_; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp new file mode 100644 index 0000000000..6a3885936e --- /dev/null +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp @@ -0,0 +1,410 @@ +#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP +#define CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP + +#include "common_header.hpp" +#include "tensor_adaptor.hpp" +#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" +#include "threadwise_contraction_dlops.hpp" + +namespace ck { + +// C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1] +// A and B are visable to the whole block, C is distributed among each thread +// Assume: +// 1. A: +// 1. ABlockDesc_BK0_BM_BK1 is known at compile-time +// 2. ABlockBuffer is DynamicBuffer +// 2. B: +// 1. BBlockDesc_BK0_BN_BK1 is known at compile-time +// 2. BBlockBuffer is DynamicBuffer +// 3. C: +// 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time +// 2. CThreadBuffer is StaticBuffer +// Also assume: +// BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2 +// BM0 = BN0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) +template + typename BM10BN10ThreadClusterBN10Xs, // Sequence + index_t AThreadCopyScalarPerVector_BM11, + index_t BThreadCopyScalarPerVector_BN11, + typename std::enable_if::type = false> +struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 +{ + using AIndex = MultiIndex<3>; + using BIndex = MultiIndex<3>; + using CIndex = MultiIndex<4>; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0); + static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2); + static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1); + static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1); + + static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0]; + static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0]; + + static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1]; + static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1]; + + static constexpr index_t BM11 = BM1PerThreadBM11; + static constexpr index_t BN11 = BN1PerThreadBN11; + + static constexpr index_t BM1 = BM100 * BM101 * BM11; + static constexpr index_t BN1 = BN100 * BN101 * BN11; + + static constexpr index_t BM0 = BM / BM1; + static constexpr index_t BN0 = BN / BN1; + + __host__ __device__ static constexpr auto + MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1) + { + const auto a_block_bk0_bm0_bm1_bk1 = transform_dynamic_tensor_descriptor( + a_block_desc_bk0_bm_bk1, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return a_block_bk0_bm0_bm1_bk1; + } + + __host__ __device__ static constexpr auto + MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1) + { + const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_dynamic_tensor_descriptor( + b_block_desc_bk0_bn_bk1, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return b_block_desc_bk0_bn0_bn1_bk1; + } + + __host__ __device__ static constexpr auto + MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN() + { + // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + // lower: [BM, BN] + constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{})); + + return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n; + } + + __host__ __device__ static constexpr auto + MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1() + { + // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + // lower: [BM0, BM1, BN0, BN1] + constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 = + make_single_stage_tensor_adaptor( + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{})); + + return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1; + } + + __host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1() + { + return Sequence{}; + } + + static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ = + MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{}); + + static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ = + MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{}); + + public: + __device__ BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2() + : c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( + get_thread_local_1d_id())}, + a_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1], 0)}, + b_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3], 0)} + { + static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() && + BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(BlockSize == BM101 * BM100 * BN101 * BN100, + "wrong! blocksize and cluster size not consistent"); + + static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!"); + + static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) == + BBlockDesc_BK0_BN_BK1{}.GetLength(I0), + "wrong! K dimension not consistent"); + + // TODO remove this restriction + static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 && + BM10BN10ThreadClusterBN10Xs::Size() == 2, + "wrong!"); + + // TODO: remove this restriction + static_assert(BM0 == 2 && BN0 == 2, "wrong"); + } + + __device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id) + { + // lower: [BM0, BM1, BN0, BN1] + // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + constexpr auto adaptor0 = + MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1(); + + // lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + // upper: [Tid, BM0, BM11, BN0, BN11] + constexpr auto adaptor1 = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)), + make_pass_through_transform(BM0), + make_pass_through_transform(BM11), + make_pass_through_transform(BN0), + make_pass_through_transform(BN11)), + make_tuple( + Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1); + + return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0)); + } + + template + __device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&, + const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + // TODO: remove this restriction + static_assert(BM0 == 2 && BN0 == 2 && + CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I0) == BM0 && + CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0, + "wrong"); + + auto a_thread_buf = make_static_buffer( + a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize()); + + constexpr auto threadwise_contraction = + ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< + FloatA, + FloatB, + FloatC, + decltype(a_thread_desc_bk0_bm0_bm1_bk1_), + decltype(b_thread_desc_bk0_bn0_bn1_bk1_), + CThreadDesc_BM0_BM11_BN0_BN11, + Sequence, + Sequence<1, BM1PerThreadBM11>, + Sequence<1, BN1PerThreadBN11>>{}; + + // read A_sub_0 + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I0, I0, I0), + a_block_buf, + a_thread_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + // read B_sub_0 + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I0, I0, I0), + b_block_buf, + b_thread_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + // read B_sub_1 + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I1, I0, I0), + b_block_buf, + b_thread_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I1, I0, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I1, I0, I0), + a_block_buf, + a_thread_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I1, I0, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + + // loop over rest of bk0 + static_for{}([&](auto bk0) { + // read A_sub_0 + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, + make_tuple(bk0, I0, I0, I0), + a_block_buf, + a_thread_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // read B_sub_0 + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, + make_tuple(bk0, I0, I0, I0), + b_block_buf, + b_thread_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + + // read B_sub_1 + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, + make_tuple(bk0, I1, I0, I0), + b_block_buf, + b_thread_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I1, I0, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, + make_tuple(bk0, I1, I0, I0), + a_block_buf, + a_thread_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I1, I0, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + }); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + } + + private: + // A[BK0, BM0, BM1, BK1] + static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + Number{}, Number{}, Number{}, Number{})); + + // B[BK0, BN0, BN1, BK1] + static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + Number{}, Number{}, Number{}, Number{})); + + using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1< + FloatA, + FloatA, + decltype(a_block_desc_bk0_bm0_bm1_bk1_), + decltype(a_thread_desc_bk0_bm0_bm1_bk1_), + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths + Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder + + using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1< + FloatB, + FloatB, + decltype(b_block_desc_bk0_bn0_bn1_bk1_), + decltype(b_thread_desc_bk0_bn0_bn1_bk1_), + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + Sequence<1, 1, BN1PerThreadBN11, BK1>, // SrcVectorTensorLengths + Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder + + CIndex c_thread_origin_data_idx_; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp new file mode 100644 index 0000000000..074d519b76 --- /dev/null +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp @@ -0,0 +1,190 @@ +#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP +#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP + +#include "common_header.hpp" +#include "threadwise_gemm_dlops_v3.hpp" + +namespace ck { + +template +struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 +{ + struct MatrixIndex + { + index_t k; + index_t h; + index_t w; + }; + + // HACK: fix this @Jing Zhang + static constexpr index_t KPerThreadSubC = 4; + + static constexpr auto a_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{})); + + static constexpr auto b_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + static constexpr auto c_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + using AThreadCopy = + ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<0, 1>, + 1, + ThreadGemmADataPerRead_K, + 1>; + + __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3() + : c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())}, + a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)} + { + static_assert(BlockMatrixA::IsKnownAtCompileTime() && + BlockMatrixB::IsKnownAtCompileTime() && + ThreadMatrixC::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0), + "wrong! K dimension not consistent\n"); + + constexpr index_t K = BlockMatrixA{}.GetLength(I1); // A is transposed + constexpr index_t N = BlockMatrixB{}.GetLength(I1); + constexpr index_t H = BlockMatrixB{}.GetLength(I2); + constexpr index_t W = BlockMatrixB{}.GetLength(I3); + + static_assert(K % KPerThread == 0 && H % HPerThread == 0 && W % WPerThread == 0, + "wrong! Cannot evenly divide work among\n"); + + constexpr auto KThreadCluster = K / KPerThread; + constexpr auto HThreadCluster = H / HPerThread; + constexpr auto WThreadCluster = W / WPerThread; + + static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster, + "wrong! wrong blocksize\n"); + } + + __device__ static constexpr auto GetThreadMatrixCLengths() + { + return Sequence{}; + } + + __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) + { + constexpr index_t H = BlockMatrixB{}.GetLength(Number<2>{}); + constexpr index_t W = BlockMatrixB{}.GetLength(Number<3>{}); + + constexpr auto num_w_threads = W / WPerThread; + constexpr auto num_h_threads = H / HPerThread; + constexpr auto num_hw_threads = num_w_threads * num_h_threads; + + index_t k_thread_id = thread_id / num_hw_threads; + index_t hw_thread_id = thread_id % num_hw_threads; + + index_t h_thread_id = hw_thread_id / num_w_threads; + index_t w_thread_id = hw_thread_id % num_w_threads; + + return MatrixIndex{k_thread_id, h_thread_id, w_thread_id}; + } + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BThreadBuffer& b_thread_buf, + CThreadBuffer& c_thread_buf) const + { + static_assert(is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + "wrong! inconsistent type"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto a_block_mtx = BlockMatrixA{}; + + constexpr auto EPerBlock = a_block_mtx.GetLength(I0); + + // HACK: fix this @Jing Zhang + constexpr auto HoPerThreadSubC = 2; + constexpr auto WoPerThreadSubC = 2; + + static_assert(KPerThread % KPerThreadSubC == 0, ""); + static_assert(HPerThread % HoPerThreadSubC == 0, ""); + static_assert(WPerThread % WoPerThreadSubC == 0, ""); + + // thread A buffer for GEMM + StaticBuffer + a_thread_buf; + + constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3{}; + + static_for<0, EPerBlock, EPerThreadLoop>{}([&](auto e_begin) { + static_for<0, KPerThread, KPerThreadSubC>{}([&](auto k_begin) { + a_thread_copy_.Run(a_block_mtx, + make_tuple(e_begin, k_begin), + a_block_buf, + a_thread_mtx_, + make_tuple(I0, I0), + a_thread_buf); + + static_for<0, HPerThread, HoPerThreadSubC>{}([&](auto h_begin) { + static_for<0, WPerThread, WoPerThreadSubC>{}([&](auto w_begin) { + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0), + b_thread_buf, + make_tuple(e_begin, I0, h_begin, w_begin), + c_thread_buf, + make_tuple(k_begin, I0, h_begin, w_begin)); + }); + }); + }); + }); + } + + template + __device__ void MoveASliceWindow(const BlockMatrixA&, + const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx) + { + a_thread_copy_.MoveSrcSliceWindow(BlockMatrixA{}, a_block_slice_move_step_idx); + } + + private: + MatrixIndex c_thread_begin_mtx_idx_; + + AThreadCopy a_thread_copy_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp new file mode 100644 index 0000000000..98407ab7fc --- /dev/null +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp @@ -0,0 +1,528 @@ +#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP +#define CK_BLOCKWISE_GEMM_XDLOPS_HPP + +#include "common_header.hpp" +#include "threadwise_dynamic_tensor_slice_transfer.hpp" +#include "xdlops_gemm.hpp" + +namespace ck { + +template +struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 +{ + + using CIndex = MultiIndex<2>; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr index_t WaveSize = 64; + + static constexpr index_t M0 = ABlockDesc{}.GetLength(I1); + static constexpr index_t M1 = ABlockDesc{}.GetLength(I2); + + static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); + static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = XdlopsGemm{}; + + static constexpr index_t MWaves = M1 / MPerWave; + static constexpr index_t NWaves = N1 / NPerWave; + + static constexpr index_t MRepeat = M0; + static constexpr index_t NRepeat = N0; + + __device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); } + + __device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); } + + __device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const index_t thread_id = get_thread_local_1d_id(); + const index_t waveId = thread_id / WaveSize; + const index_t laneId = thread_id % WaveSize; + const index_t waveId_m = waveId / NWaves; + const index_t waveId_n = waveId % NWaves; + + if constexpr(xdlops_gemm.IsKReduction) + { + const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId); + const index_t k_offset = xdlops_gemm.GetBlkId(laneId); + return make_tuple(k_offset, 0, m_offset, 0); + } + else + { + const index_t m_offset = waveId_m * MPerWave + laneId; + const index_t k_offset = 0; + return make_tuple(k_offset, 0, m_offset, 0); + } + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const index_t thread_id = get_thread_local_1d_id(); + const index_t waveId = thread_id / WaveSize; + const index_t laneId = thread_id % WaveSize; + const index_t waveId_m = waveId / NWaves; + const index_t waveId_n = waveId % NWaves; + + if constexpr(xdlops_gemm.IsKReduction) + { + const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId); + const index_t k_offset = xdlops_gemm.GetBlkId(laneId); + return make_tuple(k_offset, 0, n_offset, 0); + } + else + { + const index_t n_offset = waveId_n * NPerWave + laneId; + const index_t k_offset = 0; + return make_tuple(k_offset, 0, n_offset, 0); + } + } + + template + __device__ static CIndex + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + + const index_t waveId = get_thread_local_1d_id() / WaveSize; + + const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + const index_t waveId_m = waveId / NWaves; + const index_t waveId_n = waveId % NWaves; + + const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0]; + const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1]; + + return CIndex{m_offset, n_offset}; + } + + __device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1() + : a_thread_copy_{CalculateAThreadOriginDataIndex()}, + b_thread_copy_{CalculateBThreadOriginDataIndex()} + { + static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), + "wrong! K dimension not consistent"); + + static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), + "wrong! K1 dimension not consistent"); + + static_assert(BlockSize == MWaves * NWaves * WaveSize, + "BlockSize != MWaves * NWaves * WaveSize\n"); + + static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!"); + + constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); + + static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!"); + + static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!"); + } + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); + + vector_type a_thread_vec; + + vector_type b_thread_vec; + + static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) { + // read A + a_thread_copy_.Run(ABlockDesc{}, + make_tuple(k, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + // read B + b_thread_copy_.Run(BBlockDesc{}, + make_tuple(k, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + using mfma_input_type = + typename vector_type::type; + + static_for<0, a_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) { + a_thread_vec.template AsType()(Number{}) = a_thread_buf[Number{}]; + }); + + static_for<0, b_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) { + b_thread_vec.template AsType()(Number{}) = b_thread_buf[Number{}]; + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + xdlops_gemm.template Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf); + }); + }); + }); + } + + private: + // A[K, M] + static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(I1, Number{}, I1, Number{})); + + // B[K, N] + static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(I1, Number{}, I1, Number{})); + + static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{})); + + using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + K1, + 1>; + + using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + K1, + 1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +template +struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline +{ + + using CIndex = MultiIndex<2>; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto xdlops_gemm = XdlopsGemm{}; + + static constexpr index_t WaveSize = 64; + + static constexpr index_t M0 = ABlockDesc{}.GetLength(I1); + static constexpr index_t M1 = ABlockDesc{}.GetLength(I2); + + static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); + static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); + + static constexpr index_t MWaves = M1 / MPerWave; + static constexpr index_t NWaves = N1 / NPerWave; + + static constexpr index_t MRepeat = M0; + static constexpr index_t NRepeat = N0; + + __device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); } + + __device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); } + + __device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const index_t thread_id = get_thread_local_1d_id(); + const index_t waveId = thread_id / WaveSize; + const index_t laneId = thread_id % WaveSize; + const index_t waveId_m = waveId / NWaves; + const index_t waveId_n = waveId % NWaves; + + if constexpr(xdlops_gemm.IsKReduction) + { + const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId); + const index_t k_offset = xdlops_gemm.GetBlkId(laneId); + return make_tuple(k_offset, 0, m_offset, 0); + } + else + { + const index_t m_offset = waveId_m * MPerWave + laneId; + const index_t k_offset = 0; + return make_tuple(k_offset, 0, m_offset, 0); + } + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const index_t thread_id = get_thread_local_1d_id(); + const index_t waveId = thread_id / WaveSize; + const index_t laneId = thread_id % WaveSize; + const index_t waveId_m = waveId / NWaves; + const index_t waveId_n = waveId % NWaves; + + if constexpr(xdlops_gemm.IsKReduction) + { + const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId); + const index_t k_offset = xdlops_gemm.GetBlkId(laneId); + return make_tuple(k_offset, 0, n_offset, 0); + } + else + { + const index_t n_offset = waveId_n * NPerWave + laneId; + const index_t k_offset = 0; + return make_tuple(k_offset, 0, n_offset, 0); + } + } + + template + __device__ static CIndex + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + + const index_t waveId = get_thread_local_1d_id() / WaveSize; + + const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + const index_t waveId_m = waveId / NWaves; + const index_t waveId_n = waveId % NWaves; + + const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0]; + const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1]; + + return CIndex{m_offset, n_offset}; + } + + __device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline() + : a_thread_copy_{CalculateAThreadOriginDataIndex()}, + b_thread_copy_{CalculateBThreadOriginDataIndex()} + { + static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), + "wrong! K dimension not consistent"); + + static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), + "wrong! K1 dimension not consistent"); + + static_assert(BlockSize == MWaves * NWaves * WaveSize, + "BlockSize != MWaves * NWaves * WaveSize\n"); + + static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!"); + + constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); + + static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!"); + + static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!"); + } + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); + + // read A_sub_0 + a_thread_copy_.Run(ABlockDesc{}, + make_tuple(I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + // read B_sub_0 + b_thread_copy_.Run(BBlockDesc{}, + make_tuple(I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + // read B_sub_1 + b_thread_copy_.Run(BBlockDesc{}, + make_tuple(I0, I1, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I1, I0, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(ABlockDesc{}, + make_tuple(I0, I1, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I1, I0, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + + static_for{}([&](auto k) { + // read A_sub_0 + a_thread_copy_.Run(ABlockDesc{}, + make_tuple(k, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + + // read B_sub_0 + b_thread_copy_.Run(BBlockDesc{}, + make_tuple(k, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + + // read B_sub_1 + b_thread_copy_.Run(BBlockDesc{}, + make_tuple(k, I1, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I1, I0, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(ABlockDesc{}, + make_tuple(k, I1, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I1, I0, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + }); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + } + + private: + // A[K, M] + static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(I1, Number{}, I1, Number{})); + + // B[K, N] + static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(I1, Number{}, I1, Number{})); + + static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{})); + + using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + 1, // K1, + 1>; + + using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + 1, // K1, + 1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_dlops_v1r2.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_dlops_v1r2.hpp new file mode 100644 index 0000000000..6d48a18169 --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_dlops_v1r2.hpp @@ -0,0 +1,664 @@ +#ifndef CK_GRIDWISE_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP +#define CK_GRIDWISE_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP + +#include "common_header.hpp" +#include "dynamic_multi_index_transform_helper.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "blockwise_gemm_dlops_v2r3.hpp" +#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp" +#include "threadwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_set.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_dynamic_contraction_dlops_v1r2( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_GK0_GM0_GM10_GM11_GK1 a_grid_desc_gk0_gm0_gm10_gm11_gk1, + const BGridDesc_GK0_GN0_GN10_GN11_GK1 b_grid_desc_gk0_gn0_gn10_gn11_gk1, + const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + const CGridBlockCluster_BlockId_To_GM10_GN10 c_grid_block_cluster_blockid_to_gm10_gn10) +{ + constexpr index_t shared_block_size = + GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseContraction::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_grid_desc_gk0_gm0_gm10_gm11_gk1, + b_grid_desc_gk0_gn0_gn10_gn11_gk1, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_block_cluster_blockid_to_gm10_gn10, + integral_constant{}, + integral_constant{}); +} + +template +struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + // GM0 and GN0 need to known at compile-time + static constexpr auto GM0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I0); + static constexpr auto GN0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I2); + static constexpr auto GK1 = AGridDesc_GK0_GM0_GM1_GK1{}.GetLength(I3); + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // lds max alignment + // TODO: part of them should be moved into blockwise-gemm + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = GK1; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = + make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GM0, I1, Number{}, GK1), + max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = + make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GN0, I1, Number{}, GK1), + max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = math::integer_least_multiple( + a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = math::integer_least_multiple( + b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align); + + return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1, + const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1, + const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) + { + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! GM0 and GN0 need to be known at compile-time"); + + const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2); + const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2); + const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return ( + (GM0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I0) && + GM1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1) && + GN0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I2) && + GN1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3) && + GM0 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I1) && + GM1 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2) && + GN0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I1) && + GN1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2) && + GK0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0) && + GK1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I3)) && + (GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % GK0PerBlock == 0)); + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) + { + const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1); + const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3); + + constexpr index_t GM11 = GM1PerBlockGM11; + constexpr index_t GN11 = GN1PerBlockGN11; + + const index_t GM10 = GM1 / GM11; + const index_t GN10 = GN1 / GN11; + + const index_t grid_size = GM10 * GN10; + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK0) + { + const bool has_main_k_block_loop = (GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1; + + return has_main_k_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK0) + { + const bool has_double_tail_k_block_loop = (GK0 / GK0PerBlock) % 2 == 0; + + return has_double_tail_k_block_loop; + } + + __host__ __device__ static constexpr auto MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1( + const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1) + { + const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0); + const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2); + + const auto GM11 = Number{}; + const auto GM10 = GM1 / GM11; + + const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_dynamic_tensor_descriptor( + a_grid_desc_gk0_gm0_gm1_gk1, + make_tuple(make_pass_through_transform(GK0), + make_pass_through_transform(GM0), + make_unmerge_transform(make_tuple(GM10, GM11)), + make_pass_through_transform(GK1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); + + return a_grid_desc_gk0_gm0_gm10_gm11_gk1; + } + + __host__ __device__ static constexpr auto MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1( + const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1) + { + const auto GK0 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0); + const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2); + + const auto GN11 = Number{}; + const auto GN10 = GN1 / GN11; + + const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_dynamic_tensor_descriptor( + b_grid_desc_gk0_gn0_gn1_gk1, + make_tuple(make_pass_through_transform(GK0), + make_pass_through_transform(GN0), + make_unmerge_transform(make_tuple(GN10, GN11)), + make_pass_through_transform(GK1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); + + return b_grid_desc_gk0_gn0_gn10_gn11_gk1; + } + + __host__ __device__ static constexpr auto MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( + const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) + { + const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1); + const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3); + + constexpr auto GM11 = Number{}; + constexpr auto GN11 = Number{}; + + const auto GM10 = GM1 / GM11; + const auto GN10 = GN1 / GN11; + + constexpr auto BM = GM0 * GM11; + constexpr auto BN = GN0 * GN11; + + constexpr auto BM1 = + Number{}; + constexpr auto BN1 = + Number{}; + + constexpr auto BM0 = BM / BM1; + constexpr auto BN0 = BN / BN1; + + const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor( + c_grid_desc_gm0_gm1_gn0_gn1, + make_tuple(make_pass_through_transform(GM0), + make_unmerge_transform(make_tuple(GM10, GM11)), + make_pass_through_transform(GN0), + make_unmerge_transform(make_tuple(GN10, GN11))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{})); + + const auto c_gm10_bm_gn10_bn_grid_desc = transform_dynamic_tensor_descriptor( + c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc, + make_tuple(make_pass_through_transform(GM10), + make_merge_transform(make_tuple(GM0, GM11)), + make_pass_through_transform(GN10), + make_merge_transform(make_tuple(GN0, GN11))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_dynamic_tensor_descriptor( + c_gm10_bm_gn10_bn_grid_desc, + make_tuple(make_pass_through_transform(GM10), + make_unmerge_transform(make_tuple(BM0, BM1)), + make_pass_through_transform(GN10), + make_unmerge_transform(make_tuple(BN0, BN1))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{})); + + return c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1; + } + + __host__ __device__ static constexpr auto MakeCGridBlockCluster_BlockId_To_GM10_GN10( + const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) + { + const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1); + const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3); + + constexpr auto GM11 = Number{}; + constexpr auto GN11 = Number{}; + + const auto GM10 = GM1 / GM11; + const auto GN10 = GN1 / GN11; + + const auto c_grid_block_cluster_blockid_to_gm10_gn10 = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(GM10, GN10))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return c_grid_block_cluster_blockid_to_gm10_gn10; + } + + using AGridDesc_GK0_GM0_GM10_GM11_GK1 = + decltype(MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(AGridDesc_GK0_GM0_GM1_GK1{})); + using BGridDesc_GK0_GN0_GN10_GN11_GK1 = + decltype(MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(BGridDesc_GK0_GN0_GN1_GK1{})); + using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = + decltype(MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(CGridDesc_GM0_GM1_GN0_GN1{})); + using CGridBlockCluster_BlockId_To_GM10_GN10 = + decltype(MakeCGridBlockCluster_BlockId_To_GM10_GN10(CGridDesc_GM0_GM1_GN0_GN1{})); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_GK0_GM0_GM10_GM11_GK1& a_grid_desc_gk0_gm0_gm10_gm11_gk1, + const BGridDesc_GK0_GN0_GN10_GN11_GK1& b_grid_desc_gk0_gn0_gn10_gn11_gk1, + const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1& c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + const CGridBlockCluster_BlockId_To_GM10_GN10& c_grid_block_cluster_blockid_to_gm10_gn10, + integral_constant, + integral_constant) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize()); + + const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0); + + // divide block work by [GM10, GN10] + const auto c_gm10_gn10_block_cluster_idx = + c_grid_block_cluster_blockid_to_gm10_gn10.CalculateBottomIndex( + make_multi_index(get_block_1d_id())); + + // HACK: this force index data into SGPR + const index_t igm10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I0]); + const index_t ign10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I1]); + + // lds max alignment + // TODO: part of them should be moved into blockwise-gemm + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = GK1; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = + make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GM0, I1, Number{}, GK1), + max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = + make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GN0, I1, Number{}, GK1), + max_lds_align); + + // A matrix in LDS memory for blockwise GEMM + // be careful of LDS alignment + constexpr auto a_block_desc_gk0_bm_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GM0 * Number{}, GK1), max_lds_align); + + // B matrix in LDS memory for blockwise GEMM + // be careful of LDS alignment + constexpr auto b_block_desc_gk0_bn_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GN0 * Number{}, GK1), max_lds_align); + + static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() == + a_block_desc_gk0_bm_gk1.GetElementSpaceSize() && + b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize() == + b_block_desc_gk0_bn_gk1.GetElementSpaceSize(), + "wrong!"); + + // A matrix blockwise copy + auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< + BlockSize, + InMemoryDataOperationEnum_t::Set, + Sequence, + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1), + decltype(a_block_desc_gk0_gm0_gm10_gm11_gk1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3, 4>, + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // SrcVectorTensorLengths + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // DstVectorTensorLengths + ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder + false, + true>(a_grid_desc_gk0_gm0_gm10_gm11_gk1, + make_multi_index(0, 0, igm10, 0, 0), + a_block_desc_gk0_gm0_gm10_gm11_gk1, + make_multi_index(0, 0, 0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< + BlockSize, + InMemoryDataOperationEnum_t::Set, + Sequence, + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1), + decltype(b_block_desc_gk0_gn0_gn10_gn11_gk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3, 4>, + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // SrcVectorTensorLengths + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // DstVectorTensorLengths + BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder + false, + true>(b_grid_desc_gk0_gn0_gn10_gn11_gk1, + make_multi_index(0, 0, ign10, 0, 0), + b_block_desc_gk0_gn0_gn10_gn11_gk1, + make_multi_index(0, 0, 0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[GK0PerBlock, GM1PerBlockGM11] is in LDS + // b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS + // c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in + // register + const auto blockwise_gemm = + BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockSize, + FloatAB, + FloatAB, + FloatAcc, + decltype(a_block_desc_gk0_bm_gk1), + decltype(b_block_desc_gk0_bn_gk1), + BM1PerThreadBM11, + BN1PerThreadBN11, + BK0PerThread, + BM10BN10ThreadClusterBM10Xs, + BM10BN10ThreadClusterBN10Xs, + BM1PerThreadBM11, + BN1PerThreadBN11>{}; + + constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 = + decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); + + constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 = + make_dynamic_naive_tensor_descriptor_packed_v2( + sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = math::integer_least_multiple( + a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = math::integer_least_multiple( + b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block_double = p_shared_block; + FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; + + // register allocation for output + auto c_thread_buf = make_static_buffer( + c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize()); + + ThreadwiseDynamicTensorSliceSet_v1{} + .Run(c_thread_desc_bm0_bm1_bn0_bn1, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + FloatAcc{0}); + + constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0); + + auto a_block_even_buf = make_dynamic_buffer( + p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); + auto b_block_even_buf = make_dynamic_buffer( + p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); + + auto a_block_odd_buf = make_dynamic_buffer( + p_a_block_double + a_block_aligned_space_size, + a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); + auto b_block_odd_buf = make_dynamic_buffer( + p_b_block_double + b_block_aligned_space_size, + b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead( + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); + b_blockwise_copy.RunRead( + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); + + a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf); + } + + if constexpr(HasMainKBlockLoop) + { + index_t gk0_block_on_grid = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, + a_block_slice_copy_step, + AGridMoveSliceWindowIteratorHacks{}); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, + b_block_slice_copy_step, + BGridMoveSliceWindowIteratorHacks{}); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); + b_blockwise_copy.RunRead( + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1, + a_block_even_buf, + b_block_even_buf, + c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf); + + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, + a_block_slice_copy_step, + AGridMoveSliceWindowIteratorHacks{}); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, + b_block_slice_copy_step, + BGridMoveSliceWindowIteratorHacks{}); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); + b_blockwise_copy.RunRead( + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run( + c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf); + + gk0_block_on_grid += 2 * GK0PerBlock; + } while(gk0_block_on_grid < GK0 - 2 * GK0PerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, + a_block_slice_copy_step, + AGridMoveSliceWindowIteratorHacks{}); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, + b_block_slice_copy_step, + BGridMoveSliceWindowIteratorHacks{}); + + __syncthreads(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunRead( + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); + b_blockwise_copy.RunRead( + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run( + c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf); + + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 = + make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + const auto c_thread_origin_on_block_bm0_bm1_bn0_bn1 = + blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( + get_thread_local_1d_id()); + + ThreadwiseDynamicTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1), + decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1), + Sequence<1, + c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0], + c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1], + 1, + c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2], + c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + false>{c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + make_multi_index(igm10, + c_thread_origin_on_block_bm0_bm1_bn0_bn1[I0], + c_thread_origin_on_block_bm0_bm1_bn0_bn1[I1], + ign10, + c_thread_origin_on_block_bm0_bm1_bn0_bn1[I2], + c_thread_origin_on_block_bm0_bm1_bn0_bn1[I3])} + .Run(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_buf, + CGridIteratorHacks{}); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v1r2.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v1r2.hpp new file mode 100644 index 0000000000..7a4ef1d7ea --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v1r2.hpp @@ -0,0 +1,679 @@ +#ifndef CK_GRIDWISE_DYNAMIC_GEMM_DLOPS_V1R2_HPP +#define CK_GRIDWISE_DYNAMIC_GEMM_DLOPS_V1R2_HPP + +#include "common_header.hpp" +#include "dynamic_multi_index_transform_helper.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "blockwise_gemm_dlops_v2r2.hpp" +#include "blockwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_set.hpp" + +namespace ck { + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_dynamic_gemm_dlops_v1r2( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AKM0M1GridDesc a_k_m0_m1_grid_desc, + const BKN0N1GridDesc b_k_n0_n1_grid_desc, + const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc, + const CBlockIdToM0N0BlockClusterAdaptor c_blockid_to_m0_n0_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER +// pass tensor descriptor by CONSTANT void pointer +// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to +// non-modifiable parameter address space, so compiler can enable corresponding optimization +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_dynamic_gemm_dlops_v1r2( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_a_k_m0_m1_grid_desc, + const void CONSTANT* p_b_k_n0_n1_grid_desc, + const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, + const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + // first cast void CONSTANT void* to void* + // second cast void* to Desc* + // the copy constructor of tensor descriptor doesn't take address_space(4) + const auto a_k_m0_m1_grid_desc = + *reinterpret_cast((const void*)p_a_k_m0_m1_grid_desc); + const auto b_k_n0_n1_grid_desc = + *reinterpret_cast((const void*)p_b_k_n0_n1_grid_desc); + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = + *reinterpret_cast( + (const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc); + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + *reinterpret_cast( + (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} +#endif + +template +struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = math::lcm(Number{}, + Number{}, + Number{}, + Number{}); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}), max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = + math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = + math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); + + return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool CheckValidity(const AKMGridDesc& a_k_m_grid_desc, + const BKNGridDesc& b_k_n_grid_desc, + const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = a_k_m_grid_desc.GetLength(I1); + const auto N = b_k_n_grid_desc.GetLength(I1); + const auto K = a_k_m_grid_desc.GetLength(I0); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K == b_k_n_grid_desc.GetLength(I0)) && + (M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K % KPerBlock == 0); + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1); + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1; + + return has_main_k_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K) + { + const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0; + + return has_double_tail_k_block_loop; + } + + __host__ __device__ static constexpr auto + MakeAKM0M1GridDescriptor(const AKMGridDesc& a_k_m_grid_desc) + { + const auto K = a_k_m_grid_desc.GetLength(I0); + const auto M = a_k_m_grid_desc.GetLength(I1); + + const auto M1 = Number{}; + const auto M0 = M / M1; + + const auto a_k_m0_m1_grid_desc = transform_dynamic_tensor_descriptor( + a_k_m_grid_desc, + make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return a_k_m0_m1_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeBKN0N1GridDescriptor(const BKNGridDesc& b_k_n_grid_desc) + { + const auto K = b_k_n_grid_desc.GetLength(I0); + const auto N = b_k_n_grid_desc.GetLength(I1); + + const auto N1 = Number{}; + const auto N0 = N / N1; + + const auto b_k_n0_n1_grid_desc = transform_dynamic_tensor_descriptor( + b_k_n_grid_desc, + make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return b_k_n0_n1_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + constexpr auto M11 = + Number{}; + constexpr auto N11 = + Number{}; + + constexpr auto M10 = M1 / M11; + constexpr auto N10 = N1 / N11; + + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), + make_unmerge_transform(make_tuple(N0, N10, N11))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_m0_m10_m11_n0_n10_n11_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return c_blockid_to_m0_n0_block_cluster_adaptor; + } + + using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{})); + using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{})); + using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{})); + using CBlockIdToM0N0BlockClusterAdaptor = + decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{})); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AKM0M1GridDesc& a_k_m0_m1_grid_desc, + const BKN0N1GridDesc& b_k_n0_n1_grid_desc, + const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, + const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor, + integral_constant, + integral_constant) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); + + const auto K = a_k_m0_m1_grid_desc.GetLength(I0); + + // divide block work by [M, N] + const auto c_m0_n0_block_cluster_idx = + c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( + make_multi_index(get_block_1d_id())); + + // HACK: this force index data into SGPR + const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); + const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(Number{}, + Number{}, + Number{}, + Number{}); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}), max_lds_align); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k_m0_m1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, I1, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k_n0_n1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, I1, Number{}), max_lds_align); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseDynamicTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_K_M0_M1, + ABlockTransferThreadClusterLengths_K_M0_M1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_k_m0_m1_grid_desc), + decltype(a_k_m0_m1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_M1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>(a_k_m0_m1_grid_desc, + make_multi_index(0, im0, 0), + a_k_m0_m1_block_desc, + make_multi_index(0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseDynamicTensorSliceTransfer_v4, + BBlockTransferThreadSliceLengths_K_N0_N1, + BBlockTransferThreadClusterLengths_K_N0_N1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_k_n0_n1_grid_desc), + decltype(b_k_n0_n1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_N1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_k_n0_n1_grid_desc, + make_multi_index(0, in0, 0), + b_k_n0_n1_block_desc, + make_multi_index(0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[KPerBlock, MPerBlockM1] is in LDS + // b_mtx[KPerBlocl, NPerBlockN1] is in LDS + // c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in + // register + const auto blockwise_gemm = + BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2{}; + constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = + decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths(); + + constexpr auto c_m10_m11_n10_n11_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2( + sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = + math::integer_least_multiple(a_k_m0_m1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = + math::integer_least_multiple(b_k_n0_n1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block_double = p_shared_block; + FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; + + // register allocation for output + auto c_thread_buf = make_static_buffer( + c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); + + ThreadwiseDynamicTensorSliceSet_v1{} + .Run(c_m10_m11_n10_n11_thread_desc, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + FloatAcc{0}); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_k_m0_m1_global_iterator_hacks = AGridIteratorHacks{}; + constexpr auto b_k_n0_n1_global_iterator_hacks = BGridIteratorHacks{}; + + // hack to control index calculation when move slice window for A and B matrix for + // threadwise copy + constexpr auto a_k_m0_m1_global_move_slice_window_iterator_hack = + AGridMoveSliceWindowIteratorHacks{}; + constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack = + BGridMoveSliceWindowIteratorHacks{}; + + auto a_block_even_buf = make_dynamic_buffer( + p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize()); + auto b_block_even_buf = make_dynamic_buffer( + p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize()); + + auto a_block_odd_buf = make_dynamic_buffer( + p_a_block_double + a_block_aligned_space_size, + a_k_m0_m1_block_desc.GetElementSpaceSize()); + auto b_block_odd_buf = make_dynamic_buffer( + p_b_block_double + b_block_aligned_space_size, + b_k_n0_n1_block_desc.GetElementSpaceSize()); + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead( + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); + b_blockwise_copy.RunRead( + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); + + a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf); + b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf); + } + + if constexpr(HasMainKBlockLoop) + { + index_t k_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + a_blockwise_copy.MoveSrcSliceWindow( + a_k_m0_m1_grid_desc, + a_block_slice_copy_step, + a_k_m0_m1_global_move_slice_window_iterator_hack); + b_blockwise_copy.MoveSrcSliceWindow( + b_k_n0_n1_grid_desc, + b_block_slice_copy_step, + b_k_n0_n1_global_move_slice_window_iterator_hack); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); + b_blockwise_copy.RunRead( + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, + a_block_even_buf, + b_block_even_buf, + c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf); + + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow( + a_k_m0_m1_grid_desc, + a_block_slice_copy_step, + a_k_m0_m1_global_move_slice_window_iterator_hack); + b_blockwise_copy.MoveSrcSliceWindow( + b_k_n0_n1_grid_desc, + b_block_slice_copy_step, + b_k_n0_n1_global_move_slice_window_iterator_hack); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); + b_blockwise_copy.RunRead( + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf); + b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf); + + k_block_data_begin += 2 * KPerBlock; + } while(k_block_data_begin < K - 2 * KPerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, + a_block_slice_copy_step, + a_k_m0_m1_global_move_slice_window_iterator_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, + b_block_slice_copy_step, + b_k_n0_n1_global_move_slice_window_iterator_hack); + + __syncthreads(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunRead( + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); + b_blockwise_copy.RunRead( + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf); + + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr index_t M11 = + M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101; + constexpr index_t N11 = + N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101; + + constexpr index_t M10 = MPerBlockM1 / M11; + constexpr index_t N10 = NPerBlockN1 / N11; + + constexpr index_t M111 = M1PerThreadM111; + constexpr index_t N111 = N1PerThreadN111; + + constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = + blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id()); + + ThreadwiseDynamicTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), + decltype(c_m0_m10_m11_n0_n10_n11_grid_desc), + Sequence<1, + c_m10_m11_n10_n11_thread_tensor_lengths[I0], + c_m10_m11_n10_n11_thread_tensor_lengths[I1], + 1, + c_m10_m11_n10_n11_thread_tensor_lengths[I2], + c_m10_m11_n10_n11_thread_tensor_lengths[I3]>, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{c_m0_m10_m11_n0_n10_n11_grid_desc, + make_multi_index(im0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], + in0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])} + .Run(c_m0_m10_m11_n0_n10_n11_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_grid_buf, + CGridIteratorHacks{}); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v1r3.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v1r3.hpp new file mode 100644 index 0000000000..db3cb99121 --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v1r3.hpp @@ -0,0 +1,671 @@ +#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V1R3_HPP +#define CK_GRIDWISE_DYNAMIC_GEMM_V1R3_HPP + +#include "common_header.hpp" +#include "dynamic_multi_index_transform_helper.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "blockwise_gemm_dlops_v2r3.hpp" +#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp" +#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" +#include "threadwise_dynamic_tensor_slice_set.hpp" + +namespace ck { + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_dynamic_gemm_dlops_v1r3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc, + const BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc, + const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc, + const CBlockIdToM0N0BlockClusterAdaptor c_blockid_to_m0_n0_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER +// pass tensor descriptor by CONSTANT void pointer +// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to +// non-modifiable parameter address space, so compiler can enable corresponding optimization +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_dynamic_gemm_dlops_v1r3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_a_k0_m0_m1_k1_grid_desc, + const void CONSTANT* p_b_k0_n0_n1_k1_grid_desc, + const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, + const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + // first cast void CONSTANT void* to void* + // second cast void* to Desc* + // the copy constructor of tensor descriptor doesn't take address_space(4) + const auto a_k0_m0_m1_k1_grid_desc = + *reinterpret_cast((const void*)p_a_k0_m0_m1_k1_grid_desc); + const auto b_k0_n0_n1_k1_grid_desc = + *reinterpret_cast((const void*)p_b_k0_n0_n1_k1_grid_desc); + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = + *reinterpret_cast( + (const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc); + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + *reinterpret_cast( + (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} +#endif + +template +struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + // K1 should be Number<...> + static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2); + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = K1; + + // TODO: check alignment + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = + math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = + math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); + + return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool + CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, + const BK0NK1GridDesc& b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + const auto K1 = a_k0_m_k1_grid_desc.GetLength(I2); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_k0_n_k1_grid_desc.GetLength(I0) && + K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && + (M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0); + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1); + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0) + { + const bool has_main_k_block_loop = (K0 + KPerBlock) / (2 * KPerBlock) > 1; + + return has_main_k_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0) + { + const bool has_double_tail_k_block_loop = (K0 / KPerBlock) % 2 == 0; + + return has_double_tail_k_block_loop; + } + + __host__ __device__ static constexpr auto + MakeAK0M0M1K1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc) + { + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + + const auto M1 = Number{}; + const auto M0 = M / M1; + + const auto a_k0_m0_m1_k1_grid_desc = transform_dynamic_tensor_descriptor( + a_k0_m_k1_grid_desc, + make_tuple(make_pass_through_transform(K0), + make_unmerge_transform(make_tuple(M0, M1)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return a_k0_m0_m1_k1_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeBK0N0N1K1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc) + { + const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0); + const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + + const auto N1 = Number{}; + const auto N0 = N / N1; + + const auto b_k0_n0_n1_k1_grid_desc = transform_dynamic_tensor_descriptor( + b_k0_n_k1_grid_desc, + make_tuple(make_pass_through_transform(K0), + make_unmerge_transform(make_tuple(N0, N1)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return b_k0_n0_n1_k1_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + constexpr auto M11 = + Number{}; + constexpr auto N11 = + Number{}; + + constexpr auto M10 = M1 / M11; + constexpr auto N10 = N1 / N11; + + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), + make_unmerge_transform(make_tuple(N0, N10, N11))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_m0_m10_m11_n0_n10_n11_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return c_blockid_to_m0_n0_block_cluster_adaptor; + } + + using AK0M0M1K1GridDesc = decltype(MakeAK0M0M1K1GridDescriptor(AK0MK1GridDesc{})); + using BK0N0N1K1GridDesc = decltype(MakeBK0N0N1K1GridDescriptor(BK0NK1GridDesc{})); + using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{})); + using CBlockIdToM0N0BlockClusterAdaptor = + decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{})); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AK0M0M1K1GridDesc& a_k0_m0_m1_k1_grid_desc, + const BK0N0N1K1GridDesc& b_k0_n0_n1_k1_grid_desc, + const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, + const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor, + integral_constant, + integral_constant) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto c_m0_n0_block_cluster_idx = + c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( + make_multi_index(get_block_1d_id())); + + // HACK: this force index data into SGPR + const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); + const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); + + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = K1; + + // TODO: check alignment + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k0_m0_m1_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k0_n0_n1_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + + // TODO: check alignment + // A matrix in LDS memory, for blockwise GEMM + constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, for blockwise GEMM + constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() == + a_k0_m_k1_block_desc.GetElementSpaceSize() && + b_k0_n0_n1_k1_block_desc.GetElementSpaceSize() == + b_k0_n_k1_block_desc.GetElementSpaceSize() && + "wrong!"); + + // A matrix blockwise copy + auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< + BlockSize, + InMemoryDataOperationEnum_t::Set, + Sequence, + ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_k0_m0_m1_k1_grid_desc), + decltype(a_k0_m0_m1_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3>, + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths + ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths + ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder + false, + true>(a_k0_m0_m1_k1_grid_desc, + make_multi_index(0, im0, 0, 0), + a_k0_m0_m1_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< + BlockSize, + InMemoryDataOperationEnum_t::Set, + Sequence, + BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_k0_n0_n1_k1_grid_desc), + decltype(b_k0_n0_n1_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3>, + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths + BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths + BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder + false, + true>(b_k0_n0_n1_k1_grid_desc, + make_multi_index(0, in0, 0, 0), + b_k0_n0_n1_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[KPerBlock, MPerBlockM1] is in LDS + // b_mtx[KPerBlocl, NPerBlockN1] is in LDS + // c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in + // register + const auto blockwise_gemm = + BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockSize, + FloatAB, + FloatAB, + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + M1PerThreadM111, + N1PerThreadN111, + KPerThread, + M11N11ThreadClusterM110Xs, + M11N11ThreadClusterN110Xs, + M1PerThreadM111, + N1PerThreadN111>{}; + + constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = + decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); + + constexpr auto c_m10_m11_n10_n11_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2( + sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = math::integer_least_multiple( + a_k0_m0_m1_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = math::integer_least_multiple( + b_k0_n0_n1_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block_double = p_shared_block; + FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; + + // register allocation for output + auto c_thread_buf = make_static_buffer( + c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); + + ThreadwiseDynamicTensorSliceSet_v1{} + .Run(c_m10_m11_n10_n11_thread_desc, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + FloatAcc{0}); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); + + auto a_block_even_buf = make_dynamic_buffer( + p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); + auto b_block_even_buf = make_dynamic_buffer( + p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); + + auto a_block_odd_buf = make_dynamic_buffer( + p_a_block_double + a_block_aligned_space_size, + a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); + auto b_block_odd_buf = make_dynamic_buffer( + p_b_block_double + b_block_aligned_space_size, + b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); + b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); + + a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); + b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); + } + + if constexpr(HasMainKBlockLoop) + { + const auto K0 = a_k0_m0_m1_k1_grid_desc.GetLength(I0); + + index_t k_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, + a_block_slice_copy_step, + AGridMoveSliceWindowIteratorHacks{}); + b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, + b_block_slice_copy_step, + BGridMoveSliceWindowIteratorHacks{}); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); + b_blockwise_copy.RunRead( + b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, + a_block_even_buf, + b_block_even_buf, + c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); + + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, + a_block_slice_copy_step, + AGridMoveSliceWindowIteratorHacks{}); + b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, + b_block_slice_copy_step, + BGridMoveSliceWindowIteratorHacks{}); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); + b_blockwise_copy.RunRead( + b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); + b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); + + k_block_data_begin += 2 * KPerBlock; + } while(k_block_data_begin < K0 - 2 * KPerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, + a_block_slice_copy_step, + AGridMoveSliceWindowIteratorHacks{}); + b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, + b_block_slice_copy_step, + BGridMoveSliceWindowIteratorHacks{}); + + __syncthreads(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); + b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); + + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr auto M11 = + Number{}; + constexpr auto N11 = + Number{}; + + constexpr index_t M10 = MPerBlockM1 / M11; + constexpr index_t N10 = NPerBlockN1 / N11; + + constexpr index_t M111 = M1PerThreadM111; + constexpr index_t N111 = N1PerThreadN111; + + constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = + blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( + get_thread_local_1d_id()); + + ThreadwiseDynamicTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), + decltype(c_m0_m10_m11_n0_n10_n11_grid_desc), + Sequence<1, + c_m10_m11_n10_n11_thread_tensor_lengths[I0], + c_m10_m11_n10_n11_thread_tensor_lengths[I1], + 1, + c_m10_m11_n10_n11_thread_tensor_lengths[I2], + c_m10_m11_n10_n11_thread_tensor_lengths[I3]>, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{c_m0_m10_m11_n0_n10_n11_grid_desc, + make_multi_index(im0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], + in0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])} + .Run(c_m0_m10_m11_n0_n10_n11_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_grid_buf, + CGridIteratorHacks{}); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v2.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v2.hpp new file mode 100644 index 0000000000..34dea34833 --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v2.hpp @@ -0,0 +1,463 @@ +#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V2_HPP +#define CK_GRIDWISE_DYNAMIC_GEMM_V2_HPP + +#include "common_header.hpp" +#include "dynamic_multi_index_transform_helper.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "blockwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_transfer.hpp" +#include "blockwise_gemm_dlops_v3.hpp" + +namespace ck { + +template +struct GridwiseDynamicGemmDlops_km_kn_mn_v3 +{ + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto E = EPerBlock * 3 * 3; + + constexpr auto max_lds_align = + math::lcm(Number{}, Number{}); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_e_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}), max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_e_k_desc.GetElementSpaceSize(), max_lds_align); + + return a_block_space_size * sizeof(FloatAB); + } + + template + __device__ void Run(const AGlobalDesc& a_e_k_global_desc, + const FloatAB* __restrict__ p_a_global, + const BGlobalDesc& b_e_n_ho_wo_global_desc, + const FloatAB* __restrict__ p_b_global, + const CGlobalDesc& c_k_n_ho_wo_global_desc, + FloatC* __restrict__ p_c_global, + FloatAB* __restrict__ p_shared_block, + integral_constant, + integral_constant) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_e_k_global_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize()); + auto c_global_buf = make_dynamic_buffer( + p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize()); + + constexpr auto E = EPerBlock * 3 * 3; + + // const auto E = a_e_k_global_desc.GetLength(I0); + const auto K = a_e_k_global_desc.GetLength(I1); + + const auto N = b_e_n_ho_wo_global_desc.GetLength(I1); + const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2); + const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3); + +// divide block work by [M, N] +#if 0 + const auto k_block_work_num = K / Number{}; + const auto ho_block_work_num = Ho / Number{}; + const auto wo_block_work_num = Wo / Number{}; + const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num; + + const index_t k_block_work_id = get_block_1d_id() / hwo_block_work_num; + const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num; + + const index_t ho_block_work_id = hwo_block_work_id / wo_block_work_num; + const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num; +#else + // Hack: this force result into SGPR + const index_t k_block_work_num = __builtin_amdgcn_readfirstlane(K / KPerBlock); + const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock); + const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock); + const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num; + + const index_t k_block_work_id = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / hwo_block_work_num); + const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num; + + const index_t ho_block_work_id = + __builtin_amdgcn_readfirstlane(hwo_block_work_id / wo_block_work_num); + const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num; +#endif + + // lds max alignment + constexpr auto max_lds_align = + math::lcm(Number{}, Number{}); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_e_k_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}), max_lds_align); + + constexpr auto a_e_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_e_n_ho_wo_block_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + // c_thread_mtx definition: this is a mess + // TODO:: more elegent way of defining c_thread_mtx + constexpr auto c_k_n_ho_wo_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + auto blockwise_gemm = + BlockwiseGemmDlops_km_kn_m0m1n0n1_v3{}; + + auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const auto k_thread_id = c_thread_mtx_index.k; + const auto ho_thread_id = c_thread_mtx_index.h; + const auto wo_thread_id = c_thread_mtx_index.w; + + const index_t k_block_data_on_global = k_block_work_id * KPerBlock; + const index_t ho_block_data_on_global = ho_block_work_id * HoPerBlock; + const index_t wo_block_data_on_global = wo_block_work_id * WoPerBlock; + + const index_t ho_thread_data_on_global = + ho_block_data_on_global + ho_thread_id * HoPerThread; + const index_t wo_thread_data_on_global = + wo_block_data_on_global + wo_thread_id * WoPerThread; + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseDynamicTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_E_K, + ABlockTransferThreadClusterLengths_E_K, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_e_k_global_desc), + decltype(a_e_k_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 1>, + ABlockTransferSrcVectorDim, + 1, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_e_k_global_desc, + make_multi_index(0, k_block_data_on_global), + a_e_k_desc, + make_multi_index(0, 0)); + + constexpr auto b_e_n_ho_wo_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + auto b_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2< + FloatAB, + FloatAB, + decltype(b_e_n_ho_wo_global_desc), + decltype(b_e_n_ho_wo_thread_desc), + Sequence, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + 1, + true>(b_e_n_ho_wo_global_desc, + make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); + + auto a_block_buf = make_dynamic_buffer( + p_shared_block, a_e_k_desc.GetElementSpaceSize()); + + // register allocation for output + StaticBuffer + c_thread_buf; + + // initialize output thread tensor + ThreadwiseDynamicTensorSliceSet_v1>{} + .Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); + + constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_e_k_global_iterator_hacks = AGlobalIteratorHacks{}; + constexpr auto b_e_n_ho_wo_global_iterator_hacks = BGlobalIteratorHacks{}; + + // hack to control index calculation when move slice window for A and B matrix for + // threadwise copy + constexpr auto a_e_k_global_move_slice_window_iterator_hack = + AGlobalMoveSliceWindowIteratorHacks{}; + constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack = + BGlobalMoveSliceWindowIteratorHacks{}; + + // double regsiter buffer for b + StaticBuffer + b_thread_even_buf, b_thread_odd_buf; + + // LDS double buffer: preload data + { + a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_iterator_hacks); + + b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, + b_global_buf, + b_e_n_ho_wo_thread_desc, + make_tuple(I0, I0, I0, I0), + b_thread_even_buf, + b_e_n_ho_wo_global_iterator_hacks); + + a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf); + } + + __syncthreads(); + + if constexpr(HasMainKBlockLoop) + { + index_t e_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, + b_thread_slice_copy_step); + + b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, + b_global_buf, + b_e_n_ho_wo_thread_desc, + make_tuple(I0, I0, I0, I0), + b_thread_odd_buf, + b_e_n_ho_wo_global_iterator_hacks); + + // LDS double buffer: GEMM on current data + // TODO: @Zhang Jing: blockwise gemm should be able to move slice window + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + + blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0)); + + b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, + b_thread_slice_copy_step); + + b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, + b_global_buf, + b_e_n_ho_wo_thread_desc, + make_tuple(I0, I0, I0, I0), + b_thread_even_buf, + b_e_n_ho_wo_global_iterator_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); + + blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0)); + + e_block_data_begin += 2 * EPerBlock; + + } while(e_block_data_begin < E - 2 * EPerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, + b_thread_slice_copy_step); + + b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, + b_global_buf, + b_e_n_ho_wo_thread_desc, + make_tuple(I0, I0, I0, I0), + b_thread_odd_buf, + b_e_n_ho_wo_global_iterator_hacks); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + + blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0)); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + } + + // output: register to global memory + { + // hack to control index calculation when iterating over c_k_n_ho_wo_global tensor + constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; + + const index_t k_thread_data_on_global = + k_block_data_on_global + k_thread_id * KPerThread; + + ThreadwiseDynamicTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_k_n_ho_wo_thread_desc), + decltype(c_k_n_ho_wo_global_desc), + Sequence, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>( + c_k_n_ho_wo_global_desc, + make_multi_index( + k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global)) + .Run(c_k_n_ho_wo_thread_desc, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + c_k_n_ho_wo_global_desc, + c_global_buf, + c_k_n_ho_wo_global_tensor_iterator_hacks); + } + } + + // pass tensor descriptor by reference + template + __device__ void Run(const AGlobalDesc& a_e_k_global_desc, + const FloatAB* __restrict__ p_a_global, + const BGlobalDesc& b_e_n_ho_wo_global_desc, + const FloatAB* __restrict__ p_b_global, + const CGlobalDesc& c_k_n_ho_wo_global_desc, + FloatC* __restrict__ p_c_global, + integral_constant, + integral_constant) const + { + constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + Run(a_e_k_global_desc, + p_a_global, + b_e_n_ho_wo_global_desc, + p_b_global, + c_k_n_ho_wo_global_desc, + p_c_global, + p_shared_block, + integral_constant{}, + integral_constant{}); + } + + // pass tensor descriptors by their pointers + template + __device__ void Run(const AGlobalDesc* p_a_e_k_global_desc, + const FloatAB* __restrict__ p_a_global, + const BGlobalDesc* p_b_e_n_ho_wo_global_desc, + const FloatAB* __restrict__ p_b_global, + const CGlobalDesc* p_c_k_n_ho_wo_global_desc, + FloatC* __restrict__ p_c_global, + integral_constant, + integral_constant) const + { + const auto a_e_k_global_desc = *p_a_e_k_global_desc; + const auto b_e_n_ho_wo_global_desc = *p_b_e_n_ho_wo_global_desc; + const auto c_k_n_ho_wo_global_desc = *p_c_k_n_ho_wo_global_desc; + + Run(a_e_k_global_desc, + p_a_global, + b_e_n_ho_wo_global_desc, + p_b_global, + c_k_n_ho_wo_global_desc, + p_c_global, + integral_constant{}, + integral_constant{}); + } + + // pass tensor descriptors by void* + template + __device__ void Run(const void* p_a_e_k_global_desc, + const FloatAB* __restrict__ p_a_global, + const void* p_b_e_n_ho_wo_global_desc, + const FloatAB* __restrict__ p_b_global, + const void* p_c_k_n_ho_wo_global_desc, + FloatC* __restrict__ p_c_global, + integral_constant, + integral_constant) const + { + const auto a_e_k_global_desc = *reinterpret_cast(p_a_e_k_global_desc); + const auto b_e_n_ho_wo_global_desc = + *reinterpret_cast(p_b_e_n_ho_wo_global_desc); + const auto c_k_n_ho_wo_global_desc = + *reinterpret_cast(p_c_k_n_ho_wo_global_desc); + + Run(a_e_k_global_desc, + p_a_global, + b_e_n_ho_wo_global_desc, + p_b_global, + c_k_n_ho_wo_global_desc, + p_c_global, + integral_constant{}, + integral_constant{}); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r3.hpp new file mode 100644 index 0000000000..a5b1de79a7 --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r3.hpp @@ -0,0 +1,823 @@ +#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R3_HPP +#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R3_HPP + +#include "common_header.hpp" +#include "dynamic_multi_index_transform_helper.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "blockwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_set.hpp" + +namespace ck { + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_dynamic_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AK0MK1GridDesc a_k0_m_k1_grid_desc, + const BK0NK1GridDesc b_k0_n_k1_grid_desc, + const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc, + const CBlockClusterAdaptor c_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_m1_m2_n_grid_desc, + c_block_cluster_adaptor); +} +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_dynamic_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_a_k0_m_k1_grid_desc, + const void CONSTANT* p_b_k0_n_k1_grid_desc, + const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, + const void CONSTANT* p_c_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + const auto a_k0_m_k1_grid_desc = + *reinterpret_cast((const void*)p_a_k0_m_k1_grid_desc); + const auto b_k0_n_k1_grid_desc = + *reinterpret_cast((const void*)p_b_k0_n_k1_grid_desc); + const auto c_m0_m1_m2_n_grid_desc = + *reinterpret_cast((const void*)p_c_m0_m1_m2_n_grid_desc); + const auto c_block_cluster_adaptor = + *reinterpret_cast((const void*)p_c_block_cluster_adaptor); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_m1_m2_n_grid_desc, + c_block_cluster_adaptor); +} +#endif + +template +struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool + CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, + const BK0NK1GridDesc& b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc) + { + // TODO: turn on this + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_k0_n_k1_grid_desc.GetLength(I0) && + K1 == a_k0_m_k1_grid_desc.GetLength(I2) && + K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && + (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0) && + (MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0); + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); + + return grid_size; + } + + __host__ __device__ static constexpr auto + MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto xdlops_gemm = XdlopsGemm{}; + + constexpr auto CLayout = xdlops_gemm.GetCLayout(); + + constexpr auto M0 = Number{}; + constexpr auto M1 = Number{}; + constexpr auto M2 = Number{}; + + constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat); + constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat); + + constexpr auto N0 = Number{}; + constexpr auto N1 = Number{}; + + const auto c_m0_m1_m2_n_grid_desc = transform_dynamic_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M0, M1, M2)), + make_unmerge_transform(make_tuple(NRepeat, NWaves, N1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{})); + + return c_m0_m1_m2_n_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + +#if 1 + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); +#elif 1 + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, M0))), + make_tuple(Sequence<1, 0>{}), + make_tuple(Sequence<0>{})); +#endif + + return c_blockid_to_m0_n0_block_cluster_adaptor; + } + + using CM0M1M2NGridDesc = decltype(MakeCM0M1M2NGridDescriptor(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{})); + + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AK0MK1GridDesc& a_k0_m_k1_grid_desc, + const BK0NK1GridDesc& b_k0_n_k1_grid_desc, + const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc, + const CBlockClusterAdaptor& c_block_cluster_adaptor) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize()); + + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseDynamicTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_K0_M_K1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_k0_m_k1_grid_desc), + decltype(a_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_k0_m_k1_grid_desc, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_k0_m_k1_block_desc, + make_multi_index(0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseDynamicTensorSliceTransfer_v4, + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_k0_n_k1_grid_desc), + decltype(b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_k0_n_k1_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[KPerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + static_assert(MPerBlock % (MPerWave * MRepeat) == 0 && + NPerBlock % (NPerWave * NRepeat) == 0, + "wrong!"); + + constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor( + a_k0_m_k1_block_desc, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor( + b_k0_n_k1_block_desc, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto blockwise_gemm = + BlockwiseGemmXdlops_km_kn_m0m1m2n_v1{}; + + constexpr auto CLayout = blockwise_gemm.GetCLayout(); + + constexpr index_t BlkSize = CLayout.GetBlkSize(); + constexpr index_t NumBlks = CLayout.GetNumBlks(); + constexpr index_t NumXdlops = CLayout.GetNumXdlops(); + + static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only"); + + constexpr auto c_mr_nr_blk_desc = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{})); + + StaticBuffer, + c_mr_nr_blk_desc.GetElementSpaceSize()> + c_thread_buf; + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block = p_shared_block; + FloatAB* p_b_block = p_shared_block + a_block_space_size; + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_k0_m_k1_grid_iterator_hacks = AGridIteratorHacks{}; + constexpr auto b_k0_n_k1_grid_iterator_hacks = BGridIteratorHacks{}; + + // hack to control index calculation when move slice window for A and B matrix for + // threadwise copy + constexpr auto a_k0_m_k1_grid_move_slice_window_iterator_hack = + AGridMoveSliceWindowIteratorHacks{}; + constexpr auto b_k0_n_k1_grid_move_slice_window_iterator_hack = + BGridMoveSliceWindowIteratorHacks{}; + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // preload data into LDS + { + a_blockwise_copy.RunRead( + a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks); + b_blockwise_copy.RunRead( + b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks); + + a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); + } + + // main body + index_t k_block_data_begin = 0; + + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc, + a_block_slice_copy_step, + a_k0_m_k1_grid_move_slice_window_iterator_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc, + b_block_slice_copy_step, + b_k0_n_k1_grid_move_slice_window_iterator_hack); + + a_blockwise_copy.RunRead( + a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks); + + block_sync_lds(); + + b_blockwise_copy.RunRead( + b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); + + k_block_data_begin += KPerBlock; + } while(k_block_data_begin < (K0 - KPerBlock)); + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + +#if 0 + // output: register to global memory + { + constexpr index_t M0 = CLayout.M1(); + constexpr index_t M1 = CLayout.N1(); + constexpr index_t M2 = CLayout.M0(); + + constexpr index_t N0 = CLayout.N1(); + constexpr index_t N1 = CLayout.N0(); + + constexpr auto c_m0_m1_m2_n_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number{}, + Number{}, + Number<1>{}, + Number<1>{}, + Number{}, + Number<1>{}, + Number{}, + Number<1>{})); + + StaticBuffer + c_blk_buf_; + + static_for<0, MRepeat, 1>{}([&](auto mr_i) { + static_for<0, NRepeat, 1>{}([&](auto nr_i) { + constexpr auto blk_off = + c_mr_nr_blk_desc.CalculateOffset(make_tuple(mr_i, nr_i)); + + static_for<0, BlkSize, 1>{}([&](auto j) { + c_blk_buf_(Number{}) = + c_thread_buf[Number{}] + .template AsType()[Number{}]; + }); + }); + }); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{}; + + constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat); + constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat); + + ThreadwiseDynamicTensorSliceTransfer_v1r3< + FloatC, + FloatC, + decltype(c_m0_m1_m2_n_thread_desc), + decltype(c_m0_m1_m2_n_grid_desc), + Sequence, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + c_m0_m1_m2_n_grid_desc, + make_multi_index(m_thread_data_on_grid / (M2 * M1 * M0 * MWaves), + n_thread_data_on_grid / (N1 * NWaves), + m_thread_data_on_grid % (M2 * M1 * M0 * MWaves) / (M2 * M1 * M0), + n_thread_data_on_grid % (N1 * NWaves) / N1, + m_thread_data_on_grid % (M2 * M1 * M0) / (M2 * M1), + m_thread_data_on_grid % (M2 * M1) / M2, + m_thread_data_on_grid % M2, + n_thread_data_on_grid % N1)} + .Run(c_m0_m1_m2_n_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_blk_buf_, + c_m0_m1_m2_n_grid_desc, + c_grid_buf, + c_m0_m1_m2_n_grid_tensor_iterator_hacks); + } +#else + { + constexpr index_t M0 = CLayout.M1(); + constexpr index_t M1 = CLayout.N1(); + constexpr index_t M2 = CLayout.M0(); + + constexpr auto c_m0_m1_m2_n_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + I1, I1, I1, I1, Number{}, Number<1>{}, Number{}, Number<1>{})); + + StaticBuffer c_blk_buf_; + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{}; + + auto c_thread_copy = + ThreadwiseDynamicTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + c_m0_m1_m2_n_grid_desc, + make_multi_index(0, + 0, + 0, + 0, + m_thread_data_on_grid / (M2 * M1), + m_thread_data_on_grid % (M2 * M1) / M2, + m_thread_data_on_grid % M2, + n_thread_data_on_grid)}; + + auto init_copy = [&](auto c_thread_idx_) { + constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); + c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf[Number{}].template AsType(), + c_m0_m1_m2_n_grid_desc, + c_grid_buf, + c_m0_m1_m2_n_grid_tensor_iterator_hacks); + + return c_thread_idx_; + }; + + auto mrepeat_plus_copy = [&](auto c_thread_idx_) { + constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0); + c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus); + + constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); + c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf[Number{}].template AsType(), + c_m0_m1_m2_n_grid_desc, + c_grid_buf, + c_m0_m1_m2_n_grid_tensor_iterator_hacks); + }; + + auto nrepeat_plus_copy = [&](auto c_thread_idx_) { + constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0); + c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_plus); + + constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); + c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf[Number{}].template AsType(), + c_m0_m1_m2_n_grid_desc, + c_grid_buf, + c_m0_m1_m2_n_grid_tensor_iterator_hacks); + }; + + auto mrepeat_minus_copy = [&](auto c_thread_idx_) { + constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0); + c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus); + + constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); + c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf[Number{}].template AsType(), + c_m0_m1_m2_n_grid_desc, + c_grid_buf, + c_m0_m1_m2_n_grid_tensor_iterator_hacks); + }; + + auto nrepeat_minus_copy = [&](auto c_thread_idx_) { + constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0); + c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_minus); + + constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); + c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf[Number{}].template AsType(), + c_m0_m1_m2_n_grid_desc, + c_grid_buf, + c_m0_m1_m2_n_grid_tensor_iterator_hacks); + }; + + static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or + (MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) or + (MRepeat == 2 && NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or + (MRepeat == 1 && NRepeat == 1), + "wrong"); + + if constexpr(MRepeat == 4 && NRepeat == 4) + { + init_copy(make_tuple(I0, I0)); + + if constexpr(CAccessOrderMRepeatNRepeat) + { + nrepeat_plus_copy(make_tuple(I0, I1)); + nrepeat_plus_copy(make_tuple(I0, I2)); + nrepeat_plus_copy(make_tuple(I0, I3)); + mrepeat_plus_copy(make_tuple(I1, I3)); + nrepeat_minus_copy(make_tuple(I1, I2)); + nrepeat_minus_copy(make_tuple(I1, I1)); + nrepeat_minus_copy(make_tuple(I1, I0)); + mrepeat_plus_copy(make_tuple(I2, I0)); + nrepeat_plus_copy(make_tuple(I2, I1)); + nrepeat_plus_copy(make_tuple(I2, I2)); + nrepeat_plus_copy(make_tuple(I2, I3)); + mrepeat_plus_copy(make_tuple(I3, I3)); + nrepeat_minus_copy(make_tuple(I3, I2)); + nrepeat_minus_copy(make_tuple(I3, I1)); + nrepeat_minus_copy(make_tuple(I3, I0)); + } + else + { + mrepeat_plus_copy(make_tuple(I1, I0)); + mrepeat_plus_copy(make_tuple(I2, I0)); + mrepeat_plus_copy(make_tuple(I3, I0)); + nrepeat_plus_copy(make_tuple(I3, I1)); + mrepeat_minus_copy(make_tuple(I2, I1)); + mrepeat_minus_copy(make_tuple(I1, I1)); + mrepeat_minus_copy(make_tuple(I0, I1)); + nrepeat_plus_copy(make_tuple(I0, I2)); + mrepeat_plus_copy(make_tuple(I1, I2)); + mrepeat_plus_copy(make_tuple(I2, I2)); + mrepeat_plus_copy(make_tuple(I3, I2)); + nrepeat_plus_copy(make_tuple(I3, I3)); + mrepeat_minus_copy(make_tuple(I2, I3)); + mrepeat_minus_copy(make_tuple(I1, I3)); + mrepeat_minus_copy(make_tuple(I0, I3)); + } + } + else if constexpr(MRepeat == 4 && NRepeat == 2) + { + init_copy(make_tuple(I0, I0)); + + if constexpr(CAccessOrderMRepeatNRepeat) + { + nrepeat_plus_copy(make_tuple(I0, I1)); + mrepeat_plus_copy(make_tuple(I1, I1)); + nrepeat_minus_copy(make_tuple(I1, I0)); + mrepeat_plus_copy(make_tuple(I2, I0)); + nrepeat_plus_copy(make_tuple(I2, I1)); + mrepeat_plus_copy(make_tuple(I3, I1)); + nrepeat_minus_copy(make_tuple(I3, I0)); + } + else + { + mrepeat_plus_copy(make_tuple(I1, I0)); + mrepeat_plus_copy(make_tuple(I2, I0)); + mrepeat_plus_copy(make_tuple(I3, I0)); + nrepeat_plus_copy(make_tuple(I3, I1)); + mrepeat_minus_copy(make_tuple(I2, I1)); + mrepeat_minus_copy(make_tuple(I1, I1)); + mrepeat_minus_copy(make_tuple(I0, I1)); + } + } + else if constexpr(MRepeat == 2 && NRepeat == 4) + { + init_copy(make_tuple(I0, I0)); + + if constexpr(CAccessOrderMRepeatNRepeat) + { + nrepeat_plus_copy(make_tuple(I0, I1)); + nrepeat_plus_copy(make_tuple(I0, I2)); + nrepeat_plus_copy(make_tuple(I0, I3)); + mrepeat_plus_copy(make_tuple(I1, I3)); + nrepeat_minus_copy(make_tuple(I1, I2)); + nrepeat_minus_copy(make_tuple(I1, I1)); + nrepeat_minus_copy(make_tuple(I1, I0)); + } + else + { + mrepeat_plus_copy(make_tuple(I1, I0)); + nrepeat_plus_copy(make_tuple(I1, I1)); + mrepeat_minus_copy(make_tuple(I0, I1)); + nrepeat_plus_copy(make_tuple(I0, I2)); + mrepeat_plus_copy(make_tuple(I1, I2)); + nrepeat_plus_copy(make_tuple(I1, I3)); + mrepeat_minus_copy(make_tuple(I0, I3)); + } + } + else if constexpr(MRepeat == 2 && NRepeat == 2) + { + init_copy(make_tuple(I0, I0)); + + if constexpr(CAccessOrderMRepeatNRepeat) + { + nrepeat_plus_copy(make_tuple(I0, I1)); + mrepeat_plus_copy(make_tuple(I1, I1)); + nrepeat_minus_copy(make_tuple(I1, I0)); + } + else + { + mrepeat_plus_copy(make_tuple(I1, I0)); + nrepeat_plus_copy(make_tuple(I1, I1)); + mrepeat_minus_copy(make_tuple(I0, I1)); + } + } + else if constexpr(MRepeat == 2 && NRepeat == 1) + { + init_copy(make_tuple(I0, I0)); + mrepeat_plus_copy(make_tuple(I1, I0)); + } + else if constexpr(MRepeat == 1 && NRepeat == 2) + { + init_copy(make_tuple(I0, I0)); + nrepeat_plus_copy(make_tuple(I0, I1)); + } + else if constexpr(MRepeat == 1 && NRepeat == 1) + { + init_copy(make_tuple(I0, I0)); + } + } +#endif + } +}; // namespace ck + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp b/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp new file mode 100644 index 0000000000..7e7bb9c8c3 --- /dev/null +++ b/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp @@ -0,0 +1,230 @@ +#ifndef CK_THREADWISE_CONTRACTION_DLOPS_HPP +#define CK_THREADWISE_CONTRACTION_DLOPS_HPP + +#include "common_header.hpp" +#include "math.hpp" + +namespace ck { + +// C[TM0, TM1, TN0, TN1] += A[TK, TM0, TM1] * B[TK, TN0, TN1] +// Tensor element can be vectorized data +// Assume: +// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are +// known at compile-time +// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time +template ::type = false> +struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 +{ + __device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1() + { + static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && + BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && + CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + // TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, + // CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths + + // TODO remove this restriction + static_assert(TKLengths::Size() == 1 && TMLengths::Size() == 2 && TNLengths::Size() == 2, + "wrong!"); + } + + template + __device__ static void Run(const ABuffer& a_buf, + AOriginIdx, + const BBuffer& b_buf, + BOriginIdx, + CBuffer& c_buf, + COriginIdx) + { + static_assert( + is_known_at_compile_time>>::value && + is_known_at_compile_time>>::value && + is_known_at_compile_time>>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + + static_assert(is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + "wrong! inconsistent type"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + constexpr auto TK = TKLengths{}[I0]; + constexpr auto TM0 = TMLengths{}[I0]; + constexpr auto TM1 = TMLengths{}[I1]; + constexpr auto TN0 = TNLengths{}[I0]; + constexpr auto TN1 = TNLengths{}[I1]; + + constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); + constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); + constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); + + static_for<0, TK, 1>{}([&](auto tk) { + static_for<0, TM0, 1>{}([&](auto tm0) { + static_for<0, TM1, 1>{}([&](auto tm1) { + static_for<0, TN0, 1>{}([&](auto tn0) { + static_for<0, TN1, 1>{}([&](auto tn1) { + constexpr index_t a_offset = + AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( + a_origin_idx + make_multi_index(tk, tm0, tm1)); + constexpr index_t b_offset = + BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( + b_origin_idx + make_multi_index(tk, tn0, tn1)); + constexpr index_t c_offset = + CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( + c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); + + amd_inner_product_dlop( + a_buf[Number{}], + b_buf[Number{}], + c_buf(Number{})); + }); + }); + }); + }); + }); + } +}; + +// C[TM0, TM1, TN0, TN1] += A[TK0, TM0, TM1, TK1] * B[TK0, TN0, TN1, TK1] +// Tensor element can be vectorized data +// Assume: +// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are +// known at compile-time +// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time +template ::type = false> +struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 +{ + __device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() + { + static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && + BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && + CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + // TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, + // CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths + + // TODO remove this restriction + static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2, + "wrong!"); + } + + template + __device__ static void Run(const ABuffer& a_buf, + AOriginIdx, + const BBuffer& b_buf, + BOriginIdx, + CBuffer& c_buf, + COriginIdx) + { + static_assert( + is_known_at_compile_time>>::value && + is_known_at_compile_time>>::value && + is_known_at_compile_time>>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + + static_assert(is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + "wrong! inconsistent type"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + constexpr index_t TK0 = TKLengths{}[I0]; + constexpr index_t TK1 = TKLengths{}[I1]; + constexpr index_t TM0 = TMLengths{}[I0]; + constexpr index_t TM1 = TMLengths{}[I1]; + constexpr index_t TN0 = TNLengths{}[I0]; + constexpr index_t TN1 = TNLengths{}[I1]; + + constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); + constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); + constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); + + static_for<0, TK0, 1>{}([&](auto tk0) { + static_for<0, TM0, 1>{}([&](auto tm0) { + static_for<0, TM1, 1>{}([&](auto tm1) { + static_for<0, TN0, 1>{}([&](auto tn0) { + static_for<0, TN1, 1>{}([&](auto tn1) { + vector_type a_vec; + vector_type b_vec; + + static_for<0, TK1, 1>{}([&](auto tk1) { + constexpr index_t a_offset = + AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( + a_origin_idx + make_multi_index(tk0, tm0, tm1, tk1)); + + constexpr index_t b_offset = + BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( + b_origin_idx + make_multi_index(tk0, tn0, tn1, tk1)); + + a_vec.template AsType()(tk1) = a_buf[Number{}]; + b_vec.template AsType()(tk1) = b_buf[Number{}]; + }); + + using a_vector_t = typename vector_type::type; + using b_vector_t = typename vector_type::type; + + constexpr index_t c_offset = + CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( + c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); + + amd_inner_product_dlop( + a_vec.template AsType()[I0], + b_vec.template AsType()[I0], + c_buf(Number{})); + }); + }); + }); + }); + }); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_set.hpp b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_set.hpp new file mode 100644 index 0000000000..f1b632aa84 --- /dev/null +++ b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_set.hpp @@ -0,0 +1,59 @@ +#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP +#define CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// Assume: +// 1. Desc is known at compile-time +// 2. Buffer is StaticBuffer +// 3. OriginIdx is known at compile-time +// 4. use #-iterator +template ::type = false> +struct ThreadwiseDynamicTensorSliceSet_v1 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + template + __device__ void Run(const Desc&, const OriginIdx&, Buffer& buf, const Data& initial_value) const + { + static_assert(Desc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); + + static_assert(is_known_at_compile_time>>::value, + "wrong! OriginIdx need to be known at compile-time"); + + // Desc is known at compile-time + constexpr auto desc = remove_cv_t>{}; + + // OriginIdx is known at compile-time + constexpr auto origin_idx = to_multi_index(OriginIdx{}); + + static_ford{}([&](auto access_idx) { + constexpr auto coord = make_dynamic_tensor_coordinate(desc, origin_idx + access_idx); + + constexpr bool is_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(desc, coord); + + constexpr index_t offset = coord.GetOffset(); + + if constexpr(is_valid) + { + buf(Number{}) = initial_value; + } + }); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp new file mode 100644 index 0000000000..9626113686 --- /dev/null +++ b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp @@ -0,0 +1,1449 @@ +#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP +#define CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +namespace detail { +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor +template +struct lambda_scalar_per_access +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + return (i == VectorDim) ? ScalarPerVector : 1; + } +}; + +template +struct lambda_scalar_step_in_vector +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + return (i == VectorDim) ? 1 : 0; + } +}; +} // namespace detail + +// Assume: +// 1. src: +// 1. SrcDesc is known at compile-time +// 2. SrcBuffer is StaticBuffer +// 3. SrcSliceOrginIdx is known at compile-time +// 2. dst: +// 1. DstDesc is not known at compile-time +// 2. DstBuffer is DynamicBuffer +// 3. DstSliceOrginIdx is not known at compile time +template ::type = false> +struct ThreadwiseDynamicTensorSliceTransfer_v1r3 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{})); + + using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{})); + + __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v1r3( + const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + : dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx)) + { + static_assert(SrcDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf, + const DstIteratorHacks& dst_iterator_hacks) + { + static_assert(SrcDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + + static_assert( + is_known_at_compile_time>>::value, + "wrong! SrcSliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); + + // static_assert(is_same>, + // remove_cv_t>>::value, + //"wrong! SrcBuffer data type is wrong"); + + // SrcDesc and src_slice_origin_idx are known at compile-time + constexpr auto src_desc = remove_cv_t>{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + // make forward iterators + const auto dst_forward_iterators = generate_tuple( + [&](auto i) { + Index forward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_dynamic_tensor_coordinate_iterator( + dst_desc, forward_step, dst_iterator_hacks[I0][i]); + }, + Number{}); + + // make backward iterators + const auto dst_backward_iterators = generate_tuple( + [&](auto i) { + Index backward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_dynamic_tensor_coordinate_iterator( + dst_desc, backward_step, dst_iterator_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_idx[I0]; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] + ? ordered_access_idx[i] + : ordered_access_lengths[i] - 1 - ordered_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + dst_scalar_per_access; + }(); + + typename vector_type_maker::type dst_vector; + + using dst_vector_t = + typename vector_type_maker::type::type; + + // copy data from src_buf into dst_vector + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector); + + dst_vector.template AsType()(i) = + type_convert{}(src_buf[Number{}]); + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_dynamic_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_iterators[dim_access_order[i]]); + } + else + { + move_dynamic_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_iterators[dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_iterator = + make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep()); + + move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator); + } + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto dst_iterator_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_iterator_hacks); + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in Run(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + dst_scalar_per_access; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = + make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx); + + move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + DstCoord dst_coord_; +}; // namespace ck + +// Assume: +// 1. src: +// 1. SrcDesc is not known at compile-time +// 2. SrcBuffer is DynamicBuffer +// 3. src_slice_origin_idx is not known at compile-time +// 2. dst: +// 1. DstDesc is known at compile-time +// 2. DstBuffer is StaticBuffer +// 3. dst_slice_origin_idx is known at compile-time +template ::type = false> +struct ThreadwiseDynamicTensorSliceTransfer_v2 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{})); + + using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{})); + + __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v2(const SrcDesc& src_desc, + const Index& src_slice_origin_idx) + : src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx)) + { + static_assert(DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + } + + __device__ void SetDstSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf, + const SrcIteratorHacks& src_iterator_hacks) + { + static_assert(DstDesc::IsKnownAtCompileTime(), + "wrong! DstDesc need to known at compile-time"); + + static_assert( + is_known_at_compile_time>>::value, + "wrong! DstSliceOrigin need to known at compile-time"); + + static_assert(is_same>, + remove_cv_t>>::value && + "wrong! inconsistent type"); + + // DstDesc and dst_slice_origin_idx are known at compile-time + constexpr auto dst_desc = remove_cv_t>{}; + constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + // make forward iterators + const auto src_forward_iterators = generate_tuple( + [&](auto i) { + Index forward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + }); + + return make_dynamic_tensor_coordinate_iterator( + src_desc, forward_step, src_iterator_hacks[I0][i]); + }, + Number{}); + + // make backward iterators + const auto src_backward_iterators = generate_tuple( + [&](auto i) { + Index backward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + }); + + return make_dynamic_tensor_coordinate_iterator( + src_desc, backward_step, src_iterator_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_idx[I0]; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] + ? ordered_access_idx[i] + : ordered_access_lengths[i] - 1 - ordered_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + src_scalar_per_access; + }(); + + typename vector_type_maker::type src_vector; + + using src_vector_t = + typename vector_type_maker::type::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // copy data from src_buf into src_vector + src_vector.template AsType()(Number<0>{}) = + src_buf.template Get(src_coord_.GetOffset(), is_src_valid); + + // copy data from src_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = + dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = src_vector.template AsType()[i]; + }); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_dynamic_tensor_coordinate( + src_desc, src_coord_, src_forward_iterators[dim_access_order[i]]); + } + else + { + move_dynamic_tensor_coordinate( + src_desc, src_coord_, src_backward_iterators[dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_iterator = + make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep()); + + move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator); + } + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) + { + constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto src_iterator_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_iterator_hacks); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in Run(), if it has not being reset by + // RunWrite() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = + make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx); + + move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + private: + SrcCoord src_coord_; +}; // namespace ck + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template // control whether to move back dst coordinate after each + // RunWrite(), will be fused with MoveDstSliceWindow to + // save addr computation +struct ThreadwiseDynamicTensorSliceTransfer_v3 +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{})); + + using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{})); + using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{})); + + __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v3(const SrcDesc& src_desc, + const Index& src_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin) + : src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin)) + { + // TODO: fix this + static_assert(is_same::value, + "wrong! current implementation assume SrcData and DstData are same type"); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const SrcIteratorHacks& src_iterator_hacks) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, + "wrong!"); + + static_assert(is_same>, + remove_cv_t>>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward iterators + const auto src_forward_iterators = generate_tuple( + [&](auto i) { + Index forward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + }); + + return make_dynamic_tensor_coordinate_iterator( + src_desc, forward_step, src_iterator_hacks[I0][i]); + }, + Number{}); + + // make backward iterators + const auto src_backward_iterators = generate_tuple( + [&](auto i) { + Index backward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + }); + + return make_dynamic_tensor_coordinate_iterator( + src_desc, backward_step, src_iterator_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + vector_type_maker_t src_tmp_vector; + + using src_vector_t = typename decltype(src_tmp_vector)::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // copy data from src_buf to src_tmp_vector + src_tmp_vector.template AsType()(Number<0>{}) = + src_buf.template Get(src_coord_.GetOffset(), is_src_valid); + + // copy data from src_tmp_vector to buffer_ + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t buffer_offset = + buffer_desc_.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector); + + buffer_(Number{}) = src_tmp_vector.template AsType()[i]; + }); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_dynamic_tensor_coordinate( + src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]); + } + else + { + move_dynamic_tensor_coordinate( + src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_iterator = + make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep()); + + move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator); + } + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + const DstIteratorHacks& dst_iterator_hacks) + { + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, + "wrong!"); + + static_assert(is_same>, + remove_cv_t>>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward iterators + const auto dst_forward_iterators = generate_tuple( + [&](auto i) { + Index forward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + const auto forward_iterator = make_dynamic_tensor_coordinate_iterator( + dst_desc, forward_step, dst_iterator_hacks[I0][i]); + + return forward_iterator; + }, + Number{}); + + // make backward iterators + const auto dst_backward_iterators = generate_tuple( + [&](auto i) { + Index backward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + const auto backward_iterator = make_dynamic_tensor_coordinate_iterator( + dst_desc, backward_step, dst_iterator_hacks[I1][i]); + + return backward_iterator; + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + vector_type_maker_t dst_tmp_vector; + + // copy data from buffer_ to dst_tmp_vector + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + constexpr index_t buffer_offset = + buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector); + + dst_tmp_vector.template AsType()(i) = buffer_[Number{}]; + }); + + using dst_vector_t = typename decltype(dst_tmp_vector)::type; + + // copy data from dst_tmp_vector to dst_buf + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_tmp_vector.template AsType()[Number<0>{}]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_dynamic_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]); + } + else + { + move_dynamic_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_iterator = + make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep()); + + move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + { + constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto src_iterator_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + RunRead(src_desc, src_buf, src_iterator_hacks); + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) + { + constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto dst_iterator_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + RunWrite(dst_desc, dst_buf, dst_iterator_hacks); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = + make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx); + + move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx, + const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_dynamic_tensor_coordinate_iterator( + src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack); + + move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = + make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx); + + move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + static constexpr auto buffer_desc_ = + make_dynamic_naive_tensor_descriptor_packed_v2(sequence_to_tuple_of_number(SliceLengths{})); + + static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); + + StaticBuffer buffer_; + + SrcCoord src_coord_; + DstCoord dst_coord_; +}; + +// Assume: +// 1. src: +// 1. SrcDesc is known at compile-time +// 2. SrcBuffer is DynamicBuffer +// 3. src_ref_idx is known at run-time +// 4. SrcRefToOriginDisplacement is known at compile-time +// 5. use #-iterator +// 2. dst: +// 1. DstDesc is known at compile-time +// 2. DstBuffer is StaticBuffer +// 3. DstOriginIdx is known at compile-time +// 4. use direct address calculation +// 3. vector access on src +template < + typename SrcData, + typename DstData, + typename SrcDesc, + typename DstDesc, + typename SliceLengths, + typename DimAccessOrder, + index_t SrcVectorDim, + index_t SrcScalarPerVector, + index_t SrcScalarStrideInVector, + typename std::enable_if::type = false> +struct ThreadwiseDynamicTensorSliceTransfer_v4 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{})); + + using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{})); + + __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v4(const Index& src_ref_idx) + : src_ref_coord_(make_dynamic_tensor_coordinate(SrcDesc{}, src_ref_idx)) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, "wrong!"); + } + + template + __device__ void Run(const SrcDesc&, + const SrcRefToOriginDisplacement&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert(is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); + + static_assert( + is_known_at_compile_time< + remove_cv_t>>::value && + is_known_at_compile_time>>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); + + // SrcDesc and DstDesc are known at compile-time + constexpr auto src_desc = remove_cv_t>{}; + constexpr auto dst_desc = remove_cv_t>{}; + + // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time + constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); + constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); + + // scalar per access of each dim + constexpr auto src_scalar_per_access = generate_sequence_v2( + [&](auto i) constexpr { + if constexpr(i == SrcVectorDim) + { + return Number{}; + } + else + { + return Number<1>{}; + } + }, + Number{}); + + // scalar step (if steping on SrcVectorDim) of each dim + constexpr auto src_scalar_step_in_vector = generate_sequence_v2( + [&](auto i) constexpr { + if constexpr(i == SrcVectorDim) + { + return Number<1>{}; + } + else + { + return Number<0>{}; + } + }, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static_ford{}([&](auto ordered_access_idx) { +#if 0 + // TODO: unable to compile + // position in slice window + constexpr auto data_to_origin_disp_idx = + container_reorder_given_old2new(ordered_access_idx, dim_access_order) * + src_scalar_per_access; +#else + // position in slice window + constexpr auto data_to_origin_disp_idx = + ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access; +#endif + // src coordinate + constexpr auto src_ref_to_data_disp_idx = + src_ref_to_origin_disp_idx + data_to_origin_disp_idx; + + constexpr auto src_ref_to_data_disp_coord_iterator = + make_dynamic_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx); + + auto src_data_coord = src_ref_coord_; + + move_dynamic_tensor_coordinate( + src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator); + + vector_type_maker_t src_tmp_vector; + + using src_vector_t = typename decltype(src_tmp_vector)::type; + + const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_desc, src_data_coord); + + // copy data from src_buf into src_tmp_vector + src_tmp_vector.template AsType()(Number<0>{}) = + src_buf.template Get(src_data_coord.GetOffset(), is_src_valid); + + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + // TODO: if SrcData and DstData are vetor type, then static_cast may not compile + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + dst_tmp_vector.template AsType()(i) = + type_convert{}(src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + }); + } + + template + __device__ void MoveSrcSliceWindow(const SrcDesc&, + const SrcSliceMoveStepIdx& src_slice_move_step_idx) + { + constexpr auto src_desc = SrcDesc{}; + + const auto src_slice_move_step_iter = make_dynamic_tensor_coordinate_iterator( + src_desc, to_multi_index(src_slice_move_step_idx)); + + move_dynamic_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); + } + + private: + SrcCoord src_ref_coord_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer_v2.hpp b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer_v2.hpp new file mode 100644 index 0000000000..ba60e26c38 --- /dev/null +++ b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer_v2.hpp @@ -0,0 +1,789 @@ +#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP +#define CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template // control whether to move back dst coordinate after each + // RunWrite(), will be fused with MoveDstSliceWindow to + // save addr computation +struct ThreadwiseDynamicTensorSliceTransfer_v3r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{})); + + using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{})); + using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{})); + + __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v3r1(const SrcDesc& src_desc, + const Index& src_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin) + : src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin)) + { + // TODO: fix this + static_assert(is_same::value, + "wrong! current implementation assume SrcData and DstData are same type"); + + static_for<0, nDim, 1>{}([](auto i) { + static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0 && + SliceLengths::At(i) % DstVectorTensorLengths::At(i) == 0, + "wrong!"); + }); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const SrcIteratorHacks& src_iterator_hacks) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, + "wrong!"); + + static_assert(is_same>, + remove_cv_t>>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + // tensor descriptor for src_vector + constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; + + constexpr auto src_vector_tensor_strides = container_reorder_given_old2new( + container_reverse_exclusive_scan( + container_reorder_given_new2old(src_vector_tensor_lengths, + SrcVectorTensorContiguousDimOrder{}), + math::multiplies_v2{}, + I1), + SrcVectorTensorContiguousDimOrder{}); + + constexpr auto src_vector_desc = make_dynamic_naive_tensor_descriptor_v2( + sequence_to_tuple_of_number(src_vector_tensor_lengths), + sequence_to_tuple_of_number(src_vector_tensor_strides)); + + // access order and lengths + constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward iterators + const auto src_forward_iterators = generate_tuple( + [&](auto i) { + Index forward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0; + }); + + return make_dynamic_tensor_coordinate_iterator( + src_desc, forward_step, src_iterator_hacks[I0][i]); + }, + Number{}); + + // make backward iterators + const auto src_backward_iterators = generate_tuple( + [&](auto i) { + Index backward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0; + }); + + return make_dynamic_tensor_coordinate_iterator( + src_desc, backward_step, src_iterator_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_vector_tensor_lengths; + }(); + + vector_type_maker_t src_vector; + + using src_vector_t = typename decltype(src_vector)::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // copy data from src_buf to src_vector + src_vector.template AsType()(I0) = + src_buf.template Get(src_coord_.GetOffset(), is_src_valid); + + // copy data from src_vector to buffer_ + static_ford{}([&](auto src_vector_idx_) { + constexpr auto src_vector_idx = to_multi_index(src_vector_idx_); + + constexpr index_t src_vector_offset = + src_vector_desc.CalculateOffset(src_vector_idx); + + constexpr index_t buffer_offset = + buffer_desc_.CalculateOffset(src_data_idx + src_vector_idx); + + buffer_(Number{}) = + src_vector.template AsType()[Number{}]; + }); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_dynamic_tensor_coordinate( + src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]); + } + else + { + move_dynamic_tensor_coordinate( + src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_iterator = + make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep()); + + move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator); + } + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + const DstIteratorHacks& dst_iterator_hacks) + { + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, + "wrong!"); + + static_assert(is_same>, + remove_cv_t>>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + // tensor descriptor for dst_vector + constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{}; + + constexpr auto dst_vector_tensor_strides = container_reorder_given_old2new( + container_reverse_exclusive_scan( + container_reorder_given_new2old(dst_vector_tensor_lengths, + DstVectorTensorContiguousDimOrder{}), + math::multiplies_v2{}, + I1), + DstVectorTensorContiguousDimOrder{}); + + constexpr auto dst_vector_desc = make_dynamic_naive_tensor_descriptor_v2( + sequence_to_tuple_of_number(dst_vector_tensor_lengths), + sequence_to_tuple_of_number(dst_vector_tensor_strides)); + + // dst access order and lengths + constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward iterators + const auto dst_forward_iterators = generate_tuple( + [&](auto i) { + Index forward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0; + }); + + const auto forward_iterator = make_dynamic_tensor_coordinate_iterator( + dst_desc, forward_step, dst_iterator_hacks[I0][i]); + + return forward_iterator; + }, + Number{}); + + // make backward iterators + const auto dst_backward_iterators = generate_tuple( + [&](auto i) { + Index backward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0; + }); + + const auto backward_iterator = make_dynamic_tensor_coordinate_iterator( + dst_desc, backward_step, dst_iterator_hacks[I1][i]); + + return backward_iterator; + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_vector_tensor_lengths; + }(); + + vector_type_maker_t dst_vector; + + // copy data from buffer_ to dst_vector (also cast from SrcData to DstData) + static_ford{}([&](auto dst_vector_idx_) { + constexpr auto dst_vector_idx = to_multi_index(dst_vector_idx_); + + constexpr index_t buffer_offset = + buffer_desc_.CalculateOffset(dst_data_idx + dst_vector_idx); + + constexpr index_t dst_vector_offset = + dst_vector_desc.CalculateOffset(dst_vector_idx); + + dst_vector.template AsType()(Number{}) = + type_convert{}(buffer_[Number{}]); + }); + + using dst_vector_t = typename decltype(dst_vector)::type; + + // copy data from dst_vector to dst_buf + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_dynamic_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]); + } + else + { + move_dynamic_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_iterator = + make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep()); + + move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + { + constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto src_iterator_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + RunRead(src_desc, src_buf, src_iterator_hacks); + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) + { + constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto dst_iterator_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + RunWrite(dst_desc, dst_buf, dst_iterator_hacks); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; + + constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_vector_tensor_lengths; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{}; + + constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_vector_tensor_lengths; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = + make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx); + + move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx, + const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_dynamic_tensor_coordinate_iterator( + src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack); + + move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = + make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx); + + move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + static constexpr auto buffer_desc_ = + make_dynamic_naive_tensor_descriptor_packed_v2(sequence_to_tuple_of_number(SliceLengths{})); + + static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); + + StaticBuffer buffer_; + + SrcCoord src_coord_; + DstCoord dst_coord_; +}; + +// Assume: +// 1. src: +// 1. SrcDesc is known at compile-time +// 2. SrcBuffer is DynamicBuffer +// 3. src_ref_idx is known at run-time +// 4. SrcRefToOriginDisplacement is known at compile-time +// 5. use #-iterator +// 2. dst: +// 1. DstDesc is known at compile-time +// 2. DstBuffer is StaticBuffer +// 3. DstOriginIdx is known at compile-time +// 4. use direct address calculation +// 3. vector access on src +template < + typename SrcData, + typename DstData, + typename SrcDesc, + typename DstDesc, + typename SliceLengths, + typename DimAccessOrder, + typename SrcVectorTensorLengths, + typename SrcVectorTensorContiguousDimOrder, + typename std::enable_if::type = false> +struct ThreadwiseDynamicTensorSliceTransfer_v4r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{})); + + using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{})); + + __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v4r1(const Index& src_ref_idx) + : src_ref_coord_(make_dynamic_tensor_coordinate(SrcDesc{}, src_ref_idx)) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_for<0, nDim, 1>{}([](auto i) { + static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0, "wrong!"); + }); + } + + template + __device__ void Run(const SrcDesc&, + const SrcRefToOriginDisplacement&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert(is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); + + static_assert( + is_known_at_compile_time< + remove_cv_t>>::value && + is_known_at_compile_time>>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); + + // SrcDesc and DstDesc are known at compile-time + constexpr auto src_desc = remove_cv_t>{}; + constexpr auto dst_desc = remove_cv_t>{}; + + // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time + constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); + constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); + + // tensor descriptor for src_vector + constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; + + constexpr auto src_vector_tensor_strides = container_reorder_given_old2new( + container_reverse_exclusive_scan( + container_reorder_given_new2old(src_vector_tensor_lengths, + SrcVectorTensorContiguousDimOrder{}), + math::multiplies_v2{}, + I1), + SrcVectorTensorContiguousDimOrder{}); + + constexpr auto src_vector_desc = make_dynamic_naive_tensor_descriptor_v2( + sequence_to_tuple_of_number(src_vector_tensor_lengths), + sequence_to_tuple_of_number(src_vector_tensor_strides)); + + // access order and lengths + constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static_ford{}([&](auto ordered_access_idx) { + // position in slice window + constexpr auto data_to_origin_disp_idx = + ordered_access_idx.ReorderGivenOld2New(dim_access_order) * + src_vector_tensor_lengths; + + // src coordinate at starting point of src_vector + constexpr auto src_ref_to_data_disp_idx = + src_ref_to_origin_disp_idx + data_to_origin_disp_idx; + + constexpr auto src_ref_to_data_disp_coord_iterator = + make_dynamic_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx); + + auto src_data_coord = src_ref_coord_; + + move_dynamic_tensor_coordinate( + src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator); + + vector_type_maker_t src_vector; + + using src_vector_t = typename decltype(src_vector)::type; + + const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_desc, src_data_coord); + + // copy data from src_buf into src_vector + src_vector.template AsType()(I0) = + src_buf.template Get(src_data_coord.GetOffset(), is_src_valid); + + // copy data from src_vector into dst_buf (also cast from SrcData to DstData) + static_ford{}([&](auto src_vector_idx_) { + constexpr auto src_vector_idx = to_multi_index(src_vector_idx_); + + constexpr index_t src_vector_offset = + src_vector_desc.CalculateOffset(src_vector_idx); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + src_vector_idx); + + dst_buf(Number{}) = type_convert{}( + src_vector.template AsType()[Number{}]); + }); + }); + } + + template + __device__ void MoveSrcSliceWindow(const SrcDesc&, + const SrcSliceMoveStepIdx& src_slice_move_step_idx) + { + constexpr auto src_desc = SrcDesc{}; + + const auto src_slice_move_step_iter = make_dynamic_tensor_coordinate_iterator( + src_desc, to_multi_index(src_slice_move_step_idx)); + + move_dynamic_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); + } + + private: + SrcCoord src_ref_coord_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp b/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp new file mode 100644 index 0000000000..153d512df7 --- /dev/null +++ b/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp @@ -0,0 +1,162 @@ +#ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP +#define CK_THREADWISE_GEMM_DLOPS_V3_HPP + +#include "common_header.hpp" +#include "math.hpp" + +namespace ck { + +// C[M, N] += transpose(A[K, M]) * B[K, N] +// Element of matrix can be vectorized data +// Assume: +// 1. ADesc, BDesc, CDesc are known at compile-time +// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time +template ::type = false> +struct ThreadwiseGemmDlops_km_kn_mn_v3 +{ + template + __device__ static void Run(const ABuffer& a_buf, + AOriginIdx, + const BBuffer& b_buf, + BOriginIdx, + CBuffer& c_buf, + COriginIdx) + { + static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && + CDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert( + is_known_at_compile_time>>::value && + is_known_at_compile_time>>::value && + is_known_at_compile_time>>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + + static_assert(is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + "wrong! inconsistent type"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto E = ADesc{}.GetLength(I0); + constexpr auto K = ADesc{}.GetLength(I1); + + constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); + constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); + constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); + + static_for<0, E, 1>{}([&](auto e) { + static_for<0, K, 1>{}([&](auto k) { + constexpr index_t a_offset = + ADesc{}.CalculateOffset(a_origin_idx + make_tuple(e, k)); + + if constexpr(H == 2 && W == 2) + { + constexpr index_t b_offset_0 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0)); + constexpr index_t b_offset_1 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 1)); + constexpr index_t b_offset_2 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0)); + constexpr index_t b_offset_3 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 1)); + + constexpr index_t c_offset_0 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0)); + constexpr index_t c_offset_1 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 1)); + constexpr index_t c_offset_2 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0)); + constexpr index_t c_offset_3 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 1)); + + amd_assembly_outer_product_1x4(a_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{})); + } + else if constexpr(H == 4 && W == 1) + { + constexpr index_t b_offset_0 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0)); + constexpr index_t b_offset_1 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0)); + constexpr index_t b_offset_2 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 2, 0)); + constexpr index_t b_offset_3 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 3, 0)); + + constexpr index_t c_offset_0 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0)); + constexpr index_t c_offset_1 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0)); + constexpr index_t c_offset_2 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 2, 0)); + constexpr index_t c_offset_3 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 3, 0)); + + amd_assembly_outer_product_1x4(a_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{})); + } + else + { + static_for<0, H, 1>{}([&](auto h) { + static_for<0, W, 1>{}([&](auto w) { + constexpr index_t b_offset = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, h, w)); + + constexpr index_t c_offset = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, h, w)); + +#if 0 + c_buf(Number{}) += inner_product_with_conversion{}( + a_buf[Number{}], b_buf[Number{}]); +#else + amd_assembly_inner_product(a_buf[Number{}], + b_buf[Number{}], + c_buf(Number{})); +#endif + }); + }); + } + }); + }); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp new file mode 100644 index 0000000000..876a1174e7 --- /dev/null +++ b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp @@ -0,0 +1,801 @@ +#ifndef CK_XDLOPS_GEMM_HPP +#define CK_XDLOPS_GEMM_HPP + +#include "common_header.hpp" +#include "math.hpp" +#include "amd_xdlops.hpp" + +namespace ck { + +enum struct mfma_instr +{ + /// fp32 + mfma_f32_32x32x1xf32 = 0, + mfma_f32_16x16x1xf32, + mfma_f32_4x4x1xf32, + mfma_f32_32x32x2xf32, // k reduction + mfma_f32_16x16x4xf32, // k reduction + /// fp16 + mfma_f32_32x32x4f16, + mfma_f32_16x16x4f16, + mfma_f32_4x4x4f16, + mfma_f32_32x32x8f16, // k reduction + mfma_f32_16x16x16f16, // k reduction + /// bfp16 + mfma_f32_32x32x2bf16, + mfma_f32_16x16x2bf16, + mfma_f32_4x4x2bf16, + mfma_f32_32x32x4bf16, // k reduction + mfma_f32_16x16x8bf16, // k reduction +}; + +template +struct mfma_info; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 4; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 2; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 32; + static constexpr index_t n = 32; + static constexpr index_t k = 1; + static constexpr index_t cycles = 64; + static constexpr index_t k_base = 1; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x1f32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 4; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 32; + static constexpr index_t n = 32; + static constexpr index_t k = 2; + static constexpr index_t cycles = 64; + static constexpr index_t k_base = 1; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x2f32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 16; + static constexpr index_t n = 16; + static constexpr index_t k = 4; + static constexpr index_t cycles = 32; + static constexpr index_t k_base = 1; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x4f32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 4; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 16; + static constexpr index_t n = 16; + static constexpr index_t k = 1; + static constexpr index_t cycles = 32; + static constexpr index_t k_base = 1; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x1f32::Run(a, b, reg_c); + } +}; + +// treat 4x4x1 as a single-blk 4x64 mfma +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = 4; + static constexpr index_t m = 4; + static constexpr index_t n = 64; + static constexpr index_t k = 1; + static constexpr index_t cycles = 8; + static constexpr index_t k_base = 1; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_4x4x1f32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 4; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 2; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 32; + static constexpr index_t n = 32; + static constexpr index_t k = 4; + static constexpr index_t cycles = 64; + static constexpr index_t k_base = 4; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x4f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 4; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 32; + static constexpr index_t n = 32; + static constexpr index_t k = 8; + static constexpr index_t cycles = 64; + static constexpr index_t k_base = 4; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x8f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 16; + static constexpr index_t n = 16; + static constexpr index_t k = 16; + static constexpr index_t cycles = 32; + static constexpr index_t k_base = 4; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x16f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 4; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 16; + static constexpr index_t n = 16; + static constexpr index_t k = 4; + static constexpr index_t cycles = 32; + static constexpr index_t k_base = 4; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x4f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = 4; + static constexpr index_t m = 4; + static constexpr index_t n = 64; + static constexpr index_t k = 4; + static constexpr index_t cycles = 8; + static constexpr index_t k_base = 4; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_4x4x4f16::Run(a, b, reg_c); + } +}; + +#if 0 +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 4; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 2; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 32; + static constexpr index_t n = 32; + static constexpr index_t k = 2; + static constexpr index_t cycles = 64; + static constexpr index_t k_base = 2; + + template + __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + { + const auto p_a = reinterpret_cast(a); + const auto p_b = reinterpret_cast(b); + + return intrin_mfma_f32_32x32x2bf16::run( + p_a, p_b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 4; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 32; + static constexpr index_t n = 32; + static constexpr index_t k = 4; + static constexpr index_t cycles = 64; + static constexpr index_t k_base = 2; + + template + __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + { + const auto p_a = reinterpret_cast(a); + const auto p_b = reinterpret_cast(b); + + return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 16; + static constexpr index_t n = 16; + static constexpr index_t k = 8; + static constexpr index_t cycles = 32; + static constexpr index_t k_base = 2; + + template + __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + { + const auto p_a = reinterpret_cast(a); + const auto p_b = reinterpret_cast(b); + + return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 4; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 16; + static constexpr index_t n = 16; + static constexpr index_t k = 2; + static constexpr index_t cycles = 32; + static constexpr index_t k_base = 2; + + template + __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + { + const auto p_a = reinterpret_cast(a); + const auto p_b = reinterpret_cast(b); + + return intrin_mfma_f32_16x16x2bf16(p_a, p_b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = 4; + static constexpr index_t m = 4; + static constexpr index_t n = 64; + static constexpr index_t k = 2; + static constexpr index_t cycles = 8; + static constexpr index_t k_base = 2; + + template + __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + { + const auto p_a = reinterpret_cast(a); + const auto p_b = reinterpret_cast(b); + + return intrin_mfma_f32_4x4x2bf16::run(p_a, p_b, reg_c); + } +}; +#endif + +template +struct xdlops_info +{ + static constexpr auto mfma_type = mfma_info{}; + + static constexpr index_t MPerXdlops = MPerXdlops_; + static constexpr index_t NPerXdlops = NPerXdlops_; + + static constexpr bool IsABroadcast() + { + static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast"); + return true; + } + + static constexpr bool IsKReduction() + { + return (mfma_type.num_output_blks == 1) && (mfma_type.num_input_blks > 1); + } + + static constexpr index_t GetKPerXdlops() + { + return IsKReduction() ? mfma_type.num_input_blks : 1; + } + + static constexpr index_t GetNumCRegs() { return MPerXdlops * NPerXdlops / mfma_type.wave_size; } +}; + +template +struct XdlopsGemm +{ + template + static constexpr auto GetXdlopsInfo(); + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + +#if 0 + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } +#endif + + using CIndex = MultiIndex<2>; + + __device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; } + + __device__ static constexpr index_t GetNumXdlops() + { + return MPerXdlops * NPerXdlops / (mfma_type.m * mfma_type.n * mfma_type.num_output_blks); + } + + __host__ __device__ constexpr XdlopsGemm() + { + static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 || + NPerXdlops == 64, + "Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); + + static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 || + MPerXdlops == 64, + "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); + + static_assert(mfma_type.num_threads_blk == mfma_type.n, "n != num_threads_blk"); + static_assert(mfma_type.num_regs_blk * mfma_type.num_input_blks == mfma_type.m, + "m != num_input_blks * num_regs_blk"); + static_assert(mfma_type.num_output_blks == mfma_type.num_input_blks || + mfma_type.num_output_blks == 1, + "incorrect num_output_blks"); + static_assert(mfma_type.num_regs_blk * mfma_type.wave_size == mfma_type.m * mfma_type.n, + "num_regs_blk incorrect"); + + static_assert(mfma_type.k % mfma_type.k_base == 0, "k % kbase != 0!"); + } + + __device__ static constexpr index_t GetRegSizePerXdlops() + { + return MPerXdlops * NPerXdlops / mfma_type.wave_size; + } + + template + __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "base base_type must be float, half, ushort!"); + + static_assert(KPack % mfma_type.k_base == 0, "KPack cannot be divided by k_base"); + + constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops(); + + static_for<0, KPack, mfma_type.k_base>{}([&](auto k) { + constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(0, m0, 0, k)); + constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(0, n0, 0, k)); + + mfma_type.template run( + p_a_wave[Number{}], + p_b_wave[Number{}], + p_c_thread); + }); + } + + __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) + { + const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size; + const index_t blk_id = laneId / mfma_type.num_threads_blk; + const index_t blk_td = laneId % mfma_type.num_threads_blk; + + index_t n_offset = blk_i * mfma_type.n + blk_td; + index_t m_offset = xdlops_i * mfma_type.m + blk_id * mfma_type.group_size; + + return CIndex{m_offset, n_offset}; + } + + static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats; + static constexpr index_t NRepeats = GetXdlopsInfo().NRepeats; + static constexpr index_t MPerXdlops = GetXdlopsInfo().MPerXdlops; + static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops; + + static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction(); + static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast(); + static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops(); + + static constexpr auto GetBlkId(const index_t lane_id) + { + return lane_id / mfma_type.num_threads_blk; + } + + static constexpr auto GetBlkTd(const index_t lane_id) + { + return lane_id % mfma_type.num_threads_blk; + } + + static constexpr auto mfma_type = GetXdlopsInfo().mfma_type; + + struct CLayout + { + __host__ __device__ static constexpr index_t M1() { return mfma_type.num_groups_blk; } + __host__ __device__ static constexpr index_t M0() { return mfma_type.group_size; } + __host__ __device__ static constexpr index_t N1() { return mfma_type.num_input_blks; } + __host__ __device__ static constexpr index_t N0() { return mfma_type.num_threads_blk; } + + __device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_blk; } + + __device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; } + + __device__ static constexpr index_t GetNumXdlops() + { + return MPerXdlops * NPerXdlops / + (mfma_type.m * mfma_type.n * mfma_type.num_output_blks); + } + }; + + __host__ __device__ static constexpr auto GetCLayout() { return CLayout{}; } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/amd_buffer_addressing_v2.hpp b/composable_kernel/include/utility/amd_buffer_addressing_v2.hpp new file mode 100644 index 0000000000..0139bceb61 --- /dev/null +++ b/composable_kernel/include/utility/amd_buffer_addressing_v2.hpp @@ -0,0 +1,654 @@ +#ifndef CK_AMD_BUFFER_ADDRESSING_V2_HPP +#define CK_AMD_BUFFER_ADDRESSING_V2_HPP + +#include "data_type.hpp" + +namespace ck { + +template +union BufferResource_v2 +{ + // 128 bit SGPRs to supply buffer resource in buffer instructions + // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions + int32x4_t data; + StaticallyIndexedArray address; + StaticallyIndexedArray range; + StaticallyIndexedArray config; +}; + +template +__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_size) +{ + BufferResource_v2 wave_buffer_resource; + + // wavewise base address (64 bit) + wave_buffer_resource.address(Number<0>{}) = const_cast*>(p_wave); + // wavewise range (32 bit) + wave_buffer_resource.range(Number<2>{}) = data_space_size * sizeof(T); + // wavewise setting (32 bit) + wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; + + return wave_buffer_resource.data; +} + +// load +__device__ int8_t +llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8"); + +__device__ int8x2_t +llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8"); + +__device__ int8x4_t +llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8"); + +__device__ int16_t +llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32"); +__device__ int32_t +llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32"); + +__device__ int32x2_t +llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32"); + +__device__ int32x4_t +llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); +// half +__device__ half_t +llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); + +__device__ half2_t +llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); + +__device__ half4_t +llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16"); + +// float +__device__ float +llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); + +__device__ float2_t +llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32"); + +__device__ float4_t +llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32"); + +// store +__device__ void +llvm_amdgcn_raw_buffer_store_i8(int8_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i16(int16_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); + +// half +__device__ void +llvm_amdgcn_raw_buffer_store_fp16(half_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16"); + +__device__ void +llvm_amdgcn_raw_buffer_store_fp16x2(half2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16"); + +__device__ void +llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16"); +// float +__device__ void +llvm_amdgcn_raw_buffer_store_fp32(float vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); + +__device__ void +llvm_amdgcn_raw_buffer_store_fp32x2(float2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); + +__device__ void +llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); + +template +__device__ typename vector_type::type +amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) +{ + static_assert( + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)), + "wrong! not implemented"); + + if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_fp32( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 2) + { + return llvm_amdgcn_raw_buffer_load_fp32x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 4) + { + return llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 8) + { + vector_type tmp; + + tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + tmp.AsType()(Number<1>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(float), + 0); + + return tmp.AsType()(Number<0>{}); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_fp16( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 2) + { + return llvm_amdgcn_raw_buffer_load_fp16x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 4) + { + return llvm_amdgcn_raw_buffer_load_fp16x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 8) + { +#if 0 + vector_type tmp; + + tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp16x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + tmp.AsType()(Number<1>{}) = + llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(half_t), + 0); + + return tmp.AsType()(Number<0>{}); +#else + float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return as_type(tmp); +#endif + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_i32( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 2) + { + return llvm_amdgcn_raw_buffer_load_i32x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 4) + { + return llvm_amdgcn_raw_buffer_load_i32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 8) + { + vector_type tmp; + + tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + tmp.AsType()(Number<1>{}) = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int32_t), + 0); + return tmp.AsType()(Number<0>{}); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_i8( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 2) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + return llvm_amdgcn_raw_buffer_load_i8x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); +#else + int16_t tmp = llvm_amdgcn_raw_buffer_load_i16( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return as_type(tmp); +#endif + } + else if constexpr(N == 4) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + return llvm_amdgcn_raw_buffer_load_i8x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); +#else + int32_t tmp = llvm_amdgcn_raw_buffer_load_i32( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return as_type(tmp); +#endif + } + else if constexpr(N == 8) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + vector_type tmp; + + tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + tmp.AsType()(Number<1>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int8_t), + 0); + + return tmp.AsType()(Number<0>{}); +#else + int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return as_type(tmp); +#endif + } + else if constexpr(N == 16) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + vector_type tmp; + + tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + tmp.AsType()(Number<1>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int8_t), + 0); + + tmp.AsType()(Number<2>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 8 * sizeof(int8_t), + 0); + + tmp.AsType()(Number<3>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 12 * sizeof(int8_t), + 0); + + return tmp.AsType()(Number<0>{}); +#else + int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return as_type(tmp); +#endif + } + } +} + +template +__device__ void amd_buffer_store_impl_v2(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert( + (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)), + "wrong! not implemented"); + + if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_fp32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_i32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_i8(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); +#else + llvm_amdgcn_raw_buffer_store_i16(as_type(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); +#endif + } + else if constexpr(N == 4) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); +#else + llvm_amdgcn_raw_buffer_store_i32(as_type(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); +#endif + } + else if constexpr(N == 8) + { + llvm_amdgcn_raw_buffer_store_i32x2(as_type(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 16) + { + llvm_amdgcn_raw_buffer_store_i32x4(as_type(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_fp16(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 8) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(half_t), + 0); + } + } +} + +// buffer_load requires: +// 1) p_src_wave must be in global memory space +// 2) p_src_wave to be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ typename vector_type_maker::type::type +amd_buffer_load_v2(const T* p_src_wave, + index_t src_thread_data_offset, + bool src_thread_data_valid, + index_t src_element_space) +{ + const int32x4_t src_wave_buffer_resource = + make_wave_buffer_resource(p_src_wave, src_element_space); + + index_t src_thread_addr_offset = src_thread_data_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK + uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; + + return amd_buffer_load_impl_v2( + src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); +#else + vector_t tmp = amd_buffer_load_impl_v2( + src_wave_buffer_resource, src_thread_addr_offset, 0); + + return src_thread_data_valid ? tmp : vector_t(0); +#endif +} + +// buffer_store requires: +// 1) p_dst_wave must be global memory +// 2) p_dst_wave to be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ void +amd_buffer_store_v2(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_data_offset, + const bool dst_thread_data_valid, + const index_t dst_element_space) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space); + + index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; + + amd_buffer_store_impl_v2( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_data_valid) + { + amd_buffer_store_impl_v2( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/amd_dlop.hpp b/composable_kernel/include/utility/amd_dlop.hpp new file mode 100644 index 0000000000..8ce19012e9 --- /dev/null +++ b/composable_kernel/include/utility/amd_dlop.hpp @@ -0,0 +1,188 @@ +#ifndef CK_AMD_DLOP_HPP +#define CK_AMD_DLOP_HPP + +#include "data_type.hpp" + +namespace ck { + +template +__device__ void amd_inner_product_dlop(const TA& a, const TB& b, TC& c); + +template <> +__device__ void +amd_inner_product_dlop(const float& a, const float& b, float& c) +{ +#if CK_USE_AMD_DLOP_INLINE_ASM + asm volatile("\n \ + v_fmac_f32 %0, %1, %2 \n \ + " + : "=v"(c) + : "v"(a), "v"(b), "0"(c)); +#else + c += a * b; +#endif +} + +template <> +__device__ void +amd_inner_product_dlop(const float2_t& a, const float2_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + amd_inner_product_dlop(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +template <> +__device__ void +amd_inner_product_dlop(const float4_t& a, const float4_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + amd_inner_product_dlop(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I2], + vector_type{b}.AsType()[I2], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I3], + vector_type{b}.AsType()[I3], + c); +} + +#if CK_USE_AMD_DLOP +template <> +__device__ void +amd_inner_product_dlop(const half2_t& a, const half2_t& b, float& c) +{ +#if CK_USE_AMD_DLOP_INLINE_ASM + asm volatile("\n \ + v_dot2_f32_f16 %0, %1, %2, %0\n \ + " + : "=v"(c) + : "v"(a), "v"(b), "0"(c)); +#else + c = __builtin_amdgcn_sdot2(a, b, c, false); +#endif +} + +template <> +__device__ void +amd_inner_product_dlop(const half4_t& a, const half4_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + amd_inner_product_dlop(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +template <> +__device__ void +amd_inner_product_dlop(const half8_t& a, const half8_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + amd_inner_product_dlop(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I2], + vector_type{b}.AsType()[I2], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I3], + vector_type{b}.AsType()[I3], + c); +} + +template <> +__device__ void amd_inner_product_dlop(const int8x4_t& a, + const int8x4_t& b, + int32_t& c) +{ +#if CK_USE_AMD_DLOP_INLINE_ASM + asm volatile("\n \ + v_dot4_i32_i8 %0, %1, %2, %0\n \ + " + : "=v"(c) + : "v"(as_type(a)), "v"(as_type(b)), "0"(c)); +#else + c = __builtin_amdgcn_sdot4(as_type(a), as_type(b), c, false); +#endif +} + +template <> +__device__ void amd_inner_product_dlop(const int8x8_t& a, + const int8x8_t& b, + int32_t& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + amd_inner_product_dlop(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +template <> +__device__ void amd_inner_product_dlop(const int8x16_t& a, + const int8x16_t& b, + int32_t& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + amd_inner_product_dlop(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I2], + vector_type{b}.AsType()[I2], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I3], + vector_type{b}.AsType()[I3], + c); +} +#endif // CK_USE_AMD_DLOP + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/amd_inline_asm.hpp b/composable_kernel/include/utility/amd_inline_asm.hpp new file mode 100644 index 0000000000..ce80fc0549 --- /dev/null +++ b/composable_kernel/include/utility/amd_inline_asm.hpp @@ -0,0 +1,353 @@ +#ifndef CK_AMD_INLINE_ASM_HPP +#define CK_AMD_INLINE_ASM_HPP + +#include "data_type.hpp" + +namespace ck { + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) +{ + asm volatile("\n \ + v_fmac_f32 %0, %2, %3 \n \ + v_fmac_f32 %1, %2, %4 \n \ + " + : "=v"(c0), "=v"(c1) + : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +// c2 += inner_product(a, b2) +// c3 += inner_product(a, b3) +__device__ void amd_assembly_outer_product_1x4( + float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3) +{ + asm volatile("\n \ + v_fmac_f32 %0, %4, %5 \n \ + v_fmac_f32 %1, %4, %6 \n \ + v_fmac_f32 %2, %4, %7 \n \ + v_fmac_f32 %3, %4, %8 \n \ + " + : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) + : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +__device__ void +amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1) +{ + asm volatile("\n \ + v_dot2_f32_f16 %0, %2, %3, %0\n \ + v_dot2_f32_f16 %1, %2, %4, %1\n \ + " + : "=v"(c0), "=v"(c1) + : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +__device__ void +amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1) +{ + // TODO remove pointer casting + const half2_t* p_a_half2 = reinterpret_cast(&a); + const half2_t* p_b0_half2 = reinterpret_cast(&b0); + const half2_t* p_b1_half2 = reinterpret_cast(&b1); + + // do dot2 two times + asm volatile("\n \ + v_dot2_f32_f16 %0, %2, %4, %0\n \ + v_dot2_f32_f16 %1, %2, %6, %1\n \ + v_dot2_f32_f16 %0, %3, %5, %0\n \ + v_dot2_f32_f16 %1, %3, %7, %1\n \ + " + : "=v"(c0), "=v"(c1) + : "v"(p_a_half2[0]), + "v"(p_a_half2[1]), + "v"(p_b0_half2[0]), + "v"(p_b0_half2[1]), + "v"(p_b1_half2[0]), + "v"(p_b1_half2[1]), + "0"(c0), + "1"(c1)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +// c2 += inner_product(a, b2) +// c3 += inner_product(a, b3) +__device__ void amd_assembly_outer_product_1x4(half2_t a, + half2_t b0, + half2_t b1, + half2_t b2, + half2_t b3, + float& c0, + float& c1, + float& c2, + float& c3) +{ + asm volatile("\n \ + v_dot2_f32_f16 %0, %4, %5, %0\n \ + v_dot2_f32_f16 %1, %4, %6, %1\n \ + v_dot2_f32_f16 %2, %4, %7, %2\n \ + v_dot2_f32_f16 %3, %4, %8, %3\n \ + " + : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) + : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +// c2 += inner_product(a, b2) +// c3 += inner_product(a, b3) +__device__ void amd_assembly_outer_product_1x4(half4_t a, + half4_t b0, + half4_t b1, + half4_t b2, + half4_t b3, + float& c0, + float& c1, + float& c2, + float& c3) +{ + // TODO remove pointer casting + const half2_t* p_a_half2 = reinterpret_cast(&a); + const half2_t* p_b0_half2 = reinterpret_cast(&b0); + const half2_t* p_b1_half2 = reinterpret_cast(&b1); + const half2_t* p_b2_half2 = reinterpret_cast(&b2); + const half2_t* p_b3_half2 = reinterpret_cast(&b3); + + // do dot2 two times + asm volatile("\n \ + v_dot2_f32_f16 %0, %4, %6, %0\n \ + v_dot2_f32_f16 %1, %4, %8, %1\n \ + v_dot2_f32_f16 %2, %4, %10, %2\n \ + v_dot2_f32_f16 %3, %4, %12, %3\n \ + v_dot2_f32_f16 %0, %5, %7, %0\n \ + v_dot2_f32_f16 %1, %5, %9, %1\n \ + v_dot2_f32_f16 %2, %5, %11, %2\n \ + v_dot2_f32_f16 %3, %5, %13, %3\n \ + " + : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) + : "v"(p_a_half2[0]), + "v"(p_a_half2[1]), + "v"(p_b0_half2[0]), + "v"(p_b0_half2[1]), + "v"(p_b1_half2[0]), + "v"(p_b1_half2[1]), + "v"(p_b2_half2[0]), + "v"(p_b2_half2[1]), + "v"(p_b3_half2[0]), + "v"(p_b3_half2[1]), + "0"(c0), + "1"(c1), + "2"(c2), + "3"(c3)); +} + +__device__ void amd_assembly_outer_product_1x4(half8_t a, + half8_t b0, + half8_t b1, + half8_t b2, + half8_t b3, + float& c0, + float& c1, + float& c2, + float& c3) +{ + + // TODO remove pointer casting + const half4_t* p_a_half4 = reinterpret_cast(&a); + const half4_t* p_b0_half4 = reinterpret_cast(&b0); + const half4_t* p_b1_half4 = reinterpret_cast(&b1); + const half4_t* p_b2_half4 = reinterpret_cast(&b2); + const half4_t* p_b3_half4 = reinterpret_cast(&b3); + + amd_assembly_outer_product_1x4( + p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3); + + amd_assembly_outer_product_1x4( + p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3); +} + +__device__ void amd_assembly_outer_product_1x4(half16_t a, + half16_t b0, + half16_t b1, + half16_t b2, + half16_t b3, + float& c0, + float& c1, + float& c2, + float& c3) +{ + // TODO remove pointer casting + const half8_t* p_a_half8 = reinterpret_cast(&a); + const half8_t* p_b0_half8 = reinterpret_cast(&b0); + const half8_t* p_b1_half8 = reinterpret_cast(&b1); + const half8_t* p_b2_half8 = reinterpret_cast(&b2); + const half8_t* p_b3_half8 = reinterpret_cast(&b3); + + amd_assembly_outer_product_1x4( + p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3); + + amd_assembly_outer_product_1x4( + p_a_half8[1], p_b0_half8[1], p_b1_half8[1], p_b2_half8[1], p_b3_half8[1], c0, c1, c2, c3); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +__device__ void +amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1) +{ +#if 1 + asm volatile("\n \ + v_dot4_i32_i8 %0, %2, %3, %0\n \ + v_dot4_i32_i8 %1, %2, %4, %1\n \ + " + : "=v"(c0), "=v"(c1) + : "v"(as_type(a)), + "v"(as_type(b0)), + "v"(as_type(b1)), + "0"(c0), + "1"(c1)); +#else + c0 = __builtin_amdgcn_sdot4(as_type(a), as_type(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(as_type(a), as_type(b1), c1, false); +#endif +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +// c2 += inner_product(a, b2) +// c3 += inner_product(a, b3) +__device__ void amd_assembly_outer_product_1x4(int8x4_t a, + int8x4_t b0, + int8x4_t b1, + int8x4_t b2, + int8x4_t b3, + int32_t& c0, + int32_t& c1, + int32_t& c2, + int32_t& c3) +{ +#if 1 + asm volatile("\n \ + v_dot4_i32_i8 %0, %4, %5, %0\n \ + v_dot4_i32_i8 %1, %4, %6, %1\n \ + v_dot4_i32_i8 %2, %4, %7, %2\n \ + v_dot4_i32_i8 %3, %4, %8, %3\n \ + " + : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) + : "v"(as_type(a)), + "v"(as_type(b0)), + "v"(as_type(b1)), + "v"(as_type(b2)), + "v"(as_type(b3)), + "0"(c0), + "1"(c1), + "2"(c2), + "3"(c3)); +#else + c0 = __builtin_amdgcn_sdot4(as_type(a), as_type(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(as_type(a), as_type(b1), c1, false); + c2 = __builtin_amdgcn_sdot4(as_type(a), as_type(b2), c2, false); + c3 = __builtin_amdgcn_sdot4(as_type(a), as_type(b3), c3, false); +#endif +} + +__device__ void amd_assembly_outer_product_1x4(int8x8_t a, + int8x8_t b0, + int8x8_t b1, + int8x8_t b2, + int8x8_t b3, + int32_t& c0, + int32_t& c1, + int32_t& c2, + int32_t& c3) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I0], + vector_type{b0}.AsType()[I0], + vector_type{b1}.AsType()[I0], + vector_type{b2}.AsType()[I0], + vector_type{b3}.AsType()[I0], + c0, + c1, + c2, + c3); + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I1], + vector_type{b0}.AsType()[I1], + vector_type{b1}.AsType()[I1], + vector_type{b2}.AsType()[I1], + vector_type{b3}.AsType()[I1], + c0, + c1, + c2, + c3); +} + +__device__ void amd_assembly_outer_product_1x4(int8x16_t a, + int8x16_t b0, + int8x16_t b1, + int8x16_t b2, + int8x16_t b3, + int32_t& c0, + int32_t& c1, + int32_t& c2, + int32_t& c3) + +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I0], + vector_type{b0}.AsType()[I0], + vector_type{b1}.AsType()[I0], + vector_type{b2}.AsType()[I0], + vector_type{b3}.AsType()[I0], + c0, + c1, + c2, + c3); + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I1], + vector_type{b0}.AsType()[I1], + vector_type{b1}.AsType()[I1], + vector_type{b2}.AsType()[I1], + vector_type{b3}.AsType()[I1], + c0, + c1, + c2, + c3); + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I2], + vector_type{b0}.AsType()[I2], + vector_type{b1}.AsType()[I2], + vector_type{b2}.AsType()[I2], + vector_type{b3}.AsType()[I2], + c0, + c1, + c2, + c3); + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I3], + vector_type{b0}.AsType()[I3], + vector_type{b1}.AsType()[I3], + vector_type{b2}.AsType()[I3], + vector_type{b3}.AsType()[I3], + c0, + c1, + c2, + c3); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/amd_llvm_intrinsic.hpp b/composable_kernel/include/utility/amd_llvm_intrinsic.hpp new file mode 100644 index 0000000000..841d48f81c --- /dev/null +++ b/composable_kernel/include/utility/amd_llvm_intrinsic.hpp @@ -0,0 +1,11 @@ +#ifndef CK_AMD_LLVM_INTRINSIC_HPP +#define CK_AMD_LLVM_INTRINSIC_HPP + +#include "data_type.hpp" + +namespace ck { + +__device__ int32_t llvm_amdgcn_readfirstlane_i32(int32_t i) __asm("llvm.amdgcn.readfirstlane"); + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/amd_xdlops.hpp b/composable_kernel/include/utility/amd_xdlops.hpp new file mode 100644 index 0000000000..da74fe1d48 --- /dev/null +++ b/composable_kernel/include/utility/amd_xdlops.hpp @@ -0,0 +1,499 @@ +#ifndef CK_AMD_XDLOPS_HPP +#define CK_AMD_XDLOPS_HPP + +#include "data_type.hpp" + +namespace ck { + +// A, B, C, cbsz, abid, blgp +extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32"); + +extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x2f32( + float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2f32"); + +extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x4f32( + float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f32"); + +extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32( + float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x1f32"); + +extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32"); + +extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16"); + +extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8f16( + half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8f16"); + +extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16f16( + half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16f16"); + +extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16( + half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f16"); + +extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16"); + +extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16( + ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16"); + +extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x4bf16( + ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4bf16"); + +extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x8bf16( + ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x8bf16"); + +extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16( + ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x2bf16"); + +extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16( + ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16"); + +template +struct intrin_mfma_f32_32x32x1f32; + +template +struct intrin_mfma_f32_32x32x1f32<64, 64, COffset> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 1, + 0, + 0); + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 1, + 1, + 0); + } +}; + +template +struct intrin_mfma_f32_32x32x1f32<32, 64, COffset> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 1, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_32x32x2f32; + +template +struct intrin_mfma_f32_32x32x2f32<32, 32, COffset> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x2f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_16x16x4f32; + +template +struct intrin_mfma_f32_16x16x4f32<16, 16, COffset> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_16x16x4f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_16x16x1f32; + +template +struct intrin_mfma_f32_16x16x1f32<16, 64, COffset> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_16x16x1f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 2, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_4x4x1f32; + +template +struct intrin_mfma_f32_4x4x1f32<4, 64, COffset> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 4, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_4x4x1f32<8, 64, COffset> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 4, + 0, + 0); + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 4, + 1, + 0); + } +}; + +template +struct intrin_mfma_f32_32x32x4f16; + +template +struct intrin_mfma_f32_32x32x4f16<64, 64, COffset> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 1, + 0, + 0); + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 1, + 1, + 0); + } +}; + +template +struct intrin_mfma_f32_32x32x4f16<32, 64, COffset> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 1, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_32x32x8f16; + +template +struct intrin_mfma_f32_32x32x8f16<32, 32, COffset> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x8f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_16x16x16f16; + +template +struct intrin_mfma_f32_16x16x16f16<16, 16, COffset> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_16x16x16f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_16x16x4f16; + +template +struct intrin_mfma_f32_16x16x4f16<16, 64, COffset> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_16x16x4f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 2, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_4x4x4f16; + +template +struct intrin_mfma_f32_4x4x4f16<4, 64, COffset> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 4, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_4x4x4f16<8, 64, COffset> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 4, + 0, + 0); + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 4, + 1, + 0); + } +}; + +#if 0 +template +struct intrin_mfma_f32_32x32x2bf16; + +template +struct intrin_mfma_f32_32x32x2bf16<128, 64, AStride, BStride> +{ + __device__ static c_vec32_4_t::VecType + run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c) + { + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); + reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); + + reg_c.s.z = + llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0); + reg_c.s.w = + llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0); + + return reg_c; + } +}; + +template +struct intrin_mfma_f32_32x32x2bf16<64, 128, AStride, BStride> +{ + __device__ static c_vec32_4_t::VecType + run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c) + { + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); + reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); + + reg_c.s.z = + llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0); + reg_c.s.w = + llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0); + + return reg_c; + } +}; + +template +struct intrin_mfma_f32_32x32x2bf16<64, 64, AStride, BStride> +{ + __device__ static c_vec32_2_t::VecType + run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_2_t::VecType reg_c) + { + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); + reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); + + return reg_c; + } +}; + +template +struct intrin_mfma_f32_32x32x2bf16<64, 32, AStride, BStride> +{ + __device__ static c_vec32_1_t::VecType + run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c) + { + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1); + + return reg_c; + } +}; + +template +struct intrin_mfma_f32_32x32x2bf16<32, 64, AStride, BStride> +{ + __device__ static c_vec32_1_t::VecType + run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c) + { + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); + return reg_c; + } +}; + +__device__ c_vec16_1_t::VecType intrin_mfma_f32_32x32x4bf16(const ushort2_t* reg_a, + const ushort2_t* reg_b, + c_vec16_1_t::VecType reg_c) +{ + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0); + return reg_c; +} + +__device__ c_vec4_1_t::VecType intrin_mfma_f32_16x16x8bf16(const ushort2_t* reg_a, + const ushort2_t* reg_b, + c_vec4_1_t::VecType reg_c) +{ + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0); + return reg_c; +} + +template +__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a, + const ushort2_t* reg_b, + c_vec16_1_t::VecType reg_c); + +template <> +__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a, + const ushort2_t* reg_b, + c_vec16_1_t::VecType reg_c) +{ + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0); + return reg_c; +} + +template <> +__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t* reg_a, + const ushort2_t* reg_b, + c_vec16_1_t::VecType reg_c) +{ + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4); + return reg_c; +} + +template +struct intrin_mfma_f32_4x4x2bf16; + +template <> +struct intrin_mfma_f32_4x4x2bf16<4, 64> +{ + __device__ static c_vec4_1_t::VecType + run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_1_t::VecType reg_c) + { + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0); + return reg_c; + } +}; + +template <> +struct intrin_mfma_f32_4x4x2bf16<8, 64> +{ + __device__ static c_vec4_2_t::VecType + run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_2_t::VecType reg_c) + { + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0); + reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0); + return reg_c; + } +}; + +#endif + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/array.hpp b/composable_kernel/include/utility/array.hpp new file mode 100644 index 0000000000..7271094d39 --- /dev/null +++ b/composable_kernel/include/utility/array.hpp @@ -0,0 +1,63 @@ +#ifndef CK_ARRAY_HPP +#define CK_ARRAY_HPP + +#include "functional2.hpp" +#include "sequence.hpp" + +namespace ck { + +template +struct Array +{ + using type = Array; + using data_type = TData; + + TData mData[NSize]; + + __host__ __device__ static constexpr index_t Size() { return NSize; } + + __host__ __device__ constexpr const TData& At(index_t i) const { return mData[i]; } + + __host__ __device__ constexpr TData& At(index_t i) { return mData[i]; } + + __host__ __device__ constexpr const TData& operator[](index_t i) const { return At(i); } + + __host__ __device__ constexpr TData& operator()(index_t i) { return At(i); } + + template + __host__ __device__ constexpr auto operator=(const T& a) + { + static_assert(T::Size() == Size(), "wrong! size not the same"); + + static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; }); + + return *this; + } +}; + +// empty Array +template +struct Array +{ + using type = Array; + using data_type = TData; + + __host__ __device__ static constexpr index_t Size() { return 0; } +}; + +template +__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) +{ + using data_type = remove_cv_t>; + return Array{{std::forward(x), std::forward(xs)...}}; +} + +// make empty array +template +__host__ __device__ constexpr auto make_array() +{ + return Array{}; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/array_multi_index.hpp b/composable_kernel/include/utility/array_multi_index.hpp new file mode 100644 index 0000000000..f692fb5143 --- /dev/null +++ b/composable_kernel/include/utility/array_multi_index.hpp @@ -0,0 +1,77 @@ +#ifndef CK_ARRAY_MULTI_INDEX_HPP +#define CK_ARRAY_MULTI_INDEX_HPP + +#include "common_header.hpp" + +namespace ck { + +template +using MultiIndex = Array; + +template +__host__ __device__ constexpr auto make_multi_index(Xs&&... xs) +{ + return make_array(index_t{xs}...); +} + +template +__host__ __device__ constexpr auto make_zero_multi_index() +{ + return unpack([](auto... xs) { return make_multi_index(xs...); }, + typename uniform_sequence_gen::type{}); +} + +template +__host__ __device__ constexpr auto to_multi_index(const T& x) +{ + return unpack([](auto... ys) { return make_multi_index(ys...); }, x); +} + +template +__host__ __device__ constexpr auto operator+=(MultiIndex& y, const X& x) +{ + static_assert(X::Size() == NSize, "wrong! size not the same"); + static_for<0, NSize, 1>{}([&](auto i) { y(i) += x[i]; }); + return y; +} + +template +__host__ __device__ constexpr auto operator-=(MultiIndex& y, const X& x) +{ + static_assert(X::Size() == NSize, "wrong! size not the same"); + static_for<0, NSize, 1>{}([&](auto i) { y(i) -= x[i]; }); + return y; +} + +template +__host__ __device__ constexpr auto operator+(const MultiIndex& a, const T& b) +{ + using type = MultiIndex; + static_assert(T::Size() == NSize, "wrong! size not the same"); + type r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] + b[i]; }); + return r; +} + +template +__host__ __device__ constexpr auto operator-(const MultiIndex& a, const T& b) +{ + using type = MultiIndex; + static_assert(T::Size() == NSize, "wrong! size not the same"); + type r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] - b[i]; }); + return r; +} + +template +__host__ __device__ constexpr auto operator*(const MultiIndex& a, const T& b) +{ + using type = MultiIndex; + static_assert(T::Size() == NSize, "wrong! size not the same"); + type r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] * b[i]; }); + return r; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/common_header.hpp b/composable_kernel/include/utility/common_header.hpp new file mode 100644 index 0000000000..5ff7688a1c --- /dev/null +++ b/composable_kernel/include/utility/common_header.hpp @@ -0,0 +1,45 @@ +#ifndef CK_COMMON_HEADER_HPP +#define CK_COMMON_HEADER_HPP + +#include "config.hpp" +#include "array.hpp" +#include "container_helper.hpp" +#include "statically_indexed_array.hpp" +#include "container_element_picker.hpp" +#include "multi_index.hpp" +#include "data_type_enum.hpp" +#include "data_type.hpp" +#include "data_type_helper.hpp" +#include "functional.hpp" +#include "functional2.hpp" +#include "functional3.hpp" +#include "functional4.hpp" +#include "integral_constant.hpp" +#include "math.hpp" +#include "number.hpp" +#include "sequence.hpp" +#include "sequence_helper.hpp" +#include "synchronization.hpp" +#include "tuple.hpp" +#include "tuple_helper.hpp" +#include "type.hpp" +#include "utility.hpp" +#include "magic_division.hpp" +#include "amd_buffer_addressing_v2.hpp" +#include "static_buffer.hpp" +#include "dynamic_buffer.hpp" + +// TODO: remove this +#if CK_USE_AMD_INLINE_ASM +#include "amd_inline_asm.hpp" +#endif + +#if CK_USE_AMD_DLOP +#include "amd_dlop.hpp" +#endif + +#if CK_USE_AMD_XDLOPS +#include "amd_xdlops.hpp" +#endif + +#endif diff --git a/composable_kernel/include/utility/config.hpp b/composable_kernel/include/utility/config.hpp new file mode 100644 index 0000000000..4908d8d818 --- /dev/null +++ b/composable_kernel/include/utility/config.hpp @@ -0,0 +1,142 @@ +#ifndef CK_CONFIG_AMD_HPP +#define CK_CONFIG_AMD_HPP + +#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS +#include "hip/hip_runtime.h" +#include "hip/hip_fp16.h" +#endif +#include "bfloat16_dev.hpp" + +// address space for kernel parameter +#define CONSTANT __attribute__((address_space(4))) + +// GPU target +// should enable one and only one GPU target +#if !(defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \ + defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) || defined(CK_AMD_GPU_GFX1030)) +#error Need to define a single GPU target +#endif + +// HIP version +#ifndef CK_HIP_VERSION_FLAT +#define CK_HIP_VERSION_FLAT 0 +#endif + +// launch bounds +#define CK_USE_LAUNCH_BOUNDS 1 + +#ifdef CK_USE_LAUNCH_BOUNDS +#define CK_MAX_THREAD_PER_BLOCK 256 +#define CK_MIN_BLOCK_PER_CU 2 +#endif + +// buffer resourse +#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \ + defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) +#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 +#elif defined(CK_AMD_GPU_GFX1030) +#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 +#endif + +// multi index +#define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0 + +// AMD inline asm +#ifndef CK_USE_AMD_INLINE_ASM +#define CK_USE_AMD_INLINE_ASM 1 +#endif + +// AMD DLOPS +#ifndef CK_USE_AMD_DLOP +#define CK_USE_AMD_DLOP 1 +#endif + +#ifndef CK_USE_AMD_DLOP_INLINE_ASM +#define CK_USE_AMD_DLOP_INLINE_ASM 1 +#endif + +// AMD buffer addressing +#ifndef CK_USE_AMD_BUFFER_ADDRESSING +#define CK_USE_AMD_BUFFER_ADDRESSING 1 +#endif + +// only gfx908 support native floating point atomic add +#ifndef CK_USE_AMD_BUFFER_ATOMIC_FADD +#define CK_USE_AMD_BUFFER_ATOMIC_FADD 0 +#endif + +// AMD XDLOPS +#ifndef CK_USE_AMD_XDLOPS +#define CK_USE_AMD_XDLOPS 0 +#endif + +// block synchronization only s_wait lgkmcnt(0), not vmcnt(0) +#ifndef CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM +#define CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 +#endif + +// experimental implementation +#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0 +#endif + +#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK +#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 +#endif + +#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK +#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK 1 +#endif + +// pass tensor descriptor by value or void* +#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 0 +#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1 + +// merge transformation use magic number division +#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0 + +// hack: have underlying assumption that need to be satsified, otherwise it's a bug +// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be +// thread-invariant, otherwise it's a bug +// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread" +#ifndef CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE +#define CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 +#endif + +// workaround for compiler crash when compiling recursive lambda +#ifndef CK_WORKAROUND_SWDEV_275126 +#define CK_WORKAROUND_SWDEV_275126 1 +#endif + +// workaround for compiler crash when using buffer load/store for i8 +#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE +#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1 +#endif + +// workaround for compiler crash when using buffer load/store for i8 +#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE +#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 +#endif + +namespace ck { + +enum AddressSpaceEnum_t +{ + Generic, + Global, + Lds, + Sgpr, + Vgpr +}; + +enum InMemoryDataOperationEnum_t +{ + Set, + AtomicAdd +}; + +// index type +using index_t = int32_t; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/container_element_picker.hpp b/composable_kernel/include/utility/container_element_picker.hpp new file mode 100644 index 0000000000..54915125ac --- /dev/null +++ b/composable_kernel/include/utility/container_element_picker.hpp @@ -0,0 +1,155 @@ +#ifndef CK_CONTAINER_ELEMENT_PICKER_HPP +#define CK_CONTAINER_ELEMENT_PICKER_HPP + +#include "functional2.hpp" +#include "sequence.hpp" + +namespace ck { + +// Arr: Array or StaticallyIndexedArray +// Picks: Sequence<...> +template +struct ContainerElementPicker +{ + using type = ContainerElementPicker; +#if 0 + using data_type = typename Arr::data_type; +#endif + + __host__ __device__ constexpr ContainerElementPicker() = delete; + + __host__ __device__ constexpr ContainerElementPicker(Arr& array) : mArray{array} + { + constexpr index_t imax = + reduce_on_sequence(Picks{}, math::maximize{}, Number<0>{}); + + static_assert(imax < Arr::Size(), "wrong! exceeding # array element"); + } + + __host__ __device__ static constexpr auto Size() { return Picks::Size(); } + + template + __host__ __device__ constexpr const auto& At(Number i) const + { + static_assert(I < Size(), "wrong!"); + + constexpr auto IP = Picks{}[i]; + return mArray[IP]; + } + + template + __host__ __device__ constexpr auto& At(Number i) + { + static_assert(I < Size(), "wrong!"); + + constexpr auto IP = Picks{}[i]; + return mArray(IP); + } + + template + __host__ __device__ constexpr const auto& operator[](Number i) const + { + return At(i); + } + + template + __host__ __device__ constexpr auto& operator()(Number i) + { + return At(i); + } + + template + __host__ __device__ constexpr auto operator=(const T& a) + { + static_assert(T::Size() == Size(), "wrong! size not the same"); + + static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; }); + + return *this; + } + + private: + Arr& mArray; +}; + +// Arr: Array or StaticallyIndexedArray +// Picks: Sequence<...> +template +struct ConstantContainerElementPicker +{ + using type = ConstantContainerElementPicker; +#if 0 + using data_type = typename Arr::data_type; +#endif + + __host__ __device__ constexpr ConstantContainerElementPicker() = delete; + + __host__ __device__ constexpr ConstantContainerElementPicker(const Arr& array) : mArray{array} + { + constexpr index_t imax = + reduce_on_sequence(Picks{}, math::maximize{}, Number<0>{}); + + static_assert(imax < Arr::Size(), "wrong! exceeding # array element"); + } + + __host__ __device__ static constexpr auto Size() { return Picks::Size(); } + + template + __host__ __device__ constexpr const auto& At(Number i) const + { + static_assert(I < Size(), "wrong!"); + + constexpr auto IP = Picks{}[i]; + return mArray[IP]; + } + + template + __host__ __device__ constexpr const auto& operator[](Number i) const + { + return At(i); + } + + private: + const Arr& mArray; +}; + +template +__host__ __device__ constexpr auto operator+=(ContainerElementPicker& y, const X& x) +{ + using Y = ContainerElementPicker; + constexpr index_t nsize = Y::Size(); + + static_assert(nsize == X::Size(), "wrong! size not the same"); + + static_for<0, nsize, 1>{}([&](auto i) { y(i) += x[i]; }); + + return y; +} + +template +__host__ __device__ constexpr auto operator-=(ContainerElementPicker& y, const X& x) +{ + using Y = ContainerElementPicker; + constexpr index_t nsize = Y::Size(); + + static_assert(nsize == X::Size(), "wrong! size not the same"); + + static_for<0, nsize, 1>{}([&](auto i) { y(i) -= x[i]; }); + + return y; +} + +template +__host__ __device__ constexpr auto pick_container_element(Arr& a, Picks) +{ + return ContainerElementPicker(a); +} + +template +__host__ __device__ constexpr auto pick_container_element(const Arr& a, Picks) +{ + return ConstantContainerElementPicker(a); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/container_helper.hpp b/composable_kernel/include/utility/container_helper.hpp new file mode 100644 index 0000000000..a7ed8ec059 --- /dev/null +++ b/composable_kernel/include/utility/container_helper.hpp @@ -0,0 +1,403 @@ +#ifndef CK_CONTAINER_HELPER_HPP +#define CK_CONTAINER_HELPER_HPP + +#include "sequence.hpp" +#include "sequence_helper.hpp" +#include "array.hpp" +#include "tuple.hpp" +#include "tuple_helper.hpp" +#include "statically_indexed_array.hpp" +#include "container_element_picker.hpp" + +namespace ck { + +template +__host__ __device__ constexpr auto container_push_back(const Array& a, const TData& x) +{ + Array r; + + static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; }); + + r(Number{}) = x; + + return r; +} + +template +__host__ __device__ constexpr auto container_push_front(const Tuple& a, const T& x) +{ + return container_concat(make_tuple(x), a); +} + +template +__host__ __device__ constexpr auto container_push_back(const Tuple& a, const T& x) +{ + return container_concat(a, make_tuple(x)); +} + +template +__host__ __device__ constexpr auto +container_reorder_given_new2old(const Array& old_array, Sequence /*new2old*/) +{ + static_assert(NSize == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + return make_array(old_array[Number{}]...); +} + +template +__host__ __device__ constexpr auto +container_reorder_given_old2new(const Array& old_array, Sequence old2new) +{ + return container_reorder_given_new2old( + old_array, typename sequence_map_inverse::type{}); +} + +template +__host__ __device__ constexpr auto container_reorder_given_new2old(const Tuple& old_tuple, + Sequence /*new2old*/) +{ + static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + return make_tuple(old_tuple[Number{}]...); +} + +template +__host__ __device__ constexpr auto container_reorder_given_old2new(const Tuple& old_tuple, + Sequence old2new) +{ + return container_reorder_given_new2old( + old_tuple, typename sequence_map_inverse::type{}); +} + +template +__host__ __device__ constexpr auto container_reorder_given_new2old(Sequence /* old_seq */, + Sequence /*new2old*/) +{ + static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + return Sequence::At(Number{})...>{}; +} + +template +__host__ __device__ constexpr auto container_reorder_given_old2new(Sequence old_seq, + Sequence /* old2new */) +{ + static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + constexpr auto new2old = typename sequence_map_inverse>::type{}; + + return container_reorder_given_new2old(old_seq, new2old); +} + +#if !CK_WORKAROUND_SWDEV_275126 +// rocm-4.1 compiler would crash for recursive lambda +template +__host__ __device__ constexpr auto container_reduce(const Container& x, + Reduce reduce, + Init init, + Number = Number<0>{}, + Number = Number{}, + Number = Number<1>{}) +{ + static_assert((IEnd - IBegin) % IStep == 0, "wrong!"); + + // f is recursive function, fs is a dummy of f + // i is index, y_old is current scan, r_old is current reduction + auto f = [&](auto fs, auto i, auto r_old) { + auto r_new = reduce(x[i], r_old); + + if constexpr(i.value < IEnd - IStep) + { + // recursively call f/fs + return fs(fs, i + Number{}, r_new); + } + else + { + return r_new; + } + }; + + // start recursion + return f(f, Number{}, init); +} +#else +// i is index, y_old is current scan, r_old is current reduction +template +__host__ __device__ constexpr auto container_reduce_impl( + const Container& x, Reduce reduce, ROld r_old, Number i, Number, Number) +{ + auto r_new = reduce(x[i], r_old); + + if constexpr(i.value < IEnd - IStep) + { + return container_reduce_impl( + x, reduce, r_new, i + Number{}, Number{}, Number{}); + } + else + { + return r_new; + } +} + +// rocm-4.1 compiler would crash for recursive lambda +// container reduce with initial value +template +__host__ __device__ constexpr auto container_reduce(const Container& x, + Reduce reduce, + Init init, + Number = Number<0>{}, + Number = Number{}, + Number = Number<1>{}) +{ + static_assert((IEnd - IBegin) % IStep == 0, "wrong!"); + + if constexpr(IEnd > IBegin) + { + return container_reduce_impl( + x, reduce, init, Number{}, Number{}, Number{}); + } + else + { + return init; + } +} +#endif + +template +__host__ __device__ constexpr auto +container_reverse_inclusive_scan(const Array& x, Reduce f, TData init) +{ + Array y; + + TData r = init; + + static_for{}([&](auto i) { + r = f(r, x[i]); + y(i) = r; + }); + + r = f(r, x[Number<0>{}]); + y(Number<0>{}) = r; + + return y; +} + +template +__host__ __device__ constexpr auto +container_reverse_exclusive_scan(const Array& x, Reduce f, TData init) +{ + Array y; + + TData r = init; + + static_for{}([&](auto i) { + y(i) = r; + r = f(r, x[i]); + }); + + y(Number<0>{}) = r; + + return y; +} + +template +__host__ __device__ constexpr auto +container_reverse_exclusive_scan(const Sequence& seq, Reduce f, Number) +{ + return reverse_exclusive_scan_sequence(seq, f, Number{}); +} + +#if !CK_WORKAROUND_SWDEV_275126 +// rocm4.1 compiler would crash with recursive lambda +template +__host__ __device__ constexpr auto +container_reverse_exclusive_scan(const Tuple& x, Reduce reduce, Init init) +{ + constexpr index_t NSize = sizeof...(Xs); + + // f is recursive function, fs is a dummy of f + // i is index, y_old is current scan, r_old is current reduction + auto f = [&](auto fs, auto i, auto y_old, auto r_old) { + auto r_new = reduce(x[i], r_old); + + auto y_new = container_push_front(y_old, r_new); + + if constexpr(i.value > 1) + { + // recursively call f/fs + return fs(fs, i - Number<1>{}, y_new, r_new); + } + else + { + return y_new; + } + }; + + // start recursion + return f(f, Number{}, make_tuple(init), init); +} +#else +// i is index, y_old is current scan, r_old is current reduction +template +__host__ __device__ constexpr auto container_reverse_exclusive_scan_impl( + const Tuple& x, Reduce reduce, Number i, YOld y_old, ROld r_old) +{ + auto r_new = reduce(x[i], r_old); + + auto y_new = container_push_front(y_old, r_new); + + if constexpr(i.value > 1) + { + // recursively call f/fs + return container_reverse_exclusive_scan_impl(x, reduce, i - Number<1>{}, y_new, r_new); + } + else + { + return y_new; + } +} + +template +__host__ __device__ constexpr auto +container_reverse_exclusive_scan(const Tuple& x, Reduce reduce, Init init) +{ + constexpr index_t NSize = sizeof...(Xs); + + return container_reverse_exclusive_scan_impl( + x, reduce, Number{}, make_tuple(init), init); +} +#endif + +// TODO: update to like container_reverse_exclusive_scan to deal with Tuple of Numebr<> +template +__host__ __device__ constexpr auto +container_reverse_inclusive_scan(const Tuple& x, Reduce f, TData init) +{ + constexpr index_t NSize = sizeof...(Xs); + + Tuple y; + + TData r = init; + + static_for{}([&](auto i) { + r = f(r, x[i]); + y(i) = r; + }); + + r = f(r, x[Number<0>{}]); + y(Number<0>{}) = r; + + return y; +} + +template +__host__ __device__ constexpr auto container_concat(const X& x, const Ys&... ys) +{ + return container_concat(x, container_concat(ys...)); +} + +template +__host__ __device__ constexpr auto container_concat(const Array& ax, const Array& ay) +{ + return unpack2( + [&](auto&&... zs) { return make_array(std::forward(zs)...); }, ax, ay); +} + +template +__host__ __device__ constexpr auto container_concat(const Tuple& tx, const Tuple& ty) +{ + return unpack2( + [&](auto&&... zs) { return make_tuple(std::forward(zs)...); }, tx, ty); +} + +template +__host__ __device__ constexpr auto container_concat(const Container& x) +{ + return x; +} + +template +__host__ __device__ constexpr auto get_container_subset(const Array& arr, Sequence) +{ + static_assert(N >= sizeof...(Is), "wrong! size"); + + return make_array(arr[Number{}]...); +} + +template +__host__ __device__ constexpr auto get_container_subset(const Tuple& tup, Sequence) +{ + static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size"); + + return make_tuple(tup[Number{}]...); +} + +template +__host__ __device__ constexpr void +set_container_subset(Array& y, Sequence picks, const Array& x) +{ + static_assert(N >= sizeof...(Is), "wrong! size"); + + static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); +} + +template +__host__ __device__ constexpr void +set_container_subset(Tuple& y, Sequence picks, const Tuple& x) +{ + static_assert(sizeof...(Ys) >= sizeof...(Is) && sizeof...(Is) == sizeof...(Xs), "wrong! size"); + + static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); +} + +template +__host__ __device__ constexpr auto to_tuple_of_number(const Container&) +{ + static_assert(is_known_at_compile_time::value, "wrong!"); + + return generate_tuple( + [&](auto i) { + constexpr index_t tmp = Container::At(i); + return Number{}; + }, + Container::Size()); +} + +template +__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence) +{ + using Seq = Sequence; + + return generate_tuple( + [&](auto i) { + constexpr index_t tmp = Seq::At(i); + return Number{}; + }, + Seq::Size()); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/data_type.hpp b/composable_kernel/include/utility/data_type.hpp new file mode 100644 index 0000000000..24a2190e84 --- /dev/null +++ b/composable_kernel/include/utility/data_type.hpp @@ -0,0 +1,1017 @@ +#ifndef CK_FLOAT_TYPE_AMD_HPP +#define CK_FLOAT_TYPE_AMD_HPP + +#include "statically_indexed_array.hpp" + +namespace ck { + +using half_t = _Float16; + +// vector_type +template +struct vector_type; + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of +// vectors" +template +struct vector_type; + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of +// vectors" +template +struct vector_type, N>; + +// vector_type_maker +// This is the right way to handle "vector of vectors": making a bigger vector instead +template +struct vector_type_maker +{ + using type = vector_type; +}; + +template +struct vector_type_maker +{ + using type = vector_type; +}; + +template +struct vector_type_maker, N0> +{ + using type = vector_type; +}; + +template +using vector_type_maker_t = typename vector_type_maker::type; + +template +__host__ __device__ constexpr auto make_vector_type(Number) +{ + return typename vector_type_maker::type{}; +} + +// scalar_type +template +struct scalar_type; + +template +struct scalar_type +{ + using type = T; + static constexpr index_t vector_size = N; +}; + +template +struct scalar_type> +{ + using type = T; + static constexpr index_t vector_size = N; +}; + +// +template <> +struct scalar_type +{ + using type = float; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = half_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = ushort; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = int32_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = int8_t; + static constexpr index_t vector_size = 1; +}; + +// +template +struct vector_type +{ + using d1_t = T; + using type = d1_t; + + union + { + T d1_; + StaticallyIndexedArray d1x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value, "wrong!"); + + return data_.d1x1_; + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value, "wrong!"); + + return data_.d1x1_; + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + + using type = d2_t; + + union + { + d2_t d2_; + StaticallyIndexedArray d1x2_; + StaticallyIndexedArray d2x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value, "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value, "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + + using type = d4_t; + + union + { + d4_t d4_; + StaticallyIndexedArray d1x4_; + StaticallyIndexedArray d2x2_; + StaticallyIndexedArray d4x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + + using type = d8_t; + + union + { + d8_t d8_; + StaticallyIndexedArray d1x8_; + StaticallyIndexedArray d2x4_; + StaticallyIndexedArray d4x2_; + StaticallyIndexedArray d8x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + + using type = d16_t; + + union + { + d16_t d16_; + StaticallyIndexedArray d1x16_; + StaticallyIndexedArray d2x8_; + StaticallyIndexedArray d4x4_; + StaticallyIndexedArray d8x2_; + StaticallyIndexedArray d16x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + + using type = d32_t; + + union + { + d32_t d32_; + StaticallyIndexedArray d1x32_; + StaticallyIndexedArray d2x16_; + StaticallyIndexedArray d4x8_; + StaticallyIndexedArray d8x4_; + StaticallyIndexedArray d16x2_; + StaticallyIndexedArray d32x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + + using type = d64_t; + + union + { + d64_t d64_; + StaticallyIndexedArray d1x64_; + StaticallyIndexedArray d2x32_; + StaticallyIndexedArray d4x16_; + StaticallyIndexedArray d8x8_; + StaticallyIndexedArray d16x4_; + StaticallyIndexedArray d32x2_; + StaticallyIndexedArray d64x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + typedef T d128_t __attribute__((ext_vector_type(128))); + + using type = d128_t; + + union + { + d128_t d128_; + StaticallyIndexedArray d1x128_; + StaticallyIndexedArray d2x64_; + StaticallyIndexedArray d4x32_; + StaticallyIndexedArray d8x16_; + StaticallyIndexedArray d16x8_; + StaticallyIndexedArray d32x4_; + StaticallyIndexedArray d64x2_; + StaticallyIndexedArray d128x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x128_; + } + else if constexpr(is_same::value) + { + return data_.d2x64_; + } + else if constexpr(is_same::value) + { + return data_.d4x32_; + } + else if constexpr(is_same::value) + { + return data_.d8x16_; + } + else if constexpr(is_same::value) + { + return data_.d16x8_; + } + else if constexpr(is_same::value) + { + return data_.d32x4_; + } + else if constexpr(is_same::value) + { + return data_.d64x2_; + } + else if constexpr(is_same::value) + { + return data_.d128x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x128_; + } + else if constexpr(is_same::value) + { + return data_.d2x64_; + } + else if constexpr(is_same::value) + { + return data_.d4x32_; + } + else if constexpr(is_same::value) + { + return data_.d8x16_; + } + else if constexpr(is_same::value) + { + return data_.d16x8_; + } + else if constexpr(is_same::value) + { + return data_.d32x4_; + } + else if constexpr(is_same::value) + { + return data_.d64x2_; + } + else if constexpr(is_same::value) + { + return data_.d128x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + typedef T d128_t __attribute__((ext_vector_type(128))); + typedef T d256_t __attribute__((ext_vector_type(256))); + + using type = d256_t; + + union + { + d256_t d256_; + StaticallyIndexedArray d1x256_; + StaticallyIndexedArray d2x128_; + StaticallyIndexedArray d4x64_; + StaticallyIndexedArray d8x32_; + StaticallyIndexedArray d16x16_; + StaticallyIndexedArray d32x8_; + StaticallyIndexedArray d64x4_; + StaticallyIndexedArray d128x2_; + StaticallyIndexedArray d256x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert( + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x256_; + } + else if constexpr(is_same::value) + { + return data_.d2x128_; + } + else if constexpr(is_same::value) + { + return data_.d4x64_; + } + else if constexpr(is_same::value) + { + return data_.d8x32_; + } + else if constexpr(is_same::value) + { + return data_.d16x16_; + } + else if constexpr(is_same::value) + { + return data_.d32x8_; + } + else if constexpr(is_same::value) + { + return data_.d64x4_; + } + else if constexpr(is_same::value) + { + return data_.d128x2_; + } + else if constexpr(is_same::value) + { + return data_.d256x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert( + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x256_; + } + else if constexpr(is_same::value) + { + return data_.d2x128_; + } + else if constexpr(is_same::value) + { + return data_.d4x64_; + } + else if constexpr(is_same::value) + { + return data_.d8x32_; + } + else if constexpr(is_same::value) + { + return data_.d16x16_; + } + else if constexpr(is_same::value) + { + return data_.d32x8_; + } + else if constexpr(is_same::value) + { + return data_.d64x4_; + } + else if constexpr(is_same::value) + { + return data_.d128x2_; + } + else if constexpr(is_same::value) + { + return data_.d256x1_; + } + } +}; + +// fp32 +using float2_t = typename vector_type::type; +using float4_t = typename vector_type::type; +using float8_t = typename vector_type::type; +using float16_t = typename vector_type::type; +using float32_t = typename vector_type::type; +using float64_t = typename vector_type::type; + +// fp16 +using half2_t = typename vector_type::type; +using half4_t = typename vector_type::type; +using half8_t = typename vector_type::type; +using half16_t = typename vector_type::type; +using half32_t = typename vector_type::type; +using half64_t = typename vector_type::type; + +// bfp16 +using ushort2_t = typename vector_type::type; +using ushort4_t = typename vector_type::type; +using ushort8_t = typename vector_type::type; +using ushort16_t = typename vector_type::type; +using ushort32_t = typename vector_type::type; +using ushort64_t = typename vector_type::type; + +// i32 +using int32x2_t = typename vector_type::type; +using int32x4_t = typename vector_type::type; +using int32x8_t = typename vector_type::type; +using int32x16_t = typename vector_type::type; +using int32x32_t = typename vector_type::type; +using int32x64_t = typename vector_type::type; + +// i8 +using int8x2_t = typename vector_type::type; +using int8x4_t = typename vector_type::type; +using int8x8_t = typename vector_type::type; +using int8x16_t = typename vector_type::type; +using int8x32_t = typename vector_type::type; +using int8x64_t = typename vector_type::type; + +// data type conversion +template +struct type_convert +{ + template + __device__ T operator()(X x) const + { + return static_cast(x); + } +}; + +template <> +template <> +__device__ float type_convert::operator()(ushort x) const +{ + return bfloat16_to_float(x); +} + +template <> +template <> +__device__ ushort type_convert::operator()(float x) const +{ + return float_to_bfloat16(x); +} + +// TODO: deprecate this +template +struct inner_product_with_conversion +{ + static constexpr auto convert = type_convert(); + + template + __device__ T operator()(typename vector_type::type a, + typename vector_type::type b) const + { + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + T acc = 0; + + static_for<0, N, 1>{}([&](auto i) { + acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]); + }); + + return acc; + } + + __device__ T operator()(float_t a, float_t b) const { return convert(a) * convert(b); } + + __device__ T operator()(int8x4_t a, int8x4_t b) const + { + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + T acc = 0; + + static_for<0, 4, 1>{}([&](auto i) { + acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + }); + + return acc; + } + + __device__ T operator()(int8x8_t a, int8x8_t b) const + { + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + T acc = 0; + + static_for<0, 8, 1>{}([&](auto i) { + acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + }); + + return acc; + } + + __device__ T operator()(int8x16_t a, int8x16_t b) const + { + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + T acc = 0; + + static_for<0, 16, 1>{}([&](auto i) { + acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + }); + + return acc; + } +}; + +template +struct NumericLimits; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int32_t Min() + { + return std::numeric_limits::min(); + } + + __host__ __device__ static constexpr int32_t Max() + { + return std::numeric_limits::max(); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/data_type_enum.hpp b/composable_kernel/include/utility/data_type_enum.hpp new file mode 100644 index 0000000000..43499605dc --- /dev/null +++ b/composable_kernel/include/utility/data_type_enum.hpp @@ -0,0 +1,20 @@ +#ifndef CK_DATA_TYPE_ENUM_HPP +#define CK_DATA_TYPE_ENUM_HPP + +namespace ck { + +// this enumerate should be synchronized with include/miopen.h +typedef enum +{ + Half = 0, + Float = 1, + Int32 = 2, + Int8 = 3, + Int8x4 = 4, + BFloat16 = 5, + Double = 6, + Unknown = 100, +} DataTypeEnum_t; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/data_type_helper.hpp b/composable_kernel/include/utility/data_type_helper.hpp new file mode 100644 index 0000000000..6a234cd10b --- /dev/null +++ b/composable_kernel/include/utility/data_type_helper.hpp @@ -0,0 +1,76 @@ +#ifndef CK_DATA_TYPE_HELPER_HPP +#define CK_DATA_TYPE_HELPER_HPP + +#include "data_type.hpp" +#include "data_type_enum.hpp" + +namespace ck { + +template +struct get_datatype_from_enum; + +template <> +struct get_datatype_from_enum +{ + using type = int8_t; +}; + +template <> +struct get_datatype_from_enum +{ + using type = int32_t; +}; + +template <> +struct get_datatype_from_enum +{ + using type = half_t; +}; + +template <> +struct get_datatype_from_enum +{ + using type = float; +}; + +template <> +struct get_datatype_from_enum +{ + using type = double; +}; + +template +struct get_datatype_enum_from_type; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum_t value = DataTypeEnum_t::Int8; +}; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum_t value = DataTypeEnum_t::Int32; +}; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum_t value = DataTypeEnum_t::Half; +}; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum_t value = DataTypeEnum_t::Float; +}; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum_t value = DataTypeEnum_t::Double; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/dynamic_buffer.hpp b/composable_kernel/include/utility/dynamic_buffer.hpp new file mode 100644 index 0000000000..5f5f386306 --- /dev/null +++ b/composable_kernel/include/utility/dynamic_buffer.hpp @@ -0,0 +1,208 @@ +#ifndef CK_DYNAMIC_BUFFER_HPP +#define CK_DYNAMIC_BUFFER_HPP + +namespace ck { + +#include "amd_buffer_addressing_v2.hpp" + +template +struct DynamicBuffer +{ + using type = T; + + T* p_data_; + ElementSpaceSize element_space_size_; + + __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size) + : p_data_{p_data}, element_space_size_{element_space_size} + { + } + + __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() + { + return BufferAddressSpace; + } + + __host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; } + + __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; } + + template >>::type, + typename scalar_type>>::type>::value, + bool>::type = false> + __host__ __device__ constexpr auto Get(index_t i, bool is_valid_offset) const + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = + scalar_type>>::vector_size; + + constexpr index_t scalar_per_x_vector = + scalar_type>>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X need to be multiple T"); + + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global) + { +#if CK_USE_AMD_BUFFER_ADDRESSING + return amd_buffer_load_v2>, t_per_x>( + p_data_, i, is_valid_offset, element_space_size_); +#else + return is_valid_offset ? *reinterpret_cast(&p_data_[i]) : X{0}; +#endif + } + else + { + return is_valid_offset ? *reinterpret_cast(&p_data_[i]) : X{0}; + } + } + + template >>::type, + typename scalar_type>>::type>::value, + bool>::type = false> + __host__ __device__ void Set(index_t i, bool is_valid_offset, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = + scalar_type>>::vector_size; + + constexpr index_t scalar_per_x_vector = + scalar_type>>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X need to be multiple T"); + + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global) + { +#if CK_USE_AMD_BUFFER_ADDRESSING + amd_buffer_store_v2>, t_per_x>( + x, p_data_, i, is_valid_offset, element_space_size_); +#else + if(is_valid_offset) + { + *reinterpret_cast(&p_data_[i]) = x; + } +#endif + } + else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds) + { + if(is_valid_offset) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE + *reinterpret_cast(&p_data_[i]) = x; +#else + // HACK: compiler would lower IR "store address_space(3)" into + // inefficient + // ISA, so I try to let compiler emit IR "store" which would be lower to + // ds_write_b128 + // TODO: remove this after compiler fix + if constexpr(is_same>>::type, + int8_t>::value) + { + static_assert( + (is_same>, int8_t>::value && + is_same>, int8_t>::value) || + (is_same>, int8_t>::value && + is_same>, int8x2_t>::value) || + (is_same>, int8_t>::value && + is_same>, int8x4_t>::value) || + (is_same>, int8x4_t>::value && + is_same>, int8x4_t>::value) || + (is_same>, int8x8_t>::value && + is_same>, int8x8_t>::value) || + (is_same>, int8x16_t>::value && + is_same>, int8x16_t>::value), + "wrong! not implemented for this combination, please add " + "implementation"); + + if constexpr(is_same>, int8_t>::value && + is_same>, int8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *reinterpret_cast(&p_data_[i]) = + *reinterpret_cast(&x); + } + else if constexpr(is_same>, int8_t>::value && + is_same>, int8x2_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *reinterpret_cast(&p_data_[i]) = + *reinterpret_cast(&x); + } + else if constexpr(is_same>, int8_t>::value && + is_same>, int8x4_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *reinterpret_cast(&p_data_[i]) = + *reinterpret_cast(&x); + } + else if constexpr(is_same>, + int8x4_t>::value && + is_same>, int8x4_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *reinterpret_cast(&p_data_[i]) = + *reinterpret_cast(&x); + } + else if constexpr(is_same>, + int8x8_t>::value && + is_same>, int8x8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *reinterpret_cast(&p_data_[i]) = + *reinterpret_cast(&x); + } + else if constexpr(is_same>, + int8x16_t>::value && + is_same>, int8x16_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *reinterpret_cast(&p_data_[i]) = + *reinterpret_cast(&x); + } + } + else + { + *reinterpret_cast(&p_data_[i]) = x; + } +#endif + } + } + else + { + if(is_valid_offset) + { + *reinterpret_cast(&p_data_[i]) = x; + } + } + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return false; } + + __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } +}; + +template +__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size) +{ + return DynamicBuffer{p, element_space_size}; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/functional.hpp b/composable_kernel/include/utility/functional.hpp new file mode 100644 index 0000000000..b84b617f44 --- /dev/null +++ b/composable_kernel/include/utility/functional.hpp @@ -0,0 +1,116 @@ +#ifndef CK_FUNCTIONAL_HPP +#define CK_FUNCTIONAL_HPP + +#include "integral_constant.hpp" +#include "type.hpp" + +namespace ck { + +// TODO: right? wrong? +struct forwarder +{ + template + __host__ __device__ constexpr T&& operator()(T&& x) const + { + return static_cast(x); + } +}; + +struct swallow +{ + template + __host__ __device__ constexpr swallow(Ts&&...) + { + } +}; + +template +struct logical_and +{ + constexpr bool operator()(const T& x, const T& y) const { return x && y; } +}; + +template +struct logical_or +{ + constexpr bool operator()(const T& x, const T& y) const { return x || y; } +}; + +template +struct logical_not +{ + constexpr bool operator()(const T& x) const { return !x; } +}; + +// Emulate if constexpr +template +struct static_if; + +template <> +struct static_if +{ + using Type = static_if; + + template + __host__ __device__ constexpr auto operator()(F f) const + { + // This is a trick for compiler: + // Pass forwarder to lambda "f" as "auto" argument, and make sure "f" will + // use it, + // this will make "f" a generic lambda, so that "f" won't be compiled + // until being + // instantiated here + f(forwarder{}); + return Type{}; + } + + template + __host__ __device__ static void Else(F) + { + } +}; + +template <> +struct static_if +{ + using Type = static_if; + + template + __host__ __device__ constexpr auto operator()(F) const + { + return Type{}; + } + + template + __host__ __device__ static void Else(F f) + { + // This is a trick for compiler: + // Pass forwarder to lambda "f" as "auto" argument, and make sure "f" will + // use it, + // this will make "f" a generic lambda, so that "f" won't be compiled + // until being + // instantiated here + f(forwarder{}); + } +}; + +template +struct conditional; + +template +struct conditional +{ + using type = X; +}; + +template +struct conditional +{ + using type = Y; +}; + +template +using conditional_t = typename conditional::type; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/functional2.hpp b/composable_kernel/include/utility/functional2.hpp new file mode 100644 index 0000000000..371182a05e --- /dev/null +++ b/composable_kernel/include/utility/functional2.hpp @@ -0,0 +1,48 @@ +#ifndef CK_FUNCTIONAL2_HPP +#define CK_FUNCTIONAL2_HPP + +#include "functional.hpp" +#include "sequence.hpp" + +namespace ck { + +namespace detail { + +template +struct static_for_impl; + +template +struct static_for_impl> +{ + template + __host__ __device__ constexpr void operator()(F f) const + { + swallow{(f(Number{}), 0)...}; + } +}; + +} // namespace detail + +// F signature: F(Number) +template +struct static_for +{ + __host__ __device__ constexpr static_for() + { + static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0, + "Wrong! should satisfy (NEnd - NBegin) % Increment == 0"); + static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd), + "wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && " + "NBegin >= NEnd)"); + } + + template + __host__ __device__ constexpr void operator()(F f) const + { + detail::static_for_impl::type>{}( + f); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/functional3.hpp b/composable_kernel/include/utility/functional3.hpp new file mode 100644 index 0000000000..6a400f3ca6 --- /dev/null +++ b/composable_kernel/include/utility/functional3.hpp @@ -0,0 +1,142 @@ +#ifndef CK_FUNCTIONAL3_HPP +#define CK_FUNCTIONAL3_HPP + +#include "functional.hpp" +#include "functional2.hpp" +#include "sequence.hpp" +#include "multi_index.hpp" + +namespace ck { + +namespace detail { + +// RemainLengths: Sequence<...> +// Orders: Sequence<...> +template +struct static_ford_impl +{ + __host__ __device__ constexpr static_ford_impl() + { + static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here"); + } + + // F signature: F(Sequence<...>) + // CurrentOrderedId: Sequence<...> + template + __host__ __device__ constexpr void operator()(F f, CurrentOrderedId) const + { + static_for<0, RemainLengths::Front(), 1>{}([=](auto I) { + static_ford_impl{}( + f, CurrentOrderedId::PushBack(I)); + }); + } +}; + +template +struct static_ford_impl, Orders> +{ + // F signature: F(Sequence<...>) + // OrderedId: Sequence<...> + template + __host__ __device__ constexpr void operator()(F f, OrderedId) const + { + // retrive unordered Id + f(OrderedId::ReorderGivenOld2New(Orders{})); + } +}; + +// RemainLengths: Sequence<...> +// Orders: Sequence<...> +template +struct ford_impl +{ + __host__ __device__ constexpr ford_impl() + { + static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here"); + } + + // F signature: F(Array<...> multi_id) + // CurrentOrderdId: Array<...> + template + __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const + { + for(index_t i = 0; i < RemainLengths::Front(); ++i) + { + ford_impl{}( + f, container_push_back(current_ordered_id, i)); + } + } +}; + +template +struct ford_impl, Orders> +{ + // F signature: F(Array<...> multi_id) + // CurrentOrderdId: Array<...> + template + __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const + { + // retrive unordered Id + f(container_reorder_given_old2new(current_ordered_id, Orders{})); + } +}; + +} // namespace detail + +// Lengths is Sequence<...>, it is the length of each dimension for +// N-dimensional loop +// Orders is Sequence<...>, it is the order of dimension in which static_ford +// will loop over each +// dimension +template ::type> +struct static_ford +{ + __host__ __device__ constexpr static_ford() + { + static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty"); + static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size"); + } + + // F signature: F(Sequence<...> multi_id) + // multi_id is the unordered multi-index + template + __host__ __device__ constexpr void operator()(F f) const + { + constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{}); + detail::static_ford_impl{}(f, Sequence<>{}); + } +}; + +// Lengths is Sequence<...>, it is the length of each dimension for +// N-dimensional loop +// Orders is Sequence<...>, it is the order of dimension in which ford will loop +// over each +// dimension +template ::type> +struct ford +{ + __host__ __device__ constexpr ford() + { + static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty"); + static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size"); + } + + // F signature: F(Array<...> multi_id) + // multi_id is the unordered multi-index + template + __host__ __device__ constexpr void operator()(F f) const + { + constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{}); + + for(index_t i = 0; i < ordered_lengths.Front(); ++i) + { + detail::ford_impl{}(f, + make_multi_index(i)); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/functional4.hpp b/composable_kernel/include/utility/functional4.hpp new file mode 100644 index 0000000000..b039644380 --- /dev/null +++ b/composable_kernel/include/utility/functional4.hpp @@ -0,0 +1,62 @@ +#ifndef CK_FUNCTIONAL4_HPP +#define CK_FUNCTIONAL4_HPP + +#include "sequence.hpp" +#include "tuple.hpp" +#include "array.hpp" + +namespace ck { + +namespace detail { + +template +struct unpack_impl; + +template +struct unpack_impl> +{ + template + __host__ __device__ constexpr auto operator()(F&& f, X&& x) const + { + return std::forward(f)(std::forward(x).At(Number{})...); + } +}; + +template +struct unpack2_impl; + +// TODO: remove this, after properly implementing unpack that takes any number of containers +template +struct unpack2_impl, Sequence> +{ + template + __host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const + { + return std::forward(f)(std::forward(x).At(Number{})..., + std::forward(y).At(Number{})...); + } +}; + +} // namespace detail + +template +__host__ __device__ constexpr auto unpack(F&& f, X&& x) +{ + using X_ = remove_reference_t; + return detail::unpack_impl::type>{}( + std::forward(f), std::forward(x)); +} + +// TODO: properly implement unpack that takes any number of containers +template +__host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y) +{ + using X_ = remove_reference_t; + using Y_ = remove_reference_t; + return detail::unpack2_impl::type, + typename arithmetic_sequence_gen<0, Y_::Size(), 1>::type>{}( + std::forward(f), std::forward(x), std::forward(y)); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/integral_constant.hpp b/composable_kernel/include/utility/integral_constant.hpp new file mode 100644 index 0000000000..14f3df894b --- /dev/null +++ b/composable_kernel/include/utility/integral_constant.hpp @@ -0,0 +1,17 @@ +#ifndef CK_INTEGRAL_CONSTANT_HPP +#define CK_INTEGRAL_CONSTANT_HPP + +namespace ck { + +template +struct integral_constant +{ + static constexpr T value = v; + typedef T value_type; + typedef integral_constant type; + __host__ __device__ constexpr operator value_type() const noexcept { return value; } + __host__ __device__ constexpr value_type operator()() const noexcept { return value; } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/magic_division.hpp b/composable_kernel/include/utility/magic_division.hpp new file mode 100644 index 0000000000..b7489016e9 --- /dev/null +++ b/composable_kernel/include/utility/magic_division.hpp @@ -0,0 +1,155 @@ +#ifndef CK_MAGIC_DIVISION_HPP +#define CK_MAGIC_DIVISION_HPP + +#include "config.hpp" +#include "integral_constant.hpp" +#include "number.hpp" +#include "type.hpp" +#include "tuple.hpp" + +namespace ck { + +// magic number division +// Caution: +// 1. For uint32_t as dividend: magic number division implementation being used would produce +// correct result if the dividend is uint32_t and its value is within 31-bit value range. +// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been +// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number +// division implementation for uint32_t is then used. Therefore, dividend value need to be +// non-negative. +// TODO: +// 1. Implement magic number divison for int32_t +// 2. Implement magic number divison for unit32_t with 32-bit value range +struct MagicDivision +{ + // uint32_t + __host__ __device__ static constexpr auto CalculateMagicNumbers(uint32_t divisor) + { + // assert(divisior >= 1 && divisior <= INT32_MAX); + uint32_t shift = 0; + for(shift = 0; shift < 32; ++shift) + { + if((1U << shift) >= divisor) + { + break; + } + } + + uint64_t one = 1; + uint64_t multiplier = ((one << 32) * ((one << shift) - divisor)) / divisor + 1; + // assert(multiplier <= 0xffffffffUL); + + return make_tuple(uint32_t(multiplier), shift); + } + + __host__ __device__ static constexpr uint32_t CalculateMagicMultiplier(uint32_t divisor) + { + auto tmp = CalculateMagicNumbers(divisor); + + return tmp[Number<0>{}]; + } + + __host__ __device__ static constexpr uint32_t CalculateMagicShift(uint32_t divisor) + { + auto tmp = CalculateMagicNumbers(divisor); + + return tmp[Number<1>{}]; + } + + // integral_constant + template + __host__ __device__ static constexpr auto + CalculateMagicNumbers(integral_constant) + { + constexpr auto tmp = CalculateMagicNumbers(uint32_t{Divisor}); + + constexpr uint32_t multiplier = tmp[Number<0>{}]; + constexpr uint32_t shift = tmp[Number<1>{}]; + + return make_tuple(integral_constant{}, + integral_constant{}); + } + + template + __host__ __device__ static constexpr auto + CalculateMagicMultiplier(integral_constant) + { + constexpr uint32_t multiplier = CalculateMagicMultiplier(uint32_t{Divisor}); + + return integral_constant{}; + } + + template + __host__ __device__ static constexpr auto + CalculateMagicShift(integral_constant) + { + constexpr uint32_t shift = CalculateMagicShift(uint32_t{Divisor}); + + return integral_constant{}; + } + + // integral_constant + template + __host__ __device__ static constexpr auto + CalculateMagicNumbers(integral_constant) + { + return CalculateMagicNumbers(integral_constant{}); + } + + template + __host__ __device__ static constexpr auto + CalculateMagicMultiplier(integral_constant) + { + return CalculateMagicMultiplier(integral_constant{}); + } + + template + __host__ __device__ static constexpr auto + CalculateMagicShift(integral_constant) + { + return CalculateMagicShift(integral_constant{}); + } + + // magic division for uint32_t + __host__ __device__ static constexpr uint32_t + DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) + { + uint32_t tmp = (uint64_t(dividend) * uint64_t(multiplier)) >> 32; + return (tmp + dividend) >> shift; + } + +#if 1 // debug + // HACK: magic division for int32_t + // HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be + // non-negative for result to be correct + // TODO: figure out how to do magic number divison for int32_t as dividended + __host__ __device__ static constexpr int32_t + DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) + { + uint32_t dividend_u32 = as_type(dividend_i32); + uint32_t tmp = + (static_cast(dividend_u32) * static_cast(multiplier)) >> 32; + return (tmp + dividend_u32) >> shift; + } +#else + // the inline ASM is producing wrong result + __host__ __device__ static int32_t + DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) + { + uint32_t r; + asm volatile("\n \ + v_mul_hi_u32 %0, %1, %2 \n \ + v_add_u32_e32 %0, %1, %0 \n \ + v_lshrrev_b32_e32 %0, %3, %0 \n \ + " + : "=v"(r) + : "v"(as_type(dividend_i32)), "s"(multiplier), "s"(shift)); + + return as_type(r); + } +#endif +}; + +} // namespace ck + +#endif diff --git a/composable_kernel/include/utility/math.hpp b/composable_kernel/include/utility/math.hpp new file mode 100644 index 0000000000..e451059647 --- /dev/null +++ b/composable_kernel/include/utility/math.hpp @@ -0,0 +1,225 @@ +#ifndef CK_MATH_HPP +#define CK_MATH_HPP + +#include "config.hpp" +#include "integral_constant.hpp" +#include "number.hpp" +#include "type.hpp" + +namespace ck { +namespace math { + +template +struct scales +{ + __host__ __device__ constexpr T operator()(T a) const { return s * a; } +}; + +template +struct plus +{ + __host__ __device__ constexpr T operator()(T a, T b) const { return a + b; } +}; + +template +struct minus +{ + __host__ __device__ constexpr T operator()(T a, T b) const { return a - b; } +}; + +template +struct multiplies +{ + __host__ __device__ constexpr T operator()(T a, T b) const { return a * b; } +}; + +struct multiplies_v2 +{ + template + __host__ __device__ constexpr auto operator()(const A& a, const B& b) const + { + return a * b; + } +}; + +template +struct maximize +{ + __host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; } +}; + +template +struct minimize +{ + __host__ __device__ constexpr T operator()(T a, T b) const { return a <= b ? a : b; } +}; + +template +struct integer_divide_ceiler +{ + __host__ __device__ constexpr T operator()(T a, T b) const + { + static_assert(is_same{} || is_same{}, "wrong type"); + + return (a + b - Number<1>{}) / b; + } +}; + +template +__host__ __device__ constexpr auto integer_divide_floor(X x, Y y) +{ + return x / y; +} + +template +__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y) +{ + return (x + y - Number<1>{}) / y; +} + +template +__host__ __device__ constexpr auto integer_least_multiple(X x, Y y) +{ + return y * integer_divide_ceil(x, y); +} + +template +__host__ __device__ constexpr T max(T x) +{ + return x; +} + +template +__host__ __device__ constexpr T max(T x, T y) +{ + return x > y ? x : y; +} + +template +__host__ __device__ constexpr index_t max(Number, index_t y) +{ + return X > y ? X : y; +} + +template +__host__ __device__ constexpr index_t max(index_t x, Number) +{ + return x > Y ? x : Y; +} + +template +__host__ __device__ constexpr auto max(X x, Ys... ys) +{ + static_assert(sizeof...(Ys) > 0, "not enough argument"); + + return max(x, max(ys...)); +} + +template +__host__ __device__ constexpr T min(T x) +{ + return x; +} + +template +__host__ __device__ constexpr T min(T x, T y) +{ + return x < y ? x : y; +} + +template +__host__ __device__ constexpr index_t min(Number, index_t y) +{ + return X < y ? X : y; +} + +template +__host__ __device__ constexpr index_t min(index_t x, Number) +{ + return x < Y ? x : Y; +} + +template +__host__ __device__ constexpr auto min(X x, Ys... ys) +{ + static_assert(sizeof...(Ys) > 0, "not enough argument"); + + return min(x, min(ys...)); +} + +// greatest common divisor, aka highest common factor +__host__ __device__ constexpr index_t gcd(index_t x, index_t y) +{ + if(x < 0) + { + return gcd(-x, y); + } + else if(y < 0) + { + return gcd(x, -y); + } + else if(x == y || x == 0) + { + return y; + } + else if(y == 0) + { + return x; + } + else if(x > y) + { + return gcd(x % y, y); + } + else + { + return gcd(x, y % x); + } +} + +template +__host__ __device__ constexpr auto gcd(Number, Number) +{ + constexpr auto r = gcd(X, Y); + + return Number{}; +} + +template = 2, bool>::type = false> +__host__ __device__ constexpr auto gcd(X x, Ys... ys) +{ + return gcd(x, gcd(ys...)); +} + +// least common multiple +template +__host__ __device__ constexpr auto lcm(X x, Y y) +{ + return (x * y) / gcd(x, y); +} + +template = 2, bool>::type = false> +__host__ __device__ constexpr auto lcm(X x, Ys... ys) +{ + return lcm(x, lcm(ys...)); +} + +template +struct equal +{ + __host__ __device__ constexpr bool operator()(T x, T y) const { return x == y; } +}; + +template +struct less +{ + __host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; } +}; + +} // namespace math +} // namespace ck + +#endif diff --git a/composable_kernel/include/utility/multi_index.hpp b/composable_kernel/include/utility/multi_index.hpp new file mode 100644 index 0000000000..0bb34fb1e2 --- /dev/null +++ b/composable_kernel/include/utility/multi_index.hpp @@ -0,0 +1,12 @@ +#ifndef CK_MULTI_INDEX_HPP +#define CK_MULTI_INDEX_HPP + +#include "common_header.hpp" + +#if CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX +#include "array_multi_index.hpp" +#else +#include "statically_indexed_array_multi_index.hpp" +#endif + +#endif diff --git a/composable_kernel/include/utility/number.hpp b/composable_kernel/include/utility/number.hpp new file mode 100644 index 0000000000..f8c5643694 --- /dev/null +++ b/composable_kernel/include/utility/number.hpp @@ -0,0 +1,44 @@ +#ifndef CK_NUMBER_HPP +#define CK_NUMBER_HPP + +#include "integral_constant.hpp" + +namespace ck { + +template +using Number = integral_constant; + +template +__host__ __device__ constexpr auto operator+(Number, Number) +{ + return Number{}; +} + +template +__host__ __device__ constexpr auto operator-(Number, Number) +{ + static_assert(Y <= X, "wrong!"); + return Number{}; +} + +template +__host__ __device__ constexpr auto operator*(Number, Number) +{ + return Number{}; +} + +template +__host__ __device__ constexpr auto operator/(Number, Number) +{ + static_assert(Y > 0, "wrong!"); + return Number{}; +} + +template +__host__ __device__ constexpr auto operator%(Number, Number) +{ + static_assert(Y > 0, "wrong!"); + return Number{}; +} +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/print.hpp b/composable_kernel/include/utility/print.hpp new file mode 100644 index 0000000000..0dd646153a --- /dev/null +++ b/composable_kernel/include/utility/print.hpp @@ -0,0 +1,70 @@ +#ifndef CK_PRINT_HPP +#define CK_PRINT_HPP + +#include "array.hpp" +#include "statically_indexed_array.hpp" +#include "container_helper.hpp" +#include "sequence.hpp" + +namespace ck { + +template +__host__ __device__ void print_array(const char* s, T a) +{ + using data_type = decltype(a.At(Number<0>{})); + constexpr index_t nsize = a.Size(); + +#if 0 + if constexpr(is_same{}) + { + printf("%s size %u, {", s, nsize); + static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%u, ", uint32_t{a[i]}); }); + printf("}\n"); + } + else if constexpr(is_same{}) + { + printf("%s size %d, {", s, nsize); + static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); }); + printf("}\n"); + } + else if constexpr(is_same{}) + { + printf("%s size %d, {", s, nsize); + static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", bool{a[i]}); }); + printf("}\n"); + } +#else + printf("%s size %d, {", s, nsize); + static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); }); + printf("}\n"); +#endif +} + +template +__host__ __device__ void print_array_v2(const char* s, T a) +{ + using data_type = decltype(a.At(Number<0>{})); + constexpr index_t nsize = a.Size(); + +#if 0 + if constexpr(is_same{}) + { + printf("%s size %u, {", s, nsize); + static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%u] %u, ", i.value, a[i]); }); + printf("}\n"); + } + else if constexpr(is_same{}) + { + printf("%s size %d, {", s, nsize); + static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); }); + printf("}\n"); + } +#else + printf("%s size %d, {", s, nsize); + static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); }); + printf("}\n"); +#endif +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/sequence.hpp b/composable_kernel/include/utility/sequence.hpp new file mode 100644 index 0000000000..81eb488715 --- /dev/null +++ b/composable_kernel/include/utility/sequence.hpp @@ -0,0 +1,882 @@ +#ifndef CK_SEQUENCE_HPP +#define CK_SEQUENCE_HPP + +#include "integral_constant.hpp" +#include "type.hpp" +#include "functional.hpp" +#include "math.hpp" + +namespace ck { + +template +struct static_for; + +template +struct Sequence; + +template +struct sequence_split; + +template +struct sequence_reverse; + +template +struct sequence_map_inverse; + +template +struct is_valid_sequence_map; + +template +__host__ __device__ constexpr auto sequence_pop_front(Sequence); + +template +__host__ __device__ constexpr auto sequence_pop_back(Seq); + +template +struct Sequence +{ + using Type = Sequence; + using data_type = index_t; + + static constexpr index_t mSize = sizeof...(Is); + + __host__ __device__ static constexpr auto Size() { return Number{}; } + + __host__ __device__ static constexpr auto GetSize() { return Size(); } + + __host__ __device__ static constexpr index_t At(index_t I) + { + // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 + const index_t mData[mSize + 1] = {Is..., 0}; + return mData[I]; + } + + template + __host__ __device__ static constexpr auto At(Number) + { + static_assert(I < mSize, "wrong! I too large"); + + return Number{}; + } + + template + __host__ __device__ static constexpr auto Get(Number) + { + return At(Number{}); + } + + template + __host__ __device__ constexpr auto operator[](I i) const + { + return At(i); + } + + template + __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence /*new2old*/) + { + static_assert(sizeof...(Is) == sizeof...(IRs), + "wrong! reorder map should have the same size as Sequence to be rerodered"); + + static_assert(is_valid_sequence_map>::value, "wrong! invalid reorder map"); + + return Sequence{})...>{}; + } + + // MapOld2New is Sequence<...> + template + __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New) + { + static_assert(MapOld2New::Size() == Size(), + "wrong! reorder map should have the same size as Sequence to be rerodered"); + + static_assert(is_valid_sequence_map::value, "wrong! invalid reorder map"); + + return ReorderGivenNew2Old(typename sequence_map_inverse::type{}); + } + + __host__ __device__ static constexpr auto Reverse() + { + return typename sequence_reverse::type{}; + } + + __host__ __device__ static constexpr auto Front() + { + static_assert(mSize > 0, "wrong!"); + return At(Number<0>{}); + } + + __host__ __device__ static constexpr auto Back() + { + static_assert(mSize > 0, "wrong!"); + return At(Number{}); + } + + __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); } + + __host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); } + + template + __host__ __device__ static constexpr auto PushFront(Sequence) + { + return Sequence{}; + } + + template + __host__ __device__ static constexpr auto PushFront(Number...) + { + return Sequence{}; + } + + template + __host__ __device__ static constexpr auto PushBack(Sequence) + { + return Sequence{}; + } + + template + __host__ __device__ static constexpr auto PushBack(Number...) + { + return Sequence{}; + } + + template + __host__ __device__ static constexpr auto Extract(Number...) + { + return Sequence{})...>{}; + } + + template + __host__ __device__ static constexpr auto Extract(Sequence) + { + return Sequence{})...>{}; + } + + template + __host__ __device__ static constexpr auto Modify(Number, Number) + { + static_assert(I < Size(), "wrong!"); + + using seq_split = sequence_split; + constexpr auto seq_left = typename seq_split::left_type{}; + constexpr auto seq_right = typename seq_split::right_type{}.PopFront(); + + return seq_left.PushBack(Number{}).PushBack(seq_right); + } + + template + __host__ __device__ static constexpr auto Transform(F f) + { + return Sequence{}; + } + + __host__ __device__ static void Print() + { + printf("{"); + printf("size %d, ", index_t{Size()}); + static_for<0, Size(), 1>{}([&](auto i) { printf("%d ", At(i).value); }); + printf("}"); + } +}; + +// merge sequence +template +struct sequence_merge +{ + using type = typename sequence_merge::type>::type; +}; + +template +struct sequence_merge, Sequence> +{ + using type = Sequence; +}; + +template +struct sequence_merge +{ + using type = Seq; +}; + +// generate sequence +template +struct sequence_gen +{ + template + struct sequence_gen_impl + { + static constexpr index_t NRemainLeft = NRemain / 2; + static constexpr index_t NRemainRight = NRemain - NRemainLeft; + static constexpr index_t IMiddle = IBegin + NRemainLeft; + + using type = typename sequence_merge< + typename sequence_gen_impl::type, + typename sequence_gen_impl::type>::type; + }; + + template + struct sequence_gen_impl + { + static constexpr index_t Is = G{}(Number{}); + using type = Sequence; + }; + + template + struct sequence_gen_impl + { + using type = Sequence<>; + }; + + using type = typename sequence_gen_impl<0, NSize, F>::type; +}; + +// arithmetic sequence +template +struct arithmetic_sequence_gen +{ + struct F + { + __host__ __device__ constexpr index_t operator()(index_t i) const + { + return i * Increment + IBegin; + } + }; + + using type = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type; +}; + +// uniform sequence +template +struct uniform_sequence_gen +{ + struct F + { + __host__ __device__ constexpr index_t operator()(index_t) const { return I; } + }; + + using type = typename sequence_gen::type; +}; + +// reverse inclusive scan (with init) sequence +template +struct sequence_reverse_inclusive_scan; + +template +struct sequence_reverse_inclusive_scan, Reduce, Init> +{ + using old_scan = typename sequence_reverse_inclusive_scan, Reduce, Init>::type; + + static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front()); + + using type = typename sequence_merge, old_scan>::type; +}; + +template +struct sequence_reverse_inclusive_scan, Reduce, Init> +{ + using type = Sequence; +}; + +template +struct sequence_reverse_inclusive_scan, Reduce, Init> +{ + using type = Sequence<>; +}; + +// split sequence +template +struct sequence_split +{ + static constexpr index_t NSize = Seq{}.Size(); + + using range0 = typename arithmetic_sequence_gen<0, I, 1>::type; + using range1 = typename arithmetic_sequence_gen::type; + + using left_type = decltype(Seq::Extract(range0{})); + using right_type = decltype(Seq::Extract(range1{})); +}; + +// reverse sequence +template +struct sequence_reverse +{ + static constexpr index_t NSize = Seq{}.Size(); + + using seq_split = sequence_split; + using type = typename sequence_merge< + typename sequence_reverse::type, + typename sequence_reverse::type>::type; +}; + +template +struct sequence_reverse> +{ + using type = Sequence; +}; + +template +struct sequence_reverse> +{ + using type = Sequence; +}; + +#if 1 +template +struct sequence_reduce +{ + using type = typename sequence_reduce::type>::type; +}; + +template +struct sequence_reduce, Sequence> +{ + using type = Sequence; +}; + +template +struct sequence_reduce +{ + using type = Seq; +}; +#endif + +template +struct sequence_sort_impl +{ + template + struct sorted_sequence_merge_impl + { + static constexpr bool choose_left = LeftValues::Front() < RightValues::Front(); + + static constexpr index_t chosen_value = + choose_left ? LeftValues::Front() : RightValues::Front(); + static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front(); + + using new_merged_values = decltype(MergedValues::PushBack(Number{})); + using new_merged_ids = decltype(MergedIds::PushBack(Number{})); + + using new_left_values = + typename conditional::type; + using new_left_ids = + typename conditional::type; + + using new_right_values = + typename conditional::type; + using new_right_ids = + typename conditional::type; + + using merge = sorted_sequence_merge_impl; + // this is output + using merged_values = typename merge::merged_values; + using merged_ids = typename merge::merged_ids; + }; + + template + struct sorted_sequence_merge_impl, + Sequence<>, + MergedValues, + MergedIds, + Comp> + { + using merged_values = typename sequence_merge::type; + using merged_ids = typename sequence_merge::type; + }; + + template + struct sorted_sequence_merge_impl, + Sequence<>, + RightValues, + RightIds, + MergedValues, + MergedIds, + Comp> + { + using merged_values = typename sequence_merge::type; + using merged_ids = typename sequence_merge::type; + }; + + template + struct sorted_sequence_merge + { + using merge = sorted_sequence_merge_impl, + Sequence<>, + Comp>; + + using merged_values = typename merge::merged_values; + using merged_ids = typename merge::merged_ids; + }; + + static constexpr index_t nsize = Values::Size(); + + using split_unsorted_values = sequence_split; + using split_unsorted_ids = sequence_split; + + using left_unsorted_values = typename split_unsorted_values::left_type; + using left_unsorted_ids = typename split_unsorted_ids::left_type; + using left_sort = sequence_sort_impl; + using left_sorted_values = typename left_sort::sorted_values; + using left_sorted_ids = typename left_sort::sorted_ids; + + using right_unsorted_values = typename split_unsorted_values::right_type; + using right_unsorted_ids = typename split_unsorted_ids::right_type; + using right_sort = sequence_sort_impl; + using right_sorted_values = typename right_sort::sorted_values; + using right_sorted_ids = typename right_sort::sorted_ids; + + using merged_sorted = sorted_sequence_merge; + + using sorted_values = typename merged_sorted::merged_values; + using sorted_ids = typename merged_sorted::merged_ids; +}; + +template +struct sequence_sort_impl, Sequence, Compare> +{ + static constexpr bool choose_x = Compare{}(ValueX, ValueY); + + using sorted_values = + typename conditional, Sequence>::type; + using sorted_ids = typename conditional, Sequence>::type; +}; + +template +struct sequence_sort_impl, Sequence, Compare> +{ + using sorted_values = Sequence; + using sorted_ids = Sequence; +}; + +template +struct sequence_sort_impl, Sequence<>, Compare> +{ + using sorted_values = Sequence<>; + using sorted_ids = Sequence<>; +}; + +template +struct sequence_sort +{ + using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type; + using sort = sequence_sort_impl; + + // this is output + using type = typename sort::sorted_values; + using sorted2unsorted_map = typename sort::sorted_ids; +}; + +template +struct sequence_unique_sort +{ + template + struct sorted_sequence_uniquify_impl + { + static constexpr index_t current_value = RemainValues::Front(); + static constexpr index_t current_id = RemainIds::Front(); + + static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back()); + + using new_remain_values = decltype(RemainValues::PopFront()); + using new_remain_ids = decltype(RemainIds::PopFront()); + + using new_uniquified_values = + typename conditional{})), + UniquifiedValues>::type; + + using new_uniquified_ids = + typename conditional{})), + UniquifiedIds>::type; + + using uniquify = sorted_sequence_uniquify_impl; + + // this is output + using uniquified_values = typename uniquify::uniquified_values; + using uniquified_ids = typename uniquify::uniquified_ids; + }; + + template + struct sorted_sequence_uniquify_impl, + Sequence<>, + UniquifiedValues, + UniquifiedIds, + Eq> + { + using uniquified_values = UniquifiedValues; + using uniquified_ids = UniquifiedIds; + }; + + template + struct sorted_sequence_uniquify + { + using uniquify = sorted_sequence_uniquify_impl, + Sequence, + Eq>; + + using uniquified_values = typename uniquify::uniquified_values; + using uniquified_ids = typename uniquify::uniquified_ids; + }; + + using sort = sequence_sort; + using sorted_values = typename sort::type; + using sorted_ids = typename sort::sorted2unsorted_map; + + using uniquify = sorted_sequence_uniquify; + + // this is output + using type = typename uniquify::uniquified_values; + using sorted2unsorted_map = typename uniquify::uniquified_ids; +}; + +template +struct is_valid_sequence_map : is_same::type, + typename sequence_sort>::type> +{ +}; + +template +struct sequence_map_inverse +{ + template + struct sequence_map_inverse_impl + { + static constexpr auto new_y2x = + WorkingY2X::Modify(X2Y::At(Number{}), Number{}); + + using type = + typename sequence_map_inverse_impl:: + type; + }; + + template + struct sequence_map_inverse_impl + { + using type = WorkingY2X; + }; + + using type = + typename sequence_map_inverse_impl::type, + 0, + SeqMap::Size()>::type; +}; + +template +__host__ __device__ constexpr auto operator+(Sequence, Sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return Sequence<(Xs + Ys)...>{}; +} + +template +__host__ __device__ constexpr auto operator-(Sequence, Sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return Sequence<(Xs - Ys)...>{}; +} + +template +__host__ __device__ constexpr auto operator*(Sequence, Sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return Sequence<(Xs * Ys)...>{}; +} + +template +__host__ __device__ constexpr auto operator/(Sequence, Sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return Sequence<(Xs / Ys)...>{}; +} + +template +__host__ __device__ constexpr auto operator%(Sequence, Sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return Sequence<(Xs % Ys)...>{}; +} + +template +__host__ __device__ constexpr auto operator+(Sequence, Number) +{ + return Sequence<(Xs + Y)...>{}; +} + +template +__host__ __device__ constexpr auto operator-(Sequence, Number) +{ + return Sequence<(Xs - Y)...>{}; +} + +template +__host__ __device__ constexpr auto operator*(Sequence, Number) +{ + return Sequence<(Xs * Y)...>{}; +} + +template +__host__ __device__ constexpr auto operator/(Sequence, Number) +{ + return Sequence<(Xs / Y)...>{}; +} + +template +__host__ __device__ constexpr auto operator%(Sequence, Number) +{ + return Sequence<(Xs % Y)...>{}; +} + +template +__host__ __device__ constexpr auto operator+(Number, Sequence) +{ + return Sequence<(Y + Xs)...>{}; +} + +template +__host__ __device__ constexpr auto operator-(Number, Sequence) +{ + constexpr auto seq_x = Sequence{}; + + return Sequence<(Y - Xs)...>{}; +} + +template +__host__ __device__ constexpr auto operator*(Number, Sequence) +{ + return Sequence<(Y * Xs)...>{}; +} + +template +__host__ __device__ constexpr auto operator/(Number, Sequence) +{ + return Sequence<(Y / Xs)...>{}; +} + +template +__host__ __device__ constexpr auto operator%(Number, Sequence) +{ + return Sequence<(Y % Xs)...>{}; +} + +template +__host__ __device__ constexpr auto sequence_pop_front(Sequence) +{ + return Sequence{}; +} + +template +__host__ __device__ constexpr auto sequence_pop_back(Seq) +{ + static_assert(Seq::Size() > 0, "wrong! cannot pop an empty Sequence!"); + return sequence_pop_front(Seq::Reverse()).Reverse(); +} + +template +__host__ __device__ constexpr auto merge_sequences(Seqs...) +{ + return typename sequence_merge::type{}; +} + +template +__host__ __device__ constexpr auto transform_sequences(F f, Sequence) +{ + return Sequence{}; +} + +template +__host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence) +{ + static_assert(Sequence::mSize == Sequence::mSize, "Dim not the same"); + + return Sequence{}; +} + +template +__host__ __device__ constexpr auto +transform_sequences(F f, Sequence, Sequence, Sequence) +{ + static_assert(Sequence::mSize == Sequence::mSize && + Sequence::mSize == Sequence::mSize, + "Dim not the same"); + + return Sequence{}; +} + +template +__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number) +{ + return typename sequence_reverse_inclusive_scan::type{}; +} + +template +__host__ __device__ constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce, Number) +{ + return reverse_inclusive_scan_sequence(Seq::PopFront(), Reduce{}, Number{}) + .PushBack(Number{}); +} + +template +__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number) +{ + return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number{}).Reverse(); +} + +template +__host__ __device__ constexpr auto pick_sequence_elements_by_ids(Seq, Sequence /* ids */) +{ + return Sequence{})...>{}; +} + +#if 1 +namespace detail { +template +struct pick_sequence_elements_by_mask_impl +{ + using new_work_seq = typename conditional::type; + + using type = + typename pick_sequence_elements_by_mask_impl::type; +}; + +template +struct pick_sequence_elements_by_mask_impl, Sequence<>> +{ + using type = WorkSeq; +}; + +} // namespace detail + +template +__host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask) +{ + static_assert(Seq::Size() == Mask::Size(), "wrong!"); + + return typename detail::pick_sequence_elements_by_mask_impl, Seq, Mask>::type{}; +} + +namespace detail { +template +struct modify_sequence_elements_by_ids_impl +{ + using new_work_seq = decltype(WorkSeq::Modify(RemainIds::Front(), RemainValues::Front())); + + using type = + typename modify_sequence_elements_by_ids_impl::type; +}; + +template +struct modify_sequence_elements_by_ids_impl, Sequence<>> +{ + using type = WorkSeq; +}; +} // namespace detail + +template +__host__ __device__ constexpr auto modify_sequence_elements_by_ids(Seq, Values, Ids) +{ + static_assert(Values::Size() == Ids::Size() && Seq::Size() >= Values::Size(), "wrong!"); + + return typename detail::modify_sequence_elements_by_ids_impl::type{}; +} +#endif + +template +__host__ __device__ constexpr index_t +reduce_on_sequence(Seq, Reduce f, Number /*initial_value*/) +{ + index_t result = Init; + + for(index_t i = 0; i < Seq::Size(); ++i) + { + result = f(result, Seq::At(i)); + } + + return result; +} + +// TODO: a generic any_of for any container +template +__host__ __device__ constexpr bool sequence_any_of(Seq, F f) +{ + bool flag = false; + + for(index_t i = 0; i < Seq::Size(); ++i) + { + flag = flag || f(Seq::At(i)); + } + + return flag; +} + +// TODO: a generic all_of for any container +template +__host__ __device__ constexpr bool sequence_all_of(Seq, F f) +{ + bool flag = true; + + for(index_t i = 0; i < Seq::Size(); ++i) + { + flag = flag && f(Seq::At(i)); + } + + return flag; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/sequence_helper.hpp b/composable_kernel/include/utility/sequence_helper.hpp new file mode 100644 index 0000000000..88d7da63e8 --- /dev/null +++ b/composable_kernel/include/utility/sequence_helper.hpp @@ -0,0 +1,36 @@ +#ifndef CK_SEQUENCE_HELPER_HPP +#define CK_SEQUENCE_HELPER_HPP + +#include "tuple.hpp" + +namespace ck { + +template +__host__ __device__ constexpr auto make_sequence(Number...) +{ + return Sequence{}; +} + +// F returns index_t +template +__host__ __device__ constexpr auto generate_sequence(F, Number) +{ + return typename sequence_gen::type{}; +} + +// F returns Number<> +template +__host__ __device__ constexpr auto generate_sequence_v2(F&& f, Number) +{ + return unpack([&f](auto&&... xs) { return make_sequence(f(xs)...); }, + typename arithmetic_sequence_gen<0, N, 1>::type{}); +} + +template +__host__ __device__ constexpr auto to_sequence(Tuple...>) +{ + return Sequence{}; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/static_buffer.hpp b/composable_kernel/include/utility/static_buffer.hpp new file mode 100644 index 0000000000..a23cf4f80d --- /dev/null +++ b/composable_kernel/include/utility/static_buffer.hpp @@ -0,0 +1,35 @@ +#ifndef CK_STATIC_BUFFER_HPP +#define CK_STATIC_BUFFER_HPP + +#include "statically_indexed_array.hpp" + +namespace ck { + +template +struct StaticBuffer : public StaticallyIndexedArray +{ + using type = T; + using base = StaticallyIndexedArray; + + __host__ __device__ constexpr StaticBuffer() : base{} {} + + __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() + { + return BufferAddressSpace; + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } + + __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } +}; + +template +__host__ __device__ constexpr auto make_static_buffer(Number) +{ + return StaticBuffer{}; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/statically_indexed_array.hpp b/composable_kernel/include/utility/statically_indexed_array.hpp new file mode 100644 index 0000000000..f30a3a9ee6 --- /dev/null +++ b/composable_kernel/include/utility/statically_indexed_array.hpp @@ -0,0 +1,40 @@ +#ifndef CK_STATICALLY_INDEXED_ARRAY_HPP +#define CK_STATICALLY_INDEXED_ARRAY_HPP + +#include "functional2.hpp" +#include "sequence.hpp" +#include "tuple.hpp" + +namespace ck { + +namespace detail { + +template +__host__ __device__ constexpr auto generate_same_type_tuple() +{ + return generate_tuple([](auto) -> T { return T{}; }, Number{}); +} + +template +using same_type_tuple = decltype(generate_same_type_tuple()); + +} // namespace detail + +template +using StaticallyIndexedArray = detail::same_type_tuple; + +template +__host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs) +{ + return StaticallyIndexedArray(x, static_cast(xs)...); +} + +// make empty StaticallyIndexedArray +template +__host__ __device__ constexpr auto make_statically_indexed_array() +{ + return StaticallyIndexedArray(); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/statically_indexed_array_multi_index.hpp b/composable_kernel/include/utility/statically_indexed_array_multi_index.hpp new file mode 100644 index 0000000000..9e96f06d73 --- /dev/null +++ b/composable_kernel/include/utility/statically_indexed_array_multi_index.hpp @@ -0,0 +1,108 @@ +#ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP +#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP + +#include "common_header.hpp" + +namespace ck { + +template +using MultiIndex = StaticallyIndexedArray; + +template +__host__ __device__ constexpr auto make_multi_index(Xs&&... xs) +{ + return make_statically_indexed_array(index_t{xs}...); +} + +template +__host__ __device__ constexpr auto make_zero_multi_index() +{ + return unpack([](auto... xs) { return make_multi_index(xs...); }, + typename uniform_sequence_gen::type{}); +} + +template +__host__ __device__ constexpr auto to_multi_index(const T& x) +{ + return unpack([](auto... ys) { return make_multi_index(ys...); }, x); +} + +// Here should use MultiIndex, instead of Tuple, although the former +// is the alias of the latter. This is because compiler cannot infer the NSize if +// using MultiIndex +// TODO: how to fix this? +template +__host__ __device__ constexpr auto operator+=(Tuple& y, const X& x) +{ + static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Ys); + static_for<0, NSize, 1>{}([&](auto i) { y(i) += x[i]; }); + return y; +} + +template +__host__ __device__ constexpr auto operator-=(Tuple& y, const X& x) +{ + static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Ys); + static_for<0, NSize, 1>{}([&](auto i) { y(i) -= x[i]; }); + return y; +} + +template +__host__ __device__ constexpr auto operator+(const Tuple& x, const Y& y) +{ + static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Xs); + + Tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] + y[i]; }); + return r; +} + +template +__host__ __device__ constexpr auto operator-(const Tuple& x, const Y& y) +{ + static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Xs); + + Tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] - y[i]; }); + return r; +} + +template +__host__ __device__ constexpr auto operator*(const Tuple& x, const Y& y) +{ + static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Xs); + + Tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] * y[i]; }); + return r; +} + +// MultiIndex = index_t * MultiIndex +template +__host__ __device__ constexpr auto operator*(index_t a, const Tuple& x) +{ + constexpr index_t NSize = sizeof...(Xs); + + Tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = a * x[i]; }); + return r; +} + +template +__host__ __device__ void print_multi_index(const Tuple& x) +{ + printf("{"); + printf("MultiIndex, "); + printf("size %d,", index_t{sizeof...(Xs)}); + static_for<0, sizeof...(Xs), 1>{}( + [&](auto i) { printf("%d ", static_cast(x.At(i))); }); + printf("}"); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/synchronization.hpp b/composable_kernel/include/utility/synchronization.hpp new file mode 100644 index 0000000000..da74f2074d --- /dev/null +++ b/composable_kernel/include/utility/synchronization.hpp @@ -0,0 +1,21 @@ +#ifndef CK_SYNCHRONIZATION_AMD_HPP +#define CK_SYNCHRONIZATION_AMD_HPP + +#include "config.hpp" + +namespace ck { + +__device__ void block_sync_lds() +{ +#if CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM + asm volatile("\ + s_waitcnt lgkmcnt(0) \n \ + s_barrier \ + " ::); +#else + __syncthreads(); +#endif +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/tuple.hpp b/composable_kernel/include/utility/tuple.hpp new file mode 100644 index 0000000000..15b73011b4 --- /dev/null +++ b/composable_kernel/include/utility/tuple.hpp @@ -0,0 +1,167 @@ +#ifndef CK_TUPLE_HPP +#define CK_TUPLE_HPP + +#include "integral_constant.hpp" +#include "sequence.hpp" +#include "type.hpp" + +namespace ck { + +namespace detail { + +template +struct TupleElementKey +{ + __host__ __device__ constexpr TupleElementKey() = default; +}; + +template +struct TupleElement +{ + __host__ __device__ constexpr TupleElement() = default; + + template < + typename T, + typename std::enable_if>, TupleElement>::value, + bool>::type = false> + __host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward(v)) + { + } + + Data mData; +}; + +template +__host__ __device__ constexpr const Data& get_tuple_element(const TupleElement& x) +{ + return static_cast(x.mData); +} + +template +__host__ __device__ constexpr Data& get_tuple_element(TupleElement& x) +{ + return x.mData; +} + +// TODO: not sure the use of reference is correct +template +__host__ __device__ constexpr Data&& get_tuple_element(TupleElement&& x) +{ + return static_cast(x.mData); +} + +template +struct TupleImpl; + +template +struct TupleImpl, Xs...> : TupleElement, Xs>... +{ + __host__ __device__ constexpr TupleImpl() = default; + + template < + typename Y, + typename std::enable_if>, TupleImpl>::value, + bool>::type = false> + __host__ __device__ constexpr TupleImpl(Y&& y) + : TupleElement, Xs>(std::forward(y))... + { + } + + template = 2, bool>::type = false> + __host__ __device__ constexpr TupleImpl(Ys&&... ys) + : TupleElement, Xs>(std::forward(ys))... + { + static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys), + "wrong! inconsistent size"); + } + + __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); } + + template + __host__ __device__ constexpr const auto& GetElementByKey(TupleElementKey) const + { + return get_tuple_element>(*this); + } + + template + __host__ __device__ constexpr auto& GetElementByKey(TupleElementKey) + { + return get_tuple_element>(*this); + } +}; + +} // namespace detail + +template +struct Tuple : detail::TupleImpl::type, Xs...> +{ + using base = + detail::TupleImpl::type, Xs...>; + + __host__ __device__ constexpr Tuple() = default; + + template >, Tuple>::value, + bool>::type = false> + __host__ __device__ constexpr Tuple(Y&& y) : base(std::forward(y)) + { + } + + template = 2, + bool>::type = false> + __host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward(ys)...) + { + } + + __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); } + + template + __host__ __device__ constexpr const auto& At(Number) const + { + static_assert(I < base::Size(), "wrong! out of range"); + return base::GetElementByKey(detail::TupleElementKey{}); + } + + template + __host__ __device__ constexpr auto& At(Number) + { + static_assert(I < base::Size(), "wrong! out of range"); + return base::GetElementByKey(detail::TupleElementKey{}); + } + + template + __host__ __device__ constexpr const auto& operator[](Number i) const + { + return At(i); + } + + template + __host__ __device__ constexpr auto& operator()(Number i) + { + return At(i); + } + + template + __host__ __device__ constexpr auto operator=(const T& a) + { + static_assert(T::Size() == Size(), "wrong! size not the same"); + + static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; }); + + return *this; + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } +}; + +template +__host__ __device__ constexpr auto make_tuple(Xs&&... xs) +{ + return Tuple>...>(std::forward(xs)...); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/tuple_helper.hpp b/composable_kernel/include/utility/tuple_helper.hpp new file mode 100644 index 0000000000..9499a3596c --- /dev/null +++ b/composable_kernel/include/utility/tuple_helper.hpp @@ -0,0 +1,80 @@ +#ifndef CK_TUPLE_HELPER_HPP +#define CK_TUPLE_HELPER_HPP + +#include "functional4.hpp" +#include "tuple.hpp" + +namespace ck { + +template +struct is_known_at_compile_time> +{ + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return container_reduce( + Tuple{}, + [](auto x, bool r) { + return is_known_at_compile_time< + remove_cv_t>>::value & + r; + }, + true); + } + + static constexpr bool value = IsKnownAtCompileTime(); +}; + +template +__host__ __device__ constexpr auto generate_tuple(F&& f, Number) +{ + return unpack([&f](auto&&... xs) { return make_tuple(f(xs)...); }, + typename arithmetic_sequence_gen<0, N, 1>::type{}); +} + +namespace detail { + +template +__host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, Sequence) +{ + return make_tuple(f(x.At(Number{}))...); +} + +template +__host__ __device__ constexpr auto +transform_tuples_impl(F f, const X& x, const Y& y, Sequence) +{ + return make_tuple(f(x.At(Number{}), y.At(Number{}))...); +} + +template +__host__ __device__ constexpr auto +transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, Sequence) +{ + return make_tuple(f(x.At(Number{}), y.At(Number{}), z.At(Number{}))...); +} + +} // namespace detail + +template +__host__ __device__ constexpr auto transform_tuples(F f, const X& x) +{ + return detail::transform_tuples_impl( + f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); +} + +template +__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y) +{ + return detail::transform_tuples_impl( + f, x, y, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); +} + +template +__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z) +{ + return detail::transform_tuples_impl( + f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/type.hpp b/composable_kernel/include/utility/type.hpp new file mode 100644 index 0000000000..32f7dfb569 --- /dev/null +++ b/composable_kernel/include/utility/type.hpp @@ -0,0 +1,60 @@ +#ifndef CK_TYPE_HPP +#define CK_TYPE_HPP + +#include "integral_constant.hpp" + +namespace ck { + +template +struct is_same : public integral_constant +{ +}; + +template +struct is_same : public integral_constant +{ +}; + +template +using remove_reference_t = typename std::remove_reference::type; + +template +using remove_cv_t = typename std::remove_cv::type; + +template +constexpr std::remove_reference_t&& move(T&& t) noexcept +{ + return static_cast::type&&>(t); +} + +template +struct is_known_at_compile_time; + +template <> +struct is_known_at_compile_time +{ + static constexpr bool value = false; +}; + +template +struct is_known_at_compile_time> +{ + static constexpr bool value = true; +}; + +template ::type = false> +__host__ __device__ constexpr Y as_type(X x) +{ + union AsType + { + X x; + Y y; + }; + + return AsType{x}.y; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/utility.hpp b/composable_kernel/include/utility/utility.hpp new file mode 100644 index 0000000000..9f34e044b7 --- /dev/null +++ b/composable_kernel/include/utility/utility.hpp @@ -0,0 +1,14 @@ +#ifndef CK_UTILITY_HPP +#define CK_UTILITY_HPP + +#include "config.hpp" + +namespace ck { + +__device__ index_t get_thread_local_1d_id() { return threadIdx.x; } + +__device__ index_t get_block_1d_id() { return blockIdx.x; } + +} // namespace ck + +#endif diff --git a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp new file mode 100644 index 0000000000..652ccdb926 --- /dev/null +++ b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp @@ -0,0 +1,374 @@ +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm_dlops_v1r2.hpp" +#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" + +using namespace ck; + +constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); +constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); +constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); + +using FloatAB = typename get_datatype_from_enum::type; +using FloatAcc = typename get_datatype_from_enum::type; +using FloatC = typename get_datatype_from_enum::type; + +constexpr index_t BlockSize = CK_PARAM_BlockSize; + +constexpr index_t MPerBlock = CK_PARAM_MPerBlock; +constexpr index_t NPerBlock = CK_PARAM_NPerBlock; +constexpr index_t KPerBlock = CK_PARAM_KPerBlock; +constexpr index_t M1PerThread = CK_PARAM_M1PerThread; +constexpr index_t N1PerThread = CK_PARAM_N1PerThread; +constexpr index_t KPerThread = CK_PARAM_KPerThread; +constexpr index_t M1N1ThreadClusterM10 = CK_PARAM_M1N1ThreadClusterM10; +constexpr index_t M1N1ThreadClusterN10 = CK_PARAM_M1N1ThreadClusterN10; +constexpr index_t M1N1ThreadClusterM11 = CK_PARAM_M1N1ThreadClusterM11; +constexpr index_t M1N1ThreadClusterN11 = CK_PARAM_M1N1ThreadClusterN11; + +using ABlockTransferThreadSliceLengths_K_M0_M1 = + Sequence; +using ABlockTransferThreadClusterLengths_K_M0_M1 = + Sequence; +using ABlockTransferThreadClusterArrangeOrder = + Sequence; +using ABlockTransferSrcAccessOrder = Sequence; + +constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim; +constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector; +constexpr index_t ABlockTransferDstScalarPerVector_M1 = + CK_PARAM_ABlockTransferDstScalarPerVector_M1; +constexpr bool AThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun); + +using BBlockTransferThreadSliceLengths_K_N0_N1 = + Sequence; +using BBlockTransferThreadClusterLengths_K_N0_N1 = + Sequence; +using BBlockTransferThreadClusterArrangeOrder = + Sequence; +using BBlockTransferSrcAccessOrder = Sequence; + +constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim; +constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector; +constexpr index_t BBlockTransferDstScalarPerVector_N1 = + CK_PARAM_BBlockTransferDstScalarPerVector_N1; +constexpr bool BThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun); + +using CThreadTransferSrcDstAccessOrder = Sequence; +constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim; +constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; + +constexpr bool HasMainKBlockLoop = static_cast(CK_PARAM_HAS_MAIN_KBLOCK_LOOP); +constexpr bool HasDoubleTailKBlockLoop = static_cast(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP); + +extern "C" __global__ void +dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare( + int n, + int c, + int hi, + int wi, + int k, + int y, + int x, + int convStrideH, + int convStrideW, + int convDilationY, + int convDilationX, + int leftPadH, + int leftPadW, + int rightPadH, + int rightPadW, + void* p_a_k_m0_m1_grid_desc, + void* p_b_k_n0_n1_grid_desc, + void* p_c_m0_m10_m11_n0_n10_n11_grid_desc, + void* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; + const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; + + const auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, c, hi, wi)); + const auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, c, y, x)); + const auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, k, ho, wo)); + + const auto descs = transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + make_tuple(convStrideH, convStrideW), + make_tuple(convDilationY, convDilationX), + make_tuple(leftPadH, leftPadW), + make_tuple(rightPadH, rightPadW)); + + const auto a_k_m_grid_desc = descs[I0]; + const auto b_k_n_grid_desc = descs[I1]; + const auto c_m_n_grid_desc = descs[I2]; + + using AKMGridDesc = decltype(a_k_m_grid_desc); + using BKNGridDesc = decltype(b_k_n_grid_desc); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}))); + + using BGridIteratorHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}))); + + using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; + using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + + using GridwiseGemm = + GridwiseDynamicGemmDlops_km_kn_mn_v1r2; + + auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); + auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); + auto c_m0_m10_m11_n0_n10_n11_grid_desc = + GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); + auto c_blockid_to_m0_n0_block_cluster_adaptor = + GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); + + if(hipThreadIdx_x == 0) + { + *static_cast(p_a_k_m0_m1_grid_desc) = a_k_m0_m1_grid_desc; + *static_cast(p_b_k_n0_n1_grid_desc) = b_k_n0_n1_grid_desc; + *static_cast( + p_c_m0_m10_m11_n0_n10_n11_grid_desc) = c_m0_m10_m11_n0_n10_n11_grid_desc; + *static_cast( + p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; + }; +}; + +extern "C" __global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_a_k_m0_m1_grid_desc, + const void CONSTANT* p_b_k_n0_n1_grid_desc, + const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, + const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + constexpr auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); + constexpr auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3)); + constexpr auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); + + constexpr auto descs = + transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1)); + + constexpr auto a_k_m_grid_desc = descs[I0]; + constexpr auto b_k_n_grid_desc = descs[I1]; + constexpr auto c_m_n_grid_desc = descs[I2]; + + using AKMGridDesc = decltype(a_k_m_grid_desc); + using BKNGridDesc = decltype(b_k_n_grid_desc); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}))); + + using BGridIteratorHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}))); + + using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; + using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + + using GridwiseGemm = + GridwiseDynamicGemmDlops_km_kn_mn_v1r2; + + constexpr auto a_k_m0_m1_grid_desc_tmp = + GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); + constexpr auto b_k_n0_n1_grid_desc_tmp = + GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); + constexpr auto c_m0_m10_m11_n0_n10_n11_grid_desc_tmp = + GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); + constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = + GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); + + using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc_tmp); + using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc_tmp); + using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc_tmp); + using CBlockIdToM0N0BlockClusterAdaptor = + decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp); + + const auto a_k_m0_m1_grid_desc = + *reinterpret_cast((const void*)p_a_k_m0_m1_grid_desc); + const auto b_k_n0_n1_grid_desc = + *reinterpret_cast((const void*)p_b_k_n0_n1_grid_desc); + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = + *reinterpret_cast( + (const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc); + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + *reinterpret_cast( + (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +}; diff --git a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp new file mode 100644 index 0000000000..d33bc74aa6 --- /dev/null +++ b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp @@ -0,0 +1,362 @@ +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" + +using namespace ck; + +constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); +constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); +constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); + +using FloatAB = typename get_datatype_from_enum::type; +using FloatAcc = typename get_datatype_from_enum::type; +using FloatC = typename get_datatype_from_enum::type; + +constexpr index_t BlockSize = CK_PARAM_BlockSize; + +constexpr index_t MPerBlock = CK_PARAM_MPerBlock; +constexpr index_t NPerBlock = CK_PARAM_NPerBlock; +constexpr index_t KPerBlock = CK_PARAM_KPerBlock; + +constexpr index_t MPerWave = CK_PARAM_MPerWave; +constexpr index_t NPerWave = CK_PARAM_NPerWave; +constexpr index_t MRepeat = CK_PARAM_MRepeat; +constexpr index_t NRepeat = CK_PARAM_NRepeat; +constexpr index_t K1 = CK_PARAM_K1; + +using ABlockTransferThreadSliceLengths_K0_M_K1 = + Sequence; +using ABlockTransferThreadClusterLengths_K0_M_K1 = + Sequence; +using ABlockTransferThreadClusterArrangeOrder = + Sequence; +using ABlockTransferSrcAccessOrder = Sequence; + +constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim; +constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector; +constexpr index_t ABlockTransferDstScalarPerVector_K1 = + CK_PARAM_ABlockTransferDstScalarPerVector_K1; +constexpr bool AThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun); + +using BBlockTransferThreadSliceLengths_K0_N_K1 = + Sequence; +using BBlockTransferThreadClusterLengths_K0_N_K1 = + Sequence; +using BBlockTransferThreadClusterArrangeOrder = + Sequence; +using BBlockTransferSrcAccessOrder = Sequence; + +constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim; +constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector; +constexpr index_t BBlockTransferDstScalarPerVector_K1 = + CK_PARAM_BBlockTransferDstScalarPerVector_K1; +constexpr bool BThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun); + +using CThreadTransferSrcDstAccessOrder = Sequence; +constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim; +constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; + +extern "C" __global__ void +dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare( + int n, + int c, + int hi, + int wi, + int k, + int y, + int x, + int convStrideH, + int convStrideW, + int convDilationY, + int convDilationX, + int leftPadH, + int leftPadW, + int rightPadH, + int rightPadW, + void* p_a_k0_m_k1_grid_desc, + void* p_b_k0_n_k1_grid_desc, + void* p_c_m0_m1_m2_n_grid_desc, + void* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; + const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; + + const auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, c, hi, wi)); + const auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, c, y, x)); + const auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, k, ho, wo)); + + const auto descs = transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + make_tuple(convStrideH, convStrideW), + make_tuple(convDilationY, convDilationX), + make_tuple(leftPadH, leftPadW), + make_tuple(rightPadH, rightPadW), + Number{}); + + const auto a_k0_m_k1_grid_desc = descs[I0]; + const auto b_k0_n_k1_grid_desc = descs[I1]; + const auto c_m_n_grid_desc = descs[I2]; + + using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc); + using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using AGridIteratorHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); + + using BGridIteratorHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); + + using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; + using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + + using GridwiseGemm = + GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + + auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + + auto c_blockid_to_m0_n0_block_cluster_adaptor = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + + if(hipThreadIdx_x == 0) + { + *static_cast*>(p_a_k0_m_k1_grid_desc) = + a_k0_m_k1_grid_desc; + *static_cast*>(p_b_k0_n_k1_grid_desc) = + b_k0_n_k1_grid_desc; + *static_cast(p_c_m0_m1_m2_n_grid_desc) = + c_m0_m1_m2_n_grid_desc; + *static_cast( + p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; + } +}; + +extern "C" __global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_a_k0_m_k1_grid_desc, + const void CONSTANT* p_b_k0_n_k1_grid_desc, + const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, + const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + constexpr auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); + constexpr auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3)); + constexpr auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); + + constexpr auto descs = + transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + Number{}); + + constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0]; + constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1]; + constexpr auto c_m_n_grid_desc = descs[I2]; + + using AGridIteratorHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); + + using BGridIteratorHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); + + using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; + using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + + using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp); + using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using GridwiseGemm = + GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + + constexpr auto c_m0_m1_m2_n_grid_desc_tmp = + GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + + using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp); + using CBlockIdToM0N0BlockClusterAdaptor = + decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp); + + const auto a_k0_m_k1_grid_desc = + *reinterpret_cast((const void*)p_a_k0_m_k1_grid_desc); + const auto b_k0_n_k1_grid_desc = + *reinterpret_cast((const void*)p_b_k0_n_k1_grid_desc); + const auto c_m0_m1_m2_n_grid_desc = + *reinterpret_cast((const void*)p_c_m0_m1_m2_n_grid_desc); + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + *reinterpret_cast( + (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_m1_m2_n_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); +}; diff --git a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp new file mode 100644 index 0000000000..d49693b511 --- /dev/null +++ b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp @@ -0,0 +1,362 @@ +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" + +using namespace ck; + +constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); +constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); +constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); + +using FloatAB = typename get_datatype_from_enum::type; +using FloatAcc = typename get_datatype_from_enum::type; +using FloatC = typename get_datatype_from_enum::type; + +constexpr index_t BlockSize = CK_PARAM_BlockSize; + +constexpr index_t MPerBlock = CK_PARAM_MPerBlock; +constexpr index_t NPerBlock = CK_PARAM_NPerBlock; +constexpr index_t KPerBlock = CK_PARAM_KPerBlock; + +constexpr index_t MPerWave = CK_PARAM_MPerWave; +constexpr index_t NPerWave = CK_PARAM_NPerWave; +constexpr index_t MRepeat = CK_PARAM_MRepeat; +constexpr index_t NRepeat = CK_PARAM_NRepeat; +constexpr index_t K1 = CK_PARAM_K1; + +using ABlockTransferThreadSliceLengths_K0_M_K1 = + Sequence; +using ABlockTransferThreadClusterLengths_K0_M_K1 = + Sequence; +using ABlockTransferThreadClusterArrangeOrder = + Sequence; +using ABlockTransferSrcAccessOrder = Sequence; + +constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim; +constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector; +constexpr index_t ABlockTransferDstScalarPerVector_K1 = + CK_PARAM_ABlockTransferDstScalarPerVector_K1; +constexpr bool AThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun); + +using BBlockTransferThreadSliceLengths_K0_N_K1 = + Sequence; +using BBlockTransferThreadClusterLengths_K0_N_K1 = + Sequence; +using BBlockTransferThreadClusterArrangeOrder = + Sequence; +using BBlockTransferSrcAccessOrder = Sequence; + +constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim; +constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector; +constexpr index_t BBlockTransferDstScalarPerVector_K1 = + CK_PARAM_BBlockTransferDstScalarPerVector_K1; +constexpr bool BThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun); + +using CThreadTransferSrcDstAccessOrder = Sequence; +constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim; +constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; + +extern "C" __global__ void +dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare( + int n, + int hi, + int wi, + int c, + int k, + int y, + int x, + int convStrideH, + int convStrideW, + int convDilationY, + int convDilationX, + int leftPadH, + int leftPadW, + int rightPadH, + int rightPadW, + void* p_a_k0_m_k1_grid_desc, + void* p_b_k0_n_k1_grid_desc, + void* p_c_m0_m1_m2_n_grid_desc, + void* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; + const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; + + const auto in_n_hi_wi_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, hi, wi, c)); + const auto wei_k_y_x_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, y, x, c)); + const auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, ho, wo, k)); + + const auto descs = transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( + in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + make_tuple(convStrideH, convStrideW), + make_tuple(convDilationY, convDilationX), + make_tuple(leftPadH, leftPadW), + make_tuple(rightPadH, rightPadW), + Number{}); + + const auto a_k0_m_k1_grid_desc = descs[I0]; + const auto b_k0_n_k1_grid_desc = descs[I1]; + const auto c_m_n_grid_desc = descs[I2]; + + using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc); + using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using BGridIteratorHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); + + using AGridIteratorHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); + + using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; + + using GridwiseGemm = + GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + + auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + + auto c_blockid_to_m0_n0_block_cluster_adaptor = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + + if(hipThreadIdx_x == 0) + { + *static_cast*>(p_a_k0_m_k1_grid_desc) = + a_k0_m_k1_grid_desc; + *static_cast*>(p_b_k0_n_k1_grid_desc) = + b_k0_n_k1_grid_desc; + *static_cast(p_c_m0_m1_m2_n_grid_desc) = + c_m0_m1_m2_n_grid_desc; + *static_cast( + p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; + } +}; + +extern "C" __global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_a_k0_m_k1_grid_desc, + const void CONSTANT* p_b_k0_n_k1_grid_desc, + const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, + const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_n_hi_wi_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 28, 28, 256)); + constexpr auto wei_k_y_x_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 3, 3, 256)); + constexpr auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 28, 28, 256)); + + constexpr auto descs = + transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + Number{}); + + constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0]; + constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1]; + constexpr auto c_m_n_grid_desc = descs[I2]; + + using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp); + using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using BGridIteratorHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); + + using AGridIteratorHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); + + using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; + + using GridwiseGemm = + GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + constexpr auto c_m0_m1_m2_n_grid_desc_tmp = + GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + + using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp); + using CBlockIdToM0N0BlockClusterAdaptor = + decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp); + + const auto a_k0_m_k1_grid_desc = + *reinterpret_cast((const void*)p_a_k0_m_k1_grid_desc); + const auto b_k0_n_k1_grid_desc = + *reinterpret_cast((const void*)p_b_k0_n_k1_grid_desc); + const auto c_m0_m1_m2_n_grid_desc = + *reinterpret_cast((const void*)p_c_m0_m1_m2_n_grid_desc); + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + *reinterpret_cast( + (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_m1_m2_n_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); +}; diff --git a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp new file mode 100644 index 0000000000..90c957bb0b --- /dev/null +++ b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp @@ -0,0 +1,392 @@ +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_contraction_dlops_v1r2.hpp" +#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" + +using namespace ck; + +constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); +constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); +constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); + +using FloatAB = typename get_datatype_from_enum::type; +using FloatAcc = typename get_datatype_from_enum::type; +using FloatC = typename get_datatype_from_enum::type; + +constexpr index_t BlockSize = CK_PARAM_BlockSize; + +constexpr auto GN0 = Number{}; +constexpr auto GK1 = Number{}; + +constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11; +constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11; +constexpr index_t GK0PerBlock = CK_PARAM_GK0PerBlock; + +constexpr index_t BM1PerThreadBM11 = CK_PARAM_BM1PerThreadBM11; +constexpr index_t BN1PerThreadBN11 = CK_PARAM_BN1PerThreadBN11; +constexpr index_t BK0PerThread = CK_PARAM_BK0PerThread; + +using BM10BN10ThreadClusterBM10Xs = Sequence; +using BM10BN10ThreadClusterBN10Xs = Sequence; + +using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = + Sequence; +using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = + Sequence; +using ABlockTransferThreadClusterArrangeOrder = Sequence<1, 2, 3, 0, 4>; +using ABlockTransferSrcAccessOrder = Sequence<3, 2, 1, 0, 4>; +using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = + Sequence; +using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = + Sequence; +using ABlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>; + +using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = + Sequence; +using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = + Sequence; +using BBlockTransferThreadClusterArrangeOrder = Sequence<0, 4, 1, 2, 3>; +using BBlockTransferSrcAccessOrder = Sequence<4, 3, 2, 0, 1>; +using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = + Sequence; +using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = + Sequence; +using BBlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>; + +using CThreadTransferSrcDstAccessOrder = Sequence<3, 4, 5, 0, 1, 2>; +constexpr index_t CThreadTransferSrcDstVectorDim = 5; +constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; + +constexpr bool HasMainKBlockLoop = static_cast(CK_PARAM_HasMainKBlockLoop); +constexpr bool HasDoubleTailKBlockLoop = static_cast(CK_PARAM_HasDoubleTailKBlockLoop); + +extern "C" __global__ void +dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N, + index_t C, + index_t Hi, + index_t Wi, + index_t K, + index_t Y, + index_t X, + index_t ConvStrideH, + index_t ConvStrideW, + index_t ConvDilationH, + index_t ConvDilationW, + index_t InLeftPadH, + index_t InLeftPadW, + index_t InRightPadH, + index_t InRightPadW, + void* p_desc_tuple) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + const index_t Ho = + (Hi + InLeftPadH + InRightPadH - ConvDilationH * (Y - 1) - 1) / ConvStrideH + 1; + const index_t Wo = + (Wi + InLeftPadW + InRightPadW - ConvDilationW * (X - 1) - 1) / ConvStrideW + 1; + + const auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C, Hi, Wi)); + const auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C, Y, X)); + const auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo)); + + const auto descs = transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + make_tuple(ConvStrideH, ConvStrideW), + make_tuple(ConvDilationH, ConvDilationW), + make_tuple(InLeftPadH, InLeftPadW), + make_tuple(InRightPadH, InRightPadW), + GN0, + GK1); + + const auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0]; + const auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1]; + const auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2]; + + using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1); + using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1); + using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1); + + using AGridIteratorHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11 + Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 + Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 + + using BGridIteratorHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 + + using CGridIteratorHacks = decltype(make_tuple( + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1 + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1 + + using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>; + + using BGridMoveSliceWindowIteratorHacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; + + using GridwiseContraction = + GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + InMemoryDataOperationEnum_t::Set, + AGridDesc_GK0_GM0_GM1_GK1, + BGridDesc_GK0_GN0_GN1_GK1, + CGridDesc_GM0_GM1_GN0_GN1, + GM1PerBlockGM11, + GN1PerBlockGN11, + GK0PerBlock, + BM1PerThreadBM11, + BN1PerThreadBN11, + BK0PerThread, + BM10BN10ThreadClusterBM10Xs, + BM10BN10ThreadClusterBN10Xs, + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferSrcVectorTensorContiguousDimOrder, + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferSrcVectorTensorContiguousDimOrder, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + AGridIteratorHacks, + BGridIteratorHacks, + CGridIteratorHacks, + AGridMoveSliceWindowIteratorHacks, + BGridMoveSliceWindowIteratorHacks>; + + if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0) + { + auto desc_tuple = + make_tuple(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1( + a_grid_desc_gk0_gm0_gm1_gk1), + GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1( + b_grid_desc_gk0_gn0_gn1_gk1), + GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( + c_grid_desc_gm0_gm1_gn0_gn1), + GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10( + c_grid_desc_gm0_gm1_gn0_gn1)); + + *static_cast(p_desc_tuple) = desc_tuple; + } +}; + +extern "C" __global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_desc_tuple) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); + constexpr auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3)); + constexpr auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); + + constexpr auto descs = + transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + GN0, + GK1); + + constexpr auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0]; + constexpr auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1]; + constexpr auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2]; + + using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1); + using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1); + using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1); + + using AGridIteratorHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11 + Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 + Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 + + using BGridIteratorHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 + + using CGridIteratorHacks = decltype(make_tuple( + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1 + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1 + + using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>; + + using BGridMoveSliceWindowIteratorHacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; + + using GridwiseContraction = + GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + InMemoryDataOperationEnum_t::Set, + AGridDesc_GK0_GM0_GM1_GK1, + BGridDesc_GK0_GN0_GN1_GK1, + CGridDesc_GM0_GM1_GN0_GN1, + GM1PerBlockGM11, + GN1PerBlockGN11, + GK0PerBlock, + BM1PerThreadBM11, + BN1PerThreadBN11, + BK0PerThread, + BM10BN10ThreadClusterBM10Xs, + BM10BN10ThreadClusterBN10Xs, + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferSrcVectorTensorContiguousDimOrder, + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferSrcVectorTensorContiguousDimOrder, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + AGridIteratorHacks, + BGridIteratorHacks, + CGridIteratorHacks, + AGridMoveSliceWindowIteratorHacks, + BGridMoveSliceWindowIteratorHacks>; + + using AGridDesc_GK0_GM0_GM10_GM11_GK1 = + decltype(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1( + a_grid_desc_gk0_gm0_gm1_gk1)); + using BGridDesc_GK0_GN0_GN10_GN11_GK1 = + decltype(GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1( + b_grid_desc_gk0_gn0_gn1_gk1)); + using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = + decltype(GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( + c_grid_desc_gm0_gm1_gn0_gn1)); + using CGridBlockCluster_BlockId_To_GM10_GN10 = + decltype(GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10( + c_grid_desc_gm0_gm1_gn0_gn1)); + + using DescTuple = decltype(make_tuple(AGridDesc_GK0_GM0_GM10_GM11_GK1{}, + BGridDesc_GK0_GN0_GN10_GN11_GK1{}, + CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{}, + CGridBlockCluster_BlockId_To_GM10_GN10{})); + + const auto desc_tuple = *reinterpret_cast( +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + // TODO: how to cast? + (const void*)p_desc_tuple +#pragma clang diagnostic pop + ); + + const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0]; + const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1]; + const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = desc_tuple[I2]; + const auto c_grid_block_cluster_blockid_to_gm10_gn10 = desc_tuple[I3]; + + constexpr index_t shared_block_size = + GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseContraction::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_grid_desc_gk0_gm0_gm10_gm11_gk1, + b_grid_desc_gk0_gn0_gn10_gn11_gk1, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_block_cluster_blockid_to_gm10_gn10, + integral_constant{}, + integral_constant{}); +}; diff --git a/external/half/include/half.hpp b/external/half/include/half.hpp new file mode 100644 index 0000000000..b698aac39f --- /dev/null +++ b/external/half/include/half.hpp @@ -0,0 +1,5671 @@ +// half - IEEE 754-based half-precision floating-point library. +// +// Copyright (c) 2012-2019 Christian Rau +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +// associated documentation +// files (the "Software"), to deal in the Software without restriction, including without limitation +// the rights to use, copy, +// modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit +// persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT +// NOT LIMITED TO THE +// WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF +// CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Version 2.1.0 + +/// \file +/// Main header file for half-precision functionality. + +#ifndef HALF_HALF_HPP +#define HALF_HALF_HPP + +#define HALF_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) + +#if defined(__INTEL_COMPILER) +#define HALF_ICC_VERSION __INTEL_COMPILER +#elif defined(__ICC) +#define HALF_ICC_VERSION __ICC +#elif defined(__ICL) +#define HALF_ICC_VERSION __ICL +#else +#define HALF_ICC_VERSION 0 +#endif + +// check C++11 language features +#if defined(__clang__) // clang +#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if __has_feature(cxx_thread_local) && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) +#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 +#endif +#if(defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && \ + !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#elif HALF_ICC_VERSION && defined(__INTEL_CXX11_MODE__) // Intel C++ +#if HALF_ICC_VERSION >= 1500 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) +#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 +#endif +#if HALF_ICC_VERSION >= 1500 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if HALF_ICC_VERSION >= 1400 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if HALF_ICC_VERSION >= 1400 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if HALF_ICC_VERSION >= 1110 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if HALF_ICC_VERSION >= 1110 && !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#elif defined(__GNUC__) // gcc +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L +#if HALF_GCC_VERSION >= 408 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) +#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 +#endif +#if HALF_GCC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if HALF_GCC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if HALF_GCC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#endif +#define HALF_TWOS_COMPLEMENT_INT 1 +#elif defined(_MSC_VER) // Visual C++ +#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) +#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 +#endif +#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#define HALF_TWOS_COMPLEMENT_INT 1 +#define HALF_POP_WARNINGS 1 +#pragma warning(push) +#pragma warning(disable : 4099 4127 4146) // struct vs class, constant in if, negative unsigned +#endif + +// check C++11 library features +#include +#if defined(_LIBCPP_VERSION) // libc++ +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 +#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#ifndef HALF_ENABLE_CPP11_CSTDINT +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#ifndef HALF_ENABLE_CPP11_CMATH +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#ifndef HALF_ENABLE_CPP11_HASH +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#ifndef HALF_ENABLE_CPP11_CFENV +#define HALF_ENABLE_CPP11_CFENV 1 +#endif +#endif +#elif defined(__GLIBCXX__) // libstdc++ +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 +#ifdef __clang__ +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CFENV) +#define HALF_ENABLE_CPP11_CFENV 1 +#endif +#else +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CFENV) +#define HALF_ENABLE_CPP11_CFENV 1 +#endif +#endif +#endif +#elif defined(_CPPLIB_VER) // Dinkumware/Visual C++ +#if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_CSTDINT) +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_HASH) +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#if _CPPLIB_VER >= 610 && !defined(HALF_ENABLE_CPP11_CMATH) +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#if _CPPLIB_VER >= 610 && !defined(HALF_ENABLE_CPP11_CFENV) +#define HALF_ENABLE_CPP11_CFENV 1 +#endif +#endif +#undef HALF_GCC_VERSION +#undef HALF_ICC_VERSION + +// any error throwing C++ exceptions? +#if defined(HALF_ERRHANDLING_THROW_INVALID) || defined(HALF_ERRHANDLING_THROW_DIVBYZERO) || \ + defined(HALF_ERRHANDLING_THROW_OVERFLOW) || defined(HALF_ERRHANDLING_THROW_UNDERFLOW) || \ + defined(HALF_ERRHANDLING_THROW_INEXACT) +#define HALF_ERRHANDLING_THROWS 1 +#endif + +// any error handling enabled? +#define HALF_ERRHANDLING \ + (HALF_ERRHANDLING_FLAGS || HALF_ERRHANDLING_ERRNO || HALF_ERRHANDLING_FENV || \ + HALF_ERRHANDLING_THROWS) + +#if HALF_ERRHANDLING +#define HALF_UNUSED_NOERR(name) name +#else +#define HALF_UNUSED_NOERR(name) +#endif + +// support constexpr +#if HALF_ENABLE_CPP11_CONSTEXPR +#define HALF_CONSTEXPR constexpr +#define HALF_CONSTEXPR_CONST constexpr +#if HALF_ERRHANDLING +#define HALF_CONSTEXPR_NOERR +#else +#define HALF_CONSTEXPR_NOERR constexpr +#endif +#else +#define HALF_CONSTEXPR +#define HALF_CONSTEXPR_CONST const +#define HALF_CONSTEXPR_NOERR +#endif + +// support noexcept +#if HALF_ENABLE_CPP11_NOEXCEPT +#define HALF_NOEXCEPT noexcept +#define HALF_NOTHROW noexcept +#else +#define HALF_NOEXCEPT +#define HALF_NOTHROW throw() +#endif + +// support thread storage +#if HALF_ENABLE_CPP11_THREAD_LOCAL +#define HALF_THREAD_LOCAL thread_local +#else +#define HALF_THREAD_LOCAL static +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if HALF_ENABLE_CPP11_TYPE_TRAITS +#include +#endif +#if HALF_ENABLE_CPP11_CSTDINT +#include +#endif +#if HALF_ERRHANDLING_ERRNO +#include +#endif +#if HALF_ENABLE_CPP11_CFENV +#include +#endif +#if HALF_ENABLE_CPP11_HASH +#include +#endif +#if HALF_ENABLE_F16C_INTRINSICS +#include +#endif + +#ifndef HALF_ENABLE_F16C_INTRINSICS +/// Enable F16C intruction set intrinsics. +/// Defining this to 1 enables the use of [F16C compiler +/// intrinsics](https://en.wikipedia.org/wiki/F16C) for converting between +/// half-precision and single-precision values which may result in improved performance. This will +/// not perform additional checks +/// for support of the F16C instruction set, so an appropriate target platform is required when +/// enabling this feature. +/// +/// Unless predefined it will be enabled automatically when the `__F16C__` symbol is defined, which +/// some compilers do on supporting platforms. +#define HALF_ENABLE_F16C_INTRINSICS __F16C__ +#endif + +#ifdef HALF_DOXYGEN_ONLY +/// Type for internal floating-point computations. +/// This can be predefined to a built-in floating-point type (`float`, `double` or `long double`) to +/// override the internal +/// half-precision implementation to use this type for computing arithmetic operations and +/// mathematical function (if available). +/// This can result in improved performance for arithmetic operators and mathematical functions but +/// might cause results to +/// deviate from the specified half-precision rounding mode and inhibits proper detection of +/// half-precision exceptions. +#define HALF_ARITHMETIC_TYPE (undefined) + +/// Enable internal exception flags. +/// Defining this to 1 causes operations on half-precision values to raise internal floating-point +/// exception flags according to +/// the IEEE 754 standard. These can then be cleared and checked with clearexcept(), testexcept(). +#define HALF_ERRHANDLING_FLAGS 0 + +/// Enable exception propagation to `errno`. +/// Defining this to 1 causes operations on half-precision values to propagate floating-point +/// exceptions to +/// [errno](https://en.cppreference.com/w/cpp/error/errno) from ``. Specifically this will +/// propagate domain errors as +/// [EDOM](https://en.cppreference.com/w/cpp/error/errno_macros) and pole, overflow and underflow +/// errors as +/// [ERANGE](https://en.cppreference.com/w/cpp/error/errno_macros). Inexact errors won't be +/// propagated. +#define HALF_ERRHANDLING_ERRNO 0 + +/// Enable exception propagation to built-in floating-point platform. +/// Defining this to 1 causes operations on half-precision values to propagate floating-point +/// exceptions to the built-in +/// single- and double-precision implementation's exception flags using the +/// [C++11 floating-point environment control](https://en.cppreference.com/w/cpp/numeric/fenv) from +/// ``. However, this +/// does not work in reverse and single- or double-precision exceptions will not raise the +/// corresponding half-precision +/// exception flags, nor will explicitly clearing flags clear the corresponding built-in flags. +#define HALF_ERRHANDLING_FENV 0 + +/// Throw C++ exception on domain errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::domain_error](https://en.cppreference.com/w/cpp/error/domain_error) with the specified +/// message on domain errors. +#define HALF_ERRHANDLING_THROW_INVALID (undefined) + +/// Throw C++ exception on pole errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::domain_error](https://en.cppreference.com/w/cpp/error/domain_error) with the specified +/// message on pole errors. +#define HALF_ERRHANDLING_THROW_DIVBYZERO (undefined) + +/// Throw C++ exception on overflow errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::overflow_error](https://en.cppreference.com/w/cpp/error/overflow_error) with the specified +/// message on overflows. +#define HALF_ERRHANDLING_THROW_OVERFLOW (undefined) + +/// Throw C++ exception on underflow errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::underflow_error](https://en.cppreference.com/w/cpp/error/underflow_error) with the +/// specified message on underflows. +#define HALF_ERRHANDLING_THROW_UNDERFLOW (undefined) + +/// Throw C++ exception on rounding errors. +/// Defining this to 1 causes operations on half-precision values to throw a +/// [std::range_error](https://en.cppreference.com/w/cpp/error/range_error) with the specified +/// message on general rounding errors. +#define HALF_ERRHANDLING_THROW_INEXACT (undefined) +#endif + +#ifndef HALF_ERRHANDLING_OVERFLOW_TO_INEXACT +/// Raise INEXACT exception on overflow. +/// Defining this to 1 (default) causes overflow errors to automatically raise inexact exceptions in +/// addition. +/// These will be raised after any possible handling of the underflow exception. +#define HALF_ERRHANDLING_OVERFLOW_TO_INEXACT 1 +#endif + +#ifndef HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT +/// Raise INEXACT exception on underflow. +/// Defining this to 1 (default) causes underflow errors to automatically raise inexact exceptions +/// in addition. +/// These will be raised after any possible handling of the underflow exception. +/// +/// **Note:** This will actually cause underflow (and the accompanying inexact) exceptions to be +/// raised *only* when the result +/// is inexact, while if disabled bare underflow errors will be raised for *any* (possibly exact) +/// subnormal result. +#define HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT 1 +#endif + +/// Default rounding mode. +/// This specifies the rounding mode used for all conversions between [half](\ref half_float::half)s +/// and more precise types +/// (unless using half_cast() and specifying the rounding mode directly) as well as in arithmetic +/// operations and mathematical +/// functions. It can be redefined (before including half.hpp) to one of the standard rounding modes +/// using their respective +/// constants or the equivalent values of +/// [std::float_round_style](https://en.cppreference.com/w/cpp/types/numeric_limits/float_round_style): +/// +/// `std::float_round_style` | value | rounding +/// ---------------------------------|-------|------------------------- +/// `std::round_indeterminate` | -1 | fastest +/// `std::round_toward_zero` | 0 | toward zero +/// `std::round_to_nearest` | 1 | to nearest (default) +/// `std::round_toward_infinity` | 2 | toward positive infinity +/// `std::round_toward_neg_infinity` | 3 | toward negative infinity +/// +/// By default this is set to `1` (`std::round_to_nearest`), which rounds results to the nearest +/// representable value. It can even +/// be set to +/// [std::numeric_limits::round_style](https://en.cppreference.com/w/cpp/types/numeric_limits/round_style) +/// to synchronize +/// the rounding mode with that of the built-in single-precision implementation (which is likely +/// `std::round_to_nearest`, though). +#ifndef HALF_ROUND_STYLE +#define HALF_ROUND_STYLE 1 // = std::round_to_nearest +#endif + +/// Value signaling overflow. +/// In correspondence with `HUGE_VAL[F|L]` from `` this symbol expands to a positive value +/// signaling the overflow of an +/// operation, in particular it just evaluates to positive infinity. +/// +/// **See also:** Documentation for +/// [HUGE_VAL](https://en.cppreference.com/w/cpp/numeric/math/HUGE_VAL) +#define HUGE_VALH std::numeric_limits::infinity() + +/// Fast half-precision fma function. +/// This symbol is defined if the fma() function generally executes as fast as, or faster than, a +/// separate +/// half-precision multiplication followed by an addition, which is always the case. +/// +/// **See also:** Documentation for +/// [FP_FAST_FMA](https://en.cppreference.com/w/cpp/numeric/math/fma) +#define FP_FAST_FMAH 1 + +/// Half rounding mode. +/// In correspondence with `FLT_ROUNDS` from `` this symbol expands to the rounding mode +/// used for +/// half-precision operations. It is an alias for [HALF_ROUND_STYLE](\ref HALF_ROUND_STYLE). +/// +/// **See also:** Documentation for +/// [FLT_ROUNDS](https://en.cppreference.com/w/cpp/types/climits/FLT_ROUNDS) +#define HLF_ROUNDS HALF_ROUND_STYLE + +#ifndef FP_ILOGB0 +#define FP_ILOGB0 INT_MIN +#endif +#ifndef FP_ILOGBNAN +#define FP_ILOGBNAN INT_MAX +#endif +#ifndef FP_SUBNORMAL +#define FP_SUBNORMAL 0 +#endif +#ifndef FP_ZERO +#define FP_ZERO 1 +#endif +#ifndef FP_NAN +#define FP_NAN 2 +#endif +#ifndef FP_INFINITE +#define FP_INFINITE 3 +#endif +#ifndef FP_NORMAL +#define FP_NORMAL 4 +#endif + +#if !HALF_ENABLE_CPP11_CFENV && !defined(FE_ALL_EXCEPT) +#define FE_INVALID 0x10 +#define FE_DIVBYZERO 0x08 +#define FE_OVERFLOW 0x04 +#define FE_UNDERFLOW 0x02 +#define FE_INEXACT 0x01 +#define FE_ALL_EXCEPT (FE_INVALID | FE_DIVBYZERO | FE_OVERFLOW | FE_UNDERFLOW | FE_INEXACT) +#endif + +/// Main namespace for half-precision functionality. +/// This namespace contains all the functionality provided by the library. +namespace half_float { +class half; + +#if HALF_ENABLE_CPP11_USER_LITERALS +/// Library-defined half-precision literals. +/// Import this namespace to enable half-precision floating-point literals: +/// ~~~~{.cpp} +/// using namespace half_float::literal; +/// half_float::half = 4.2_h; +/// ~~~~ +namespace literal { +half operator"" _h(long double); +} +#endif + +/// \internal +/// \brief Implementation details. +namespace detail { +#if HALF_ENABLE_CPP11_TYPE_TRAITS +/// Conditional type. +template +struct conditional : std::conditional +{ +}; + +/// Helper for tag dispatching. +template +struct bool_type : std::integral_constant +{ +}; +using std::false_type; +using std::true_type; + +/// Type traits for floating-point types. +template +struct is_float : std::is_floating_point +{ +}; +#else +/// Conditional type. +template +struct conditional +{ + typedef T type; +}; +template +struct conditional +{ + typedef F type; +}; + +/// Helper for tag dispatching. +template +struct bool_type +{ +}; +typedef bool_type true_type; +typedef bool_type false_type; + +/// Type traits for floating-point types. +template +struct is_float : false_type +{ +}; +template +struct is_float : is_float +{ +}; +template +struct is_float : is_float +{ +}; +template +struct is_float : is_float +{ +}; +template <> +struct is_float : true_type +{ +}; +template <> +struct is_float : true_type +{ +}; +template <> +struct is_float : true_type +{ +}; +#endif + +/// Type traits for floating-point bits. +template +struct bits +{ + typedef unsigned char type; +}; +template +struct bits : bits +{ +}; +template +struct bits : bits +{ +}; +template +struct bits : bits +{ +}; + +#if HALF_ENABLE_CPP11_CSTDINT +/// Unsigned integer of (at least) 16 bits width. +typedef std::uint_least16_t uint16; + +/// Fastest unsigned integer of (at least) 32 bits width. +typedef std::uint_fast32_t uint32; + +/// Fastest signed integer of (at least) 32 bits width. +typedef std::int_fast32_t int32; + +/// Unsigned integer of (at least) 32 bits width. +template <> +struct bits +{ + typedef std::uint_least32_t type; +}; + +/// Unsigned integer of (at least) 64 bits width. +template <> +struct bits +{ + typedef std::uint_least64_t type; +}; +#else +/// Unsigned integer of (at least) 16 bits width. +typedef unsigned short uint16; + +/// Fastest unsigned integer of (at least) 32 bits width. +typedef unsigned long uint32; + +/// Fastest unsigned integer of (at least) 32 bits width. +typedef long int32; + +/// Unsigned integer of (at least) 32 bits width. +template <> +struct bits + : conditional::digits >= 32, unsigned int, unsigned long> +{ +}; + +#if HALF_ENABLE_CPP11_LONG_LONG +/// Unsigned integer of (at least) 64 bits width. +template <> +struct bits : conditional::digits >= 64, + unsigned long, + unsigned long long> +{ +}; +#else +/// Unsigned integer of (at least) 64 bits width. +template <> +struct bits +{ + typedef unsigned long type; +}; +#endif +#endif + +#ifdef HALF_ARITHMETIC_TYPE +/// Type to use for arithmetic computations and mathematic functions internally. +typedef HALF_ARITHMETIC_TYPE internal_t; +#endif + +/// Tag type for binary construction. +struct binary_t +{ +}; + +/// Tag for binary construction. +HALF_CONSTEXPR_CONST binary_t binary = binary_t(); + +/// \name Implementation defined classification and arithmetic +/// \{ + +/// Check for infinity. +/// \tparam T argument type (builtin floating-point type) +/// \param arg value to query +/// \retval true if infinity +/// \retval false else +template +bool builtin_isinf(T arg) +{ +#if HALF_ENABLE_CPP11_CMATH + return std::isinf(arg); +#elif defined(_MSC_VER) + return !::_finite(static_cast(arg)) && !::_isnan(static_cast(arg)); +#else + return arg == std::numeric_limits::infinity() || arg == -std::numeric_limits::infinity(); +#endif +} + +/// Check for NaN. +/// \tparam T argument type (builtin floating-point type) +/// \param arg value to query +/// \retval true if not a number +/// \retval false else +template +bool builtin_isnan(T arg) +{ +#if HALF_ENABLE_CPP11_CMATH + return std::isnan(arg); +#elif defined(_MSC_VER) + return ::_isnan(static_cast(arg)) != 0; +#else + return arg != arg; +#endif +} + +/// Check sign. +/// \tparam T argument type (builtin floating-point type) +/// \param arg value to query +/// \retval true if signbit set +/// \retval false else +template +bool builtin_signbit(T arg) +{ +#if HALF_ENABLE_CPP11_CMATH + return std::signbit(arg); +#else + return arg < T() || (arg == T() && T(1) / arg < T()); +#endif +} + +/// Platform-independent sign mask. +/// \param arg integer value in two's complement +/// \retval -1 if \a arg negative +/// \retval 0 if \a arg positive +inline uint32 sign_mask(uint32 arg) +{ + static const int N = std::numeric_limits::digits - 1; +#if HALF_TWOS_COMPLEMENT_INT + return static_cast(arg) >> N; +#else + return -((arg >> N) & 1); +#endif +} + +/// Platform-independent arithmetic right shift. +/// \param arg integer value in two's complement +/// \param i shift amount (at most 31) +/// \return \a arg right shifted for \a i bits with possible sign extension +inline uint32 arithmetic_shift(uint32 arg, int i) +{ +#if HALF_TWOS_COMPLEMENT_INT + return static_cast(arg) >> i; +#else + return static_cast(arg) / (static_cast(1) << i) - + ((arg >> (std::numeric_limits::digits - 1)) & 1); +#endif +} + +/// \} +/// \name Error handling +/// \{ + +/// Internal exception flags. +/// \return reference to global exception flags +inline int& errflags() +{ + HALF_THREAD_LOCAL int flags = 0; + return flags; +} + +/// Raise floating-point exception. +/// \param flags exceptions to raise +/// \param cond condition to raise exceptions for +inline void raise(int HALF_UNUSED_NOERR(flags), bool HALF_UNUSED_NOERR(cond) = true) +{ +#if HALF_ERRHANDLING + if(!cond) + return; +#if HALF_ERRHANDLING_FLAGS + errflags() |= flags; +#endif +#if HALF_ERRHANDLING_ERRNO + if(flags & FE_INVALID) + errno = EDOM; + else if(flags & (FE_DIVBYZERO | FE_OVERFLOW | FE_UNDERFLOW)) + errno = ERANGE; +#endif +#if HALF_ERRHANDLING_FENV && HALF_ENABLE_CPP11_CFENV + std::feraiseexcept(flags); +#endif +#ifdef HALF_ERRHANDLING_THROW_INVALID + if(flags & FE_INVALID) + throw std::domain_error(HALF_ERRHANDLING_THROW_INVALID); +#endif +#ifdef HALF_ERRHANDLING_THROW_DIVBYZERO + if(flags & FE_DIVBYZERO) + throw std::domain_error(HALF_ERRHANDLING_THROW_DIVBYZERO); +#endif +#ifdef HALF_ERRHANDLING_THROW_OVERFLOW + if(flags & FE_OVERFLOW) + throw std::overflow_error(HALF_ERRHANDLING_THROW_OVERFLOW); +#endif +#ifdef HALF_ERRHANDLING_THROW_UNDERFLOW + if(flags & FE_UNDERFLOW) + throw std::underflow_error(HALF_ERRHANDLING_THROW_UNDERFLOW); +#endif +#ifdef HALF_ERRHANDLING_THROW_INEXACT + if(flags & FE_INEXACT) + throw std::range_error(HALF_ERRHANDLING_THROW_INEXACT); +#endif +#if HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT + if((flags & FE_UNDERFLOW) && !(flags & FE_INEXACT)) + raise(FE_INEXACT); +#endif +#if HALF_ERRHANDLING_OVERFLOW_TO_INEXACT + if((flags & FE_OVERFLOW) && !(flags & FE_INEXACT)) + raise(FE_INEXACT); +#endif +#endif +} + +/// Check and signal for any NaN. +/// \param x first half-precision value to check +/// \param y second half-precision value to check +/// \retval true if either \a x or \a y is NaN +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool compsignal(unsigned int x, unsigned int y) +{ +#if HALF_ERRHANDLING + raise(FE_INVALID, (x & 0x7FFF) > 0x7C00 || (y & 0x7FFF) > 0x7C00); +#endif + return (x & 0x7FFF) > 0x7C00 || (y & 0x7FFF) > 0x7C00; +} + +/// Signal and silence signaling NaN. +/// \param nan half-precision NaN value +/// \return quiet NaN +/// \exception FE_INVALID if \a nan is signaling NaN +inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int nan) +{ +#if HALF_ERRHANDLING + raise(FE_INVALID, !(nan & 0x200)); +#endif + return nan | 0x200; +} + +/// Signal and silence signaling NaNs. +/// \param x first half-precision value to check +/// \param y second half-precision value to check +/// \return quiet NaN +/// \exception FE_INVALID if \a x or \a y is signaling NaN +inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int x, unsigned int y) +{ +#if HALF_ERRHANDLING + raise(FE_INVALID, + ((x & 0x7FFF) > 0x7C00 && !(x & 0x200)) || ((y & 0x7FFF) > 0x7C00 && !(y & 0x200))); +#endif + return ((x & 0x7FFF) > 0x7C00) ? (x | 0x200) : (y | 0x200); +} + +/// Signal and silence signaling NaNs. +/// \param x first half-precision value to check +/// \param y second half-precision value to check +/// \param z third half-precision value to check +/// \return quiet NaN +/// \exception FE_INVALID if \a x, \a y or \a z is signaling NaN +inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int x, unsigned int y, unsigned int z) +{ +#if HALF_ERRHANDLING + raise(FE_INVALID, + ((x & 0x7FFF) > 0x7C00 && !(x & 0x200)) || ((y & 0x7FFF) > 0x7C00 && !(y & 0x200)) || + ((z & 0x7FFF) > 0x7C00 && !(z & 0x200))); +#endif + return ((x & 0x7FFF) > 0x7C00) ? (x | 0x200) + : ((y & 0x7FFF) > 0x7C00) ? (y | 0x200) : (z | 0x200); +} + +/// Select value or signaling NaN. +/// \param x preferred half-precision value +/// \param y ignored half-precision value except for signaling NaN +/// \return \a y if signaling NaN, \a x otherwise +/// \exception FE_INVALID if \a y is signaling NaN +inline HALF_CONSTEXPR_NOERR unsigned int select(unsigned int x, unsigned int HALF_UNUSED_NOERR(y)) +{ +#if HALF_ERRHANDLING + return (((y & 0x7FFF) > 0x7C00) && !(y & 0x200)) ? signal(y) : x; +#else + return x; +#endif +} + +/// Raise domain error and return NaN. +/// return quiet NaN +/// \exception FE_INVALID +inline HALF_CONSTEXPR_NOERR unsigned int invalid() +{ +#if HALF_ERRHANDLING + raise(FE_INVALID); +#endif + return 0x7FFF; +} + +/// Raise pole error and return infinity. +/// \param sign half-precision value with sign bit only +/// \return half-precision infinity with sign of \a sign +/// \exception FE_DIVBYZERO +inline HALF_CONSTEXPR_NOERR unsigned int pole(unsigned int sign = 0) +{ +#if HALF_ERRHANDLING + raise(FE_DIVBYZERO); +#endif + return sign | 0x7C00; +} + +/// Check value for underflow. +/// \param arg non-zero half-precision value to check +/// \return \a arg +/// \exception FE_UNDERFLOW if arg is subnormal +inline HALF_CONSTEXPR_NOERR unsigned int check_underflow(unsigned int arg) +{ +#if HALF_ERRHANDLING && !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT + raise(FE_UNDERFLOW, !(arg & 0x7C00)); +#endif + return arg; +} + +/// \} +/// \name Conversion and rounding +/// \{ + +/// Half-precision overflow. +/// \tparam R rounding mode to use +/// \param sign half-precision value with sign bit only +/// \return rounded overflowing half-precision value +/// \exception FE_OVERFLOW +template +HALF_CONSTEXPR_NOERR unsigned int overflow(unsigned int sign = 0) +{ +#if HALF_ERRHANDLING + raise(FE_OVERFLOW); +#endif + return (R == std::round_toward_infinity) + ? (sign + 0x7C00 - (sign >> 15)) + : (R == std::round_toward_neg_infinity) + ? (sign + 0x7BFF + (sign >> 15)) + : (R == std::round_toward_zero) ? (sign | 0x7BFF) : (sign | 0x7C00); +} + +/// Half-precision underflow. +/// \tparam R rounding mode to use +/// \param sign half-precision value with sign bit only +/// \return rounded underflowing half-precision value +/// \exception FE_UNDERFLOW +template +HALF_CONSTEXPR_NOERR unsigned int underflow(unsigned int sign = 0) +{ +#if HALF_ERRHANDLING + raise(FE_UNDERFLOW); +#endif + return (R == std::round_toward_infinity) + ? (sign + 1 - (sign >> 15)) + : (R == std::round_toward_neg_infinity) ? (sign + (sign >> 15)) : sign; +} + +/// Round half-precision number. +/// \tparam R rounding mode to use +/// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results +/// \param value finite half-precision number to round +/// \param g guard bit (most significant discarded bit) +/// \param s sticky bit (or of all but the most significant discarded bits) +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded or \a I is `true` +template +HALF_CONSTEXPR_NOERR unsigned int rounded(unsigned int value, int g, int s) +{ +#if HALF_ERRHANDLING + value += (R == std::round_to_nearest) + ? (g & (s | value)) + : (R == std::round_toward_infinity) + ? (~(value >> 15) & (g | s)) + : (R == std::round_toward_neg_infinity) ? ((value >> 15) & (g | s)) : 0; + if((value & 0x7C00) == 0x7C00) + raise(FE_OVERFLOW); + else if(value & 0x7C00) + raise(FE_INEXACT, I || (g | s) != 0); + else + raise(FE_UNDERFLOW, !(HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT) || I || (g | s) != 0); + return value; +#else + return (R == std::round_to_nearest) + ? (value + (g & (s | value))) + : (R == std::round_toward_infinity) + ? (value + (~(value >> 15) & (g | s))) + : (R == std::round_toward_neg_infinity) ? (value + ((value >> 15) & (g | s))) + : value; +#endif +} + +/// Round half-precision number to nearest integer value. +/// \tparam R rounding mode to use +/// \tparam E `true` for round to even, `false` for round away from zero +/// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never raise it +/// \param value half-precision value to round +/// \return half-precision bits for nearest integral value +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded and \a I is `true` +template +unsigned int integral(unsigned int value) +{ + unsigned int abs = value & 0x7FFF; + if(abs < 0x3C00) + { + raise(FE_INEXACT, I); + return ((R == std::round_to_nearest) + ? (0x3C00 & -static_cast(abs >= (0x3800 + E))) + : (R == std::round_toward_infinity) + ? (0x3C00 & -(~(value >> 15) & (abs != 0))) + : (R == std::round_toward_neg_infinity) + ? (0x3C00 & -static_cast(value > 0x8000)) + : 0) | + (value & 0x8000); + } + if(abs >= 0x6400) + return (abs > 0x7C00) ? signal(value) : value; + unsigned int exp = 25 - (abs >> 10), mask = (1 << exp) - 1; + raise(FE_INEXACT, I && (value & mask)); + return (((R == std::round_to_nearest) + ? ((1 << (exp - 1)) - (~(value >> exp) & E)) + : (R == std::round_toward_infinity) + ? (mask & ((value >> 15) - 1)) + : (R == std::round_toward_neg_infinity) ? (mask & -(value >> 15)) : 0) + + value) & + ~mask; +} + +/// Convert fixed point to half-precision floating-point. +/// \tparam R rounding mode to use +/// \tparam F number of fractional bits (at least 11) +/// \tparam S `true` for signed, `false` for unsigned +/// \tparam N `true` for additional normalization step, `false` if already normalized to 1.F +/// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results +/// \param m mantissa in Q1.F fixed point format +/// \param exp exponent +/// \param sign half-precision value with sign bit only +/// \param s sticky bit (or of all but the most significant already discarded bits) +/// \return value converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded or \a I is `true` +template +unsigned int fixed2half(uint32 m, int exp = 14, unsigned int sign = 0, int s = 0) +{ + if(S) + { + uint32 msign = sign_mask(m); + m = (m ^ msign) - msign; + sign = msign & 0x8000; + } + if(N) + for(; m < (static_cast(1) << F) && exp; m <<= 1, --exp) + ; + else if(exp < 0) + return rounded(sign + (m >> (F - 10 - exp)), + (m >> (F - 11 - exp)) & 1, + s | ((m & ((static_cast(1) << (F - 11 - exp)) - 1)) != 0)); + return rounded(sign + (exp << 10) + (m >> (F - 10)), + (m >> (F - 11)) & 1, + s | ((m & ((static_cast(1) << (F - 11)) - 1)) != 0)); +} + +/// Convert IEEE single-precision to half-precision. +/// Credit for this goes to [Jeroen van der +/// Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). +/// \tparam R rounding mode to use +/// \param value single-precision value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int float2half_impl(float value, true_type) +{ +#if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtsi128_si32(_mm_cvtps_ph(_mm_set_ss(value), + (R == std::round_to_nearest) + ? _MM_FROUND_TO_NEAREST_INT + : (R == std::round_toward_zero) + ? _MM_FROUND_TO_ZERO + : (R == std::round_toward_infinity) + ? _MM_FROUND_TO_POS_INF + : (R == std::round_toward_neg_infinity) + ? _MM_FROUND_TO_NEG_INF + : _MM_FROUND_CUR_DIRECTION)); +#else + bits::type fbits; + std::memcpy(&fbits, &value, sizeof(float)); +#if 1 + unsigned int sign = (fbits >> 16) & 0x8000; + fbits &= 0x7FFFFFFF; + if(fbits >= 0x7F800000) + return sign | 0x7C00 | ((fbits > 0x7F800000) ? (0x200 | ((fbits >> 13) & 0x3FF)) : 0); + if(fbits >= 0x47800000) + return overflow(sign); + if(fbits >= 0x38800000) + return rounded(sign | (((fbits >> 23) - 112) << 10) | ((fbits >> 13) & 0x3FF), + (fbits >> 12) & 1, + (fbits & 0xFFF) != 0); + if(fbits >= 0x33000000) + { + int i = 125 - (fbits >> 23); + fbits = (fbits & 0x7FFFFF) | 0x800000; + return rounded(sign | (fbits >> (i + 1)), + (fbits >> i) & 1, + (fbits & ((static_cast(1) << i) - 1)) != 0); + } + if(fbits != 0) + return underflow(sign); + return sign; +#else + static const uint16 base_table[512] = { + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, 0x0040, + 0x0080, 0x0100, 0x0200, 0x0400, 0x0800, 0x0C00, 0x1000, 0x1400, 0x1800, 0x1C00, 0x2000, + 0x2400, 0x2800, 0x2C00, 0x3000, 0x3400, 0x3800, 0x3C00, 0x4000, 0x4400, 0x4800, 0x4C00, + 0x5000, 0x5400, 0x5800, 0x5C00, 0x6000, 0x6400, 0x6800, 0x6C00, 0x7000, 0x7400, 0x7800, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7C00, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8001, 0x8002, 0x8004, 0x8008, + 0x8010, 0x8020, 0x8040, 0x8080, 0x8100, 0x8200, 0x8400, 0x8800, 0x8C00, 0x9000, 0x9400, + 0x9800, 0x9C00, 0xA000, 0xA400, 0xA800, 0xAC00, 0xB000, 0xB400, 0xB800, 0xBC00, 0xC000, + 0xC400, 0xC800, 0xCC00, 0xD000, 0xD400, 0xD800, 0xDC00, 0xE000, 0xE400, 0xE800, 0xEC00, + 0xF000, 0xF400, 0xF800, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFC00}; + static const unsigned char shift_table[256] = { + 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 24, 23, 22, 21, 20, 19, 18, 17, + 16, 15, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 13}; + int sexp = fbits >> 23, exp = sexp & 0xFF, i = shift_table[exp]; + fbits &= 0x7FFFFF; + uint32 m = (fbits | ((exp != 0) << 23)) & -static_cast(exp != 0xFF); + return rounded(base_table[sexp] + (fbits >> i), + (m >> (i - 1)) & 1, + (((static_cast(1) << (i - 1)) - 1) & m) != 0); +#endif +#endif +} + +/// Convert IEEE double-precision to half-precision. +/// \tparam R rounding mode to use +/// \param value double-precision value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int float2half_impl(double value, true_type) +{ +#if HALF_ENABLE_F16C_INTRINSICS + if(R == std::round_indeterminate) + return _mm_cvtsi128_si32( + _mm_cvtps_ph(_mm_cvtpd_ps(_mm_set_sd(value)), _MM_FROUND_CUR_DIRECTION)); +#endif + bits::type dbits; + std::memcpy(&dbits, &value, sizeof(double)); + uint32 hi = dbits >> 32, lo = dbits & 0xFFFFFFFF; + unsigned int sign = (hi >> 16) & 0x8000; + hi &= 0x7FFFFFFF; + if(hi >= 0x7FF00000) + return sign | 0x7C00 | ((dbits & 0xFFFFFFFFFFFFF) ? (0x200 | ((hi >> 10) & 0x3FF)) : 0); + if(hi >= 0x40F00000) + return overflow(sign); + if(hi >= 0x3F100000) + return rounded(sign | (((hi >> 20) - 1008) << 10) | ((hi >> 10) & 0x3FF), + (hi >> 9) & 1, + ((hi & 0x1FF) | lo) != 0); + if(hi >= 0x3E600000) + { + int i = 1018 - (hi >> 20); + hi = (hi & 0xFFFFF) | 0x100000; + return rounded(sign | (hi >> (i + 1)), + (hi >> i) & 1, + ((hi & ((static_cast(1) << i) - 1)) | lo) != 0); + } + if((hi | lo) != 0) + return underflow(sign); + return sign; +} + +/// Convert non-IEEE floating-point to half-precision. +/// \tparam R rounding mode to use +/// \tparam T source type (builtin floating-point type) +/// \param value floating-point value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int float2half_impl(T value, ...) +{ + unsigned int hbits = static_cast(builtin_signbit(value)) << 15; + if(value == T()) + return hbits; + if(builtin_isnan(value)) + return hbits | 0x7FFF; + if(builtin_isinf(value)) + return hbits | 0x7C00; + int exp; + std::frexp(value, &exp); + if(exp > 16) + return overflow(hbits); + if(exp < -13) + value = std::ldexp(value, 25); + else + { + value = std::ldexp(value, 12 - exp); + hbits |= ((exp + 13) << 10); + } + T ival, frac = std::modf(value, &ival); + int m = std::abs(static_cast(ival)); + return rounded(hbits + (m >> 1), m & 1, frac != T()); +} + +/// Convert floating-point to half-precision. +/// \tparam R rounding mode to use +/// \tparam T source type (builtin floating-point type) +/// \param value floating-point value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int float2half(T value) +{ + return float2half_impl(value, + bool_type < std::numeric_limits::is_iec559 && + sizeof(typename bits::type) == sizeof(T) > ()); +} + +/// Convert integer to half-precision floating-point. +/// \tparam R rounding mode to use +/// \tparam T type to convert (builtin integer type) +/// \param value integral value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int int2half(T value) +{ + unsigned int bits = static_cast(value < 0) << 15; + if(!value) + return bits; + if(bits) + value = -value; + if(value > 0xFFFF) + return overflow(bits); + unsigned int m = static_cast(value), exp = 24; + for(; m < 0x400; m <<= 1, --exp) + ; + for(; m > 0x7FF; m >>= 1, ++exp) + ; + bits |= (exp << 10) + m; + return (exp > 24) ? rounded( + bits, (value >> (exp - 25)) & 1, (((1 << (exp - 25)) - 1) & value) != 0) + : bits; +} + +/// Convert half-precision to IEEE single-precision. +/// Credit for this goes to [Jeroen van der +/// Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). +/// \param value half-precision value to convert +/// \return single-precision value +inline float half2float_impl(unsigned int value, float, true_type) +{ +#if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(value))); +#else +#if 0 + bits::type fbits = static_cast::type>(value&0x8000) << 16; + int abs = value & 0x7FFF; + if(abs) + { + fbits |= 0x38000000 << static_cast(abs>=0x7C00); + for(; abs<0x400; abs<<=1,fbits-=0x800000) ; + fbits += static_cast::type>(abs) << 13; + } +#else + static const bits::type mantissa_table[2048] = { + 0x00000000, 0x33800000, 0x34000000, 0x34400000, 0x34800000, 0x34A00000, 0x34C00000, + 0x34E00000, 0x35000000, 0x35100000, 0x35200000, 0x35300000, 0x35400000, 0x35500000, + 0x35600000, 0x35700000, 0x35800000, 0x35880000, 0x35900000, 0x35980000, 0x35A00000, + 0x35A80000, 0x35B00000, 0x35B80000, 0x35C00000, 0x35C80000, 0x35D00000, 0x35D80000, + 0x35E00000, 0x35E80000, 0x35F00000, 0x35F80000, 0x36000000, 0x36040000, 0x36080000, + 0x360C0000, 0x36100000, 0x36140000, 0x36180000, 0x361C0000, 0x36200000, 0x36240000, + 0x36280000, 0x362C0000, 0x36300000, 0x36340000, 0x36380000, 0x363C0000, 0x36400000, + 0x36440000, 0x36480000, 0x364C0000, 0x36500000, 0x36540000, 0x36580000, 0x365C0000, + 0x36600000, 0x36640000, 0x36680000, 0x366C0000, 0x36700000, 0x36740000, 0x36780000, + 0x367C0000, 0x36800000, 0x36820000, 0x36840000, 0x36860000, 0x36880000, 0x368A0000, + 0x368C0000, 0x368E0000, 0x36900000, 0x36920000, 0x36940000, 0x36960000, 0x36980000, + 0x369A0000, 0x369C0000, 0x369E0000, 0x36A00000, 0x36A20000, 0x36A40000, 0x36A60000, + 0x36A80000, 0x36AA0000, 0x36AC0000, 0x36AE0000, 0x36B00000, 0x36B20000, 0x36B40000, + 0x36B60000, 0x36B80000, 0x36BA0000, 0x36BC0000, 0x36BE0000, 0x36C00000, 0x36C20000, + 0x36C40000, 0x36C60000, 0x36C80000, 0x36CA0000, 0x36CC0000, 0x36CE0000, 0x36D00000, + 0x36D20000, 0x36D40000, 0x36D60000, 0x36D80000, 0x36DA0000, 0x36DC0000, 0x36DE0000, + 0x36E00000, 0x36E20000, 0x36E40000, 0x36E60000, 0x36E80000, 0x36EA0000, 0x36EC0000, + 0x36EE0000, 0x36F00000, 0x36F20000, 0x36F40000, 0x36F60000, 0x36F80000, 0x36FA0000, + 0x36FC0000, 0x36FE0000, 0x37000000, 0x37010000, 0x37020000, 0x37030000, 0x37040000, + 0x37050000, 0x37060000, 0x37070000, 0x37080000, 0x37090000, 0x370A0000, 0x370B0000, + 0x370C0000, 0x370D0000, 0x370E0000, 0x370F0000, 0x37100000, 0x37110000, 0x37120000, + 0x37130000, 0x37140000, 0x37150000, 0x37160000, 0x37170000, 0x37180000, 0x37190000, + 0x371A0000, 0x371B0000, 0x371C0000, 0x371D0000, 0x371E0000, 0x371F0000, 0x37200000, + 0x37210000, 0x37220000, 0x37230000, 0x37240000, 0x37250000, 0x37260000, 0x37270000, + 0x37280000, 0x37290000, 0x372A0000, 0x372B0000, 0x372C0000, 0x372D0000, 0x372E0000, + 0x372F0000, 0x37300000, 0x37310000, 0x37320000, 0x37330000, 0x37340000, 0x37350000, + 0x37360000, 0x37370000, 0x37380000, 0x37390000, 0x373A0000, 0x373B0000, 0x373C0000, + 0x373D0000, 0x373E0000, 0x373F0000, 0x37400000, 0x37410000, 0x37420000, 0x37430000, + 0x37440000, 0x37450000, 0x37460000, 0x37470000, 0x37480000, 0x37490000, 0x374A0000, + 0x374B0000, 0x374C0000, 0x374D0000, 0x374E0000, 0x374F0000, 0x37500000, 0x37510000, + 0x37520000, 0x37530000, 0x37540000, 0x37550000, 0x37560000, 0x37570000, 0x37580000, + 0x37590000, 0x375A0000, 0x375B0000, 0x375C0000, 0x375D0000, 0x375E0000, 0x375F0000, + 0x37600000, 0x37610000, 0x37620000, 0x37630000, 0x37640000, 0x37650000, 0x37660000, + 0x37670000, 0x37680000, 0x37690000, 0x376A0000, 0x376B0000, 0x376C0000, 0x376D0000, + 0x376E0000, 0x376F0000, 0x37700000, 0x37710000, 0x37720000, 0x37730000, 0x37740000, + 0x37750000, 0x37760000, 0x37770000, 0x37780000, 0x37790000, 0x377A0000, 0x377B0000, + 0x377C0000, 0x377D0000, 0x377E0000, 0x377F0000, 0x37800000, 0x37808000, 0x37810000, + 0x37818000, 0x37820000, 0x37828000, 0x37830000, 0x37838000, 0x37840000, 0x37848000, + 0x37850000, 0x37858000, 0x37860000, 0x37868000, 0x37870000, 0x37878000, 0x37880000, + 0x37888000, 0x37890000, 0x37898000, 0x378A0000, 0x378A8000, 0x378B0000, 0x378B8000, + 0x378C0000, 0x378C8000, 0x378D0000, 0x378D8000, 0x378E0000, 0x378E8000, 0x378F0000, + 0x378F8000, 0x37900000, 0x37908000, 0x37910000, 0x37918000, 0x37920000, 0x37928000, + 0x37930000, 0x37938000, 0x37940000, 0x37948000, 0x37950000, 0x37958000, 0x37960000, + 0x37968000, 0x37970000, 0x37978000, 0x37980000, 0x37988000, 0x37990000, 0x37998000, + 0x379A0000, 0x379A8000, 0x379B0000, 0x379B8000, 0x379C0000, 0x379C8000, 0x379D0000, + 0x379D8000, 0x379E0000, 0x379E8000, 0x379F0000, 0x379F8000, 0x37A00000, 0x37A08000, + 0x37A10000, 0x37A18000, 0x37A20000, 0x37A28000, 0x37A30000, 0x37A38000, 0x37A40000, + 0x37A48000, 0x37A50000, 0x37A58000, 0x37A60000, 0x37A68000, 0x37A70000, 0x37A78000, + 0x37A80000, 0x37A88000, 0x37A90000, 0x37A98000, 0x37AA0000, 0x37AA8000, 0x37AB0000, + 0x37AB8000, 0x37AC0000, 0x37AC8000, 0x37AD0000, 0x37AD8000, 0x37AE0000, 0x37AE8000, + 0x37AF0000, 0x37AF8000, 0x37B00000, 0x37B08000, 0x37B10000, 0x37B18000, 0x37B20000, + 0x37B28000, 0x37B30000, 0x37B38000, 0x37B40000, 0x37B48000, 0x37B50000, 0x37B58000, + 0x37B60000, 0x37B68000, 0x37B70000, 0x37B78000, 0x37B80000, 0x37B88000, 0x37B90000, + 0x37B98000, 0x37BA0000, 0x37BA8000, 0x37BB0000, 0x37BB8000, 0x37BC0000, 0x37BC8000, + 0x37BD0000, 0x37BD8000, 0x37BE0000, 0x37BE8000, 0x37BF0000, 0x37BF8000, 0x37C00000, + 0x37C08000, 0x37C10000, 0x37C18000, 0x37C20000, 0x37C28000, 0x37C30000, 0x37C38000, + 0x37C40000, 0x37C48000, 0x37C50000, 0x37C58000, 0x37C60000, 0x37C68000, 0x37C70000, + 0x37C78000, 0x37C80000, 0x37C88000, 0x37C90000, 0x37C98000, 0x37CA0000, 0x37CA8000, + 0x37CB0000, 0x37CB8000, 0x37CC0000, 0x37CC8000, 0x37CD0000, 0x37CD8000, 0x37CE0000, + 0x37CE8000, 0x37CF0000, 0x37CF8000, 0x37D00000, 0x37D08000, 0x37D10000, 0x37D18000, + 0x37D20000, 0x37D28000, 0x37D30000, 0x37D38000, 0x37D40000, 0x37D48000, 0x37D50000, + 0x37D58000, 0x37D60000, 0x37D68000, 0x37D70000, 0x37D78000, 0x37D80000, 0x37D88000, + 0x37D90000, 0x37D98000, 0x37DA0000, 0x37DA8000, 0x37DB0000, 0x37DB8000, 0x37DC0000, + 0x37DC8000, 0x37DD0000, 0x37DD8000, 0x37DE0000, 0x37DE8000, 0x37DF0000, 0x37DF8000, + 0x37E00000, 0x37E08000, 0x37E10000, 0x37E18000, 0x37E20000, 0x37E28000, 0x37E30000, + 0x37E38000, 0x37E40000, 0x37E48000, 0x37E50000, 0x37E58000, 0x37E60000, 0x37E68000, + 0x37E70000, 0x37E78000, 0x37E80000, 0x37E88000, 0x37E90000, 0x37E98000, 0x37EA0000, + 0x37EA8000, 0x37EB0000, 0x37EB8000, 0x37EC0000, 0x37EC8000, 0x37ED0000, 0x37ED8000, + 0x37EE0000, 0x37EE8000, 0x37EF0000, 0x37EF8000, 0x37F00000, 0x37F08000, 0x37F10000, + 0x37F18000, 0x37F20000, 0x37F28000, 0x37F30000, 0x37F38000, 0x37F40000, 0x37F48000, + 0x37F50000, 0x37F58000, 0x37F60000, 0x37F68000, 0x37F70000, 0x37F78000, 0x37F80000, + 0x37F88000, 0x37F90000, 0x37F98000, 0x37FA0000, 0x37FA8000, 0x37FB0000, 0x37FB8000, + 0x37FC0000, 0x37FC8000, 0x37FD0000, 0x37FD8000, 0x37FE0000, 0x37FE8000, 0x37FF0000, + 0x37FF8000, 0x38000000, 0x38004000, 0x38008000, 0x3800C000, 0x38010000, 0x38014000, + 0x38018000, 0x3801C000, 0x38020000, 0x38024000, 0x38028000, 0x3802C000, 0x38030000, + 0x38034000, 0x38038000, 0x3803C000, 0x38040000, 0x38044000, 0x38048000, 0x3804C000, + 0x38050000, 0x38054000, 0x38058000, 0x3805C000, 0x38060000, 0x38064000, 0x38068000, + 0x3806C000, 0x38070000, 0x38074000, 0x38078000, 0x3807C000, 0x38080000, 0x38084000, + 0x38088000, 0x3808C000, 0x38090000, 0x38094000, 0x38098000, 0x3809C000, 0x380A0000, + 0x380A4000, 0x380A8000, 0x380AC000, 0x380B0000, 0x380B4000, 0x380B8000, 0x380BC000, + 0x380C0000, 0x380C4000, 0x380C8000, 0x380CC000, 0x380D0000, 0x380D4000, 0x380D8000, + 0x380DC000, 0x380E0000, 0x380E4000, 0x380E8000, 0x380EC000, 0x380F0000, 0x380F4000, + 0x380F8000, 0x380FC000, 0x38100000, 0x38104000, 0x38108000, 0x3810C000, 0x38110000, + 0x38114000, 0x38118000, 0x3811C000, 0x38120000, 0x38124000, 0x38128000, 0x3812C000, + 0x38130000, 0x38134000, 0x38138000, 0x3813C000, 0x38140000, 0x38144000, 0x38148000, + 0x3814C000, 0x38150000, 0x38154000, 0x38158000, 0x3815C000, 0x38160000, 0x38164000, + 0x38168000, 0x3816C000, 0x38170000, 0x38174000, 0x38178000, 0x3817C000, 0x38180000, + 0x38184000, 0x38188000, 0x3818C000, 0x38190000, 0x38194000, 0x38198000, 0x3819C000, + 0x381A0000, 0x381A4000, 0x381A8000, 0x381AC000, 0x381B0000, 0x381B4000, 0x381B8000, + 0x381BC000, 0x381C0000, 0x381C4000, 0x381C8000, 0x381CC000, 0x381D0000, 0x381D4000, + 0x381D8000, 0x381DC000, 0x381E0000, 0x381E4000, 0x381E8000, 0x381EC000, 0x381F0000, + 0x381F4000, 0x381F8000, 0x381FC000, 0x38200000, 0x38204000, 0x38208000, 0x3820C000, + 0x38210000, 0x38214000, 0x38218000, 0x3821C000, 0x38220000, 0x38224000, 0x38228000, + 0x3822C000, 0x38230000, 0x38234000, 0x38238000, 0x3823C000, 0x38240000, 0x38244000, + 0x38248000, 0x3824C000, 0x38250000, 0x38254000, 0x38258000, 0x3825C000, 0x38260000, + 0x38264000, 0x38268000, 0x3826C000, 0x38270000, 0x38274000, 0x38278000, 0x3827C000, + 0x38280000, 0x38284000, 0x38288000, 0x3828C000, 0x38290000, 0x38294000, 0x38298000, + 0x3829C000, 0x382A0000, 0x382A4000, 0x382A8000, 0x382AC000, 0x382B0000, 0x382B4000, + 0x382B8000, 0x382BC000, 0x382C0000, 0x382C4000, 0x382C8000, 0x382CC000, 0x382D0000, + 0x382D4000, 0x382D8000, 0x382DC000, 0x382E0000, 0x382E4000, 0x382E8000, 0x382EC000, + 0x382F0000, 0x382F4000, 0x382F8000, 0x382FC000, 0x38300000, 0x38304000, 0x38308000, + 0x3830C000, 0x38310000, 0x38314000, 0x38318000, 0x3831C000, 0x38320000, 0x38324000, + 0x38328000, 0x3832C000, 0x38330000, 0x38334000, 0x38338000, 0x3833C000, 0x38340000, + 0x38344000, 0x38348000, 0x3834C000, 0x38350000, 0x38354000, 0x38358000, 0x3835C000, + 0x38360000, 0x38364000, 0x38368000, 0x3836C000, 0x38370000, 0x38374000, 0x38378000, + 0x3837C000, 0x38380000, 0x38384000, 0x38388000, 0x3838C000, 0x38390000, 0x38394000, + 0x38398000, 0x3839C000, 0x383A0000, 0x383A4000, 0x383A8000, 0x383AC000, 0x383B0000, + 0x383B4000, 0x383B8000, 0x383BC000, 0x383C0000, 0x383C4000, 0x383C8000, 0x383CC000, + 0x383D0000, 0x383D4000, 0x383D8000, 0x383DC000, 0x383E0000, 0x383E4000, 0x383E8000, + 0x383EC000, 0x383F0000, 0x383F4000, 0x383F8000, 0x383FC000, 0x38400000, 0x38404000, + 0x38408000, 0x3840C000, 0x38410000, 0x38414000, 0x38418000, 0x3841C000, 0x38420000, + 0x38424000, 0x38428000, 0x3842C000, 0x38430000, 0x38434000, 0x38438000, 0x3843C000, + 0x38440000, 0x38444000, 0x38448000, 0x3844C000, 0x38450000, 0x38454000, 0x38458000, + 0x3845C000, 0x38460000, 0x38464000, 0x38468000, 0x3846C000, 0x38470000, 0x38474000, + 0x38478000, 0x3847C000, 0x38480000, 0x38484000, 0x38488000, 0x3848C000, 0x38490000, + 0x38494000, 0x38498000, 0x3849C000, 0x384A0000, 0x384A4000, 0x384A8000, 0x384AC000, + 0x384B0000, 0x384B4000, 0x384B8000, 0x384BC000, 0x384C0000, 0x384C4000, 0x384C8000, + 0x384CC000, 0x384D0000, 0x384D4000, 0x384D8000, 0x384DC000, 0x384E0000, 0x384E4000, + 0x384E8000, 0x384EC000, 0x384F0000, 0x384F4000, 0x384F8000, 0x384FC000, 0x38500000, + 0x38504000, 0x38508000, 0x3850C000, 0x38510000, 0x38514000, 0x38518000, 0x3851C000, + 0x38520000, 0x38524000, 0x38528000, 0x3852C000, 0x38530000, 0x38534000, 0x38538000, + 0x3853C000, 0x38540000, 0x38544000, 0x38548000, 0x3854C000, 0x38550000, 0x38554000, + 0x38558000, 0x3855C000, 0x38560000, 0x38564000, 0x38568000, 0x3856C000, 0x38570000, + 0x38574000, 0x38578000, 0x3857C000, 0x38580000, 0x38584000, 0x38588000, 0x3858C000, + 0x38590000, 0x38594000, 0x38598000, 0x3859C000, 0x385A0000, 0x385A4000, 0x385A8000, + 0x385AC000, 0x385B0000, 0x385B4000, 0x385B8000, 0x385BC000, 0x385C0000, 0x385C4000, + 0x385C8000, 0x385CC000, 0x385D0000, 0x385D4000, 0x385D8000, 0x385DC000, 0x385E0000, + 0x385E4000, 0x385E8000, 0x385EC000, 0x385F0000, 0x385F4000, 0x385F8000, 0x385FC000, + 0x38600000, 0x38604000, 0x38608000, 0x3860C000, 0x38610000, 0x38614000, 0x38618000, + 0x3861C000, 0x38620000, 0x38624000, 0x38628000, 0x3862C000, 0x38630000, 0x38634000, + 0x38638000, 0x3863C000, 0x38640000, 0x38644000, 0x38648000, 0x3864C000, 0x38650000, + 0x38654000, 0x38658000, 0x3865C000, 0x38660000, 0x38664000, 0x38668000, 0x3866C000, + 0x38670000, 0x38674000, 0x38678000, 0x3867C000, 0x38680000, 0x38684000, 0x38688000, + 0x3868C000, 0x38690000, 0x38694000, 0x38698000, 0x3869C000, 0x386A0000, 0x386A4000, + 0x386A8000, 0x386AC000, 0x386B0000, 0x386B4000, 0x386B8000, 0x386BC000, 0x386C0000, + 0x386C4000, 0x386C8000, 0x386CC000, 0x386D0000, 0x386D4000, 0x386D8000, 0x386DC000, + 0x386E0000, 0x386E4000, 0x386E8000, 0x386EC000, 0x386F0000, 0x386F4000, 0x386F8000, + 0x386FC000, 0x38700000, 0x38704000, 0x38708000, 0x3870C000, 0x38710000, 0x38714000, + 0x38718000, 0x3871C000, 0x38720000, 0x38724000, 0x38728000, 0x3872C000, 0x38730000, + 0x38734000, 0x38738000, 0x3873C000, 0x38740000, 0x38744000, 0x38748000, 0x3874C000, + 0x38750000, 0x38754000, 0x38758000, 0x3875C000, 0x38760000, 0x38764000, 0x38768000, + 0x3876C000, 0x38770000, 0x38774000, 0x38778000, 0x3877C000, 0x38780000, 0x38784000, + 0x38788000, 0x3878C000, 0x38790000, 0x38794000, 0x38798000, 0x3879C000, 0x387A0000, + 0x387A4000, 0x387A8000, 0x387AC000, 0x387B0000, 0x387B4000, 0x387B8000, 0x387BC000, + 0x387C0000, 0x387C4000, 0x387C8000, 0x387CC000, 0x387D0000, 0x387D4000, 0x387D8000, + 0x387DC000, 0x387E0000, 0x387E4000, 0x387E8000, 0x387EC000, 0x387F0000, 0x387F4000, + 0x387F8000, 0x387FC000, 0x38000000, 0x38002000, 0x38004000, 0x38006000, 0x38008000, + 0x3800A000, 0x3800C000, 0x3800E000, 0x38010000, 0x38012000, 0x38014000, 0x38016000, + 0x38018000, 0x3801A000, 0x3801C000, 0x3801E000, 0x38020000, 0x38022000, 0x38024000, + 0x38026000, 0x38028000, 0x3802A000, 0x3802C000, 0x3802E000, 0x38030000, 0x38032000, + 0x38034000, 0x38036000, 0x38038000, 0x3803A000, 0x3803C000, 0x3803E000, 0x38040000, + 0x38042000, 0x38044000, 0x38046000, 0x38048000, 0x3804A000, 0x3804C000, 0x3804E000, + 0x38050000, 0x38052000, 0x38054000, 0x38056000, 0x38058000, 0x3805A000, 0x3805C000, + 0x3805E000, 0x38060000, 0x38062000, 0x38064000, 0x38066000, 0x38068000, 0x3806A000, + 0x3806C000, 0x3806E000, 0x38070000, 0x38072000, 0x38074000, 0x38076000, 0x38078000, + 0x3807A000, 0x3807C000, 0x3807E000, 0x38080000, 0x38082000, 0x38084000, 0x38086000, + 0x38088000, 0x3808A000, 0x3808C000, 0x3808E000, 0x38090000, 0x38092000, 0x38094000, + 0x38096000, 0x38098000, 0x3809A000, 0x3809C000, 0x3809E000, 0x380A0000, 0x380A2000, + 0x380A4000, 0x380A6000, 0x380A8000, 0x380AA000, 0x380AC000, 0x380AE000, 0x380B0000, + 0x380B2000, 0x380B4000, 0x380B6000, 0x380B8000, 0x380BA000, 0x380BC000, 0x380BE000, + 0x380C0000, 0x380C2000, 0x380C4000, 0x380C6000, 0x380C8000, 0x380CA000, 0x380CC000, + 0x380CE000, 0x380D0000, 0x380D2000, 0x380D4000, 0x380D6000, 0x380D8000, 0x380DA000, + 0x380DC000, 0x380DE000, 0x380E0000, 0x380E2000, 0x380E4000, 0x380E6000, 0x380E8000, + 0x380EA000, 0x380EC000, 0x380EE000, 0x380F0000, 0x380F2000, 0x380F4000, 0x380F6000, + 0x380F8000, 0x380FA000, 0x380FC000, 0x380FE000, 0x38100000, 0x38102000, 0x38104000, + 0x38106000, 0x38108000, 0x3810A000, 0x3810C000, 0x3810E000, 0x38110000, 0x38112000, + 0x38114000, 0x38116000, 0x38118000, 0x3811A000, 0x3811C000, 0x3811E000, 0x38120000, + 0x38122000, 0x38124000, 0x38126000, 0x38128000, 0x3812A000, 0x3812C000, 0x3812E000, + 0x38130000, 0x38132000, 0x38134000, 0x38136000, 0x38138000, 0x3813A000, 0x3813C000, + 0x3813E000, 0x38140000, 0x38142000, 0x38144000, 0x38146000, 0x38148000, 0x3814A000, + 0x3814C000, 0x3814E000, 0x38150000, 0x38152000, 0x38154000, 0x38156000, 0x38158000, + 0x3815A000, 0x3815C000, 0x3815E000, 0x38160000, 0x38162000, 0x38164000, 0x38166000, + 0x38168000, 0x3816A000, 0x3816C000, 0x3816E000, 0x38170000, 0x38172000, 0x38174000, + 0x38176000, 0x38178000, 0x3817A000, 0x3817C000, 0x3817E000, 0x38180000, 0x38182000, + 0x38184000, 0x38186000, 0x38188000, 0x3818A000, 0x3818C000, 0x3818E000, 0x38190000, + 0x38192000, 0x38194000, 0x38196000, 0x38198000, 0x3819A000, 0x3819C000, 0x3819E000, + 0x381A0000, 0x381A2000, 0x381A4000, 0x381A6000, 0x381A8000, 0x381AA000, 0x381AC000, + 0x381AE000, 0x381B0000, 0x381B2000, 0x381B4000, 0x381B6000, 0x381B8000, 0x381BA000, + 0x381BC000, 0x381BE000, 0x381C0000, 0x381C2000, 0x381C4000, 0x381C6000, 0x381C8000, + 0x381CA000, 0x381CC000, 0x381CE000, 0x381D0000, 0x381D2000, 0x381D4000, 0x381D6000, + 0x381D8000, 0x381DA000, 0x381DC000, 0x381DE000, 0x381E0000, 0x381E2000, 0x381E4000, + 0x381E6000, 0x381E8000, 0x381EA000, 0x381EC000, 0x381EE000, 0x381F0000, 0x381F2000, + 0x381F4000, 0x381F6000, 0x381F8000, 0x381FA000, 0x381FC000, 0x381FE000, 0x38200000, + 0x38202000, 0x38204000, 0x38206000, 0x38208000, 0x3820A000, 0x3820C000, 0x3820E000, + 0x38210000, 0x38212000, 0x38214000, 0x38216000, 0x38218000, 0x3821A000, 0x3821C000, + 0x3821E000, 0x38220000, 0x38222000, 0x38224000, 0x38226000, 0x38228000, 0x3822A000, + 0x3822C000, 0x3822E000, 0x38230000, 0x38232000, 0x38234000, 0x38236000, 0x38238000, + 0x3823A000, 0x3823C000, 0x3823E000, 0x38240000, 0x38242000, 0x38244000, 0x38246000, + 0x38248000, 0x3824A000, 0x3824C000, 0x3824E000, 0x38250000, 0x38252000, 0x38254000, + 0x38256000, 0x38258000, 0x3825A000, 0x3825C000, 0x3825E000, 0x38260000, 0x38262000, + 0x38264000, 0x38266000, 0x38268000, 0x3826A000, 0x3826C000, 0x3826E000, 0x38270000, + 0x38272000, 0x38274000, 0x38276000, 0x38278000, 0x3827A000, 0x3827C000, 0x3827E000, + 0x38280000, 0x38282000, 0x38284000, 0x38286000, 0x38288000, 0x3828A000, 0x3828C000, + 0x3828E000, 0x38290000, 0x38292000, 0x38294000, 0x38296000, 0x38298000, 0x3829A000, + 0x3829C000, 0x3829E000, 0x382A0000, 0x382A2000, 0x382A4000, 0x382A6000, 0x382A8000, + 0x382AA000, 0x382AC000, 0x382AE000, 0x382B0000, 0x382B2000, 0x382B4000, 0x382B6000, + 0x382B8000, 0x382BA000, 0x382BC000, 0x382BE000, 0x382C0000, 0x382C2000, 0x382C4000, + 0x382C6000, 0x382C8000, 0x382CA000, 0x382CC000, 0x382CE000, 0x382D0000, 0x382D2000, + 0x382D4000, 0x382D6000, 0x382D8000, 0x382DA000, 0x382DC000, 0x382DE000, 0x382E0000, + 0x382E2000, 0x382E4000, 0x382E6000, 0x382E8000, 0x382EA000, 0x382EC000, 0x382EE000, + 0x382F0000, 0x382F2000, 0x382F4000, 0x382F6000, 0x382F8000, 0x382FA000, 0x382FC000, + 0x382FE000, 0x38300000, 0x38302000, 0x38304000, 0x38306000, 0x38308000, 0x3830A000, + 0x3830C000, 0x3830E000, 0x38310000, 0x38312000, 0x38314000, 0x38316000, 0x38318000, + 0x3831A000, 0x3831C000, 0x3831E000, 0x38320000, 0x38322000, 0x38324000, 0x38326000, + 0x38328000, 0x3832A000, 0x3832C000, 0x3832E000, 0x38330000, 0x38332000, 0x38334000, + 0x38336000, 0x38338000, 0x3833A000, 0x3833C000, 0x3833E000, 0x38340000, 0x38342000, + 0x38344000, 0x38346000, 0x38348000, 0x3834A000, 0x3834C000, 0x3834E000, 0x38350000, + 0x38352000, 0x38354000, 0x38356000, 0x38358000, 0x3835A000, 0x3835C000, 0x3835E000, + 0x38360000, 0x38362000, 0x38364000, 0x38366000, 0x38368000, 0x3836A000, 0x3836C000, + 0x3836E000, 0x38370000, 0x38372000, 0x38374000, 0x38376000, 0x38378000, 0x3837A000, + 0x3837C000, 0x3837E000, 0x38380000, 0x38382000, 0x38384000, 0x38386000, 0x38388000, + 0x3838A000, 0x3838C000, 0x3838E000, 0x38390000, 0x38392000, 0x38394000, 0x38396000, + 0x38398000, 0x3839A000, 0x3839C000, 0x3839E000, 0x383A0000, 0x383A2000, 0x383A4000, + 0x383A6000, 0x383A8000, 0x383AA000, 0x383AC000, 0x383AE000, 0x383B0000, 0x383B2000, + 0x383B4000, 0x383B6000, 0x383B8000, 0x383BA000, 0x383BC000, 0x383BE000, 0x383C0000, + 0x383C2000, 0x383C4000, 0x383C6000, 0x383C8000, 0x383CA000, 0x383CC000, 0x383CE000, + 0x383D0000, 0x383D2000, 0x383D4000, 0x383D6000, 0x383D8000, 0x383DA000, 0x383DC000, + 0x383DE000, 0x383E0000, 0x383E2000, 0x383E4000, 0x383E6000, 0x383E8000, 0x383EA000, + 0x383EC000, 0x383EE000, 0x383F0000, 0x383F2000, 0x383F4000, 0x383F6000, 0x383F8000, + 0x383FA000, 0x383FC000, 0x383FE000, 0x38400000, 0x38402000, 0x38404000, 0x38406000, + 0x38408000, 0x3840A000, 0x3840C000, 0x3840E000, 0x38410000, 0x38412000, 0x38414000, + 0x38416000, 0x38418000, 0x3841A000, 0x3841C000, 0x3841E000, 0x38420000, 0x38422000, + 0x38424000, 0x38426000, 0x38428000, 0x3842A000, 0x3842C000, 0x3842E000, 0x38430000, + 0x38432000, 0x38434000, 0x38436000, 0x38438000, 0x3843A000, 0x3843C000, 0x3843E000, + 0x38440000, 0x38442000, 0x38444000, 0x38446000, 0x38448000, 0x3844A000, 0x3844C000, + 0x3844E000, 0x38450000, 0x38452000, 0x38454000, 0x38456000, 0x38458000, 0x3845A000, + 0x3845C000, 0x3845E000, 0x38460000, 0x38462000, 0x38464000, 0x38466000, 0x38468000, + 0x3846A000, 0x3846C000, 0x3846E000, 0x38470000, 0x38472000, 0x38474000, 0x38476000, + 0x38478000, 0x3847A000, 0x3847C000, 0x3847E000, 0x38480000, 0x38482000, 0x38484000, + 0x38486000, 0x38488000, 0x3848A000, 0x3848C000, 0x3848E000, 0x38490000, 0x38492000, + 0x38494000, 0x38496000, 0x38498000, 0x3849A000, 0x3849C000, 0x3849E000, 0x384A0000, + 0x384A2000, 0x384A4000, 0x384A6000, 0x384A8000, 0x384AA000, 0x384AC000, 0x384AE000, + 0x384B0000, 0x384B2000, 0x384B4000, 0x384B6000, 0x384B8000, 0x384BA000, 0x384BC000, + 0x384BE000, 0x384C0000, 0x384C2000, 0x384C4000, 0x384C6000, 0x384C8000, 0x384CA000, + 0x384CC000, 0x384CE000, 0x384D0000, 0x384D2000, 0x384D4000, 0x384D6000, 0x384D8000, + 0x384DA000, 0x384DC000, 0x384DE000, 0x384E0000, 0x384E2000, 0x384E4000, 0x384E6000, + 0x384E8000, 0x384EA000, 0x384EC000, 0x384EE000, 0x384F0000, 0x384F2000, 0x384F4000, + 0x384F6000, 0x384F8000, 0x384FA000, 0x384FC000, 0x384FE000, 0x38500000, 0x38502000, + 0x38504000, 0x38506000, 0x38508000, 0x3850A000, 0x3850C000, 0x3850E000, 0x38510000, + 0x38512000, 0x38514000, 0x38516000, 0x38518000, 0x3851A000, 0x3851C000, 0x3851E000, + 0x38520000, 0x38522000, 0x38524000, 0x38526000, 0x38528000, 0x3852A000, 0x3852C000, + 0x3852E000, 0x38530000, 0x38532000, 0x38534000, 0x38536000, 0x38538000, 0x3853A000, + 0x3853C000, 0x3853E000, 0x38540000, 0x38542000, 0x38544000, 0x38546000, 0x38548000, + 0x3854A000, 0x3854C000, 0x3854E000, 0x38550000, 0x38552000, 0x38554000, 0x38556000, + 0x38558000, 0x3855A000, 0x3855C000, 0x3855E000, 0x38560000, 0x38562000, 0x38564000, + 0x38566000, 0x38568000, 0x3856A000, 0x3856C000, 0x3856E000, 0x38570000, 0x38572000, + 0x38574000, 0x38576000, 0x38578000, 0x3857A000, 0x3857C000, 0x3857E000, 0x38580000, + 0x38582000, 0x38584000, 0x38586000, 0x38588000, 0x3858A000, 0x3858C000, 0x3858E000, + 0x38590000, 0x38592000, 0x38594000, 0x38596000, 0x38598000, 0x3859A000, 0x3859C000, + 0x3859E000, 0x385A0000, 0x385A2000, 0x385A4000, 0x385A6000, 0x385A8000, 0x385AA000, + 0x385AC000, 0x385AE000, 0x385B0000, 0x385B2000, 0x385B4000, 0x385B6000, 0x385B8000, + 0x385BA000, 0x385BC000, 0x385BE000, 0x385C0000, 0x385C2000, 0x385C4000, 0x385C6000, + 0x385C8000, 0x385CA000, 0x385CC000, 0x385CE000, 0x385D0000, 0x385D2000, 0x385D4000, + 0x385D6000, 0x385D8000, 0x385DA000, 0x385DC000, 0x385DE000, 0x385E0000, 0x385E2000, + 0x385E4000, 0x385E6000, 0x385E8000, 0x385EA000, 0x385EC000, 0x385EE000, 0x385F0000, + 0x385F2000, 0x385F4000, 0x385F6000, 0x385F8000, 0x385FA000, 0x385FC000, 0x385FE000, + 0x38600000, 0x38602000, 0x38604000, 0x38606000, 0x38608000, 0x3860A000, 0x3860C000, + 0x3860E000, 0x38610000, 0x38612000, 0x38614000, 0x38616000, 0x38618000, 0x3861A000, + 0x3861C000, 0x3861E000, 0x38620000, 0x38622000, 0x38624000, 0x38626000, 0x38628000, + 0x3862A000, 0x3862C000, 0x3862E000, 0x38630000, 0x38632000, 0x38634000, 0x38636000, + 0x38638000, 0x3863A000, 0x3863C000, 0x3863E000, 0x38640000, 0x38642000, 0x38644000, + 0x38646000, 0x38648000, 0x3864A000, 0x3864C000, 0x3864E000, 0x38650000, 0x38652000, + 0x38654000, 0x38656000, 0x38658000, 0x3865A000, 0x3865C000, 0x3865E000, 0x38660000, + 0x38662000, 0x38664000, 0x38666000, 0x38668000, 0x3866A000, 0x3866C000, 0x3866E000, + 0x38670000, 0x38672000, 0x38674000, 0x38676000, 0x38678000, 0x3867A000, 0x3867C000, + 0x3867E000, 0x38680000, 0x38682000, 0x38684000, 0x38686000, 0x38688000, 0x3868A000, + 0x3868C000, 0x3868E000, 0x38690000, 0x38692000, 0x38694000, 0x38696000, 0x38698000, + 0x3869A000, 0x3869C000, 0x3869E000, 0x386A0000, 0x386A2000, 0x386A4000, 0x386A6000, + 0x386A8000, 0x386AA000, 0x386AC000, 0x386AE000, 0x386B0000, 0x386B2000, 0x386B4000, + 0x386B6000, 0x386B8000, 0x386BA000, 0x386BC000, 0x386BE000, 0x386C0000, 0x386C2000, + 0x386C4000, 0x386C6000, 0x386C8000, 0x386CA000, 0x386CC000, 0x386CE000, 0x386D0000, + 0x386D2000, 0x386D4000, 0x386D6000, 0x386D8000, 0x386DA000, 0x386DC000, 0x386DE000, + 0x386E0000, 0x386E2000, 0x386E4000, 0x386E6000, 0x386E8000, 0x386EA000, 0x386EC000, + 0x386EE000, 0x386F0000, 0x386F2000, 0x386F4000, 0x386F6000, 0x386F8000, 0x386FA000, + 0x386FC000, 0x386FE000, 0x38700000, 0x38702000, 0x38704000, 0x38706000, 0x38708000, + 0x3870A000, 0x3870C000, 0x3870E000, 0x38710000, 0x38712000, 0x38714000, 0x38716000, + 0x38718000, 0x3871A000, 0x3871C000, 0x3871E000, 0x38720000, 0x38722000, 0x38724000, + 0x38726000, 0x38728000, 0x3872A000, 0x3872C000, 0x3872E000, 0x38730000, 0x38732000, + 0x38734000, 0x38736000, 0x38738000, 0x3873A000, 0x3873C000, 0x3873E000, 0x38740000, + 0x38742000, 0x38744000, 0x38746000, 0x38748000, 0x3874A000, 0x3874C000, 0x3874E000, + 0x38750000, 0x38752000, 0x38754000, 0x38756000, 0x38758000, 0x3875A000, 0x3875C000, + 0x3875E000, 0x38760000, 0x38762000, 0x38764000, 0x38766000, 0x38768000, 0x3876A000, + 0x3876C000, 0x3876E000, 0x38770000, 0x38772000, 0x38774000, 0x38776000, 0x38778000, + 0x3877A000, 0x3877C000, 0x3877E000, 0x38780000, 0x38782000, 0x38784000, 0x38786000, + 0x38788000, 0x3878A000, 0x3878C000, 0x3878E000, 0x38790000, 0x38792000, 0x38794000, + 0x38796000, 0x38798000, 0x3879A000, 0x3879C000, 0x3879E000, 0x387A0000, 0x387A2000, + 0x387A4000, 0x387A6000, 0x387A8000, 0x387AA000, 0x387AC000, 0x387AE000, 0x387B0000, + 0x387B2000, 0x387B4000, 0x387B6000, 0x387B8000, 0x387BA000, 0x387BC000, 0x387BE000, + 0x387C0000, 0x387C2000, 0x387C4000, 0x387C6000, 0x387C8000, 0x387CA000, 0x387CC000, + 0x387CE000, 0x387D0000, 0x387D2000, 0x387D4000, 0x387D6000, 0x387D8000, 0x387DA000, + 0x387DC000, 0x387DE000, 0x387E0000, 0x387E2000, 0x387E4000, 0x387E6000, 0x387E8000, + 0x387EA000, 0x387EC000, 0x387EE000, 0x387F0000, 0x387F2000, 0x387F4000, 0x387F6000, + 0x387F8000, 0x387FA000, 0x387FC000, 0x387FE000}; + static const bits::type exponent_table[64] = { + 0x00000000, 0x00800000, 0x01000000, 0x01800000, 0x02000000, 0x02800000, 0x03000000, + 0x03800000, 0x04000000, 0x04800000, 0x05000000, 0x05800000, 0x06000000, 0x06800000, + 0x07000000, 0x07800000, 0x08000000, 0x08800000, 0x09000000, 0x09800000, 0x0A000000, + 0x0A800000, 0x0B000000, 0x0B800000, 0x0C000000, 0x0C800000, 0x0D000000, 0x0D800000, + 0x0E000000, 0x0E800000, 0x0F000000, 0x47800000, 0x80000000, 0x80800000, 0x81000000, + 0x81800000, 0x82000000, 0x82800000, 0x83000000, 0x83800000, 0x84000000, 0x84800000, + 0x85000000, 0x85800000, 0x86000000, 0x86800000, 0x87000000, 0x87800000, 0x88000000, + 0x88800000, 0x89000000, 0x89800000, 0x8A000000, 0x8A800000, 0x8B000000, 0x8B800000, + 0x8C000000, 0x8C800000, 0x8D000000, 0x8D800000, 0x8E000000, 0x8E800000, 0x8F000000, + 0xC7800000}; + static const unsigned short offset_table[64] = { + 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, + 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, + 1024, 1024, 1024, 1024, 1024, 1024, 0, 1024, 1024, 1024, 1024, 1024, 1024, + 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, + 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024}; + bits::type fbits = + mantissa_table[offset_table[value >> 10] + (value & 0x3FF)] + exponent_table[value >> 10]; +#endif + float out; + std::memcpy(&out, &fbits, sizeof(float)); + return out; +#endif +} + +/// Convert half-precision to IEEE double-precision. +/// \param value half-precision value to convert +/// \return double-precision value +inline double half2float_impl(unsigned int value, double, true_type) +{ +#if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtsd_f64(_mm_cvtps_pd(_mm_cvtph_ps(_mm_cvtsi32_si128(value)))); +#else + uint32 hi = static_cast(value & 0x8000) << 16; + unsigned int abs = value & 0x7FFF; + if(abs) + { + hi |= 0x3F000000 << static_cast(abs >= 0x7C00); + for(; abs < 0x400; abs <<= 1, hi -= 0x100000) + ; + hi += static_cast(abs) << 10; + } + bits::type dbits = static_cast::type>(hi) << 32; + double out; + std::memcpy(&out, &dbits, sizeof(double)); + return out; +#endif +} + +/// Convert half-precision to non-IEEE floating-point. +/// \tparam T type to convert to (builtin integer type) +/// \param value half-precision value to convert +/// \return floating-point value +template +T half2float_impl(unsigned int value, T, ...) +{ + T out; + unsigned int abs = value & 0x7FFF; + if(abs > 0x7C00) + out = + (std::numeric_limits::has_signaling_NaN && !(abs & 0x200)) + ? std::numeric_limits::signaling_NaN() + : std::numeric_limits::has_quiet_NaN ? std::numeric_limits::quiet_NaN() : T(); + else if(abs == 0x7C00) + out = std::numeric_limits::has_infinity ? std::numeric_limits::infinity() + : std::numeric_limits::max(); + else if(abs > 0x3FF) + out = std::ldexp(static_cast((abs & 0x3FF) | 0x400), (abs >> 10) - 25); + else + out = std::ldexp(static_cast(abs), -24); + return (value & 0x8000) ? -out : out; +} + +/// Convert half-precision to floating-point. +/// \tparam T type to convert to (builtin integer type) +/// \param value half-precision value to convert +/// \return floating-point value +template +T half2float(unsigned int value) +{ + return half2float_impl(value, + T(), + bool_type < std::numeric_limits::is_iec559 && + sizeof(typename bits::type) == sizeof(T) > ()); +} + +/// Convert half-precision floating-point to integer. +/// \tparam R rounding mode to use +/// \tparam E `true` for round to even, `false` for round away from zero +/// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never raise it +/// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding +/// any implicit sign bits) +/// \param value half-precision value to convert +/// \return rounded integer value +/// \exception FE_INVALID if value is not representable in type \a T +/// \exception FE_INEXACT if value had to be rounded and \a I is `true` +template +T half2int(unsigned int value) +{ + unsigned int abs = value & 0x7FFF; + if(abs >= 0x7C00) + { + raise(FE_INVALID); + return (value & 0x8000) ? std::numeric_limits::min() : std::numeric_limits::max(); + } + if(abs < 0x3800) + { + raise(FE_INEXACT, I); + return (R == std::round_toward_infinity) + ? T(~(value >> 15) & (abs != 0)) + : (R == std::round_toward_neg_infinity) ? -T(value > 0x8000) : T(); + } + int exp = 25 - (abs >> 10); + unsigned int m = (value & 0x3FF) | 0x400; + int32 i = static_cast( + (exp <= 0) + ? (m << -exp) + : ((m + ((R == std::round_to_nearest) ? ((1 << (exp - 1)) - (~(m >> exp) & E)) + : (R == std::round_toward_infinity) + ? (((1 << exp) - 1) & ((value >> 15) - 1)) + : (R == std::round_toward_neg_infinity) + ? (((1 << exp) - 1) & -(value >> 15)) + : 0)) >> + exp)); + if((!std::numeric_limits::is_signed && (value & 0x8000)) || + (std::numeric_limits::digits < 16 && + ((value & 0x8000) ? (-i < std::numeric_limits::min()) + : (i > std::numeric_limits::max())))) + raise(FE_INVALID); + else if(I && exp > 0 && (m & ((1 << exp) - 1))) + raise(FE_INEXACT); + return static_cast((value & 0x8000) ? -i : i); +} + +/// \} +/// \name Mathematics +/// \{ + +/// upper part of 64-bit multiplication. +/// \tparam R rounding mode to use +/// \param x first factor +/// \param y second factor +/// \return upper 32 bit of \a x * \a y +template +uint32 mulhi(uint32 x, uint32 y) +{ + uint32 xy = (x >> 16) * (y & 0xFFFF), yx = (x & 0xFFFF) * (y >> 16), + c = (xy & 0xFFFF) + (yx & 0xFFFF) + (((x & 0xFFFF) * (y & 0xFFFF)) >> 16); + return (x >> 16) * (y >> 16) + (xy >> 16) + (yx >> 16) + (c >> 16) + + ((R == std::round_to_nearest) + ? ((c >> 15) & 1) + : (R == std::round_toward_infinity) ? ((c & 0xFFFF) != 0) : 0); +} + +/// 64-bit multiplication. +/// \param x first factor +/// \param y second factor +/// \return upper 32 bit of \a x * \a y rounded to nearest +inline uint32 multiply64(uint32 x, uint32 y) +{ +#if HALF_ENABLE_CPP11_LONG_LONG + return static_cast( + (static_cast(x) * static_cast(y) + 0x80000000) >> + 32); +#else + return mulhi(x, y); +#endif +} + +/// 64-bit division. +/// \param x upper 32 bit of dividend +/// \param y divisor +/// \param s variable to store sticky bit for rounding +/// \return (\a x << 32) / \a y +inline uint32 divide64(uint32 x, uint32 y, int& s) +{ +#if HALF_ENABLE_CPP11_LONG_LONG + unsigned long long xx = static_cast(x) << 32; + return s = (xx % y != 0), static_cast(xx / y); +#else + y >>= 1; + uint32 rem = x, div = 0; + for(unsigned int i = 0; i < 32; ++i) + { + div <<= 1; + if(rem >= y) + { + rem -= y; + div |= 1; + } + rem <<= 1; + } + return s = rem > 1, div; +#endif +} + +/// Half precision positive modulus. +/// \tparam Q `true` to compute full quotient, `false` else +/// \tparam R `true` to compute signed remainder, `false` for positive remainder +/// \param x first operand as positive finite half-precision value +/// \param y second operand as positive finite half-precision value +/// \param quo adress to store quotient at, `nullptr` if \a Q `false` +/// \return modulus of \a x / \a y +template +unsigned int mod(unsigned int x, unsigned int y, int* quo = NULL) +{ + unsigned int q = 0; + if(x > y) + { + int absx = x, absy = y, expx = 0, expy = 0; + for(; absx < 0x400; absx <<= 1, --expx) + ; + for(; absy < 0x400; absy <<= 1, --expy) + ; + expx += absx >> 10; + expy += absy >> 10; + int mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400; + for(int d = expx - expy; d; --d) + { + if(!Q && mx == my) + return 0; + if(mx >= my) + { + mx -= my; + q += Q; + } + mx <<= 1; + q <<= static_cast(Q); + } + if(!Q && mx == my) + return 0; + if(mx >= my) + { + mx -= my; + ++q; + } + if(Q) + { + q &= (1 << (std::numeric_limits::digits - 1)) - 1; + if(!mx) + return *quo = q, 0; + } + for(; mx < 0x400; mx <<= 1, --expy) + ; + x = (expy > 0) ? ((expy << 10) | (mx & 0x3FF)) : (mx >> (1 - expy)); + } + if(R) + { + unsigned int a, b; + if(y < 0x800) + { + a = (x < 0x400) ? (x << 1) : (x + 0x400); + b = y; + } + else + { + a = x; + b = y - 0x400; + } + if(a > b || (a == b && (q & 1))) + { + int exp = (y >> 10) + (y <= 0x3FF), d = exp - (x >> 10) - (x <= 0x3FF); + int m = (((y & 0x3FF) | ((y > 0x3FF) << 10)) << 1) - + (((x & 0x3FF) | ((x > 0x3FF) << 10)) << (1 - d)); + for(; m < 0x800 && exp > 1; m <<= 1, --exp) + ; + x = 0x8000 + ((exp - 1) << 10) + (m >> 1); + q += Q; + } + } + if(Q) + *quo = q; + return x; +} + +/// Fixed point square root. +/// \tparam F number of fractional bits +/// \param r radicand in Q1.F fixed point format +/// \param exp exponent +/// \return square root as Q1.F/2 +template +uint32 sqrt(uint32& r, int& exp) +{ + int i = exp & 1; + r <<= i; + exp = (exp - i) / 2; + uint32 m = 0; + for(uint32 bit = static_cast(1) << F; bit; bit >>= 2) + { + if(r < m + bit) + m >>= 1; + else + { + r -= m + bit; + m = (m >> 1) + bit; + } + } + return m; +} + +/// Fixed point binary exponential. +/// This uses the BKM algorithm in E-mode. +/// \param m exponent in [0,1) as Q0.31 +/// \param n number of iterations (at most 32) +/// \return 2 ^ \a m as Q1.31 +inline uint32 exp2(uint32 m, unsigned int n = 32) +{ + static const uint32 logs[] = { + 0x80000000, 0x4AE00D1D, 0x2934F098, 0x15C01A3A, 0x0B31FB7D, 0x05AEB4DD, 0x02DCF2D1, + 0x016FE50B, 0x00B84E23, 0x005C3E10, 0x002E24CA, 0x001713D6, 0x000B8A47, 0x0005C53B, + 0x0002E2A3, 0x00017153, 0x0000B8AA, 0x00005C55, 0x00002E2B, 0x00001715, 0x00000B8B, + 0x000005C5, 0x000002E3, 0x00000171, 0x000000B9, 0x0000005C, 0x0000002E, 0x00000017, + 0x0000000C, 0x00000006, 0x00000003, 0x00000001}; + if(!m) + return 0x80000000; + uint32 mx = 0x80000000, my = 0; + for(unsigned int i = 1; i < n; ++i) + { + uint32 mz = my + logs[i]; + if(mz <= m) + { + my = mz; + mx += mx >> i; + } + } + return mx; +} + +/// Fixed point binary logarithm. +/// This uses the BKM algorithm in L-mode. +/// \param m mantissa in [1,2) as Q1.30 +/// \param n number of iterations (at most 32) +/// \return log2(\a m) as Q0.31 +inline uint32 log2(uint32 m, unsigned int n = 32) +{ + static const uint32 logs[] = { + 0x80000000, 0x4AE00D1D, 0x2934F098, 0x15C01A3A, 0x0B31FB7D, 0x05AEB4DD, 0x02DCF2D1, + 0x016FE50B, 0x00B84E23, 0x005C3E10, 0x002E24CA, 0x001713D6, 0x000B8A47, 0x0005C53B, + 0x0002E2A3, 0x00017153, 0x0000B8AA, 0x00005C55, 0x00002E2B, 0x00001715, 0x00000B8B, + 0x000005C5, 0x000002E3, 0x00000171, 0x000000B9, 0x0000005C, 0x0000002E, 0x00000017, + 0x0000000C, 0x00000006, 0x00000003, 0x00000001}; + if(m == 0x40000000) + return 0; + uint32 mx = 0x40000000, my = 0; + for(unsigned int i = 1; i < n; ++i) + { + uint32 mz = mx + (mx >> i); + if(mz <= m) + { + mx = mz; + my += logs[i]; + } + } + return my; +} + +/// Fixed point sine and cosine. +/// This uses the CORDIC algorithm in rotation mode. +/// \param mz angle in [-pi/2,pi/2] as Q1.30 +/// \param n number of iterations (at most 31) +/// \return sine and cosine of \a mz as Q1.30 +inline std::pair sincos(uint32 mz, unsigned int n = 31) +{ + static const uint32 angles[] = { + 0x3243F6A9, 0x1DAC6705, 0x0FADBAFD, 0x07F56EA7, 0x03FEAB77, 0x01FFD55C, 0x00FFFAAB, + 0x007FFF55, 0x003FFFEB, 0x001FFFFD, 0x00100000, 0x00080000, 0x00040000, 0x00020000, + 0x00010000, 0x00008000, 0x00004000, 0x00002000, 0x00001000, 0x00000800, 0x00000400, + 0x00000200, 0x00000100, 0x00000080, 0x00000040, 0x00000020, 0x00000010, 0x00000008, + 0x00000004, 0x00000002, 0x00000001}; + uint32 mx = 0x26DD3B6A, my = 0; + for(unsigned int i = 0; i < n; ++i) + { + uint32 sign = sign_mask(mz); + uint32 tx = mx - (arithmetic_shift(my, i) ^ sign) + sign; + uint32 ty = my + (arithmetic_shift(mx, i) ^ sign) - sign; + mx = tx; + my = ty; + mz -= (angles[i] ^ sign) - sign; + } + return std::make_pair(my, mx); +} + +/// Fixed point arc tangent. +/// This uses the CORDIC algorithm in vectoring mode. +/// \param my y coordinate as Q0.30 +/// \param mx x coordinate as Q0.30 +/// \param n number of iterations (at most 31) +/// \return arc tangent of \a my / \a mx as Q1.30 +inline uint32 atan2(uint32 my, uint32 mx, unsigned int n = 31) +{ + static const uint32 angles[] = { + 0x3243F6A9, 0x1DAC6705, 0x0FADBAFD, 0x07F56EA7, 0x03FEAB77, 0x01FFD55C, 0x00FFFAAB, + 0x007FFF55, 0x003FFFEB, 0x001FFFFD, 0x00100000, 0x00080000, 0x00040000, 0x00020000, + 0x00010000, 0x00008000, 0x00004000, 0x00002000, 0x00001000, 0x00000800, 0x00000400, + 0x00000200, 0x00000100, 0x00000080, 0x00000040, 0x00000020, 0x00000010, 0x00000008, + 0x00000004, 0x00000002, 0x00000001}; + uint32 mz = 0; + for(unsigned int i = 0; i < n; ++i) + { + uint32 sign = sign_mask(my); + uint32 tx = mx + (arithmetic_shift(my, i) ^ sign) - sign; + uint32 ty = my - (arithmetic_shift(mx, i) ^ sign) + sign; + mx = tx; + my = ty; + mz += (angles[i] ^ sign) - sign; + } + return mz; +} + +/// Reduce argument for trigonometric functions. +/// \param abs half-precision floating-point value +/// \param k value to take quarter period +/// \return \a abs reduced to [-pi/4,pi/4] as Q0.30 +inline uint32 angle_arg(unsigned int abs, int& k) +{ + uint32 m = (abs & 0x3FF) | ((abs > 0x3FF) << 10); + int exp = (abs >> 10) + (abs <= 0x3FF) - 15; + if(abs < 0x3A48) + return k = 0, m << (exp + 20); +#if HALF_ENABLE_CPP11_LONG_LONG + unsigned long long y = m * 0xA2F9836E4E442, mask = (1ULL << (62 - exp)) - 1, + yi = (y + (mask >> 1)) & ~mask, f = y - yi; + uint32 sign = -static_cast(f >> 63); + k = static_cast(yi >> (62 - exp)); + return (multiply64(static_cast((sign ? -f : f) >> (31 - exp)), 0xC90FDAA2) ^ sign) - + sign; +#else + uint32 yh = m * 0xA2F98 + mulhi(m, 0x36E4E442), + yl = (m * 0x36E4E442) & 0xFFFFFFFF; + uint32 mask = (static_cast(1) << (30 - exp)) - 1, yi = (yh + (mask >> 1)) & ~mask, + sign = -static_cast(yi > yh); + k = static_cast(yi >> (30 - exp)); + uint32 fh = (yh ^ sign) + (yi ^ ~sign) - ~sign, fl = (yl ^ sign) - sign; + return (multiply64((exp > -1) + ? (((fh << (1 + exp)) & 0xFFFFFFFF) | ((fl & 0xFFFFFFFF) >> (31 - exp))) + : fh, + 0xC90FDAA2) ^ + sign) - + sign; +#endif +} + +/// Get arguments for atan2 function. +/// \param abs half-precision floating-point value +/// \return \a abs and sqrt(1 - \a abs^2) as Q0.30 +inline std::pair atan2_args(unsigned int abs) +{ + int exp = -15; + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + uint32 my = ((abs & 0x3FF) | 0x400) << 5, r = my * my; + int rexp = 2 * exp; + r = 0x40000000 - + ((rexp > -31) ? ((r >> -rexp) | ((r & ((static_cast(1) << -rexp) - 1)) != 0)) : 1); + for(rexp = 0; r < 0x40000000; r <<= 1, --rexp) + ; + uint32 mx = sqrt<30>(r, rexp); + int d = exp - rexp; + if(d < 0) + return std::make_pair((d < -14) ? ((my >> (-d - 14)) + ((my >> (-d - 15)) & 1)) + : (my << (14 + d)), + (mx << 14) + (r << 13) / mx); + if(d > 0) + return std::make_pair(my << 14, + (d > 14) + ? ((mx >> (d - 14)) + ((mx >> (d - 15)) & 1)) + : ((d == 14) ? mx : ((mx << (14 - d)) + (r << (13 - d)) / mx))); + return std::make_pair(my << 13, (mx << 13) + (r << 12) / mx); +} + +/// Get exponentials for hyperbolic computation +/// \param abs half-precision floating-point value +/// \param exp variable to take unbiased exponent of larger result +/// \param n number of BKM iterations (at most 32) +/// \return exp(abs) and exp(-\a abs) as Q1.31 with same exponent +inline std::pair hyperbolic_args(unsigned int abs, int& exp, unsigned int n = 32) +{ + uint32 mx = detail::multiply64(static_cast((abs & 0x3FF) + ((abs > 0x3FF) << 10)) << 21, + 0xB8AA3B29), + my; + int e = (abs >> 10) + (abs <= 0x3FF); + if(e < 14) + { + exp = 0; + mx >>= 14 - e; + } + else + { + exp = mx >> (45 - e); + mx = (mx << (e - 14)) & 0x7FFFFFFF; + } + mx = exp2(mx, n); + int d = exp << 1, s; + if(mx > 0x80000000) + { + my = divide64(0x80000000, mx, s); + my |= s; + ++d; + } + else + my = mx; + return std::make_pair( + mx, (d < 31) ? ((my >> d) | ((my & ((static_cast(1) << d) - 1)) != 0)) : 1); +} + +/// Postprocessing for binary exponential. +/// \tparam R rounding mode to use +/// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results +/// \param m mantissa as Q1.31 +/// \param exp absolute value of unbiased exponent +/// \param esign sign of actual exponent +/// \param sign sign bit of result +/// \return value converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded or \a I is `true` +template +unsigned int exp2_post(uint32 m, int exp, bool esign, unsigned int sign = 0) +{ + int s = 0; + if(esign) + { + if(m > 0x80000000) + { + m = divide64(0x80000000, m, s); + ++exp; + } + if(exp > 25) + return underflow(sign); + else if(exp == 25) + return rounded(sign, 1, (m & 0x7FFFFFFF) != 0); + exp = -exp; + } + else if(exp > 15) + return overflow(sign); + return fixed2half(m, exp + 14, sign, s); +} + +/// Postprocessing for binary logarithm. +/// \tparam R rounding mode to use +/// \tparam L logarithm for base transformation as Q1.31 +/// \param m fractional part of logarithm as Q0.31 +/// \param ilog signed integer part of logarithm +/// \param exp biased exponent of result +/// \param sign sign bit of result +/// \return value base-transformed and converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if no other exception occurred +template +unsigned int log2_post(uint32 m, int ilog, int exp, unsigned int sign = 0) +{ + uint32 msign = sign_mask(ilog); + m = (((static_cast(ilog) << 27) + (m >> 4)) ^ msign) - msign; + if(!m) + return 0; + for(; m < 0x80000000; m <<= 1, --exp) + ; + int i = m >= L, s; + exp += i; + m >>= 1 + i; + sign ^= msign & 0x8000; + if(exp < -11) + return underflow(sign); + m = divide64(m, L, s); + return fixed2half(m, exp, sign, 1); +} + +/// Hypotenuse square root and postprocessing. +/// \tparam R rounding mode to use +/// \param r mantissa as Q2.30 +/// \param exp unbiased exponent +/// \return square root converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int hypot_post(uint32 r, int exp) +{ + int i = r >> 31; + if((exp += i) > 46) + return overflow(); + if(exp < -34) + return underflow(); + r = (r >> i) | (r & i); + uint32 m = sqrt<30>(r, exp += 15); + return fixed2half(m, exp - 1, 0, r != 0); +} + +/// Division and postprocessing for tangents. +/// \tparam R rounding mode to use +/// \param my dividend as Q1.31 +/// \param mx divisor as Q1.31 +/// \param exp biased exponent of result +/// \param sign sign bit of result +/// \return quotient converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if no other exception occurred +template +unsigned int tangent_post(uint32 my, uint32 mx, int exp, unsigned int sign = 0) +{ + int i = my >= mx, s; + exp += i; + if(exp > 29) + return overflow(sign); + if(exp < -11) + return underflow(sign); + uint32 m = divide64(my >> (i + 1), mx, s); + return fixed2half(m, exp, sign, s); +} + +/// Area function and postprocessing. +/// This computes the value directly in Q2.30 using the representation `asinh|acosh(x) = +/// log(x+sqrt(x^2+|-1))`. +/// \tparam R rounding mode to use +/// \tparam S `true` for asinh, `false` for acosh +/// \param arg half-precision argument +/// \return asinh|acosh(\a arg) converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if no other exception occurred +template +unsigned int area(unsigned int arg) +{ + int abs = arg & 0x7FFF, expx = (abs >> 10) + (abs <= 0x3FF) - 15, expy = -15, ilog, i; + uint32 mx = static_cast((abs & 0x3FF) | ((abs > 0x3FF) << 10)) << 20, my, r; + for(; abs < 0x400; abs <<= 1, --expy) + ; + expy += abs >> 10; + r = ((abs & 0x3FF) | 0x400) << 5; + r *= r; + i = r >> 31; + expy = 2 * expy + i; + r >>= i; + if(S) + { + if(expy < 0) + { + r = 0x40000000 + ((expy > -30) ? ((r >> -expy) | + ((r & ((static_cast(1) << -expy) - 1)) != 0)) + : 1); + expy = 0; + } + else + { + r += 0x40000000 >> expy; + i = r >> 31; + r = (r >> i) | (r & i); + expy += i; + } + } + else + { + r -= 0x40000000 >> expy; + for(; r < 0x40000000; r <<= 1, --expy) + ; + } + my = sqrt<30>(r, expy); + my = (my << 15) + (r << 14) / my; + if(S) + { + mx >>= expy - expx; + ilog = expy; + } + else + { + my >>= expx - expy; + ilog = expx; + } + my += mx; + i = my >> 31; + static const int G = S && (R == std::round_to_nearest); + return log2_post( + log2(my >> i, 26 + S + G) + (G << 3), ilog + i, 17, arg & (static_cast(S) << 15)); +} + +/// Class for 1.31 unsigned floating-point computation +struct f31 +{ + /// Constructor. + /// \param mant mantissa as 1.31 + /// \param e exponent + HALF_CONSTEXPR f31(uint32 mant, int e) : m(mant), exp(e) {} + + /// Constructor. + /// \param abs unsigned half-precision value + f31(unsigned int abs) : exp(-15) + { + for(; abs < 0x400; abs <<= 1, --exp) + ; + m = static_cast((abs & 0x3FF) | 0x400) << 21; + exp += (abs >> 10); + } + + /// Addition operator. + /// \param a first operand + /// \param b second operand + /// \return \a a + \a b + friend f31 operator+(f31 a, f31 b) + { + if(b.exp > a.exp) + std::swap(a, b); + int d = a.exp - b.exp; + uint32 m = a.m + ((d < 32) ? (b.m >> d) : 0); + int i = (m & 0xFFFFFFFF) < a.m; + return f31(((m + i) >> i) | 0x80000000, a.exp + i); + } + + /// Subtraction operator. + /// \param a first operand + /// \param b second operand + /// \return \a a - \a b + friend f31 operator-(f31 a, f31 b) + { + int d = a.exp - b.exp, exp = a.exp; + uint32 m = a.m - ((d < 32) ? (b.m >> d) : 0); + if(!m) + return f31(0, -32); + for(; m < 0x80000000; m <<= 1, --exp) + ; + return f31(m, exp); + } + + /// Multiplication operator. + /// \param a first operand + /// \param b second operand + /// \return \a a * \a b + friend f31 operator*(f31 a, f31 b) + { + uint32 m = multiply64(a.m, b.m); + int i = m >> 31; + return f31(m << (1 - i), a.exp + b.exp + i); + } + + /// Division operator. + /// \param a first operand + /// \param b second operand + /// \return \a a / \a b + friend f31 operator/(f31 a, f31 b) + { + int i = a.m >= b.m, s; + uint32 m = divide64((a.m + i) >> i, b.m, s); + return f31(m, a.exp - b.exp + i - 1); + } + + uint32 m; ///< mantissa as 1.31. + int exp; ///< exponent. +}; + +/// Error function and postprocessing. +/// This computes the value directly in Q1.31 using the approximations given +/// [here](https://en.wikipedia.org/wiki/Error_function#Approximation_with_elementary_functions). +/// \tparam R rounding mode to use +/// \tparam C `true` for comlementary error function, `false` else +/// \param arg half-precision function argument +/// \return approximated value of error function in half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if no other exception occurred +template +unsigned int erf(unsigned int arg) +{ + unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; + f31 x(abs), x2 = x * x * f31(0xB8AA3B29, 0), + t = f31(0x80000000, 0) / (f31(0x80000000, 0) + f31(0xA7BA054A, -2) * x), t2 = t * t; + f31 e = ((f31(0x87DC2213, 0) * t2 + f31(0xB5F0E2AE, 0)) * t2 + f31(0x82790637, -2) - + (f31(0xBA00E2B8, 0) * t2 + f31(0x91A98E62, -2)) * t) * + t / + ((x2.exp < 0) ? f31(exp2((x2.exp > -32) ? (x2.m >> -x2.exp) : 0, 30), 0) + : f31(exp2((x2.m << x2.exp) & 0x7FFFFFFF, 22), x2.m >> (31 - x2.exp))); + return (!C || sign) + ? fixed2half( + 0x80000000 - (e.m >> (C - e.exp)), 14 + C, sign & (C - 1U)) + : (e.exp < -25) + ? underflow() + : fixed2half(e.m >> 1, e.exp + 14, 0, e.m & 1); +} + +/// Gamma function and postprocessing. +/// This approximates the value of either the gamma function or its logarithm directly in Q1.31. +/// \tparam R rounding mode to use +/// \tparam L `true` for lograithm of gamma function, `false` for gamma function +/// \param arg half-precision floating-point value +/// \return lgamma/tgamma(\a arg) in half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if \a arg is not a positive integer +template +unsigned int gamma(unsigned int arg) +{ + /* static const double p[] ={ 2.50662827563479526904, 225.525584619175212544, + -268.295973841304927459, 80.9030806934622512966, -5.00757863970517583837, + 0.0114684895434781459556 }; double t = arg + 4.65, s = p[0]; for(unsigned int i=0; i<5; ++i) + s += p[i+1] / (arg+i); + return std::log(s) + (arg-0.5)*std::log(t) - t; +*/ static const f31 pi(0xC90FDAA2, 1), + lbe(0xB8AA3B29, 0); + unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; + bool bsign = sign != 0; + f31 z(abs), x = sign ? (z + f31(0x80000000, 0)) : z, t = x + f31(0x94CCCCCD, 2), + s = f31(0xA06C9901, 1) + f31(0xBBE654E2, -7) / (x + f31(0x80000000, 2)) + + f31(0xA1CE6098, 6) / (x + f31(0x80000000, 1)) + f31(0xE1868CB7, 7) / x - + f31(0x8625E279, 8) / (x + f31(0x80000000, 0)) - + f31(0xA03E158F, 2) / (x + f31(0xC0000000, 1)); + int i = (s.exp >= 2) + (s.exp >= 4) + (s.exp >= 8) + (s.exp >= 16); + s = f31((static_cast(s.exp) << (31 - i)) + (log2(s.m >> 1, 28) >> i), i) / lbe; + if(x.exp != -1 || x.m != 0x80000000) + { + i = (t.exp >= 2) + (t.exp >= 4) + (t.exp >= 8); + f31 l = f31((static_cast(t.exp) << (31 - i)) + (log2(t.m >> 1, 30) >> i), i) / lbe; + s = (x.exp < -1) ? (s - (f31(0x80000000, -1) - x) * l) + : (s + (x - f31(0x80000000, -1)) * l); + } + s = x.exp ? (s - t) : (t - s); + if(bsign) + { + if(z.exp >= 0) + { + sign &= (L | ((z.m >> (31 - z.exp)) & 1)) - 1; + for(z = f31((z.m << (1 + z.exp)) & 0xFFFFFFFF, -1); z.m < 0x80000000; + z.m <<= 1, --z.exp) + ; + } + if(z.exp == -1) + z = f31(0x80000000, 0) - z; + if(z.exp < -1) + { + z = z * pi; + z.m = sincos(z.m >> (1 - z.exp), 30).first; + for(z.exp = 1; z.m < 0x80000000; z.m <<= 1, --z.exp) + ; + } + else + z = f31(0x80000000, 0); + } + if(L) + { + if(bsign) + { + f31 l(0x92868247, 0); + if(z.exp < 0) + { + uint32 m = log2((z.m + 1) >> 1, 27); + z = f31(-((static_cast(z.exp) << 26) + (m >> 5)), 5); + for(; z.m < 0x80000000; z.m <<= 1, --z.exp) + ; + l = l + z / lbe; + } + sign = static_cast(x.exp && (l.exp < s.exp || (l.exp == s.exp && l.m < s.m))) + << 15; + s = sign ? (s - l) : x.exp ? (l - s) : (l + s); + } + else + { + sign = static_cast(x.exp == 0) << 15; + if(s.exp < -24) + return underflow(sign); + if(s.exp > 15) + return overflow(sign); + } + } + else + { + s = s * lbe; + uint32 m; + if(s.exp < 0) + { + m = s.m >> -s.exp; + s.exp = 0; + } + else + { + m = (s.m << s.exp) & 0x7FFFFFFF; + s.exp = (s.m >> (31 - s.exp)); + } + s.m = exp2(m, 27); + if(!x.exp) + s = f31(0x80000000, 0) / s; + if(bsign) + { + if(z.exp < 0) + s = s * z; + s = pi / s; + if(s.exp < -24) + return underflow(sign); + } + else if(z.exp > 0 && !(z.m & ((1 << (31 - z.exp)) - 1))) + return ((s.exp + 14) << 10) + (s.m >> 21); + if(s.exp > 15) + return overflow(sign); + } + return fixed2half(s.m, s.exp + 14, sign); +} +/// \} + +template +struct half_caster; +} // namespace detail + +/// Half-precision floating-point type. +/// This class implements an IEEE-conformant half-precision floating-point type with the usual +/// arithmetic +/// operators and conversions. It is implicitly convertible to single-precision floating-point, +/// which makes artihmetic +/// expressions and functions with mixed-type operands to be of the most precise operand type. +/// +/// According to the C++98/03 definition, the half type is not a POD type. But according to C++11's +/// less strict and +/// extended definitions it is both a standard layout type and a trivially copyable type (even if +/// not a POD type), which +/// means it can be standard-conformantly copied using raw binary copies. But in this context some +/// more words about the +/// actual size of the type. Although the half is representing an IEEE 16-bit type, it does not +/// neccessarily have to be of +/// exactly 16-bits size. But on any reasonable implementation the actual binary representation of +/// this type will most +/// probably not ivolve any additional "magic" or padding beyond the simple binary representation of +/// the underlying 16-bit +/// IEEE number, even if not strictly guaranteed by the standard. But even then it only has an +/// actual size of 16 bits if +/// your C++ implementation supports an unsigned integer type of exactly 16 bits width. But this +/// should be the case on +/// nearly any reasonable platform. +/// +/// So if your C++ implementation is not totally exotic or imposes special alignment requirements, +/// it is a reasonable +/// assumption that the data of a half is just comprised of the 2 bytes of the underlying IEEE +/// representation. +class half +{ + public: + /// \name Construction and assignment + /// \{ + + /// Default constructor. + /// This initializes the half to 0. Although this does not match the builtin types' + /// default-initialization semantics + /// and may be less efficient than no initialization, it is needed to provide proper + /// value-initialization semantics. + HALF_CONSTEXPR half() HALF_NOEXCEPT : data_() {} + + /// Conversion constructor. + /// \param rhs float to convert + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + explicit half(float rhs) + : data_(static_cast(detail::float2half(rhs))) + { + } + + /// Conversion to single-precision. + /// \return single precision value representing expression value + operator float() const { return detail::half2float(data_); } + + /// Assignment operator. + /// \param rhs single-precision value to copy from + /// \return reference to this half + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + half& operator=(float rhs) + { + data_ = static_cast(detail::float2half(rhs)); + return *this; + } + + /// \} + /// \name Arithmetic updates + /// \{ + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to add + /// \return reference to this half + /// \exception FE_... according to operator+(half,half) + half& operator+=(half rhs) { return *this = *this + rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to subtract + /// \return reference to this half + /// \exception FE_... according to operator-(half,half) + half& operator-=(half rhs) { return *this = *this - rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to multiply with + /// \return reference to this half + /// \exception FE_... according to operator*(half,half) + half& operator*=(half rhs) { return *this = *this * rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to divide by + /// \return reference to this half + /// \exception FE_... according to operator/(half,half) + half& operator/=(half rhs) { return *this = *this / rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to add + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator+=(float rhs) { return *this = *this + rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to subtract + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator-=(float rhs) { return *this = *this - rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to multiply with + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator*=(float rhs) { return *this = *this * rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to divide by + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator/=(float rhs) { return *this = *this / rhs; } + + /// \} + /// \name Increment and decrement + /// \{ + + /// Prefix increment. + /// \return incremented half value + /// \exception FE_... according to operator+(half,half) + half& operator++() { return *this = *this + half(detail::binary, 0x3C00); } + + /// Prefix decrement. + /// \return decremented half value + /// \exception FE_... according to operator-(half,half) + half& operator--() { return *this = *this + half(detail::binary, 0xBC00); } + + /// Postfix increment. + /// \return non-incremented half value + /// \exception FE_... according to operator+(half,half) + half operator++(int) + { + half out(*this); + ++*this; + return out; + } + + /// Postfix decrement. + /// \return non-decremented half value + /// \exception FE_... according to operator-(half,half) + half operator--(int) + { + half out(*this); + --*this; + return out; + } + /// \} + + private: + /// Rounding mode to use + static const std::float_round_style round_style = (std::float_round_style)(HALF_ROUND_STYLE); + + /// Constructor. + /// \param bits binary representation to set half to + HALF_CONSTEXPR half(detail::binary_t, unsigned int bits) HALF_NOEXCEPT + : data_(static_cast(bits)) + { + } + + /// Internal binary representation + detail::uint16 data_; + +#ifndef HALF_DOXYGEN_ONLY + friend HALF_CONSTEXPR_NOERR bool operator==(half, half); + friend HALF_CONSTEXPR_NOERR bool operator!=(half, half); + friend HALF_CONSTEXPR_NOERR bool operator<(half, half); + friend HALF_CONSTEXPR_NOERR bool operator>(half, half); + friend HALF_CONSTEXPR_NOERR bool operator<=(half, half); + friend HALF_CONSTEXPR_NOERR bool operator>=(half, half); + friend HALF_CONSTEXPR half operator-(half); + friend half operator+(half, half); + friend half operator-(half, half); + friend half operator*(half, half); + friend half operator/(half, half); + template + friend std::basic_ostream& operator<<(std::basic_ostream&, half); + template + friend std::basic_istream& operator>>(std::basic_istream&, half&); + friend HALF_CONSTEXPR half fabs(half); + friend half fmod(half, half); + friend half remainder(half, half); + friend half remquo(half, half, int*); + friend half fma(half, half, half); + friend HALF_CONSTEXPR_NOERR half fmax(half, half); + friend HALF_CONSTEXPR_NOERR half fmin(half, half); + friend half fdim(half, half); + friend half nanh(const char*); + friend half exp(half); + friend half exp2(half); + friend half expm1(half); + friend half log(half); + friend half log10(half); + friend half log2(half); + friend half log1p(half); + friend half sqrt(half); + friend half cbrt(half); + friend half hypot(half, half); + friend half hypot(half, half, half); + friend half pow(half, half); + friend void sincos(half, half*, half*); + friend half sin(half); + friend half cos(half); + friend half tan(half); + friend half asin(half); + friend half acos(half); + friend half atan(half); + friend half atan2(half, half); + friend half sinh(half); + friend half cosh(half); + friend half tanh(half); + friend half asinh(half); + friend half acosh(half); + friend half atanh(half); + friend half erf(half); + friend half erfc(half); + friend half lgamma(half); + friend half tgamma(half); + friend half ceil(half); + friend half floor(half); + friend half trunc(half); + friend half round(half); + friend long lround(half); + friend half rint(half); + friend long lrint(half); + friend half nearbyint(half); +#ifdef HALF_ENABLE_CPP11_LONG_LONG + friend long long llround(half); + friend long long llrint(half); +#endif + friend half frexp(half, int*); + friend half scalbln(half, long); + friend half modf(half, half*); + friend int ilogb(half); + friend half logb(half); + friend half nextafter(half, half); + friend half nexttoward(half, long double); + friend HALF_CONSTEXPR half copysign(half, half); + friend HALF_CONSTEXPR int fpclassify(half); + friend HALF_CONSTEXPR bool isfinite(half); + friend HALF_CONSTEXPR bool isinf(half); + friend HALF_CONSTEXPR bool isnan(half); + friend HALF_CONSTEXPR bool isnormal(half); + friend HALF_CONSTEXPR bool signbit(half); + friend HALF_CONSTEXPR bool isgreater(half, half); + friend HALF_CONSTEXPR bool isgreaterequal(half, half); + friend HALF_CONSTEXPR bool isless(half, half); + friend HALF_CONSTEXPR bool islessequal(half, half); + friend HALF_CONSTEXPR bool islessgreater(half, half); + template + friend struct detail::half_caster; + friend class std::numeric_limits; +#if HALF_ENABLE_CPP11_HASH + friend struct std::hash; +#endif +#if HALF_ENABLE_CPP11_USER_LITERALS + friend half literal::operator"" _h(long double); +#endif +#endif +}; + +#if HALF_ENABLE_CPP11_USER_LITERALS +namespace literal { +/// Half literal. +/// While this returns a properly rounded half-precision value, half literals can unfortunately not +/// be constant +/// expressions due to rather involved conversions. So don't expect this to be a literal literal +/// without involving +/// conversion operations at runtime. It is a convenience feature, not a performance optimization. +/// \param value literal value +/// \return half with of given value (possibly rounded) +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half operator"" _h(long double value) +{ + return half(detail::binary, detail::float2half(value)); +} +} // namespace literal +#endif + +namespace detail { +/// Helper class for half casts. +/// This class template has to be specialized for all valid cast arguments to define an appropriate +/// static +/// `cast` member function and a corresponding `type` member denoting its return type. +/// \tparam T destination type +/// \tparam U source type +/// \tparam R rounding mode to use +template +struct half_caster +{ +}; +template +struct half_caster +{ +#if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic::value, "half_cast from non-arithmetic type unsupported"); +#endif + + static half cast(U arg) { return cast_impl(arg, is_float()); }; + + private: + static half cast_impl(U arg, true_type) { return half(binary, float2half(arg)); } + static half cast_impl(U arg, false_type) { return half(binary, int2half(arg)); } +}; +template +struct half_caster +{ +#if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic::value, "half_cast to non-arithmetic type unsupported"); +#endif + + static T cast(half arg) { return cast_impl(arg, is_float()); } + + private: + static T cast_impl(half arg, true_type) { return half2float(arg.data_); } + static T cast_impl(half arg, false_type) { return half2int(arg.data_); } +}; +template +struct half_caster +{ + static half cast(half arg) { return arg; } +}; +} // namespace detail +} // namespace half_float + +/// Extensions to the C++ standard library. +namespace std { +/// Numeric limits for half-precision floats. +/// **See also:** Documentation for +/// [std::numeric_limits](https://en.cppreference.com/w/cpp/types/numeric_limits) +template <> +class numeric_limits +{ + public: + /// Is template specialization. + static HALF_CONSTEXPR_CONST bool is_specialized = true; + + /// Supports signed values. + static HALF_CONSTEXPR_CONST bool is_signed = true; + + /// Is not an integer type. + static HALF_CONSTEXPR_CONST bool is_integer = false; + + /// Is not exact. + static HALF_CONSTEXPR_CONST bool is_exact = false; + + /// Doesn't provide modulo arithmetic. + static HALF_CONSTEXPR_CONST bool is_modulo = false; + + /// Has a finite set of values. + static HALF_CONSTEXPR_CONST bool is_bounded = true; + + /// IEEE conformant. + static HALF_CONSTEXPR_CONST bool is_iec559 = true; + + /// Supports infinity. + static HALF_CONSTEXPR_CONST bool has_infinity = true; + + /// Supports quiet NaNs. + static HALF_CONSTEXPR_CONST bool has_quiet_NaN = true; + + /// Supports signaling NaNs. + static HALF_CONSTEXPR_CONST bool has_signaling_NaN = true; + + /// Supports subnormal values. + static HALF_CONSTEXPR_CONST float_denorm_style has_denorm = denorm_present; + + /// Supports no denormalization detection. + static HALF_CONSTEXPR_CONST bool has_denorm_loss = false; + +#if HALF_ERRHANDLING_THROWS + static HALF_CONSTEXPR_CONST bool traps = true; +#else + /// Traps only if [HALF_ERRHANDLING_THROW_...](\ref HALF_ERRHANDLING_THROW_INVALID) is + /// acitvated. + static HALF_CONSTEXPR_CONST bool traps = false; +#endif + + /// Does not support no pre-rounding underflow detection. + static HALF_CONSTEXPR_CONST bool tinyness_before = false; + + /// Rounding mode. + static HALF_CONSTEXPR_CONST float_round_style round_style = half_float::half::round_style; + + /// Significant digits. + static HALF_CONSTEXPR_CONST int digits = 11; + + /// Significant decimal digits. + static HALF_CONSTEXPR_CONST int digits10 = 3; + + /// Required decimal digits to represent all possible values. + static HALF_CONSTEXPR_CONST int max_digits10 = 5; + + /// Number base. + static HALF_CONSTEXPR_CONST int radix = 2; + + /// One more than smallest exponent. + static HALF_CONSTEXPR_CONST int min_exponent = -13; + + /// Smallest normalized representable power of 10. + static HALF_CONSTEXPR_CONST int min_exponent10 = -4; + + /// One more than largest exponent + static HALF_CONSTEXPR_CONST int max_exponent = 16; + + /// Largest finitely representable power of 10. + static HALF_CONSTEXPR_CONST int max_exponent10 = 4; + + /// Smallest positive normal value. + static HALF_CONSTEXPR half_float::half min() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x0400); + } + + /// Smallest finite value. + static HALF_CONSTEXPR half_float::half lowest() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0xFBFF); + } + + /// Largest finite value. + static HALF_CONSTEXPR half_float::half max() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x7BFF); + } + + /// Difference between 1 and next representable value. + static HALF_CONSTEXPR half_float::half epsilon() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x1400); + } + + /// Maximum rounding error in ULP (units in the last place). + static HALF_CONSTEXPR half_float::half round_error() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, + (round_style == std::round_to_nearest) ? 0x3800 : 0x3C00); + } + + /// Positive infinity. + static HALF_CONSTEXPR half_float::half infinity() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x7C00); + } + + /// Quiet NaN. + static HALF_CONSTEXPR half_float::half quiet_NaN() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x7FFF); + } + + /// Signaling NaN. + static HALF_CONSTEXPR half_float::half signaling_NaN() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x7DFF); + } + + /// Smallest positive subnormal value. + static HALF_CONSTEXPR half_float::half denorm_min() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x0001); + } +}; + +#if HALF_ENABLE_CPP11_HASH +/// Hash function for half-precision floats. +/// This is only defined if C++11 `std::hash` is supported and enabled. +/// +/// **See also:** Documentation for [std::hash](https://en.cppreference.com/w/cpp/utility/hash) +template <> +struct hash +{ + /// Type of function argument. + typedef half_float::half argument_type; + + /// Function return type. + typedef size_t result_type; + + /// Compute hash function. + /// \param arg half to hash + /// \return hash value + result_type operator()(argument_type arg) const + { + return hash()(arg.data_ & + -static_cast(arg.data_ != 0x8000)); + } +}; +#endif +} // namespace std + +namespace half_float { +/// \anchor compop +/// \name Comparison operators +/// \{ + +/// Comparison for equality. +/// \param x first operand +/// \param y second operand +/// \retval true if operands equal +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator==(half x, half y) +{ + return !detail::compsignal(x.data_, y.data_) && + (x.data_ == y.data_ || !((x.data_ | y.data_) & 0x7FFF)); +} + +/// Comparison for inequality. +/// \param x first operand +/// \param y second operand +/// \retval true if operands not equal +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator!=(half x, half y) +{ + return detail::compsignal(x.data_, y.data_) || + (x.data_ != y.data_ && ((x.data_ | y.data_) & 0x7FFF)); +} + +/// Comparison for less than. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less than \a y +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator<(half x, half y) +{ + return !detail::compsignal(x.data_, y.data_) && + ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) < + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); +} + +/// Comparison for greater than. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater than \a y +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator>(half x, half y) +{ + return !detail::compsignal(x.data_, y.data_) && + ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) > + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); +} + +/// Comparison for less equal. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less equal \a y +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator<=(half x, half y) +{ + return !detail::compsignal(x.data_, y.data_) && + ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) <= + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); +} + +/// Comparison for greater equal. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater equal \a y +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator>=(half x, half y) +{ + return !detail::compsignal(x.data_, y.data_) && + ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) >= + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); +} + +/// \} +/// \anchor arithmetics +/// \name Arithmetic operators +/// \{ + +/// Identity. +/// \param arg operand +/// \return unchanged operand +inline HALF_CONSTEXPR half operator+(half arg) { return arg; } + +/// Negation. +/// \param arg operand +/// \return negated operand +inline HALF_CONSTEXPR half operator-(half arg) { return half(detail::binary, arg.data_ ^ 0x8000); } + +/// Addition. +/// This operation is exact to rounding for all rounding modes. +/// \param x left operand +/// \param y right operand +/// \return sum of half expressions +/// \exception FE_INVALID if \a x and \a y are infinities with different signs or signaling NaNs +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half operator+(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half( + detail::binary, + detail::float2half(detail::half2float(x.data_) + + detail::half2float(y.data_))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF; + bool sub = ((x.data_ ^ y.data_) & 0x8000) != 0; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absy != 0x7C00) ? x.data_ + : (sub && absx == 0x7C00) ? detail::invalid() : y.data_); + if(!absx) + return absy ? y + : half(detail::binary, + (half::round_style == std::round_toward_neg_infinity) + ? (x.data_ | y.data_) + : (x.data_ & y.data_)); + if(!absy) + return x; + unsigned int sign = ((sub && absy > absx) ? y.data_ : x.data_) & 0x8000; + if(absy > absx) + std::swap(absx, absy); + int exp = (absx >> 10) + (absx <= 0x3FF), d = exp - (absy >> 10) - (absy <= 0x3FF), + mx = ((absx & 0x3FF) | ((absx > 0x3FF) << 10)) << 3, my; + if(d < 13) + { + my = ((absy & 0x3FF) | ((absy > 0x3FF) << 10)) << 3; + my = (my >> d) | ((my & ((1 << d) - 1)) != 0); + } + else + my = 1; + if(sub) + { + if(!(mx -= my)) + return half(detail::binary, + static_cast(half::round_style == std::round_toward_neg_infinity) + << 15); + for(; mx < 0x2000 && exp > 1; mx <<= 1, --exp) + ; + } + else + { + mx += my; + int i = mx >> 14; + if((exp += i) > 30) + return half(detail::binary, detail::overflow(sign)); + mx = (mx >> i) | (mx & i); + } + return half(detail::binary, + detail::rounded( + sign + ((exp - 1) << 10) + (mx >> 3), (mx >> 2) & 1, (mx & 0x3) != 0)); +#endif +} + +/// Subtraction. +/// This operation is exact to rounding for all rounding modes. +/// \param x left operand +/// \param y right operand +/// \return difference of half expressions +/// \exception FE_INVALID if \a x and \a y are infinities with equal signs or signaling NaNs +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half operator-(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half( + detail::binary, + detail::float2half(detail::half2float(x.data_) - + detail::half2float(y.data_))); +#else + return x + -y; +#endif +} + +/// Multiplication. +/// This operation is exact to rounding for all rounding modes. +/// \param x left operand +/// \param y right operand +/// \return product of half expressions +/// \exception FE_INVALID if multiplying 0 with infinity or if \a x or \a y is signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half operator*(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half( + detail::binary, + detail::float2half(detail::half2float(x.data_) * + detail::half2float(y.data_))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = -16; + unsigned int sign = (x.data_ ^ y.data_) & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : ((absx == 0x7C00 && !absy) || (absy == 0x7C00 && !absx)) + ? detail::invalid() + : (sign | 0x7C00)); + if(!absx || !absy) + return half(detail::binary, sign); + for(; absx < 0x400; absx <<= 1, --exp) + ; + for(; absy < 0x400; absy <<= 1, --exp) + ; + detail::uint32 m = static_cast((absx & 0x3FF) | 0x400) * + static_cast((absy & 0x3FF) | 0x400); + int i = m >> 21, s = m & i; + exp += (absx >> 10) + (absy >> 10) + i; + if(exp > 29) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -11) + return half(detail::binary, detail::underflow(sign)); + return half( + detail::binary, + detail::fixed2half(m >> i, exp, sign, s)); +#endif +} + +/// Division. +/// This operation is exact to rounding for all rounding modes. +/// \param x left operand +/// \param y right operand +/// \return quotient of half expressions +/// \exception FE_INVALID if dividing 0s or infinities with each other or if \a x or \a y is +/// signaling NaN +/// \exception FE_DIVBYZERO if dividing finite value by 0 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half operator/(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half( + detail::binary, + detail::float2half(detail::half2float(x.data_) / + detail::half2float(y.data_))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = 14; + unsigned int sign = (x.data_ ^ y.data_) & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absx == absy) ? detail::invalid() + : (sign | ((absx == 0x7C00) ? 0x7C00 : 0))); + if(!absx) + return half(detail::binary, absy ? sign : detail::invalid()); + if(!absy) + return half(detail::binary, detail::pole(sign)); + for(; absx < 0x400; absx <<= 1, --exp) + ; + for(; absy < 0x400; absy <<= 1, ++exp) + ; + detail::uint32 mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400; + int i = mx < my; + exp += (absx >> 10) - (absy >> 10) - i; + if(exp > 29) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -11) + return half(detail::binary, detail::underflow(sign)); + mx <<= 12 + i; + my <<= 1; + return half(detail::binary, + detail::fixed2half( + mx / my, exp, sign, mx % my != 0)); +#endif +} + +/// \} +/// \anchor streaming +/// \name Input and output +/// \{ + +/// Output operator. +/// This uses the built-in functionality for streaming out floating-point numbers. +/// \param out output stream to write into +/// \param arg half expression to write +/// \return reference to output stream +template +std::basic_ostream& operator<<(std::basic_ostream& out, half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return out << detail::half2float(arg.data_); +#else + return out << detail::half2float(arg.data_); +#endif +} + +/// Input operator. +/// This uses the built-in functionality for streaming in floating-point numbers, specifically +/// double precision floating +/// point numbers (unless overridden with [HALF_ARITHMETIC_TYPE](\ref HALF_ARITHMETIC_TYPE)). So the +/// input string is first +/// rounded to double precision using the underlying platform's current floating-point rounding mode +/// before being rounded +/// to half-precision using the library's half-precision rounding mode. +/// \param in input stream to read from +/// \param arg half to read into +/// \return reference to input stream +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +template +std::basic_istream& operator>>(std::basic_istream& in, half& arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t f; +#else + double f; +#endif + if(in >> f) + arg.data_ = detail::float2half(f); + return in; +} + +/// \} +/// \anchor basic +/// \name Basic mathematical operations +/// \{ + +/// Absolute value. +/// **See also:** Documentation for +/// [std::fabs](https://en.cppreference.com/w/cpp/numeric/math/fabs). +/// \param arg operand +/// \return absolute value of \a arg +inline HALF_CONSTEXPR half fabs(half arg) { return half(detail::binary, arg.data_ & 0x7FFF); } + +/// Absolute value. +/// **See also:** Documentation for [std::abs](https://en.cppreference.com/w/cpp/numeric/math/fabs). +/// \param arg operand +/// \return absolute value of \a arg +inline HALF_CONSTEXPR half abs(half arg) { return fabs(arg); } + +/// Remainder of division. +/// **See also:** Documentation for +/// [std::fmod](https://en.cppreference.com/w/cpp/numeric/math/fmod). +/// \param x first operand +/// \param y second operand +/// \return remainder of floating-point division. +/// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN +inline half fmod(half x, half y) +{ + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, sign = x.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absx == 0x7C00) ? detail::invalid() : x.data_); + if(!absy) + return half(detail::binary, detail::invalid()); + if(!absx) + return x; + if(absx == absy) + return half(detail::binary, sign); + return half(detail::binary, sign | detail::mod(absx, absy)); +} + +/// Remainder of division. +/// **See also:** Documentation for +/// [std::remainder](https://en.cppreference.com/w/cpp/numeric/math/remainder). +/// \param x first operand +/// \param y second operand +/// \return remainder of floating-point division. +/// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN +inline half remainder(half x, half y) +{ + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, sign = x.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absx == 0x7C00) ? detail::invalid() : x.data_); + if(!absy) + return half(detail::binary, detail::invalid()); + if(absx == absy) + return half(detail::binary, sign); + return half(detail::binary, sign ^ detail::mod(absx, absy)); +} + +/// Remainder of division. +/// **See also:** Documentation for +/// [std::remquo](https://en.cppreference.com/w/cpp/numeric/math/remquo). +/// \param x first operand +/// \param y second operand +/// \param quo address to store some bits of quotient at +/// \return remainder of floating-point division. +/// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN +inline half remquo(half x, half y, int* quo) +{ + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, value = x.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absx == 0x7C00) ? detail::invalid() : (*quo = 0, x.data_)); + if(!absy) + return half(detail::binary, detail::invalid()); + bool qsign = ((value ^ y.data_) & 0x8000) != 0; + int q = 1; + if(absx != absy) + value ^= detail::mod(absx, absy, &q); + return *quo = qsign ? -q : q, half(detail::binary, value); +} + +/// Fused multiply add. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::fma](https://en.cppreference.com/w/cpp/numeric/math/fma). +/// \param x first operand +/// \param y second operand +/// \param z third operand +/// \return ( \a x * \a y ) + \a z rounded as one operation. +/// \exception FE_INVALID according to operator*() and operator+() unless any argument is a quiet +/// NaN and no argument is a signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding the final addition +inline half fma(half x, half y, half z) +{ +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), + fy = detail::half2float(y.data_), + fz = detail::half2float(z.data_); +#if HALF_ENABLE_CPP11_CMATH && FP_FAST_FMA + return half(detail::binary, detail::float2half(std::fma(fx, fy, fz))); +#else + return half(detail::binary, detail::float2half(fx * fy + fz)); +#endif +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, absz = z.data_ & 0x7FFF, exp = -15; + unsigned int sign = (x.data_ ^ y.data_) & 0x8000; + bool sub = ((sign ^ z.data_) & 0x8000) != 0; + if(absx >= 0x7C00 || absy >= 0x7C00 || absz >= 0x7C00) + return (absx > 0x7C00 || absy > 0x7C00 || absz > 0x7C00) + ? half(detail::binary, detail::signal(x.data_, y.data_, z.data_)) + : (absx == 0x7C00) ? half(detail::binary, + (!absy || (sub && absz == 0x7C00)) ? detail::invalid() + : (sign | 0x7C00)) + : (absy == 0x7C00) ? half(detail::binary, + (!absx || (sub && absz == 0x7C00)) + ? detail::invalid() + : (sign | 0x7C00)) + : z; + if(!absx || !absy) + return absz + ? z + : half(detail::binary, + (half::round_style == std::round_toward_neg_infinity) ? (z.data_ | sign) + : (z.data_ & sign)); + for(; absx < 0x400; absx <<= 1, --exp) + ; + for(; absy < 0x400; absy <<= 1, --exp) + ; + detail::uint32 m = static_cast((absx & 0x3FF) | 0x400) * + static_cast((absy & 0x3FF) | 0x400); + int i = m >> 21; + exp += (absx >> 10) + (absy >> 10) + i; + m <<= 3 - i; + if(absz) + { + int expz = 0; + for(; absz < 0x400; absz <<= 1, --expz) + ; + expz += absz >> 10; + detail::uint32 mz = static_cast((absz & 0x3FF) | 0x400) << 13; + if(expz > exp || (expz == exp && mz > m)) + { + std::swap(m, mz); + std::swap(exp, expz); + if(sub) + sign = z.data_ & 0x8000; + } + int d = exp - expz; + mz = (d < 23) ? ((mz >> d) | ((mz & ((static_cast(1) << d) - 1)) != 0)) : 1; + if(sub) + { + m = m - mz; + if(!m) + return half( + detail::binary, + static_cast(half::round_style == std::round_toward_neg_infinity) + << 15); + for(; m < 0x800000; m <<= 1, --exp) + ; + } + else + { + m += mz; + i = m >> 24; + m = (m >> i) | (m & i); + exp += i; + } + } + if(exp > 30) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -10) + return half(detail::binary, detail::underflow(sign)); + return half(detail::binary, + detail::fixed2half(m, exp - 1, sign)); +#endif +} + +/// Maximum of half expressions. +/// **See also:** Documentation for +/// [std::fmax](https://en.cppreference.com/w/cpp/numeric/math/fmax). +/// \param x first operand +/// \param y second operand +/// \return maximum of operands, ignoring quiet NaNs +/// \exception FE_INVALID if \a x or \a y is signaling NaN +inline HALF_CONSTEXPR_NOERR half fmax(half x, half y) +{ + return half(detail::binary, + (!isnan(y) && (isnan(x) || (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) < + (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))))) + ? detail::select(y.data_, x.data_) + : detail::select(x.data_, y.data_)); +} + +/// Minimum of half expressions. +/// **See also:** Documentation for +/// [std::fmin](https://en.cppreference.com/w/cpp/numeric/math/fmin). +/// \param x first operand +/// \param y second operand +/// \return minimum of operands, ignoring quiet NaNs +/// \exception FE_INVALID if \a x or \a y is signaling NaN +inline HALF_CONSTEXPR_NOERR half fmin(half x, half y) +{ + return half(detail::binary, + (!isnan(y) && (isnan(x) || (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) > + (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))))) + ? detail::select(y.data_, x.data_) + : detail::select(x.data_, y.data_)); +} + +/// Positive difference. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::fdim](https://en.cppreference.com/w/cpp/numeric/math/fdim). +/// \param x first operand +/// \param y second operand +/// \return \a x - \a y or 0 if difference negative +/// \exception FE_... according to operator-(half,half) +inline half fdim(half x, half y) +{ + if(isnan(x) || isnan(y)) + return half(detail::binary, detail::signal(x.data_, y.data_)); + return (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) <= + (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + ? half(detail::binary, 0) + : (x - y); +} + +/// Get NaN value. +/// **See also:** Documentation for [std::nan](https://en.cppreference.com/w/cpp/numeric/math/nan). +/// \param arg string code +/// \return quiet NaN +inline half nanh(const char* arg) +{ + unsigned int value = 0x7FFF; + while(*arg) + value ^= static_cast(*arg++) & 0xFF; + return half(detail::binary, value); +} + +/// \} +/// \anchor exponential +/// \name Exponential functions +/// \{ + +/// Exponential function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::exp](https://en.cppreference.com/w/cpp/numeric/math/exp). +/// \param arg function argument +/// \return e raised to \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half exp(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::exp(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? (0x7C00 & ((arg.data_ >> 15) - 1U)) + : detail::signal(arg.data_)); + if(abs >= 0x4C80) + return half(detail::binary, + (arg.data_ & 0x8000) ? detail::underflow() + : detail::overflow()); + detail::uint32 m = detail::multiply64( + static_cast((abs & 0x3FF) + ((abs > 0x3FF) << 10)) << 21, 0xB8AA3B29); + int e = (abs >> 10) + (abs <= 0x3FF), exp; + if(e < 14) + { + exp = 0; + m >>= 14 - e; + } + else + { + exp = m >> (45 - e); + m = (m << (e - 14)) & 0x7FFFFFFF; + } + return half(detail::binary, + detail::exp2_post( + detail::exp2(m, 26), exp, (arg.data_ & 0x8000) != 0)); +#endif +} + +/// Binary exponential. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::exp2](https://en.cppreference.com/w/cpp/numeric/math/exp2). +/// \param arg function argument +/// \return 2 raised to \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half exp2(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::exp2(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? (0x7C00 & ((arg.data_ >> 15) - 1U)) + : detail::signal(arg.data_)); + if(abs >= 0x4E40) + return half(detail::binary, + (arg.data_ & 0x8000) ? detail::underflow() + : detail::overflow()); + int e = (abs >> 10) + (abs <= 0x3FF), exp = (abs & 0x3FF) + ((abs > 0x3FF) << 10); + detail::uint32 m = detail::exp2((static_cast(exp) << (6 + e)) & 0x7FFFFFFF, 28); + exp >>= 25 - e; + if(m == 0x80000000) + { + if(arg.data_ & 0x8000) + exp = -exp; + else if(exp > 15) + return half(detail::binary, detail::overflow()); + return half(detail::binary, + detail::fixed2half(m, exp + 14)); + } + return half(detail::binary, + detail::exp2_post(m, exp, (arg.data_ & 0x8000) != 0)); +#endif +} + +/// Exponential minus one. +/// This function may be 1 ULP off the correctly rounded exact result in <0.05% of inputs for +/// `std::round_to_nearest` +/// and in <1% of inputs for any other rounding mode. +/// +/// **See also:** Documentation for +/// [std::expm1](https://en.cppreference.com/w/cpp/numeric/math/expm1). +/// \param arg function argument +/// \return e raised to \a arg and subtracted by 1 +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half expm1(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::expm1(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? (0x7C00 + (sign >> 1)) : detail::signal(arg.data_)); + if(abs >= 0x4A00) + return half(detail::binary, + (arg.data_ & 0x8000) ? detail::rounded(0xBBFF, 1, 1) + : detail::overflow()); + detail::uint32 m = detail::multiply64( + static_cast((abs & 0x3FF) + ((abs > 0x3FF) << 10)) << 21, 0xB8AA3B29); + int e = (abs >> 10) + (abs <= 0x3FF), exp; + if(e < 14) + { + exp = 0; + m >>= 14 - e; + } + else + { + exp = m >> (45 - e); + m = (m << (e - 14)) & 0x7FFFFFFF; + } + m = detail::exp2(m); + if(sign) + { + int s = 0; + if(m > 0x80000000) + { + ++exp; + m = detail::divide64(0x80000000, m, s); + } + m = 0x80000000 - + ((m >> exp) | ((m & ((static_cast(1) << exp) - 1)) != 0) | s); + exp = 0; + } + else + m -= (exp < 31) ? (0x80000000 >> exp) : 1; + for(exp += 14; m < 0x80000000 && exp; m <<= 1, --exp) + ; + if(exp > 29) + return half(detail::binary, detail::overflow()); + return half(detail::binary, + detail::rounded( + sign + (exp << 10) + (m >> 21), (m >> 20) & 1, (m & 0xFFFFF) != 0)); +#endif +} + +/// Natural logarithm. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::log](https://en.cppreference.com/w/cpp/numeric/math/log). +/// \param arg function argument +/// \return logarithm of \a arg to base e +/// \exception FE_INVALID for signaling NaN or negative argument +/// \exception FE_DIVBYZERO for 0 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half log(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::log(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(arg.data_ & 0x8000) + return half(detail::binary, + (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs >= 0x7C00) + return (abs == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + return half(detail::binary, + detail::log2_post( + detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, 27) + 8, + exp, + 17)); +#endif +} + +/// Common logarithm. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::log10](https://en.cppreference.com/w/cpp/numeric/math/log10). +/// \param arg function argument +/// \return logarithm of \a arg to base 10 +/// \exception FE_INVALID for signaling NaN or negative argument +/// \exception FE_DIVBYZERO for 0 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half log10(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::log10(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(arg.data_ & 0x8000) + return half(detail::binary, + (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs >= 0x7C00) + return (abs == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + switch(abs) + { + case 0x4900: return half(detail::binary, 0x3C00); + case 0x5640: return half(detail::binary, 0x4000); + case 0x63D0: return half(detail::binary, 0x4200); + case 0x70E2: return half(detail::binary, 0x4400); + } + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + return half(detail::binary, + detail::log2_post( + detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, 27) + 8, + exp, + 16)); +#endif +} + +/// Binary logarithm. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::log2](https://en.cppreference.com/w/cpp/numeric/math/log2). +/// \param arg function argument +/// \return logarithm of \a arg to base 2 +/// \exception FE_INVALID for signaling NaN or negative argument +/// \exception FE_DIVBYZERO for 0 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half log2(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::log2(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = -15, s = 0; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(arg.data_ & 0x8000) + return half(detail::binary, + (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs >= 0x7C00) + return (abs == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + if(abs == 0x3C00) + return half(detail::binary, 0); + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += (abs >> 10); + if(!(abs & 0x3FF)) + { + unsigned int value = static_cast(exp < 0) << 15, m = std::abs(exp) << 6; + for(exp = 18; m < 0x400; m <<= 1, --exp) + ; + return half(detail::binary, value + (exp << 10) + m); + } + detail::uint32 ilog = exp, sign = detail::sign_mask(ilog), + m = (((ilog << 27) + + (detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, + 28) >> + 4)) ^ + sign) - + sign; + if(!m) + return half(detail::binary, 0); + for(exp = 14; m < 0x8000000 && exp; m <<= 1, --exp) + ; + for(; m > 0xFFFFFFF; m >>= 1, ++exp) + s |= m & 1; + return half( + detail::binary, + detail::fixed2half(m, exp, sign & 0x8000, s)); +#endif +} + +/// Natural logarithm plus one. +/// This function may be 1 ULP off the correctly rounded exact result in <0.05% of inputs for +/// `std::round_to_nearest` +/// and in ~1% of inputs for any other rounding mode. +/// +/// **See also:** Documentation for +/// [std::log1p](https://en.cppreference.com/w/cpp/numeric/math/log1p). +/// \param arg function argument +/// \return logarithm of \a arg plus 1 to base e +/// \exception FE_INVALID for signaling NaN or argument <-1 +/// \exception FE_DIVBYZERO for -1 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half log1p(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::log1p(detail::half2float(arg.data_)))); +#else + if(arg.data_ >= 0xBC00) + return half(detail::binary, + (arg.data_ == 0xBC00) + ? detail::pole(0x8000) + : (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs || abs >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + detail::uint32 m = static_cast((abs & 0x3FF) | 0x400) << 20; + if(arg.data_ & 0x8000) + { + m = 0x40000000 - (m >> -exp); + for(exp = 0; m < 0x40000000; m <<= 1, --exp) + ; + } + else + { + if(exp < 0) + { + m = 0x40000000 + (m >> -exp); + exp = 0; + } + else + { + m += 0x40000000 >> exp; + int i = m >> 31; + m >>= i; + exp += i; + } + } + return half(detail::binary, + detail::log2_post(detail::log2(m), exp, 17)); +#endif +} + +/// \} +/// \anchor power +/// \name Power functions +/// \{ + +/// Square root. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::sqrt](https://en.cppreference.com/w/cpp/numeric/math/sqrt). +/// \param arg function argument +/// \return square root of \a arg +/// \exception FE_INVALID for signaling NaN and negative arguments +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half sqrt(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::sqrt(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = 15; + if(!abs || arg.data_ >= 0x7C00) + return half(detail::binary, + (abs > 0x7C00) ? detail::signal(arg.data_) + : (arg.data_ > 0x8000) ? detail::invalid() : arg.data_); + for(; abs < 0x400; abs <<= 1, --exp) + ; + detail::uint32 r = static_cast((abs & 0x3FF) | 0x400) << 10, + m = detail::sqrt<20>(r, exp += abs >> 10); + return half( + detail::binary, + detail::rounded((exp << 10) + (m & 0x3FF), r > m, r != 0)); +#endif +} + +/// Cubic root. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::cbrt](https://en.cppreference.com/w/cpp/numeric/math/cbrt). +/// \param arg function argument +/// \return cubic root of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half cbrt(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::cbrt(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs || abs == 0x3C00 || abs >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs < 0x400; abs <<= 1, --exp) + ; + detail::uint32 ilog = exp + (abs >> 10), sign = detail::sign_mask(ilog), f, + m = (((ilog << 27) + + (detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, + 24) >> + 4)) ^ + sign) - + sign; + for(exp = 2; m < 0x80000000; m <<= 1, --exp) + ; + m = detail::multiply64(m, 0xAAAAAAAB); + int i = m >> 31, s; + exp += i; + m <<= 1 - i; + if(exp < 0) + { + f = m >> -exp; + exp = 0; + } + else + { + f = (m << exp) & 0x7FFFFFFF; + exp = m >> (31 - exp); + } + m = detail::exp2(f, (half::round_style == std::round_to_nearest) ? 29 : 26); + if(sign) + { + if(m > 0x80000000) + { + m = detail::divide64(0x80000000, m, s); + ++exp; + } + exp = -exp; + } + return half(detail::binary, + (half::round_style == std::round_to_nearest) + ? detail::fixed2half( + m, exp + 14, arg.data_ & 0x8000) + : detail::fixed2half( + (m + 0x80) >> 8, exp + 14, arg.data_ & 0x8000)); +#endif +} + +/// Hypotenuse function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::hypot](https://en.cppreference.com/w/cpp/numeric/math/hypot). +/// \param x first argument +/// \param y second argument +/// \return square root of sum of squares without internal over- or underflows +/// \exception FE_INVALID if \a x or \a y is signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding of the final square root +inline half hypot(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), + fy = detail::half2float(y.data_); +#if HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::hypot(fx, fy))); +#else + return half(detail::binary, + detail::float2half(std::sqrt(fx * fx + fy * fy))); +#endif +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, expx = 0, expy = 0; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx == 0x7C00) ? detail::select(0x7C00, y.data_) + : (absy == 0x7C00) ? detail::select(0x7C00, x.data_) + : detail::signal(x.data_, y.data_)); + if(!absx) + return half(detail::binary, absy ? detail::check_underflow(absy) : 0); + if(!absy) + return half(detail::binary, detail::check_underflow(absx)); + if(absy > absx) + std::swap(absx, absy); + for(; absx < 0x400; absx <<= 1, --expx) + ; + for(; absy < 0x400; absy <<= 1, --expy) + ; + detail::uint32 mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400; + mx *= mx; + my *= my; + int ix = mx >> 21, iy = my >> 21; + expx = 2 * (expx + (absx >> 10)) - 15 + ix; + expy = 2 * (expy + (absy >> 10)) - 15 + iy; + mx <<= 10 - ix; + my <<= 10 - iy; + int d = expx - expy; + my = (d < 30) ? ((my >> d) | ((my & ((static_cast(1) << d) - 1)) != 0)) : 1; + return half(detail::binary, detail::hypot_post(mx + my, expx)); +#endif +} + +/// Hypotenuse function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::hypot](https://en.cppreference.com/w/cpp/numeric/math/hypot). +/// \param x first argument +/// \param y second argument +/// \param z third argument +/// \return square root of sum of squares without internal over- or underflows +/// \exception FE_INVALID if \a x, \a y or \a z is signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding of the final square root +inline half hypot(half x, half y, half z) +{ +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), + fy = detail::half2float(y.data_), + fz = detail::half2float(z.data_); + return half(detail::binary, + detail::float2half(std::sqrt(fx * fx + fy * fy + fz * fz))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, absz = z.data_ & 0x7FFF, expx = 0, + expy = 0, expz = 0; + if(!absx) + return hypot(y, z); + if(!absy) + return hypot(x, z); + if(!absz) + return hypot(x, y); + if(absx >= 0x7C00 || absy >= 0x7C00 || absz >= 0x7C00) + return half(detail::binary, + (absx == 0x7C00) + ? detail::select(0x7C00, detail::select(y.data_, z.data_)) + : (absy == 0x7C00) + ? detail::select(0x7C00, detail::select(x.data_, z.data_)) + : (absz == 0x7C00) + ? detail::select(0x7C00, detail::select(x.data_, y.data_)) + : detail::signal(x.data_, y.data_, z.data_)); + if(absz > absy) + std::swap(absy, absz); + if(absy > absx) + std::swap(absx, absy); + if(absz > absy) + std::swap(absy, absz); + for(; absx < 0x400; absx <<= 1, --expx) + ; + for(; absy < 0x400; absy <<= 1, --expy) + ; + for(; absz < 0x400; absz <<= 1, --expz) + ; + detail::uint32 mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400, + mz = (absz & 0x3FF) | 0x400; + mx *= mx; + my *= my; + mz *= mz; + int ix = mx >> 21, iy = my >> 21, iz = mz >> 21; + expx = 2 * (expx + (absx >> 10)) - 15 + ix; + expy = 2 * (expy + (absy >> 10)) - 15 + iy; + expz = 2 * (expz + (absz >> 10)) - 15 + iz; + mx <<= 10 - ix; + my <<= 10 - iy; + mz <<= 10 - iz; + int d = expy - expz; + mz = (d < 30) ? ((mz >> d) | ((mz & ((static_cast(1) << d) - 1)) != 0)) : 1; + my += mz; + if(my & 0x80000000) + { + my = (my >> 1) | (my & 1); + if(++expy > expx) + { + std::swap(mx, my); + std::swap(expx, expy); + } + } + d = expx - expy; + my = (d < 30) ? ((my >> d) | ((my & ((static_cast(1) << d) - 1)) != 0)) : 1; + return half(detail::binary, detail::hypot_post(mx + my, expx)); +#endif +} + +/// Power function. +/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in +/// ~0.00025% of inputs. +/// +/// **See also:** Documentation for [std::pow](https://en.cppreference.com/w/cpp/numeric/math/pow). +/// \param x base +/// \param y exponent +/// \return \a x raised to \a y +/// \exception FE_INVALID if \a x or \a y is signaling NaN or if \a x is finite an negative and \a y +/// is finite and not integral +/// \exception FE_DIVBYZERO if \a x is 0 and \a y is negative +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half pow(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::pow(detail::half2float(x.data_), + detail::half2float(y.data_)))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = -15; + if(!absy || x.data_ == 0x3C00) + return half(detail::binary, + detail::select(0x3C00, (x.data_ == 0x3C00) ? y.data_ : x.data_)); + bool is_int = absy >= 0x6400 || (absy >= 0x3C00 && !(absy & ((1 << (25 - (absy >> 10))) - 1))); + unsigned int sign = + x.data_ & + (static_cast((absy < 0x6800) && is_int && ((absy >> (25 - (absy >> 10))) & 1)) + << 15); + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absy == 0x7C00) + ? ((absx == 0x3C00) + ? 0x3C00 + : (!absx && y.data_ == 0xFC00) + ? detail::pole() + : (0x7C00 & -((y.data_ >> 15) ^ (absx > 0x3C00)))) + : (sign | (0x7C00 & ((y.data_ >> 15) - 1U)))); + if(!absx) + return half(detail::binary, (y.data_ & 0x8000) ? detail::pole(sign) : sign); + if((x.data_ & 0x8000) && !is_int) + return half(detail::binary, detail::invalid()); + if(x.data_ == 0xBC00) + return half(detail::binary, sign | 0x3C00); + if(y.data_ == 0x3800) + return sqrt(x); + if(y.data_ == 0x3C00) + return half(detail::binary, detail::check_underflow(x.data_)); + if(y.data_ == 0x4000) + return x * x; + for(; absx < 0x400; absx <<= 1, --exp) + ; + detail::uint32 ilog = exp + (absx >> 10), msign = detail::sign_mask(ilog), f, + m = (((ilog << 27) + + ((detail::log2(static_cast((absx & 0x3FF) | 0x400) << 20) + + 8) >> + 4)) ^ + msign) - + msign; + for(exp = -11; m < 0x80000000; m <<= 1, --exp) + ; + for(; absy < 0x400; absy <<= 1, --exp) + ; + m = detail::multiply64(m, static_cast((absy & 0x3FF) | 0x400) << 21); + int i = m >> 31; + exp += (absy >> 10) + i; + m <<= 1 - i; + if(exp < 0) + { + f = m >> -exp; + exp = 0; + } + else + { + f = (m << exp) & 0x7FFFFFFF; + exp = m >> (31 - exp); + } + return half(detail::binary, + detail::exp2_post( + detail::exp2(f), exp, ((msign & 1) ^ (y.data_ >> 15)) != 0, sign)); +#endif +} + +/// \} +/// \anchor trigonometric +/// \name Trigonometric functions +/// \{ + +/// Compute sine and cosine simultaneously. +/// This returns the same results as sin() and cos() but is faster than calling each function +/// individually. +/// +/// This function is exact to rounding for all rounding modes. +/// \param arg function argument +/// \param sin variable to take sine of \a arg +/// \param cos variable to take cosine of \a arg +/// \exception FE_INVALID for signaling NaN or infinity +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline void sincos(half arg, half* sin, half* cos) +{ +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t f = detail::half2float(arg.data_); + *sin = half(detail::binary, detail::float2half(std::sin(f))); + *cos = half(detail::binary, detail::float2half(std::cos(f))); +#else + int abs = arg.data_ & 0x7FFF, sign = arg.data_ >> 15, k; + if(abs >= 0x7C00) + *sin = *cos = + half(detail::binary, (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + else if(!abs) + { + *sin = arg; + *cos = half(detail::binary, 0x3C00); + } + else if(abs < 0x2500) + { + *sin = half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x3BFF, 1, 1)); + } + else + { + if(half::round_style != std::round_to_nearest) + { + switch(abs) + { + case 0x48B7: + *sin = half( + detail::binary, + detail::rounded((~arg.data_ & 0x8000) | 0x1D07, 1, 1)); + *cos = half(detail::binary, detail::rounded(0xBBFF, 1, 1)); + return; + case 0x598C: + *sin = half( + detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x3BFF, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x80FC, 1, 1)); + return; + case 0x6A64: + *sin = half( + detail::binary, + detail::rounded((~arg.data_ & 0x8000) | 0x3BFE, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x27FF, 1, 1)); + return; + case 0x6D8C: + *sin = half( + detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x0FE6, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x3BFF, 1, 1)); + return; + } + } + std::pair sc = + detail::sincos(detail::angle_arg(abs, k), 28); + switch(k & 3) + { + case 1: sc = std::make_pair(sc.second, -sc.first); break; + case 2: sc = std::make_pair(-sc.first, -sc.second); break; + case 3: sc = std::make_pair(-sc.second, sc.first); break; + } + *sin = half(detail::binary, + detail::fixed2half( + (sc.first ^ -static_cast(sign)) + sign)); + *cos = half(detail::binary, + detail::fixed2half(sc.second)); + } +#endif +} + +/// Sine function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::sin](https://en.cppreference.com/w/cpp/numeric/math/sin). +/// \param arg function argument +/// \return sine value of \a arg +/// \exception FE_INVALID for signaling NaN or infinity +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half sin(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::sin(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, k; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2900) + return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); + if(half::round_style != std::round_to_nearest) + switch(abs) + { + case 0x48B7: + return half( + detail::binary, + detail::rounded((~arg.data_ & 0x8000) | 0x1D07, 1, 1)); + case 0x6A64: + return half( + detail::binary, + detail::rounded((~arg.data_ & 0x8000) | 0x3BFE, 1, 1)); + case 0x6D8C: + return half( + detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x0FE6, 1, 1)); + } + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); + detail::uint32 sign = -static_cast(((k >> 1) & 1) ^ (arg.data_ >> 15)); + return half(detail::binary, + detail::fixed2half( + (((k & 1) ? sc.second : sc.first) ^ sign) - sign)); +#endif +} + +/// Cosine function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::cos](https://en.cppreference.com/w/cpp/numeric/math/cos). +/// \param arg function argument +/// \return cosine value of \a arg +/// \exception FE_INVALID for signaling NaN or infinity +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half cos(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::cos(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, k; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2500) + return half(detail::binary, detail::rounded(0x3BFF, 1, 1)); + if(half::round_style != std::round_to_nearest && abs == 0x598C) + return half(detail::binary, detail::rounded(0x80FC, 1, 1)); + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); + detail::uint32 sign = -static_cast(((k >> 1) ^ k) & 1); + return half(detail::binary, + detail::fixed2half( + (((k & 1) ? sc.first : sc.second) ^ sign) - sign)); +#endif +} + +/// Tangent function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::tan](https://en.cppreference.com/w/cpp/numeric/math/tan). +/// \param arg function argument +/// \return tangent value of \a arg +/// \exception FE_INVALID for signaling NaN or infinity +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half tan(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::tan(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = 13, k; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2700) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + if(half::round_style != std::round_to_nearest) + switch(abs) + { + case 0x658C: + return half( + detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x07E6, 1, 1)); + case 0x7330: + return half( + detail::binary, + detail::rounded((~arg.data_ & 0x8000) | 0x4B62, 1, 1)); + } + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 30); + if(k & 1) + sc = std::make_pair(-sc.second, sc.first); + detail::uint32 signy = detail::sign_mask(sc.first), signx = detail::sign_mask(sc.second); + detail::uint32 my = (sc.first ^ signy) - signy, mx = (sc.second ^ signx) - signx; + for(; my < 0x80000000; my <<= 1, --exp) + ; + for(; mx < 0x80000000; mx <<= 1, ++exp) + ; + return half( + detail::binary, + detail::tangent_post(my, mx, exp, (signy ^ signx ^ arg.data_) & 0x8000)); +#endif +} + +/// Arc sine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::asin](https://en.cppreference.com/w/cpp/numeric/math/asin). +/// \param arg function argument +/// \return arc sine value of \a arg +/// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half asin(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::asin(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(!abs) + return arg; + if(abs >= 0x3C00) + return half(detail::binary, + (abs > 0x7C00) + ? detail::signal(arg.data_) + : (abs > 0x3C00) + ? detail::invalid() + : detail::rounded(sign | 0x3E48, 0, 1)); + if(abs < 0x2900) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + if(half::round_style != std::round_to_nearest && (abs == 0x2B44 || abs == 0x2DC3)) + return half(detail::binary, detail::rounded(arg.data_ + 1, 1, 1)); + std::pair sc = detail::atan2_args(abs); + detail::uint32 m = + detail::atan2(sc.first, sc.second, (half::round_style == std::round_to_nearest) ? 27 : 26); + return half(detail::binary, + detail::fixed2half(m, 14, sign)); +#endif +} + +/// Arc cosine function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::acos](https://en.cppreference.com/w/cpp/numeric/math/acos). +/// \param arg function argument +/// \return arc cosine value of \a arg +/// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half acos(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::acos(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ >> 15; + if(!abs) + return half(detail::binary, detail::rounded(0x3E48, 0, 1)); + if(abs >= 0x3C00) + return half(detail::binary, + (abs > 0x7C00) + ? detail::signal(arg.data_) + : (abs > 0x3C00) + ? detail::invalid() + : sign ? detail::rounded(0x4248, 0, 1) : 0); + std::pair cs = detail::atan2_args(abs); + detail::uint32 m = detail::atan2(cs.second, cs.first, 28); + return half(detail::binary, + detail::fixed2half( + sign ? (0xC90FDAA2 - m) : m, 15, 0, sign)); +#endif +} + +/// Arc tangent function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::atan](https://en.cppreference.com/w/cpp/numeric/math/atan). +/// \param arg function argument +/// \return arc tangent value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half atan(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::atan(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? detail::rounded(sign | 0x3E48, 0, 1) + : detail::signal(arg.data_)); + if(abs <= 0x2700) + return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); + int exp = (abs >> 10) + (abs <= 0x3FF); + detail::uint32 my = (abs & 0x3FF) | ((abs > 0x3FF) << 10); + detail::uint32 m = (exp > 15) + ? detail::atan2(my << 19, + 0x20000000 >> (exp - 15), + (half::round_style == std::round_to_nearest) ? 26 : 24) + : detail::atan2(my << (exp + 4), + 0x20000000, + (half::round_style == std::round_to_nearest) ? 30 : 28); + return half(detail::binary, + detail::fixed2half(m, 14, sign)); +#endif +} + +/// Arc tangent function. +/// This function may be 1 ULP off the correctly rounded exact result in ~0.005% of inputs for +/// `std::round_to_nearest`, +/// in ~0.1% of inputs for `std::round_toward_zero` and in ~0.02% of inputs for any other rounding +/// mode. +/// +/// **See also:** Documentation for +/// [std::atan2](https://en.cppreference.com/w/cpp/numeric/math/atan2). +/// \param y numerator +/// \param x denominator +/// \return arc tangent value +/// \exception FE_INVALID if \a x or \a y is signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half atan2(half y, half x) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::atan2(detail::half2float(y.data_), + detail::half2float(x.data_)))); +#else + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, signx = x.data_ >> 15, + signy = y.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + { + if(absx > 0x7C00 || absy > 0x7C00) + return half(detail::binary, detail::signal(x.data_, y.data_)); + if(absy == 0x7C00) + return half(detail::binary, + (absx < 0x7C00) + ? detail::rounded(signy | 0x3E48, 0, 1) + : signx + ? detail::rounded(signy | 0x40B6, 0, 1) + : detail::rounded(signy | 0x3A48, 0, 1)); + return (x.data_ == 0x7C00) + ? half(detail::binary, signy) + : half(detail::binary, + detail::rounded(signy | 0x4248, 0, 1)); + } + if(!absy) + return signx ? half(detail::binary, + detail::rounded(signy | 0x4248, 0, 1)) + : y; + if(!absx) + return half(detail::binary, detail::rounded(signy | 0x3E48, 0, 1)); + int d = (absy >> 10) + (absy <= 0x3FF) - (absx >> 10) - (absx <= 0x3FF); + if(d > (signx ? 18 : 12)) + return half(detail::binary, detail::rounded(signy | 0x3E48, 0, 1)); + if(signx && d < -11) + return half(detail::binary, detail::rounded(signy | 0x4248, 0, 1)); + if(!signx && d < ((half::round_style == std::round_toward_zero) ? -15 : -9)) + { + for(; absy < 0x400; absy <<= 1, --d) + ; + detail::uint32 mx = ((absx << 1) & 0x7FF) | 0x800, my = ((absy << 1) & 0x7FF) | 0x800; + int i = my < mx; + d -= i; + if(d < -25) + return half(detail::binary, detail::underflow(signy)); + my <<= 11 + i; + return half(detail::binary, + detail::fixed2half( + my / mx, d + 14, signy, my % mx != 0)); + } + detail::uint32 m = detail::atan2( + ((absy & 0x3FF) | ((absy > 0x3FF) << 10)) << (19 + ((d < 0) ? d : (d > 0) ? 0 : -1)), + ((absx & 0x3FF) | ((absx > 0x3FF) << 10)) << (19 - ((d > 0) ? d : (d < 0) ? 0 : 1))); + return half(detail::binary, + detail::fixed2half( + signx ? (0xC90FDAA2 - m) : m, 15, signy, signx)); +#endif +} + +/// \} +/// \anchor hyperbolic +/// \name Hyperbolic functions +/// \{ + +/// Hyperbolic sine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::sinh](https://en.cppreference.com/w/cpp/numeric/math/sinh). +/// \param arg function argument +/// \return hyperbolic sine value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half sinh(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::sinh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp; + if(!abs || abs >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + if(abs <= 0x2900) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + std::pair mm = + detail::hyperbolic_args(abs, exp, (half::round_style == std::round_to_nearest) ? 29 : 27); + detail::uint32 m = mm.first - mm.second; + for(exp += 13; m < 0x80000000 && exp; m <<= 1, --exp) + ; + unsigned int sign = arg.data_ & 0x8000; + if(exp > 29) + return half(detail::binary, detail::overflow(sign)); + return half(detail::binary, + detail::fixed2half(m, exp, sign)); +#endif +} + +/// Hyperbolic cosine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::cosh](https://en.cppreference.com/w/cpp/numeric/math/cosh). +/// \param arg function argument +/// \return hyperbolic cosine value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half cosh(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::cosh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, (abs > 0x7C00) ? detail::signal(arg.data_) : 0x7C00); + std::pair mm = + detail::hyperbolic_args(abs, exp, (half::round_style == std::round_to_nearest) ? 23 : 26); + detail::uint32 m = mm.first + mm.second, i = (~m & 0xFFFFFFFF) >> 31; + m = (m >> i) | (m & i) | 0x80000000; + if((exp += 13 + i) > 29) + return half(detail::binary, detail::overflow()); + return half(detail::binary, + detail::fixed2half(m, exp)); +#endif +} + +/// Hyperbolic tangent. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::tanh](https://en.cppreference.com/w/cpp/numeric/math/tanh). +/// \param arg function argument +/// \return hyperbolic tangent value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half tanh(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::tanh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, + (abs > 0x7C00) ? detail::signal(arg.data_) : (arg.data_ - 0x4000)); + if(abs >= 0x4500) + return half(detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x3BFF, 1, 1)); + if(abs < 0x2700) + return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); + if(half::round_style != std::round_to_nearest && abs == 0x2D3F) + return half(detail::binary, detail::rounded(arg.data_ - 3, 0, 1)); + std::pair mm = detail::hyperbolic_args(abs, exp, 27); + detail::uint32 my = mm.first - mm.second - (half::round_style != std::round_to_nearest), + mx = mm.first + mm.second, i = (~mx & 0xFFFFFFFF) >> 31; + for(exp = 13; my < 0x80000000; my <<= 1, --exp) + ; + mx = (mx >> i) | 0x80000000; + return half(detail::binary, + detail::tangent_post(my, mx, exp - i, arg.data_ & 0x8000)); +#endif +} + +/// Hyperbolic area sine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::asinh](https://en.cppreference.com/w/cpp/numeric/math/asinh). +/// \param arg function argument +/// \return area sine value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half asinh(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::asinh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if(!abs || abs >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + if(abs <= 0x2900) + return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); + if(half::round_style != std::round_to_nearest) + switch(abs) + { + case 0x32D4: + return half(detail::binary, + detail::rounded(arg.data_ - 13, 1, 1)); + case 0x3B5B: + return half(detail::binary, + detail::rounded(arg.data_ - 197, 1, 1)); + } + return half(detail::binary, detail::area(arg.data_)); +#endif +} + +/// Hyperbolic area cosine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::acosh](https://en.cppreference.com/w/cpp/numeric/math/acosh). +/// \param arg function argument +/// \return area cosine value of \a arg +/// \exception FE_INVALID for signaling NaN or arguments <1 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half acosh(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::acosh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if((arg.data_ & 0x8000) || abs < 0x3C00) + return half(detail::binary, + (abs <= 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs == 0x3C00) + return half(detail::binary, 0); + if(arg.data_ >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + return half(detail::binary, detail::area(arg.data_)); +#endif +} + +/// Hyperbolic area tangent. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::atanh](https://en.cppreference.com/w/cpp/numeric/math/atanh). +/// \param arg function argument +/// \return area tangent value of \a arg +/// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 +/// \exception FE_DIVBYZERO for +/-1 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half atanh(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::atanh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = 0; + if(!abs) + return arg; + if(abs >= 0x3C00) + return half(detail::binary, + (abs == 0x3C00) + ? detail::pole(arg.data_ & 0x8000) + : (abs <= 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2700) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + detail::uint32 m = static_cast((abs & 0x3FF) | ((abs > 0x3FF) << 10)) + << ((abs >> 10) + (abs <= 0x3FF) + 6), + my = 0x80000000 + m, mx = 0x80000000 - m; + for(; mx < 0x80000000; mx <<= 1, ++exp) + ; + int i = my >= mx, s; + return half(detail::binary, + detail::log2_post( + detail::log2((detail::divide64(my >> i, mx, s) + 1) >> 1, 27) + 0x10, + exp + i - 1, + 16, + arg.data_ & 0x8000)); +#endif +} + +/// \} +/// \anchor special +/// \name Error and gamma functions +/// \{ + +/// Error function. +/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in <0.5% +/// of inputs. +/// +/// **See also:** Documentation for [std::erf](https://en.cppreference.com/w/cpp/numeric/math/erf). +/// \param arg function argument +/// \return error function value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half erf(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::erf(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF; + if(!abs || abs >= 0x7C00) + return (abs >= 0x7C00) + ? half(detail::binary, + (abs == 0x7C00) ? (arg.data_ - 0x4000) : detail::signal(arg.data_)) + : arg; + if(abs >= 0x4200) + return half(detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x3BFF, 1, 1)); + return half(detail::binary, detail::erf(arg.data_)); +#endif +} + +/// Complementary error function. +/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in <0.5% +/// of inputs. +/// +/// **See also:** Documentation for +/// [std::erfc](https://en.cppreference.com/w/cpp/numeric/math/erfc). +/// \param arg function argument +/// \return 1 minus error function value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half erfc(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::erfc(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(abs >= 0x7C00) + return (abs >= 0x7C00) + ? half(detail::binary, (abs == 0x7C00) ? (sign >> 1) : detail::signal(arg.data_)) + : arg; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x4400) + return half( + detail::binary, + detail::rounded((sign >> 1) - (sign >> 15), sign >> 15, 1)); + return half(detail::binary, detail::erf(arg.data_)); +#endif +} + +/// Natural logarithm of gamma function. +/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in +/// ~0.025% of inputs. +/// +/// **See also:** Documentation for +/// [std::lgamma](https://en.cppreference.com/w/cpp/numeric/math/lgamma). +/// \param arg function argument +/// \return natural logarith of gamma function for \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_DIVBYZERO for 0 or negative integer arguments +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half lgamma(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::lgamma(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if(abs >= 0x7C00) + return half(detail::binary, (abs == 0x7C00) ? 0x7C00 : detail::signal(arg.data_)); + if(!abs || arg.data_ >= 0xE400 || + (arg.data_ >= 0xBC00 && !(abs & ((1 << (25 - (abs >> 10))) - 1)))) + return half(detail::binary, detail::pole()); + if(arg.data_ == 0x3C00 || arg.data_ == 0x4000) + return half(detail::binary, 0); + return half(detail::binary, detail::gamma(arg.data_)); +#endif +} + +/// Gamma function. +/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in +/// <0.25% of inputs. +/// +/// **See also:** Documentation for +/// [std::tgamma](https://en.cppreference.com/w/cpp/numeric/math/tgamma). +/// \param arg function argument +/// \return gamma function value of \a arg +/// \exception FE_INVALID for signaling NaN, negative infinity or negative integer arguments +/// \exception FE_DIVBYZERO for 0 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half tgamma(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::tgamma(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF; + if(!abs) + return half(detail::binary, detail::pole(arg.data_)); + if(abs >= 0x7C00) + return (arg.data_ == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + if(arg.data_ >= 0xE400 || (arg.data_ >= 0xBC00 && !(abs & ((1 << (25 - (abs >> 10))) - 1)))) + return half(detail::binary, detail::invalid()); + if(arg.data_ >= 0xCA80) + return half( + detail::binary, + detail::underflow((1 - ((abs >> (25 - (abs >> 10))) & 1)) << 15)); + if(arg.data_ <= 0x100 || (arg.data_ >= 0x4900 && arg.data_ < 0x8000)) + return half(detail::binary, detail::overflow()); + if(arg.data_ == 0x3C00) + return arg; + return half(detail::binary, detail::gamma(arg.data_)); +#endif +} + +/// \} +/// \anchor rounding +/// \name Rounding +/// \{ + +/// Nearest integer not less than half value. +/// **See also:** Documentation for +/// [std::ceil](https://en.cppreference.com/w/cpp/numeric/math/ceil). +/// \param arg half to round +/// \return nearest integer not less than \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded +inline half ceil(half arg) +{ + return half(detail::binary, + detail::integral(arg.data_)); +} + +/// Nearest integer not greater than half value. +/// **See also:** Documentation for +/// [std::floor](https://en.cppreference.com/w/cpp/numeric/math/floor). +/// \param arg half to round +/// \return nearest integer not greater than \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded +inline half floor(half arg) +{ + return half(detail::binary, + detail::integral(arg.data_)); +} + +/// Nearest integer not greater in magnitude than half value. +/// **See also:** Documentation for +/// [std::trunc](https://en.cppreference.com/w/cpp/numeric/math/trunc). +/// \param arg half to round +/// \return nearest integer not greater in magnitude than \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded +inline half trunc(half arg) +{ + return half(detail::binary, detail::integral(arg.data_)); +} + +/// Nearest integer. +/// **See also:** Documentation for +/// [std::round](https://en.cppreference.com/w/cpp/numeric/math/round). +/// \param arg half to round +/// \return nearest integer, rounded away from zero in half-way cases +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded +inline half round(half arg) +{ + return half(detail::binary, detail::integral(arg.data_)); +} + +/// Nearest integer. +/// **See also:** Documentation for +/// [std::lround](https://en.cppreference.com/w/cpp/numeric/math/round). +/// \param arg half to round +/// \return nearest integer, rounded away from zero in half-way cases +/// \exception FE_INVALID if value is not representable as `long` +inline long lround(half arg) +{ + return detail::half2int(arg.data_); +} + +/// Nearest integer using half's internal rounding mode. +/// **See also:** Documentation for +/// [std::rint](https://en.cppreference.com/w/cpp/numeric/math/rint). +/// \param arg half expression to round +/// \return nearest integer using default rounding mode +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded +inline half rint(half arg) +{ + return half(detail::binary, detail::integral(arg.data_)); +} + +/// Nearest integer using half's internal rounding mode. +/// **See also:** Documentation for +/// [std::lrint](https://en.cppreference.com/w/cpp/numeric/math/rint). +/// \param arg half expression to round +/// \return nearest integer using default rounding mode +/// \exception FE_INVALID if value is not representable as `long` +/// \exception FE_INEXACT if value had to be rounded +inline long lrint(half arg) +{ + return detail::half2int(arg.data_); +} + +/// Nearest integer using half's internal rounding mode. +/// **See also:** Documentation for +/// [std::nearbyint](https://en.cppreference.com/w/cpp/numeric/math/nearbyint). +/// \param arg half expression to round +/// \return nearest integer using default rounding mode +/// \exception FE_INVALID for signaling NaN +inline half nearbyint(half arg) +{ + return half(detail::binary, detail::integral(arg.data_)); +} +#if HALF_ENABLE_CPP11_LONG_LONG +/// Nearest integer. +/// **See also:** Documentation for +/// [std::llround](https://en.cppreference.com/w/cpp/numeric/math/round). +/// \param arg half to round +/// \return nearest integer, rounded away from zero in half-way cases +/// \exception FE_INVALID if value is not representable as `long long` +inline long long llround(half arg) +{ + return detail::half2int(arg.data_); +} + +/// Nearest integer using half's internal rounding mode. +/// **See also:** Documentation for +/// [std::llrint](https://en.cppreference.com/w/cpp/numeric/math/rint). +/// \param arg half expression to round +/// \return nearest integer using default rounding mode +/// \exception FE_INVALID if value is not representable as `long long` +/// \exception FE_INEXACT if value had to be rounded +inline long long llrint(half arg) +{ + return detail::half2int(arg.data_); +} +#endif + +/// \} +/// \anchor float +/// \name Floating point manipulation +/// \{ + +/// Decompress floating-point number. +/// **See also:** Documentation for +/// [std::frexp](https://en.cppreference.com/w/cpp/numeric/math/frexp). +/// \param arg number to decompress +/// \param exp address to store exponent at +/// \return significant in range [0.5, 1) +/// \exception FE_INVALID for signaling NaN +inline half frexp(half arg, int* exp) +{ + *exp = 0; + unsigned int abs = arg.data_ & 0x7FFF; + if(abs >= 0x7C00 || !abs) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs < 0x400; abs <<= 1, --*exp) + ; + *exp += (abs >> 10) - 14; + return half(detail::binary, (arg.data_ & 0x8000) | 0x3800 | (abs & 0x3FF)); +} + +/// Multiply by power of two. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::scalbln](https://en.cppreference.com/w/cpp/numeric/math/scalbn). +/// \param arg number to modify +/// \param exp power of two to multiply with +/// \return \a arg multplied by 2 raised to \a exp +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half scalbln(half arg, long exp) +{ + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(abs >= 0x7C00 || !abs) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + if(exp > 30) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -10) + return half(detail::binary, detail::underflow(sign)); + else if(exp > 0) + return half(detail::binary, sign | (exp << 10) | (abs & 0x3FF)); + unsigned int m = (abs & 0x3FF) | 0x400; + return half(detail::binary, + detail::rounded( + sign | (m >> (1 - exp)), (m >> -exp) & 1, (m & ((1 << -exp) - 1)) != 0)); +} + +/// Multiply by power of two. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::scalbn](https://en.cppreference.com/w/cpp/numeric/math/scalbn). +/// \param arg number to modify +/// \param exp power of two to multiply with +/// \return \a arg multplied by 2 raised to \a exp +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half scalbn(half arg, int exp) { return scalbln(arg, exp); } + +/// Multiply by power of two. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::ldexp](https://en.cppreference.com/w/cpp/numeric/math/ldexp). +/// \param arg number to modify +/// \param exp power of two to multiply with +/// \return \a arg multplied by 2 raised to \a exp +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half ldexp(half arg, int exp) { return scalbln(arg, exp); } + +/// Extract integer and fractional parts. +/// **See also:** Documentation for +/// [std::modf](https://en.cppreference.com/w/cpp/numeric/math/modf). +/// \param arg number to decompress +/// \param iptr address to store integer part at +/// \return fractional part +/// \exception FE_INVALID for signaling NaN +inline half modf(half arg, half* iptr) +{ + unsigned int abs = arg.data_ & 0x7FFF; + if(abs > 0x7C00) + { + arg = half(detail::binary, detail::signal(arg.data_)); + return *iptr = arg, arg; + } + if(abs >= 0x6400) + return *iptr = arg, half(detail::binary, arg.data_ & 0x8000); + if(abs < 0x3C00) + return iptr->data_ = arg.data_ & 0x8000, arg; + unsigned int exp = abs >> 10, mask = (1 << (25 - exp)) - 1, m = arg.data_ & mask; + iptr->data_ = arg.data_ & ~mask; + if(!m) + return half(detail::binary, arg.data_ & 0x8000); + for(; m < 0x400; m <<= 1, --exp) + ; + return half(detail::binary, (arg.data_ & 0x8000) | (exp << 10) | (m & 0x3FF)); +} + +/// Extract exponent. +/// **See also:** Documentation for +/// [std::ilogb](https://en.cppreference.com/w/cpp/numeric/math/ilogb). +/// \param arg number to query +/// \return floating-point exponent +/// \retval FP_ILOGB0 for zero +/// \retval FP_ILOGBNAN for NaN +/// \retval INT_MAX for infinity +/// \exception FE_INVALID for 0 or infinite values +inline int ilogb(half arg) +{ + int abs = arg.data_ & 0x7FFF, exp; + if(!abs || abs >= 0x7C00) + { + detail::raise(FE_INVALID); + return !abs ? FP_ILOGB0 : (abs == 0x7C00) ? INT_MAX : FP_ILOGBNAN; + } + for(exp = (abs >> 10) - 15; abs < 0x200; abs <<= 1, --exp) + ; + return exp; +} + +/// Extract exponent. +/// **See also:** Documentation for +/// [std::logb](https://en.cppreference.com/w/cpp/numeric/math/logb). +/// \param arg number to query +/// \return floating-point exponent +/// \exception FE_INVALID for signaling NaN +/// \exception FE_DIVBYZERO for 0 +inline half logb(half arg) +{ + int abs = arg.data_ & 0x7FFF, exp; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(abs >= 0x7C00) + return half(detail::binary, (abs == 0x7C00) ? 0x7C00 : detail::signal(arg.data_)); + for(exp = (abs >> 10) - 15; abs < 0x200; abs <<= 1, --exp) + ; + unsigned int value = static_cast(exp < 0) << 15; + if(exp) + { + unsigned int m = std::abs(exp) << 6; + for(exp = 18; m < 0x400; m <<= 1, --exp) + ; + value |= (exp << 10) + m; + } + return half(detail::binary, value); +} + +/// Next representable value. +/// **See also:** Documentation for +/// [std::nextafter](https://en.cppreference.com/w/cpp/numeric/math/nextafter). +/// \param from value to compute next representable value for +/// \param to direction towards which to compute next value +/// \return next representable value after \a from in direction towards \a to +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW for infinite result from finite argument +/// \exception FE_UNDERFLOW for subnormal result +inline half nextafter(half from, half to) +{ + int fabs = from.data_ & 0x7FFF, tabs = to.data_ & 0x7FFF; + if(fabs > 0x7C00 || tabs > 0x7C00) + return half(detail::binary, detail::signal(from.data_, to.data_)); + if(from.data_ == to.data_ || !(fabs | tabs)) + return to; + if(!fabs) + { + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT); + return half(detail::binary, (to.data_ & 0x8000) + 1); + } + unsigned int out = + from.data_ + + (((from.data_ >> 15) ^ + static_cast((from.data_ ^ (0x8000 | (0x8000 - (from.data_ >> 15)))) < + (to.data_ ^ (0x8000 | (0x8000 - (to.data_ >> 15)))))) + << 1) - + 1; + detail::raise(FE_OVERFLOW, fabs < 0x7C00 && (out & 0x7C00) == 0x7C00); + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT && (out & 0x7C00) < 0x400); + return half(detail::binary, out); +} + +/// Next representable value. +/// **See also:** Documentation for +/// [std::nexttoward](https://en.cppreference.com/w/cpp/numeric/math/nexttoward). +/// \param from value to compute next representable value for +/// \param to direction towards which to compute next value +/// \return next representable value after \a from in direction towards \a to +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW for infinite result from finite argument +/// \exception FE_UNDERFLOW for subnormal result +inline half nexttoward(half from, long double to) +{ + int fabs = from.data_ & 0x7FFF; + if(fabs > 0x7C00) + return half(detail::binary, detail::signal(from.data_)); + long double lfrom = static_cast(from); + if(detail::builtin_isnan(to) || lfrom == to) + return half(static_cast(to)); + if(!fabs) + { + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT); + return half(detail::binary, (static_cast(detail::builtin_signbit(to)) << 15) + 1); + } + unsigned int out = + from.data_ + (((from.data_ >> 15) ^ static_cast(lfrom < to)) << 1) - 1; + detail::raise(FE_OVERFLOW, (out & 0x7FFF) == 0x7C00); + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT && (out & 0x7FFF) < 0x400); + return half(detail::binary, out); +} + +/// Take sign. +/// **See also:** Documentation for +/// [std::copysign](https://en.cppreference.com/w/cpp/numeric/math/copysign). +/// \param x value to change sign for +/// \param y value to take sign from +/// \return value equal to \a x in magnitude and to \a y in sign +inline HALF_CONSTEXPR half copysign(half x, half y) +{ + return half(detail::binary, x.data_ ^ ((x.data_ ^ y.data_) & 0x8000)); +} + +/// \} +/// \anchor classification +/// \name Floating point classification +/// \{ + +/// Classify floating-point value. +/// **See also:** Documentation for +/// [std::fpclassify](https://en.cppreference.com/w/cpp/numeric/math/fpclassify). +/// \param arg number to classify +/// \retval FP_ZERO for positive and negative zero +/// \retval FP_SUBNORMAL for subnormal numbers +/// \retval FP_INFINITY for positive and negative infinity +/// \retval FP_NAN for NaNs +/// \retval FP_NORMAL for all other (normal) values +inline HALF_CONSTEXPR int fpclassify(half arg) +{ + return !(arg.data_ & 0x7FFF) + ? FP_ZERO + : ((arg.data_ & 0x7FFF) < 0x400) + ? FP_SUBNORMAL + : ((arg.data_ & 0x7FFF) < 0x7C00) + ? FP_NORMAL + : ((arg.data_ & 0x7FFF) == 0x7C00) ? FP_INFINITE : FP_NAN; +} + +/// Check if finite number. +/// **See also:** Documentation for +/// [std::isfinite](https://en.cppreference.com/w/cpp/numeric/math/isfinite). +/// \param arg number to check +/// \retval true if neither infinity nor NaN +/// \retval false else +inline HALF_CONSTEXPR bool isfinite(half arg) { return (arg.data_ & 0x7C00) != 0x7C00; } + +/// Check for infinity. +/// **See also:** Documentation for +/// [std::isinf](https://en.cppreference.com/w/cpp/numeric/math/isinf). +/// \param arg number to check +/// \retval true for positive or negative infinity +/// \retval false else +inline HALF_CONSTEXPR bool isinf(half arg) { return (arg.data_ & 0x7FFF) == 0x7C00; } + +/// Check for NaN. +/// **See also:** Documentation for +/// [std::isnan](https://en.cppreference.com/w/cpp/numeric/math/isnan). +/// \param arg number to check +/// \retval true for NaNs +/// \retval false else +inline HALF_CONSTEXPR bool isnan(half arg) { return (arg.data_ & 0x7FFF) > 0x7C00; } + +/// Check if normal number. +/// **See also:** Documentation for +/// [std::isnormal](https://en.cppreference.com/w/cpp/numeric/math/isnormal). +/// \param arg number to check +/// \retval true if normal number +/// \retval false if either subnormal, zero, infinity or NaN +inline HALF_CONSTEXPR bool isnormal(half arg) +{ + return ((arg.data_ & 0x7C00) != 0) & ((arg.data_ & 0x7C00) != 0x7C00); +} + +/// Check sign. +/// **See also:** Documentation for +/// [std::signbit](https://en.cppreference.com/w/cpp/numeric/math/signbit). +/// \param arg number to check +/// \retval true for negative number +/// \retval false for positive number +inline HALF_CONSTEXPR bool signbit(half arg) { return (arg.data_ & 0x8000) != 0; } + +/// \} +/// \anchor compfunc +/// \name Comparison +/// \{ + +/// Quiet comparison for greater than. +/// **See also:** Documentation for +/// [std::isgreater](https://en.cppreference.com/w/cpp/numeric/math/isgreater). +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater than \a y +/// \retval false else +inline HALF_CONSTEXPR bool isgreater(half x, half y) +{ + return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) > + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && + !isnan(x) && !isnan(y); +} + +/// Quiet comparison for greater equal. +/// **See also:** Documentation for +/// [std::isgreaterequal](https://en.cppreference.com/w/cpp/numeric/math/isgreaterequal). +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater equal \a y +/// \retval false else +inline HALF_CONSTEXPR bool isgreaterequal(half x, half y) +{ + return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) >= + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && + !isnan(x) && !isnan(y); +} + +/// Quiet comparison for less than. +/// **See also:** Documentation for +/// [std::isless](https://en.cppreference.com/w/cpp/numeric/math/isless). +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less than \a y +/// \retval false else +inline HALF_CONSTEXPR bool isless(half x, half y) +{ + return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) < + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && + !isnan(x) && !isnan(y); +} + +/// Quiet comparison for less equal. +/// **See also:** Documentation for +/// [std::islessequal](https://en.cppreference.com/w/cpp/numeric/math/islessequal). +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less equal \a y +/// \retval false else +inline HALF_CONSTEXPR bool islessequal(half x, half y) +{ + return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) <= + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && + !isnan(x) && !isnan(y); +} + +/// Quiet comarison for less or greater. +/// **See also:** Documentation for +/// [std::islessgreater](https://en.cppreference.com/w/cpp/numeric/math/islessgreater). +/// \param x first operand +/// \param y second operand +/// \retval true if either less or greater +/// \retval false else +inline HALF_CONSTEXPR bool islessgreater(half x, half y) +{ + return x.data_ != y.data_ && ((x.data_ | y.data_) & 0x7FFF) && !isnan(x) && !isnan(y); +} + +/// Quiet check if unordered. +/// **See also:** Documentation for +/// [std::isunordered](https://en.cppreference.com/w/cpp/numeric/math/isunordered). +/// \param x first operand +/// \param y second operand +/// \retval true if unordered (one or two NaN operands) +/// \retval false else +inline HALF_CONSTEXPR bool isunordered(half x, half y) { return isnan(x) || isnan(y); } + +/// \} +/// \anchor casting +/// \name Casting +/// \{ + +/// Cast to or from half-precision floating-point number. +/// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values +/// are converted +/// directly using the default rounding mode, without any roundtrip over `float` that a +/// `static_cast` would otherwise do. +/// +/// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any +/// of the two types +/// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) +/// results in a compiler +/// error and casting between [half](\ref half_float::half)s returns the argument unmodified. +/// \tparam T destination type (half or built-in arithmetic type) +/// \tparam U source type (half or built-in arithmetic type) +/// \param arg value to cast +/// \return \a arg converted to destination type +/// \exception FE_INVALID if \a T is integer type and result is not representable as \a T +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +template +T half_cast(U arg) +{ + return detail::half_caster::cast(arg); +} + +/// Cast to or from half-precision floating-point number. +/// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values +/// are converted +/// directly using the specified rounding mode, without any roundtrip over `float` that a +/// `static_cast` would otherwise do. +/// +/// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any +/// of the two types +/// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) +/// results in a compiler +/// error and casting between [half](\ref half_float::half)s returns the argument unmodified. +/// \tparam T destination type (half or built-in arithmetic type) +/// \tparam R rounding mode to use. +/// \tparam U source type (half or built-in arithmetic type) +/// \param arg value to cast +/// \return \a arg converted to destination type +/// \exception FE_INVALID if \a T is integer type and result is not representable as \a T +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +template +T half_cast(U arg) +{ + return detail::half_caster::cast(arg); +} +/// \} + +/// \} +/// \anchor errors +/// \name Error handling +/// \{ + +/// Clear exception flags. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::feclearexcept](https://en.cppreference.com/w/cpp/numeric/fenv/feclearexcept). +/// \param excepts OR of exceptions to clear +/// \retval 0 all selected flags cleared successfully +inline int feclearexcept(int excepts) +{ + detail::errflags() &= ~excepts; + return 0; +} + +/// Test exception flags. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::fetestexcept](https://en.cppreference.com/w/cpp/numeric/fenv/fetestexcept). +/// \param excepts OR of exceptions to test +/// \return OR of selected exceptions if raised +inline int fetestexcept(int excepts) { return detail::errflags() & excepts; } + +/// Raise exception flags. +/// This raises the specified floating point exceptions and also invokes any additional automatic +/// exception handling as +/// configured with the [HALF_ERRHANDLIG_...](\ref HALF_ERRHANDLING_ERRNO) preprocessor symbols. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::feraiseexcept](https://en.cppreference.com/w/cpp/numeric/fenv/feraiseexcept). +/// \param excepts OR of exceptions to raise +/// \retval 0 all selected exceptions raised successfully +inline int feraiseexcept(int excepts) +{ + detail::errflags() |= excepts; + detail::raise(excepts); + return 0; +} + +/// Save exception flags. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::fegetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). +/// \param flagp adress to store flag state at +/// \param excepts OR of flags to save +/// \retval 0 for success +inline int fegetexceptflag(int* flagp, int excepts) +{ + *flagp = detail::errflags() & excepts; + return 0; +} + +/// Restore exception flags. +/// This only copies the specified exception state (including unset flags) without incurring any +/// additional exception handling. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::fesetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). +/// \param flagp adress to take flag state from +/// \param excepts OR of flags to restore +/// \retval 0 for success +inline int fesetexceptflag(const int* flagp, int excepts) +{ + detail::errflags() = (detail::errflags() | (*flagp & excepts)) & (*flagp | ~excepts); + return 0; +} + +/// Throw C++ exceptions based on set exception flags. +/// This function manually throws a corresponding C++ exception if one of the specified flags is +/// set, +/// no matter if automatic throwing (via [HALF_ERRHANDLING_THROW_...](\ref +/// HALF_ERRHANDLING_THROW_INVALID)) is enabled or not. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// \param excepts OR of exceptions to test +/// \param msg error message to use for exception description +/// \throw std::domain_error if `FE_INVALID` or `FE_DIVBYZERO` is selected and set +/// \throw std::overflow_error if `FE_OVERFLOW` is selected and set +/// \throw std::underflow_error if `FE_UNDERFLOW` is selected and set +/// \throw std::range_error if `FE_INEXACT` is selected and set +inline void fethrowexcept(int excepts, const char* msg = "") +{ + excepts &= detail::errflags(); + if(excepts & (FE_INVALID | FE_DIVBYZERO)) + throw std::domain_error(msg); + if(excepts & FE_OVERFLOW) + throw std::overflow_error(msg); + if(excepts & FE_UNDERFLOW) + throw std::underflow_error(msg); + if(excepts & FE_INEXACT) + throw std::range_error(msg); +} +/// \} +} // namespace half_float + +#undef HALF_UNUSED_NOERR +#undef HALF_CONSTEXPR +#undef HALF_CONSTEXPR_CONST +#undef HALF_CONSTEXPR_NOERR +#undef HALF_NOEXCEPT +#undef HALF_NOTHROW +#undef HALF_THREAD_LOCAL +#undef HALF_TWOS_COMPLEMENT_INT +#ifdef HALF_POP_WARNINGS +#pragma warning(pop) +#undef HALF_POP_WARNINGS +#endif + +#endif diff --git a/external/rocm/include/bfloat16_dev.hpp b/external/rocm/include/bfloat16_dev.hpp new file mode 100644 index 0000000000..52d00346cf --- /dev/null +++ b/external/rocm/include/bfloat16_dev.hpp @@ -0,0 +1,125 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2019 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef BFLOAT16_DEVICE_HPP +#define BFLOAT16_DEVICE_HPP + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef __HIP_PLATFORM_HCC__ +#define EXECUTION_SPECIFIER __device__ +#else +#define EXECUTION_SPECIFIER +#endif // MIOPEN_BACKEND_HIP + +typedef union +{ + uint u32; + ushort2 ushortx2; + +// Composable kernels are written in HIP language. The language doesnt support +// ushort2.hi or ushort2.low. +#ifdef __HIP_PLATFORM_HCC__ + ushort ushortvec[2]; +#endif // MIOPEN_BACKEND_HIP + float f32; +} cvt_bf16_fp32_t; + +EXECUTION_SPECIFIER float bfloat16_to_float(ushort src_val) +{ + cvt_bf16_fp32_t target_val; + +#ifdef __HIP_PLATFORM_HCC__ + target_val.ushortx2 = make_ushort2(0, src_val); +#else + target_val.ushortx2 = (ushort2)(0, src_val); +#endif + + return target_val.f32; +} + +EXECUTION_SPECIFIER ushort float_to_bfloat16(float src_val) +{ + cvt_bf16_fp32_t target_val; + target_val.f32 = src_val; + // BF16 round and NaN preservation code matches + // https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/include/rocblas_bfloat16.h + if((~target_val.u32 & 0x7f800000) == 0) // Inf or NaN + { + // When all of the exponent bits are 1, the value is Inf or NaN. + // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero + // mantissa bit. Quiet NaN is indicated by the most significant mantissa + // bit being 1. Signaling NaN is indicated by the most significant + // mantissa bit being 0 but some other bit(s) being 1. If any of the + // lower 16 bits of the mantissa are 1, we set the least significant bit + // of the bfloat16 mantissa, in order to preserve signaling NaN in case + // the bloat16's mantissa bits are all 0. + if((target_val.u32 & 0xffff) != 0) + { + target_val.u32 |= 0x10000; // Preserve signaling NaN + } + } + else + { +#ifdef MIOPEN_USE_RNE_BFLOAT16 +// When the exponent bits are not all 1s, then the value is zero, normal, +// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus +// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). +// This causes the bfloat16's mantissa to be incremented by 1 if the 16 +// least significant bits of the float mantissa are greater than 0x8000, +// or if they are equal to 0x8000 and the least significant bit of the +// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when +// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already +// has the value 0x7f, then incrementing it causes it to become 0x00 and +// the exponent is incremented by one, which is the next higher FP value +// to the unrounded bfloat16 value. When the bfloat16 value is subnormal +// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up +// to a normal value with an exponent of 0x01 and a mantissa of 0x00. +// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, +// incrementing it causes it to become an exponent of 0xFF and a mantissa +// of 0x00, which is Inf, the next higher value to the unrounded value. +#ifdef __HIP_PLATFORM_HCC__ + target_val.u32 += (0x7fff + (target_val.ushortvec[1] & 1)); +#else + target_val.u32 += + (0x7fff + (target_val.ushortx2.hi & 1)); // Round to nearest, round to even +#endif // MIOPEN_BACKEND_HIP +#endif // MIOPEN_USE_RNE_BFLOAT16 + } + +#ifdef __HIP_PLATFORM_HCC__ + return target_val.ushortvec[1]; +#else + return target_val.ushortx2.hi; +#endif // MIOPEN_BACKEND_HIP +} + +#ifdef __cplusplus +} +#endif + +#endif // BFLOAT16_DEVICE_HPP diff --git a/host/CMakeLists.txt b/host/CMakeLists.txt new file mode 100644 index 0000000000..c9779398a6 --- /dev/null +++ b/host/CMakeLists.txt @@ -0,0 +1,4 @@ +add_subdirectory(host_tensor) +add_subdirectory(online_compilation) +add_subdirectory(driver_offline) +add_subdirectory(driver_online) diff --git a/host/driver_offline/CMakeLists.txt b/host/driver_offline/CMakeLists.txt new file mode 100644 index 0000000000..85bd31fbca --- /dev/null +++ b/host/driver_offline/CMakeLists.txt @@ -0,0 +1,21 @@ +include_directories(BEFORE + include + ${PROJECT_SOURCE_DIR}/host/host_tensor/include + ${PROJECT_SOURCE_DIR}/composable_kernel/include + ${PROJECT_SOURCE_DIR}/composable_kernel/include/utility + ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description + ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation + ${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform + ${PROJECT_SOURCE_DIR}/composable_kernel/include/driver + ${PROJECT_SOURCE_DIR}/external/rocm/include + ${PROJECT_SOURCE_DIR}/external/half/include +) + +set(CONV_FWD_DRIVER_OFFLINE_SOURCE conv_fwd_driver_offline.cpp) +set(CONV_BWD_DRIVER_OFFLINE_SOURCE conv_bwd_driver_offline.cpp) + +add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE}) +add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE}) + +target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor) +target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor) diff --git a/host/driver_offline/conv_bwd_driver_offline.cpp b/host/driver_offline/conv_bwd_driver_offline.cpp new file mode 100644 index 0000000000..61c3fc385d --- /dev/null +++ b/host/driver_offline/conv_bwd_driver_offline.cpp @@ -0,0 +1,357 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "conv_common.hpp" +#include "host_conv_bwd_data.hpp" +#include "device_tensor.hpp" +#include "device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp" +#include "device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp" + +#define USE_DYNAMIC_MODE 1 +#define USE_CONV_BWD_V4R1_XDL_NHWC 1 +#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1 + +enum ConvBackwardDataAlgo +{ + V4R1XDLNHWC, + V4R1R2XDLNHWC, +}; + +int main(int argc, char* argv[]) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + +#if USE_DYNAMIC_MODE + // dynamic mode + if(argc != 22) + { + printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); + exit(1); + } + + const ConvTensorLayout layout = static_cast(atoi(argv[1])); + const ConvBackwardDataAlgo algo = static_cast(atoi(argv[2])); + const bool do_verification = atoi(argv[3]); + const int init_method = atoi(argv[4]); + const bool do_log = atoi(argv[5]); + const int nrepeat = atoi(argv[6]); + + const index_t N = atoi(argv[7]); + const index_t K = atoi(argv[8]); + const index_t C = atoi(argv[9]); + const index_t Y = atoi(argv[10]); + const index_t X = atoi(argv[11]); + const index_t Hi = atoi(argv[12]); + const index_t Wi = atoi(argv[13]); + + const index_t conv_stride_h = atoi(argv[14]); + const index_t conv_stride_w = atoi(argv[15]); + const index_t conv_dilation_h = atoi(argv[16]); + const index_t conv_dilation_w = atoi(argv[17]); + const index_t in_left_pad_h = atoi(argv[18]); + const index_t in_left_pad_w = atoi(argv[19]); + const index_t in_right_pad_h = atoi(argv[20]); + const index_t in_right_pad_w = atoi(argv[21]); + + const index_t YEff = (Y - 1) * conv_dilation_h + 1; + const index_t XEff = (X - 1) * conv_dilation_w + 1; + + const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; +#else + // static mode + if(argc < 7) + { + printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + exit(1); + } + + const ConvTensorLayout layout = static_cast(atoi(argv[1])); + const ConvBackwardDataAlgo algo = static_cast(atoi(argv[2])); + const bool do_verification = atoi(argv[3]); + const int init_method = atoi(argv[4]); + const bool do_log = atoi(argv[5]); + const int nrepeat = atoi(argv[6]); + + constexpr index_t N = 128; + constexpr index_t C = 192; + constexpr index_t Hi = 71; + constexpr index_t Wi = 71; + constexpr index_t K = 256; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + const index_t conv_stride_h = 2; + const index_t conv_stride_w = 2; + const index_t conv_dilation_h = 1; + const index_t conv_dilation_w = 1; + const index_t in_left_pad_h = 1; + const index_t in_left_pad_w = 1; + const index_t in_right_pad_h = 1; + const index_t in_right_pad_w = 1; + + const index_t YEff = (Y - 1) * conv_dilation_h + 1; + const index_t XEff = (X - 1) * conv_dilation_w + 1; + + const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; +#endif + +#if 0 + constexpr index_t in_vector_size = 1; + using in_data_t = float; + using acc_data_t = float; + using out_data_t = float; +#elif 1 + constexpr index_t in_vector_size = 1; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; +#endif + + std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); + + switch(layout) + { + case ConvTensorLayout::NCHW: + // NCHW + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(C); + in_lengths_host[2] = static_cast(Hi); + in_lengths_host[3] = static_cast(Wi); + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(C); + wei_lengths_host[2] = static_cast(Y); + wei_lengths_host[3] = static_cast(X); + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(K); + out_lengths_host[2] = static_cast(Ho); + out_lengths_host[3] = static_cast(Wo); + break; + case ConvTensorLayout::NHWC: + // NHWC + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(Hi); + in_lengths_host[2] = static_cast(Wi); + in_lengths_host[3] = static_cast(C); + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(Y); + wei_lengths_host[2] = static_cast(X); + wei_lengths_host[3] = static_cast(C); + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(Ho); + out_lengths_host[2] = static_cast(Wo); + out_lengths_host[3] = static_cast(K); + break; + default: throw std::runtime_error("wrong! not implemented"); + } + + Tensor in_host(in_lengths_host); + Tensor in_device(in_lengths_host); + Tensor wei(wei_lengths_host); + Tensor out(out_lengths_host); + + std::cout << "layout: " << layout << std::endl; + ostream_HostTensorDescriptor(in_host.mDesc, std::cout << "in: "); + ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: "); + ostream_HostTensorDescriptor(out.mDesc, std::cout << "out: "); + print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); + print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); + print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); + print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + switch(init_method) + { + case 0: + // no initialization + break; + case 1: + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 2: + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 3: + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 4: + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 5: + out.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + break; + default: + out.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); + + auto gen_wei = [](auto... is) { + return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); + }; + wei.GenerateTensorValue(gen_wei, num_thread); + } + + auto f_make_for_device_nchw = [&]() { +#if USE_DYNAMIC_MODE + const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); + const auto wei_lengths_dev = make_tuple(K, C, Y, X); + const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); + const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); +#else + const auto in_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto out_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto conv_strides_dev = make_tuple(Number{}, Number{}); + const auto conv_dilations_dev = + make_tuple(Number{}, Number{}); + const auto in_left_pads_dev = make_tuple(Number{}, Number{}); + const auto in_right_pads_dev = + make_tuple(Number{}, Number{}); +#endif + + return make_tuple(in_lengths_dev, + wei_lengths_dev, + out_lengths_dev, + conv_strides_dev, + conv_dilations_dev, + in_left_pads_dev, + in_right_pads_dev); + }; + + auto f_make_for_device_nhwc = [&]() { +#if USE_DYNAMIC_MODE + const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); + const auto wei_lengths_dev = make_tuple(K, Y, X, C); + const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); + const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); +#else + const auto in_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto out_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto conv_strides_dev = make_tuple(Number{}, Number{}); + const auto conv_dilations_dev = + make_tuple(Number{}, Number{}); + const auto in_left_pads_dev = make_tuple(Number{}, Number{}); + const auto in_right_pads_dev = + make_tuple(Number{}, Number{}); +#endif + + return make_tuple(in_lengths_dev, + wei_lengths_dev, + out_lengths_dev, + conv_strides_dev, + conv_dilations_dev, + in_left_pads_dev, + in_right_pads_dev); + }; + + const auto nhwc_desc = f_make_for_device_nhwc(); + +#if USE_CONV_BWD_V4R1_XDL_NHWC + if(algo == ConvBackwardDataAlgo::V4R1XDLNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk< + in_data_t, + acc_data_t, + out_data_t>(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in_device, + wei, + out, + nrepeat); + } +#endif + +#if USE_CONV_BWD_V4R1R2_XDL_NHWC + if(algo == ConvBackwardDataAlgo::V4R1R2XDLNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk< + in_data_t, + acc_data_t, + out_data_t>(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in_device, + wei, + out, + nrepeat); + } +#endif + + if(do_verification) + { + host_direct_convolution_backward_data(in_host, + wei, + out, + make_tuple(conv_stride_h, conv_stride_w), + make_tuple(conv_dilation_h, conv_dilation_w), + make_tuple(in_left_pad_h, in_left_pad_w), + make_tuple(in_right_pad_h, in_right_pad_w), + layout); + + check_error(in_host, in_device); + + if(do_log) + { + LogRangeAsType(std::cout << "out : ", out.mData, ",") << std::endl; + LogRangeAsType(std::cout << "wei: ", wei.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in_host : ", in_host.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in_device: ", in_device.mData, ",") << std::endl; + } + } +} diff --git a/host/driver_offline/conv_fwd_driver_offline.cpp b/host/driver_offline/conv_fwd_driver_offline.cpp new file mode 100644 index 0000000000..ef2e16c4fa --- /dev/null +++ b/host/driver_offline/conv_fwd_driver_offline.cpp @@ -0,0 +1,480 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "conv_common.hpp" +#include "host_conv.hpp" +#include "device_tensor.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" + +#define USE_DYNAMIC_MODE 1 +#define USE_CONV_FWD_V4R4_NCHW 1 +#define USE_CONV_FWD_V4R4R2_NHWC 1 +#define USE_CONV_FWD_V6R1_NCHW 1 +#define USE_CONV_FWD_V5R1_NCHW 0 +#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0 +#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0 + +enum ConvForwardAlgo +{ + V4R4NCHW, // 0 + V4R4R2NHWC, // 1 + V6R1NCHW, // 2 + V5R1NCHW, // 3 + V4R4R2XDLNCHW, // 4 + V4R4R4XDLNHWC // 5 +}; + +int main(int argc, char* argv[]) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + +#if USE_DYNAMIC_MODE + // dynamic mode + if(argc != 22) + { + printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); + exit(1); + } + + const ConvTensorLayout layout = static_cast(atoi(argv[1])); + const ConvForwardAlgo algo = static_cast(atoi(argv[2])); + const bool do_verification = atoi(argv[3]); + const int init_method = atoi(argv[4]); + const bool do_log = atoi(argv[5]); + const int nrepeat = atoi(argv[6]); + + const index_t N = atoi(argv[7]); + const index_t K = atoi(argv[8]); + const index_t C = atoi(argv[9]); + const index_t Y = atoi(argv[10]); + const index_t X = atoi(argv[11]); + const index_t Hi = atoi(argv[12]); + const index_t Wi = atoi(argv[13]); + + const index_t conv_stride_h = atoi(argv[14]); + const index_t conv_stride_w = atoi(argv[15]); + const index_t conv_dilation_h = atoi(argv[16]); + const index_t conv_dilation_w = atoi(argv[17]); + const index_t in_left_pad_h = atoi(argv[18]); + const index_t in_left_pad_w = atoi(argv[19]); + const index_t in_right_pad_h = atoi(argv[20]); + const index_t in_right_pad_w = atoi(argv[21]); + + const index_t YEff = (Y - 1) * conv_dilation_h + 1; + const index_t XEff = (X - 1) * conv_dilation_w + 1; + + const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; +#else + // static mode + if(argc < 7) + { + printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + exit(1); + } + + const ConvTensorLayout layout = static_cast(atoi(argv[1])); + const ConvForwardAlgo algo = static_cast(atoi(argv[2])); + const bool do_verification = atoi(argv[3]); + const int init_method = atoi(argv[4]); + const bool do_log = atoi(argv[5]); + const int nrepeat = atoi(argv[6]); + + constexpr index_t N = 128; + constexpr index_t C = 192; + constexpr index_t Hi = 71; + constexpr index_t Wi = 71; + constexpr index_t K = 256; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + const index_t conv_stride_h = 2; + const index_t conv_stride_w = 2; + const index_t conv_dilation_h = 1; + const index_t conv_dilation_w = 1; + const index_t in_left_pad_h = 1; + const index_t in_left_pad_w = 1; + const index_t in_right_pad_h = 1; + const index_t in_right_pad_w = 1; + + const index_t YEff = (Y - 1) * conv_dilation_h + 1; + const index_t XEff = (X - 1) * conv_dilation_w + 1; + + const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; +#endif + +#if 1 + using in_data_t = float; + using acc_data_t = float; + using out_data_t = float; +#elif 1 + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; +#elif 1 + using in_data_t = int8_t; + using acc_data_t = int32_t; + using out_data_t = int8_t; +#endif + + std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); + + switch(layout) + { + case ConvTensorLayout::NCHW: + // NCHW + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(C); + in_lengths_host[2] = static_cast(Hi); + in_lengths_host[3] = static_cast(Wi); + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(C); + wei_lengths_host[2] = static_cast(Y); + wei_lengths_host[3] = static_cast(X); + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(K); + out_lengths_host[2] = static_cast(Ho); + out_lengths_host[3] = static_cast(Wo); + break; + case ConvTensorLayout::NHWC: + // NHWC + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(Hi); + in_lengths_host[2] = static_cast(Wi); + in_lengths_host[3] = static_cast(C); + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(Y); + wei_lengths_host[2] = static_cast(X); + wei_lengths_host[3] = static_cast(C); + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(Ho); + out_lengths_host[2] = static_cast(Wo); + out_lengths_host[3] = static_cast(K); + break; + default: throw std::runtime_error("wrong! not implemented"); + } + + Tensor in(in_lengths_host); + Tensor wei(wei_lengths_host); + Tensor out_host(out_lengths_host); + Tensor out_device(out_lengths_host); + + std::cout << "layout: " << layout << std::endl; + ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: "); + ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: "); + ostream_HostTensorDescriptor(out_host.mDesc, std::cout << "out: "); + print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); + print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); + print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); + print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + switch(init_method) + { + case 0: + // no initialization + break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 3: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 4: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 5: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); + + auto gen_wei = [](auto... is) { + return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); + }; + wei.GenerateTensorValue(gen_wei, num_thread); + } + + auto f_make_for_device_nchw = [&]() { +#if USE_DYNAMIC_MODE + const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); + const auto wei_lengths_dev = make_tuple(K, C, Y, X); + const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); + const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); +#else + const auto in_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto out_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto conv_strides_dev = make_tuple(Number{}, Number{}); + const auto conv_dilations_dev = + make_tuple(Number{}, Number{}); + const auto in_left_pads_dev = make_tuple(Number{}, Number{}); + const auto in_right_pads_dev = + make_tuple(Number{}, Number{}); +#endif + + return make_tuple(in_lengths_dev, + wei_lengths_dev, + out_lengths_dev, + conv_strides_dev, + conv_dilations_dev, + in_left_pads_dev, + in_right_pads_dev); + }; + + auto f_make_for_device_nhwc = [&]() { +#if USE_DYNAMIC_MODE + const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); + const auto wei_lengths_dev = make_tuple(K, Y, X, C); + const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); + const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); +#else + const auto in_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto out_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto conv_strides_dev = make_tuple(Number{}, Number{}); + const auto conv_dilations_dev = + make_tuple(Number{}, Number{}); + const auto in_left_pads_dev = make_tuple(Number{}, Number{}); + const auto in_right_pads_dev = + make_tuple(Number{}, Number{}); +#endif + + return make_tuple(in_lengths_dev, + wei_lengths_dev, + out_lengths_dev, + conv_strides_dev, + conv_dilations_dev, + in_left_pads_dev, + in_right_pads_dev); + }; + +#if USE_CONV_FWD_V4R4_NCHW + if(algo == ConvForwardAlgo::V4R4NCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + +#if USE_CONV_FWD_V4R4R2_NHWC + if(algo == ConvForwardAlgo::V4R4R2NHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + +#if USE_CONV_FWD_V6R1_NCHW + if(algo == ConvForwardAlgo::V6R1NCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + +#if USE_CONV_FWD_V5R1_NCHW + if(algo == ConvForwardAlgo::V5R1NCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + +#if USE_CONV_FWD_V4R4R2_XDL_NCHW + if(algo == ConvForwardAlgo::V4R4R2XDLNCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + +#if USE_CONV_FWD_V4R4R4_XDL_NHWC + if(algo == ConvForwardAlgo::V4R4R4XDLNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + + if(do_verification) + { + host_direct_convolution(in, + wei, + out_host, + make_tuple(conv_stride_h, conv_stride_w), + make_tuple(conv_dilation_h, conv_dilation_w), + make_tuple(in_left_pad_h, in_left_pad_w), + make_tuple(in_right_pad_h, in_right_pad_w), + layout); + + check_error(out_host, out_device); + +#if 0 + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; + LogRangeAsType(std::cout << "wei: ", wei.mData, ",") << std::endl; + LogRangeAsType(std::cout << "out_host : ", out_host.mData, ",") << std::endl; + LogRangeAsType(std::cout << "out_device: ", out_device.mData, ",") << std::endl; + } +#endif + } +} diff --git a/host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..49e0223b33 --- /dev/null +++ b/host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,341 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp" +#include "driver_dynamic_gemm_xdlops_v2r3.hpp" + +template +void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + const Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + +#if 1 + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [256, 128, 4, 4] + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#endif + + const auto descs = + transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + in_n_hi_wi_c_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + I0, + I0, + Number{}); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto in_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 + + constexpr auto out_gemmk0_gemmn_gemmk1_grid_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 + + constexpr auto in_m0_m1_m2_n_grid_iterator_hacks = make_tuple( + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 3+: NWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 7+: N1 + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: MRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: NRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: MWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 3-: NWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1 + + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; + + constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_dynamic_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), + decltype(out_gemmk0_gemmn_gemmk1_grid_desc), + decltype(in_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<2, 0, 1>, + Sequence<0, 2, 1>, + 1, + GemmABlockTransferSrcScalarPerVector_GemmM, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmK1, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<1, 3, 7, 0, 2, 4, 5, 6>, + 6, + GemmCThreadTransferDstScalarPerVector, + decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), + decltype(out_gemmk0_gemmn_gemmk1_grid_iterator_hacks), + decltype(in_m0_m1_m2_n_grid_iterator_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + wei_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc, + wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, + out_gemmk0_gemmn_gemmk1_grid_iterator_hacks, + in_m0_m1_m2_n_grid_iterator_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, + out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Hi = in_n_hi_wi_c_lengths[I1]; + const auto Wi = in_n_hi_wi_c_lengths[I2]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data()); +} diff --git a/host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..ce4dd155f6 --- /dev/null +++ b/host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,317 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp" +#include "driver_dynamic_gemm_xdlops_v2r3.hpp" + +template +void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + const Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto descs = + transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(out_n_ho_wo_k_desc, + wei_k_y_x_c_desc, + in_n_hi_wi_c_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + I0, + I0, + Number{}); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto in_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto out_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmm + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 + + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmn + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 + + constexpr auto in_m0_m1_m2_n_grid_iterator_hacks = make_tuple( + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 2+: MWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: NWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 4+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 5+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 6+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N1 + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: MRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: NRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 2-: MWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: NWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 4-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 5-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1 + + constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; + + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_dynamic_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(out_gemmk0_gemmm_gemmk1_grid_desc), + decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), + decltype(in_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<2, 0, 1>, + Sequence<0, 2, 1>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy +#if 0 + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, +#else + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, +#endif + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(out_gemmk0_gemmm_gemmk1_grid_iterator_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks), + decltype(in_m0_m1_m2_n_grid_iterator_hacks), + decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), + true // CAccessOrderMRepeatNRepeat + >(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc, + out_gemmk0_gemmm_gemmk1_grid_iterator_hacks, + wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks, + in_m0_m1_m2_n_grid_iterator_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, + wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Hi = in_n_hi_wi_c_lengths[I1]; + const auto Wi = in_n_hi_wi_c_lengths[I2]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data()); +} diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..24ba775309 --- /dev/null +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,210 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" +#include "driver_dynamic_gemm_dlops_v1r2.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + const auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); + +#if 1 + // cdata = 64, BlockSize = 256, 128x128x8 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlockM1 = 128; + constexpr index_t GemmNPerBlockN1 = 128; + constexpr index_t GemmKPerBlock = 8; + + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmM11N11ThreadClusterM1100 = 8; + constexpr index_t GemmM11N11ThreadClusterN1100 = 8; + constexpr index_t GemmM11N11ThreadClusterM1101 = 2; + constexpr index_t GemmM11N11ThreadClusterN1101 = 2; + + using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 1>; + using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 128>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1; + + using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>; + using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_N1 = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 1; +#endif + + const auto descs = + transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); + + constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); + + constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; + + const auto wei_gemmk_gemmm_grid_desc = descs[I0]; + const auto in_gemmk_gemmn_grid_desc = descs[I1]; + const auto out_gemmm_gemmn_grid_desc = descs[I2]; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_dynamic_gemm_dlops_v1r2< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(wei_gemmk_gemmm_grid_desc), + decltype(in_gemmk_gemmn_grid_desc), + decltype(out_gemmm_gemmn_grid_desc), + GemmMPerBlockM1, + GemmNPerBlockN1, + GemmKPerBlock, + GemmM1PerThreadM111, + GemmN1PerThreadN111, + GemmKPerThread, + GemmM11N11ThreadClusterM1100, + GemmM11N11ThreadClusterN1100, + GemmM11N11ThreadClusterM1101, + GemmM11N11ThreadClusterN1101, + GemmABlockTransferThreadSliceLengths_K_M0_M1, + GemmABlockTransferThreadClusterLengths_K_M0_M1, + Sequence<2, 1, 0>, // ABlockTransferThreadClusterArrangeOrder + Sequence<2, 1, 0>, // ABlockTransferSrcAccessOrder + 0, // ABlockTransferSrcVectorDim + GemmABlockTransferSrcScalarPerVector_K, + GemmABlockTransferDstScalarPerVector_M1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_K_N0_N1, + GemmBBlockTransferThreadClusterLengths_K_N0_N1, + Sequence<0, 1, 2>, // BBlockTransferThreadClusterArrangeOrder + Sequence<0, 1, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + GemmBBlockTransferSrcScalarPerVector_N1, + GemmBBlockTransferDstScalarPerVector_N1, + false, // don't move back src coordinate after threadwise copy + Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + GemmCThreadTransferDstScalarPerVector_N11, + decltype(wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks), + decltype(in_gemmk_gemmn0_gemmn1_grid_iterator_hacks), + decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks), + decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks), + decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks)>( + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + wei_gemmk_gemmm_grid_desc, + in_gemmk_gemmn_grid_desc, + out_gemmm_gemmn_grid_desc, + wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks, + in_gemmk_gemmn0_gemmn1_grid_iterator_hacks, + out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks, + wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks, + in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks, + nrepeat); + + float perf = (float)calculate_convolution_flops( + in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); +} diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..b6b1cc8969 --- /dev/null +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,283 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + const auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); + +#if 0 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmKPack = 8; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#elif 0 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmKPack = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#elif 0 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmKPack = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 4] + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmKPack = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 4] + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmKPack = 4; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#endif + + const auto descs = +#if 1 + transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad +#else + transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1 +#endif + ( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads); + + for(index_t i = 0; i < 5; ++i) + { +#if 0 + float ave_time = launch_kernel_dynamic_gemm_xdlops_v1 +#else + float ave_time = launch_kernel_dynamic_gemm_xdlops_v2 +#endif + , + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK, + GemmABlockTransferDstScalarPerVector_KPack, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<0, 2, 1>, + Sequence<1, 0, 2>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_KPack, + false, // don't move back src coordinate after threadwise copy, which will be fused + // with MoveSrcSliceWindow() to save addr computation + Sequence<2, 3, 0, 1>, + 3, + GemmCThreadTransferDstScalarPerVector_GemmN1, + decltype(descs[I4]), + decltype(descs[I5]), + decltype(descs[I6]), + decltype(descs[I7]), + decltype(descs[I8])>(static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + descs[I0], + descs[I1], + descs[I2], + descs[I3], + descs[I4], + descs[I5], + descs[I6], + descs[I7], + descs[I8], + nrepeat); + + float perf = (float)calculate_convolution_flops( + in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); +} diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..cdd1084c0d --- /dev/null +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,284 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" +#include "driver_dynamic_gemm_dlops_v1r3.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + +#if 1 + // [M, N, K0, K1] = [128, 128, 8, 1] for fp32 + // cdata = 64, BlockSize = 256 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlockM1 = 128; + constexpr index_t GemmNPerBlockN1 = 128; + constexpr index_t GemmKPerBlock = 8; + constexpr index_t GemmK1 = 1; + + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; + + using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>; + using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>; + + using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>; + using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>; + + using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>; + using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 1>; + + using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>; + using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>; + + using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>; + using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 1>; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 8, 2] for fp16 + // cdata = 64, BlockSize = 256 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlockM1 = 128; + constexpr index_t GemmNPerBlockN1 = 128; + constexpr index_t GemmKPerBlock = 8; + constexpr index_t GemmK1 = 2; + + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; + + using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>; + using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>; + + using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>; + using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>; + + using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>; + using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 2>; + + using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>; + using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>; + + using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>; + using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 2>; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 8, 4] for i8 + // cdata = 64, BlockSize = 256 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlockM1 = 128; + constexpr index_t GemmNPerBlockN1 = 128; + constexpr index_t GemmKPerBlock = 8; + constexpr index_t GemmK1 = 4; + + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; + + using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>; + using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>; + + using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>; + using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>; + + using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>; + using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 4>; + + using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>; + using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>; + + using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>; + using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 4>; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4; +#endif + + const auto descs = + transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto out_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1 + + constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1 + + constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10 + Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11 + Sequence<0, 0, 0, 0, 0>{}, // 3+: GemmN0 + Sequence<0, 0, 0, 0, 0>{}, // 4+: GemmN10 + Sequence<0, 0, 0, 0, 0>{}), // 5+: GemmN11 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmM0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM10 + Sequence<0, 0, 0, 0, 0>{}, // 2-: GemmM11 + Sequence<0, 0, 0, 0, 0>{}, // 3-: GemmN0 + Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10 + Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11 + + constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{}; + + constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_dynamic_gemm_dlops_v1r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(in_gemmk0_gemmm_gemmk1_grid_desc), + decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), + decltype(out_gemmm_gemmn_grid_desc), + GemmMPerBlockM1, + GemmNPerBlockN1, + GemmKPerBlock, + GemmM1PerThreadM111, + GemmN1PerThreadN111, + GemmKPerThread, + GemmM11N11ThreadClusterM110Xs, + GemmM11N11ThreadClusterN110Xs, + GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1, + GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1, + Sequence<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder + Sequence<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder + GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + Sequence<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder + GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1, + GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1, + Sequence<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder + Sequence<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder + GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + Sequence<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder + GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + GemmCThreadTransferDstScalarPerVector_N11, + decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks), + decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks), + decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks), + decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks)>( + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks, + wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks, + out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks, + in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks, + wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Hi = in_n_hi_wi_c_lengths[I1]; + const auto Wi = in_n_hi_wi_c_lengths[I2]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); +} diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..b56cbc0335 --- /dev/null +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,206 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" +#include "driver_dynamic_gemm_xdlops_v2r3.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + const auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); + +#if 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto descs = + transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto out_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); + + constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); + + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_dynamic_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), + decltype(in_gemmk0_gemmn_gemmk1_grid_desc), + decltype(out_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<0, 2, 1>, + Sequence<1, 0, 2>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<3, 0, 1, 2, 7, 5, 4, 6>, + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks), + decltype(out_m0_m1_m2_n_grid_iterator_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), + false>(static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + wei_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, + in_gemmk0_gemmn_gemmk1_grid_iterator_hacks, + out_m0_m1_m2_n_grid_iterator_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + nrepeat); + + float perf = (float)calculate_convolution_flops( + in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); +} diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..10284b48f3 --- /dev/null +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,240 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" +#include "driver_dynamic_gemm_xdlops_v2r2.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + +#if 1 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#endif + + const auto descs = + transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, + in_n_hi_wi_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto out_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); + + constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); + + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_dynamic_gemm_xdlops_v2r2< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), + decltype(in_gemmk0_gemmn_gemmk1_grid_desc), + decltype(out_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmK1, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1>, + 2, + GemmCThreadTransferDstScalarPerVector, + decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks), + decltype(out_m0_m1_m2_n_grid_iterator_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks)>( + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + wei_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, + in_gemmk0_gemmn_gemmk1_grid_iterator_hacks, + out_m0_m1_m2_n_grid_iterator_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Hi = in_n_hi_wi_c_lengths[I1]; + const auto Wi = in_n_hi_wi_c_lengths[I2]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); +} diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..f2a30fb525 --- /dev/null +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,305 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" +#include "driver_dynamic_gemm_xdlops_v2r3.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + +#if 1 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [256, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#endif + + const auto descs = + transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, + in_n_hi_wi_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto out_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); + + constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); + + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_dynamic_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), + decltype(in_gemmk0_gemmn_gemmk1_grid_desc), + decltype(out_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmK1, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 6, + GemmCThreadTransferDstScalarPerVector, + decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks), + decltype(out_m0_m1_m2_n_grid_iterator_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + wei_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, + in_gemmk0_gemmn_gemmk1_grid_iterator_hacks, + out_m0_m1_m2_n_grid_iterator_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Hi = in_n_hi_wi_c_lengths[I1]; + const auto Wi = in_n_hi_wi_c_lengths[I2]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); +} diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..601878c347 --- /dev/null +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,365 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" +#include "driver_dynamic_gemm_xdlops_v2r3.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [256, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto descs = + transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto out_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto in_gemmk0_gemmm_gemmk1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1 + + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 + + constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: MRepeat + Sequence<0, 0, 0, 0, 0>{}, // 1+: NRepeat + Sequence<0, 0, 0, 0, 0>{}, // 2+: MWaves + Sequence<0, 0, 0, 0, 0>{}, // 3+: NWaves + Sequence<0, 0, 0, 0, 0>{}, // 4+: M0 + Sequence<0, 0, 0, 0, 0>{}, // 5+: M1 + Sequence<0, 0, 0, 0, 0>{}, // 6+: M2 + Sequence<0, 0, 0, 0, 0>{}), // 7+: N1 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: MRepeat + Sequence<0, 0, 0, 0, 0>{}, // 1-: NRepeat + Sequence<0, 0, 0, 0, 0>{}, // 2-: MWaves + Sequence<0, 0, 0, 0, 0>{}, // 3-: NWaves + Sequence<0, 0, 0, 0, 0>{}, // 4-: M0 + Sequence<0, 0, 0, 0, 0>{}, // 5-: M1 + Sequence<0, 0, 0, 0, 0>{}, // 6-: M2 + Sequence<0, 0, 0, 0, 0>{})); // 7-: N1 + + constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; + + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_dynamic_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(in_gemmk0_gemmm_gemmk1_grid_desc), + decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), + decltype(out_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmK1, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(in_gemmk0_gemmm_gemmk1_grid_iterator_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks), + decltype(out_m0_m1_m2_n_grid_iterator_hacks), + decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + in_gemmk0_gemmm_gemmk1_grid_iterator_hacks, + wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks, + out_m0_m1_m2_n_grid_iterator_hacks, + in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, + wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Hi = in_n_hi_wi_c_lengths[I1]; + const auto Wi = in_n_hi_wi_c_lengths[I2]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); +} diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..ca0d47c33a --- /dev/null +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,192 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp" +#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = out_n_k_ho_wo_lengths[I0]; + const auto K = out_n_k_ho_wo_lengths[I1]; + const auto C = wei_k_c_y_x_lengths[I1]; + + const auto Hi = in_n_c_hi_wi_lengths[I2]; + const auto Wi = in_n_c_hi_wi_lengths[I3]; + + const auto Ho = out_n_k_ho_wo_lengths[I2]; + const auto Wo = out_n_k_ho_wo_lengths[I3]; + + const auto Y = wei_k_c_y_x_lengths[I2]; + const auto X = wei_k_c_y_x_lengths[I3]; + + const auto C0 = C / Number{}; + const auto C1 = Number{}; + + const auto K0 = K / Number{}; + const auto K1 = Number{}; + + Tensor in_n_c0_hi_wi_c1( + HostTensorDescriptor(std::initializer_list{N, C0, Hi, Wi, C1})); + Tensor wei_k_c0_y_x_c1( + HostTensorDescriptor(std::initializer_list{K, C0, Y, X, C1})); + Tensor out_n_k0_ho_wo_k1( + HostTensorDescriptor(std::initializer_list{N, K0, Ho, Wo, K1})); + + auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) { + in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) = + in_n_c_hi_wi(n, c, hi, wi); + }; + + auto f_kcyx2kc0yxc1 = [&](auto k, auto y, auto x, auto c) { + wei_k_c0_y_x_c1(k, c / InWeiVectorSize, y, x, c % InWeiVectorSize) = + wei_k_c_y_x(k, c, y, x); + }; + + make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)(); + make_ParallelTensorFunctor(f_kcyx2kc0yxc1, K, Y, X, C)(); + + DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) * + in_n_c0_hi_wi_c1.mDesc.GetElementSpace()); + DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace()); + DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) * + out_n_k0_ho_wo_k1.mDesc.GetElementSpace()); + + in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data()); + wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); + + const auto in_n_c0_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C0, Hi, Wi)); + const auto wei_k_c0_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X)); + const auto out_n_k0_ho_wo_k1_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)); + +#if 1 + // cdata = 64, BlockSize = 64, 16x8x32x4 + constexpr index_t BlockSize = 64; + + constexpr index_t KPerBlock = 16; + constexpr index_t HoPerBlock = 8; + constexpr index_t WoPerBlock = 32; + constexpr index_t EPerBlock = 1; + + constexpr index_t KPerThread = KPerBlock; + constexpr index_t HoPerThread = 2; + constexpr index_t WoPerThread = 2; + constexpr index_t EPerThread = EPerBlock; + + using ABlockTransferThreadSliceLengths_E_K = Sequence<3, 1>; + using ABlockTransferThreadClusterLengths_E_K = Sequence<3 * EPerBlock, KPerBlock>; + + constexpr index_t ABlockTransferSrcScalarPerVector_E = 1; + constexpr index_t ABlockTransferDstScalarPerVector_K = 1; + + constexpr index_t BThreadTransferSrcScalarPerVector_W = 1; + + constexpr index_t CThreadTransferDstScalarPerVector_W = 16; + + static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, ""); +#else + constexpr index_t BlockSize = 64; + + constexpr index_t KPerBlock = 16; + constexpr index_t HoPerBlock = 8; + constexpr index_t WoPerBlock = 32; + constexpr index_t EPerBlock = 1; + + constexpr index_t KPerThread = 16; + constexpr index_t HoPerThread = 2; + constexpr index_t WoPerThread = 2; + constexpr index_t EPerThread = EPerBlock; + + using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>; + using ABlockTransferThreadClusterLengths_E_K = Sequence; + + constexpr index_t ABlockTransferSrcScalarPerVector_E = 1; + constexpr index_t ABlockTransferDstScalarPerVector_K = 1; + + constexpr index_t BThreadTransferSrcScalarPerVector_W = 1; + + constexpr index_t CThreadTransferDstScalarPerVector_W = K1; + + static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, ""); +#endif + + constexpr auto conv_driver = +#if 0 + DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad +#else + DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad +#endif + ::type, + TAcc, + TOut, + KPerBlock, + HoPerBlock, + WoPerBlock, + EPerBlock, + KPerThread, + HoPerThread, + WoPerThread, + EPerThread, + ABlockTransferThreadSliceLengths_E_K, + ABlockTransferThreadClusterLengths_E_K, + ABlockTransferSrcScalarPerVector_E, + ABlockTransferDstScalarPerVector_K, + BThreadTransferSrcScalarPerVector_W, + CThreadTransferDstScalarPerVector_W>{}; + + conv_driver.Run(wei_k_c0_y_x_desc, + in_n_c0_hi_wi_desc, + out_n_k0_ho_wo_k1_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + static_cast::type*>( + wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()), + static_cast::type*>( + in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()), + static_cast(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer())); + + out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data()); + + auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) { + out_n_k_ho_wo(n, k, ho, wo) = + out_n_k0_ho_wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize); + }; + + make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Ho, Wo)(); +} diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..8fb276b464 --- /dev/null +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,244 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" +#include "driver_dynamic_contraction_dlops_v1r2.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + const auto in_desc_n_c_hi_wi = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); + const auto wei_desc_k_c_y_x = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); + const auto out_desc_n_k_ho_wo = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); + +#if 1 + // [8, 1, 128, 1] * [8, 4, 32, 1] = [1, 128, 4, 32] for fp32 + // cdata = 64, BlockSize = 256 + constexpr index_t BlockSize = 256; + + constexpr index_t GN0 = 4; + constexpr index_t GK1 = 1; + + constexpr index_t GM1PerBlockGM11 = 128; + constexpr index_t GN1PerBlockGN11 = 32; + constexpr index_t GK0PerBlock = 8; + + constexpr index_t BM1PerThreadBM11 = 4; + constexpr index_t BN1PerThreadBN11 = 4; + constexpr index_t BK0PerThread = 1; + + using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>; + using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>; + + using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; + using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>; + + using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; + using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 1>; + + using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 1>; + using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>; + + using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>; + using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>; + + constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1; +#elif 1 + // [8, 1, 128, 2] * [8, 4, 32, 2] = [1, 128, 4, 32] for fp16 + // cdata = 64, BlockSize = 256 + constexpr index_t BlockSize = 256; + + constexpr index_t GN0 = 4; + constexpr index_t GK1 = 2; + + constexpr index_t GM1PerBlockGM11 = 128; + constexpr index_t GN1PerBlockGN11 = 32; + constexpr index_t GK0PerBlock = 8; + + constexpr index_t BM1PerThreadBM11 = 4; + constexpr index_t BN1PerThreadBN11 = 4; + constexpr index_t BK0PerThread = 1; + + using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>; + using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>; + + using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 2>; + using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>; + + using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; + using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 2>; + + using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 2>; + using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>; + + using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>; + using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 2>; + + constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1; +#endif + + const auto descs = + transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_desc_k_c_y_x, + in_desc_n_c_hi_wi, + out_desc_n_k_ho_wo, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}, + Number{}); + + const auto wei_grid_desc_gk0_gm0_gm1_gk1 = descs[I0]; + const auto in_grid_desc_gk0_gn0_gn1_gk1 = descs[I1]; + const auto out_grid_desc_gm0_gm1_gn0_gn1 = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11 + Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 + Sequence<0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1 + + constexpr auto in_grid_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1 + + constexpr auto out_grid_iterator_hacks = make_tuple( + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1 + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 5-: GN1 + + constexpr auto wei_grid_move_slice_window_iterator_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{}; + + constexpr auto in_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_dynamic_contraction_dlops_v1r2< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(wei_grid_desc_gk0_gm0_gm1_gk1), + decltype(in_grid_desc_gk0_gn0_gn1_gk1), + decltype(out_grid_desc_gm0_gm1_gn0_gn1), + GM1PerBlockGM11, + GN1PerBlockGN11, + GK0PerBlock, + BM1PerThreadBM11, + BN1PerThreadBN11, + BK0PerThread, + BM10BN10ThreadClusterBM10Xs, + BM10BN10ThreadClusterBN10Xs, + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, + Sequence<1, 2, 3, 0, 4>, // ABlockTransferThreadClusterArrangeOrder + Sequence<3, 2, 1, 0, 4>, // ABlockTransferSrcAccessOrder + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + Sequence<0, 1, 2, 3, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, + Sequence<0, 4, 1, 2, 3>, // BBlockTransferThreadClusterArrangeOrder + Sequence<4, 3, 2, 0, 1>, // BBlockTransferSrcAccessOrder + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + Sequence<0, 1, 2, 3, 4>, // BBlockTransferSrcVectorTensorContiguousDimOrder + Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + CThreadTransferDstScalarPerVector_BN1, + decltype(wei_grid_iterator_hacks), + decltype(in_grid_iterator_hacks), + decltype(out_grid_iterator_hacks), + decltype(wei_grid_move_slice_window_iterator_hacks), + decltype(in_grid_move_slice_window_iterator_hacks)>( + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + wei_grid_desc_gk0_gm0_gm1_gk1, + in_grid_desc_gk0_gn0_gn1_gk1, + out_grid_desc_gm0_gm1_gn0_gn1, + wei_grid_iterator_hacks, + in_grid_iterator_hacks, + out_grid_iterator_hacks, + wei_grid_move_slice_window_iterator_hacks, + in_grid_move_slice_window_iterator_hacks, + nrepeat); + + float perf = (float)calculate_convolution_flops( + in_desc_n_c_hi_wi, wei_desc_k_c_y_x, out_desc_n_k_ho_wo) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); +} diff --git a/host/driver_offline/include/driver_dynamic_contraction_dlops_v1r2.hpp b/host/driver_offline/include/driver_dynamic_contraction_dlops_v1r2.hpp new file mode 100644 index 0000000000..2f175962c1 --- /dev/null +++ b/host/driver_offline/include/driver_dynamic_contraction_dlops_v1r2.hpp @@ -0,0 +1,290 @@ +#ifndef DRIVER_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP +#define DRIVER_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_contraction_dlops_v1r2.hpp" + +template +__host__ float +driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1, + const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1, + const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1, + AGridIteratorHacks, + BGridIteratorHacks, + CGridIteratorHacks, + AGridMoveSliceWindowIteratorHacks, + BGridMoveSliceWindowIteratorHacks, + ck::index_t nrepeat) + +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + // GEMM + using GridwiseContraction = + GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + CGlobalMemoryDataOperation, + AGridDesc_GK0_GM0_GM1_GK1, + BGridDesc_GK0_GN0_GN1_GK1, + CGridDesc_GM0_GM1_GN0_GN1, + GM1PerBlockGM11, + GN1PerBlockGN11, + GK0PerBlock, + BM1PerThreadBM11, + BN1PerThreadBN11, + BK0PerThread, + BM10BN10ThreadClusterBM10Xs, + BM10BN10ThreadClusterBN10Xs, + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferSrcVectorTensorContiguousDimOrder, + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferSrcVectorTensorContiguousDimOrder, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + AGridIteratorHacks, + BGridIteratorHacks, + CGridIteratorHacks, + AGridMoveSliceWindowIteratorHacks, + BGridMoveSliceWindowIteratorHacks>; + + const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0); + + if(!GridwiseContraction::CheckValidity( + a_grid_desc_gk0_gm0_gm1_gk1, b_grid_desc_gk0_gn0_gn1_gk1, c_grid_desc_gm0_gm1_gn0_gn1)) + { + throw std::runtime_error("wrong! " + "GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_" + "GM0_GM1_GN0_GN1 has invalid setting"); + } + + const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = + GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1); + const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = + GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1); + + using AGridDesc_GK0_GM0_GM10_GM11_GK1 = decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1); + using BGridDesc_GK0_GN0_GN10_GN11_GK1 = decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1); + + // c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 + const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = + GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( + c_grid_desc_gm0_gm1_gn0_gn1); + + using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1); + + // c_grid_block_cluster_blockid_to_gm10_gn10 + const auto c_grid_block_cluster_blockid_to_gm10_gn10 = + GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10( + c_grid_desc_gm0_gm1_gn0_gn1); + + using CGridBlockCluster_BlockId_To_GM10_GN10 = + decltype(c_grid_block_cluster_blockid_to_gm10_gn10); + + const index_t grid_size = GridwiseContraction::CalculateGridSize(c_grid_desc_gm0_gm1_gn0_gn1); + + const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(GK0); + + const bool has_double_tail_k_block_loop = + GridwiseContraction::CalculateHasDoubleTailKBlockLoop(GK0); + + { + std::cout << "a_grid_desc_gk0_gm0_gm10_gm11_gk1{" + << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0) << ", " + << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I1) << ", " + << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I2) << ", " + << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I3) << ", " + << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I4) << "}" << std::endl; + + std::cout << "b_grid_desc_gk0_gn0_gn10_gn11_gk1{" + << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I0) << ", " + << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I1) << ", " + << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I2) << ", " + << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I3) << ", " + << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I4) << "}" << std::endl; + + std::cout << "c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1{ " + << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I0) << ", " + << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I1) << ", " + << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I2) << ", " + << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I3) << ", " + << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I4) << ", " + << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I5) << "}" << std::endl; + } + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_contraction_dlops_v1r2< + GridwiseContraction, + FloatAB, + FloatC, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_grid_desc_gk0_gm0_gm10_gm11_gk1, + b_grid_desc_gk0_gn0_gn10_gn11_gk1, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_block_cluster_blockid_to_gm10_gn10); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_contraction_dlops_v1r2< + GridwiseContraction, + FloatAB, + FloatC, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_grid_desc_gk0_gm0_gm10_gm11_gk1, + b_grid_desc_gk0_gn0_gn10_gn11_gk1, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_block_cluster_blockid_to_gm10_gn10); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_contraction_dlops_v1r2< + GridwiseContraction, + FloatAB, + FloatC, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_grid_desc_gk0_gm0_gm10_gm11_gk1, + b_grid_desc_gk0_gn0_gn10_gn11_gk1, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_block_cluster_blockid_to_gm10_gn10); + } + else + { + const auto kernel = kernel_dynamic_contraction_dlops_v1r2< + GridwiseContraction, + FloatAB, + FloatC, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_grid_desc_gk0_gm0_gm10_gm11_gk1, + b_grid_desc_gk0_gn0_gn10_gn11_gk1, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_block_cluster_blockid_to_gm10_gn10); + } + + return ave_time; +} +#endif diff --git a/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..7c4b1043f3 --- /dev/null +++ b/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,352 @@ +#ifndef DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP +#define DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm_dlops_v2.hpp" +#include "gridwise_operation_wrapper.hpp" + +template +struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad +{ + template + __host__ void Run(const ck::DynamicTensorDescriptor& wei_k_c_y_x_global_desc, + const ck::DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, + const ck::DynamicTensorDescriptor& out_n_k0_ho_wo_k1_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const FloatAB* __restrict__ p_wei_global, + const FloatAB* __restrict__ p_in_global, + FloatC* __restrict__ p_out_global) const + { + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); + + const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2); + const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3); + + const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4); + + const auto K = wei_k_c_y_x_global_desc.GetLength(I0); + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + // weight tensor + const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hip_wip_global_desc, + make_tuple( + make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_e_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_y_ho_x_wo_global_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_pass_through_transform(N), + make_pass_through_transform(Ho), + make_pass_through_transform(Wo)), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + // output tensor + const auto out_k_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)), + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N), + make_pass_through_transform(Ho), + make_pass_through_transform(Wo)), + make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto E = C * Y * X; + + if(!((K % KPerBlock) == 0 && (Ho % HoPerBlock) == 0 && (Wo % WoPerBlock) == 0 && + (E % EPerBlock) == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + // hack to control index calculation when iterating over a_k_m_global tensor + constexpr auto a_e_k_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); + + constexpr auto a_e_k_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; + + constexpr auto b_e_n_ho_wo_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); + + constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}; + + // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor + // hack for NKHW format + constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{})); + +#if 1 + // GEMM + using gridwise_gemm = GridwiseDynamicGemmDlops_km_kn_mn_v3< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + InMemoryDataOperationEnum_t::Set, + decltype(wei_e_k_global_desc), + decltype(in_e_n_ho_wo_global_desc), + decltype(out_k_n_ho_wo_global_desc), + KPerBlock, + HoPerBlock, + WoPerBlock, + EPerBlock, + KPerThread, + HoPerThread, + WoPerThread, + EPerThread, + ABlockTransferThreadSliceLengths_E_K, + ABlockTransferThreadClusterLengths_E_K, + Sequence<1, 0>, + Sequence<1, 0>, + 0, + ABlockTransferSrcScalarPerVector_E, + ABlockTransferDstScalarPerVector_K, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 3, 1>, + 3, + BThreadTransferSrcScalarPerVector_W, + false, // don't move back src coordinate after threadwise copy, which will be fused with + // MoveSrcSliceWindow() to save addr computation + Sequence<0, 2, 3, 1>, + 0, + CThreadTransferDstScalarPerVector_W, + decltype(a_e_k_global_iterator_hacks), + decltype(b_e_n_ho_wo_global_iterator_hacks), + decltype(c_k_n_ho_wo_global_tensor_iterator_hacks), + decltype(a_e_k_global_move_slice_window_iterator_hack), + decltype(b_e_n_ho_wo_global_move_slice_window_iterator_hack)>; + + const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N; + + const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1; + + const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0; + + index_t nrepeat = 100; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + std::cout << "has_main_k_block_loop: " << has_main_k_block_loop + << " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop + << std::endl; + + for(index_t j = 0; j < nrepeat; ++j) + { + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_e_k_global_desc, + p_wei_global, + in_e_n_ho_wo_global_desc, + p_in_global, + out_k_n_ho_wo_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_e_k_global_desc, + p_wei_global, + in_e_n_ho_wo_global_desc, + p_in_global, + out_k_n_ho_wo_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_e_k_global_desc, + p_wei_global, + in_e_n_ho_wo_global_desc, + p_in_global, + out_k_n_ho_wo_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_e_k_global_desc, + p_wei_global, + in_e_n_ho_wo_global_desc, + p_in_global, + out_k_n_ho_wo_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, + wei_k_c_y_x_global_desc, + out_n_k0_ho_wo_k1_global_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } +#endif + } +}; +#endif diff --git a/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp b/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp new file mode 100644 index 0000000000..b7f8e6039c --- /dev/null +++ b/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp @@ -0,0 +1,367 @@ +#ifndef DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP +#define DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm_dlops_v2.hpp" +#include "gridwise_operation_wrapper.hpp" + +template +struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad +{ + template + __host__ void Run(const ck::DynamicTensorDescriptor& wei_k_c_y_x_global_desc, + const ck::DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, + const ck::DynamicTensorDescriptor& out_n_k0_ho_wo_k1_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const FloatAB* __restrict__ p_wei_global, + const FloatAB* __restrict__ p_in_global, + FloatC* __restrict__ p_out_global) const + { + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); + + const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2); + const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3); + + const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4); + + const auto K = wei_k_c_y_x_global_desc.GetLength(I0); + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock; + const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock; + + const auto OutRightPadH = Hop - Ho; + const auto OutRightPadW = Wop - Wo; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH; + const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW; + + std::cerr << "OutRightPadH = " << OutRightPadH << " OutRightPadW = " << OutRightPadW + << std::endl; + std::cerr << "InRightPadH = " << InRightPadH << " InRightPadW = " << InRightPadW + << std::endl; + + // weight tensor + const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hip_wip_global_desc, + make_tuple( + make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_e_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_y_ho_x_wo_global_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_pass_through_transform(N), + make_pass_through_transform(Hop), + make_pass_through_transform(Wop)), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + // output tensor + const auto out_k_n_hop_wop_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)), + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N), + make_pad_transform(Ho, 0, OutRightPadH), + make_pad_transform(Wo, 0, OutRightPadW)), + make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto E = C * Y * X; + + std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl; + + if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 && + (E % EPerBlock) == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + // hack to control index calculation when iterating over a_k_m_global tensor + constexpr auto a_e_k_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); + + constexpr auto a_e_k_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; + + constexpr auto b_e_n_ho_wo_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); + + constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}; + + // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor + // hack for NKHW format + constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{})); + + // GEMM + using gridwise_gemm = GridwiseDynamicGemmDlops_km_kn_mn_v3< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + InMemoryDataOperationEnum_t::Set, + decltype(wei_e_k_global_desc), + decltype(in_e_n_ho_wo_global_desc), + decltype(out_k_n_hop_wop_global_desc), + KPerBlock, + HoPerBlock, + WoPerBlock, + EPerBlock, + KPerThread, + HoPerThread, + WoPerThread, + EPerThread, + ABlockTransferThreadSliceLengths_E_K, + ABlockTransferThreadClusterLengths_E_K, + Sequence<1, 0>, + Sequence<1, 0>, + 0, + ABlockTransferSrcScalarPerVector_E, + ABlockTransferDstScalarPerVector_K, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 3, 1>, + 3, + BThreadTransferSrcScalarPerVector_W, + false, // don't move back src coordinate after threadwise copy, which will be fused with + // MoveSrcSliceWindow() to save addr computation + Sequence<0, 2, 3, 1>, + 0, + CThreadTransferDstScalarPerVector_W, + decltype(a_e_k_global_iterator_hacks), + decltype(b_e_n_ho_wo_global_iterator_hacks), + decltype(c_k_n_ho_wo_global_tensor_iterator_hacks), + decltype(a_e_k_global_move_slice_window_iterator_hack), + decltype(b_e_n_ho_wo_global_move_slice_window_iterator_hack)>; + + const auto GridSize = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N; + + const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1; + + const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0; + + index_t nrepeat = 100; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + std::cout << "has_main_k_block_loop: " << has_main_k_block_loop + << " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop + << std::endl; + + for(index_t j = 0; j < nrepeat; ++j) + { + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_e_k_global_desc, + p_wei_global, + in_e_n_ho_wo_global_desc, + p_in_global, + out_k_n_hop_wop_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_e_k_global_desc, + p_wei_global, + in_e_n_ho_wo_global_desc, + p_in_global, + out_k_n_hop_wop_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_e_k_global_desc, + p_wei_global, + in_e_n_ho_wo_global_desc, + p_in_global, + out_k_n_hop_wop_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_e_k_global_desc, + p_wei_global, + in_e_n_ho_wo_global_desc, + p_in_global, + out_k_n_hop_wop_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, + wei_k_c_y_x_global_desc, + out_n_k0_ho_wo_k1_global_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } +}; +#endif diff --git a/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r2.hpp b/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r2.hpp new file mode 100644 index 0000000000..0ebc68b48a --- /dev/null +++ b/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r2.hpp @@ -0,0 +1,415 @@ +#ifndef DRIVER_DYNAMIC_GEMM_DLOPS_V1R2 +#define DRIVER_DYNAMIC_GEMM_DLOPS_V1R2 + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm_dlops_v1r2.hpp" + +template +__host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AKMGridDesc& a_k_m_grid_desc, + const BKNGridDesc& b_k_n_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + AGridIteratorHacks, + BGridIteratorHacks, + CGridIteratorHacks, + AGridMoveSliceWindowIteratorHacks, + BGridMoveSliceWindowIteratorHacks, + ck::index_t nrepeat) + +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + // GEMM + using GridwiseGemm = + GridwiseDynamicGemmDlops_km_kn_mn_v1r2; + + const auto M = a_k_m_grid_desc.GetLength(I1); + const auto N = b_k_n_grid_desc.GetLength(I1); + const auto K = a_k_m_grid_desc.GetLength(I0); + + if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc)) + { + throw std::runtime_error( + "wrong! GridwiseDynamicGemmDlops_km_kn_mn_v1r2 has invalid setting"); + } + + const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); + const auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); + + using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc); + using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc); + + // c_m0_m10_m11_n0_n10_n11_grid_desc + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = + GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); + + using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc); + + // c_blockid_to_m0_n0_block_cluster_adaptor + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); + + using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor); + + const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N); + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K); + + const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K); + + { + std::cout << "a_k_m0_m1_grid_desc{" << a_k_m0_m1_grid_desc.GetLength(I0) << ", " + << a_k_m0_m1_grid_desc.GetLength(I1) << ", " << a_k_m0_m1_grid_desc.GetLength(I2) + << "}" << std::endl; + + std::cout << "b_k_n0_n1_grid_desc{" << b_k_n0_n1_grid_desc.GetLength(I0) << ", " + << b_k_n0_n1_grid_desc.GetLength(I1) << ", " << b_k_n0_n1_grid_desc.GetLength(I2) + << "}" << std::endl; + + std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl; + } + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + + return ave_time; +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + DeviceMem a_k_m0_m1_grid_desc_dev_buf(sizeof(AKM0M1GridDesc)); + DeviceMem b_k_n0_n1_grid_desc_dev_buf(sizeof(BKN0N1GridDesc)); + DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc)); + DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf( + sizeof(CBlockIdToM0N0BlockClusterAdaptor)); + + a_k_m0_m1_grid_desc_dev_buf.ToDevice(&a_k_m0_m1_grid_desc); + b_k_n0_n1_grid_desc_dev_buf.ToDevice(&b_k_n0_n1_grid_desc); + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc); + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice( + &c_blockid_to_m0_n0_block_cluster_adaptor); + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + else + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + + return ave_time; +#endif +} +#endif diff --git a/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r3.hpp b/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r3.hpp new file mode 100644 index 0000000000..d075eac822 --- /dev/null +++ b/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r3.hpp @@ -0,0 +1,411 @@ +#ifndef DRIVER_DYNAMIC_GEMM_DLOPS_V1R3 +#define DRIVER_DYNAMIC_GEMM_DLOPS_V1R3 + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm_dlops_v1r3.hpp" + +template +__host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AK0MK1GridDesc& a_k0_m_k1_grid_desc, + const BK0NK1GridDesc& b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + AGridIteratorHacks, + BGridIteratorHacks, + CGridIteratorHacks, + AGridMoveSliceWindowIteratorHacks, + BGridMoveSliceWindowIteratorHacks, + ck::index_t nrepeat) + +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + // GEMM + using GridwiseGemm = + GridwiseDynamicGemmDlops_km_kn_mn_v1r3; + + const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + + if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) + { + throw std::runtime_error( + "wrong! GridwiseDynamicGemmDlops_km_kn_mn_v1r3 has invalid setting"); + } + + const auto a_k0_m0_m1_k1_grid_desc = + GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_k0_m_k1_grid_desc); + const auto b_k0_n0_n1_k1_grid_desc = + GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_k0_n_k1_grid_desc); + + using AK0M0M1K1GridDesc = decltype(a_k0_m0_m1_k1_grid_desc); + using BK0N0N1K1GridDesc = decltype(b_k0_n0_n1_k1_grid_desc); + + // c_m0_m10_m11_n0_n10_n11_grid_desc + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = + GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); + + using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc); + + // c_blockid_to_m0_n0_block_cluster_adaptor + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); + + using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor); + + const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N); + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + + const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + { + std::cout << "a_k0_m0_m1_k1_grid_desc{" << a_k0_m0_m1_k1_grid_desc.GetLength(I0) << ", " + << a_k0_m0_m1_k1_grid_desc.GetLength(I1) << ", " + << a_k0_m0_m1_k1_grid_desc.GetLength(I2) << ", " + << a_k0_m0_m1_k1_grid_desc.GetLength(I3) << "}" << std::endl; + + std::cout << "b_k0_n0_n1_k1_grid_desc{" << b_k0_n0_n1_k1_grid_desc.GetLength(I0) << ", " + << b_k0_n0_n1_k1_grid_desc.GetLength(I1) << ", " + << b_k0_n0_n1_k1_grid_desc.GetLength(I2) << ", " + << b_k0_n0_n1_k1_grid_desc.GetLength(I3) << "}" << std::endl; + + std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl; + } + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + + return ave_time; +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + DeviceMem a_k0_m0_m1_k1_grid_desc_dev_buf(sizeof(AK0M0M1K1GridDesc)); + DeviceMem b_k0_n0_n1_k1_grid_desc_dev_buf(sizeof(BK0N0N1K1GridDesc)); + DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc)); + DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf( + sizeof(CBlockIdToM0N0BlockClusterAdaptor)); + + a_k0_m0_m1_k1_grid_desc_dev_buf.ToDevice(&a_k0_m0_m1_k1_grid_desc); + b_k0_n0_n1_k1_grid_desc_dev_buf.ToDevice(&b_k0_n0_n1_k1_grid_desc); + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc); + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice( + &c_blockid_to_m0_n0_block_cluster_adaptor); + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + else + { + const auto kernel = + kernel_dynamic_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + + return ave_time; +#endif +} +#endif diff --git a/host/driver_offline/include/driver_dynamic_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_dynamic_gemm_xdlops_v2r3.hpp new file mode 100644 index 0000000000..481d08188d --- /dev/null +++ b/host/driver_offline/include/driver_dynamic_gemm_xdlops_v2r3.hpp @@ -0,0 +1,196 @@ +#ifndef DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3 +#define DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3 + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp" + +template +__host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AK0MK1GridDesc& a_k0_m_k1_grid_desc, + const BK0NK1GridDesc& b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + AGridIteratorHacks, + BGridIteratorHacks, + CGridIteratorHacks, + AGridMoveSliceWindowIteratorHacks, + BGridMoveSliceWindowIteratorHacks, + ck::index_t nrepeat) + +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + using GridwiseGemm = + GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + + { + std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", " + << a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2) + << "}" << std::endl; + + std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", " + << b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2) + << "}" << std::endl; + + std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", " + << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) + { + throw std::runtime_error( + "wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + } + + const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + + using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc); + + const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + + using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); + + const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc); + + const auto kernel = kernel_dynamic_gemm_xdlops_v2r3, + remove_reference_t, + remove_reference_t, + remove_reference_t>; + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + float ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_m1_m2_n_grid_desc, + c_block_cluster_adaptor); + +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc)); + DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc)); + DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc)); + DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); + + a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc); + b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc); + c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc); + c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); + + float ave_time = + launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void CONSTANT*)a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer(), + (void CONSTANT*)c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); +#endif + return ave_time; +} +#endif diff --git a/host/driver_online/CMakeLists.txt b/host/driver_online/CMakeLists.txt new file mode 100644 index 0000000000..2ae05e0ba5 --- /dev/null +++ b/host/driver_online/CMakeLists.txt @@ -0,0 +1,21 @@ +include_directories(BEFORE + include + ${PROJECT_BINARY_DIR}/host/online_compilation/include + ${PROJECT_SOURCE_DIR}/host/online_compilation/include + ${PROJECT_SOURCE_DIR}/host/host_tensor/include + ${PROJECT_SOURCE_DIR}/composable_kernel/include + ${PROJECT_SOURCE_DIR}/composable_kernel/include/utility + ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description + ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation + ${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform + ${PROJECT_SOURCE_DIR}/composable_kernel/include/driver + ${PROJECT_SOURCE_DIR}/external/rocm/include + ${PROJECT_SOURCE_DIR}/external/half/include +) + +set(CONV_FWD_DRIVER_ONLINE_SOURCE conv_fwd_driver_online.cpp) + +add_executable(conv_fwd_driver_online ${CONV_FWD_DRIVER_ONLINE_SOURCE}) + +target_link_libraries(conv_fwd_driver_online PRIVATE host_tensor) +target_link_libraries(conv_fwd_driver_online PRIVATE online_compilation) diff --git a/host/driver_online/conv_fwd_driver_online.cpp b/host/driver_online/conv_fwd_driver_online.cpp new file mode 100644 index 0000000000..c91f76fa24 --- /dev/null +++ b/host/driver_online/conv_fwd_driver_online.cpp @@ -0,0 +1,453 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "conv_common.hpp" +#include "host_conv.hpp" +#include "device_tensor.hpp" +#include "handle.hpp" +#include "hipCheck.hpp" +#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" +#include "online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp" +#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp" +#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp" + +#define USE_CONV_FWD_V4R4_NCHW 1 +#define USE_CONV_FWD_V6R1_NCHW 1 +#define USE_CONV_FWD_V4R4_XDLOPS_NCHW 1 +#define USE_CONV_FWD_V4R4_XDLOPS_NHWC 1 + +enum ConvForwardAlgo +{ + V4R4NCHW, // 0 + V6R1NCHW, // 1 + V4R4XDLNCHW, // 2 + V4R4XDLNHWC // 3 +}; + +int main(int argc, char* argv[]) +{ + using namespace ck; + using namespace ck_driver; + using size_t = std::size_t; + + hipStream_t stream; + olCompile::Handle* handle; + + MY_HIP_CHECK(hipStreamCreate(&stream)); + + handle = new olCompile::Handle(stream); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + + if(argc != 22) + { + printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); + exit(1); + } + + const ConvTensorLayout layout = static_cast(atoi(argv[1])); + const ConvForwardAlgo algo = static_cast(atoi(argv[2])); + const bool do_verification = atoi(argv[3]); + const int init_method = atoi(argv[4]); + const bool do_log = atoi(argv[5]); + const int nrepeat = atoi(argv[6]); + + const index_t N = atoi(argv[7]); + const index_t K = atoi(argv[8]); + const index_t C = atoi(argv[9]); + const index_t Y = atoi(argv[10]); + const index_t X = atoi(argv[11]); + const index_t Hi = atoi(argv[12]); + const index_t Wi = atoi(argv[13]); + + const index_t conv_stride_h = atoi(argv[14]); + const index_t conv_stride_w = atoi(argv[15]); + const index_t conv_dilation_h = atoi(argv[16]); + const index_t conv_dilation_w = atoi(argv[17]); + const index_t in_left_pad_h = atoi(argv[18]); + const index_t in_left_pad_w = atoi(argv[19]); + const index_t in_right_pad_h = atoi(argv[20]); + const index_t in_right_pad_w = atoi(argv[21]); + + const index_t YEff = (Y - 1) * conv_dilation_h + 1; + const index_t XEff = (X - 1) * conv_dilation_w + 1; + + const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + +#if 1 + using in_data_t = float; + using acc_data_t = float; + using out_data_t = float; +#elif 0 + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; +#elif 1 + using in_data_t = int8_t; + using acc_data_t = int32_t; + using out_data_t = int8_t; +#endif + + std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); + + switch(layout) + { + case ConvTensorLayout::NCHW: + // NCHW + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(C); + in_lengths_host[2] = static_cast(Hi); + in_lengths_host[3] = static_cast(Wi); + + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(C); + wei_lengths_host[2] = static_cast(Y); + wei_lengths_host[3] = static_cast(X); + + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(K); + out_lengths_host[2] = static_cast(Ho); + out_lengths_host[3] = static_cast(Wo); + break; + case ConvTensorLayout::NHWC: + // NHWC + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(Hi); + in_lengths_host[2] = static_cast(Wi); + in_lengths_host[3] = static_cast(C); + + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(Y); + wei_lengths_host[2] = static_cast(X); + wei_lengths_host[3] = static_cast(C); + + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(Ho); + out_lengths_host[2] = static_cast(Wo); + out_lengths_host[3] = static_cast(K); + break; + default: throw std::runtime_error("wrong! not implemented"); + } + + Tensor in(in_lengths_host); + Tensor wei(wei_lengths_host); + Tensor out_host(out_lengths_host); + Tensor out_device(out_lengths_host); + + std::cout << "layout: " << layout << std::endl; + ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: "); + ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: "); + ostream_HostTensorDescriptor(out_host.mDesc, std::cout << "out: "); + print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); + print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); + print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); + print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + switch(init_method) + { + case 0: + // no initialization + break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 3: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 4: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 5: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); + + auto gen_wei = [](auto... is) { + return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); + }; + wei.GenerateTensorValue(gen_wei, num_thread); + } + + auto f_make_for_device_nchw = [&]() { + const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); + const auto wei_lengths_dev = make_tuple(K, C, Y, X); + const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); + + return make_tuple(in_lengths_dev, wei_lengths_dev, out_lengths_dev); + }; + + auto f_make_for_device_nhwc = [&]() { + const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); + const auto wei_lengths_dev = make_tuple(K, Y, X, C); + const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); + + return make_tuple(in_lengths_dev, wei_lengths_dev, out_lengths_dev); + }; + + const auto conv_strides = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads = make_tuple(in_right_pad_h, in_right_pad_w); + +#if USE_CONV_FWD_V4R4_NCHW + if(algo == ConvForwardAlgo::V4R4NCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw* tunable = + &default_tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw; + + online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw< + in_data_t, + acc_data_t, + out_data_t>(handle, + tmp[I0], + tmp[I1], + tmp[I2], + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + in, + wei, + out_device, + tunable, + nrepeat); + } +#endif + +#if USE_CONV_FWD_V6R1_NCHW + if(algo == ConvForwardAlgo::V6R1NCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + +#if 1 + const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param = { + get_datatype_enum_from_type::value, + get_datatype_enum_from_type::value, + get_datatype_enum_from_type::value, + 256, + 4, + 1, + 128, + 32, + 8, + 4, + 4, + 1, + {8, 2}, + {8, 2}, + {4, 1, 1, 1, 1}, + {2, 1, 1, 128, 1}, + {4, 1, 1, 1, 1}, + {1, 1, 1, 1, 1}, + {1, 4, 1, 1, 1}, + {8, 1, 1, 32, 1}, + {1, 1, 1, 1, 1}, + {1, 1, 1, 1, 1}, + 4, + true, + true}; +#elif 0 + const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param = { + get_datatype_enum_from_type::value, + get_datatype_enum_from_type::value, + get_datatype_enum_from_type::value, + 256, + 4, + 2, + 128, + 32, + 8, + 4, + 4, + 1, + {8, 2}, + {8, 2}, + {4, 1, 1, 1, 2}, + {2, 1, 1, 128, 1}, + {4, 1, 1, 1, 1}, + {1, 1, 1, 1, 1}, + {1, 4, 1, 1, 2}, + {8, 1, 1, 32, 1}, + {1, 1, 1, 1, 1}, + {1, 1, 1, 1, 1}, + 4, + true, + true}; +#elif 1 + const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param = { + get_datatype_enum_from_type::value, + get_datatype_enum_from_type::value, + get_datatype_enum_from_type::value, + 256, + 4, + 4, + 128, + 32, + 8, + 4, + 4, + 1, + {8, 2}, + {8, 2}, + {4, 1, 1, 1, 4}, + {2, 1, 1, 128, 1}, + {4, 1, 1, 1, 1}, + {1, 1, 1, 1, 1}, + {1, 4, 1, 1, 4}, + {8, 1, 1, 32, 1}, + {1, 1, 1, 1, 1}, + {1, 1, 1, 1, 1}, + 4, + true, + true}; +#endif + + online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw< + in_data_t, + acc_data_t, + out_data_t>(handle, + tmp[I0], + tmp[I1], + tmp[I2], + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + in, + wei, + out_device, + compile_param, + nrepeat); + } +#endif + +#if USE_CONV_FWD_V4R4_XDLOPS_NCHW + if(algo == ConvForwardAlgo::V4R4XDLNCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* tunable = + &default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw; + + online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw< + in_data_t, + acc_data_t, + out_data_t>(handle, + tmp[I0], + tmp[I1], + tmp[I2], + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + in, + wei, + out_device, + tunable, + nrepeat); + } +#endif + +#if USE_CONV_FWD_V4R4_XDLOPS_NHWC + if(algo == ConvForwardAlgo::V4R4XDLNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* tunable = + &default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk; + + online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk< + in_data_t, + acc_data_t, + out_data_t>(handle, + tmp[I0], + tmp[I1], + tmp[I2], + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + in, + wei, + out_device, + tunable, + nrepeat); + } +#endif + + if(do_verification) + { + host_direct_convolution(in, + wei, + out_host, + make_tuple(conv_stride_h, conv_stride_w), + make_tuple(conv_dilation_h, conv_dilation_w), + make_tuple(in_left_pad_h, in_left_pad_w), + make_tuple(in_right_pad_h, in_right_pad_w), + layout); + + check_error(out_host, out_device); + +#if 0 + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; + LogRangeAsType(std::cout << "wei: ", wei.mData, ",") << std::endl; + LogRangeAsType(std::cout << "out_host : ", out_host.mData, ",") << std::endl; + LogRangeAsType(std::cout << "out_device: ", out_device.mData, ",") << std::endl; + } +#endif + } + + delete handle; + MY_HIP_CHECK(hipStreamDestroy(stream)); +} diff --git a/host/driver_online/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_online/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..b0c4921019 --- /dev/null +++ b/host/driver_online/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,673 @@ +#ifndef CONV_IGEMM_FWD_V6R1_DLOPS_NCHW_KCYX_NKHW_HPP +#define CONV_IGEMM_FWD_V6R1_DLOPS_NCHW_KCYX_NKHW_HPP + +#include + +namespace ck_driver { + +struct CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw +{ + ck::DataTypeEnum_t ABDataTypeEnum; + ck::DataTypeEnum_t AccDataTypeEnum; + ck::DataTypeEnum_t CDataTypeEnum; + + int BlockSize; + + int GN0; + int GK1; + + int GM1PerBlockGM11; + int GN1PerBlockGN11; + int GK0PerBlock; + + int BM1PerThreadBM11; + int BN1PerThreadBN11; + int BK0PerThread; + + std::array BM10BN10ThreadClusterBM10Xs; + std::array BM10BN10ThreadClusterBN10Xs; + + std::array ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1; + std::array ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1; + std::array ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + std::array ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + + std::array BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1; + std::array BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1; + std::array BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; + std::array BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; + + int CThreadTransferDstScalarPerVector; + + bool HasMainKBlockLoop; + bool HasDoubleTailKBlockLoop; + + auto GetCompileParameterString() const + { + // clang-format off + return + " -DCK_PARAM_ABDataTypeEnum=" + + std::to_string(ABDataTypeEnum) + + " -DCK_PARAM_AccDataTypeEnum=" + + std::to_string(AccDataTypeEnum) + + " -DCK_PARAM_CDataTypeEnum=" + + std::to_string(CDataTypeEnum) + + " -DCK_PARAM_BlockSize=" + + std::to_string(BlockSize) + + " -DCK_PARAM_GN0=" + + std::to_string(GN0) + + " -DCK_PARAM_GK1=" + + std::to_string(GK1) + + " -DCK_PARAM_GM1PerBlockGM11=" + + std::to_string(GM1PerBlockGM11) + + " -DCK_PARAM_GN1PerBlockGN11=" + + std::to_string(GN1PerBlockGN11) + + " -DCK_PARAM_GK0PerBlock=" + + std::to_string(GK0PerBlock) + + " -DCK_PARAM_BM1PerThreadBM11=" + + std::to_string(BM1PerThreadBM11) + + " -DCK_PARAM_BN1PerThreadBN11=" + + std::to_string(BN1PerThreadBN11) + + " -DCK_PARAM_BK0PerThread=" + + std::to_string(BK0PerThread) + + " -DCK_PARAM_BM10BN10ThreadClusterBM10Xs=" + + std::to_string(BM10BN10ThreadClusterBM10Xs[0]) + "," + + std::to_string(BM10BN10ThreadClusterBM10Xs[1]) + + " -DCK_PARAM_BM10BN10ThreadClusterBN10Xs=" + + std::to_string(BM10BN10ThreadClusterBN10Xs[0]) + "," + + std::to_string(BM10BN10ThreadClusterBN10Xs[1]) + + " -DCK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1=" + + std::to_string(ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[0]) + "," + + std::to_string(ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[1]) + "," + + std::to_string(ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[2]) + "," + + std::to_string(ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[3]) + "," + + std::to_string(ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[4]) + + " -DCK_PARAM_ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1=" + + std::to_string(ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[0]) + "," + + std::to_string(ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[1]) + "," + + std::to_string(ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[2]) + "," + + std::to_string(ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[3]) + "," + + std::to_string(ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[4]) + + " -DCK_PARAM_ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" + + std::to_string(ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0]) + "," + + std::to_string(ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1]) + "," + + std::to_string(ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2]) + "," + + std::to_string(ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3]) + "," + + std::to_string(ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4]) + + " -DCK_PARAM_ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" + + std::to_string(ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0]) + "," + + std::to_string(ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1]) + "," + + std::to_string(ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2]) + "," + + std::to_string(ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3]) + "," + + std::to_string(ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4]) + + " -DCK_PARAM_BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1=" + + std::to_string(BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[0]) + "," + + std::to_string(BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[1]) + "," + + std::to_string(BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[2]) + "," + + std::to_string(BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[3]) + "," + + std::to_string(BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[4]) + + " -DCK_PARAM_BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1=" + + std::to_string(BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[0]) + "," + + std::to_string(BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[1]) + "," + + std::to_string(BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[2]) + "," + + std::to_string(BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[3]) + "," + + std::to_string(BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[4]) + + " -DCK_PARAM_BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" + + std::to_string(BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0]) + "," + + std::to_string(BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1]) + "," + + std::to_string(BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2]) + "," + + std::to_string(BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3]) + "," + + std::to_string(BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4]) + + " -DCK_PARAM_BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" + + std::to_string(BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0]) + "," + + std::to_string(BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1]) + "," + + std::to_string(BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2]) + "," + + std::to_string(BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3]) + "," + + std::to_string(BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4]) + + " -DCK_PARAM_CThreadTransferDstScalarPerVector=" + + std::to_string(CThreadTransferDstScalarPerVector) + + " -DCK_PARAM_HasMainKBlockLoop=" + + std::to_string(HasMainKBlockLoop) + + " -DCK_PARAM_HasDoubleTailKBlockLoop=" + + std::to_string(HasDoubleTailKBlockLoop); + // clang-format on + } +}; + +struct TunableConvIgemmFwdV6r1DlopsNchwKcyxNkhw +{ + ck::DataTypeEnum_t ABDataTypeEnum; + ck::DataTypeEnum_t CDataTypeEnum; + + int BlockSize; + + int GN0; + int GK1; + + int GM1PerBlockGM11; + int GN1PerBlockGN11; + int GK0PerBlock; + + int BM1PerThreadBM11; + int BN1PerThreadBN11; + int BK0PerThread; + + std::array BM10BN10ThreadClusterBM10Xs; + std::array BM10BN10ThreadClusterBN10Xs; + + std::array ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1; + std::array ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1; + std::array ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + std::array ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + + std::array BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1; + std::array BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1; + std::array BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; + std::array BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; +}; + +inline static auto generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw() +{ + constexpr auto f32 = ck::DataTypeEnum_t::Float; + constexpr auto f16 = ck::DataTypeEnum_t::Half; + constexpr auto i8 = ck::DataTypeEnum_t::Int8; + + return std::vector{ + // clang-format off + // fp32 + {f32, f32, 256, 1, 1, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 1}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, + + {f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, + {f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}}, + {f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}}, + + {f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 1}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + {f32, f32, 256, 2, 1, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 1}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + {f32, f32, 256, 4, 1, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + {f32, f32, 256, 8, 1, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 1}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + {f32, f32, 128, 1, 1, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 1}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + // fp16 + {f16, f16, 256, 1, 2, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 2}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, + + {f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, + {f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}}, + {f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}}, + + {f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 2}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + {f16, f16, 256, 2, 2, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 2}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + {f16, f16, 256, 4, 2, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + {f16, f16, 256, 8, 2, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 2}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + {f16, f16, 128, 1, 2, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 2}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + // i8 + { i8, i8, 256, 1, 4, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 4}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, + + { i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, + { i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}}, + { i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}}, + + { i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 4}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + { i8, i8, 256, 2, 4, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 4}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + { i8, i8, 256, 4, 4, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + { i8, i8, 256, 8, 4, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 4}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + { i8, i8, 128, 1, 4, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 4}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}} + // clang-format on + }; +} + +// TODO make this common interface and write specs for it +struct ConvIgemmFwdV6r1DlopsNchwKcyxNkhw +{ + static auto + CalculateCompileParameterBasedOnTunable(const ConvolutionProblemDescriptor& conv_problem_desc, + const TunableConvIgemmFwdV6r1DlopsNchwKcyxNkhw& tunable) + { + using namespace ck; + + const int C = conv_problem_desc.C; + const int Y = conv_problem_desc.Y; + const int X = conv_problem_desc.X; + const int Ho = conv_problem_desc.Ho; + const int Wo = conv_problem_desc.Wo; + + if(!(conv_problem_desc.InDataTypeEnum == tunable.ABDataTypeEnum && + conv_problem_desc.WeiDataTypeEnum == tunable.ABDataTypeEnum && + conv_problem_desc.OutDataTypeEnum == tunable.CDataTypeEnum)) + return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); + + const auto ABDataTypeEnum = conv_problem_desc.InDataTypeEnum; + const auto CDataTypeEnum = conv_problem_desc.OutDataTypeEnum; + + DataTypeEnum_t AccDataTypeEnum; + + switch(ABDataTypeEnum) + { + case DataTypeEnum_t::Float: + case DataTypeEnum_t::Half: AccDataTypeEnum = DataTypeEnum_t::Float; break; + case DataTypeEnum_t::Int8: AccDataTypeEnum = DataTypeEnum_t::Int32; break; + default: return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); + } + + const int BlockSize = tunable.BlockSize; + + const int GN0 = tunable.GN0; + const int GK1 = tunable.GK1; + + const int GM11 = tunable.GM1PerBlockGM11; + const int GN11 = tunable.GN1PerBlockGN11; + const int GK0PerBlock = tunable.GK0PerBlock; + + const int BM11 = tunable.BM1PerThreadBM11; + const int BN11 = tunable.BN1PerThreadBN11; + const int BK0PerThread = tunable.BK0PerThread; + + const auto BM10BN10ThreadClusterBM10Xs = tunable.BM10BN10ThreadClusterBM10Xs; + const auto BM10BN10ThreadClusterBN10Xs = tunable.BM10BN10ThreadClusterBN10Xs; + + const auto ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = + tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1; + const auto ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = + tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1; + const auto ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = + tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + const auto ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = + tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + + const auto BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = + tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1; + const auto BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = + tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1; + const auto BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = + tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; + const auto BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = + tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; + + // C threadwise copy: {BN11} or {BN} or {BN1} or {GN11} is Dst vector dim + const int CThreadTransferDstScalarPerVector = gcd(4, GN11, BN11, Ho * Wo); + + const int C0 = GK1; + + if(!(C % C0 == 0)) + return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); + + const int C1 = C / C0; + + const int GK0 = C1 * Y * X; + + if(!(GK0 % GK0PerBlock == 0)) + return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); + + const bool HasMainKBlockLoop = ((GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1); + + const bool HasDoubleTailKBlockLoop = ((GK0 / GK0PerBlock) % 2 == 0); + + return std::make_tuple( + CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{ + ABDataTypeEnum, + AccDataTypeEnum, + CDataTypeEnum, + BlockSize, + GN0, + GK1, + GM11, + GN11, + GK0PerBlock, + BM11, + BN11, + BK0PerThread, + BM10BN10ThreadClusterBM10Xs, + BM10BN10ThreadClusterBN10Xs, + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + CThreadTransferDstScalarPerVector, + HasMainKBlockLoop, + HasDoubleTailKBlockLoop}, + true); + } + + static auto GetDefaultCompileParameter(const ConvolutionProblemDescriptor& conv_problem_desc) + { + for(const auto& tunable : generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw()) + { + CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param; + bool found = false; + + std::tie(compile_param, found) = + CalculateCompileParameterBasedOnTunable(conv_problem_desc, tunable); + + if(found && IsValidCompileParameter(conv_problem_desc, compile_param)) + return std::make_tuple(compile_param, true); + } + + return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); + } + + static bool IsApplicable(const ConvolutionProblemDescriptor& conv_problem_desc) + { + bool found = false; + + std::tie(std::ignore, found) = GetDefaultCompileParameter(conv_problem_desc); + + return found; + } + + static bool + IsValidCompileParameter(const ConvolutionProblemDescriptor& conv_problem_desc, + const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param) + { + using namespace ck; + + const int N = conv_problem_desc.N; + const int K = conv_problem_desc.K; + const int C = conv_problem_desc.C; + const int Y = conv_problem_desc.Y; + const int X = conv_problem_desc.X; + const int Ho = conv_problem_desc.Ho; + const int Wo = conv_problem_desc.Wo; + + const int GK1 = compile_param.GK1; + const int GN0 = compile_param.GN0; + const int GM11 = compile_param.GM1PerBlockGM11; + const int GN11 = compile_param.GN1PerBlockGN11; + + const int BM11 = compile_param.BM1PerThreadBM11; + const int BN11 = compile_param.BN1PerThreadBN11; + + const int C0 = GK1; + const int N0 = GN0; + + if(!(C % C0 == 0)) + return false; + + const int C1 = C / C0; + + if(!(N % N0 == 0)) + return false; + + const int N1 = N / N0; + + const int GM0 = 1; + const int GM1 = K; + const int GN1 = N1 * Ho * Wo; + const int GK0 = C1 * Y * X; + + // check data type + { + if(!(conv_problem_desc.InDataTypeEnum == conv_problem_desc.WeiDataTypeEnum && + conv_problem_desc.InDataTypeEnum == compile_param.ABDataTypeEnum)) + return false; + + if(compile_param.ABDataTypeEnum == DataTypeEnum_t::Float || + compile_param.ABDataTypeEnum == DataTypeEnum_t::Half) + { + if(!(compile_param.AccDataTypeEnum == DataTypeEnum_t::Float)) + return false; + } + else if(compile_param.ABDataTypeEnum == DataTypeEnum_t::Int8) + { + if(!(compile_param.AccDataTypeEnum == DataTypeEnum_t::Int32)) + return false; + } + } + + // check gridwise contraction + { + if(!(GM1 % GM11 == 0 && GN1 % GN11 == 0 && GK0 % compile_param.GK0PerBlock == 0)) + return false; + + const bool has_main_k_block_loop = + ((GK0 + compile_param.GK0PerBlock) / (2 * compile_param.GK0PerBlock) > 1); + + const bool has_double_tail_k_block_loop = ((GK0 / compile_param.GK0PerBlock) % 2 == 0); + + if(!(has_main_k_block_loop == compile_param.HasMainKBlockLoop && + has_double_tail_k_block_loop == compile_param.HasDoubleTailKBlockLoop)) + return false; + } + + // check A blockwise copy + { + const auto block_slice_lengths = + std::array{compile_param.GK0PerBlock, GM0, 1, GM11, GK1}; + const auto& cluster_lengths = + compile_param.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1; + const auto& thread_slice_lengths = + compile_param.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1; + const auto& src_vector_lengths = + compile_param.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + const auto& dst_vector_lengths = + compile_param.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + + // check number of working thread + const int num_work_thread = std::accumulate( + cluster_lengths.begin(), cluster_lengths.end(), 1, std::multiplies{}); + + if(!(compile_param.BlockSize >= num_work_thread)) + return false; + + // check block slice lengths vs thread slice lengths vs cluster lengths + for(int i = 0; i < 5; ++i) + { + if(!(cluster_lengths[i] * thread_slice_lengths[i] == block_slice_lengths[i])) + return false; + } + + // check thread slice lengths vs vector lengths + for(int i = 0; i < 5; ++i) + { + if(!(thread_slice_lengths[i] % src_vector_lengths[i] == 0)) + return false; + + if(!(thread_slice_lengths[i] % dst_vector_lengths[i] == 0)) + return false; + } + + // check Src vectorization, GK0 is global mem vector dim + if(!(src_vector_lengths[1] == 1 && src_vector_lengths[2] == 1 && + src_vector_lengths[3] == 1 && src_vector_lengths[4] == 1)) + return false; + + // check Dst vectorization, {GM11, GK1} are LDS vector dims + if(dst_vector_lengths[4] == GK1) + { // vectorize on {GM11, GK1} + if(!(GM11 % dst_vector_lengths[3] == 0)) + return false; + } + else + { // vectorize on {GK1} only + if(!(GK1 % dst_vector_lengths[4] == 0)) + return false; + + if(!(dst_vector_lengths[3] == 1)) + return false; + } + } + + // check B blockwise copy + { + const auto block_slice_lengths = + std::array{compile_param.GK0PerBlock, GN0, 1, GN11, GK1}; + const auto& cluster_lengths = + compile_param.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1; + const auto& thread_slice_lengths = + compile_param.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1; + const auto& src_vector_lengths = + compile_param.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; + const auto& dst_vector_lengths = + compile_param.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; + + // check number of working thread + const int num_work_thread = std::accumulate( + cluster_lengths.begin(), cluster_lengths.end(), 1, std::multiplies{}); + + if(!(compile_param.BlockSize >= num_work_thread)) + return false; + + // check block slice lengths vs thread slice lengths vs cluster lengths + for(int i = 0; i < 5; ++i) + { + if(!(cluster_lengths[i] * thread_slice_lengths[i] == block_slice_lengths[i])) + return false; + } + + // check thread slice lengths vs vector lengths + for(int i = 0; i < 5; ++i) + { + if(!(thread_slice_lengths[i] % src_vector_lengths[i] == 0 && + thread_slice_lengths[i] % dst_vector_lengths[i] == 0)) + return false; + } + + // check Src vectorization: {GN11} is global mem vector dim + if(!(src_vector_lengths[0] == 1 && src_vector_lengths[1] == 1 && + src_vector_lengths[2] == 1 && src_vector_lengths[4] == 1)) + return false; + + // check Src tensor layout related vectorization + if(Y == 1 && X == 1 && conv_problem_desc.ConvStrideH == 1 && + conv_problem_desc.ConvStrideW == 1 && conv_problem_desc.InLeftPadH == 0 && + conv_problem_desc.InLeftPadW == 0 && conv_problem_desc.InRightPadH == 0 && + conv_problem_desc.InRightPadW == 0) + { + if(!((Ho * Wo) % src_vector_lengths[3] == 0)) + return false; + } + else if(conv_problem_desc.ConvStrideW == 1 && conv_problem_desc.InLeftPadW == 0 && + conv_problem_desc.InRightPadW == 0) + { + if(!(Wo % src_vector_lengths[3] == 0)) + return false; + } + else + { + if(!(src_vector_lengths[3] == 1)) + return false; + } + + // check Dst vectorization: {GN11, GK1} are LDS vector dims + if(dst_vector_lengths[4] == GK1) + { // vectorize on {GN11, GK1} + if(!(GN11 % dst_vector_lengths[3] == 0)) + return false; + } + else + { // vectorize on {GK1} only + if(!(dst_vector_lengths[3] == 1)) + return false; + + if(!(GK1 % dst_vector_lengths[4] == 0)) + return false; + } + } + + // check blockwise GEMM + { + const int BM10 = std::accumulate(compile_param.BM10BN10ThreadClusterBM10Xs.begin(), + compile_param.BM10BN10ThreadClusterBM10Xs.end(), + 1, + std::multiplies{}); + + const int BN10 = std::accumulate(compile_param.BM10BN10ThreadClusterBN10Xs.begin(), + compile_param.BM10BN10ThreadClusterBN10Xs.end(), + 1, + std::multiplies{}); + + if(!(compile_param.BlockSize == BM10 * BN10)) + return false; + + const int BM = GM0 * GM11; + const int BN = GN0 * GN11; + + const int BM1 = BM10 * BM11; + const int BN1 = BN10 * BN11; + + if(!(BM % BM1 == 0 && BN % BN1 == 0)) + return false; + + const int BM0 = BM / BM1; + const int BN0 = BN / BN1; + + // blockwise GEMM currently only support BM0 == 2 && BN0 == 2 + if(!(BM0 == 2 && BN0 == 2)) + return false; + + if(!(compile_param.GK0PerBlock % compile_param.BK0PerThread == 0)) + return false; + } + + // check C threadwise copy + { + // {BN11} or {BN} or {BN1} or {GN11} is Dst vector dim + const int dst_vector_len_gn11 = compile_param.CThreadTransferDstScalarPerVector; + + // check slice length vs Dst vector length: + if(!(BN11 % dst_vector_len_gn11 == 0 && GN11 % dst_vector_len_gn11 == 0)) + return false; + + // check Dst memory layout related vectorization: + if(!((Ho * Wo) % compile_param.CThreadTransferDstScalarPerVector == 0)) + return false; + } + + return true; + }; + + static int GetBlockSize(const ConvolutionProblemDescriptor&, + const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param) + { + return compile_param.BlockSize; + } + + static int GetGridSize(const ConvolutionProblemDescriptor& conv_problem_desc, + const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param) + { + const int N = conv_problem_desc.N; + const int K = conv_problem_desc.K; + const int Ho = conv_problem_desc.Ho; + const int Wo = conv_problem_desc.Wo; + + const int N0 = compile_param.GN0; + const int N1 = N / N0; + + const int GM1 = K; + const int GN1 = N1 * Ho * Wo; + + const int GM11 = compile_param.GM1PerBlockGM11; + const int GN11 = compile_param.GN1PerBlockGN11; + + const int GM10 = GM1 / GM11; + const int GN10 = GN1 / GN11; + + return GM10 * GN10; + } + + static std::size_t GetWorkSpaceSize(const ConvolutionProblemDescriptor&, + const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw&) + { + // workspace is used for save transformed tensor descritpors created by prepare kernel + return 4096L; + } + + static std::size_t GetMaxWorkSpaceSize(const ConvolutionProblemDescriptor&) { return 4096L; } + + static auto GetTunableList() + { + return generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw(); + } +}; + +} // namespace ck_driver +#endif diff --git a/host/driver_online/include/conv_tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp b/host/driver_online/include/conv_tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..58fe588ad9 --- /dev/null +++ b/host/driver_online/include/conv_tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,51 @@ +#ifndef CONV_TUNABLE_FWD_V4R4_DLOPS_NCHW_KCYX_NKHW_HPP +#define CONV_TUNABLE_FWD_V4R4_DLOPS_NCHW_KCYX_NKHW_HPP + +struct tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw +{ + int BlockSize; + + int MPerBlock; + int NPerBlock; + int KPerBlock; + + int M1PerThread; + int N1PerThread; + int KPerThread; + + int M1N1ThreadClusterM10; + int M1N1ThreadClusterN10; + int M1N1ThreadClusterM11; + int M1N1ThreadClusterN11; + + std::array ABlockTransferThreadSliceLengths_K_M0_M1; + std::array ABlockTransferThreadClusterLengths_K_M0_M1; + std::array ABlockTransferThreadClusterArrangeOrder; + std::array ABlockTransferSrcAccessOrder; + int ABlockTransferSrcVectorDim; + int ABlockTransferSrcScalarPerVector; + int ABlockTransferDstScalarPerVector_M1; + bool AThreadTransferSrcResetCoordinateAfterRun; + + std::array BBlockTransferThreadSliceLengths_K_N0_N1; + std::array BBlockTransferThreadClusterLengths_K_N0_N1; + std::array BBlockTransferThreadClusterArrangeOrder; + std::array BBlockTransferSrcAccessOrder; + int BBlockTransferSrcVectorDim; + int BBlockTransferSrcScalarPerVector; + int BBlockTransferDstScalarPerVector_N1; + bool BThreadTransferSrcResetCoordinateAfterRun; + + std::array CThreadTransferSrcDstAccessOrder; + int CThreadTransferSrcDstVectorDim; + int CThreadTransferDstScalarPerVector; +}; + +static tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw + default_tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw = { + 256, 128, 128, 8, 4, 4, 1, + 8, 8, 2, 2, {4, 1, 1}, {2, 1, 128}, {2, 1, 0}, + {2, 1, 0}, 0, 4, 1, false, {4, 1, 1}, {2, 1, 128}, + {0, 1, 2}, {0, 1, 2}, 2, 1, 1, false, {3, 4, 5, 0, 1, 2}, + 5, 1}; +#endif diff --git a/host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..97ce326346 --- /dev/null +++ b/host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,73 @@ +#ifndef CONV_TUNABLE_FWD_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP +#define CONV_TUNABLE_FWD_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP + +struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw +{ + int BlockSize; + + int MPerBlock; + int NPerBlock; + int KPerBlock; + + int MPerWave; + int NPerWave; + int K1; + + int MRepeat; + int NRepeat; + + std::array ABlockTransferThreadSliceLengths_K0_M_K1; + std::array ABlockTransferThreadClusterLengths_K0_M_K1; + std::array ABlockTransferThreadClusterArrangeOrder; + std::array ABlockTransferSrcAccessOrder; + int ABlockTransferSrcVectorDim; + int ABlockTransferSrcScalarPerVector; + int ABlockTransferDstScalarPerVector_K1; + bool AThreadTransferSrcResetCoordinateAfterRun; + + std::array BBlockTransferThreadSliceLengths_K0_N_K1; + std::array BBlockTransferThreadClusterLengths_K0_N_K1; + std::array BBlockTransferThreadClusterArrangeOrder; + std::array BBlockTransferSrcAccessOrder; + int BBlockTransferSrcVectorDim; + int BBlockTransferSrcScalarPerVector; + int BBlockTransferDstScalarPerVector_K1; + bool BThreadTransferSrcResetCoordinateAfterRun; + + std::array CThreadTransferSrcDstAccessOrder; + int CThreadTransferSrcDstVectorDim; + int CThreadTransferDstScalarPerVector; +}; + +static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw + default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw = { + 256, // BlockSize + 128, // MPerBlock, + 128, // NPerBlock, + 4, // KPerBlock, + 32, // MPerWave, + 32, // NPerWave, + 4, // K1, + 2, // MRepeat, + 2, // NRepeat, + {1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1, + {4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1, + {1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder, + {1, 0, 2}, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector, + 4, // ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + {1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1, + {4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1, + {0, 2, 1}, // BBlockTransferThreadClusterArrangeOrder, + {1, 0, 2}, // BBlockTransferSrcAccessOrder, + 1, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_K1 + false, // BThreadTransferSrcResetCoordinateAfterRun + {3, 0, 1, 2, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder + 7, // CThreadTransferSrcDstVectorDim, + 1 // CThreadTransferDstScalarPerVector +}; +#endif diff --git a/host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..263c21a13b --- /dev/null +++ b/host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,73 @@ +#ifndef CONV_TUNABLE_FWD_V4R4_XDLOPS_NHWC_KYXC_NHWK_HPP +#define CONV_TUNABLE_FWD_V4R4_XDLOPS_NHWC_KYXC_NHWK_HPP + +struct tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk +{ + int BlockSize; + + int MPerBlock; + int NPerBlock; + int KPerBlock; + + int MPerWave; + int NPerWave; + int K1; + + int MRepeat; + int NRepeat; + + std::array ABlockTransferThreadSliceLengths_K0_M_K1; + std::array ABlockTransferThreadClusterLengths_K0_M_K1; + std::array ABlockTransferThreadClusterArrangeOrder; + std::array ABlockTransferSrcAccessOrder; + int ABlockTransferSrcVectorDim; + int ABlockTransferSrcScalarPerVector; + int ABlockTransferDstScalarPerVector_K1; + bool AThreadTransferSrcResetCoordinateAfterRun; + + std::array BBlockTransferThreadSliceLengths_K0_N_K1; + std::array BBlockTransferThreadClusterLengths_K0_N_K1; + std::array BBlockTransferThreadClusterArrangeOrder; + std::array BBlockTransferSrcAccessOrder; + int BBlockTransferSrcVectorDim; + int BBlockTransferSrcScalarPerVector; + int BBlockTransferDstScalarPerVector_K1; + bool BThreadTransferSrcResetCoordinateAfterRun; + + std::array CThreadTransferSrcDstAccessOrder; + int CThreadTransferSrcDstVectorDim; + int CThreadTransferDstScalarPerVector; +}; + +static tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk + default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk = { + 256, // BlockSize + 128, // MPerBlock, + 128, // NPerBlock, + 4, // KPerBlock, + 32, // MPerWave, + 32, // NPerWave, + 4, // K1, + 2, // MRepeat, + 2, // NRepeat, + {1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1, + {4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1, + {1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder, + {1, 0, 2}, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim + 4, // ABlockTransferSrcScalarPerVector, + 4, // ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + {1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1, + {4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1, + {1, 0, 2}, // BBlockTransferThreadClusterArrangeOrder, + {1, 0, 2}, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_K1 + false, // BThreadTransferSrcResetCoordinateAfterRun + {2, 3, 0, 1, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder + 7, // CThreadTransferSrcDstVectorDim, + 1 // CThreadTransferDstScalarPerVector +}; +#endif diff --git a/host/driver_online/include/convolution_problem_descriptor.hpp b/host/driver_online/include/convolution_problem_descriptor.hpp new file mode 100644 index 0000000000..df9c110e70 --- /dev/null +++ b/host/driver_online/include/convolution_problem_descriptor.hpp @@ -0,0 +1,79 @@ +#ifndef CONVOLUTION_PROBLEM_DESCRIPTOR +#define CONVOLUTION_PROBLEM_DESCRIPTOR + +namespace ck_driver { + +struct ConvolutionProblemDescriptor +{ + ConvolutionProblemDescriptor() = default; + + ConvolutionProblemDescriptor(int N_, + int K_, + int C_, + int Y_, + int X_, + int Hi_, + int Wi_, + int Ho_, + int Wo_, + int ConvStrideH_, + int ConvStrideW_, + int ConvDilationH_, + int ConvDilationW_, + int InLeftPadH_, + int InLeftPadW_, + int InRightPadH_, + int InRightPadW_, + ck::DataTypeEnum_t InDataTypeEnum_, + ck::DataTypeEnum_t WeiDataTypeEnum_, + ck::DataTypeEnum_t OutDataTypeEnum_) + : N{N_}, + K{K_}, + C{C_}, + Y{Y_}, + X{X_}, + Hi{Hi_}, + Wi{Wi_}, + Ho{Ho_}, + Wo{Wo_}, + ConvStrideH{ConvStrideH_}, + ConvStrideW{ConvStrideW_}, + ConvDilationH{ConvDilationH_}, + ConvDilationW{ConvDilationW_}, + InLeftPadH{InLeftPadH_}, + InLeftPadW{InLeftPadW_}, + InRightPadH{InRightPadH_}, + InRightPadW{InRightPadW_}, + InDataTypeEnum{InDataTypeEnum_}, + WeiDataTypeEnum{WeiDataTypeEnum_}, + OutDataTypeEnum{OutDataTypeEnum_} + { + } + + int N; + int K; + int C; + int Y; + int X; + int Hi; + int Wi; + int Ho; + int Wo; + int ConvStrideH; + int ConvStrideW; + int ConvDilationH; + int ConvDilationW; + int InLeftPadH; + int InLeftPadW; + int InRightPadH; + int InRightPadW; + + ck::DataTypeEnum_t InDataTypeEnum; + ck::DataTypeEnum_t WeiDataTypeEnum; + ck::DataTypeEnum_t OutDataTypeEnum; + + std::size_t CalculateFlop() const { return 2L * N * K * C * Y * X * Ho * Wo; } +}; + +} // namespace ck_driver +#endif diff --git a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..628bb6d96d --- /dev/null +++ b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,395 @@ +#pragma once +#include "device.hpp" +#include "host_tensor.hpp" +#include "handle.hpp" +#include "online_driver_common.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" +#include "conv_tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp" + +namespace detail_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw { + +template +static std::string get_network_config_string_from_types() +{ + using namespace ck; + + std::string out; + + out += std::to_string(get_datatype_enum_from_type::value) + "_" + + std::to_string(get_datatype_enum_from_type::value) + "_" + + std::to_string(get_datatype_enum_from_type::value); + + return (out); +}; + +static std::string +get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw* pt) +{ + std::string out("TUN_"); + + out += std::to_string(pt->BlockSize) + "_"; + + out += std::to_string(pt->MPerBlock) + "x" + std::to_string(pt->NPerBlock) + "x" + + std::to_string(pt->KPerBlock) + "_"; + out += std::to_string(pt->M1PerThread) + "x" + std::to_string(pt->N1PerThread) + "x" + + std::to_string(pt->KPerThread) + "_"; + out += std::to_string(pt->M1N1ThreadClusterM10) + "x" + + std::to_string(pt->M1N1ThreadClusterN10) + "x" + + std::to_string(pt->M1N1ThreadClusterM11) + "x" + + std::to_string(pt->M1N1ThreadClusterN11) + "_"; + + out += std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[0]) + "x" + + std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[1]) + "x" + + std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[2]) + "_"; + + out += std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[0]) + "x" + + std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[1]) + "x" + + std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[2]) + "_"; + + out += std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "x" + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "x" + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "_"; + + out += std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "x" + + std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "x" + + std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "_"; + + out += std::to_string(pt->ABlockTransferSrcVectorDim) + "_"; + out += std::to_string(pt->ABlockTransferSrcScalarPerVector) + "_"; + out += std::to_string(pt->ABlockTransferDstScalarPerVector_M1) + "_"; + out += std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun) + "_"; + + out += std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[0]) + "x" + + std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[1]) + "x" + + std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[2]) + "_"; + + out += std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[0]) + "x" + + std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[1]) + "x" + + std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[2]) + "_"; + + out += std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "x" + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "x" + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "_"; + + out += std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "x" + + std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "x" + + std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "_"; + + out += std::to_string(pt->BBlockTransferSrcVectorDim) + "_"; + out += std::to_string(pt->BBlockTransferSrcScalarPerVector) + "_"; + out += std::to_string(pt->BBlockTransferDstScalarPerVector_N1) + "_"; + out += std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun) + "_"; + + out += std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "_"; + + out += std::to_string(pt->CThreadTransferSrcDstVectorDim) + "_"; + out += std::to_string(pt->CThreadTransferDstScalarPerVector); + + return (out); +}; + +template +static std::string get_definition_string_from_types() +{ + using namespace ck; + + std::string out; + + out += + " -DCK_PARAM_ABDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value) + + " -DCK_PARAM_AccDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value) + + " -DCK_PARAM_CDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value); + + return (out); +}; + +static std::string +get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw* pt) +{ + std::string out; + + out += " -DCK_PARAM_BlockSize=" + std::to_string(pt->BlockSize); + + out += " -DCK_PARAM_MPerBlock=" + std::to_string(pt->MPerBlock) + + " -DCK_PARAM_NPerBlock=" + std::to_string(pt->NPerBlock) + + " -DCK_PARAM_KPerBlock=" + std::to_string(pt->KPerBlock); + out += " -DCK_PARAM_M1PerThread=" + std::to_string(pt->M1PerThread) + + " -DCK_PARAM_N1PerThread=" + std::to_string(pt->N1PerThread) + + " -DCK_PARAM_KPerThread=" + std::to_string(pt->KPerThread); + + out += " -DCK_PARAM_M1N1ThreadClusterM10=" + std::to_string(pt->M1N1ThreadClusterM10) + + " -DCK_PARAM_M1N1ThreadClusterN10=" + std::to_string(pt->M1N1ThreadClusterN10) + + " -DCK_PARAM_M1N1ThreadClusterM11=" + std::to_string(pt->M1N1ThreadClusterM11) + + " -DCK_PARAM_M1N1ThreadClusterN11=" + std::to_string(pt->M1N1ThreadClusterN11); + + out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_K_M0_M1=" + + std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[0]) + "," + + std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[1]) + "," + + std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[2]); + + out += " -DCK_PARAM_ABlockTransferThreadClusterLengths_K_M0_M1=" + + std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[0]) + "," + + std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[1]) + "," + + std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[2]); + + out += " -DCK_PARAM_ABlockTransferThreadClusterArrangeOrder=" + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "," + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "," + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]); + + out += " -DCK_PARAM_ABlockTransferSrcAccessOrder=" + + std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "," + + std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "," + + std::to_string(pt->ABlockTransferSrcAccessOrder[2]); + + out += + " -DCK_PARAM_ABlockTransferSrcVectorDim=" + std::to_string(pt->ABlockTransferSrcVectorDim); + out += " -DCK_PARAM_ABlockTransferSrcScalarPerVector=" + + std::to_string(pt->ABlockTransferSrcScalarPerVector); + out += " -DCK_PARAM_ABlockTransferDstScalarPerVector_M1=" + + std::to_string(pt->ABlockTransferDstScalarPerVector_M1); + out += " -DCK_PARAM_AThreadTransferSrcResetCoordinateAfterRun=" + + std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun); + + out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_K_N0_N1=" + + std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[0]) + "," + + std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[1]) + "," + + std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[2]); + + out += " -DCK_PARAM_BBlockTransferThreadClusterLengths_K_N0_N1=" + + std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[0]) + "," + + std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[1]) + "," + + std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[2]); + + out += " -DCK_PARAM_BBlockTransferThreadClusterArrangeOrder=" + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "," + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "," + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]); + + out += " -DCK_PARAM_BBlockTransferSrcAccessOrder=" + + std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "," + + std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "," + + std::to_string(pt->BBlockTransferSrcAccessOrder[2]); + + out += + " -DCK_PARAM_BBlockTransferSrcVectorDim=" + std::to_string(pt->BBlockTransferSrcVectorDim); + out += " -DCK_PARAM_BBlockTransferSrcScalarPerVector=" + + std::to_string(pt->BBlockTransferSrcScalarPerVector); + out += " -DCK_PARAM_BBlockTransferDstScalarPerVector_N1=" + + std::to_string(pt->BBlockTransferDstScalarPerVector_N1); + out += " -DCK_PARAM_BThreadTransferSrcResetCoordinateAfterRun=" + + std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun); + + out += " -DCK_PARAM_CThreadTransferSrcDstAccessOrder=" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]); + + out += " -DCK_PARAM_CThreadTransferSrcDstVectorDim=" + + std::to_string(pt->CThreadTransferSrcDstVectorDim); + out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" + + std::to_string(pt->CThreadTransferDstScalarPerVector); + + return (out); +}; + +} // namespace detail_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw + +template +void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( + olCompile::Handle* handle, + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw* tunable, + ck::index_t nrepeat) +{ + using namespace ck; + using namespace ck_driver; + using namespace detail_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw; + using size_t = std::size_t; + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + // The follow codes are only used for computing the grid_size, hasMainKBlockLoop, + // hasDoubleTailKBlockLoop + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); + + const auto descs = + transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads); + const auto a_k_m_grid_desc = descs[I0]; + const auto c_m_n_grid_desc = descs[I2]; + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + const auto K = a_k_m_grid_desc.GetLength(I0); + + const index_t grid_size = (M / tunable->MPerBlock) * (N / tunable->NPerBlock); + const bool hasMainKBlockLoop = ((K + tunable->KPerBlock) / (2 * tunable->KPerBlock) > 1); + const bool hasDoubleTailKBlockLoop = ((K / tunable->KPerBlock) % 2 == 0); + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + + // these buffers are usually provided by the user application + DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + // these are workspace buffers that should be expressed to the user by the corresponding + // workspace API + DeviceMem workspace_buf(4096); + + void* a_k_m0_m1_grid_desc_dev_buf = workspace_buf.GetDeviceBuffer(); + void* b_k_n0_n1_grid_desc_dev_buf = + static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 1024); + void* c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf = + static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 2048); + void* c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf = + static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 3072); + + const std::vector vld = {static_cast(tunable->BlockSize), 1, 1}; + const std::vector vgd1 = {static_cast(tunable->BlockSize), 1, 1}; + const std::vector vgd2 = {static_cast(grid_size * tunable->BlockSize), 1, 1}; + + std::string program_name = + "dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp"; + std::string algo_name = "implicit_gemm_conv_fwd_v4r4_dlops_nchw"; + + std::string param = " -std=c++17 "; + std::string network_config; + + param += get_definition_string_from_types() + " " + + get_definition_string_from_tunable(tunable) + + " -DCK_PARAM_HAS_MAIN_KBLOCK_LOOP=" + std::to_string(hasMainKBlockLoop) + + " -DCK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP=" + std::to_string(hasDoubleTailKBlockLoop); + network_config = get_network_config_string_from_types() + "_" + + get_network_config_string_from_tunable(tunable) + "_" + + std::to_string(hasMainKBlockLoop) + "_" + + std::to_string(hasDoubleTailKBlockLoop); + + std::vector kernel1_times; + std::vector kernel2_times; + + for(index_t i = 0; i < nrepeat; ++i) + { + KernelTimer timer1, timer2; + std::string kernel_name; + + kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare"; + auto network_config_1 = network_config + "_1"; + + timer1.Start(); + handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)( + static_cast(in_n_c_hi_wi_lengths[I0]), + static_cast(in_n_c_hi_wi_lengths[I1]), + static_cast(in_n_c_hi_wi_lengths[I2]), + static_cast(in_n_c_hi_wi_lengths[I3]), + static_cast(wei_k_c_y_x_lengths[I0]), + static_cast(wei_k_c_y_x_lengths[I2]), + static_cast(wei_k_c_y_x_lengths[I3]), + conv_strides[I0], + conv_strides[I1], + conv_dilations[I0], + conv_dilations[I1], + in_left_pads[I0], + in_left_pads[I1], + in_right_pads[I0], + in_right_pads[I1], + a_k_m0_m1_grid_desc_dev_buf, + b_k_n0_n1_grid_desc_dev_buf, + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf, + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf); + timer1.End(); + + kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw"; + auto network_config_2 = network_config + "_2"; + + timer2.Start(); + handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)( + reinterpret_cast(wei_k_c_y_x_dev_buf.GetDeviceBuffer()), + reinterpret_cast(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()), + reinterpret_cast(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()), + (const void*)(a_k_m0_m1_grid_desc_dev_buf), + (const void*)(b_k_n0_n1_grid_desc_dev_buf), + (const void*)(c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf), + (const void*)(c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf)); + timer2.End(); + + kernel1_times.push_back(timer1.GetElapsedTime()); + kernel2_times.push_back(timer2.GetElapsedTime()); + } + + { + auto ave_time1 = + std::accumulate( + std::next(kernel1_times.begin()), kernel1_times.end(), 0., std::plus{}) / + (nrepeat - 1); + auto ave_time2 = + std::accumulate( + std::next(kernel2_times.begin()), kernel2_times.end(), 0., std::plus{}) / + (nrepeat - 1); + + const auto N = in_n_c_hi_wi_lengths[I0]; + const auto C = in_n_c_hi_wi_lengths[I1]; + + const auto K = out_n_k_ho_wo_lengths[I1]; + const auto Ho = out_n_k_ho_wo_lengths[I2]; + const auto Wo = out_n_k_ho_wo_lengths[I3]; + + const auto Y = wei_k_c_y_x_lengths[I2]; + const auto X = wei_k_c_y_x_lengths[I3]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2); + + std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", " + << ave_time2 << "), " << perf << " TFlop/s" << std::endl; + }; + + // copy result back to host + out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data()); +} diff --git a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..1e213b92e1 --- /dev/null +++ b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,386 @@ +#include "device.hpp" +#include "host_tensor.hpp" +#include "handle.hpp" +#include "online_driver_common.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp" + +namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw { + +template +static std::string get_network_config_string_from_types() +{ + using namespace ck; + + std::string out; + + out += std::to_string(get_datatype_enum_from_type::value) + "_" + + std::to_string(get_datatype_enum_from_type::value) + "_" + + std::to_string(get_datatype_enum_from_type::value); + + return (out); +}; + +static std::string +get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* pt) +{ + std::string out("TUN_"); + + out += std::to_string(pt->BlockSize) + "_"; + + out += std::to_string(pt->MPerBlock) + "x" + std::to_string(pt->NPerBlock) + "x" + + std::to_string(pt->KPerBlock) + "_"; + out += std::to_string(pt->MPerWave) + "x" + std::to_string(pt->NPerWave) + "x" + + std::to_string(pt->MRepeat) + "x" + std::to_string(pt->NRepeat) + "x" + + std::to_string(pt->K1) + "_"; + + out += std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "x" + + std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "x" + + std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]) + "_"; + + out += std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "x" + + std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "x" + + std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]) + "_"; + + out += std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "x" + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "x" + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "_"; + + out += std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "x" + + std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "x" + + std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "_"; + + out += std::to_string(pt->ABlockTransferSrcVectorDim) + "_"; + out += std::to_string(pt->ABlockTransferSrcScalarPerVector) + "_"; + out += std::to_string(pt->ABlockTransferDstScalarPerVector_K1) + "_"; + out += std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun) + "_"; + + out += std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "x" + + std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "x" + + std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]) + "_"; + + out += std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "x" + + std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "x" + + std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]) + "_"; + + out += std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "x" + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "x" + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "_"; + + out += std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "x" + + std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "x" + + std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "_"; + + out += std::to_string(pt->BBlockTransferSrcVectorDim) + "_"; + out += std::to_string(pt->BBlockTransferSrcScalarPerVector) + "_"; + out += std::to_string(pt->BBlockTransferDstScalarPerVector_K1) + "_"; + out += std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun) + "_"; + + out += std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]) + "_"; + + out += std::to_string(pt->CThreadTransferSrcDstVectorDim) + "_"; + out += std::to_string(pt->CThreadTransferDstScalarPerVector); + + return (out); +}; + +template +static std::string get_definition_string_from_types() +{ + using namespace ck; + + std::string out; + + out += + " -DCK_PARAM_ABDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value) + + " -DCK_PARAM_AccDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value) + + " -DCK_PARAM_CDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value); + + return (out); +}; + +static std::string +get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* pt) +{ + std::string out; + + out += " -DCK_PARAM_BlockSize=" + std::to_string(pt->BlockSize); + + out += " -DCK_PARAM_MPerBlock=" + std::to_string(pt->MPerBlock) + + " -DCK_PARAM_NPerBlock=" + std::to_string(pt->NPerBlock) + + " -DCK_PARAM_KPerBlock=" + std::to_string(pt->KPerBlock); + out += " -DCK_PARAM_MPerWave=" + std::to_string(pt->MPerWave) + + " -DCK_PARAM_NPerWave=" + std::to_string(pt->NPerWave) + + " -DCK_PARAM_K1=" + std::to_string(pt->K1) + + " -DCK_PARAM_MRepeat=" + std::to_string(pt->MRepeat) + + " -DCK_PARAM_NRepeat=" + std::to_string(pt->NRepeat); + + out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1=" + + std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "," + + std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "," + + std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]); + + out += " -DCK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1=" + + std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "," + + std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "," + + std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]); + + out += " -DCK_PARAM_ABlockTransferThreadClusterArrangeOrder=" + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "," + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "," + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]); + + out += " -DCK_PARAM_ABlockTransferSrcAccessOrder=" + + std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "," + + std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "," + + std::to_string(pt->ABlockTransferSrcAccessOrder[2]); + + out += + " -DCK_PARAM_ABlockTransferSrcVectorDim=" + std::to_string(pt->ABlockTransferSrcVectorDim); + out += " -DCK_PARAM_ABlockTransferSrcScalarPerVector=" + + std::to_string(pt->ABlockTransferSrcScalarPerVector); + out += " -DCK_PARAM_ABlockTransferDstScalarPerVector_K1=" + + std::to_string(pt->ABlockTransferDstScalarPerVector_K1); + out += " -DCK_PARAM_AThreadTransferSrcResetCoordinateAfterRun=" + + std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun); + + out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1=" + + std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "," + + std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "," + + std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]); + + out += " -DCK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1=" + + std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "," + + std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "," + + std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]); + + out += " -DCK_PARAM_BBlockTransferThreadClusterArrangeOrder=" + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "," + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "," + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]); + + out += " -DCK_PARAM_BBlockTransferSrcAccessOrder=" + + std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "," + + std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "," + + std::to_string(pt->BBlockTransferSrcAccessOrder[2]); + + out += + " -DCK_PARAM_BBlockTransferSrcVectorDim=" + std::to_string(pt->BBlockTransferSrcVectorDim); + out += " -DCK_PARAM_BBlockTransferSrcScalarPerVector=" + + std::to_string(pt->BBlockTransferSrcScalarPerVector); + out += " -DCK_PARAM_BBlockTransferDstScalarPerVector_K1=" + + std::to_string(pt->BBlockTransferDstScalarPerVector_K1); + out += " -DCK_PARAM_BThreadTransferSrcResetCoordinateAfterRun=" + + std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun); + + out += " -DCK_PARAM_CThreadTransferSrcDstAccessOrder=" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]); + + out += " -DCK_PARAM_CThreadTransferSrcDstVectorDim=" + + std::to_string(pt->CThreadTransferSrcDstVectorDim); + out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" + + std::to_string(pt->CThreadTransferDstScalarPerVector); + + return (out); +}; + +} // namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw + +template +void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( + olCompile::Handle* handle, + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* tunable, + ck::index_t nrepeat) +{ + using namespace ck; + using namespace ck_driver; + using namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw; + using size_t = std::size_t; + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); + + const auto n = in_n_c_hi_wi_desc.GetLength(I0); + const auto c = in_n_c_hi_wi_desc.GetLength(I1); + const auto hi = in_n_c_hi_wi_desc.GetLength(I2); + const auto wi = in_n_c_hi_wi_desc.GetLength(I3); + const auto k = wei_k_c_y_x_desc.GetLength(I0); + const auto y = wei_k_c_y_x_desc.GetLength(I2); + const auto x = wei_k_c_y_x_desc.GetLength(I3); + const auto ho = out_n_k_ho_wo_desc.GetLength(I2); + const auto wo = out_n_k_ho_wo_desc.GetLength(I3); + + const auto M = k; + const auto N = n * ho * wo; + const auto K = c * y * x; + const auto K0 = K / tunable->K1; + + const index_t grid_size = (M / tunable->MPerBlock) * (N / tunable->NPerBlock); + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + + // these buffers are usually provided by the user application + DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + // these are workspace buffers that should be expressed to the user by the corresponding + // workspace API + DeviceMem workspace_buf(4096); + + void* a_k_m0_m1_grid_desc_dev_buf = workspace_buf.GetDeviceBuffer(); + void* b_k_n0_n1_grid_desc_dev_buf = + static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 1024); + void* c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf = + static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 2048); + void* c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf = + static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 3072); + + const std::vector vld = {static_cast(tunable->BlockSize), 1, 1}; + const std::vector vgd1 = {static_cast(tunable->BlockSize), 1, 1}; + const std::vector vgd2 = {static_cast(grid_size * tunable->BlockSize), 1, 1}; + + std::string program_name = + "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp"; + std::string algo_name = "implicit_gemm_conv_fwd_v4r4_xdlops_nchw"; + + std::string param = " -std=c++17 "; + std::string network_config; + + param += get_definition_string_from_types() + " " + " -DCK_USE_AMD_XDLOPS" + + get_definition_string_from_tunable(tunable); + + network_config = get_network_config_string_from_types() + "_" + + get_network_config_string_from_tunable(tunable); + + std::vector kernel1_times; + std::vector kernel2_times; + + for(index_t i = 0; i < nrepeat; ++i) + { + KernelTimer timer1, timer2; + std::string kernel_name; + + kernel_name = + "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare"; + auto network_config_1 = network_config + "_1"; + + timer1.Start(); + handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)( + static_cast(in_n_c_hi_wi_lengths[I0]), + static_cast(in_n_c_hi_wi_lengths[I1]), + static_cast(in_n_c_hi_wi_lengths[I2]), + static_cast(in_n_c_hi_wi_lengths[I3]), + static_cast(wei_k_c_y_x_lengths[I0]), + static_cast(wei_k_c_y_x_lengths[I2]), + static_cast(wei_k_c_y_x_lengths[I3]), + conv_strides[I0], + conv_strides[I1], + conv_dilations[I0], + conv_dilations[I1], + in_left_pads[I0], + in_left_pads[I1], + in_right_pads[I0], + in_right_pads[I1], + a_k_m0_m1_grid_desc_dev_buf, + b_k_n0_n1_grid_desc_dev_buf, + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf, + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf); + timer1.End(); + + kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw"; + auto network_config_2 = network_config + "_2"; + + timer2.Start(); + handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)( + reinterpret_cast(wei_k_c_y_x_dev_buf.GetDeviceBuffer()), + reinterpret_cast(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()), + reinterpret_cast(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()), + (const void*)(a_k_m0_m1_grid_desc_dev_buf), + (const void*)(b_k_n0_n1_grid_desc_dev_buf), + (const void*)(c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf), + (const void*)(c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf)); + timer2.End(); + + kernel1_times.push_back(timer1.GetElapsedTime()); + kernel2_times.push_back(timer2.GetElapsedTime()); + } + + { + auto ave_time1 = + std::accumulate( + std::next(kernel1_times.begin()), kernel1_times.end(), 0., std::plus{}) / + (nrepeat - 1); + auto ave_time2 = + std::accumulate( + std::next(kernel2_times.begin()), kernel2_times.end(), 0., std::plus{}) / + (nrepeat - 1); + + const auto N = in_n_c_hi_wi_lengths[I0]; + const auto C = in_n_c_hi_wi_lengths[I1]; + + const auto K = out_n_k_ho_wo_lengths[I1]; + const auto Ho = out_n_k_ho_wo_lengths[I2]; + const auto Wo = out_n_k_ho_wo_lengths[I3]; + + const auto Y = wei_k_c_y_x_lengths[I2]; + const auto X = wei_k_c_y_x_lengths[I3]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2); + + std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", " + << ave_time2 << "), " << perf << " TFlop/s" << std::endl; + }; + + // copy result back to host + out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data()); +} diff --git a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..8eed1a9934 --- /dev/null +++ b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,389 @@ +#include "device.hpp" +#include "host_tensor.hpp" +#include "handle.hpp" +#include "online_driver_common.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" +#include "conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp" + +namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk { + +template +static std::string get_network_config_string_from_types() +{ + using namespace ck; + + std::string out; + + out += std::to_string(get_datatype_enum_from_type::value) + "_" + + std::to_string(get_datatype_enum_from_type::value) + "_" + + std::to_string(get_datatype_enum_from_type::value); + + return (out); +}; + +static std::string +get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* pt) +{ + std::string out("TUN_"); + + out += std::to_string(pt->BlockSize) + "_"; + + out += std::to_string(pt->MPerBlock) + "x" + std::to_string(pt->NPerBlock) + "x" + + std::to_string(pt->KPerBlock) + "_"; + out += std::to_string(pt->MPerWave) + "x" + std::to_string(pt->NPerWave) + "x" + + std::to_string(pt->MRepeat) + "x" + std::to_string(pt->NRepeat) + "x" + + std::to_string(pt->K1) + "_"; + + out += std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "x" + + std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "x" + + std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]) + "_"; + + out += std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "x" + + std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "x" + + std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]) + "_"; + + out += std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "x" + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "x" + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "_"; + + out += std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "x" + + std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "x" + + std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "_"; + + out += std::to_string(pt->ABlockTransferSrcVectorDim) + "_"; + out += std::to_string(pt->ABlockTransferSrcScalarPerVector) + "_"; + out += std::to_string(pt->ABlockTransferDstScalarPerVector_K1) + "_"; + out += std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun) + "_"; + + out += std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "x" + + std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "x" + + std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]) + "_"; + + out += std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "x" + + std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "x" + + std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]) + "_"; + + out += std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "x" + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "x" + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "_"; + + out += std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "x" + + std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "x" + + std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "_"; + + out += std::to_string(pt->BBlockTransferSrcVectorDim) + "_"; + out += std::to_string(pt->BBlockTransferSrcScalarPerVector) + "_"; + out += std::to_string(pt->BBlockTransferDstScalarPerVector_K1) + "_"; + out += std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun) + "_"; + + out += std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "x" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]) + "_"; + + out += std::to_string(pt->CThreadTransferSrcDstVectorDim) + "_"; + out += std::to_string(pt->CThreadTransferDstScalarPerVector); + + return (out); +}; + +template +static std::string get_definition_string_from_types() +{ + using namespace ck; + + std::string out; + + out += + " -DCK_PARAM_ABDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value) + + " -DCK_PARAM_AccDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value) + + " -DCK_PARAM_CDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value); + + return (out); +}; + +static std::string +get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* pt) +{ + std::string out; + + out += " -DCK_PARAM_BlockSize=" + std::to_string(pt->BlockSize); + + out += " -DCK_PARAM_MPerBlock=" + std::to_string(pt->MPerBlock) + + " -DCK_PARAM_NPerBlock=" + std::to_string(pt->NPerBlock) + + " -DCK_PARAM_KPerBlock=" + std::to_string(pt->KPerBlock); + out += " -DCK_PARAM_MPerWave=" + std::to_string(pt->MPerWave) + + " -DCK_PARAM_NPerWave=" + std::to_string(pt->NPerWave) + + " -DCK_PARAM_K1=" + std::to_string(pt->K1) + + " -DCK_PARAM_MRepeat=" + std::to_string(pt->MRepeat) + + " -DCK_PARAM_NRepeat=" + std::to_string(pt->NRepeat); + + out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1=" + + std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "," + + std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "," + + std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]); + + out += " -DCK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1=" + + std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "," + + std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "," + + std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]); + + out += " -DCK_PARAM_ABlockTransferThreadClusterArrangeOrder=" + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "," + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "," + + std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]); + + out += " -DCK_PARAM_ABlockTransferSrcAccessOrder=" + + std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "," + + std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "," + + std::to_string(pt->ABlockTransferSrcAccessOrder[2]); + + out += + " -DCK_PARAM_ABlockTransferSrcVectorDim=" + std::to_string(pt->ABlockTransferSrcVectorDim); + out += " -DCK_PARAM_ABlockTransferSrcScalarPerVector=" + + std::to_string(pt->ABlockTransferSrcScalarPerVector); + out += " -DCK_PARAM_ABlockTransferDstScalarPerVector_K1=" + + std::to_string(pt->ABlockTransferDstScalarPerVector_K1); + out += " -DCK_PARAM_AThreadTransferSrcResetCoordinateAfterRun=" + + std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun); + + out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1=" + + std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "," + + std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "," + + std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]); + + out += " -DCK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1=" + + std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "," + + std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "," + + std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]); + + out += " -DCK_PARAM_BBlockTransferThreadClusterArrangeOrder=" + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "," + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "," + + std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]); + + out += " -DCK_PARAM_BBlockTransferSrcAccessOrder=" + + std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "," + + std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "," + + std::to_string(pt->BBlockTransferSrcAccessOrder[2]); + + out += + " -DCK_PARAM_BBlockTransferSrcVectorDim=" + std::to_string(pt->BBlockTransferSrcVectorDim); + out += " -DCK_PARAM_BBlockTransferSrcScalarPerVector=" + + std::to_string(pt->BBlockTransferSrcScalarPerVector); + out += " -DCK_PARAM_BBlockTransferDstScalarPerVector_K1=" + + std::to_string(pt->BBlockTransferDstScalarPerVector_K1); + out += " -DCK_PARAM_BThreadTransferSrcResetCoordinateAfterRun=" + + std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun); + + out += " -DCK_PARAM_CThreadTransferSrcDstAccessOrder=" + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "," + + std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]); + + out += " -DCK_PARAM_CThreadTransferSrcDstVectorDim=" + + std::to_string(pt->CThreadTransferSrcDstVectorDim); + out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" + + std::to_string(pt->CThreadTransferDstScalarPerVector); + + return (out); +}; + +} // namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk + +template +void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk( + olCompile::Handle* handle, + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + Tensor& out_n_ho_wo_k, + const tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* tunable, + ck::index_t nrepeat) +{ + using namespace ck; + using namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk; + using size_t = std::size_t; + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + // The follow codes are only used for computing the grid_size, hasMainKBlockLoop, + // hasDoubleTailKBlockLoop + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto in_n_hi_wi_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + + const auto n = in_n_hi_wi_c_desc.GetLength(I0); + const auto hi = in_n_hi_wi_c_desc.GetLength(I1); + const auto wi = in_n_hi_wi_c_desc.GetLength(I2); + const auto c = in_n_hi_wi_c_desc.GetLength(I3); + + const auto k = wei_k_y_x_c_desc.GetLength(I0); + const auto y = wei_k_y_x_c_desc.GetLength(I1); + const auto x = wei_k_y_x_c_desc.GetLength(I2); + + const auto ho = out_n_ho_wo_k_desc.GetLength(I1); + const auto wo = out_n_ho_wo_k_desc.GetLength(I2); + + const auto M = k; + const auto N = n * ho * wo; + const auto K = c * y * x; + const auto K0 = K / tunable->K1; + + const index_t grid_size = (M / tunable->MPerBlock) * (N / tunable->NPerBlock); + + // these buffers are usually provided by the user application + DeviceMem in_n_hi_wi_c_dev_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_dev_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_dev_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_dev_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_dev_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_dev_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + // these are workspace buffers that should be expressed to the user by the corresponding + // workspace API + DeviceMem workspace_buf(4096); + + void* a_k0_m_k1_grid_desc_dev_buf = workspace_buf.GetDeviceBuffer(); + void* b_k0_n_k1_grid_desc_dev_buf = + static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 1024); + void* c_m0_m1_m2_n_grid_desc_dev_buf = + static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 2048); + void* c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf = + static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 3072); + + const std::vector vld = {static_cast(tunable->BlockSize), 1, 1}; + const std::vector vgd1 = {static_cast(tunable->BlockSize), 1, 1}; + const std::vector vgd2 = {static_cast(grid_size * tunable->BlockSize), 1, 1}; + + std::string program_name = + "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp"; + std::string algo_name = "implicit_gemm_conv_fwd_v4r4_xdlops_nhwc"; + + std::string param = " -std=c++17 "; + std::string network_config; + + param += get_definition_string_from_types() + " -DCK_USE_AMD_XDLOPS "; + param += get_definition_string_from_tunable(tunable); + + network_config = get_network_config_string_from_types() + "_" + + get_network_config_string_from_tunable(tunable); + + std::vector kernel1_times; + std::vector kernel2_times; + + for(index_t i = 0; i < nrepeat; ++i) + { + KernelTimer timer1, timer2; + std::string kernel_name; + + kernel_name = + "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare"; + auto network_config_1 = network_config + "_1"; + + timer1.Start(); + handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)( + static_cast(in_n_hi_wi_c_lengths[I0]), + static_cast(in_n_hi_wi_c_lengths[I1]), + static_cast(in_n_hi_wi_c_lengths[I2]), + static_cast(in_n_hi_wi_c_lengths[I3]), + static_cast(wei_k_y_x_c_lengths[I0]), + static_cast(wei_k_y_x_c_lengths[I1]), + static_cast(wei_k_y_x_c_lengths[I2]), + conv_strides[I0], + conv_strides[I1], + conv_dilations[I0], + conv_dilations[I1], + in_left_pads[I0], + in_left_pads[I1], + in_right_pads[I0], + in_right_pads[I1], + a_k0_m_k1_grid_desc_dev_buf, + b_k0_n_k1_grid_desc_dev_buf, + c_m0_m1_m2_n_grid_desc_dev_buf, + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf); + timer1.End(); + + kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk"; + auto network_config_2 = network_config + "_2"; + + timer2.Start(); + handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)( + reinterpret_cast(in_n_hi_wi_c_dev_buf.GetDeviceBuffer()), + reinterpret_cast(wei_k_y_x_c_dev_buf.GetDeviceBuffer()), + reinterpret_cast(out_n_ho_wo_k_dev_buf.GetDeviceBuffer()), + (const void*)(a_k0_m_k1_grid_desc_dev_buf), + (const void*)(b_k0_n_k1_grid_desc_dev_buf), + (const void*)(c_m0_m1_m2_n_grid_desc_dev_buf), + (const void*)(c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf)); + timer2.End(); + + kernel1_times.push_back(timer1.GetElapsedTime()); + kernel2_times.push_back(timer2.GetElapsedTime()); + } + + { + auto ave_time1 = + std::accumulate( + std::next(kernel1_times.begin()), kernel1_times.end(), 0., std::plus{}) / + (nrepeat - 1); + auto ave_time2 = + std::accumulate( + std::next(kernel2_times.begin()), kernel2_times.end(), 0., std::plus{}) / + (nrepeat - 1); + + const auto N = in_n_hi_wi_c_lengths[I0]; + const auto C = in_n_hi_wi_c_lengths[I3]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + const auto K = out_n_ho_wo_k_lengths[I3]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time2; + + std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", " + << ave_time2 << "), " << perf << " TFlop/s" << std::endl; + }; + + // copy result back to host + out_n_ho_wo_k_dev_buf.FromDevice(out_n_ho_wo_k.mData.data()); +} diff --git a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..260c94ee0e --- /dev/null +++ b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,182 @@ +#pragma once +#include "device.hpp" +#include "host_tensor.hpp" +#include "handle.hpp" +#include "online_driver_common.hpp" +#include "convolution_problem_descriptor.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" +#include "conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp" + +template +void online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( + olCompile::Handle* handle, + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const ck_driver::CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param, + ck::index_t nrepeat) +{ + using namespace ck; + using namespace ck_driver; + using size_t = std::size_t; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + ConvolutionProblemDescriptor conv_problem_desc{in_n_c_hi_wi_lengths[I0], + out_n_k_ho_wo_lengths[I1], + in_n_c_hi_wi_lengths[I1], + wei_k_c_y_x_lengths[I2], + wei_k_c_y_x_lengths[I3], + in_n_c_hi_wi_lengths[I2], + in_n_c_hi_wi_lengths[I3], + out_n_k_ho_wo_lengths[I2], + out_n_k_ho_wo_lengths[I3], + conv_strides[I0], + conv_strides[I1], + conv_dilations[I0], + conv_dilations[I1], + in_left_pads[I0], + in_left_pads[I1], + in_right_pads[I0], + in_right_pads[I1], + get_datatype_enum_from_type::value, + get_datatype_enum_from_type::value, + get_datatype_enum_from_type::value}; + + if(!ConvIgemmFwdV6r1DlopsNchwKcyxNkhw::IsValidCompileParameter(conv_problem_desc, + compile_param)) + { + throw std::runtime_error("wrong! IsValidCompileParameter fail"); + } + + DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + // workspace is used for save transformed tensor descritpors created by prepare kernel + DeviceMem workspace_dev_buf( + ConvIgemmFwdV6r1DlopsNchwKcyxNkhw::GetWorkSpaceSize(conv_problem_desc, compile_param)); + + const auto block_size = std::size_t( + ConvIgemmFwdV6r1DlopsNchwKcyxNkhw::GetBlockSize(conv_problem_desc, compile_param)); + + const auto grid_size = std::size_t( + ConvIgemmFwdV6r1DlopsNchwKcyxNkhw::GetGridSize(conv_problem_desc, compile_param)); + + const std::vector vld1 = {1, 1, 1}; + const std::vector vgd1 = {1, 1, 1}; + + const std::vector vld2 = {static_cast(block_size), 1, 1}; + const std::vector vgd2 = {static_cast(grid_size * block_size), 1, 1}; + + std::string program_name = + "dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp"; + std::string algo_name = "implicit_gemm_conv_fwd_v6r1_dlops_nchw"; + + std::string compile_param_string = " -std=c++17 " + compile_param.GetCompileParameterString(); + std::string network_config = compile_param_string; + + std::vector kernel1_times; + std::vector kernel2_times; + + for(index_t i = 0; i < nrepeat; ++i) + { + KernelTimer timer1, timer2; + std::string kernel_name; + + kernel_name = "dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare"; + auto network_config_1 = network_config + "_1"; + + timer1.Start(); + handle->AddKernel(algo_name, + network_config_1, + program_name, + kernel_name, + vld1, + vgd1, + compile_param_string)(static_cast(in_n_c_hi_wi_lengths[I0]), + static_cast(in_n_c_hi_wi_lengths[I1]), + static_cast(in_n_c_hi_wi_lengths[I2]), + static_cast(in_n_c_hi_wi_lengths[I3]), + static_cast(wei_k_c_y_x_lengths[I0]), + static_cast(wei_k_c_y_x_lengths[I2]), + static_cast(wei_k_c_y_x_lengths[I3]), + conv_strides[I0], + conv_strides[I1], + conv_dilations[I0], + conv_dilations[I1], + in_left_pads[I0], + in_left_pads[I1], + in_right_pads[I0], + in_right_pads[I1], + (void*)(workspace_dev_buf.GetDeviceBuffer())); + timer1.End(); + + kernel_name = "dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw"; + auto network_config_2 = network_config + "_2"; + + timer2.Start(); + handle->AddKernel(algo_name, + network_config_2, + program_name, + kernel_name, + vld2, + vgd2, + compile_param_string)( + reinterpret_cast(wei_k_c_y_x_dev_buf.GetDeviceBuffer()), + reinterpret_cast(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()), + reinterpret_cast(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()), + (const void*)(workspace_dev_buf.GetDeviceBuffer())); + timer2.End(); + + kernel1_times.push_back(timer1.GetElapsedTime()); + kernel2_times.push_back(timer2.GetElapsedTime()); + } + + { + auto ave_time1 = + std::accumulate( + std::next(kernel1_times.begin()), kernel1_times.end(), 0., std::plus{}) / + (nrepeat - 1); + auto ave_time2 = + std::accumulate( + std::next(kernel2_times.begin()), kernel2_times.end(), 0., std::plus{}) / + (nrepeat - 1); + + float perf = (float)(conv_problem_desc.CalculateFlop()) / + (std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2); + + std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", " + << ave_time2 << "), " << perf << " TFlop/s" << std::endl; + }; + + // copy result back to host + out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data()); +} diff --git a/host/driver_online/include/online_driver_common.hpp b/host/driver_online/include/online_driver_common.hpp new file mode 100644 index 0000000000..472ffb52dc --- /dev/null +++ b/host/driver_online/include/online_driver_common.hpp @@ -0,0 +1,44 @@ +#ifndef ONLINE_DRIVER_COMMON_HPP +#define ONLINE_DRIVER_COMMON_HPP + +namespace ck_driver { + +// greatest common divisor, aka highest common factor +inline int gcd(int x, int y) +{ + if(x < 0) + { + return gcd(-x, y); + } + else if(y < 0) + { + return gcd(x, -y); + } + else if(x == y || x == 0) + { + return y; + } + else if(y == 0) + { + return x; + } + else if(x > y) + { + return gcd(x % y, y); + } + else + { + return gcd(x, y % x); + } +} + +template = 2, bool>::type = false> +auto gcd(X x, Ys... ys) +{ + return gcd(x, gcd(ys...)); +} + +} // namespace ck_driver +#endif diff --git a/host/host_tensor/CMakeLists.txt b/host/host_tensor/CMakeLists.txt new file mode 100644 index 0000000000..9c30275220 --- /dev/null +++ b/host/host_tensor/CMakeLists.txt @@ -0,0 +1,19 @@ +include_directories(BEFORE + include +) + +set(HOST_TENSOR_SOURCE + src/host_tensor.cpp; + src/device.cpp; +) + +## the library target +add_library(host_tensor SHARED ${HOST_TENSOR_SOURCE}) + +target_link_libraries(host_tensor PRIVATE hip::device) +target_link_libraries(host_tensor INTERFACE hip::host) + +target_compile_features(host_tensor PUBLIC) +set_target_properties(host_tensor PROPERTIES POSITION_INDEPENDENT_CODE ON) + +install(TARGETS host_tensor LIBRARY DESTINATION lib) diff --git a/host/host_tensor/include/conv_common.hpp b/host/host_tensor/include/conv_common.hpp new file mode 100644 index 0000000000..73126b3c79 --- /dev/null +++ b/host/host_tensor/include/conv_common.hpp @@ -0,0 +1,86 @@ +#ifndef CONV_COMMON_HPP +#define CONV_COMMON_HPP + +#include "dynamic_tensor_descriptor.hpp" + +enum ConvTensorLayout +{ + NCHW, + NHWC, + CHWN, + NCHWc, + NHWCc +}; + +template +constexpr auto get_convolution_output_default_4d_tensor_descriptor( + const ck::DynamicTensorDescriptor& in_desc, + const ck::DynamicTensorDescriptor& wei_desc, + const ConvStrides& conv_strides, + const ConvDilations conv_dilations, + const LeftPads& left_pads, + const RightPads& right_pads) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + assert(in_desc.GetNumOfDimension() == 4); + assert(wei_desc.GetNumOfDimension() == 4); + assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1)); + + const auto N = in_desc.GetLength(I0); + const auto Hi = in_desc.GetLength(I2); + const auto Wi = in_desc.GetLength(I3); + + const auto K = wei_desc.GetLength(I0); + const auto Y = wei_desc.GetLength(I2); + const auto X = wei_desc.GetLength(I3); + + const auto LeftPadH = left_pads[I0]; + const auto LeftPadW = left_pads[I1]; + + const auto RightPadH = right_pads[I0]; + const auto RightPadW = right_pads[I1]; + + const auto YEff = (Y - I1) * conv_dilations[I0] + I1; + const auto XEff = (X - I1) * conv_dilations[I1] + I1; + + const auto Ho = (Hi + LeftPadH + RightPadH - YEff) / conv_strides[I0] + I1; + const auto Wo = (Wi + LeftPadW + RightPadW - XEff) / conv_strides[I1] + I1; + + return make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo)); +} + +template +constexpr std::size_t +calculate_convolution_flops(const InDesc& in_desc, const WeiDesc& wei_desc, const OutDesc& out_desc) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const index_t N = out_desc.GetLength(I0); + const index_t K = out_desc.GetLength(I1); + const index_t Ho = out_desc.GetLength(I2); + const index_t Wo = out_desc.GetLength(I3); + + const index_t C = wei_desc.GetLength(I1); + const index_t Y = wei_desc.GetLength(I2); + const index_t X = wei_desc.GetLength(I3); + + return std::size_t(2) * N * K * Ho * Wo * C * Y * X; +} + +#endif diff --git a/host/host_tensor/include/device.hpp b/host/host_tensor/include/device.hpp new file mode 100644 index 0000000000..2299e14921 --- /dev/null +++ b/host/host_tensor/include/device.hpp @@ -0,0 +1,86 @@ +#ifndef DEVICE_HPP +#define DEVICE_HPP + +#include +#include "hip/hip_runtime.h" +#include "hip/hip_fp16.h" + +struct DeviceMem +{ + DeviceMem() = delete; + DeviceMem(std::size_t mem_size); + void* GetDeviceBuffer(); + void ToDevice(const void* p); + void FromDevice(void* p); + ~DeviceMem(); + + void* mpDeviceBuf; + std::size_t mMemSize; +}; + +struct KernelTimerImpl; + +struct KernelTimer +{ + KernelTimer(); + ~KernelTimer(); + void Start(); + void End(); + float GetElapsedTime() const; + + std::unique_ptr impl; +}; + +using device_stream_t = hipStream_t; + +template +void launch_kernel(F kernel, + dim3 grid_dim, + dim3 block_dim, + std::size_t lds_byte, + hipStream_t stream_id, + Args... args) +{ + hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); +} + +template +float launch_and_time_kernel(F kernel, + int nrepeat, + dim3 grid_dim, + dim3 block_dim, + std::size_t lds_byte, + hipStream_t stream_id, + Args... args) +{ + KernelTimer timer; + + printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + printf("Warm up\n"); + + // warm up + hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); + + printf("Start running %d times...\n", nrepeat); + + timer.Start(); + + for(int i = 0; i < nrepeat; ++i) + { + hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); + } + + timer.End(); + + return timer.GetElapsedTime() / nrepeat; +} + +#endif diff --git a/host/host_tensor/include/device_tensor.hpp b/host/host_tensor/include/device_tensor.hpp new file mode 100644 index 0000000000..1a7a34a4cf --- /dev/null +++ b/host/host_tensor/include/device_tensor.hpp @@ -0,0 +1,9 @@ +#pragma once +#include "host_tensor.hpp" +#include "common_header.hpp" + +template +void ostream_tensor_descriptor(TensorDesc, std::ostream& os = std::cout) +{ + ostream_HostTensorDescriptor(make_HostTensorDescriptor(TensorDesc{}), os); +} diff --git a/host/host_tensor/include/host_conv.hpp b/host/host_tensor/include/host_conv.hpp new file mode 100644 index 0000000000..7f26cb42f7 --- /dev/null +++ b/host/host_tensor/include/host_conv.hpp @@ -0,0 +1,326 @@ +#pragma once +#include "host_tensor.hpp" + +template +void host_direct_convolution(const Tensor& in, + const Tensor& wei, + Tensor& out, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const ConvTensorLayout layout = ConvTensorLayout::NCHW) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + double v = 0; + for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c) + { + for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && + wi < in.mDesc.GetLengths()[3]) + { + v += static_cast(in(n, c, hi, wi)) * + static_cast(wei(k, c, y, x)); + } + } + } + } + out(n, k, ho, wo) = v; + }; + + auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) { + double v = 0; + for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c) + { + for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int x = 0; x < wei.mDesc.GetLengths()[2]; ++x) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && + wi < in.mDesc.GetLengths()[2]) + { + v += static_cast(in(n, hi, wi, c)) * + static_cast(wei(k, y, x, c)); + } + } + } + } + out(n, ho, wo, k) = v; + }; + + switch(layout) + { + case ConvTensorLayout::NCHW: + make_ParallelTensorFunctor(f_nchw, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + break; + case ConvTensorLayout::NHWC: + make_ParallelTensorFunctor(f_nhwc, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + break; + default: throw std::runtime_error("wrong! not supported layout"); + } +} + +template +void host_winograd_3x3_convolution(const Tensor& in_nchw, + const Tensor& wei_kcyx, + Tensor& out_nkhw, + InLeftPads, + InRightPads) +{ + using namespace ck; + + constexpr std::size_t HoPerTile = 2; + constexpr std::size_t WoPerTile = 2; + + std::size_t N = in_nchw.mDesc.GetLengths()[0]; + std::size_t C = in_nchw.mDesc.GetLengths()[1]; + std::size_t HI = in_nchw.mDesc.GetLengths()[2]; + std::size_t WI = in_nchw.mDesc.GetLengths()[3]; + + std::size_t K = wei_kcyx.mDesc.GetLengths()[0]; + std::size_t Y = wei_kcyx.mDesc.GetLengths()[2]; + std::size_t X = wei_kcyx.mDesc.GetLengths()[3]; + + std::size_t HO = out_nkhw.mDesc.GetLengths()[2]; + std::size_t WO = out_nkhw.mDesc.GetLengths()[3]; + + index_t h_pad_low = InLeftPads{}.Get(Number<0>{}); + index_t w_pad_low = InLeftPads{}.Get(Number<1>{}); + + std::size_t HiPerTile = HoPerTile + Y - 1; + std::size_t WiPerTile = WoPerTile + X - 1; + + std::size_t HTile = (HO + HoPerTile - 1) / HoPerTile; + std::size_t WTile = (WO + WoPerTile - 1) / WoPerTile; + + Tensor in_hold({N, C, HTile, WTile, HiPerTile, WiPerTile}); + Tensor in_transform({N, C, HTile, WTile, HiPerTile, WiPerTile}); + Tensor wei_transform({K, C, HiPerTile, WiPerTile}); + Tensor out_transform({N, K, HTile, WTile, HiPerTile, HiPerTile}); + Tensor out_hold({N, K, HTile, WTile, HoPerTile, WoPerTile}); + + auto f_in_hold = [&](auto n, auto c, auto htile, auto wtile) { + for(int j = 0; j < HiPerTile; ++j) + { + int hi = HoPerTile * htile + j - h_pad_low; + for(int i = 0; i < WiPerTile; ++i) + { + int wi = WoPerTile * wtile + i - w_pad_low; + + if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 && + wi < in_nchw.mDesc.GetLengths()[3]) + { + in_hold(n, c, htile, wtile, j, i) = in_nchw(n, c, hi, wi); + } + else + { + in_hold(n, c, htile, wtile, j, i) = TIn(0); + } + } + } + }; + + auto f_in_transform = [&](auto n, auto c, auto htile, auto wtile) { + in_transform(n, c, htile, wtile, 0, 0) = + in_hold(n, c, htile, wtile, 0, 0) - in_hold(n, c, htile, wtile, 0, 2) - + in_hold(n, c, htile, wtile, 2, 0) + in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 0, 1) = + in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) - + in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 0, 2) = + -in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) + + in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 0, 3) = + in_hold(n, c, htile, wtile, 0, 1) - in_hold(n, c, htile, wtile, 0, 3) - + in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 3); + + in_transform(n, c, htile, wtile, 1, 0) = + in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) + + in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 1, 1) = + in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) + + in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 1, 2) = + -in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) - + in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 1, 3) = + in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) + + in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3); + + in_transform(n, c, htile, wtile, 2, 0) = + -in_hold(n, c, htile, wtile, 1, 0) + in_hold(n, c, htile, wtile, 1, 2) + + in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 2, 1) = + -in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) + + in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 2, 2) = + in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) - + in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 2, 3) = + -in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 3) + + in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3); + + in_transform(n, c, htile, wtile, 3, 0) = + in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) - + in_hold(n, c, htile, wtile, 3, 0) + in_hold(n, c, htile, wtile, 3, 2); + in_transform(n, c, htile, wtile, 3, 1) = + in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) - + in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2); + in_transform(n, c, htile, wtile, 3, 2) = + -in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) + + in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2); + in_transform(n, c, htile, wtile, 3, 3) = + in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) - + in_hold(n, c, htile, wtile, 3, 1) + in_hold(n, c, htile, wtile, 3, 3); + }; + + auto f_wei_transform = [&](auto k, auto c) { + wei_transform(k, c, 0, 0) = double(wei_kcyx(k, c, 0, 0)); + wei_transform(k, c, 0, 1) = 0.5 * double(wei_kcyx(k, c, 0, 0)) + + 0.5 * double(wei_kcyx(k, c, 0, 1)) + + 0.5 * double(wei_kcyx(k, c, 0, 2)); + wei_transform(k, c, 0, 2) = 0.5 * double(wei_kcyx(k, c, 0, 0)) - + 0.5 * double(wei_kcyx(k, c, 0, 1)) + + 0.5 * double(wei_kcyx(k, c, 0, 2)); + wei_transform(k, c, 0, 3) = double(wei_kcyx(k, c, 0, 2)); + + wei_transform(k, c, 1, 0) = 0.5 * double(wei_kcyx(k, c, 0, 0)) + + 0.5 * double(wei_kcyx(k, c, 1, 0)) + + 0.5 * double(wei_kcyx(k, c, 2, 0)); + wei_transform(k, c, 1, 1) = + 0.25 * double(wei_kcyx(k, c, 0, 0)) + 0.25 * double(wei_kcyx(k, c, 0, 1)) + + 0.25 * double(wei_kcyx(k, c, 0, 2)) + 0.25 * double(wei_kcyx(k, c, 1, 0)) + + 0.25 * double(wei_kcyx(k, c, 1, 1)) + 0.25 * double(wei_kcyx(k, c, 1, 2)) + + 0.25 * double(wei_kcyx(k, c, 2, 0)) + 0.25 * double(wei_kcyx(k, c, 2, 1)) + + 0.25 * double(wei_kcyx(k, c, 2, 2)); + wei_transform(k, c, 1, 2) = + 0.25 * double(wei_kcyx(k, c, 0, 0)) - 0.25 * double(wei_kcyx(k, c, 0, 1)) + + 0.25 * double(wei_kcyx(k, c, 0, 2)) + 0.25 * double(wei_kcyx(k, c, 1, 0)) - + 0.25 * double(wei_kcyx(k, c, 1, 1)) + 0.25 * double(wei_kcyx(k, c, 1, 2)) + + 0.25 * double(wei_kcyx(k, c, 2, 0)) - 0.25 * double(wei_kcyx(k, c, 2, 1)) + + 0.25 * double(wei_kcyx(k, c, 2, 2)); + wei_transform(k, c, 1, 3) = 0.5 * double(wei_kcyx(k, c, 0, 2)) + + 0.5 * double(wei_kcyx(k, c, 1, 2)) + + 0.5 * double(wei_kcyx(k, c, 2, 2)); + + wei_transform(k, c, 2, 0) = 0.5 * double(wei_kcyx(k, c, 0, 0)) - + 0.5 * double(wei_kcyx(k, c, 1, 0)) + + 0.5 * double(wei_kcyx(k, c, 2, 0)); + wei_transform(k, c, 2, 1) = + 0.25 * double(wei_kcyx(k, c, 0, 0)) + 0.25 * double(wei_kcyx(k, c, 0, 1)) + + 0.25 * double(wei_kcyx(k, c, 0, 2)) - 0.25 * double(wei_kcyx(k, c, 1, 0)) - + 0.25 * double(wei_kcyx(k, c, 1, 1)) - 0.25 * double(wei_kcyx(k, c, 1, 2)) + + 0.25 * double(wei_kcyx(k, c, 2, 0)) + 0.25 * double(wei_kcyx(k, c, 2, 1)) + + 0.25 * double(wei_kcyx(k, c, 2, 2)); + wei_transform(k, c, 2, 2) = + 0.25 * double(wei_kcyx(k, c, 0, 0)) - 0.25 * double(wei_kcyx(k, c, 0, 1)) + + 0.25 * double(wei_kcyx(k, c, 0, 2)) - 0.25 * double(wei_kcyx(k, c, 1, 0)) + + 0.25 * double(wei_kcyx(k, c, 1, 1)) - 0.25 * double(wei_kcyx(k, c, 1, 2)) + + 0.25 * double(wei_kcyx(k, c, 2, 0)) - 0.25 * double(wei_kcyx(k, c, 2, 1)) + + 0.25 * double(wei_kcyx(k, c, 2, 2)); + wei_transform(k, c, 2, 3) = 0.5 * double(wei_kcyx(k, c, 0, 2)) - + 0.5 * double(wei_kcyx(k, c, 1, 2)) + + 0.5 * double(wei_kcyx(k, c, 2, 2)); + + wei_transform(k, c, 3, 0) = double(wei_kcyx(k, c, 2, 0)); + wei_transform(k, c, 3, 1) = 0.5 * double(wei_kcyx(k, c, 2, 0)) + + 0.5 * double(wei_kcyx(k, c, 2, 1)) + + 0.5 * double(wei_kcyx(k, c, 2, 2)); + wei_transform(k, c, 3, 2) = 0.5 * double(wei_kcyx(k, c, 2, 0)) - + 0.5 * double(wei_kcyx(k, c, 2, 1)) + + 0.5 * double(wei_kcyx(k, c, 2, 2)); + wei_transform(k, c, 3, 3) = double(wei_kcyx(k, c, 2, 2)); + }; + + auto f_out_transform = [&](auto n, auto k, auto htile, auto wtile) { + for(int j = 0; j < HiPerTile; ++j) + { + for(int i = 0; i < WiPerTile; ++i) + { + double v = 0; + for(int c = 0; c < C; ++c) + { + v += in_transform(n, c, htile, wtile, j, i) * wei_transform(k, c, j, i); + } + + out_transform(n, k, htile, wtile, j, i) = v; + } + } + }; + + auto f_out_hold = [&](auto n, auto k, auto htile, auto wtile) { + out_hold(n, k, htile, wtile, 0, 0) = + out_transform(n, k, htile, wtile, 0, 0) + out_transform(n, k, htile, wtile, 0, 1) + + out_transform(n, k, htile, wtile, 0, 2) + out_transform(n, k, htile, wtile, 1, 0) + + out_transform(n, k, htile, wtile, 1, 1) + out_transform(n, k, htile, wtile, 1, 2) + + out_transform(n, k, htile, wtile, 2, 0) + out_transform(n, k, htile, wtile, 2, 1) + + out_transform(n, k, htile, wtile, 2, 2); + out_hold(n, k, htile, wtile, 0, 1) = + out_transform(n, k, htile, wtile, 0, 1) - out_transform(n, k, htile, wtile, 0, 2) - + out_transform(n, k, htile, wtile, 0, 3) + out_transform(n, k, htile, wtile, 1, 1) - + out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 1, 3) + + out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) - + out_transform(n, k, htile, wtile, 2, 3); + out_hold(n, k, htile, wtile, 1, 0) = + out_transform(n, k, htile, wtile, 1, 0) + out_transform(n, k, htile, wtile, 1, 1) + + out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 2, 0) - + out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) - + out_transform(n, k, htile, wtile, 3, 0) - out_transform(n, k, htile, wtile, 3, 1) - + out_transform(n, k, htile, wtile, 3, 2); + out_hold(n, k, htile, wtile, 1, 1) = + out_transform(n, k, htile, wtile, 1, 1) - out_transform(n, k, htile, wtile, 1, 2) - + out_transform(n, k, htile, wtile, 1, 3) - out_transform(n, k, htile, wtile, 2, 1) + + out_transform(n, k, htile, wtile, 2, 2) + out_transform(n, k, htile, wtile, 2, 3) - + out_transform(n, k, htile, wtile, 3, 1) + out_transform(n, k, htile, wtile, 3, 2) + + out_transform(n, k, htile, wtile, 3, 3); + }; + + auto f_out = [&](auto n, auto k, auto htile, auto wtile) { + for(int j = 0; j < HoPerTile; ++j) + { + std::size_t ho = HoPerTile * htile + j; + for(int i = 0; i < WoPerTile; ++i) + { + std::size_t wo = WoPerTile * wtile + i; + out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i); + } + } + }; + + std::size_t num_thread = std::thread::hardware_concurrency(); + + make_ParallelTensorFunctor(f_in_hold, N, C, HTile, WTile)(num_thread); + make_ParallelTensorFunctor(f_in_transform, N, C, HTile, WTile)(num_thread); + make_ParallelTensorFunctor(f_wei_transform, K, C)(num_thread); + make_ParallelTensorFunctor(f_out_transform, N, K, HTile, WTile)(num_thread); + make_ParallelTensorFunctor(f_out_hold, N, K, HTile, WTile)(num_thread); + make_ParallelTensorFunctor(f_out, N, K, HTile, WTile)(num_thread); +} diff --git a/host/host_tensor/include/host_conv_bwd_data.hpp b/host/host_tensor/include/host_conv_bwd_data.hpp new file mode 100644 index 0000000000..07617c3926 --- /dev/null +++ b/host/host_tensor/include/host_conv_bwd_data.hpp @@ -0,0 +1,143 @@ +#pragma once +#include "host_tensor.hpp" + +template +void host_direct_convolution_backward_data(Tensor& in, + const Tensor& wei, + const Tensor& out, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const ConvTensorLayout layout = ConvTensorLayout::NCHW) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + auto f_nchw = [&](auto n, auto c, auto hi, auto wi) { + std::size_t N = in.mDesc.GetLengths()[I0]; + std::size_t C = in.mDesc.GetLengths()[I1]; + std::size_t Hi = in.mDesc.GetLengths()[I2]; + std::size_t Wi = in.mDesc.GetLengths()[I3]; + + std::size_t K = wei.mDesc.GetLengths()[I0]; + std::size_t Y = wei.mDesc.GetLengths()[I2]; + std::size_t X = wei.mDesc.GetLengths()[I3]; + + std::size_t Ho = out.mDesc.GetLengths()[I2]; + std::size_t Wo = out.mDesc.GetLengths()[I3]; + + double v = 0; + + for(int y = 0; y < Y; ++y) + { + int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0]; + + if(h_tmp % conv_strides[I0] == 0) + { + int ho = h_tmp / conv_strides[I0]; + + if(ho >= 0 && ho < Ho) + { + for(int x = 0; x < X; ++x) + { + int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1]; + + if(w_tmp % conv_strides[I1] == 0) + { + int wo = w_tmp / conv_strides[I1]; + + if(wo >= 0 && wo < Wo) + { + for(int k = 0; k < K; ++k) + { + v += out(n, k, ho, wo) * wei(k, c, y, x); + } + } + } + } + } + } + } + + in(n, c, hi, wi) = v; + }; + + auto f_nhwc = [&](auto n, auto hi, auto wi, auto c) { + std::size_t N = in.mDesc.GetLengths()[I0]; + std::size_t Hi = in.mDesc.GetLengths()[I1]; + std::size_t Wi = in.mDesc.GetLengths()[I2]; + std::size_t C = in.mDesc.GetLengths()[I3]; + + std::size_t K = wei.mDesc.GetLengths()[I0]; + std::size_t Y = wei.mDesc.GetLengths()[I1]; + std::size_t X = wei.mDesc.GetLengths()[I2]; + + std::size_t Ho = out.mDesc.GetLengths()[I1]; + std::size_t Wo = out.mDesc.GetLengths()[I2]; + + double v = 0; + + for(int y = 0; y < Y; ++y) + { + int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0]; + + if(h_tmp % conv_strides[I0] == 0) + { + int ho = h_tmp / conv_strides[I0]; + + if(ho >= 0 && ho < Ho) + { + for(int x = 0; x < X; ++x) + { + int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1]; + + if(w_tmp % conv_strides[I1] == 0) + { + int wo = w_tmp / conv_strides[I1]; + + if(wo >= 0 && wo < Wo) + { + for(int k = 0; k < K; ++k) + { + v += out(n, ho, wo, k) * wei(k, y, x, c); + } + } + } + } + } + } + } + + in(n, hi, wi, c) = v; + }; + + switch(layout) + { + case ConvTensorLayout::NCHW: + make_ParallelTensorFunctor(f_nchw, + in.mDesc.GetLengths()[0], + in.mDesc.GetLengths()[1], + in.mDesc.GetLengths()[2], + in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + break; + case ConvTensorLayout::NHWC: + make_ParallelTensorFunctor(f_nhwc, + in.mDesc.GetLengths()[0], + in.mDesc.GetLengths()[1], + in.mDesc.GetLengths()[2], + in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + break; + default: throw std::runtime_error("wrong! not supported layout"); + } +} diff --git a/host/host_tensor/include/host_tensor.hpp b/host/host_tensor/include/host_tensor.hpp new file mode 100644 index 0000000000..70778a4a94 --- /dev/null +++ b/host/host_tensor/include/host_tensor.hpp @@ -0,0 +1,322 @@ +#ifndef HOST_TENSOR_HPP +#define HOST_TENSOR_HPP + +#include +#include +#include +#include +#include +#include +#include + +template +std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) +{ + bool first = true; + for(auto&& v : range) + { + if(first) + first = false; + else + os << delim; + os << v; + } + return os; +} + +template +std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) +{ + bool first = true; + for(auto&& v : range) + { + if(first) + first = false; + else + os << delim; + os << T{v}; + } + return os; +} + +typedef enum +{ + Half = 0, + Float = 1, +} DataType_t; + +template +struct DataType; + +template <> +struct DataType : std::integral_constant +{ +}; + +template +auto call_f_unpack_args_impl(F f, T args, std::index_sequence) +{ + return f(std::get(args)...); +} + +template +auto call_f_unpack_args(F f, T args) +{ + constexpr std::size_t N = std::tuple_size{}; + + return call_f_unpack_args_impl(f, args, std::make_index_sequence{}); +} + +template +auto construct_f_unpack_args_impl(T args, std::index_sequence) +{ + return F(std::get(args)...); +} + +template +auto construct_f_unpack_args(F, T args) +{ + constexpr std::size_t N = std::tuple_size{}; + + return construct_f_unpack_args_impl(args, std::make_index_sequence{}); +} + +struct HostTensorDescriptor +{ + HostTensorDescriptor() = delete; + + template + HostTensorDescriptor(std::vector lens); + + template + HostTensorDescriptor(std::vector lens, std::vector strides); + + void CalculateStrides(); + + template + HostTensorDescriptor(const Range& lens) : mLens(lens.begin(), lens.end()) + { + this->CalculateStrides(); + } + + template + HostTensorDescriptor(const Range1& lens, const Range2& strides) + : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) + { + } + + std::size_t GetNumOfDimension() const; + std::size_t GetElementSize() const; + std::size_t GetElementSpace() const; + + const std::vector& GetLengths() const; + const std::vector& GetStrides() const; + + template + std::size_t GetOffsetFromMultiIndex(Is... is) const + { + assert(sizeof...(Is) == this->GetNumOfDimension()); + std::initializer_list iss{static_cast(is)...}; + return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); + } + + private: + std::vector mLens; + std::vector mStrides; +}; + +struct joinable_thread : std::thread +{ + template + joinable_thread(Xs&&... xs) : std::thread(std::forward(xs)...) + { + } + + joinable_thread(joinable_thread&&) = default; + joinable_thread& operator=(joinable_thread&&) = default; + + ~joinable_thread() + { + if(this->joinable()) + this->join(); + } +}; + +template +struct ParallelTensorFunctor +{ + F mF; + static constexpr std::size_t NDIM = sizeof...(Xs); + std::array mLens; + std::array mStrides; + std::size_t mN1d; + + ParallelTensorFunctor(F f, Xs... xs) : mF(f), mLens({static_cast(xs)...}) + { + mStrides.back() = 1; + std::partial_sum(mLens.rbegin(), + mLens.rend() - 1, + mStrides.rbegin() + 1, + std::multiplies()); + mN1d = mStrides[0] * mLens[0]; + } + + std::array GetNdIndices(std::size_t i) const + { + std::array indices; + + for(int idim = 0; idim < NDIM; ++idim) + { + indices[idim] = i / mStrides[idim]; + i -= indices[idim] * mStrides[idim]; + } + + return indices; + } + + void operator()(std::size_t num_thread = std::thread::hardware_concurrency()) const + { + std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread; + + std::vector threads(num_thread); + + for(std::size_t it = 0; it < num_thread; ++it) + { + std::size_t iw_begin = it * work_per_thread; + std::size_t iw_end = std::min((it + 1) * work_per_thread, mN1d); + + auto f = [=] { + for(std::size_t iw = iw_begin; iw < iw_end; ++iw) + { + call_f_unpack_args(mF, GetNdIndices(iw)); + } + }; + threads[it] = joinable_thread(f); + } + } +}; + +template +auto make_ParallelTensorFunctor(F f, Xs... xs) +{ + return ParallelTensorFunctor(f, xs...); +} + +template +struct Tensor +{ + template + Tensor(std::initializer_list lens) : mDesc(lens), mData(mDesc.GetElementSpace()) + { + } + + template + Tensor(std::vector lens) : mDesc(lens), mData(mDesc.GetElementSpace()) + { + } + + template + Tensor(std::vector lens, std::vector strides) + : mDesc(lens, strides), mData(mDesc.GetElementSpace()) + { + } + + Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {} + + template + void GenerateTensorValue(G g, std::size_t num_thread = 1) + { + switch(mDesc.GetNumOfDimension()) + { + case 1: { + auto f = [&](auto i) { (*this)(i) = g(i); }; + make_ParallelTensorFunctor(f, mDesc.GetLengths()[0])(num_thread); + break; + } + case 2: { + auto f = [&](auto i0, auto i1) { (*this)(i0, i1) = g(i0, i1); }; + make_ParallelTensorFunctor(f, mDesc.GetLengths()[0], mDesc.GetLengths()[1])(num_thread); + break; + } + case 3: { + auto f = [&](auto i0, auto i1, auto i2) { (*this)(i0, i1, i2) = g(i0, i1, i2); }; + make_ParallelTensorFunctor( + f, mDesc.GetLengths()[0], mDesc.GetLengths()[1], mDesc.GetLengths()[2])(num_thread); + break; + } + case 4: { + auto f = [&](auto i0, auto i1, auto i2, auto i3) { + (*this)(i0, i1, i2, i3) = g(i0, i1, i2, i3); + }; + make_ParallelTensorFunctor(f, + mDesc.GetLengths()[0], + mDesc.GetLengths()[1], + mDesc.GetLengths()[2], + mDesc.GetLengths()[3])(num_thread); + break; + } + default: throw std::runtime_error("unspported dimension"); + } + } + + template + T& operator()(Is... is) + { + return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + } + + template + const T& operator()(Is... is) const + { + return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + } + + typename std::vector::iterator begin() { return mData.begin(); } + + typename std::vector::iterator end() { return mData.end(); } + + typename std::vector::const_iterator begin() const { return mData.begin(); } + + typename std::vector::const_iterator end() const { return mData.end(); } + + HostTensorDescriptor mDesc; + std::vector mData; +}; + +template +HostTensorDescriptor::HostTensorDescriptor(std::vector lens) : mLens(lens) +{ + this->CalculateStrides(); +} + +template +HostTensorDescriptor::HostTensorDescriptor(std::vector lens, std::vector strides) + : mLens(lens), mStrides(strides) +{ +} + +void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout); + +template +void check_error(const Tensor& ref, const Tensor& result) +{ + float error = 0; + float max_diff = -1; + float ref_value = 0, result_value = 0; + for(int i = 0; i < ref.mData.size(); ++i) + { + error += std::abs(double(ref.mData[i]) - double(result.mData[i])); + float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); + if(max_diff < diff) + { + max_diff = diff; + ref_value = ref.mData[i]; + result_value = result.mData[i]; + } + } + + std::cout << "error: " << error << std::endl; + std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl; +} + +#endif diff --git a/host/host_tensor/include/host_tensor_generator.hpp b/host/host_tensor/include/host_tensor_generator.hpp new file mode 100644 index 0000000000..98192e066f --- /dev/null +++ b/host/host_tensor/include/host_tensor_generator.hpp @@ -0,0 +1,60 @@ +#ifndef HOST_TENSOR_GENERATOR_HPP +#define HOST_TENSOR_GENERATOR_HPP + +#include +#include "config.hpp" + +struct GeneratorTensor_1 +{ + int value = 1; + + template + float operator()(Is... is) + { + return value; + } +}; + +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + float operator()(Is...) + { + return (std::rand() % (max_value - min_value)) + min_value; + } +}; + +template +struct GeneratorTensor_3 +{ + T min_value = 0; + T max_value = 1; + + template + float operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + return min_value + tmp * (max_value - min_value); + } +}; + +struct GeneratorTensor_Checkboard +{ + template + float operator()(Ts... Xs) const + { + std::array dims = {{static_cast(Xs)...}}; + return std::accumulate(dims.begin(), + dims.end(), + true, + [](bool init, ck::index_t x) -> int { return init != (x % 2); }) + ? 1 + : -1; + } +}; + +#endif diff --git a/host/host_tensor/src/device.cpp b/host/host_tensor/src/device.cpp new file mode 100644 index 0000000000..d0d74a4c2a --- /dev/null +++ b/host/host_tensor/src/device.cpp @@ -0,0 +1,67 @@ +#include "device.hpp" + +DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) +{ + hipGetErrorString(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); +} + +void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; } + +void DeviceMem::ToDevice(const void* p) +{ + hipGetErrorString( + hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); +} + +void DeviceMem::FromDevice(void* p) +{ + hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); +} + +DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); } + +struct KernelTimerImpl +{ + KernelTimerImpl() + { + hipEventCreate(&mStart); + hipEventCreate(&mEnd); + } + + ~KernelTimerImpl() + { + hipEventDestroy(mStart); + hipEventDestroy(mEnd); + } + + void Start() + { + hipDeviceSynchronize(); + hipEventRecord(mStart, 0); + } + + void End() + { + hipEventRecord(mEnd, 0); + hipEventSynchronize(mEnd); + } + + float GetElapsedTime() const + { + float time; + hipEventElapsedTime(&time, mStart, mEnd); + return time; + } + + hipEvent_t mStart, mEnd; +}; + +KernelTimer::KernelTimer() : impl(new KernelTimerImpl()) {} + +KernelTimer::~KernelTimer() {} + +void KernelTimer::Start() { impl->Start(); } + +void KernelTimer::End() { impl->End(); } + +float KernelTimer::GetElapsedTime() const { return impl->GetElapsedTime(); } diff --git a/host/host_tensor/src/host_tensor.cpp b/host/host_tensor/src/host_tensor.cpp new file mode 100644 index 0000000000..e840baf7f5 --- /dev/null +++ b/host/host_tensor/src/host_tensor.cpp @@ -0,0 +1,48 @@ +#include +#include + +#include "host_tensor.hpp" + +void HostTensorDescriptor::CalculateStrides() +{ + mStrides.clear(); + mStrides.resize(mLens.size(), 0); + if(mStrides.empty()) + return; + + mStrides.back() = 1; + std::partial_sum( + mLens.rbegin(), mLens.rend() - 1, mStrides.rbegin() + 1, std::multiplies()); +} + +std::size_t HostTensorDescriptor::GetNumOfDimension() const { return mLens.size(); } + +std::size_t HostTensorDescriptor::GetElementSize() const +{ + assert(mLens.size() == mStrides.size()); + return std::accumulate( + mLens.begin(), mLens.end(), std::size_t{1}, std::multiplies()); +} + +std::size_t HostTensorDescriptor::GetElementSpace() const +{ + auto ls = mLens | boost::adaptors::transformed([](std::size_t v) { return v - 1; }); + return std::inner_product(ls.begin(), ls.end(), mStrides.begin(), std::size_t{0}) + 1; +} + +const std::vector& HostTensorDescriptor::GetLengths() const { return mLens; } + +const std::vector& HostTensorDescriptor::GetStrides() const { return mStrides; } + +void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os) +{ + os << "dim " << desc.GetNumOfDimension() << ", "; + + os << "lengths {"; + LogRange(os, desc.GetLengths(), ", "); + os << "}, "; + + os << "strides {"; + LogRange(os, desc.GetStrides(), ", "); + os << "}" << std::endl; +} diff --git a/host/online_compilation/CMakeLists.txt b/host/online_compilation/CMakeLists.txt new file mode 100644 index 0000000000..02f6795308 --- /dev/null +++ b/host/online_compilation/CMakeLists.txt @@ -0,0 +1,168 @@ +set(CMAKE_CXX_COMPILER /opt/rocm/llvm/bin/clang++) + +## for online-compiling of HIP kernels +set(OLC_HIP_COMPILER ${CMAKE_CXX_COMPILER} CACHE PATH "") + +## reset to avoid the C++ options from the parent project +set(CMAKE_CXX_FLAGS "") +message("Compiling options for library and kernels: ${CMAKE_CXX_FLAGS}") + +# look for and register clang-offload-bundler +if(OLC_HIP_COMPILER MATCHES ".*clang\\+\\+$") + find_program(OLC_OFFLOADBUNDLER_BIN clang-offload-bundler + PATH_SUFFIXES bin + PATHS + /opt/rocm/llvm + ${CMAKE_INSTALL_PREFIX}/llvm + ) +endif() + +if(OLC_OFFLOADBUNDLER_BIN) + message(STATUS "clang-offload-bundler found: ${OLC_OFFLOADBUNDLER_BIN}") + set(OLC_OFFLOADBUNDLER_BIN "${OLC_OFFLOADBUNDLER_BIN}") +else() + # look for and register extractkernel + message(STATUS "clang-offload-bundler not found") + + find_program(EXTRACTKERNEL_BIN extractkernel + PATH_SUFFIXES bin + PATHS + /opt/rocm/hip + /opt/rocm/hcc + /opt/rocm + ${CMAKE_INSTALL_PREFIX}/hip + ${CMAKE_INSTALL_PREFIX}/hcc + ${CMAKE_INSTALL_PREFIX} + + ) + if(EXTRACTKERNEL_BIN) + message(STATUS "extractkernel found: ${EXTRACTKERNEL_BIN}") + set(EXTRACTKERNEL_BIN "${EXTRACTKERNEL_BIN}") + else() + message(FATAL_ERROR "extractkernel not found") + endif() +endif() + +option(Boost_USE_STATIC_LIBS "Use boost static libraries" OFF) +set(BOOST_COMPONENTS filesystem) +add_definitions(-DBOOST_ALL_NO_LIB=1) +find_package(Boost REQUIRED COMPONENTS ${BOOST_COMPONENTS}) + +# HIP is always required +find_package(hip REQUIRED PATHS /opt/rocm) +message(STATUS "Build with HIP ${hip_VERSION}") +target_flags(HIP_COMPILER_FLAGS hip::device) +# Remove cuda arch flags +string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") +string(REGEX REPLACE --offload-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") + +set(OLC_hip_VERSION_MAJOR "${hip_VERSION_MAJOR}") +set(OLC_hip_VERSION_MINOR "${hip_VERSION_MINOR}") +set(OLC_hip_VERSION_PATCH "${hip_VERSION_PATCH}") + +option(ENABLE_DEBUG "Build to enable debugging" ON) +if(ENABLE_DEBUG) + set(OLC_DEBUG 1) +else() + set(OLC_DEBUG 0) +endif() + +configure_file("${PROJECT_SOURCE_DIR}/host/online_compilation/include/config.h.in" "${PROJECT_BINARY_DIR}/host/online_compilation/include/config.h") + +include_directories(BEFORE + ${PROJECT_BINARY_DIR}/host/online_compilation/include +) + +message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}") + +## HIP_COMPILER_FLAGS will be used for on-line compiling of the HIP kernels +set(HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS} ${HIP_ONLINE_COMPILER_FLAGS}") +add_definitions("-DHIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}") + +file(GLOB_RECURSE COMPOSABLE_KERNEL_INCLUDE_1 "${PROJECT_SOURCE_DIR}/composable_kernel/include/*/*.hpp") +file(GLOB COMPOSABLE_KERNEL_INCLUDE_2 "${PROJECT_SOURCE_DIR}/external/rocm/include/bfloat16_dev.hpp") +set(MCONV_KERNEL_INCLUDES + ${COMPOSABLE_KERNEL_INCLUDE_1} + ${COMPOSABLE_KERNEL_INCLUDE_2} + ) + +file(GLOB_RECURSE MCONV_KERNELS "${PROJECT_SOURCE_DIR}/composable_kernel/src/kernel_wrapper/*.cpp") + +add_kernels(${CMAKE_CURRENT_SOURCE_DIR} "${MCONV_KERNELS}") +add_kernel_includes(${CMAKE_CURRENT_SOURCE_DIR} "${MCONV_KERNEL_INCLUDES}") + +set(ONLINE_COMPILATION_SOURCE + ${PROJECT_BINARY_DIR}/kernel.cpp + ${PROJECT_BINARY_DIR}/kernel_includes.cpp +) + +include_directories(BEFORE + ${PROJECT_BINARY_DIR}/host/online_compilation/include + include +) + +set(OLC_HIP_UTILITY_CPPS + hip_utility/logger.cpp + hip_utility/tmp_dir.cpp + hip_utility/md5.cpp + hip_utility/exec_utils.cpp + hip_utility/target_properties.cpp + hip_utility/handlehip.cpp + hip_utility/kernel_build_params.cpp + hip_utility/hip_build_utils.cpp + hip_utility/hipoc_program.cpp + hip_utility/hipoc_kernel.cpp + hip_utility/kernel_cache.cpp + hip_utility/binary_cache.cpp + ) + +list(APPEND OLC_SOURCES ${OLC_HIP_UTILITY_CPPS} ${OLC_HIP_UTILITY_HEADERS}) + +## addkernels provide the tool to create inlined kernels in one header +add_subdirectory(addkernels) + +function(inline_kernels_src KERNELS KERNEL_INCLUDES) + set(KERNEL_SRC_HPP_FILENAME batch_all.cpp.hpp) + set(KERNEL_SRC_HPP_PATH ${PROJECT_BINARY_DIR}/inlined_kernels/${KERNEL_SRC_HPP_FILENAME}) + set(KERNEL_SRC_CPP_PATH ${PROJECT_BINARY_DIR}/inlined_kernels/batch_all.cpp) + + add_custom_command( + OUTPUT ${KERNEL_SRC_HPP_PATH} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + DEPENDS addkernels ${KERNELS} ${KERNEL_INCLUDES} + COMMAND $ -target ${KERNEL_SRC_HPP_PATH} -extern -source ${KERNELS} + COMMENT "Inlining All kernels" + ) + configure_file(kernels_batch.cpp.in ${KERNEL_SRC_CPP_PATH}) + list(APPEND OLC_SOURCES ${KERNEL_SRC_CPP_PATH} ${KERNEL_SRC_HPP_PATH}) + + set(OLC_SOURCES ${OLC_SOURCES} PARENT_SCOPE) +endfunction() + +inline_kernels_src("${MCONV_KERNELS}" "${MCONV_KERNEL_INCLUDES}") + +list(APPEND ONLINE_COMPILATION_SOURCE ${OLC_SOURCES} ${PROJECT_BINARY_DIR}/olc_kernel_includes.h) + +add_custom_command( + OUTPUT ${PROJECT_BINARY_DIR}/olc_kernel_includes.h + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + DEPENDS addkernels ${MCONV_KERNEL_INCLUDES} + COMMAND $ -no-recurse -guard GUARD_OLC_KERNEL_INCLUDES_HPP_ -target ${PROJECT_BINARY_DIR}/olc_kernel_includes.h -source ${MCONV_KERNEL_INCLUDES} + COMMENT "Inlining HIP kernel includes" + ) + +## the library target +add_library(online_compilation SHARED ${ONLINE_COMPILATION_SOURCE}) + +target_include_directories(online_compilation PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/online_compilation/include/) +target_include_directories(online_compilation PRIVATE ${PROJECT_BINARY_DIR}) +target_include_directories(online_compilation PRIVATE ${PROJECT_SOURCE_DIR}/external/half/include/) + +target_link_libraries(online_compilation PRIVATE hip::device) +target_link_libraries(online_compilation INTERFACE hip::host) +target_link_libraries(online_compilation PRIVATE Boost::filesystem) + +target_compile_features(online_compilation PUBLIC) +set_target_properties(online_compilation PROPERTIES POSITION_INDEPENDENT_CODE ON) + +install(TARGETS online_compilation LIBRARY DESTINATION lib) diff --git a/host/online_compilation/addkernels/CMakeLists.txt b/host/online_compilation/addkernels/CMakeLists.txt new file mode 100644 index 0000000000..874cba6a5e --- /dev/null +++ b/host/online_compilation/addkernels/CMakeLists.txt @@ -0,0 +1,30 @@ +################################################################################ +# +# MIT License +# +# Copyright (c) 2017 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################ + +set(ADD_KERNELS_SOURCE include_inliner.cpp addkernels.cpp) + +add_executable(addkernels EXCLUDE_FROM_ALL ${ADD_KERNELS_SOURCE}) + diff --git a/host/online_compilation/addkernels/addkernels.cpp b/host/online_compilation/addkernels/addkernels.cpp new file mode 100644 index 0000000000..5be523d97b --- /dev/null +++ b/host/online_compilation/addkernels/addkernels.cpp @@ -0,0 +1,264 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2021 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "include_inliner.hpp" +#include +#include +#include +#include +#include +#include +#include + +void Bin2Hex(std::istream& source, + std::ostream& target, + const std::string& variable, + bool nullTerminate, + size_t bufferSize, + size_t lineSize) +{ + source.seekg(0, std::ios::end); + std::unique_ptr buffer(new unsigned char[bufferSize]); + std::streamoff sourceSize = source.tellg(); + std::streamoff blockStart = 0; + + if(variable.length() != 0) + { + target << "extern const size_t " << variable << "_SIZE;" << std::endl; + target << "extern const unsigned char " << variable << "[];" << std::endl; + target << "const size_t " << variable << "_SIZE = " << std::setbase(10) << sourceSize << ";" + << std::endl; + target << "const unsigned char " << variable << "[] = {" << std::endl; + } + + target << std::setbase(16) << std::setfill('0'); + source.seekg(0, std::ios::beg); + + while(blockStart < sourceSize) + { + source.read(reinterpret_cast(buffer.get()), bufferSize); + + std::streamoff pos = source.tellg(); + std::streamoff blockSize = (pos < 0 ? sourceSize : pos) - blockStart; + std::streamoff i = 0; + + while(i < blockSize) + { + size_t j = i; + size_t end = std::min(i + lineSize, blockSize); + + for(; j < end; j++) + target << "0x" << std::setw(2) << static_cast(buffer[j]) << ","; + + target << std::endl; + i = end; + } + + blockStart += blockSize; + } + + if(nullTerminate) + target << "0x00," << std::endl; + + if(variable.length() != 0) + { + target << "};" << std::endl; + } +} + +void PrintHelp() +{ + std::cout << "Usage: bin2hex {