Skip to content

Commit

Permalink
Merge pull request #47 from Achazwl/fix-typo
Browse files Browse the repository at this point in the history
Fix typo
  • Loading branch information
Achazwl authored Sep 7, 2022
2 parents 95bf7e0 + 55e4d05 commit 6912b50
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 40 deletions.
40 changes: 4 additions & 36 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ In addition, BMTrain also provides the common LRScheduler in the `bmtrain.lr_sch

```python
# create a new instance of optimizer manager
optim_manager = bmtrain.optim.OptimManager()
optim_manager = bmtrain.optim.OptimManager(loss_scale=1024)
# let optim_manager handle all the optimizer and (optional) their corresponding lr_scheduler
optim_manager.add_optimizer(optimizer, lr_scheduler)
# add_optimizer can be called multiple times to add other optimizers.
Expand All @@ -301,12 +301,12 @@ for iteration in range(1000):
# zero grad
optim_manager.zero_grad() # calling zero_grad for each optimizer

# clip grad norm
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, max_norm=1.0)

# loss scale and backward
optim_manager.backward()

# clip grad norm
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, max_norm=1.0)

# optimizer step
optim_manager.step()

Expand All @@ -323,38 +323,6 @@ If you are not using the mixed-precision training, you can train without `loss_s

If you are using mixed-precision training, *loss scale* is the technique widely used in mixed precision training to prevent gradient underflow. By using `optim_manager.backward(loss)` to scale the `loss` before backward and set `loss_scale` to some floating number in the `__init__` function of `OptimManager`。The `loss_scale` would be adjusted adaptively based on the gradient during training.


```python
for iteration in range(1000):
# ... load data for each rank ...

# zero grad
optimizer.zero_grad()

# forward
pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1)
logits = model(
enc_input,
pos,
pos < enc_length[:, None]
)
batch, seq_len, vocab_out_size = logits.size()

loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len))

global_loss = bmtrain.sum_loss(loss).item() # sum the loss across all ranks

# backward
loss.backward()

# optimizer step
bmtrain.optim_step(optimizer, lr_scheduler)

# ... save checkpoint or print logs ...
```

Note that `bmtrain.optim_step` should be used instead of directly calling `optimizer.step()` and `lr_scheduler.step()`.

<div id="performance"></div>

## Performance
Expand Down
4 changes: 2 additions & 2 deletions bmtrain/optim/optim_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def add_optimizer(
self.optimizers.append(optimizer)
self.lr_schedulers.append(lr_scheduler)

def loss_scale(self, loss : torch.Tensor) -> torch.Tensor:
def scale_loss(self, loss : torch.Tensor) -> torch.Tensor:
return loss * (self.loss_scale / config['world_size']) # loss scale

def backward(self, loss : torch.Tensor):
Expand All @@ -90,7 +90,7 @@ def backward(self, loss : torch.Tensor):
Args:
loss (torch.Tensor): loss
"""
loss = self.loss_scale(loss)
loss = self.scale_loss(loss)
loss.backward()
# some reduce ops of distributed parameter were launched on load stream
current_stream = torch.cuda.current_stream()
Expand Down
3 changes: 2 additions & 1 deletion tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
("loss_func", 1),

("middle_hidden", 4),
("other_hidden", 4),
("model_wrapper", 4),
("send_recv", 4),
("nccl_backward", 4),

("trainging", 4),
("training", 4),
])

for t, num_gpu in tq:
Expand Down
1 change: 1 addition & 0 deletions tests/test_init_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def test_main():

for i in range(10):
ret[i] = ( ret[i][0].view(-1), ret[i][1].view(-1) )
print(ret[i])
for i in range(10):
for j in range(10):
assert_all_eq(ret[i][0], ret[j][0])
Expand Down
4 changes: 3 additions & 1 deletion tests/test_middle_hidden.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,12 @@ def test_main():
ret.append( run("zero", Model_ZERO) )
ret.append( run("pipe", Model_PIPE) )
for r in ret:
bmt.prnit_rank(r)
bmt.print_rank(r)
for r in ret:
for r2 in ret:
assert_eq(r, r2)

if __name__ == "__main__":
bmt.init_distributed(pipe_size=4)

test_main()
188 changes: 188 additions & 0 deletions tests/test_other_hidden.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from utils import *

import bmtrain as bmt
import random
import torch
from bmtrain import config
from bmtrain.block_layer import CheckpointBlock, TransformerBlockList
from bmtrain.pipe_layer import PipelineTransformerBlockList
import torch.nn.functional as F

class Linear(bmt.DistributedModule):
def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None:
super().__init__()

self.in_features = in_features
self.out_features = out_features
self.out = {}
if init_weight:
self.weight = bmt.DistributedParameter(torch.tensor(init_weight, dtype=torch.float, device="cuda").reshape(out_features, in_features))
else:
self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.xavier_normal_)

if init_bias:
self.bias = bmt.DistributedParameter(torch.tensor(init_bias, dtype=torch.float, device="cuda").reshape(out_features,))
else:
self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.zeros_)

def forward(self, input):
ret = F.linear(input, self.weight, self.bias)
return ret

class Model_ZERO(torch.nn.Module):
def __init__(self, pre, ms, post) -> None:
super().__init__()
self.pre = pre
self.ms = TransformerBlockList([
CheckpointBlock(m)
for m in ms
])
self.post = post

def forward(self, x, return_hidden_states=False):
x = self.pre(x)
if return_hidden_states:
x, o = self.ms(x, return_hidden_states=return_hidden_states)
return self.post(x), o
else:
x = self.ms(x, return_hidden_states=return_hidden_states)
return self.post(x)

class Model_PIPE(torch.nn.Module):
def __init__(self, pre, ms, post) -> None:
super().__init__()
self.pre = pre
self.ms = PipelineTransformerBlockList([
CheckpointBlock(m)
for m in ms
])
self.post = post

def forward(self, x, return_hidden_states=False):
x = self.pre(x)
if return_hidden_states:
x, o = self.ms(x, return_hidden_states=return_hidden_states)
return self.post(x), o
else:
x = self.ms(x, return_hidden_states=return_hidden_states)
return self.post(x)

class Model_BLOCK(torch.nn.Module):
def __init__(self, pre, ms, post) -> None:
super().__init__()
self.pre = pre
self.ms = torch.nn.ModuleList([
CheckpointBlock(m)
for m in ms
])
self.post = post

def forward(self, x, return_hidden_states=False):
x = self.pre(x)
o = []
y = x
for m in self.ms:
o.append(y)
y = m(y)
if return_hidden_states:
return self.post(y), o
else:
return self.post(y)

class Model_NORMAL(torch.nn.Module):
def __init__(self, pre, ms, post) -> None:
super().__init__()
self.pre = pre
self.ms = torch.nn.ModuleList(ms)
self.post = post

def forward(self, x, return_hidden_states=False):
x = self.pre(x)
o = []
y = x
for m in self.ms:
o.append(y)
y = m(y)
if return_hidden_states:
return self.post(y), o
else:
return self.post(y)

def manual_seed(seed=33):
torch.manual_seed(seed)
random.seed(seed)
try:
import numpy as np
np.random.seed(seed)
except ModuleNotFoundError:
pass

def sub_run(name, cls, num_layer, dim, batch, seq_len, only_pre=False, only_post=False, mix_test=False):
manual_seed()

pre = Linear(dim, dim)
post = Linear(dim, dim)
ms = [Linear(dim, dim) for i in range(num_layer)]

inp = torch.randn((batch, seq_len, dim)).cuda()
last_weight = torch.randn(pre.weight.shape).cuda()*10
middle_weight = [
torch.randn((batch, seq_len, dim)).cuda()
for i in range(len(ms))
]

bmt.init_parameters(pre)
bmt.init_parameters(post)
for m in ms:
bmt.init_parameters(m)
m = cls(pre, [m for m in ms], post)

ret = ""
if only_pre:
loss = (pre.weight * last_weight).sum()
loss.backward()
ret += f"========================only last========================\n"
ret += bmt.inspect.format_summary(
bmt.inspect.inspect_model(m, '*')
)
if only_post:
loss = (post.weight * last_weight).sum()
loss.backward()
ret += f"========================only middle========================\n"
ret += bmt.inspect.format_summary(
bmt.inspect.inspect_model(m, '*')
)
if mix_test:
loss = (pre.weight * last_weight).sum() + (post.weight * last_weight).sum()
loss.backward()
ret += f"========================mix========================\n"
ret += bmt.inspect.format_summary(
bmt.inspect.inspect_model(m, '*')
)
return ret.replace("None ", "0.0000") + "\n" # replace for matching None grad with zero_grad

def run(name, cls, num_layer=4, dim=4096, batch=32, seq_len=256):
ret = ""
ret += sub_run(name, cls, num_layer=num_layer, dim=dim, batch=batch, seq_len=seq_len, only_pre=True)
bmt.synchronize()
ret += sub_run(name, cls, num_layer=num_layer, dim=dim, batch=batch, seq_len=seq_len, only_post=True)
bmt.synchronize()
ret += sub_run(name, cls, num_layer=num_layer, dim=dim, batch=batch, seq_len=seq_len, mix_test=True)
bmt.synchronize()
return ret

def test_main():
ret = []
ret.append( run("normal", Model_NORMAL) )
ret.append( run("block", Model_BLOCK) )
ret.append( run("zero", Model_ZERO) )
ret.append( run("pipe", Model_PIPE) )
for r in ret:
bmt.print_rank(r)
for r in ret:
for r2 in ret:
assert_eq(r, r2)

if __name__ == "__main__":
bmt.init_distributed(pipe_size=1)
test_main()

0 comments on commit 6912b50

Please sign in to comment.