Skip to content

Commit

Permalink
Make auto select for embedding_ops and fix duplicate cu files (#2269)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2269

Both "embedding_ops" and "index_select_ops" include "gen_batch_index_select_dim0_backward_kernel_warp.cu", which caused the duplicate symbol issue.

~~So remove the "gen_embedding_backward_split_grad.cu" from the source of "index_select_ops", and include "embedding_ops" as its deps.~~

The linker cannot link to the "gen_embedding_backward_split_grad.cu" in the "embedding_ops" object, so we have to include this cu file explicitly. So introduce namespace to this file, and generate different files.

Reviewed By: jiaqizhai

Differential Revision: D52790445

fbshipit-source-id: a6779ebdf2028155ad20fbbf17588f635ecd6564
  • Loading branch information
houseroad authored and facebook-github-bot committed Jan 17, 2024
1 parent 85cd858 commit 54a56f0
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 3 deletions.
3 changes: 2 additions & 1 deletion fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ set(gen_gpu_kernel_source_files
"gen_batch_index_select_dim0_backward_codegen_cuda.cu"
"gen_batch_index_select_dim0_backward_kernel_cta.cu"
"gen_batch_index_select_dim0_backward_kernel_warp.cu"
"gen_embedding_backward_split_grad.cu"
"gen_embedding_backward_split_grad_embedding_ops.cu"
"gen_embedding_backward_split_grad_index_select.cu"
)

if(NOT USE_ROCM)
Expand Down
10 changes: 8 additions & 2 deletions fbgemm_gpu/codegen/embedding_backward_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,10 @@ def index_select() -> None:
)

template = env.get_template("embedding_backward_split_grad_template.cu")
write("gen_embedding_backward_split_grad.cu", template.render())
write(
"gen_embedding_backward_split_grad_index_select.cu",
template.render(is_index_select=True),
)


def forward_quantized() -> None:
Expand Down Expand Up @@ -461,7 +464,10 @@ class elem_type:
def backward_grad() -> None:
# Generate the common grad functions
template = env.get_template("embedding_backward_split_grad_template.cu")
write("gen_embedding_backward_split_grad.cu", template.render())
write(
"gen_embedding_backward_split_grad_embedding_ops.cu",
template.render(is_index_select=False),
)


def backward_indices() -> None:
Expand Down
10 changes: 10 additions & 0 deletions fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,16 @@
#include "fbgemm_gpu/split_embeddings_utils.cuh"

using Tensor = at::Tensor;

using namespace fbgemm_gpu;

{% if is_index_select %}
namespace index_select {
{% else %}
namespace embedding_ops {
{% endif %}


__global__ __launch_bounds__(kMaxThreads) void
split_embedding_backward_codegen_find_long_segments(
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
Expand Down Expand Up @@ -225,4 +233,6 @@ void grad_mean{{ vdesc }}_kernel
{% endfor %} // for grad_type in ['at::Half', 'float']
{% endfor %} // for vbe in [True, False]

}

// clang-format on
15 changes: 15 additions & 0 deletions fbgemm_gpu/codegen/embedding_backward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc
{%- endif %}
);

{% if is_index_select %}
namespace index_select {
{% else %}
namespace embedding_ops {
{% endif %}


__global__ __launch_bounds__(kMaxThreads) void
split_embedding_backward_codegen_find_long_segments(
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_num_runs,
Expand Down Expand Up @@ -222,6 +229,14 @@ split_embedding_backward_count_unique_indices_kernel(
const int info_B_num_bits
);

}

{% if is_index_select %}
using namespace index_select;
{% else %}
using namespace embedding_ops;
{% endif %}

////////////////////////////////////////////////////////////////////////////////
// Utility Macros
////////////////////////////////////////////////////////////////////////////////
Expand Down

0 comments on commit 54a56f0

Please sign in to comment.