Skip to content

Commit

Permalink
Add new optimizer state row_counter for Adam [Backend] (#3342)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#436

A new optional optimizer state `row_counter` is added to Adam to perform bias correction per embedding row. `row_counter` serves as  the iteration counter when a row (an index) occurs and used to do bias correction.


Without rowwise bias correction (existing Adam),
```
m_hat_t = m_t / (1.0 - powf(beta1, iter));
v_hat_t = v_t / (1.0 - powf(beta2, iter));
```

With rowwise bias correction enabled.
```
// when index `idx` occurs
_row_counter = row_counter[idx] + 1;
m_hat_t = m_t / (1.0 - powf(beta1, _row_counter));
v_hat_t = v_t / (1.0 - powf(beta2, _row_counter));
```

This request is from IG to allow all the models to be scaled on sparse features with expected 1.5% NE on Stories.

-------

**__The functionality is not set by default.__** Frontend: D64848802

To enable the bias correction, `use_rowwise_bias_correction` needs to be set to True through extra_optimizer_config. 
```
extra_optimizer_config = UserEnabledConfigDefinition(use_rowwise_bias_correction=True)
emb_op = SplitTableBatchedEmbeddingBagsCodegen
(
            embedding_specs=[
                (E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed)
            ],
            optimizer=OptimType.Adam
            extra_optimizer_config=extra_optimizer_config,
            ...
)
```
------
**__Performance from Kineto__** (unweighted)
```
                   Baseline* |  default** | enabled*** 
forward  | cpu  |   2.293 s  |   2.188 s  |   2.043 s
         | cuda |  12.512 ms |  12.539 ms |  12.547 ms
backward | cpu  |  69.861 ms |  66.546 ms |  65.880 ms
         | cuda | 103.429 ms | 103.395 ms | 103.130 ms
```
\* Baseline: before changes
\** default: default setting; use_bias_correction = False
\*** enabled: use_bias_correction = True

Reviewed By: sryap

Differential Revision: D64808460
  • Loading branch information
spcyppt authored and facebook-github-bot committed Nov 13, 2024
1 parent d9d4066 commit f2cf409
Show file tree
Hide file tree
Showing 10 changed files with 375 additions and 106 deletions.
353 changes: 284 additions & 69 deletions fbgemm_gpu/codegen/genscript/optimizer_args.py

Large diffs are not rendered by default.

37 changes: 31 additions & 6 deletions fbgemm_gpu/codegen/genscript/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,29 @@ def partial_rowwise_lamb() -> Dict[str, Any]:


def adam() -> Dict[str, Any]:
split_precomputation = """
at::acc_type<cache_t, true>* __restrict__ row_counter;
at::acc_type<cache_t, true> _row_counter = iter;
if (use_rowwise_bias_correction) {
const auto row_counter_placement = static_cast<PlacementType>(row_counter_placements[t]);
const int64_t row_counter_offset = row_counter_offsets[t];
if (row_counter_placement == PlacementType::DEVICE) {
row_counter = &row_counter_dev[row_counter_offset];
} else {
row_counter = &row_counter_uvm[row_counter_offset];
}
// need to compute bias correction for each row
if (threadIdx.x == 0) {
_row_counter = row_counter[idx] + 1;
row_counter[idx] = _row_counter;
}
// broadcast bias correction to all threads
_row_counter = SHFL_SYNC(_row_counter, 0);
}
"""

split_weight_update = """
Vec4T<cache_t> m_t(&momentum1[idx * D + d]);
m_t.acc.x *= beta1;
Expand All @@ -1023,10 +1046,10 @@ def adam() -> Dict[str, Any]:
v_t.fma_(grad, 1.0 - beta2);
v_t.store(&momentum2[idx * D + d]);
weight_new.acc.x -= learning_rate * (m_t.acc.x / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.x / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.x);
weight_new.acc.y -= learning_rate * (m_t.acc.y / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.y / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.y);
weight_new.acc.z -= learning_rate * (m_t.acc.z / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.z / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.z);
weight_new.acc.w -= learning_rate * (m_t.acc.w / (1.0 - powf(beta1, iter)) / (sqrtf((v_t.acc.w / (1.0 - powf(beta2, iter)))) + eps) + weight_decay * weight_new.acc.w);
weight_new.acc.x -= learning_rate * (m_t.acc.x / (1.0 - powf(beta1, _row_counter)) / (sqrtf((v_t.acc.x / (1.0 - powf(beta2, _row_counter)))) + eps) + weight_decay * weight_new.acc.x);
weight_new.acc.y -= learning_rate * (m_t.acc.y / (1.0 - powf(beta1, _row_counter)) / (sqrtf((v_t.acc.y / (1.0 - powf(beta2, _row_counter)))) + eps) + weight_decay * weight_new.acc.y);
weight_new.acc.z -= learning_rate * (m_t.acc.z / (1.0 - powf(beta1, _row_counter)) / (sqrtf((v_t.acc.z / (1.0 - powf(beta2, _row_counter)))) + eps) + weight_decay * weight_new.acc.z);
weight_new.acc.w -= learning_rate * (m_t.acc.w / (1.0 - powf(beta1, _row_counter)) / (sqrtf((v_t.acc.w / (1.0 - powf(beta2, _row_counter)))) + eps) + weight_decay * weight_new.acc.w);
"""
split_weight_update_cpu = "" # TODO

Expand All @@ -1043,12 +1066,14 @@ def adam() -> Dict[str, Any]:
OptimItem(ArgType.FLOAT, "beta2"),
OptimItem(ArgType.FLOAT, "weight_decay"),
OptimItem(ArgType.INT, "iter"),
OptimItem(ArgType.BOOL, "use_rowwise_bias_correction"),
OptimItem(ArgType.TENSOR, "row_counter", is_optional=True),
],
{
"v1": "Tensor momentum1, Tensor momentum2, float learning_rate = 0, float eps = 0, float beta1 = 0, float beta2 = 0, float weight_decay = 0, int iter = 0"
"v1": "Tensor momentum1, Tensor momentum2, float learning_rate = 0, float eps = 0, float beta1 = 0, float beta2 = 0, float weight_decay = 0, int iter = 0, bool use_rowwise_bias_correction = False, Tensor? row_counter = None",
},
),
"split_precomputation": "",
"split_precomputation": split_precomputation,
"split_weight_update": split_weight_update,
"split_post_update": "",
"split_weight_update_cpu": split_weight_update_cpu,
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/codegen/genscript/torch_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class ArgType(IntEnum):
INT = 7
FLOAT = 8
SYM_INT = 9
BOOL = 10


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ class SplitLookupFunction_{{ optimizer }}_Op : public torch::autograd::Function<
bool gradient_clipping,
double max_gradient,
bool stochastic_rounding,
{{ args.split_function_args | join(", ") }},
{{ args.split_function_args_autograd | join(", ") }},
int64_t output_dtype = static_cast<int64_t>(SparseType::FP32)) {
Tensor indice_weights_value = indice_weights.value_or(Tensor());
Tensor feature_requires_grad_value =
feature_requires_grad.value_or(Tensor());
ctx->save_for_backward({
host_weights, weights_placements, weights_offsets, D_offsets, hash_size_cumsum,
indices, offsets, indice_weights_value, feature_requires_grad_value, {{ args.split_saved_tensors | join(", ") }} });
indices, offsets, indice_weights_value, feature_requires_grad_value, {{ args.split_saved_tensors_optional | join(", ") }} });

ctx->saved_data["total_D"] = total_D;
ctx->saved_data["max_D"] = max_D;
Expand Down Expand Up @@ -242,7 +242,7 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function_cpu(
gradient_clipping,
max_gradient,
stochastic_rounding,
{{ args.split_function_arg_names | join(", ") }},
{{ args.split_function_arg_names_autograd | join(", ") }},
output_dtype)[0];
{% else %}
TORCH_CHECK(false, "split_embedding_codegen_lookup_{{ optimizer }}_function_cpu is deprecated. Please see https://github.com/pytorch/FBGEMM/discussions/1727 for more detail.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ enum SSDTensor {
{%- if ssd %}
ssd_tensors.value(),
{%- endif %}
{{ args.split_function_arg_names | join(", ") }}
{{ args.split_function_arg_names_autograd | join(", ") }}
{%- endif %}
)[0];
{%- endmacro %}
Expand Down Expand Up @@ -618,7 +618,7 @@ class {{ autograd_func }} :
{%- if ssd %}
const at::TensorList& ssd_tensors,
{%- endif %}
{{ args.split_function_args | join(", ") }}
{{ args.split_function_args_autograd | join(", ") }}
{%- else %}
{%- if vbe %}
const std::optional<Tensor>& B_offsets,
Expand Down Expand Up @@ -757,7 +757,7 @@ class {{ autograd_func }} :
ssd_tensors[SSDTensor::{{ tensor | upper }}],
{%- endfor %}
{%- endif %}
{{ args.split_saved_tensors | join(", ") }}
{{ args.split_saved_tensors_optional | join(", ") }}
});

{%- if not nobag %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -476,12 +476,35 @@ enum SSDTensor {

/* This macro generates a code blob for unpacking the tensor list
*/
{%- macro unpack_tensor_list(tensor_list) %}
const Tensor {{ tensor_list }}_host = {{ tensor_list }}[0];
const Tensor {{ tensor_list }}_dev = {{ tensor_list }}[1];
const Tensor {{ tensor_list }}_uvm = {{ tensor_list }}[2];
const Tensor {{ tensor_list }}_placements = {{ tensor_list }}[3];
const Tensor {{ tensor_list }}_offsets = {{ tensor_list }}[4];
{%- macro unpack_tensorlist(name) %}
const Tensor {{ name }}_host = {{ name }}[0];
const Tensor {{ name }}_dev = {{ name }}[1];
const Tensor {{ name }}_uvm = {{ name }}[2];
const Tensor {{ name }}_placements = {{ name }}[3];
const Tensor {{ name }}_offsets = {{ name }}[4];
{%- endmacro %}

{%- macro unpack_tensorlist_optional(name) %}
Tensor {{ name }}_host;
Tensor {{ name }}_dev;
Tensor {{ name }}_uvm;
Tensor {{ name }}_placements;
Tensor {{ name }}_offsets;
if ({{ name }}.has_value()) {
at::TensorList _{{ name }} = {{ name }}.value();
{{ name }}_host = _{{ name }}[0];
{{ name }}_dev = _{{ name }}[1];
{{ name }}_uvm = _{{ name }}[2];
{{ name }}_placements = _{{ name }}[3];
{{ name }}_offsets = _{{ name }}[4];
}
else{
{{ name }}_host = at::empty({0}, weights_host.options());
{{ name }}_dev = at::empty({0}, weights_dev.options());
{{ name }}_uvm = at::empty({0}, weights_uvm.options());
{{ name }}_placements = at::empty({0}, weights_placements.options());
{{ name }}_offsets = at::empty({0}, weights_offsets.options());
}
{%- endmacro %}


Expand Down Expand Up @@ -581,9 +604,12 @@ class {{ autograd_func }} :
{{ args_pt2.unified_pt2.split_function_args | join(", ") }}) {

// unpack Tensor lists
{{ unpack_tensor_list("weights") }}
{%- for arg_name in args_pt2.unified_pt2.split_saved_tensor_list %}
{{ unpack_tensor_list(arg_name) }}
{{ unpack_tensorlist("weights") }}
{%- for arg_name in args_pt2.unified_pt2.split_saved_tensorlist %}
{{ unpack_tensorlist(arg_name) }}
{%- endfor %}
{%- for arg_name in args_pt2.unified_pt2.split_saved_tensorlist_optional %}
{{ unpack_tensorlist_optional(arg_name) }}
{%- endfor %}

const auto T = weights_offsets.sym_numel();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def invoke(
{%- if "prev_iter_dev" in args.split_function_arg_names %}
prev_iter: Momentum,
{%- endif %}
{%- if "row_counter_dev" in args.split_function_arg_names %}
{%- if "row_counter_dev" in args.split_function_arg_names and "row_counter" not in args_pt2.unified_pt2.split_saved_tensorlist_optional %}
row_counter: Momentum,
{%- endif %}
{%- if "iter" in args.split_function_arg_names %}
Expand Down Expand Up @@ -209,7 +209,7 @@ def invoke(
prev_iter_placements=prev_iter.placements,
{%- endif %}
# row_counter
{%- if "row_counter_dev" in args.split_function_arg_names %}
{%- if "row_counter_dev" in args.split_function_arg_names and "row_counter" not in args_pt2.unified_pt2.split_saved_tensorlist_optional %}
row_counter_host=row_counter.host,
row_counter_offsets=row_counter.offsets,
row_counter_placements=row_counter.placements,
Expand Down Expand Up @@ -387,7 +387,7 @@ def invoke(
prev_iter_dev=prev_iter_dev,
{%- endif %}
# row_counter
{%- if "row_counter_dev" in args.split_function_arg_names %}
{%- if "row_counter_dev" in args.split_function_arg_names and "row_counter" not in args_pt2.unified_pt2.split_saved_tensorlist_optional %}
row_counter_dev=row_counter.dev,
row_counter_uvm=row_counter.uvm,
row_counter_offsets=row_counter.offsets,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,6 @@
"gwd_lower_bound": st.sampled_from([0, 0.01, 0.001]),
}

additional_decorators.update(
{
# learning rate tensor needs to be on CPU to avoid D->H sync point since it will be used as float in the kernel
# this fails fake_tensor test as the test expects all tensors to be on the same device
"test_pt2_compliant_tag_fbgemm_split_embedding_codegen_lookup_rowwise_adagrad_function": [
unittest.skip(
"Operator failed on FakeTensor test since learning rate tensor is always on CPU regardless of other tensors"
),
]
}
)


def compare_output(
output_ref: torch.Tensor,
Expand Down
7 changes: 7 additions & 0 deletions fbgemm_gpu/test/tbe/training/forward_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@
"test_faketensor__test_forward_gpu_uvm_cache_int8": [
unittest.skip("Operator not implemented for Meta tensors"),
],
# learning rate tensor needs to be on CPU to avoid D->H sync point since it will be used as float in the kernel
# this fails fake_tensor test as the test expects all tensors to be on the same device
"test_pt2_compliant_tag_fbgemm_split_embedding_codegen_lookup_rowwise_adagrad_function": [
unittest.skip(
"Operator failed on FakeTensor test since learning rate tensor is always on CPU regardless of other tensors"
),
],
}
)

Expand Down
9 changes: 8 additions & 1 deletion fbgemm_gpu/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@
# fake_tensor test is added in failures_dict but failing fake_tensor test still cause pt2_compliant tag test to fail
"test_pt2_compliant_tag_fbgemm_split_embedding_codegen_lookup_rowwise_adagrad_function_pt2": [
unittest.skip("Operator failed on pt2 compliant tag"),
]
],
# learning rate tensor needs to be on CPU to avoid D->H sync point since it will be used as float in the kernel
# this fails fake_tensor test as the test expects all tensors to be on the same device
"test_pt2_compliant_tag_fbgemm_split_embedding_codegen_lookup_rowwise_adagrad_function": [
unittest.skip(
"Operator failed on FakeTensor test since learning rate tensor is always on CPU regardless of other tensors"
),
],
}

# Used for `@unittest.skipIf`
Expand Down

0 comments on commit f2cf409

Please sign in to comment.