Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

AdamW operator (Fixing Weight Decay Regularization in Adam) #13728

Merged
merged 3 commits into from
Dec 28, 2018

Conversation

eric-haibin-lin
Copy link
Member

@eric-haibin-lin eric-haibin-lin commented Dec 25, 2018

Description

Implement a modification of Adam in "Fixing Weight Decay Regularization in Adam" https://arxiv.org/abs/1711.05101.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

@eric-haibin-lin
Copy link
Member Author

@sxjscience @szhengac could you guys help review this PR?

@eric-haibin-lin eric-haibin-lin changed the title [WIP] AdamW optimizer AdamW optimizer (Fixing Weight Decay Regularization in Adam) Dec 25, 2018
rescaled_grad = clip(grad * rescale_grad, clip_gradient)
m = beta1 * m + (1 - beta1) * rescaled_grad
v = beta2 * v + (1 - beta2) * (rescaled_grad**2)
w = w - learning_rate * (m / (sqrt(v) + epsilon) + wd * w)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the paper, it has two learning rates. An alpha before m / (sqrt(v) + epsilon).

Copy link
Member Author

@eric-haibin-lin eric-haibin-lin Dec 26, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. The issue is that the learning rate and schedule multiplier is not decoupled in MXNet. Here learning_rate is effectively eta_t * alpha in the paper and wd actually needs to be set as w / alpha. In another word wd can be rescaled properly so that it does exactly the same thing in the paper. Would this be acceptable? Is so maybe I can move this to contrib for the moment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's acceptable as long as the wd is set correctly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought I think it's better to keep it consistent with the paper

Copy link
Contributor

@sandeep-krishnamurthy sandeep-krishnamurthy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.
Can you please provide an output of an end to end use case using AdamW optimizer?

@sandeep-krishnamurthy sandeep-krishnamurthy added Optimizer pr-awaiting-review PR is waiting for code review labels Dec 26, 2018
@eric-haibin-lin
Copy link
Member Author

@sandeep-krishnamurthy training/fine-tuning the BERT model in GluonNLP would be a use case of AdamW

kwargs['clip_gradient'] = self.clip_gradient

mean, var = state
adamw_update(weight, grad, mean, var, out=weight, lr=lr, wd=wd, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we set wd to something like wd / self._original_lr?

@eric-haibin-lin
Copy link
Member Author

@sxjscience @szhengac I took a step back and moved the operator to contrib and use the same notation as the one in the paper. I think the optimizer API still needs more discussion, so I removed it from the PR.

@eric-haibin-lin eric-haibin-lin changed the title AdamW optimizer (Fixing Weight Decay Regularization in Adam) AdamW operator (Fixing Weight Decay Regularization in Adam) Dec 27, 2018
@eric-haibin-lin eric-haibin-lin merged commit 116d01e into apache:master Dec 28, 2018
rondogency pushed a commit to rondogency/incubator-mxnet that referenced this pull request Jan 9, 2019
…3728)

* tests

* remove optimizer and move op to contrib

* rename parameter
haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
…3728)

* tests

* remove optimizer and move op to contrib

* rename parameter
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Optimizer pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants