Skip to content

Commit

Permalink
Optimzed backward pass for ROCm devices (pytorch#3367)
Browse files Browse the repository at this point in the history
Summary:
Added optimized implementation of backward pass for ROCm devices. Currently support **not nobag** mode, **rowwise_adagrad** optimizer with non-mixed dimensions in [64, 128, 160, 192].

Pull Request resolved: pytorch#3367

Differential Revision: D66310520

Pulled By: leitian
  • Loading branch information
avbokovoy authored and facebook-github-bot committed Dec 5, 2024
1 parent 09b7b96 commit f8102c0
Show file tree
Hide file tree
Showing 14 changed files with 1,643 additions and 10 deletions.
22 changes: 22 additions & 0 deletions fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,27 @@ def generate_backward_indices() -> None:
ssd=ssd,
)

@staticmethod
def generate_rocm_backward_split(**kwargs: Any) -> None:
# Generate backward device kernels based on weighted (True/False), VBE
# (True/False), no bag (True/False)
template_filepath = (
"training/backward/rocm/embedding_backward_split_device_kernel_template.hip"
)

BackwardSplitGenerator.render_backward_templates(
template_filepath,
"",
"{}gen_embedding_backward_{}_device_kernel_hip.hip",
{
"has_gpu_support": True,
"has_vbe_support": False,
"has_ssd_support": False,
"dense": False,
"gen_once": False,
},
)

@staticmethod
def generate_python_sources(
all_optimizers: List[str], ssd_optimizers: List[str]
Expand Down Expand Up @@ -369,6 +390,7 @@ def generate() -> None:
BackwardSplitGenerator.generate_backward_split(
ssd_tensors=ssd_tensors, **optimizer
)
BackwardSplitGenerator.generate_rocm_backward_split(**optimizer)

# Generate common device kernels for backwards
BackwardSplitGenerator.generate_backward_device()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ Tensor split_embedding_codegen_lookup_dense_function(
Tensor>& /* vbe_B_offsets_rank_per_feature = std::nullopt */,
c10::SymInt /* max_B = -1 */,
c10::SymInt /* max_B_feature_rank = -1 */,
c10::SymInt /* vbe_output_size = -1 */) {
c10::SymInt /* vbe_output_size = -1 */,
bool /* mixed_D = false */) {
return SplitLookupFunction_Dense_Op::apply(
host_weights,
weights_offsets,
Expand All @@ -190,15 +191,15 @@ Tensor split_embedding_codegen_lookup_dense_function(
// Deprecated for fb namespace! Please use fbgemm namespace instead!
TORCH_LIBRARY_FRAGMENT(fb, m) {
m.def(
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1) -> Tensor");
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=False) -> Tensor");
DISPATCH_TO_CPU(
"dense_embedding_codegen_lookup_function",
split_embedding_codegen_lookup_dense_function);
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1) -> Tensor");
"dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=False) -> Tensor");
DISPATCH_TO_CPU(
"dense_embedding_codegen_lookup_function",
split_embedding_codegen_lookup_dense_function);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ enum SSDTensor {
{%- else %}
D_offsets,
max_D,
mixed_D,
{%- endif %} {# /* if nobag */ #}
hash_size_cumsum,
total_hash_size_bits,
Expand Down Expand Up @@ -224,6 +225,7 @@ enum SSDTensor {
Variable(), // D_offsets
Variable(), // total_D
Variable(), // max_D
Variable(), // mixed_D
{%- endif %}
Variable(), // hash_size_cumsum
Variable(), //total_hash_size_bits
Expand Down Expand Up @@ -304,6 +306,7 @@ enum SSDTensor {
D_offsets,
total_D,
max_D,
mixed_D,
{%- endif %}
hash_size_cumsum,
total_hash_size_bits,
Expand Down Expand Up @@ -484,6 +487,7 @@ Tensor
{%- else %}
const Tensor& D_offsets,
const c10::SymInt max_D,
const bool mixed_D,
{%- endif %}
const Tensor& hash_size_cumsum,
const int64_t total_hash_size_bits,
Expand Down Expand Up @@ -566,6 +570,7 @@ class {{ autograd_func }} :
const Tensor& D_offsets,
const c10::SymInt total_D,
const c10::SymInt max_D,
const bool mixed_D,
{%- else %}
const c10::SymInt D,
{%- endif %}
Expand Down Expand Up @@ -762,6 +767,7 @@ class {{ autograd_func }} :

{%- if not nobag %}
ctx->saved_data["max_D"] = max_D;
ctx->saved_data["mixed_D"] = mixed_D;
ctx->saved_data["pooling_mode"] = pooling_mode;
{%- else %}
ctx->saved_data["D"] = D;
Expand Down Expand Up @@ -877,6 +883,7 @@ class {{ autograd_func }} :

{%- if not nobag %}
auto max_D = ctx->saved_data["max_D"].toSymInt();
const auto mixed_D = ctx->saved_data["mixed_D"].toBool();
auto pooling_mode = ctx->saved_data["pooling_mode"].toInt();
{%- else %}
auto D = ctx->saved_data["D"].toSymInt();
Expand Down Expand Up @@ -1072,10 +1079,11 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function(
{%- if ssd %}
const std::optional<at::TensorList>& ssd_tensors = std::nullopt,
{%- endif %}
const double gwd_lower_bound = 0
const double gwd_lower_bound = 0,
{%- else %}
const c10::SymInt vbe_output_size = -1
const c10::SymInt vbe_output_size = -1,
{%- endif %}
const bool mixed_D = false
) {
// TODO: refactor into macro
{%- if has_gpu_support %}
Expand Down Expand Up @@ -1191,7 +1199,8 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) {
{%- if ssd %}
" Tensor[]? ssd_tensors=None,"
{%- endif %}
" float gwd_lower_bound=0 "
" float gwd_lower_bound=0, "
" bool mixed_D=False"
") -> Tensor",
{PT2_COMPLIANT_TAG});

Expand Down
Loading

0 comments on commit f8102c0

Please sign in to comment.