Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add swa_utils modules #9781

Merged
merged 43 commits into from
Jan 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
f572af4
small change for alignment with pytorch
process852 Jan 4, 2023
9972b75
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
process852 Jan 4, 2023
8b84be2
make of_format
process852 Jan 4, 2023
d9ee7e1
Merge branch 'master' into master
mergify[bot] Jan 4, 2023
451e5a1
Merge branch 'master' into master
mergify[bot] Jan 5, 2023
9c437dd
Merge branch 'master' into master
mergify[bot] Jan 5, 2023
1ff7c9b
Merge branch 'master' into master
mergify[bot] Jan 5, 2023
538571e
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
process852 Jan 6, 2023
a46b476
Merge branch 'master' of https://github.com/process852/oneflow
process852 Jan 6, 2023
e03b6bd
Merge branch 'master' into master
mergify[bot] Jan 6, 2023
5572ed4
Merge branch 'master' into master
mergify[bot] Jan 6, 2023
1be4cd7
fix merge conflict
process852 Jan 9, 2023
e6af4a8
make of_format
process852 Jan 9, 2023
09ca80f
Merge branch 'master' of https://github.com/process852/oneflow
process852 Jan 9, 2023
34f515b
clang_format check
process852 Jan 9, 2023
b670ddd
Merge branch 'master' into master
process852 Jan 9, 2023
14b9d0c
Merge branch 'master' into master
mergify[bot] Jan 9, 2023
9392854
Merge branch 'master' into master
mergify[bot] Jan 9, 2023
22a18d3
Merge branch 'master' into master
mergify[bot] Jan 9, 2023
494b506
add prune API
process852 Jan 10, 2023
ba5b481
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
process852 Jan 10, 2023
ccce6e0
Merge branch 'master' of https://github.com/process852/oneflow
process852 Jan 10, 2023
fa34dac
fix oneflow.norm use reshape
process852 Jan 10, 2023
f745722
add pytorch reference links
process852 Jan 11, 2023
8c79c16
make of_format
process852 Jan 11, 2023
ba1a51c
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
process852 Jan 11, 2023
96ecc57
add test unit of prune
process852 Jan 16, 2023
951621f
Merge branch 'master' into master
mergify[bot] Jan 16, 2023
11ebf6e
Merge branch 'master' into master
mergify[bot] Jan 16, 2023
f146c42
Merge branch 'master' into master
mergify[bot] Jan 16, 2023
7433aff
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
process852 Jan 17, 2023
cbb78d4
Merge branch 'master' of https://github.com/process852/oneflow
process852 Jan 17, 2023
82a48e8
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
process852 Jan 17, 2023
6ec4a73
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
process852 Jan 18, 2023
80d5a0b
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
process852 Jan 20, 2023
c4dee14
add swa_utils
process852 Jan 20, 2023
58a8021
change docs format
process852 Jan 28, 2023
94f85ee
delete notes
process852 Jan 29, 2023
2b943de
Merge branch 'master' into master
mergify[bot] Jan 29, 2023
493ca05
Merge branch 'master' into master
mergify[bot] Jan 29, 2023
69f1ce3
fix bn error
process852 Jan 29, 2023
1944221
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow
process852 Jan 29, 2023
36ce9ae
Merge branch 'master' of https://github.com/process852/oneflow
process852 Jan 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/oneflow/nn/modules/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,12 @@ def __init__(

def forward(self, x):
self._check_input_dim(x)
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked.add_(1)
if self.momentum is None:
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
if self.training:
is_training = True
else:
Expand All @@ -139,7 +142,7 @@ def forward(self, x):
self.bias,
axis=self.channel_axis,
epsilon=self.eps,
momentum=self.momentum,
momentum=exponential_average_factor,
is_training=is_training,
)

Expand Down
114 changes: 114 additions & 0 deletions python/oneflow/nn/optimizer/multiplicative_lr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import math

from ...optim.optimizer import Optimizer
from .lr_scheduler import LRScheduler


class MultiplicativeLR(LRScheduler):
"""Multiply the learning rate of each parameter group by the factor given
in the specified function. When last_epoch=-1, sets initial lr as lr.

The documentation is referenced from:
https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.MultiplicativeLR

Args:
optimizer (Optimizer): Wrapped optimizer.
lr_lambda (function or list): A function which computes a multiplicative
factor given an integer parameter epoch, or a list of such
functions, one for each group in optimizer.param_groups.
last_step (int): The index of last step. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``.

For example:

.. code-block:: python

import oneflow as flow

...
lmbda = lambda epoch: 0.95
step_lr = flow.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)
for epoch in range(num_epoch):
train(...)
step_lr.step()
"""

def __init__(self, optimizer, lr_lambda, last_step=-1, verbose=False):
self.optimizer = optimizer

if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
else:
if len(lr_lambda) != len(optimizer.param_groups):
raise ValueError(
"Expected {} lr_lambdas, but got {}".format(
len(optimizer.param_groups), len(lr_lambda)
)
)
self.lr_lambdas = list(lr_lambda)
super().__init__(optimizer, last_step, verbose)

def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.

It contains an entry for every variable in self.__dict__ which
is not the optimizer.
The learning rate lambda functions will only be saved if they are callable objects
and not if they are functions or lambdas.
"""
state_dict = {
key: value
for key, value in self.__dict__.items()
if key not in ("optimizer", "lr_lambdas")
}
state_dict["lr_lambdas"] = [None] * len(self.lr_lambdas)

for idx, fn in enumerate(self.lr_lambdas):
if not isinstance(fn, types.FunctionType):
state_dict["lr_lambdas"][idx] = fn.__dict__.copy()

return state_dict

def load_state_dict(self, state_dict):
"""Loads the schedulers state.

Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
lr_lambdas = state_dict.pop("lr_lambdas")
self.__dict__.update(state_dict)
state_dict["lr_lambdas"] = lr_lambdas

for idx, fn in enumerate(lr_lambdas):
if fn is not None:
self.lr_lambdas[idx].__dict__.update(fn)

def step(self):
"""Performs a single learning rate schedule step.

"""
self.last_step += 1
if self.last_step > 0:
lrs = [
group["lr"] * lmbda(self.last_step)
for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)
]
else:
lrs = [group["lr"] for group in self.optimizer.param_groups]
self.update_lrs(lrs)
Loading