Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sandyhouse committed Mar 23, 2022
1 parent a9ea543 commit 51ca6b2
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 6 deletions.
2 changes: 0 additions & 2 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ message RecomputeConfig {
repeated string checkpoints = 1;
optional bool enable_offload = 2 [ default = false ];
repeated int32 checkpoint_shape = 3;
optional int32 segments = 4 [ default = 1 ];
}

message ShardingConfig {
Expand Down Expand Up @@ -68,7 +67,6 @@ message AMPConfig {
repeated string custom_black_varnames = 9;
optional bool use_pure_fp16 = 10 [ default = false ];
optional bool use_fp16_guard = 11 [ default = true ];
optional string amp_level = 12 [ default = "O1" ];
}

message LocalSGDConfig {
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/distributed/fleet/base/fleet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
_grad_scalar = None


class RecomputeModelWrapper(paddle.nn.Layer):
class _RecomputeModelWrapper(paddle.nn.Layer):
def __init__(self, model, segments=1, preserve_rng_state=True):
super(RecomputeModelWrapper, self).__init__()
assert isinstance(model, paddle.nn.Sequential), (
Expand All @@ -69,6 +69,7 @@ def _checkpoint(self, func, *args, **kwargs):
return RecomputeFunction.apply(func, self._preserve_rng_state, *args)

def forward(self, input):
end = 0
for begin in range(0, self._segment_size * (self._segments - 1),
self._segment_size):
end = begin + self._segment_size
Expand Down Expand Up @@ -992,7 +993,7 @@ def forward(self, x):
strategy = self._user_defined_strategy
if strategy.amp == True:
amp_enable = True
amp_level = strategy.amp_configs['amp_level']
amp_level = "O2" if strategy.amp_configs['use_pure_fp16'] else "O1"
if amp_level.upper() == "O2":
model = paddle.amp.decorate(
models=model,
Expand Down Expand Up @@ -1020,8 +1021,7 @@ def forward(self, x):

if strategy.recompute == True:
recompute_enable = True
segments = strategy.recompute_configs['segments']
model = RecomputeModelWrapper(model)
model = _RecomputeModelWrapper(model)

if self._user_defined_strategy.heter_ccl_mode == True:
model = paddle.DataParallel(
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,7 @@ if (WITH_DISTRIBUTE)
set_tests_properties(test_dist_fleet_infer PROPERTIES TIMEOUT 200)
set_tests_properties(test_dist_fleet_raw_program_optimizer PROPERTIES TIMEOUT 120)
set_tests_properties(test_dist_fleet_raw_program_optimizer_fuse_allreduce PROPERTIES TIMEOUT 60)
set_tests_properties(test_dist_dygraph_apis PROPERTIES TIMEOUT 120)
endif()

if (WITH_DISTRIBUTE AND NOT APPLE)
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/fluid/tests/unittests/dygraph_fleet_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def test_dygraph_fleet_api(self):
net = paddle.nn.Sequential(
paddle.nn.Linear(10, 1), paddle.nn.Linear(1, 2))
net = dist.fleet.distributed_model(net)
data = np.random.uniform(-1, 1, [30, 10]).astype('float32')
data = paddle.to_tensor(data)
net(data)


if __name__ == "__main__":
Expand Down

0 comments on commit 51ca6b2

Please sign in to comment.