Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support multiple input-output in transformerblocklist #92

Merged
merged 1 commit into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 53 additions & 29 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,29 +680,30 @@ def __repr__(self):

class OpTransformerBlockList(torch.autograd.Function):
@staticmethod
def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, hidden_state, *args):
def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, num_hidden, *args):
tensors = []
others = []
for arg in args:
for arg in args[num_hidden:]:
if torch.is_tensor(arg):
tensors.append(arg)
others.append(None)
else:
tensors.append(None)
others.append(arg)
hidden_states = args[:num_hidden]

ctx.nontensor_inputs = others
ctx.self = self
ctx.save_list = copy.deepcopy(save_list)
ctx.num_save_needed = save_list[-1][1]+1
ctx.layers_dict=[{} for _ in range(len(self))]
ctx.layers_dict = [{} for _ in range(len(self))]
layer_inputs = []
layer_inspector = []
cuda_rng_state = []
for i in range(len(self)):
with torch.no_grad():
if save_list[i][0] == i:
layer_inputs.append(hidden_state.detach())
layer_inputs += [hidden_state.detach() for hidden_state in hidden_states]
cuda_rng_state.append( torch.cuda.get_rng_state() )
if config['zero_level']==2:
flag = 1
Expand All @@ -713,29 +714,38 @@ def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, hidden_s
block_ctx.enter()
# call inner module directly
with ScopedTensorInspectorContext() as inspector:
hidden_state = self._modules[str(i)]._module._call_impl(hidden_state, *args)
hidden_states = self._modules[str(i)]._module._call_impl(*hidden_states, *args[num_hidden:])
if not isinstance(hidden_states, tuple):
hidden_states = (hidden_states,)
block_ctx.exit()
for it in inspector.hidden_states:
debug.append("_inspect_hidden_states", it)
layer_inspector.append(inspector.hidden_states)

ctx.layer_inspector = layer_inspector
ctx.cuda_rng_state = cuda_rng_state
ctx.num_hidden = num_hidden

ctx.save_for_backward(*layer_inputs, *tensors)

if self.return_hidden_states:
middle_hiddens = layer_inputs
for mid in middle_hiddens:
mid.requires_grad_()
middle_hiddens = torch.stack(middle_hiddens, dim=0)
middle_hiddens = [
torch.stack(middle_hiddens[i::num_hidden], dim=0)
for i in range(num_hidden)
]
else:
middle_hiddens = None
return tuple([hidden_state, middle_hiddens] + [it["tensor"] for inspector_hiddens in ctx.layer_inspector for it in inspector_hiddens])
middle_hiddens = [None] * num_hidden
return tuple(list(hidden_states) + middle_hiddens + [it["tensor"] for inspector_hiddens in ctx.layer_inspector for it in inspector_hiddens])


@staticmethod
def backward(ctx, grad_hidden_state : torch.Tensor, grad_middle: List[torch.Tensor], *grad_inspectors):
def backward(ctx, *grads):
grad_hidden_states = grads[:ctx.num_hidden]
grad_middles = grads[ctx.num_hidden:2*ctx.num_hidden]
grad_inspectors = grads[2*ctx.num_hidden:]
def exit_prev(prev_ctx, prev_grad):
if prev_ctx is not None:
if prev_grad:
Expand All @@ -755,8 +765,8 @@ def exit_prev(prev_ctx, prev_grad):
all_inputs = []
input_requires_grad = []

layer_inputs = ctx.saved_tensors[:ctx.num_save_needed]
save_args = ctx.saved_tensors[ctx.num_save_needed:]
layer_inputs = ctx.saved_tensors[:ctx.num_save_needed * ctx.num_hidden]
save_args = ctx.saved_tensors[ctx.num_save_needed * ctx.num_hidden:]
for tensor, other in zip(save_args, ctx.nontensor_inputs):
if tensor is None:
all_inputs.append(other)
Expand Down Expand Up @@ -786,14 +796,23 @@ def exit_prev(prev_ctx, prev_grad):
block_ctx = CheckpointBlockContext(ctx.self._modules[str(j)], ctx.layers_dict[j], flag)
block_ctx.enter()
exit_prev(prev_ctx, prev_grad)
output = ctx.self._modules[str(j)]._module._call_impl(layer_inputs[ctx.save_list[j][1]], *all_inputs)
outputs = ctx.self._modules[str(j)]._module._call_impl(
layer_inputs[ctx.save_list[j][1]*ctx.num_hidden: ctx.save_list[j][1]*ctx.num_hidden+ctx.num_hidden],
*all_inputs
)
if not isinstance(outputs, tuple):
outputs = (outputs,)
prev_ctx = block_ctx
prev_grad = False
layer_inputs[ctx.save_list[j+1][1]].copy_(output)
for k, output in enumerate(outputs):
layer_inputs[ctx.save_list[j+1][1]*ctx.num_hidden + k].copy_(output)
ctx.save_list[j+1][0] = j+1

torch.cuda.set_rng_state(ctx.cuda_rng_state[i])
ipt = layer_inputs[ctx.save_list[i][1]].detach().requires_grad_()
ipts = [
layer_inputs[ctx.save_list[i][1]*ctx.num_hidden + k].detach().requires_grad_()
for k in range(ctx.num_hidden)
]
if config['zero_level'] == 2:
flag = 2
else:
Expand All @@ -805,7 +824,9 @@ def exit_prev(prev_ctx, prev_grad):
prev_grad = True

with ScopedTensorInspectorContext() as inspector:
output = ctx.self._modules[str(i)]._module._call_impl(ipt, *all_inputs)
outputs = ctx.self._modules[str(i)]._module._call_impl(*ipts, *all_inputs)
if not isinstance(outputs, tuple):
outputs = (outputs,)

assert len(ctx.layer_inspector[i]) == len(inspector.hidden_states), "Backward step changed"
for j, it in enumerate(inspector.hidden_states):
Expand All @@ -818,18 +839,20 @@ def exit_prev(prev_ctx, prev_grad):
ctx.layer_inspector[i][j]["requires_grad"] = it["requires_grad"]
if len(inspector.hidden_states) > 0:
torch.autograd.backward(
[output] + [hidden_state["tensor"] for hidden_state in inspector.hidden_states],
(grad_hidden_state,) + grad_inspectors[-len(inspector.hidden_states):],
list(outputs) + [hidden_state["tensor"] for hidden_state in inspector.hidden_states],
grad_hidden_states + grad_inspectors[-len(inspector.hidden_states):],
)
grad_inspectors = grad_inspectors[:-len(inspector.hidden_states)]
else:
torch.autograd.backward(
[output],
(grad_hidden_state,),
outputs,
grad_hidden_states,
)
grad_hidden_state = ipt.grad
if grad_middle is not None:
grad_hidden_state = grad_hidden_state + grad_middle[i]
grad_hidden_states = [ipt.grad for ipt in ipts]
for k in range(ctx.num_hidden):
if grad_middles[k] is not None:
grad_hidden_states[k] = grad_hidden_states[k] + grad_middles[k][i]
grad_hidden_states = tuple(grad_hidden_states)

exit_prev(prev_ctx, prev_grad)

Expand All @@ -839,7 +862,7 @@ def exit_prev(prev_ctx, prev_grad):
grads.append(inp.grad)
else:
grads.append(None)
return (None, None, None, grad_hidden_state) + tuple(grads)
return (None, None, None, None) + tuple(grad_hidden_states) + tuple(grads)

class TransformerBlockList(torch.nn.Module):
r"""
Expand All @@ -862,7 +885,7 @@ class TransformerBlockList(torch.nn.Module):
"""
_modules: Dict[str, CheckpointBlock]

def __init__(self, modules: Iterable[CheckpointBlock], sqrt=False) -> None:
def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) -> None:
super().__init__()

self._modules = {}
Expand All @@ -872,6 +895,8 @@ def __init__(self, modules: Iterable[CheckpointBlock], sqrt=False) -> None:
self._modules[str(i)] = module
self.add_module(str(i), module)

self.num_hidden = num_hidden

if sqrt:
length = len(self)
num_save_needed = 0
Expand Down Expand Up @@ -901,12 +926,11 @@ def __iter__(self) -> Iterator[CheckpointBlock]:
def __getitem__(self, index: Union[int, str]) -> CheckpointBlock:
return self._modules[str(index)]

def forward(self, hidden_state, *args, return_hidden_states = False):
def forward(self, *args, return_hidden_states = False):
self.return_hidden_states = return_hidden_states
placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled())
outputs = OpTransformerBlockList.apply(placeholder, self, self.save_list, hidden_state, *args)
last_hidden, middle_hiddens = outputs[:2]
outputs = OpTransformerBlockList.apply(placeholder, self, self.save_list, self.num_hidden, *args)
if return_hidden_states:
return last_hidden, middle_hiddens
return tuple(outputs[:2*self.num_hidden])
else:
return last_hidden
return tuple(outputs[:self.num_hidden]) if self.num_hidden > 1 else outputs[0]
1 change: 1 addition & 0 deletions tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
("dropout", 1),
("loss_func", 1),

("multi_return", 2),
("middle_hidden", 4),
("other_hidden", 4),
("inspector_hidden", 2),
Expand Down
126 changes: 126 additions & 0 deletions tests/test_multi_return.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from utils import *

import bmtrain as bmt
import torch
import random
from bmtrain import config
from bmtrain.block_layer import CheckpointBlock, TransformerBlockList
from bmtrain.pipe_layer import PipelineTransformerBlockList
import torch.nn.functional as F

class MultiInputReturn(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, b, c, d, e):
return a*2, b+d, c*4+e*5

class Model_ZERO(torch.nn.Module):
def __init__(self, ms) -> None:
super().__init__()
self.ms = TransformerBlockList([
CheckpointBlock(m)
for m in ms
], num_hidden=3)

def forward(self, x):
y = self.ms(*x)
return y

class Model_PIPE(torch.nn.Module):
def __init__(self, ms) -> None:
super().__init__()
self.ms = PipelineTransformerBlockList([
CheckpointBlock(m)
for m in ms
], num_hidden=3)

def forward(self, x):
y = self.ms(*x)
return y

class Model_BLOCK(torch.nn.Module):
def __init__(self, ms) -> None:
super().__init__()
self.ms = torch.nn.ModuleList([
CheckpointBlock(m)
for m in ms
])

def forward(self, x):
y = x[:3]
other = x[3:]
for m in self.ms:
y = m(*y, *other)
return y

class Model_NORMAL(torch.nn.Module):
def __init__(self, ms) -> None:
super().__init__()
self.ms = torch.nn.ModuleList(ms)

def forward(self, x):
y = x[:3]
other = x[3:]
for m in self.ms:
y = m(*y, *other)
return y

def manual_seed(seed=33):
torch.manual_seed(seed)
random.seed(seed)
try:
import numpy as np
np.random.seed(seed)
except ModuleNotFoundError:
pass

def run(name, cls, num_layer=4, dim=4096):
manual_seed()

ms = [MultiInputReturn() for i in range(num_layer)]

inps = (
torch.randn((dim,)).cuda(),
torch.randn((dim,)).cuda(),
torch.randn((dim,)).cuda(),
torch.randn((dim,)).cuda(),
torch.randn((dim,)).cuda(),
)
last_weights = (
torch.randn((dim,)).cuda(),
torch.randn((dim,)).cuda(),
torch.randn((dim,)).cuda(),
)

for inp in inps:
inp.requires_grad_(True)
m = cls(ms)

ret = ""
logits = m(inps)
loss = (logits[0]*last_weights[0] + logits[1]*last_weights[1] + logits[2]*last_weights[2]).sum()
loss.backward()
return list(logits) + [
inp.grad
for inp in inps
]

def test_main():
ret = {}
ret["normal"] = run("normal", Model_NORMAL)
ret["block"] = run("block", Model_BLOCK)
ret["zero"] = run("zero", Model_ZERO)
# ret["pipe"] = run("pipe", Model_PIPE) # TODO pipeline not support multiple input-output yet
for k, r in ret.items():
bmt.print_rank(f"============={k}============")
bmt.print_rank(r)
for r in ret.values():
for r2 in ret.values():
for i in range(len(r)):
assert_lt((r[i]-r2[i]).abs().max(), 1e-5)

if __name__ == "__main__":
bmt.init_distributed(pipe_size=2)

test_main()