Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make auto select for embedding_ops and fix duplicate cu files #2269

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ set(gen_gpu_kernel_source_files
"gen_batch_index_select_dim0_backward_codegen_cuda.cu"
"gen_batch_index_select_dim0_backward_kernel_cta.cu"
"gen_batch_index_select_dim0_backward_kernel_warp.cu"
"gen_embedding_backward_split_grad.cu"
"gen_embedding_backward_split_grad_embedding_ops.cu"
"gen_embedding_backward_split_grad_index_select.cu"
)

if(NOT USE_ROCM)
Expand Down
10 changes: 8 additions & 2 deletions fbgemm_gpu/codegen/embedding_backward_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,10 @@ def index_select() -> None:
)

template = env.get_template("embedding_backward_split_grad_template.cu")
write("gen_embedding_backward_split_grad.cu", template.render())
write(
"gen_embedding_backward_split_grad_index_select.cu",
template.render(is_index_select=True),
)


def forward_quantized() -> None:
Expand Down Expand Up @@ -461,7 +464,10 @@ class elem_type:
def backward_grad() -> None:
# Generate the common grad functions
template = env.get_template("embedding_backward_split_grad_template.cu")
write("gen_embedding_backward_split_grad.cu", template.render())
write(
"gen_embedding_backward_split_grad_embedding_ops.cu",
template.render(is_index_select=False),
)


def backward_indices() -> None:
Expand Down
10 changes: 10 additions & 0 deletions fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,16 @@
#include "fbgemm_gpu/split_embeddings_utils.cuh"

using Tensor = at::Tensor;

using namespace fbgemm_gpu;

{% if is_index_select %}
namespace index_select {
{% else %}
namespace embedding_ops {
{% endif %}


__global__ __launch_bounds__(kMaxThreads) void
split_embedding_backward_codegen_find_long_segments(
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
Expand Down Expand Up @@ -225,4 +233,6 @@ void grad_mean{{ vdesc }}_kernel
{% endfor %} // for grad_type in ['at::Half', 'float']
{% endfor %} // for vbe in [True, False]

}

// clang-format on
15 changes: 15 additions & 0 deletions fbgemm_gpu/codegen/embedding_backward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc
{%- endif %}
);

{% if is_index_select %}
namespace index_select {
{% else %}
namespace embedding_ops {
{% endif %}


__global__ __launch_bounds__(kMaxThreads) void
split_embedding_backward_codegen_find_long_segments(
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_num_runs,
Expand Down Expand Up @@ -222,6 +229,14 @@ split_embedding_backward_count_unique_indices_kernel(
const int info_B_num_bits
);

}

{% if is_index_select %}
using namespace index_select;
{% else %}
using namespace embedding_ops;
{% endif %}

////////////////////////////////////////////////////////////////////////////////
// Utility Macros
////////////////////////////////////////////////////////////////////////////////
Expand Down
Loading