Skip to content

Commit

Permalink
Optimzed backward pass for ROCm devices (#3367)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#491

Added optimized implementation of backward pass for ROCm devices. Currently support **not nobag** mode, **rowwise_adagrad** optimizer with non-mixed dimensions in [64, 128, 160, 192].


Reviewed By: leitian

Differential Revision: D66310520

Pulled By: q10
  • Loading branch information
avbokovoy authored and facebook-github-bot committed Dec 6, 2024
1 parent 9d78337 commit fca72b7
Show file tree
Hide file tree
Showing 16 changed files with 1,673 additions and 12 deletions.
1 change: 0 additions & 1 deletion .github/scripts/fbgemm_gpu_test.bash
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ __configure_fbgemm_gpu_test_cuda () {

ignored_tests=(
)

}

__configure_fbgemm_gpu_test_rocm () {
Expand Down
21 changes: 20 additions & 1 deletion fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -279,18 +279,27 @@ foreach(optimizer ${SSD_OPTIMIZERS})
"gen_embedding_backward_${optimizer}_ssd_${wdesc}_kernel_cta.cu"
"gen_embedding_backward_${optimizer}_ssd_${wdesc}_kernel_warp.cu")
endforeach()

foreach(wdesc weighted unweighted)
list(APPEND gen_gpu_kernel_source_files
"gen_embedding_backward_${optimizer}_ssd_${wdesc}_vbe_cuda.cu"
"gen_embedding_backward_${optimizer}_ssd_${wdesc}_vbe_kernel_cta.cu"
"gen_embedding_backward_${optimizer}_ssd_${wdesc}_vbe_kernel_warp.cu")
endforeach()

endforeach()

list(APPEND gen_defused_optim_py_files
${CMAKE_BINARY_DIR}/optimizer_args.py)

################################################################################
# FBGEMM_GPU Generated HIP-Specific Sources
################################################################################

set(gen_hip_kernel_source_files)
foreach(wdesc weighted unweighted unweighted_nobag)
list(APPEND gen_hip_kernel_source_files
"gen_embedding_backward_split_${wdesc}_device_kernel_hip.hip")
endforeach()

################################################################################
# FBGEMM_GPU Static Sources
Expand Down Expand Up @@ -426,6 +435,9 @@ set(fbgemm_gpu_sources_gpu_gen
${gen_gpu_host_source_files}
${gen_defused_optim_source_files})

set(fbgemm_gpu_sources_hip_gen
${gen_hip_kernel_source_files})

if(USE_ROCM)
prepend_filepaths(
PREFIX ${CMAKE_BINARY_DIR}
Expand All @@ -436,6 +448,11 @@ if(USE_ROCM)
PREFIX ${CMAKE_BINARY_DIR}
INPUT ${fbgemm_gpu_sources_gpu_gen}
OUTPUT fbgemm_gpu_sources_gpu_gen)

prepend_filepaths(
PREFIX ${CMAKE_BINARY_DIR}
INPUT ${fbgemm_gpu_sources_hip_gen}
OUTPUT fbgemm_gpu_sources_hip_gen)
endif()


Expand Down Expand Up @@ -478,6 +495,8 @@ gpu_cpp_library(
GPU_SRCS
${fbgemm_gpu_sources_gpu_static}
${fbgemm_gpu_sources_gpu_gen}
HIP_SPECIFIC_SRCS
${fbgemm_gpu_sources_hip_gen}
GPU_FLAGS
${TORCH_CUDA_OPTIONS}
DEPS
Expand Down
22 changes: 22 additions & 0 deletions fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,27 @@ def generate_backward_indices() -> None:
ssd=ssd,
)

@staticmethod
def generate_rocm_backward_split(**kwargs: Any) -> None:
# Generate backward device kernels based on weighted (True/False), VBE
# (True/False), no bag (True/False)
template_filepath = (
"training/backward/rocm/embedding_backward_split_device_kernel_template.hip"
)

BackwardSplitGenerator.render_backward_templates(
template_filepath,
"",
"{}gen_embedding_backward_{}_device_kernel_hip.hip",
{
"has_gpu_support": True,
"has_vbe_support": False,
"has_ssd_support": False,
"dense": False,
"gen_once": False,
},
)

@staticmethod
def generate_python_sources(
all_optimizers: List[str], ssd_optimizers: List[str]
Expand Down Expand Up @@ -369,6 +390,7 @@ def generate() -> None:
BackwardSplitGenerator.generate_backward_split(
ssd_tensors=ssd_tensors, **optimizer
)
BackwardSplitGenerator.generate_rocm_backward_split()

# Generate common device kernels for backwards
BackwardSplitGenerator.generate_backward_device()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ Tensor split_embedding_codegen_lookup_dense_function(
Tensor>& /* vbe_B_offsets_rank_per_feature = std::nullopt */,
c10::SymInt /* max_B = -1 */,
c10::SymInt /* max_B_feature_rank = -1 */,
c10::SymInt /* vbe_output_size = -1 */) {
c10::SymInt /* vbe_output_size = -1 */,
bool /* mixed_D = false */) {
return SplitLookupFunction_Dense_Op::apply(
host_weights,
weights_offsets,
Expand All @@ -190,15 +191,15 @@ Tensor split_embedding_codegen_lookup_dense_function(
// Deprecated for fb namespace! Please use fbgemm namespace instead!
TORCH_LIBRARY_FRAGMENT(fb, m) {
m.def(
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1) -> Tensor");
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=False) -> Tensor");
DISPATCH_TO_CPU(
"dense_embedding_codegen_lookup_function",
split_embedding_codegen_lookup_dense_function);
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1) -> Tensor");
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=False) -> Tensor");
DISPATCH_TO_CPU(
"dense_embedding_codegen_lookup_function",
split_embedding_codegen_lookup_dense_function);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ enum SSDTensor {
{%- else %}
D_offsets,
max_D,
mixed_D,
{%- endif %} {# /* if nobag */ #}
hash_size_cumsum,
total_hash_size_bits,
Expand Down Expand Up @@ -224,6 +225,7 @@ enum SSDTensor {
Variable(), // D_offsets
Variable(), // total_D
Variable(), // max_D
Variable(), // mixed_D
{%- endif %}
Variable(), // hash_size_cumsum
Variable(), //total_hash_size_bits
Expand Down Expand Up @@ -304,6 +306,7 @@ enum SSDTensor {
D_offsets,
total_D,
max_D,
mixed_D,
{%- endif %}
hash_size_cumsum,
total_hash_size_bits,
Expand Down Expand Up @@ -484,6 +487,7 @@ Tensor
{%- else %}
const Tensor& D_offsets,
const c10::SymInt max_D,
const bool mixed_D,
{%- endif %}
const Tensor& hash_size_cumsum,
const int64_t total_hash_size_bits,
Expand Down Expand Up @@ -566,6 +570,7 @@ class {{ autograd_func }} :
const Tensor& D_offsets,
const c10::SymInt total_D,
const c10::SymInt max_D,
const bool mixed_D,
{%- else %}
const c10::SymInt D,
{%- endif %}
Expand Down Expand Up @@ -762,6 +767,7 @@ class {{ autograd_func }} :

{%- if not nobag %}
ctx->saved_data["max_D"] = max_D;
ctx->saved_data["mixed_D"] = mixed_D;
ctx->saved_data["pooling_mode"] = pooling_mode;
{%- else %}
ctx->saved_data["D"] = D;
Expand Down Expand Up @@ -877,6 +883,7 @@ class {{ autograd_func }} :

{%- if not nobag %}
auto max_D = ctx->saved_data["max_D"].toSymInt();
const auto mixed_D = ctx->saved_data["mixed_D"].toBool();
auto pooling_mode = ctx->saved_data["pooling_mode"].toInt();
{%- else %}
auto D = ctx->saved_data["D"].toSymInt();
Expand Down Expand Up @@ -1072,10 +1079,11 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function(
{%- if ssd %}
const std::optional<at::TensorList>& ssd_tensors = std::nullopt,
{%- endif %}
const double gwd_lower_bound = 0
const double gwd_lower_bound = 0,
{%- else %}
const c10::SymInt vbe_output_size = -1
const c10::SymInt vbe_output_size = -1,
{%- endif %}
const bool mixed_D = false
) {
// TODO: refactor into macro
{%- if has_gpu_support %}
Expand Down Expand Up @@ -1191,7 +1199,8 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) {
{%- if ssd %}
" Tensor[]? ssd_tensors=None,"
{%- endif %}
" float gwd_lower_bound=0 "
" float gwd_lower_bound=0, "
" bool mixed_D=False"
") -> Tensor",
{PT2_COMPLIANT_TAG});

Expand Down
Loading

0 comments on commit fca72b7

Please sign in to comment.