Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sandyhouse committed Mar 22, 2022
1 parent 4a2ce73 commit a9ea543
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 31 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ message AMPConfig {
repeated string custom_black_varnames = 9;
optional bool use_pure_fp16 = 10 [ default = false ];
optional bool use_fp16_guard = 11 [ default = true ];
optional string amp_level = 12 [ default = "O1" ];
}

message LocalSGDConfig {
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/distributed/fleet/base/fleet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

class RecomputeModelWrapper(paddle.nn.Layer):
def __init__(self, model, segments=1, preserve_rng_state=True):
super(RecomputeModelWrapper, self).__init__()
assert isinstance(model, paddle.nn.Sequential), (
"The model passed to RecomputeModelWrapper must be of type "
"paddle.nn.Sequential.")
Expand Down Expand Up @@ -1006,10 +1007,10 @@ def forward(self, x):
decr_every_n_nan_or_inf = strategy.amp_configs[
'decr_every_n_nan_or_inf']
use_dynamic_loss_scaling = strategy.amp_configs[
'use_dynamic_loss_scalling']
'use_dynamic_loss_scaling']

global _grad_scalar
_grad_scalar = paddle.amp.GradScalar(
_grad_scalar = paddle.amp.GradScaler(
init_loss_scaling=init_loss_scaling,
incr_ratio=incr_ratio,
decr_ratio=decr_ratio,
Expand Down
57 changes: 57 additions & 0 deletions python/paddle/fluid/tests/unittests/dygraph_fleet_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import random
import numpy as np
import os
import shutil

import paddle
import paddle.nn as nn
from paddle.fluid import core
import datetime
from datetime import timedelta
import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.dygraph.parallel import ParallelEnv


class TestDygraphFleetAPI(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
random.seed(2022)
np.random.seed(2022)
self.config()

def config(self):
self.dtype = "float32"
self.shape = (2, 10, 5)

def test_dygraph_fleet_api(self):
import paddle.distributed.fleet as fleet
import paddle.distributed as dist
strategy = fleet.DistributedStrategy()
strategy.amp = True
strategy.recompute = True
fleet.init(is_collective=True, strategy=strategy)
net = paddle.nn.Sequential(
paddle.nn.Linear(10, 1), paddle.nn.Linear(1, 2))
net = dist.fleet.distributed_model(net)


if __name__ == "__main__":
unittest.main()
35 changes: 6 additions & 29 deletions python/paddle/fluid/tests/unittests/test_dist_dygraph_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import os
import paddle
import paddle.nn as nn


class SimpleNet(nn.Layer):
def __init__(self, in_size, out_size):
super(SimpleNet, self).__init__()
self.linear = nn.Linear(in_size, out_size)
self.softmax = nn.Softmax(axis=-1)

def forward(self, input):
y = self.linear(input)
pred = self.softmax(y)
return pred
from __future__ import print_function

import unittest
from test_parallel_dygraph_dataparallel import TestMultipleGpus

class TestDygraphFleetApis(unittest.TestCase):
def setUp(self):
os.environ["PADDLE_TRAINER_ID"] = "1"
os.environ[
"PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002"

def test_pipeline_optimizer(self):
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.amp = True
strategy.recompute = True
fleet.init(is_collective=True, strategy=strategy)
net = SimpleNet(8, 8)
net = dist.fleet.distributed_model(net)
class TestDygraphFleetApi(TestMultipleGpus):
def test_dygraph_fleet_api(self):
self.run_mnist_2gpu('dygraph_fleet_api.py')


if __name__ == "__main__":
Expand Down

0 comments on commit a9ea543

Please sign in to comment.