diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 59faa5dafe..3519d960aa 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -228,7 +228,8 @@ set(gen_gpu_kernel_source_files "gen_batch_index_select_dim0_backward_codegen_cuda.cu" "gen_batch_index_select_dim0_backward_kernel_cta.cu" "gen_batch_index_select_dim0_backward_kernel_warp.cu" - "gen_embedding_backward_split_grad.cu" + "gen_embedding_backward_split_grad_embedding_ops.cu" + "gen_embedding_backward_split_grad_index_select.cu" ) if(NOT USE_ROCM) diff --git a/fbgemm_gpu/codegen/embedding_backward_code_generator.py b/fbgemm_gpu/codegen/embedding_backward_code_generator.py index f42d2cda5c..791862859f 100644 --- a/fbgemm_gpu/codegen/embedding_backward_code_generator.py +++ b/fbgemm_gpu/codegen/embedding_backward_code_generator.py @@ -319,7 +319,10 @@ def index_select() -> None: ) template = env.get_template("embedding_backward_split_grad_template.cu") - write("gen_embedding_backward_split_grad.cu", template.render()) + write( + "gen_embedding_backward_split_grad_index_select.cu", + template.render(is_index_select=True), + ) def forward_quantized() -> None: @@ -461,7 +464,10 @@ class elem_type: def backward_grad() -> None: # Generate the common grad functions template = env.get_template("embedding_backward_split_grad_template.cu") - write("gen_embedding_backward_split_grad.cu", template.render()) + write( + "gen_embedding_backward_split_grad_embedding_ops.cu", + template.render(is_index_select=False), + ) def backward_indices() -> None: diff --git a/fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu index 8adb1ea4a7..7290636272 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu @@ -12,8 +12,16 @@ #include "fbgemm_gpu/split_embeddings_utils.cuh" using Tensor = at::Tensor; + using namespace fbgemm_gpu; +{% if is_index_select %} +namespace index_select { +{% else %} +namespace embedding_ops { +{% endif %} + + __global__ __launch_bounds__(kMaxThreads) void split_embedding_backward_codegen_find_long_segments( const pta::PackedTensorAccessor32 @@ -225,4 +233,6 @@ void grad_mean{{ vdesc }}_kernel {% endfor %} // for grad_type in ['at::Half', 'float'] {% endfor %} // for vbe in [True, False] +} + // clang-format on diff --git a/fbgemm_gpu/codegen/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_template.cu index 0cd3371a09..875ed32dbf 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_template.cu @@ -175,6 +175,13 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc {%- endif %} ); +{% if is_index_select %} +namespace index_select { +{% else %} +namespace embedding_ops { +{% endif %} + + __global__ __launch_bounds__(kMaxThreads) void split_embedding_backward_codegen_find_long_segments( const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, @@ -222,6 +229,14 @@ split_embedding_backward_count_unique_indices_kernel( const int info_B_num_bits ); +} + +{% if is_index_select %} +using namespace index_select; +{% else %} +using namespace embedding_ops; +{% endif %} + //////////////////////////////////////////////////////////////////////////////// // Utility Macros ////////////////////////////////////////////////////////////////////////////////