Skip to content

Commit

Permalink
Dygraph Recompute (#32516)
Browse files Browse the repository at this point in the history
* Dygraph reocmpute

* unitest for Dygraph reocmpute

* dy recompute remove unitest for win and mac
  • Loading branch information
JZ-LIANG authored Apr 25, 2021
1 parent f16981b commit 583ebab
Show file tree
Hide file tree
Showing 4 changed files with 355 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/paddle/distributed/fleet/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@

from .fs import LocalFS, HDFSClient
from .ps_util import DistributedInfer
from .recompute import recompute
177 changes: 177 additions & 0 deletions python/paddle/distributed/fleet/utils/recompute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.fluid import core
from paddle.autograd import PyLayer
from paddle.fluid import framework
import contextlib

import logging
logging.basicConfig(
format='%(asctime)s %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')


def detach_variable(inputs):
out = []
for inp in inputs:
if not isinstance(inp, core.VarBase):
out.append(inp)
continue

x = inp.detach()
x.stop_gradient = inp.stop_gradient
out.append(x)
return tuple(out)


def check_recompute_necessary(inputs):
if not any(input_.stop_gradient == False for input_ in inputs
if isinstance(input_, paddle.Tensor)):
logging.warn(
"[Recompute]: None of the inputs to current recompute block need grad, "
"therefore there is NO need to recompute this block in backward !")


@contextlib.contextmanager
def swith_rng_state(rng_state):
orig_cuda_rng_state = paddle.get_cuda_rng_state()
paddle.set_cuda_rng_state(rng_state)
try:
yield
finally:
paddle.set_cuda_rng_state(orig_cuda_rng_state)


class RecomputeFunction(PyLayer):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
check_recompute_necessary(args)

# store for recomputing
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state

# NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input
# the order of tensors in backward()'s output should be the same as tensors in forward()'s input
# None tensor inputs will be filtered in backward inputs.

# save input for backward
ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
for i, arg in enumerate(args):
if paddle.is_tensor(arg):
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)

# NOTE recompute with restore RNG only support one senario where one process for one cuda gpu.
# one process with multiple gpu and mix-gpu-cpu senarios are not support
if ctx.preserve_rng_state:
cur_device = paddle.get_device()
if 'gpu:' not in cur_device:
raise RuntimeError(
"Recompute with RNG perserve is not support current device: {}.".
format(cur_device))
ctx.fw_cuda_rng_state = paddle.get_cuda_rng_state()

# TODO support AMP

with paddle.no_grad():
outputs = run_function(*args)

return outputs

@staticmethod
def backward(ctx, *args):
with paddle.fluid.dygraph.guard():
# TODO need to check the recompute calling is vaild or not

# Restore inputs
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensor()
for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]

# paddle.enable_grad()
tracer = framework._dygraph_tracer()
tracer._has_grad = True

# TODO support AMP

if ctx.preserve_rng_state:
with swith_rng_state(ctx.fw_cuda_rng_state):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
else:
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, core.VarBase):
outputs = (outputs, )
assert len(outputs) == len(args)

# run backward() with only tensor that requires grad
forward_outputs_with_grad = []
backward_inputs = list(args)
for i in range(len(outputs)):
if isinstance(outputs[i],
core.VarBase) and not outputs[i].stop_gradient:
forward_outputs_with_grad.append(outputs[i])
if len(forward_outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True, this recompute() is not necessary"
)

assert len(backward_inputs) == len(
forward_outputs_with_grad
), "number of forward outputs is [{}], but the backward got [{}] inputs".format(
len(forward_outputs_with_grad), len(backward_inputs))

# actually backward
paddle.autograd.backward(forward_outputs_with_grad, backward_inputs)

grads = list(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, core.VarBase))

return grads


def recompute(function, *args, **kwargs):
"""
recompute intermediate activations to save then memory.
Args:
function: layer of sequence of layers that describes part of forward pass of the model whose
intermediate activations will be released to save memory in forward stage and will be recomputed
in backward stage for gradient calculation.
preserve_rng_state(bool, optional): if preserve the RNG state of forward and restore it in backward.
args: inputs to the function
Returns:
Output of function on args
"""
# Hack to mix *args with **kwargs in a python 2.7-compliant way
preserve = kwargs.pop('preserve_rng_state', True)
if kwargs:
raise ValueError("Unexpected keyword arguments: " + ",".join(
arg for arg in kwargs))

return RecomputeFunction.apply(function, preserve, *args)
1 change: 1 addition & 0 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_pipeline_layer)
LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision)
LIST(REMOVE_ITEM TEST_OPS test_fleet_base_single)
LIST(REMOVE_ITEM TEST_OPS test_dygraph_recompute)
elseif(WITH_GPU)
if (${CUDNN_VERSION} VERSION_LESS 7100)
LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
Expand Down
176 changes: 176 additions & 0 deletions python/paddle/fluid/tests/unittests/test_dygraph_recompute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np

import paddle
from paddle.autograd import PyLayer
from paddle.distributed.fleet.utils import recompute
import random

import paddle.fluid.layers as layers


def get_fc_block(block_idx, input_size, is_last=False):
block_name = "block_" + str(block_idx)
block = paddle.nn.Sequential(
(block_name + "_fc_0", paddle.nn.Linear(
input_size, input_size, bias_attr=False)),
(block_name + "_dropout", paddle.nn.Dropout(p=0.5)),
(block_name + "_relu_1", paddle.nn.ReLU()),
(block_name + "_fc_1", paddle.nn.Linear(
input_size, input_size, bias_attr=False)),
(block_name + "_relu_2", paddle.nn.ReLU()), )
if is_last:
block.add_sublayer(
block_name + "_fc_2",
paddle.nn.Linear(
input_size, 1, bias_attr=False)) # add sublayer
else:
block.add_sublayer(
block_name + "_fc_2",
paddle.nn.Linear(
input_size, input_size, bias_attr=False)) # add sublayer
return block


class Naive_fc_net(paddle.nn.Layer):
def __init__(self,
input_size=10,
recompute_blocks=[1, 3],
recompute_kwargs={}):
super(Naive_fc_net, self).__init__()
self.recompute_blocks = recompute_blocks
self.recompute_kwargs = recompute_kwargs
self.runfunc0 = get_fc_block(0, input_size, is_last=False)
self.runfunc1 = get_fc_block(1, input_size, is_last=False)
self.runfunc2 = get_fc_block(2, input_size, is_last=False)
self.runfunc3 = get_fc_block(3, input_size, is_last=False)
self.runfunc4 = get_fc_block(4, input_size, is_last=True)

def forward(self, inputs):

if 0 in self.recompute_blocks:
inputs = recompute(self.runfunc0, inputs)
else:
inputs = self.runfunc0(inputs)

if 1 in self.recompute_blocks:
inputs = recompute(self.runfunc1, inputs)
else:
inputs = self.runfunc1(inputs)

if 2 in self.recompute_blocks:
inputs = recompute(self.runfunc2, inputs, **self.recompute_kwargs)
else:
inputs = self.runfunc2(inputs)

if 3 in self.recompute_blocks:
inputs = recompute(self.runfunc3, inputs)
else:
inputs = self.runfunc3(inputs)

if 4 in self.recompute_blocks:
inputs = recompute(self.runfunc4, inputs)
else:
inputs = self.runfunc4(inputs)

return inputs


def run_model(cuda_state, recompute_block=[], recompute_kwargs={}):
gen = paddle.seed(10)
gen.manual_seed(10)
np.random.seed(10)
random.seed(10)

if cuda_state:
paddle.set_cuda_rng_state(cuda_state)

batch_size, input_size = 1, 10
model = Naive_fc_net(
input_size,
recompute_blocks=recompute_block,
recompute_kwargs=recompute_kwargs)
loss_fn = paddle.nn.MSELoss(reduction='mean')
optimizer = paddle.optimizer.SGD(learning_rate=0.01,
parameters=model.parameters())

loss_ = []
param_ = []
grad_ = []
for step in range(10):
x_data = np.random.randn(batch_size, input_size).astype(np.float32)
x = paddle.to_tensor(x_data)
# x.stop_gradient = False
y_pred = model(x)
loss = y_pred.mean()

loss_.append(np.asarray(loss).tolist())
loss.backward()
optimizer.step()

param_.append(np.asarray(model.parameters()[9]).tolist())
grad_.append(np.asarray(model.parameters()[3]._grad_ivar()).tolist())

optimizer.clear_grad()
return loss_, param_, grad_


class TestPyLayer(unittest.TestCase):
def test_fc_net_with_dropout(self):
def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad):
self.assertEqual(loss_ref, loss)
self.assertEqual(param_ref, param)
self.assertEqual(grad_ref, grad)

cuda_state = paddle.get_cuda_rng_state()
# without recompute
loss_ref, param_ref, grad_ref = run_model(
cuda_state, recompute_block=[])

# recompute second block
loss, param, grad = run_model(cuda_state, recompute_block=[1, 3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

# recompute fourth block
loss, param, grad = run_model(cuda_state, recompute_block=[3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

# recompute second to fourth block
loss, param, grad = run_model(cuda_state, recompute_block=[1, 2, 3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

# recompute second & fourth block
loss, param, grad = run_model(cuda_state, recompute_block=[1, 3])
check_identical(loss_ref, param_ref, grad_ref, loss, param, grad)

def test_recompute_kwargs(self):
paddle.set_device("gpu")
kwargs = {"is_test": False}
with self.assertRaises(ValueError):
loss_ref, param_ref, grad_ref = run_model(
None, recompute_block=[2], recompute_kwargs=kwargs)

def test_recompute_cpu_rng(self):
paddle.set_device("cpu")
with self.assertRaises(RuntimeError):
loss_ref, param_ref, grad_ref = run_model(None, recompute_block=[2])


if __name__ == '__main__':
unittest.main()

0 comments on commit 583ebab

Please sign in to comment.