-
Notifications
You must be signed in to change notification settings - Fork 501
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add new optimizer state
row_counter
for Adam [Backend] (#3342)
Summary: X-link: facebookresearch/FBGEMM#436 A new optional optimizer state `row_counter` is added to Adam to perform bias correction per embedding row. `row_counter` serves as the iteration counter when a row (an index) occurs and used to do bias correction. Without rowwise bias correction (existing Adam), ``` m_hat_t = m_t / (1.0 - powf(beta1, iter)); v_hat_t = v_t / (1.0 - powf(beta2, iter)); ``` With rowwise bias correction enabled. ``` // when index `idx` occurs _row_counter = row_counter[idx] + 1; m_hat_t = m_t / (1.0 - powf(beta1, _row_counter)); v_hat_t = v_t / (1.0 - powf(beta2, _row_counter)); ``` This request is from IG to allow all the models to be scaled on sparse features with expected 1.5% NE on Stories. ------- **__The functionality is not set by default.__** Frontend: D64848802 To enable the bias correction, `use_rowwise_bias_correction` needs to be set to True through extra_optimizer_config. ``` extra_optimizer_config = UserEnabledConfigDefinition(use_rowwise_bias_correction=True) emb_op = SplitTableBatchedEmbeddingBagsCodegen ( embedding_specs=[ (E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed) ], optimizer=OptimType.Adam extra_optimizer_config=extra_optimizer_config, ... ) ``` ------ **__Performance from Kineto__** (unweighted) ``` Baseline* | default** | enabled*** forward | cpu | 2.293 s | 2.188 s | 2.043 s | cuda | 12.512 ms | 12.539 ms | 12.547 ms backward | cpu | 69.861 ms | 66.546 ms | 65.880 ms | cuda | 103.429 ms | 103.395 ms | 103.130 ms ``` \* Baseline: before changes \** default: default setting; use_bias_correction = False \*** enabled: use_bias_correction = True Reviewed By: sryap Differential Revision: D64808460
- Loading branch information
1 parent
d9d4066
commit f2cf409
Showing
10 changed files
with
375 additions
and
106 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,7 @@ class ArgType(IntEnum): | |
INT = 7 | ||
FLOAT = 8 | ||
SYM_INT = 9 | ||
BOOL = 10 | ||
|
||
|
||
@dataclass | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters