Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dygraph Recompute #32516

Merged
merged 3 commits into from
Apr 25, 2021
Merged

Dygraph Recompute #32516

merged 3 commits into from
Apr 25, 2021

Conversation

JZ-LIANG
Copy link
Contributor

@JZ-LIANG JZ-LIANG commented Apr 25, 2021

PR types

New features

PR changes

APIs

Describe

Dygraph Recompute

Bert example base on PaddleNLP bert

# import recompute, line 40
from paddle.distributed.fleet.utils import recompute

# modify TransformerEncoder class, line 615
class TransformerEncoder(Layer):
    # NOTE recompute modification
    # def __init__(self, encoder_layer, num_layers, norm=None):
    def __init__(self, encoder_layer, num_layers, norm=None, enable_recompute = True, preserve_rng_state = True):
        super(TransformerEncoder, self).__init__()
        self.layers = LayerList([(encoder_layer if i == 0 else
                                  type(encoder_layer)(**encoder_layer._config))
                                 for i in range(num_layers)])
        self.num_layers = num_layers
        self.norm = norm
        # NOTE recompute modification
        self.enable_recompute = enable_recompute
        self.preserve_rng_state = preserve_rng_state
        if preserve_rng_state:
            assert self.enable_recompute, "preserve_rng_state is True, but enable_recompute is False."

    def forward(self, src, src_mask=None, cache=None):
        src_mask = _convert_attention_mask(src_mask, src.dtype)

        output = src
        new_caches = []
        for i, mod in enumerate(self.layers):
            if cache is None:
                # NOTE recompute modification
                if self.enable_recompute:
                    output = recompute(mod, output, src_mask, preserve_rng_state = self.preserve_rng_state)
                else:   
                    output = mod(output, src_mask=src_mask)
            else:
                output, new_cache = mod(output,
                                        src_mask=src_mask,
                                        cache=cache[i])
                new_caches.append(new_cache)

        if self.norm is not None:
            output = self.norm(output)

        return output if cache is None else (output, new_caches)

example to recompute the second block of a naive fc net:

def get_fc_block(block_idx, input_size, is_last=False):
    block_name = "block_" + str(block_idx)
    block = paddle.nn.Sequential(
        (block_name + "_fc_0", paddle.nn.Linear(
            input_size, input_size, bias_attr=False)),
        (block_name + "_dropout", paddle.nn.Dropout(p=0.5)),
        (block_name + "_relu_1", paddle.nn.ReLU()),
        (block_name + "_fc_1", paddle.nn.Linear(
            input_size, input_size, bias_attr=False)),
        (block_name + "_relu_2", paddle.nn.ReLU()), )
    if is_last:
        block.add_sublayer(
            block_name + "_fc_2",
            paddle.nn.Linear(
                input_size, 1, bias_attr=False))  # add sublayer
    else:
        block.add_sublayer(
            block_name + "_fc_2",
            paddle.nn.Linear(
                input_size, input_size, bias_attr=False))  # add sublayer
    return block

class Naive_fc_net(paddle.nn.Layer):
    def __init__(self, input_size=10,):
        super(Naive_fc_net, self).__init__()
        self.runfunc0 = get_fc_block(0, input_size, is_last=False)
        self.runfunc1 = get_fc_block(1, input_size, is_last=False)
        self.runfunc2 = get_fc_block(2, input_size, is_last=False)

    def forward(self, inputs):

        inputs = self.runfunc0(inputs)
        # recompute 
        inputs = recompute(self.runfunc1, inputs)
        inputs = self.runfunc2(inputs)

        return inputs

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.



@contextlib.contextmanager
def swith_rng_state(rng_state):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

swith_rng_state -> switch_rng_state

Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ForFishes ForFishes merged commit 583ebab into PaddlePaddle:develop Apr 25, 2021
@FesianXu
Copy link

你好,动态图的recompute在单机多卡环境下会报错,请问你尝试过单机多卡运行吗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants