From 9c6a7438c824d87790a15e7d66b670ffefe281dc Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 26 Nov 2024 21:52:50 -0800 Subject: [PATCH] Optimzed backward pass for ROCm devices (#3367) Summary: Added optimized implementation of backward pass for ROCm devices. Currently support **not nobag** mode, **rowwise_adagrad** optimizer with non-mixed dimensions in [64, 128, 160, 192]. Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3367 Differential Revision: D66310520 Pulled By: leitian --- .../genscript/generate_backward_split.py | 22 + .../embedding_backward_dense_host_cpu.cpp | 7 +- ...embedding_backward_split_host_template.cpp | 15 +- ...ing_backward_split_kernel_warp_template.cu | 309 ++++++++++ ...embedding_backward_split_meta_template.cpp | 3 + .../embedding_backward_split_template.cu | 134 ++++- ..._backward_split_device_kernel_template.hip | 461 +++++++++++++++ .../training/python/lookup_args.template | 1 + ..._embedding_codegen_lookup_invoker.template | 1 + ...t_table_batched_embeddings_ops_training.py | 4 +- .../include/fbgemm_gpu/rocm/cdna_guard.h | 51 ++ .../fbgemm_gpu/rocm/split_embeddings_common.h | 549 ++++++++++++++++++ .../tbe/training/backward_optimizers_test.py | 76 +++ fbgemm_gpu/test/test_utils.py | 20 + 14 files changed, 1643 insertions(+), 10 deletions(-) create mode 100644 fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip create mode 100644 fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h create mode 100644 fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index ac60a8dad8..5e01defc83 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -310,6 +310,27 @@ def generate_backward_indices() -> None: ssd=ssd, ) + @staticmethod + def generate_rocm_backward_split(**kwargs: Any) -> None: + # Generate backward device kernels based on weighted (True/False), VBE + # (True/False), no bag (True/False) + template_filepath = ( + "training/backward/rocm/embedding_backward_split_device_kernel_template.hip" + ) + + BackwardSplitGenerator.render_backward_templates( + template_filepath, + "", + "{}gen_embedding_backward_{}_device_kernel_hip.hip", + { + "has_gpu_support": True, + "has_vbe_support": False, + "has_ssd_support": False, + "dense": False, + "gen_once": False, + }, + ) + @staticmethod def generate_python_sources( all_optimizers: List[str], ssd_optimizers: List[str] @@ -369,6 +390,7 @@ def generate() -> None: BackwardSplitGenerator.generate_backward_split( ssd_tensors=ssd_tensors, **optimizer ) + BackwardSplitGenerator.generate_rocm_backward_split(**optimizer) # Generate common device kernels for backwards BackwardSplitGenerator.generate_backward_device() diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp index ee608e83e0..3c18a2b9bf 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp @@ -171,7 +171,8 @@ Tensor split_embedding_codegen_lookup_dense_function( Tensor>& /* vbe_B_offsets_rank_per_feature = std::nullopt */, c10::SymInt /* max_B = -1 */, c10::SymInt /* max_B_feature_rank = -1 */, - c10::SymInt /* vbe_output_size = -1 */) { + c10::SymInt /* vbe_output_size = -1 */, + bool /* mixed_D = false */) { return SplitLookupFunction_Dense_Op::apply( host_weights, weights_offsets, @@ -190,7 +191,7 @@ Tensor split_embedding_codegen_lookup_dense_function( // Deprecated for fb namespace! Please use fbgemm namespace instead! TORCH_LIBRARY_FRAGMENT(fb, m) { m.def( - "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1) -> Tensor"); + "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=False) -> Tensor"); DISPATCH_TO_CPU( "dense_embedding_codegen_lookup_function", split_embedding_codegen_lookup_dense_function); @@ -198,7 +199,7 @@ TORCH_LIBRARY_FRAGMENT(fb, m) { TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( - "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1) -> Tensor"); + "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=False) -> Tensor"); DISPATCH_TO_CPU( "dense_embedding_codegen_lookup_function", split_embedding_codegen_lookup_dense_function); diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 63fa373a50..bbb1ebbadb 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -152,6 +152,7 @@ enum SSDTensor { {%- else %} D_offsets, max_D, + mixed_D, {%- endif %} {# /* if nobag */ #} hash_size_cumsum, total_hash_size_bits, @@ -224,6 +225,7 @@ enum SSDTensor { Variable(), // D_offsets Variable(), // total_D Variable(), // max_D + Variable(), // mixed_D {%- endif %} Variable(), // hash_size_cumsum Variable(), //total_hash_size_bits @@ -304,6 +306,7 @@ enum SSDTensor { D_offsets, total_D, max_D, + mixed_D, {%- endif %} hash_size_cumsum, total_hash_size_bits, @@ -484,6 +487,7 @@ Tensor {%- else %} const Tensor& D_offsets, const c10::SymInt max_D, + const bool mixed_D, {%- endif %} const Tensor& hash_size_cumsum, const int64_t total_hash_size_bits, @@ -566,6 +570,7 @@ class {{ autograd_func }} : const Tensor& D_offsets, const c10::SymInt total_D, const c10::SymInt max_D, + const bool mixed_D, {%- else %} const c10::SymInt D, {%- endif %} @@ -762,6 +767,7 @@ class {{ autograd_func }} : {%- if not nobag %} ctx->saved_data["max_D"] = max_D; + ctx->saved_data["mixed_D"] = mixed_D; ctx->saved_data["pooling_mode"] = pooling_mode; {%- else %} ctx->saved_data["D"] = D; @@ -877,6 +883,7 @@ class {{ autograd_func }} : {%- if not nobag %} auto max_D = ctx->saved_data["max_D"].toSymInt(); + const auto mixed_D = ctx->saved_data["mixed_D"].toBool(); auto pooling_mode = ctx->saved_data["pooling_mode"].toInt(); {%- else %} auto D = ctx->saved_data["D"].toSymInt(); @@ -1072,10 +1079,11 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( {%- if ssd %} const std::optional& ssd_tensors = std::nullopt, {%- endif %} - const double gwd_lower_bound = 0 + const double gwd_lower_bound = 0, {%- else %} - const c10::SymInt vbe_output_size = -1 + const c10::SymInt vbe_output_size = -1, {%- endif %} + const bool mixed_D = false ) { // TODO: refactor into macro {%- if has_gpu_support %} @@ -1191,7 +1199,8 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) { {%- if ssd %} " Tensor[]? ssd_tensors=None," {%- endif %} - " float gwd_lower_bound=0 " + " float gwd_lower_bound=0, " + " bool mixed_D=False" ") -> Tensor", {PT2_COMPLIANT_TAG}); diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 3b230b0100..0e4f552ebc 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -521,5 +521,314 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row #endif //////////////////////////////////////////////////////////////////////////////// +{%- endif %} + +{%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and not dense and not is_gwd_kernel and not vbe and not ssd %} +#include +#include +#include "fbgemm_gpu/rocm/split_embeddings_common.h" +#include "gen_embedding_backward_split_{{ desc_suffix }}{{ ndesc }}_device_kernel_hip.hip" + +template < + typename emb_t, + typename grad_t, + typename cache_t, + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize, + bool kUseVecBlocking, + int32_t embedding_dim, + int32_t weight_decay_mode_v> +__global__ void +hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1( + const pta::PackedTensorAccessor64 grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64 dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64 uvm_weights, + pta::PackedTensorAccessor64 lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, + int32_t max_segment_length_per_warp, + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64 grad_dev_weights, + {%- endif %} // if not dense and optimizer != "none" + {%- if not nobag and vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const int32_t max_D, + const int32_t max_vecs_per_thread, + {%- if is_index_select %} + const at::PackedTensorAccessor32 grad_offsets, + const bool permute_output_dim_0_1 + {%- else %} + {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} + {%- endif %} +) { + {%- if not nobag %} + int32_t T = D_offsets.size(0) - 1; + {%- else %} + int32_t T = weights_offsets.size(0); + {%- endif %} + + auto p_output_grad = grad_output.data(); + auto p_emb_table = dev_weights.data(); + auto p_hash_size_cumsum = hash_size_cumsum.data(); + auto p_sorted_linear_indices_run = sorted_linear_indices_run.data(); + auto p_sorted_linear_indices_cumulative_run_lengths = sorted_linear_indices_cumulative_run_lengths.data(); + auto p_sorted_linear_indices_num_runs = sorted_linear_indices_num_runs.data(); + auto p_sorted_infos = sorted_infos.data(); + {%- if weighted %} + auto p_indice_weights_sorted = sorted_indice_weights.data(); + {%- endif %} + auto emb_dim = embedding_dim; + constexpr int32_t segment_prefetch = 2; + constexpr int32_t segment_unroll = 8; + constexpr int32_t segment_split = 0; + auto batch = grad_output.size(0); + auto num_rows = dev_weights.size(0) / T / max_D; + {%- if weighted %} + constexpr bool is_weighted = true; + {%- else %} + constexpr bool is_weighted = false; + {%- endif %} + rocm::{{optimizer}}_kernel_arg_t opt_karg; + opt_karg.p_momentum = momentum1_dev.data(); + opt_karg.eps = eps; + opt_karg.learning_rate = learning_rate; + // weight_decay(_mode) is supplied as args.split_function_args_no_defaults + opt_karg.weight_decay_mode = weight_decay_mode_v; + opt_karg.weight_decay = weight_decay; + auto batch_mdiv = [](uint32_t d) -> rocm::magic_div_u32_t { + assert(d >= 1 && d <= INT32_MAX); + uint8_t shift; + for(shift = 0; shift < 32; shift++) + if((1U << shift) >= d) + break; + + uint64_t one = 1; + uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1; + assert(magic <= 0xffffffffUL); + + rocm::magic_div_u32_t result; + result.magic = magic; + result.shift = shift; + return result; + }(batch); + rocm::split_tbe_backward_hip_kernel_{{kdesc}}< + rocm::{{optimizer}}_optimizer_t, + rocm::{{optimizer}}_kernel_arg_t, + emb_t, + cache_t, + grad_t, + BLOCK_SIZE, + embedding_dim, + segment_prefetch, + segment_unroll, + segment_split, + is_weighted>(p_output_grad, + p_emb_table, + p_hash_size_cumsum, + p_sorted_linear_indices_run, + p_sorted_linear_indices_cumulative_run_lengths, + p_sorted_linear_indices_num_runs, + {%- if not nobag %} + info_B_num_bits, + info_B_mask, + {%- endif %} + p_sorted_infos, + batch_mdiv, + max_segment_length_per_warp, + emb_dim, + batch, + num_rows, + T, + opt_karg + {%- if weighted %} + , p_indice_weights_sorted + {%- endif %}); +} + +{%- macro hip_template_instantiation( + emb_type, + grad_type, + cache_type, + kFixedMaxVecsPerThread, + kThreadGroupSize, + kUseVecBlocking, + kEmbeddingDim, + kWeighDecayMode + ) +%} +template __global__ __launch_bounds__(kBackwardMaxThreads) void +hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1 +< {{ emb_type }}, + {{ grad_type }}, + {{ cache_type }}, + {{ kFixedMaxVecsPerThread }}, + {{ kThreadGroupSize }}, + {{ kUseVecBlocking }}, + {{ kEmbeddingDim }}, + {{ kWeighDecayMode }} +> ( + const pta::PackedTensorAccessor64<{{ grad_type }}, {{ "1" if is_index_select else "2" }}, at::RestrictPtrTraits> grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> uvm_weights, + pta::PackedTensorAccessor64<{{ cache_type }}, 2, at::RestrictPtrTraits> lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, + int32_t max_segment_length_per_warp, + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64< {{ emb_type }}, 1, at::RestrictPtrTraits> grad_dev_weights, + {%- endif %} // if not dense and optimizer != "none" + {%- if not nobag and vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const int32_t max_D, + const int32_t max_vecs_per_thread, + {%- if is_index_select %} + const at::PackedTensorAccessor32 grad_offsets, + const bool permute_output_dim_0_1 + {%- else %} + {{ args.split_kernel_args_no_defaults | replace_pta_namespace() | join(",\n ") | replace("cache_t", cache_type) }} + {%- endif %} +); +{%- endmacro %} + +{%- macro hip_bulk_template_instantiations(kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) %} + {%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %} + {%- for emb_type in ['float', 'at::Half'] %} + {%- for cache_type in ['float', 'at::Half'] %} + {%- for kEmbeddingDim in [64, 128, 160, 192, 256] %} + {%- for kWeighDecayMode in [0, 1, 2] %} + {{ hip_template_instantiation( + emb_type, + grad_type, + cache_type, + kFixedMaxVecsPerThread, + kThreadGroupSize, + kUseVecBlocking, + kEmbeddingDim, + kWeighDecayMode + ) + }} + {%- endfor %} + {%- endfor %} + {%- endfor %} + {%- endfor %} + {%- endfor %} +{%- endmacro %} + +{%- macro hip_instantiate_templates(use_subwarp_shuffle) %} +{%- for (kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) + in get_max_vecs_template_configs( + items_per_warp, + fixed_max_vecs_per_thread["backward"], + use_subwarp_shuffle, + use_vec_blocking=True, + ) +%} + {{ + hip_bulk_template_instantiations( + kFixedMaxVecsPerThread, + kThreadGroupSize, + kUseVecBlocking, + ) + }} +{%- endfor %} +{%- endmacro %} + +//////////////////////////////////////////////////////////////////////////////// +#ifdef FBGEMM_USE_SUBWARP_SHUFFLE +//////////////////////////////////////////////////////////////////////////////// + +{#- /* + Explicitly instantiate kernels for the FBGEMM_USE_SUBWARP_SHUFFLE case + Please see get_max_vecs_template_configs in + codegen/embedding_common_code_generator.py for more details +*/ #} + +{{ hip_instantiate_templates(use_subwarp_shuffle=True) }} + +//////////////////////////////////////////////////////////////////////////////// +#else +//////////////////////////////////////////////////////////////////////////////// + +{#- /* + Explicitly instantiate kernels for the non-FBGEMM_USE_SUBWARP_SHUFFLE case + Please see get_max_vecs_template_configs in + codegen/embedding_common_code_generator.py for more details +*/ #} + +{{ hip_instantiate_templates(use_subwarp_shuffle=False) }} + +//////////////////////////////////////////////////////////////////////////////// +#endif +//////////////////////////////////////////////////////////////////////////////// {%- endif %} // clang-format on diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp index 6b3d5604d1..def21bd39d 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp @@ -72,6 +72,9 @@ Tensor {{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc {%- else %} const c10::SymInt D, {%- endif %} + {%- if not nobag and not is_index_select %} + const bool mixed_D, + {%- endif %} const Tensor& hash_size_cumsum, const int64_t total_hash_size_bits, const Tensor& indices, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index fdd9c0f798..4c8038d97c 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -26,6 +26,10 @@ #include "fbgemm_gpu/split_embeddings_utils.cuh" #include "fbgemm_gpu/utils/ops_utils.h" +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -211,6 +215,78 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} {%- endif %} ); + +{%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select + and not is_gwd_kernel and not vbe and not ssd %} +#include "fbgemm_gpu/rocm/split_embeddings_common.h" +template < + typename emb_t, + typename grad_t, + typename cache_t, + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize, + bool kUseVecBlocking, + int32_t embedding_dim, + int32_t weight_decay_mode_v> +__global__ void +hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1( + const pta::PackedTensorAccessor64 grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64 dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64 uvm_weights, + pta::PackedTensorAccessor64 lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, + int32_t max_segment_length_per_warp, + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64 grad_dev_weights, + {%- endif %} // if not dense and optimizer != "none" + {%- if not nobag and vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const int32_t max_D, + const int32_t max_vecs_per_thread, + {%- if is_index_select %} + const at::PackedTensorAccessor32 grad_offsets, + const bool permute_output_dim_0_1 + {%- else %} + {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} + {%- endif %} +); +{%- endif %} {% if is_index_select %} namespace index_select { {% else %} @@ -452,6 +528,9 @@ Tensor {{ embedding_cuda_op }}( {%- else %} const c10::SymInt D_, {%- endif %} + {%- if not nobag and not is_index_select %} + const bool mixed_D, + {%- endif %} const Tensor& hash_size_cumsum, const int64_t total_hash_size_bits, const Tensor& indices, @@ -775,6 +854,17 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} + {%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select + and not is_gwd_kernel and not vbe and not ssd %} + {%- set hip_kernel = "hip_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( + ndesc, + optimizer, + wdesc, + vdesc, + ) + %} + {%- endif %} + DISPATCH_EMB_GRAD_CACHE_TYPES( dev_weights.scalar_type(), aligned_grad_output.scalar_type(), @@ -1070,7 +1160,7 @@ Tensor {{ embedding_cuda_op }}( desc_suffix, ) %} - const auto backward_warp_per_row_kernel = + auto backward_warp_per_row_kernel = {{ warp_kernel }} (), segments_per_workgroup); + blockSize = dim3(256); + warp_per_row_smem_bytes = 0; + + backward_warp_per_row_kernel = + {{ hip_kernel }} + ; + } + {%- endfor %} + {%- endfor %} + } + {%- endif %} +#endif + + #ifdef FBGEMM_GPU_MEMCHECK const auto func_name4 = "{{ warp_kernel }}"; #endif backward_warp_per_row_kernel <<>>( grad_output_accessor, @@ -1222,6 +1347,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- else %} " SymInt D, " {%- endif %} + {%- if not nobag and not is_index_select %} + " bool mixed_D, " + {%- endif %} " Tensor hash_size_cumsum, " " int total_hash_size_bits, " " Tensor indices, " diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip new file mode 100644 index 0000000000..0374a1724c --- /dev/null +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -0,0 +1,461 @@ +/******************************************************************************* + * Copyright (c) 2016 - 2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + ******************************************************************************/ + +#include +#include + +#include "fbgemm_gpu/rocm/split_embeddings_common.h" + +namespace fbgemm_gpu::rocm { +template +struct rowwise_adagrad_optimizer_t +{ + __device__ rowwise_adagrad_optimizer_t(const rowwise_adagrad_kernel_arg_t& karg_) + : karg(karg_) + { + } + + template + __device__ void update(cache_t* acc, emb_t* weight, uint32_t row_index) + { + if constexpr(segment_split == 0) + { + cache_t * p_momentum = reinterpret_cast(karg.p_momentum); + cache_t momentum = p_momentum[row_index]; // should be s_load + // compute per row square sum + cache_t local_sum_squre = .0f; + if constexpr(weight_decay_mode == 1) + { +#pragma unroll + for(auto i = 0; i < thread_length; i++) + { + cache_t w = static_cast(weight[i]); + cache_t a = acc[i] + w * karg.weight_decay; + local_sum_squre += a * a; + } + } + else + { +#pragma unroll + for(auto i = 0; i < thread_length; i++) + { + cache_t a = acc[i]; + local_sum_squre += a * a; + } + } + + cache_t avg_square = + wave_reduce, cache_t, AMDGCN_WAVE_SIZE>(local_sum_squre) / + embedding_dim; + + cache_t momentum_new = momentum + avg_square; + + cache_t multiplier = karg.learning_rate / (sqrtf(momentum_new) + karg.eps); + cache_t correction; + + if constexpr(weight_decay_mode == 1) + { + correction = 1.0 - multiplier * karg.weight_decay; + } + else if constexpr(weight_decay_mode == 2) + { + correction = 1.0 - karg.learning_rate * karg.weight_decay; + } + else + { + correction = 1.0; + } + +// update new weight value +#pragma unroll + for(auto i = 0; i < thread_length; i++) + { + cache_t w = static_cast(weight[i]); + cache_t a = acc[i]; + w = correction * w - multiplier * a; + weight[i] = static_cast(w); + } + + p_momentum[row_index] = momentum_new; + } + } + + rowwise_adagrad_kernel_arg_t karg; +}; + +template +__device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( + const grad_t* p_output_grad, + emb_t* p_emb_table, + const int64_t* p_hash_size_cumsum, + const int64_t* p_sorted_linear_indices_run, + const int32_t* p_sorted_linear_indices_cumulative_run_lengths, + const int32_t* p_sorted_linear_indices_num_runs, + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + {%- if not nobag %} + const int32_t* p_sorted_infos, + {%- else %} + const int64_t* p_sorted_infos, + {%- endif %} + magic_div_u32_t batch_mdiv, + uint32_t max_segment_length_per_warp, + uint32_t emb_dim, + uint32_t batch, + uint32_t num_rows, + uint32_t num_tables, + optimizer_karg_t opt_karg, + const float * p_sorted_indice_weights = nullptr) +{ + constexpr uint32_t dword_per_row = (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; + constexpr uint32_t waves_per_block = block_size / AMDGCN_WAVE_SIZE; + constexpr uint32_t length_mask = ~(segment_unroll - 1); + const uint32_t wave_id = __builtin_amdgcn_readfirstlane(threadIdx.x / AMDGCN_WAVE_SIZE); + const uint32_t lane_id = threadIdx.x % AMDGCN_WAVE_SIZE; + const uint32_t run_id = wave_id + blockIdx.x * waves_per_block; + + if(run_id >= p_sorted_linear_indices_num_runs[0]) + { + return; + } + + const int64_t linear_index = p_sorted_linear_indices_run[run_id]; + + const int32_t segment_start = p_sorted_linear_indices_cumulative_run_lengths[run_id]; + const int32_t segment_end = p_sorted_linear_indices_cumulative_run_lengths[run_id + 1]; + + {%- if nobag %} + const auto info_0 = p_sorted_infos[segment_start]; + int32_t t_0 = info_0 % num_tables; + {%- else %} + const auto info_0 = reinterpret_cast(&p_sorted_infos[0])[segment_start]; + const auto t_0 = info_0 >> info_B_num_bits; + {%- endif %} + int64_t hash_size = p_hash_size_cumsum[t_0]; + + const int64_t emb_idx = linear_index - hash_size; + + p_emb_table += hash_size * emb_dim; + opt_karg.p_momentum = reinterpret_cast(reinterpret_cast(opt_karg.p_momentum) + hash_size); + + const int32_t segment_length = segment_end - segment_start; + + if(segment_length >= max_segment_length_per_warp) + return; + + const int32_t segment_length_mod = segment_length & length_mask; + + cache_t grad_acc[dword_per_row]; + int32_t infos[segment_unroll]; + grad_t grad_data[dword_per_row * segment_prefetch]; + emb_t emb_data[dword_per_row]; + float indice_weights[segment_unroll]; + + #pragma unroll + for(int i=0; i < dword_per_row; i++) + { + grad_acc[i] = .0f; + } + + int itr = 0; + if(segment_length_mod == 0) + goto L_tail_grad_acc; + + if constexpr (!weighted) { + #pragma unroll + for(int i = 0; i < segment_unroll; i++) + { + infos[i] = p_sorted_infos[segment_start + i]; + } + } else { + for(int i = 0; i < segment_unroll; i++) + { + infos[i] = p_sorted_infos[segment_start + i]; + indice_weights[i] = p_sorted_indice_weights[segment_start + i]; + } + } + + itr += segment_unroll; + p_sorted_infos += segment_unroll; + + if constexpr (weighted) { + p_sorted_indice_weights += segment_unroll; + } + + uint32_t bag_index; + uint32_t table_index; + + // LOOP + for(; itr < segment_length_mod; itr += segment_unroll) + { + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); + {%- else %} + table_index = infos[0] >> info_B_num_bits; + bag_index = infos[0] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); + {%- else %} + table_index = infos[1] >> info_B_num_bits; + bag_index = infos[1] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + if constexpr (!weighted){ + #pragma unroll + for(int j = 2; j < segment_unroll; j += 2) + { + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); + {%- else %} + table_index = infos[j] >> info_B_num_bits; + bag_index = infos[j] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id); + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); + {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; + bag_index = infos[j + 1] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + } + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id); + + #pragma unroll + for(int i = 0; i < segment_unroll; i++) + { + infos[i] = p_sorted_infos[segment_start + i]; + } + p_sorted_infos += segment_unroll; + + + } else { + #pragma unroll + for(int j = 2; j < segment_unroll; j += 2) + { + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); + {%- else %} + table_index = infos[j] >> info_B_num_bits; + bag_index = infos[j] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); + {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; + bag_index = infos[j + 1] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + } + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id, indice_weights[segment_unroll-2]); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[segment_unroll-1]); + + #pragma unroll + for(int i = 0; i < segment_unroll; i++) + { + infos[i] = p_sorted_infos[segment_start + i]; + indice_weights[i] = p_sorted_indice_weights[segment_start + i]; + } + p_sorted_infos += segment_unroll; + p_sorted_indice_weights += segment_unroll; + } + } + + // LAST + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); + {%- else %} + table_index = infos[0] >> info_B_num_bits; + bag_index = infos[0] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); + {%- else %} + table_index = infos[1] >> info_B_num_bits; + bag_index = infos[1] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + if constexpr (!weighted) { + #pragma unroll + for(int j = 2; j < segment_unroll; j += 2) + { + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); + {%- else %} + table_index = infos[j] >> info_B_num_bits; + bag_index = infos[j] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id); + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); + {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; + bag_index = infos[j + 1] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + } + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id); + } else { + #pragma unroll + for(int j = 2; j < segment_unroll; j += 2) + { + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); + {%- else %} + table_index = infos[j] >> info_B_num_bits; + bag_index = infos[j] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); + {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; + bag_index = infos[j + 1] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + } + + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id, indice_weights[segment_unroll-2]); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[segment_unroll-1]); + } + +L_tail_grad_acc: + if(segment_length & (segment_unroll - 1)) + { + if constexpr (!weighted){ + // last, load one by one + do + { + infos[0] = p_sorted_infos[segment_start]; + p_sorted_infos++; + + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); + {%- else %} + table_index = infos[0] >> info_B_num_bits; + bag_index = infos[0] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id); + + itr++; + } while(itr < segment_length); + } else { + do + { + infos[0] = p_sorted_infos[segment_start]; + indice_weights[0] = p_sorted_indice_weights[segment_start]; + p_sorted_infos++; + p_sorted_indice_weights++; + + {%- if nobag %} + magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); + {%- else %} + table_index = infos[0] >> info_B_num_bits; + bag_index = infos[0] & info_B_mask; + {%- endif %} + load_row_per_warp::run( + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + accumulate_row_per_warp::run( + &grad_acc[0], &grad_data[0], lane_id, indice_weights[0]); + + itr++; + } while(itr < segment_length); + } + } + + // load the old emb weight data + load_row_per_warp::run( + &emb_data[0], emb_idx, p_emb_table, lane_id); + optimizer_t optimizer(opt_karg); + optimizer.template update(grad_acc, emb_data, emb_idx); + + store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id); +} +} // namespace fbgemm_gpu::rocm \ No newline at end of file diff --git a/fbgemm_gpu/codegen/training/python/lookup_args.template b/fbgemm_gpu/codegen/training/python/lookup_args.template index 357aad622a..f3fd7aa87a 100644 --- a/fbgemm_gpu/codegen/training/python/lookup_args.template +++ b/fbgemm_gpu/codegen/training/python/lookup_args.template @@ -49,6 +49,7 @@ class CommonArgs(NamedTuple): {%- if ssd %} ssd_tensors: Dict[str, torch.Tensor] {%- endif %} + mixed_D: bool class OptimizerArgs(NamedTuple): diff --git a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template index e86b27b2dc..062d526a01 100644 --- a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template +++ b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template @@ -409,5 +409,6 @@ def invoke( use_homogeneous_placements=common_args.use_homogeneous_placements, apply_global_weight_decay=apply_global_weight_decay, gwd_lower_bound=gwd_lower_bound, + mixed_D=common_args.mixed_D, ) {%- endif %} diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index b6057aadfa..ca8476120d 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -728,7 +728,7 @@ def __init__( # noqa C901 assert ( not mixed_D ), "OptimType.NONE does not support mixed embedding dimension" - + self.mixed_D = mixed_D if device is None: self.current_device: torch.device = ( torch.device("cpu") @@ -1778,6 +1778,7 @@ def forward( # noqa: C901 is_experimental=self.is_experimental, use_uniq_cache_locations_bwd=self.use_uniq_cache_locations_bwd, use_homogeneous_placements=self.use_homogeneous_placements, + mixed_D=self.mixed_D, ) if self.optimizer == OptimType.NONE: @@ -3621,6 +3622,7 @@ def forward( max_B=vbe_metadata.max_B, max_B_feature_rank=vbe_metadata.max_B_feature_rank, vbe_output_size=vbe_metadata.output_size, + mixed_D=self.mixed_D, ) @torch.jit.export diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h new file mode 100644 index 0000000000..b55fd72fce --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h @@ -0,0 +1,51 @@ +/******************************************************************************* + * Copyright (c) 2016 - 2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + ******************************************************************************/ +#pragma once + +#include +#include +#include + +#define HIP_CHECK(c) \ + { \ + if (c != hipSuccess) { \ + printf("HIP Error : %s", hipGetErrorString(c)); \ + printf(" %s %d\n", __FILE__, __LINE__); \ + exit(c); \ + } \ + } + +namespace fbgemm_gpu::rocm { + +[[nodiscard]] inline bool is_supported_cdna() { + const std::set supported_archs{"gfx942", "gfx90a"}; + int device_id = 0; + HIP_CHECK(hipGetDevice(&device_id)); + hipDeviceProp_t dev_props; + HIP_CHECK(hipGetDeviceProperties(&dev_props, device_id)); + std::string gcn_arch = dev_props.gcnArchName; + gcn_arch = gcn_arch.substr(0, gcn_arch.find(":")); + return supported_archs.contains(gcn_arch); +} + +} // namespace fbgemm_gpu::rocm diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h new file mode 100644 index 0000000000..0058548e22 --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -0,0 +1,549 @@ +/******************************************************************************* + * Copyright (c) 2016 - 2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + ******************************************************************************/ +#pragma once +#include +#include + +/******************************************************************************/ +typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); +typedef float floatx2_t __attribute__((ext_vector_type(2))); +#define AMDGCN_BUFFER_RES_3 0x00027000 +#define AMDGCN_WAVE_SIZE 64 +#define THREADS_PER_ROW 64 +#define BLOCK_SIZE 256 + +namespace fbgemm_gpu::rocm { +template +union amdgcn_buffer_resource { + // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions + int32x4_t content; + struct { + T* address; + int32_t range; + int32_t config; + }; +}; + +template +__device__ int32x4_t amdgcn_make_buffer_resource(const T* addr) { + amdgcn_buffer_resource buffer_resource; + buffer_resource.address = const_cast(addr); + buffer_resource.range = 0xffffffff; + buffer_resource.config = AMDGCN_BUFFER_RES_3; // for gfx9 + + return buffer_resource.content; +} + +// buffer load fp32 +__device__ half llvm_amdgcn_raw_buffer_load_fp16( + int32x4_t srsrc, + int32_t voffset, + int32_t soffset, + int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); + +__device__ float llvm_amdgcn_raw_buffer_load_fp32( + int32x4_t srsrc, + int32_t voffset, + int32_t soffset, + int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); + +__device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( + int32x4_t srsrc, + int32_t voffset, + int32_t soffset, + int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); + +__device__ void llvm_amdgcn_raw_buffer_store_fp32( + float vdata, + int32x4_t rsrc, + int32_t voffset, + int32_t soffset, + int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); + +__device__ void llvm_amdgcn_raw_buffer_store_fp32x2( + floatx2_t vdata, + int32x4_t rsrc, + int32_t voffset, + int32_t soffset, + int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); + +/******************************************************************************/ + +template +struct load_row_per_warp { + static __device__ void run( + emb_t* emb_data, + index_t row_index, + const emb_t* p_emb_table, + int lane_id) {} +}; + +template +struct load_row_per_warp { + static constexpr int dword_per_row = + (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; + static __device__ void run( + float* emb_data, + index_t row_index, + const float* p_emb_table, + int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * embedding_dim); +#pragma unroll + for (int i = 0; i < dword_per_row; i++) { + if constexpr (embedding_dim == 160) { + if ((lane_id + i * THREADS_PER_ROW) % 192 < 160) { + emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); + } else { + emb_data[i] = 0.f; + } + } else { + emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); + } + } + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 64); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half), 0, 0); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 128); + *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, lane_id * sizeof(half2), 0, 0); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, lane_id * sizeof(half2), 0, 0); + if ((lane_id + 128) % 192 < 160) { + emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( + emb_res, (lane_id + 128) * sizeof(half), 0, 0); + } else { + emb_data[2] = __float2half(0.0); + } + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, lane_id * sizeof(half2), 0, 0); + emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( + emb_res, (lane_id + 128) * sizeof(half), 0, 0); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 256); + *reinterpret_cast(&emb_data[0]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, lane_id * sizeof(half2), 0, 0); + *reinterpret_cast(&emb_data[2]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 64) * sizeof(half2), 0, 0); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 512); + *reinterpret_cast(&emb_data[0]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, lane_id * sizeof(half2), 0, 0); + *reinterpret_cast(&emb_data[2]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 64) * sizeof(half2), 0, 0); + *reinterpret_cast(&emb_data[4]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 64 * 2) * sizeof(half2), 0, 0); + *reinterpret_cast(&emb_data[6]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 64 * 3) * sizeof(half2), 0, 0); + } +}; + +template < + typename emb_t, + int32_t embedding_dim, + typename output_t, + bool weighted> +struct accumulate_row_per_warp { + static constexpr int dword_per_row = + (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; + static __device__ void + run(output_t* acc, emb_t* emb_data, int lane_id, float row_weight = 1.0) { + if constexpr (!weighted) { +#pragma unroll + for (int i = 0; i < dword_per_row; i++) { + acc[i] += static_cast(emb_data[i]); + } + } else { +#pragma unroll + for (int i = 0; i < dword_per_row; i++) { + acc[i] += static_cast((float)emb_data[i] * row_weight); + } + } + } +}; + +template +struct store_row_per_warp { + static constexpr int dword_per_row = + (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; + static __device__ void run(output_t* acc, output_t* p_output, int lane_id) { + if constexpr (embedding_dim == 160) { + for (int i = 0; i < dword_per_row; i++) { + if ((lane_id + i * THREADS_PER_ROW) % 192 < 160) { + p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; + } + } + } else { +#pragma unroll + for (int i = 0; i < dword_per_row; i++) { + p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; + } + } + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32x2( + *reinterpret_cast(acc), + out_res, + lane_id * sizeof(floatx2_t), + 0, + 0); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32x2( + *reinterpret_cast(acc), + out_res, + lane_id * sizeof(floatx2_t), + 0, + 0); + if ((lane_id + 128) % 192 < 160) { + llvm_amdgcn_raw_buffer_store_fp32( + acc[2], out_res, (lane_id + 128) * sizeof(float), 0, 0); + } + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32x2( + *reinterpret_cast(acc), + out_res, + lane_id * sizeof(floatx2_t), + 0, + 0); + llvm_amdgcn_raw_buffer_store_fp32( + acc[2], out_res, (lane_id + 128) * sizeof(float), 0, 0); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32x2( + *reinterpret_cast(acc), + out_res, + lane_id * sizeof(floatx2_t), + 0, + 0); + llvm_amdgcn_raw_buffer_store_fp32x2( + *reinterpret_cast(&acc[2]), + out_res, + (lane_id + 64) * sizeof(floatx2_t), + 0, + 0); + } +}; + +// Helper function to pack fp16 and fp32 into int to further pass +// into mov_dpp and readfirstlane() +template + requires( + (sizeof(to_t) == 4 || sizeof(to_t) == 2) && + (sizeof(from_t) == 4 || sizeof(from_t) == 2)) +__device__ to_t pack(const from_t& v) { + to_t result = 0; + if constexpr (sizeof(to_t) == sizeof(from_t)) { + result = __builtin_bit_cast(to_t, v); + return result; + } + + memcpy(&result, &v, 2); + + return result; +} + +namespace reduce_op { +struct sum {}; +struct sub {}; +struct mul {}; +struct div {}; +} // namespace reduce_op + +template +struct reduce_op_sum_t { + __device__ data_t operator()(const data_t& a, const data_t& b) { + return a + b; + } +}; + +#define DPP_REDUCE(OP, TYPE) \ + __asm__ volatile( \ + "v_nop\n" \ + "v_nop\n" \ + "v_nop\n" \ + "v_nop\n" \ + "v_" #OP "_" #TYPE \ + "_dpp %0 %0 %0 quad_perm:[1,0,3,2]\n" \ + "v_nop\n" \ + "v_nop\n" \ + "v_" #OP "_" #TYPE \ + "_dpp %0 %0 %0 quad_perm:[2,3,0,1]\n" \ + "v_nop\n" \ + "v_nop\n" \ + "v_" #OP "_" #TYPE \ + "_dpp %0 %0 %0 row_shr:4\n" \ + "v_nop\n" \ + "v_nop\n" \ + "v_" #OP "_" #TYPE \ + "_dpp %0 %0 %0 row_shr:8\n" \ + "v_nop\n" \ + "v_nop\n" \ + "v_" #OP "_" #TYPE \ + "_dpp %0 %0 %0 row_bcast:15\n" \ + "v_nop\n" \ + "v_nop\n" \ + "v_" #OP "_" #TYPE \ + "_dpp %0 %0 %0 row_bcast:31\n" \ + "v_nop\n" \ + "v_nop\n" \ + : "=v"(result) \ + : "0"(result)) + +#define DPP_REDUCE_F16_F32(OP) \ + if constexpr (std::is_same_v) { \ + DPP_REDUCE(OP, f32); \ + } \ + \ + if constexpr (std::is_same_v) { \ + DPP_REDUCE(OP, f16); \ + } + +template +__device__ __forceinline__ void generic_dpp_reduction(data_t& result) { + constexpr int row_mask = 0xf; + constexpr int bank_mask = 0xf; + constexpr bool bound_ctrl = false; + + reduce_op_t reduce_op; + + if constexpr (wave_size > 1) { + result = reduce_op( + result, + pack(__builtin_amdgcn_mov_dpp( + pack(result), + 0xb1, + row_mask, + bank_mask, + bound_ctrl))); // quad_perm:[1,0,3,2] + } + if constexpr (wave_size > 2) { + result = reduce_op( + result, + pack(__builtin_amdgcn_mov_dpp( + pack(result), + 0x4e, + row_mask, + bank_mask, + bound_ctrl))); // quad_perm:[2,3,0,1] + } + if constexpr (wave_size > 4) { + result = reduce_op( + result, + pack(__builtin_amdgcn_mov_dpp( + pack(result), + 0x114, + row_mask, + bank_mask, + bound_ctrl))); // row_shr:4 + } + if constexpr (wave_size > 8) { + result = reduce_op( + result, + pack(__builtin_amdgcn_mov_dpp( + pack(result), + 0x118, + row_mask, + bank_mask, + bound_ctrl))); // row_shr:8 + } + if constexpr (wave_size > 16) { + result = reduce_op( + result, + pack(__builtin_amdgcn_mov_dpp( + pack(result), + 0x142, + row_mask, + bank_mask, + bound_ctrl))); // row_bcast:15 + } + if constexpr (wave_size > 32) { + result = reduce_op( + result, + pack(__builtin_amdgcn_mov_dpp( + pack(result), + 0x143, + row_mask, + bank_mask, + bound_ctrl))); // row_bcast:31 + } +} + +// Use corresponding assebly instruction for dpp reduction in case +// of trivial operation with an option to use custom operation +template +__device__ __forceinline__ void dpp_reduction(data_t& result) { +#if defined(__gfx942__) || defined(__gfx90a__) + if constexpr (std::is_same_v) { + DPP_REDUCE_F16_F32(add); + return; + } else if constexpr (std::is_same_v) { + DPP_REDUCE_F16_F32(sub); + return; + } else if constexpr (std::is_same_v) { + DPP_REDUCE_F16_F32(mul); + return; + } else if constexpr (std::is_same_v) { + DPP_REDUCE_F16_F32(div); + return; + } else { + generic_dpp_reduction(result); + } +#endif +} + +template +__device__ inline data_t wave_reduce(const data_t& thread_data) { + data_t result = thread_data; + + // now the reduced value is in the last lane of wave + dpp_reduction(result); + return pack( + __builtin_amdgcn_readlane(pack(result), wave_size - 1)); +} + +struct rowwise_adagrad_kernel_arg_t { + void* p_momentum; + float eps; + float learning_rate; + float weight_decay; + int64_t weight_decay_mode; +}; + +typedef struct { + uint32_t magic; + uint32_t shift; // actually 8 bit is enough +} magic_div_u32_t; + +static inline magic_div_u32_t magic_div_u32_gen(uint32_t d) { + assert(d >= 1 && d <= INT32_MAX); + uint8_t shift; + for (shift = 0; shift < 32; shift++) + if ((1U << shift) >= d) + break; + + uint64_t one = 1; + uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1; + assert(magic <= 0xffffffffUL); + + magic_div_u32_t result; + result.magic = magic; + result.shift = shift; + return result; +} + +// numer / denom = quotient, reminder +__device__ inline uint32_t magic_div_u32_run( + const magic_div_u32_t& mdiv, + const uint32_t& n) { + uint32_t tmp = __umulhi(n, mdiv.magic); + return (tmp + n) >> mdiv.shift; +} + +__device__ inline void magic_div_u32_run_with_mod( + const magic_div_u32_t& mdiv, + const uint32_t& n, + const uint32_t d, + uint32_t& quo, + uint32_t& rem) { + quo = magic_div_u32_run(mdiv, n); + rem = n - quo * d; +} +} // namespace fbgemm_gpu::rocm diff --git a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py index cf7f0cbd85..5e65e40bf6 100644 --- a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py @@ -58,6 +58,7 @@ additional_decorators, gpu_unavailable, optests, + skipIfNotRocm, TEST_WITH_ROCM, use_cpu_strategy, ) @@ -66,6 +67,7 @@ additional_decorators, gpu_unavailable, optests, + skipIfNotRocm, TEST_WITH_ROCM, use_cpu_strategy, ) @@ -1080,6 +1082,80 @@ def test_backward_optimizers_adagrad( # noqa C901 weight_decay_mode, ) + @given( + T=st.integers(min_value=1, max_value=5), + D=st.sampled_from([16, 32, 40, 48, 64]), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=2, max_value=20), + weighted=st.booleans(), + mixed=st.just(False), + mixed_B=st.just(False), + optimizer=st.sampled_from( + [ + OptimType.EXACT_ROWWISE_ADAGRAD, + ] + ), + long_segments=st.booleans(), + pooling_mode=st.sampled_from( + [ + PoolingMode.SUM, + ] + ), + use_cpu=st.just(False), + weight_decay_mode=st.sampled_from( + [ + WeightDecayMode.NONE, + WeightDecayMode.L2, + WeightDecayMode.DECOUPLE, + ] + ), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], + ) + @unittest.skipIf(*gpu_unavailable) + @skipIfNotRocm("Test only evaluates ROCm optimized kernels") + def test_new_bwd_kernel( # noqa C901 + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + mixed: bool, + mixed_B: bool, + optimizer: OptimType, + long_segments: bool, + pooling_mode: PoolingMode, + use_cpu: bool, + weight_decay_mode: WeightDecayMode, + ) -> None: + if ( + pooling_mode == PoolingMode.NONE + or optimizer != OptimType.EXACT_ROWWISE_ADAGRAD + ): + mixed_B = False + self.execute_backward_optimizers_( + T, + D, + B, + log_E, + L, + weighted, + mixed, + mixed_B, + optimizer, + long_segments, + pooling_mode, + use_cpu, + weight_decay_mode, + ) + @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), diff --git a/fbgemm_gpu/test/test_utils.py b/fbgemm_gpu/test/test_utils.py index 853b2d070c..e073f7a383 100644 --- a/fbgemm_gpu/test/test_utils.py +++ b/fbgemm_gpu/test/test_utils.py @@ -254,6 +254,26 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return decorator +# pyre-fixme[3]: Return annotation cannot be `Any`. +def skipIfNotRocm( + reason: str = "Test currently doesn work only on the ROCm stack", +) -> Any: + # pyre-fixme[3]: Return annotation cannot be `Any`. + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. + def decorator(fn: Callable) -> Any: + @wraps(fn) + # pyre-fixme[3]: Return annotation cannot be `Any`. + def wrapper(*args: Any, **kwargs: Any) -> Any: + if TEST_WITH_ROCM: + fn(*args, **kwargs) + else: + raise unittest.SkipTest(reason) + + return wrapper + + return decorator + + # pyre-fixme[3]: Return annotation cannot be `Any`. def skipIfRocmLessThan(min_version: int) -> Any: # pyre-fixme[3]: Return annotation cannot be `Any`.