Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add grad scale for optim_manager && fix workflow action #195

Merged
merged 2 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ on:
branches:
- 'dev'
- 'main'
push:
branches:
- 'dev'

jobs:
build-archive-wheel:

uses: OpenBMB/BMTrain/.github/workflows/build_whl.yml@main
secrets:
DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
secrets: inherit

publish:
needs: build-archive-wheel
Expand Down
6 changes: 5 additions & 1 deletion bmtrain/optim/optim_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self,
loss_scale_steps : int = 1024,
min_loss_scale = 1,
max_loss_scale = float("inf"),
grad_scale : Optional[int] = None,
):
if loss_scale is not None:
self.loss_scale = loss_scale
Expand All @@ -64,6 +65,9 @@ def __init__(self,
self.loss_scale_steps = loss_scale_steps
self.min_loss_scale = min_loss_scale
self.max_loss_scale = max_loss_scale
if grad_scale is None:
grad_scale = config['zero_size']
self.grad_scale = grad_scale

self.optimizers = []
self.lr_schedulers = []
Expand All @@ -85,7 +89,7 @@ def add_optimizer(

def scale_loss(self, loss : torch.Tensor) -> torch.Tensor:

return loss * (self.loss_scale / (config['world_size']//(config['tp_size']*config['pipe_size']))) # loss scale
return loss * ( self.loss_scale / self.grad_scale ) # loss scale

def backward(self, loss : torch.Tensor):
"""
Expand Down
8 changes: 6 additions & 2 deletions bmtrain/synchronize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from . import distributed, nccl
from .global_var import config
import warnings
from typing import Optional

def synchronize():
"""
Expand All @@ -24,14 +25,17 @@ def wait_loader():
config['calc_stream'].record_event(config['load_event'])


def sum_loss(loss : torch.Tensor):
def sum_loss(loss : torch.Tensor, comm: Optional[nccl.NCCLCommunicator] = None):
"""
Sum the loss across all workers.

This is a helper function to reduce the loss across all workers.
"""
if comm is None:
comm = config['comm']
warnings.warn("bmtrain.sum_loss is deprecated and will be removed in later version. Use bmtrain.distributed.all_reduce instead.", DeprecationWarning)
return distributed.all_reduce(loss, "sum") / config['world_size']

return distributed.all_reduce(loss, "avg", comm)

def gather_result(result: torch.Tensor):
warnings.warn("bmtrain.gather_result is deprecated and will be removed in later version. Use bmtrain.distributed.all_gather instead.", DeprecationWarning)
Expand Down
Loading