From 8b21ef4a416a93b3d73cf4e2eed63ba0efdcf9fd Mon Sep 17 00:00:00 2001 From: xiayanming Date: Fri, 22 Mar 2024 17:37:12 +0800 Subject: [PATCH 1/2] for support ascvrq --- paddle/fluid/framework/op_proto_maker.h | 3 + paddle/fluid/operators/batch_fc_op.kps | 361 ++++++++++++++++++ .../collective/c_broadcast_op_xpu.cc | 111 ++++++ .../fluid/platform/device/xpu/xpu2_op_list.h | 6 + paddle/fluid/pybind/const_value.cc | 3 +- .../fleet/meta_optimizers/__init__.py | 1 + .../fleet/meta_optimizers/sharding/shard.py | 20 +- .../fleet/meta_optimizers/sharding/utils.py | 3 +- .../meta_optimizers/sharding_optimizer.py | 282 ++++++++++++-- .../unittests/xpu/test_batch_fc_op_xpu.py | 111 ++++++ 10 files changed, 868 insertions(+), 33 deletions(-) create mode 100644 paddle/fluid/operators/batch_fc_op.kps create mode 100644 paddle/fluid/operators/collective/c_broadcast_op_xpu.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_batch_fc_op_xpu.py diff --git a/paddle/fluid/framework/op_proto_maker.h b/paddle/fluid/framework/op_proto_maker.h index 51aeed2e5d734e..07ab71aa38dc9a 100644 --- a/paddle/fluid/framework/op_proto_maker.h +++ b/paddle/fluid/framework/op_proto_maker.h @@ -35,6 +35,9 @@ enum class OpRole { // Tag all learning rate scheduler operators. kLRSched = 0x0010, + // scale lr(for adam) + kScaleLr = 0x0012, + kLoss = 0x0100, // The default value of op's role. This should be only used for unittests and // CreateOp inside a operator. diff --git a/paddle/fluid/operators/batch_fc_op.kps b/paddle/fluid/operators/batch_fc_op.kps new file mode 100644 index 00000000000000..65071568110ecc --- /dev/null +++ b/paddle/fluid/operators/batch_fc_op.kps @@ -0,0 +1,361 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#ifdef PADDLE_WITH_XPU_KP + +#include // NOLINT +#include +#include +#include +#include + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/device/xpu/enforce_xpu.h" +#include "paddle/fluid/platform/device_context.h" + +#include "xpu/kernel/xtdk.h" // NOLINT +#include "xpu/kernel/xtdk_math.h" // NOLINT +#include "xpu/kernel/xtdk_simd.h" + +#include "xpu/kernel/xtdk_io.h" + +#include "paddle/fluid/operators/batch_fc_op.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +#include "paddle/fluid/operators/xpu_api_wrapper.h" + +namespace paddle { +namespace operators { +using framework::Tensor; + + +static __device__ inline void memset_lm_float(float* dst_ptr, int size) { + for (int i = 0; i < size; i += 16) { + vstore_lm_float32x16_mz(dst_ptr + i, 0, 0); + } + mfence_lm(); +} + +template +__global__ void add_bias_kernel( + T* data, int slot_pairs_num, int ins_num, int out_dim, const T* bias) { + int cid = core_id(); + int ncores = core_num(); + if (cid >= ncores) { + return; + } + int thread_id = cluster_id() * ncores + cid; + int nthreads = cluster_num() * ncores; + + int row_per_loop = 10; + int total_row = slot_pairs_num * ins_num; + // int threads_per_bs = roundup_div(ins_num, row_per_loop); + int buf_size = row_per_loop * (out_dim + 16); + __simd__ T local_data_buf[buf_size]; + __simd__ T local_bias_buf[out_dim + 16]; + + __simd__ T out_buf[buf_size]; + memset_lm_float(out_buf, buf_size); + + for (int g_index = thread_id * row_per_loop; g_index < total_row; g_index += nthreads * row_per_loop) { + int b_index = g_index / ins_num; // bs index + int r_index = g_index % ins_num; // row index + + GM2LM_ASYNC(bias + b_index, local_bias_buf, out_dim * sizeof(T)); + + int gm_offset = b_index * ins_num * out_dim + r_index * out_dim; + GM2LM_ASYNC(data + gm_offset, local_data_buf, row_per_loop * out_dim * sizeof(T)); + + mfence(); + + int col_offset = 0; + for (int curr_row = 0; curr_row < row_per_loop; curr_row++) { + // column + for (int col_step = 0; col_step < out_dim; col_step += 16) { + col_offset = curr_row * out_dim + col_step; + float32x16_t vec_data = vload_lm_float32x16(local_data_buf + col_offset); + float32x16_t vec_bias = vload_lm_float32x16(local_bias_buf); + + vec_data = vvadd_float32x16(vec_data, vec_bias); + vstore_lm_float32x16(out_buf + col_step, vec_data); + } + mfence(); + LM2GM_ASYNC(out_buf, data + gm_offset, out_dim * sizeof(T)); + mfence(); + } + } +} + +template +void add_bias(xpu::Context* xpu_ctx, + T* data, + int slot_pairs_num, + int ins_num, + int out_dim, + const T* bias) { + auto stream = xpu_ctx->xpu_stream; + add_bias_kernel<<<8, 64, stream>>>(data, slot_pairs_num, ins_num, out_dim, bias); +} + +template +__global__ void add_bias_grad_kernel(const T* dout_data, + int slot_pairs_num, + int ins_num, + int out_dim, + T* db_data) { + int cid = core_id(); + int ncores = core_num(); + if (cid >= ncores) { + return; + } + int thread_id = cluster_id() * ncores + cid; + int nthreads = cluster_num() * ncores; + + int bs_per_loop = 64; + int total_bs = slot_pairs_num * out_dim; + int buf_size = bs_per_loop; + __simd__ T local_bias_buf[buf_size]; + __simd__ T tmp_sum_buf[buf_size]; + + __local__ float local_data_buf[1]; + + memset_lm_float(local_bias_buf, buf_size); + memset_lm_float(tmp_sum_buf, buf_size); + + __local__ T tmp_sum = static_cast(0); + for (int g_index = thread_id * bs_per_loop; g_index < total_bs; g_index += nthreads * bs_per_loop) { + int len = min(total_bs - g_index, bs_per_loop); + int r_index = g_index / out_dim; // row index + int c_index = g_index % out_dim; // col index + + GM2LM_ASYNC(db_data + g_index, local_bias_buf, len * sizeof(T)); + + for (int index = 0; index < len; index++) { + for (int i = 0; i < ins_num; ++i) { + int select_indx = ((r_index + 1) * i + 1) * c_index; + GM2LM_ASYNC(dout_data + select_indx, local_data_buf, sizeof(T)); + mfence(); + tmp_sum_buf[index] += local_data_buf[0]; + } + } + + mfence(); + + for (int step = 0; step < len; step += 16) { + float32x16_t vec_bias = vload_lm_float32x16(local_bias_buf + step); + float32x16_t vec_sum = vload_lm_float32x16(tmp_sum_buf + step); + vec_bias = vvadd_float32x16(vec_bias, vec_sum); + vstore_lm_float32x16(local_bias_buf + step, vec_bias); + } + mfence(); + LM2GM_ASYNC(local_bias_buf, db_data + g_index, len * sizeof(T)); + } +} + +template +void add_bias_grad(xpu::Context* xpu_ctx, + const T* dout_data, + int slot_pairs_num, + int ins_num, + int out_dim, + T* db_data) { + auto stream = xpu_ctx->xpu_stream; + add_bias_grad_kernel<<<8, 64, stream>>>( + dout_data, slot_pairs_num, ins_num, out_dim, db_data); +} + + + +template +class BatchFCXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext& ctx) const override { + int batchcount = ctx.Attr("batchcount"); + auto transpose_weight = ctx.Attr("transpose_weight"); + if (transpose_weight) { + // TODO + PADDLE_ENFORCE_EQ( + transpose_weight, + true, + platform::errors::Unimplemented("BatchFC not support transpose_weight now.")); + return; + } + if (batchcount > 0) { + // TODO + PADDLE_ENFORCE_EQ( + (batchcount > 0), + true, + platform::errors::Unimplemented("BatchFC not support transpose_weight now.")); + } else { + // X.dim = slot_pairs_num * ins_num * in_dim + // W.dim = slot_pairs_num * in_dim * out_dim + // b.dim = slot_pairs_num * out_dim + // output.dim = slot_pairs_num * ins_num * out_dim + auto* input = ctx.Input("Input"); + auto* w = ctx.Input("W"); + auto* bias = ctx.Input("Bias"); + auto* output = ctx.Output("Out"); + auto input_dims = input->dims(); + auto w_dims = w->dims(); + auto slot_pairs_num = input_dims[0]; + auto ins_num = input_dims[1]; + // auto in_dim = input_dims[2]; + auto out_dim = w_dims[2]; + + // get data ptr + const XPUType* x_ptr = reinterpret_cast(input->data()); + const XPUType* y_ptr = reinterpret_cast(w->data()); + const XPUType* bias_data = reinterpret_cast(bias->data()); + + output->Resize({slot_pairs_num, ins_num, out_dim}); + XPUType* out_ptr = reinterpret_cast(output->mutable_data(ctx.GetPlace())); + + // initialize + auto& dev_ctx = + ctx.template device_context(); + xpu::Context* xpu_ctx = dev_ctx.x_context(); + + bool trans_x = false; + bool trans_y = false; + + T alpha = 1; + T beta = 0; + + XpuFcInfo fc_info; + GetFCInfo(input_dims, w_dims, trans_x, trans_y, &fc_info); + MatMulXPUFunction(xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, alpha); + + // add bias + add_bias(xpu_ctx, + out_ptr, + slot_pairs_num, + ins_num, + out_dim, + bias_data); + } + } +}; + +template +class BatchFCGradOpXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext& ctx) const override { + int batchcount = ctx.Attr("batchcount"); + if (batchcount > 0) { + // TODO + PADDLE_ENFORCE_EQ( + (batchcount > 0), + true, + platform::errors::Unimplemented("BatchFC not support transpose_weight now.")); + } else { + auto* input = ctx.Input("Input"); + auto* w = ctx.Input("W"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + + auto* dx = ctx.Output(framework::GradVarName("Input")); + auto* dw = ctx.Output(framework::GradVarName("W")); + auto* db = ctx.Output(framework::GradVarName("Bias")); + + auto input_dims = input->dims(); + auto w_dims = w->dims(); + auto slot_pairs_num = input_dims[0]; + auto ins_num = input_dims[1]; + // auto in_dim = input_dims[2]; + auto out_dim = w_dims[2]; + + const XPUType* dout_ptr = reinterpret_cast(dout->data()); + const XPUType* x_ptr = reinterpret_cast(dx->mutable_data(ctx.GetPlace())); + const XPUType* y_ptr = reinterpret_cast(dw->mutable_data(ctx.GetPlace())); + XPUType* b_ptr = reinterpret_cast(db->mutable_data(ctx.GetPlace())); + + auto& dev_ctx = + ctx.template device_context(); + xpu::Context* xpu_ctx = dev_ctx.x_context(); + xpu::ctx_guard RAII_GUARD(xpu_ctx); + + bool transpose_x = false; + bool transpose_y = false; + XpuFcInfo info_forward; + GetFCInfo(input_dims, w_dims, transpose_x, transpose_y, &info_forward); + + const XPUType* a_1 = reinterpret_cast(NULL); + const XPUType* b_1 = reinterpret_cast(NULL); + const XPUType* a_2 = reinterpret_cast(NULL); + const XPUType* b_2 = reinterpret_cast(NULL); + XPUType* c_1 = (dx == NULL) ? reinterpret_cast(NULL) + : reinterpret_cast(dx->data()); + XPUType* c_2 = (dw == NULL) ? reinterpret_cast(NULL) + : reinterpret_cast(dw->data()); + + // add bias grad + db->mutable_data(ctx.GetPlace()); + add_bias_grad(xpu_ctx, + dout_ptr, + slot_pairs_num, + ins_num, + out_dim, + b_ptr); + + T alpha = 1; + + XpuFcInfo info_dx; + XpuFcInfo info_dy; + std::tuple + fc_info = MatmulGradFcInfo(xpu_ctx, + &RAII_GUARD, + info_forward, + transpose_x, + transpose_y, + x_ptr, + y_ptr, + dout_ptr); + std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info; + + // dx = dout_data * y^T + MatMulXPUFunction(xpu_ctx, a_1, b_1, c_1, info_dx, alpha); + + // dy = x^T * dout_data + MatMulXPUFunction(xpu_ctx, a_2, b_2, c_2, info_dy, alpha); + } + } +}; + +} // namespace operators +} // namespace paddle + + + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_KERNEL(batch_fc, KP, plat::XPUPlace, + ops::BatchFCXPUKernel); +REGISTER_OP_KERNEL(batch_fc_grad, KP, plat::XPUPlace, + ops::BatchFCGradOpXPUKernel); + +// namespace ops = paddle::operators; + +REGISTER_OP_XPU_KERNEL(batch_fc, + ops::BatchFCXPUKernel); +REGISTER_OP_XPU_KERNEL(batch_fc_grad, + ops::BatchFCGradOpXPUKernel); +#endif \ No newline at end of file diff --git a/paddle/fluid/operators/collective/c_broadcast_op_xpu.cc b/paddle/fluid/operators/collective/c_broadcast_op_xpu.cc new file mode 100644 index 00000000000000..ae2e3657ad6d69 --- /dev/null +++ b/paddle/fluid/operators/collective/c_broadcast_op_xpu.cc @@ -0,0 +1,111 @@ +/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_broadcast_op.h" + +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_XPU_BKCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class CBroadcastOpXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_XPU_BKCL) + auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + size_t numel = x->numel(); + + BKCLDataType dtype = + platform::ToBKCLDataType(framework::TransToProtoVarType(x->dtype())); + int ring_id = ctx.Attr("ring_id"); + auto place = ctx.GetPlace(); + int root = ctx.Attr("root"); + + auto comm = paddle::platform::BKCLCommContext::Instance().Get(ring_id, place); + auto stream = comm->stream(); + VLOG(3) << "BKCLCommContext ring_id " << ring_id; + + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx) + ->x_context() + ->xpu_stream; + } + + void* send_recv_buffer = nullptr; + if (root == comm->rank()) { + send_recv_buffer = + reinterpret_cast(const_cast(x->data())); + PADDLE_ENFORCE_XPU_SUCCESS(bkcl_broadcast(comm->comm(), + send_recv_buffer, + send_recv_buffer, + numel, + dtype, + root, + stream)); + VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent " + << x->numel(); + if (out != x) { + framework::TensorCopy( + *static_cast(x), + place, + *platform::DeviceContextPool::Instance().Get(place), + static_cast(out)); + } + } else { + auto& dev_ctx = + ctx.template device_context(); + dev_ctx.template Alloc(out); + send_recv_buffer = out->data(); + PADDLE_ENFORCE_XPU_SUCCESS(bkcl_broadcast(comm->comm(), + send_recv_buffer, + send_recv_buffer, + numel, + dtype, + root, + stream)); + VLOG(3) << "rank " << comm->rank() << " invoke Bcast. received " + << phi::product(out->dims()); + } + + out->Resize(x->dims()); + out->set_lod(x->lod()); +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with XPU.")); +#endif + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL(c_broadcast, + ops::CBroadcastOpXPUKernel); diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 22c2bed9ce6c58..6c3e43755352f2 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -611,6 +611,12 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"fused_concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"c_broadcast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"c_reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"batch_fc", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"batch_fc_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, }; return s_xpu2_kernels; } diff --git a/paddle/fluid/pybind/const_value.cc b/paddle/fluid/pybind/const_value.cc index 89a3904d0003fe..5503c47197a26d 100644 --- a/paddle/fluid/pybind/const_value.cc +++ b/paddle/fluid/pybind/const_value.cc @@ -49,7 +49,8 @@ void BindConstValue(pybind11::module* m) { .value("Loss", framework::OpRole::kLoss) .value("RPC", framework::OpRole::kRPC) .value("Dist", framework::OpRole::kDist) - .value("LRSched", framework::OpRole::kLRSched); + .value("LRSched", framework::OpRole::kLRSched) + .value("ScaleLr", framework::OpRole::kScaleLr); op_proto_and_checker_maker.def( "kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName); diff --git a/python/paddle/distributed/fleet/meta_optimizers/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/__init__.py index 1eae4be579aa78..9c35964587cd97 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/__init__.py +++ b/python/paddle/distributed/fleet/meta_optimizers/__init__.py @@ -27,6 +27,7 @@ from .lamb_optimizer import LambOptimizer from .fp16_allreduce_optimizer import FP16AllReduceOptimizer from .sharding_optimizer import ShardingOptimizer +from .sharding_optimizer import ThreadShardingOptimizer from .dygraph_optimizer import HybridParallelOptimizer from .dygraph_optimizer import HeterParallelOptimizer from .dygraph_optimizer import HybridParallelGradScaler diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py index 7002dfa2be5148..33c4d01e4daea8 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py @@ -51,6 +51,7 @@ def has_var(self, var_name): self._var_device_id(var_name) == self.worker_idx def _split_params(self, params_grads, worker_idx, worker_num): + """ param2device = {} total_param_mem = 0.0 param2mem = [] @@ -62,12 +63,29 @@ def _split_params(self, params_grads, worker_idx, worker_num): device_idx = 0 mem_accu = 0.0 for param_name, mem in param2mem: - if mem_accu > total_param_mem * 1.0 * (device_idx + 1) / worker_num: + if mem_accu > total_param_mem * (device_idx + 1) / worker_num: device_idx += 1 device2params[device_idx].append(param_name) param2device[param_name] = device_idx mem_accu += mem return param2device, device2params + """ + param2device = {} + device2params = {x: [] for x in range(worker_num)} + + sizes = [0] * worker_num + for param in [x[0] for x in params_grads]: + numel = get_var_size(param) + device_idx = sizes.index(min(sizes)) + device2params[device_idx].append(param.name) + param2device[param.name] = device_idx + sizes[device_idx] += numel + + for x in range(worker_num): + print("device id: %s, num: %s, mem: %s, names: %s" % ( + x, len(device2params[x]), sizes[x], device2params[x])) + + return param2device, device2params def _var_device_id(self, var_name): if var_name in self.global_param2device: diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py index 39f71be0cde764..605e94e94d9d6f 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -754,7 +754,7 @@ def get_first_optimize_op_idx(block): return first_opt_op_idx -def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root): +def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root,use_calc_stream=False): """ _add_broadcast_ops """ @@ -767,6 +767,7 @@ def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root): attrs={ 'ring_id': ring_id, 'root': root_device, + 'use_calc_stream': use_calc_stream, OP_ROLE_KEY: op_role }) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index fcecc3a9a671ec..aae9d223f5a2c7 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -71,6 +71,9 @@ def __init__(self, optimizer): self._reduced_grads_to_param = {} self._shard = Shard() self._verbose = False + self._thread_mode = False + self._use_calc_stream = False + # use sharding as outer parallelism (e.g. inner:Megatron & outer sharding) self.mp_degree = 1 @@ -576,10 +579,12 @@ def _apply_optimize_offload_pass(self, params_grads): def _dump_program_for_debug(self): main_block = self._main_program.global_block() startup_block = self._startup_program.global_block() - with open("start_sharding_%d" % self.role_maker._worker_index(), + startup_id = str(id(startup_block.program)) + with open(("start_sharding_%d_%s" % (self.role_maker._worker_index(), startup_id)), 'w') as f: f.writelines(str(startup_block.program)) - with open("main_sharding_%d" % self.role_maker._worker_index(), + main_id = str(id(main_block.program)) + with open(("main_sharding_%d_%s" % (self.role_maker._worker_index(), main_id)), 'w') as f: f.writelines(str(main_block.program)) @@ -819,7 +824,7 @@ def collect_segment(self, segment, op_idx, block): def _split_program(self, block): for op_idx, op in reversed(list(enumerate(block.ops))): - if int(op.attr('op_role')) != int(OpRole.Optimize): + if int(op.attr('op_role')) != int(OpRole.Optimize) and int(op.attr('op_role'))!= int(OpRole.ScaleLr): last_backward_op_idx = op_idx + 1 break @@ -829,6 +834,7 @@ def _split_program(self, block): for op_idx in reversed(range(last_backward_op_idx)): op = block.ops[op_idx] assert (int(op.attr('op_role')) != int(OpRole.Optimize)) + assert (int(op.attr('op_role')) != int(OpRole.ScaleLr)) if self._sharding_segment_strategy == "segment_broadcast_MB": if segment._param_mem >= self._broadcast_MB: segment = self.collect_segment(segment, op_idx, block) @@ -874,7 +880,8 @@ def _split_program(self, block): else: broadcast_var_name = unique_name.generate(input_name + "@BroadCast") - segment._fill_constant_vars.append(broadcast_var_name) + if not self._thread_mode: + segment._fill_constant_vars.append(broadcast_var_name) # (JZ-LIANG) should use Param base name ? broadcast_var_base_name = input_name @@ -1094,24 +1101,26 @@ def _add_broadcast_allreduce(self, block): if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1: if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len( shard_allredue_vars) >= 1: - insert_sync_comm_ops(block, self._segments[-1]._end_idx, - self.dp_ring_id, shard_allredue_vars) + if not self._use_calc_stream: + insert_sync_comm_ops(block, self._segments[-1]._end_idx, + self.dp_ring_id, shard_allredue_vars) insert_allreduce_ops( block, self._segments[-1]._end_idx, self.dp_ring_id, shard_allredue_vars, - user_defined_strategy=self.user_defined_strategy) + user_defined_strategy=self.user_defined_strategy, + use_calc_stream=self._use_calc_stream) # gradient merge elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: self.create_persistable_gradients_and_insert_merge_ops( block, self._startup_program.global_block(), self._segments[-1]._end_idx, shard_allredue_vars, self._shard) - - insert_sync_comm_ops(block, self._segments[-1]._end_idx, - self.sharding_ring_id, - self._segments[-1]._allreduce_vars) + if not self._use_calc_stream: + insert_sync_comm_ops(block, self._segments[-1]._end_idx, + self.sharding_ring_id, + self._segments[-1]._allreduce_vars) # allreduce --> reduce insert_reduce_ops(block, self._segments[-1]._end_idx, @@ -1119,7 +1128,8 @@ def _add_broadcast_allreduce(self, block): self._segments[-1]._allreduce_vars, self._shard, op_role=OpRole.Backward, - use_calc_stream=False) + use_calc_stream=self._use_calc_stream, + ) for idx, segment in reversed(list(enumerate(self._segments))): allreduce_vars = self._segments[ @@ -1162,11 +1172,12 @@ def _add_broadcast_allreduce(self, block): if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1: if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len( shard_allredue_vars) >= 1: - insert_sync_comm_ops(block, segment._end_idx, - self.dp_ring_id, shard_allredue_vars) + if not self._use_calc_stream: + insert_sync_comm_ops(block, segment._end_idx, + self.dp_ring_id, shard_allredue_vars) broad_cast_vars = [x[0] for x in broadcast_vars] - if len(broad_cast_vars) > 0: + if not self._use_calc_stream and len(broad_cast_vars) > 0: insert_sync_comm_ops(block, segment._end_idx, self.sharding_ring_id, broad_cast_vars) @@ -1174,14 +1185,14 @@ def _add_broadcast_allreduce(self, block): comm_dep_vars = allreduce_vars + [ x[0] for x in broadcast_vars ] - if len(comm_dep_vars) > 0: + if not self._use_calc_stream and len(comm_dep_vars) > 0: insert_sync_comm_ops(block, segment._end_idx, self.sharding_ring_id, comm_dep_vars) # gradient merge elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: broad_cast_vars = [x[0] for x in broadcast_vars] - if len(broad_cast_vars) > 0: + if not self._use_calc_stream and len(broad_cast_vars) > 0: insert_sync_comm_ops(block, segment._end_idx, self.sharding_ring_id, broad_cast_vars) @@ -1189,7 +1200,7 @@ def _add_broadcast_allreduce(self, block): k for k, v in cast_ops.items() ] + self._segments[idx]._allreduce_vars - if len(calc_dep_vars) > 0: + if not self._use_calc_stream and len(calc_dep_vars) > 0: insert_sync_calc_op(block, segment._end_idx, [calc_dep_vars[-1]]) @@ -1208,7 +1219,7 @@ def _add_broadcast_allreduce(self, block): segment._start_idx, shard_allredue_vars, self._shard) insert_broadcast_ops(block, segment._start_idx, - self.sharding_ring_id, broadcast_vars) + self.sharding_ring_id, broadcast_vars, self._use_calc_stream) # step6: add all_reduce ops # dp @@ -1220,13 +1231,17 @@ def _add_broadcast_allreduce(self, block): segment._start_idx, self.dp_ring_id, shard_allredue_vars, - user_defined_strategy=self.user_defined_strategy) - insert_sync_comm_ops(block, segment._start_idx, - self.sharding_ring_id, allreduce_vars) + user_defined_strategy=self.user_defined_strategy, + use_calc_stream=self._use_calc_stream, + ) + if not self._use_calc_stream: + insert_sync_comm_ops(block, segment._start_idx, + self.sharding_ring_id, allreduce_vars) # gradient merge elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: - insert_sync_comm_ops(block, segment._start_idx, - self.sharding_ring_id, allreduce_vars) + if not self._use_calc_stream: + insert_sync_comm_ops(block, segment._start_idx, + self.sharding_ring_id, allreduce_vars) # sharding # allreduce --> reduce # TODO temp change @@ -1237,17 +1252,19 @@ def _add_broadcast_allreduce(self, block): allreduce_vars, self._shard, op_role=OpRole.Backward, - use_calc_stream=False) + use_calc_stream=self._use_calc_stream) block._sync_with_cpp() if self._segments[0]._broadcast_vars: broadcast_vars = [x[0] for x in self._segments[0]._broadcast_vars] - insert_sync_comm_ops(block, self._segments[0]._start_idx, - self.sharding_ring_id, broadcast_vars) + if not self._use_calc_stream: + insert_sync_comm_ops(block, self._segments[0]._start_idx, + self.sharding_ring_id, broadcast_vars) insert_broadcast_ops(block, self._segments[0]._start_idx, self.sharding_ring_id, - self._segments[0]._broadcast_vars) + self._segments[0]._broadcast_vars, + self._use_calc_stream) fill_constant_vars = [] for x in self._segments[:2]: @@ -1260,7 +1277,7 @@ def _add_broadcast_allreduce(self, block): cast_ops[k] = v calc_deps_vars = fill_constant_vars + [k for k, v in cast_ops.items()] - if fill_constant_vars or cast_ops: + if not self._use_calc_stream and fill_constant_vars or cast_ops: insert_sync_calc_op(block, self._segments[0]._start_idx, [calc_deps_vars[-1]]) @@ -1308,7 +1325,10 @@ def _build_groups(self): self.global_word_size = self.role_maker._worker_num() self.global_rank = self.role_maker._worker_index() self.global_endpoints = self.role_maker._get_trainer_endpoints() - self.current_endpoint = self.global_endpoints[self.global_rank] + if self._thread_mode: + self.current_endpoint = self.global_endpoints[self.role_maker._role_id()] + else: + self.current_endpoint = self.global_endpoints[self.global_rank] self._collective_helper = CollectiveHelper(self.role_maker, nrings=self._nrings_sharding) assert self.global_word_size % self.mp_degree == 0, \ @@ -1844,3 +1864,205 @@ def _sharding_gradient_merge(self): 'sub_block': cond_block, 'is_scalar_condition': True, }) +class ThreadShardingOptimizer(ShardingOptimizer): + """Sharding Optimizer.""" + def __init__(self, optimizer): + super().__init__(optimizer) + self.inner_opt = optimizer + self.meta_optimizers_white_list = [ + "ParameterServerOptimizer", + "RecomputeOptimizer", + "AMPOptimizer", + "LarsOptimizer", + "LambOptimizer", + "ASPOptimizer", + # "ModelParallelOptimizer", + # "PipelineOptimizer", + ] + self._thread_mode = True + self._use_calc_stream = False + op_maker = core.op_proto_and_checker_maker + self.op_role_key = op_maker.kOpRoleAttrName() + + def _prune_main_program(self, block, shard, rings): + """ + rename BroadCast param + """ + var_names = set([]) + for idx, op in enumerate(block.ops): + for input_name in op.desc.input_arg_names(): + pos = input_name.find("@BroadCast") + if pos <= 0: + continue + new_name = input_name[0 : pos] + op.desc._rename_input( + input_name, new_name + ) + var_names.add(input_name) + for output_name in op.desc.output_arg_names(): + pos = output_name.find("@BroadCast") + if pos <= 0: + continue + new_name = output_name[0 : pos] + op.desc._rename_output( + output_name, new_name + ) + var_names.add(output_name) + + for var_name in var_names: + block._remove_var(var_name, sync=False) + + print("remove broadcast param count=", len(var_names)) + block._sync_with_cpp() + + def _prune_startup_program(self, block, shard): + """ + not need process + """ + block._sync_with_cpp() + + def _insert_loss_grad_scale_op(self): + """ + paddlebox grad not need scale + """ + main_block = self._main_program.global_block() + # # step6: loss div dp_degree + # global_dp_degree = self.sharding_degree * self.dp_degree + # assert int(global_dp_degree) == global_dp_degree + # if global_dp_degree > 1: + # insert_scale_loss_grad_ops(main_block, scale=global_dp_degree) + main_block._sync_with_cpp() + + def minimize_impl( + self, loss, startup_program=None, parameter_list=None, no_grad_set=None + ): + """ + reset start program and main program + """ + sharding_configs = self.user_defined_strategy.sharding_configs + if "use_calc_stream" in sharding_configs: + self._use_calc_stream = sharding_configs["use_calc_stream"] + optimize_ops, params_grads = super().minimize_impl( + loss, startup_program, parameter_list, no_grad_set) + # main_block = self._main_program.global_block() + # startup_block = self._startup_program.global_block() + loss.block.program = self._main_program + from paddle import fluid + fluid.framework.switch_startup_program(self._startup_program) + return optimize_ops, params_grads + + def _init_comm(self): + # sync var + self.role_id = self.role_maker._role_id() + self.node_nums = self.role_maker._node_num() + startup_block = self._startup_program.global_block() + if self.node_nums > 1: + node_nums = len(self.global_endpoints) + assert ( + self.node_nums == node_nums + ), "end points not equal node nums" + self.current_endpoint = self.global_endpoints[self.role_id] + + # mp ring + if self.mp_degree > 1: + self._init_communicator( + self._startup_program, + self.current_endpoint, + self.mp_group_endpoints, + self.role_id, + self.mp_ring_id, + ) + + # sharding ring + if self.sharding_degree > 1: + self._init_communicator( + self._startup_program, + self.current_endpoint, + self.sharding_group_endpoints, + self.role_id, + self.sharding_ring_id, + ) + + # pure dp ring + if self.dp_degree > 1: + self._init_communicator( + self._startup_program, + self.current_endpoint, + self.dp_group_endpoints, + self.role_id, + self.dp_ring_id, + ) + + startup_block._sync_with_cpp() + + def _wait(self): + if self.node_nums <= 1: + return + endpoints = self.global_endpoints[:] + current_endpoint = endpoints[self.role_id] + if self.global_rank == 0: + from paddle.fluid.transpiler.details import wait_server_ready + endpoints.remove(current_endpoint) + wait_server_ready(endpoints) + + def _init_communicator( + self, + program, + current_endpoint, + endpoints, + role_id, + ring_id + ): + block = program.global_block() + # init mulit node nccl + if self.node_nums > 1: + other_endpoints = endpoints[:] + other_endpoints.remove(current_endpoint) + + comm_id_var = block.create_var( + name=unique_name.generate('comm_id'), + persistable=True, + type=core.VarDesc.VarType.RAW, + ) + if core.is_compiled_with_cuda(): + block.append_op( + type='c_gen_nccl_id', + inputs={}, + outputs={'Out': comm_id_var}, + attrs={ + 'rank': role_id, + 'endpoint': current_endpoint, + 'other_endpoints': other_endpoints, + self.op_role_key: OpRole.Forward, + }, + ) + + elif core.is_compiled_with_xpu(): + block.append_op( + type='c_gen_bkcl_id', + inputs={}, + outputs={'Out': comm_id_var}, + attrs={ + 'rank': role_id, + 'endpoint': current_endpoint, + 'other_endpoints': other_endpoints, + self.op_role_key: OpRole.Forward, + }, + ) + + block.append_op( + type='c_comm_init_multitrainer', + inputs={'X': comm_id_var}, + outputs={}, + attrs={ + 'ntrainers': self.node_nums, + 'trainer_id': role_id, + 'ring_id': ring_id, + self.op_role_key: OpRole.Forward, + }, + ) + else: + block.append_op( + type='c_comm_init_all', + attrs={'ring_id': ring_id} + ) diff --git a/python/paddle/fluid/tests/unittests/xpu/test_batch_fc_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_batch_fc_op_xpu.py new file mode 100644 index 00000000000000..9c556b342daa5b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_batch_fc_op_xpu.py @@ -0,0 +1,111 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import random +from op_test import OpTest +from op_test_xpu import XPUOpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +from op_test import OpTest, skip_check_grad_ci +import paddle.fluid.core as core + +paddle.enable_static() + +def np_cal_batchfc(input, w, bias): + slot_pairs_num, batch_size, in_dim = input.shape + _, _, out_dim = w.shape + res = np.zeros((slot_pairs_num, batch_size, out_dim)) + for slot in range(slot_pairs_num): + res[slot, :] = np.dot(input[slot, :], w[slot, :]) + # for slot in range(slot_pairs_num): + # for bindx in range(out_dim): + # res[slot, :, bindx] += bias[slot, bindx] + return res + + +class TestBatchFCOp(XPUOpTest): + + def config(self): + self.slot_pairs_num = 10 + self.batch_size = 5 + self.in_dim = 10 + self.out_dim = 12 + self.dtype = "float32" + + def setUp(self): + self.config() + self.input = np.random.random((self.slot_pairs_num, self.batch_size, + self.in_dim)).astype(self.dtype) + self.w = np.random.random( + (self.slot_pairs_num, self.in_dim, self.out_dim)).astype(self.dtype) + self.bias = np.random.random( + (self.slot_pairs_num, self.out_dim)).astype(self.dtype) + self.op_type = "batch_fc" + np_out = np_cal_batchfc(self.input, self.w, self.bias) + np_out = np_out.astype(self.dtype) + self.inputs = {"Input": self.input, "W": self.w, "Bias": self.bias} + self.outputs = {"Out": np_out} + + def test_check_output_xpu(self): + if core.is_compiled_with_xpu(): + self.check_output_with_place(paddle.XPUPlace(0)) + + def test_check_grad_xpu(self): + if core.is_compiled_with_xpu(): + self.check_grad_with_place(paddle.XPUPlace(0), + ["Bias", "W", "Input"], "Out") + + +class TestBatchFCOp1(XPUOpTest): + + def config(self): + self.slot_pairs_num = 10 + self.batch_size = 5 + self.in_dim = 10 + self.out_dim = 12 + self.dtype = "float32" + + def setUp(self): + self.config() + self.input = np.random.random((self.slot_pairs_num, self.batch_size, + self.in_dim)).astype(self.dtype) + self.w = np.random.random( + (self.slot_pairs_num, self.in_dim, self.out_dim)).astype(self.dtype) + self.bias = np.random.random( + (self.slot_pairs_num, self.out_dim)).astype(self.dtype) + self.op_type = "batch_fc" + np_out = np_cal_batchfc(self.input, self.w, self.bias) + np_out = np_out.astype(self.dtype) + self.inputs = {"Input": self.input, "W": self.w, "Bias": self.bias} + self.outputs = {"Out": np_out} + + def test_check_output_cpu(self): + try: + self.check_output_with_place(place=core.CPUPlace()) + except: + print("do not support cpu test, skip") + + def test_check_grad_cpu(self): + try: + self.check_grad_with_place(core.CPUPlace(), ["Bias", "W", "Input"], + "Out") + except: + print("do not support cpu test, skip") + + +if __name__ == "__main__": + unittest.main() From 6ce59ed38972a6159f6850708ffd3a8356e0a5b4 Mon Sep 17 00:00:00 2001 From: xiayanming Date: Fri, 22 Mar 2024 17:37:12 +0800 Subject: [PATCH 2/2] for support ascvrq --- paddle/fluid/framework/op_proto_maker.h | 3 + paddle/fluid/operators/batch_fc_op.kps | 361 ++++++++++++++++++ .../collective/c_broadcast_op_xpu.cc | 111 ++++++ .../fluid/platform/device/xpu/xpu2_op_list.h | 6 + paddle/fluid/pybind/const_value.cc | 3 +- .../fleet/meta_optimizers/__init__.py | 1 + .../fleet/meta_optimizers/sharding/shard.py | 20 +- .../fleet/meta_optimizers/sharding/utils.py | 3 +- .../meta_optimizers/sharding_optimizer.py | 282 ++++++++++++-- .../unittests/xpu/test_batch_fc_op_xpu.py | 111 ++++++ 10 files changed, 868 insertions(+), 33 deletions(-) create mode 100644 paddle/fluid/operators/batch_fc_op.kps create mode 100644 paddle/fluid/operators/collective/c_broadcast_op_xpu.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_batch_fc_op_xpu.py diff --git a/paddle/fluid/framework/op_proto_maker.h b/paddle/fluid/framework/op_proto_maker.h index 51aeed2e5d734e..07ab71aa38dc9a 100644 --- a/paddle/fluid/framework/op_proto_maker.h +++ b/paddle/fluid/framework/op_proto_maker.h @@ -35,6 +35,9 @@ enum class OpRole { // Tag all learning rate scheduler operators. kLRSched = 0x0010, + // scale lr(for adam) + kScaleLr = 0x0012, + kLoss = 0x0100, // The default value of op's role. This should be only used for unittests and // CreateOp inside a operator. diff --git a/paddle/fluid/operators/batch_fc_op.kps b/paddle/fluid/operators/batch_fc_op.kps new file mode 100644 index 00000000000000..a79b556b754dfb --- /dev/null +++ b/paddle/fluid/operators/batch_fc_op.kps @@ -0,0 +1,361 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#ifdef PADDLE_WITH_XPU_KP + +#include // NOLINT +#include +#include +#include +#include + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/device/xpu/enforce_xpu.h" +#include "paddle/fluid/platform/device_context.h" + +#include "xpu/kernel/xtdk.h" // NOLINT +#include "xpu/kernel/xtdk_math.h" // NOLINT +#include "xpu/kernel/xtdk_simd.h" + +#include "xpu/kernel/xtdk_io.h" + +#include "paddle/fluid/operators/batch_fc_op.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +#include "paddle/fluid/operators/xpu_api_wrapper.h" + +namespace paddle { +namespace operators { +using framework::Tensor; + +static __device__ void primitive_add(const float* x, const float* y, float* z, int len) { + float32x16_t vx0; + float32x16_t vy0; + float32x16_t vx1; + float32x16_t vy1; + int len_rounddown32 = rounddown32(len); + int remain = len - len_rounddown32; + for (int i = 0; i < len_rounddown32; i += 32) { + vx0 = vload_lm_float32x16(x + i); + vx1 = vload_lm_float32x16(x + i + 16); + vy0 = vload_lm_float32x16(y + i); + vy1 = vload_lm_float32x16(y + i + 16); + vy0 = vvadd_float32x16(vx0, vy0); + vy1 = vvadd_float32x16(vx1, vy1); + vstore_lm_float32x16(z + i, vy0); + vstore_lm_float32x16(z + i + 16, vy1); + } + for (int i = 0; i < remain; i++) { + *(z + len_rounddown32 + i) = *(y + len_rounddown32 + i) + *(x + len_rounddown32 + i); + } + mfence_lm(); +} + +static __device__ inline void memset_lm_float(float* dst_ptr, int size) { + for (int i = 0; i < size; i += 16) { + vstore_lm_float32x16_mz(dst_ptr + i, 0, 0); + } + mfence_lm(); +} + +template +__global__ void add_bias_kernel( + T* data, int slot_pairs_num, int ins_num, int out_dim, const T* bias) { + int cid = core_id(); + int ncores = core_num(); + if (cid >= ncores) { + return; + } + int thread_id = cluster_id() * ncores + cid; + int total_thread = cluster_num() * ncores; + + const int buf_size = 512; + int max_seq_len = buf_size / out_dim; + + __simd__ T local_data_buf[buf_size]; + __simd__ T local_bias_buf[256]; + + __simd__ T out_buf[buf_size]; + memset_lm_float(out_buf, buf_size); + + for (int slot = thread_id; slot < slot_pairs_num; slot += total_thread) { + mfence(); + GM2LM(bias + slot * out_dim, local_bias_buf, out_dim * sizeof(T)); + + for (int i = 0; i < ins_num; i += max_seq_len) { + int len = min(ins_num - i, max_seq_len); + + GM2LM(data + slot * ins_num * out_dim, local_data_buf, len * out_dim * sizeof(T)); + for (int j = 0; j < len; j++) { + primitive_add(local_data_buf + j * out_dim, local_bias_buf, out_buf + j * out_dim, out_dim); + } + // mfence(); + LM2GM_ASYNC(out_buf, data + slot * ins_num * out_dim, len * out_dim * sizeof(T)); + } + } +} + +template +void add_bias(xpu::Context* xpu_ctx, + T* data, + int slot_pairs_num, + int ins_num, + int out_dim, + const T* bias) { + auto stream = xpu_ctx->xpu_stream; + add_bias_kernel<<<8, 64, stream>>>(data, slot_pairs_num, ins_num, out_dim, bias); +} + +template +__global__ void add_bias_grad_kernel(const T* dout_data, + int slot_pairs_num, + int ins_num, + int out_dim, + T* db_data) { + int cid = core_id(); + int ncores = core_num(); + if (cid >= ncores) { + return; + } + int thread_id = cluster_id() * ncores + cid; + int total_thread = cluster_num() * ncores; + + int buf_size = out_dim + 16; + __simd__ T local_bias_buf[buf_size]; + __simd__ T tmp_sum_buf[buf_size]; + + __local__ T local_data_buf[1]; + + // memset_lm_float(local_bias_buf, buf_size); + memset_lm_float(tmp_sum_buf, buf_size); + + __local__ T tmp_sum = static_cast(0); + for (int slot = thread_id; slot < slot_pairs_num; slot += total_thread) { + mfence(); + GM2LM(db_data + slot * out_dim, local_bias_buf, out_dim * sizeof(T)); + + for (int index = 0; index < out_dim; index++) { + for (int i = 0; i < ins_num; ++i) { + int select_indx = ((slot + 1) * i + 1) * index; + GM2LM_ASYNC(dout_data + select_indx, local_data_buf, sizeof(T)); + mfence(); + tmp_sum_buf[index] += local_data_buf[0]; + } + } + + // mfence(); + primitive_add(tmp_sum_buf, local_bias_buf, local_bias_buf, out_dim); + + // mfence(); + LM2GM_ASYNC(local_bias_buf, db_data + slot * out_dim, out_dim * sizeof(T)); + mfence(); + } +} + +template +void add_bias_grad(xpu::Context* xpu_ctx, + const T* dout_data, + int slot_pairs_num, + int ins_num, + int out_dim, + T* db_data) { + auto stream = xpu_ctx->xpu_stream; + add_bias_grad_kernel<<<8, 64, stream>>>( + dout_data, slot_pairs_num, ins_num, out_dim, db_data); +} + + + +template +class BatchFCXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext& ctx) const override { + int batchcount = ctx.Attr("batchcount"); + auto transpose_weight = ctx.Attr("transpose_weight"); + if (transpose_weight) { + // TODO + PADDLE_ENFORCE_EQ( + transpose_weight, + true, + platform::errors::Unimplemented("BatchFC not support transpose_weight now.")); + return; + } + if (batchcount > 0) { + // TODO + PADDLE_ENFORCE_EQ( + (batchcount > 0), + true, + platform::errors::Unimplemented("BatchFC not support transpose_weight now.")); + } else { + // X.dim = slot_pairs_num * ins_num * in_dim + // W.dim = slot_pairs_num * in_dim * out_dim + // b.dim = slot_pairs_num * out_dim + // output.dim = slot_pairs_num * ins_num * out_dim + auto* input = ctx.Input("Input"); + auto* w = ctx.Input("W"); + auto* bias = ctx.Input("Bias"); + auto* output = ctx.Output("Out"); + auto input_dims = input->dims(); + auto w_dims = w->dims(); + auto slot_pairs_num = input_dims[0]; + auto ins_num = input_dims[1]; + // auto in_dim = input_dims[2]; + auto out_dim = w_dims[2]; + + // get data ptr + const XPUType* x_ptr = reinterpret_cast(input->data()); + const XPUType* y_ptr = reinterpret_cast(w->data()); + const XPUType* bias_data = reinterpret_cast(bias->data()); + + output->Resize({slot_pairs_num, ins_num, out_dim}); + XPUType* out_ptr = reinterpret_cast(output->mutable_data(ctx.GetPlace())); + + // initialize + auto& dev_ctx = + ctx.template device_context(); + xpu::Context* xpu_ctx = dev_ctx.x_context(); + + bool trans_x = false; + bool trans_y = false; + + T alpha = 1; + T beta = 0; + + XpuFcInfo fc_info; + GetFCInfo(input_dims, w_dims, trans_x, trans_y, &fc_info); + MatMulXPUFunction(xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, alpha); + + // add bias + add_bias(xpu_ctx, + out_ptr, + slot_pairs_num, + ins_num, + out_dim, + bias_data); + } + } +}; + +template +class BatchFCGradOpXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext& ctx) const override { + int batchcount = ctx.Attr("batchcount"); + if (batchcount > 0) { + // TODO + PADDLE_ENFORCE_EQ( + (batchcount > 0), + true, + platform::errors::Unimplemented("BatchFC not support transpose_weight now.")); + } else { + auto* input = ctx.Input("Input"); + auto* w = ctx.Input("W"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + + auto* dx = ctx.Output(framework::GradVarName("Input")); + auto* dw = ctx.Output(framework::GradVarName("W")); + auto* db = ctx.Output(framework::GradVarName("Bias")); + + auto input_dims = input->dims(); + auto w_dims = w->dims(); + auto slot_pairs_num = input_dims[0]; + auto ins_num = input_dims[1]; + // auto in_dim = input_dims[2]; + auto out_dim = w_dims[2]; + + const XPUType* dout_ptr = reinterpret_cast(dout->data()); + const XPUType* x_ptr = reinterpret_cast(dx->mutable_data(ctx.GetPlace())); + const XPUType* y_ptr = reinterpret_cast(dw->mutable_data(ctx.GetPlace())); + XPUType* b_ptr = reinterpret_cast(db->mutable_data(ctx.GetPlace())); + + auto& dev_ctx = + ctx.template device_context(); + xpu::Context* xpu_ctx = dev_ctx.x_context(); + xpu::ctx_guard RAII_GUARD(xpu_ctx); + + bool transpose_x = false; + bool transpose_y = false; + XpuFcInfo info_forward; + GetFCInfo(input_dims, w_dims, transpose_x, transpose_y, &info_forward); + + const XPUType* a_1 = reinterpret_cast(NULL); + const XPUType* b_1 = reinterpret_cast(NULL); + const XPUType* a_2 = reinterpret_cast(NULL); + const XPUType* b_2 = reinterpret_cast(NULL); + XPUType* c_1 = (dx == NULL) ? reinterpret_cast(NULL) + : reinterpret_cast(dx->data()); + XPUType* c_2 = (dw == NULL) ? reinterpret_cast(NULL) + : reinterpret_cast(dw->data()); + + // add bias grad + db->mutable_data(ctx.GetPlace()); + add_bias_grad(xpu_ctx, + dout_ptr, + slot_pairs_num, + ins_num, + out_dim, + b_ptr); + + T alpha = 1; + + XpuFcInfo info_dx; + XpuFcInfo info_dy; + std::tuple + fc_info = MatmulGradFcInfo(xpu_ctx, + &RAII_GUARD, + info_forward, + transpose_x, + transpose_y, + x_ptr, + y_ptr, + dout_ptr); + std::tie(info_dx, info_dy, a_1, b_1, a_2, b_2) = fc_info; + + // dx = dout_data * y^T + MatMulXPUFunction(xpu_ctx, a_1, b_1, c_1, info_dx, alpha); + + // dy = x^T * dout_data + MatMulXPUFunction(xpu_ctx, a_2, b_2, c_2, info_dy, alpha); + } + } +}; + +} // namespace operators +} // namespace paddle + + + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_KERNEL(batch_fc, KP, plat::XPUPlace, + ops::BatchFCXPUKernel); +REGISTER_OP_KERNEL(batch_fc_grad, KP, plat::XPUPlace, + ops::BatchFCGradOpXPUKernel); + +// namespace ops = paddle::operators; + +REGISTER_OP_XPU_KERNEL(batch_fc, + ops::BatchFCXPUKernel); +REGISTER_OP_XPU_KERNEL(batch_fc_grad, + ops::BatchFCGradOpXPUKernel); +#endif \ No newline at end of file diff --git a/paddle/fluid/operators/collective/c_broadcast_op_xpu.cc b/paddle/fluid/operators/collective/c_broadcast_op_xpu.cc new file mode 100644 index 00000000000000..ae2e3657ad6d69 --- /dev/null +++ b/paddle/fluid/operators/collective/c_broadcast_op_xpu.cc @@ -0,0 +1,111 @@ +/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_broadcast_op.h" + +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_XPU_BKCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class CBroadcastOpXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_XPU_BKCL) + auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + size_t numel = x->numel(); + + BKCLDataType dtype = + platform::ToBKCLDataType(framework::TransToProtoVarType(x->dtype())); + int ring_id = ctx.Attr("ring_id"); + auto place = ctx.GetPlace(); + int root = ctx.Attr("root"); + + auto comm = paddle::platform::BKCLCommContext::Instance().Get(ring_id, place); + auto stream = comm->stream(); + VLOG(3) << "BKCLCommContext ring_id " << ring_id; + + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx) + ->x_context() + ->xpu_stream; + } + + void* send_recv_buffer = nullptr; + if (root == comm->rank()) { + send_recv_buffer = + reinterpret_cast(const_cast(x->data())); + PADDLE_ENFORCE_XPU_SUCCESS(bkcl_broadcast(comm->comm(), + send_recv_buffer, + send_recv_buffer, + numel, + dtype, + root, + stream)); + VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent " + << x->numel(); + if (out != x) { + framework::TensorCopy( + *static_cast(x), + place, + *platform::DeviceContextPool::Instance().Get(place), + static_cast(out)); + } + } else { + auto& dev_ctx = + ctx.template device_context(); + dev_ctx.template Alloc(out); + send_recv_buffer = out->data(); + PADDLE_ENFORCE_XPU_SUCCESS(bkcl_broadcast(comm->comm(), + send_recv_buffer, + send_recv_buffer, + numel, + dtype, + root, + stream)); + VLOG(3) << "rank " << comm->rank() << " invoke Bcast. received " + << phi::product(out->dims()); + } + + out->Resize(x->dims()); + out->set_lod(x->lod()); +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with XPU.")); +#endif + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL(c_broadcast, + ops::CBroadcastOpXPUKernel); diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 22c2bed9ce6c58..6c3e43755352f2 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -611,6 +611,12 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"fused_concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"c_broadcast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"c_reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"batch_fc", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"batch_fc_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, }; return s_xpu2_kernels; } diff --git a/paddle/fluid/pybind/const_value.cc b/paddle/fluid/pybind/const_value.cc index 89a3904d0003fe..5503c47197a26d 100644 --- a/paddle/fluid/pybind/const_value.cc +++ b/paddle/fluid/pybind/const_value.cc @@ -49,7 +49,8 @@ void BindConstValue(pybind11::module* m) { .value("Loss", framework::OpRole::kLoss) .value("RPC", framework::OpRole::kRPC) .value("Dist", framework::OpRole::kDist) - .value("LRSched", framework::OpRole::kLRSched); + .value("LRSched", framework::OpRole::kLRSched) + .value("ScaleLr", framework::OpRole::kScaleLr); op_proto_and_checker_maker.def( "kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName); diff --git a/python/paddle/distributed/fleet/meta_optimizers/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/__init__.py index 1eae4be579aa78..9c35964587cd97 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/__init__.py +++ b/python/paddle/distributed/fleet/meta_optimizers/__init__.py @@ -27,6 +27,7 @@ from .lamb_optimizer import LambOptimizer from .fp16_allreduce_optimizer import FP16AllReduceOptimizer from .sharding_optimizer import ShardingOptimizer +from .sharding_optimizer import ThreadShardingOptimizer from .dygraph_optimizer import HybridParallelOptimizer from .dygraph_optimizer import HeterParallelOptimizer from .dygraph_optimizer import HybridParallelGradScaler diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py index 7002dfa2be5148..33c4d01e4daea8 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py @@ -51,6 +51,7 @@ def has_var(self, var_name): self._var_device_id(var_name) == self.worker_idx def _split_params(self, params_grads, worker_idx, worker_num): + """ param2device = {} total_param_mem = 0.0 param2mem = [] @@ -62,12 +63,29 @@ def _split_params(self, params_grads, worker_idx, worker_num): device_idx = 0 mem_accu = 0.0 for param_name, mem in param2mem: - if mem_accu > total_param_mem * 1.0 * (device_idx + 1) / worker_num: + if mem_accu > total_param_mem * (device_idx + 1) / worker_num: device_idx += 1 device2params[device_idx].append(param_name) param2device[param_name] = device_idx mem_accu += mem return param2device, device2params + """ + param2device = {} + device2params = {x: [] for x in range(worker_num)} + + sizes = [0] * worker_num + for param in [x[0] for x in params_grads]: + numel = get_var_size(param) + device_idx = sizes.index(min(sizes)) + device2params[device_idx].append(param.name) + param2device[param.name] = device_idx + sizes[device_idx] += numel + + for x in range(worker_num): + print("device id: %s, num: %s, mem: %s, names: %s" % ( + x, len(device2params[x]), sizes[x], device2params[x])) + + return param2device, device2params def _var_device_id(self, var_name): if var_name in self.global_param2device: diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py index 39f71be0cde764..605e94e94d9d6f 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -754,7 +754,7 @@ def get_first_optimize_op_idx(block): return first_opt_op_idx -def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root): +def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root,use_calc_stream=False): """ _add_broadcast_ops """ @@ -767,6 +767,7 @@ def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root): attrs={ 'ring_id': ring_id, 'root': root_device, + 'use_calc_stream': use_calc_stream, OP_ROLE_KEY: op_role }) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index fcecc3a9a671ec..aae9d223f5a2c7 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -71,6 +71,9 @@ def __init__(self, optimizer): self._reduced_grads_to_param = {} self._shard = Shard() self._verbose = False + self._thread_mode = False + self._use_calc_stream = False + # use sharding as outer parallelism (e.g. inner:Megatron & outer sharding) self.mp_degree = 1 @@ -576,10 +579,12 @@ def _apply_optimize_offload_pass(self, params_grads): def _dump_program_for_debug(self): main_block = self._main_program.global_block() startup_block = self._startup_program.global_block() - with open("start_sharding_%d" % self.role_maker._worker_index(), + startup_id = str(id(startup_block.program)) + with open(("start_sharding_%d_%s" % (self.role_maker._worker_index(), startup_id)), 'w') as f: f.writelines(str(startup_block.program)) - with open("main_sharding_%d" % self.role_maker._worker_index(), + main_id = str(id(main_block.program)) + with open(("main_sharding_%d_%s" % (self.role_maker._worker_index(), main_id)), 'w') as f: f.writelines(str(main_block.program)) @@ -819,7 +824,7 @@ def collect_segment(self, segment, op_idx, block): def _split_program(self, block): for op_idx, op in reversed(list(enumerate(block.ops))): - if int(op.attr('op_role')) != int(OpRole.Optimize): + if int(op.attr('op_role')) != int(OpRole.Optimize) and int(op.attr('op_role'))!= int(OpRole.ScaleLr): last_backward_op_idx = op_idx + 1 break @@ -829,6 +834,7 @@ def _split_program(self, block): for op_idx in reversed(range(last_backward_op_idx)): op = block.ops[op_idx] assert (int(op.attr('op_role')) != int(OpRole.Optimize)) + assert (int(op.attr('op_role')) != int(OpRole.ScaleLr)) if self._sharding_segment_strategy == "segment_broadcast_MB": if segment._param_mem >= self._broadcast_MB: segment = self.collect_segment(segment, op_idx, block) @@ -874,7 +880,8 @@ def _split_program(self, block): else: broadcast_var_name = unique_name.generate(input_name + "@BroadCast") - segment._fill_constant_vars.append(broadcast_var_name) + if not self._thread_mode: + segment._fill_constant_vars.append(broadcast_var_name) # (JZ-LIANG) should use Param base name ? broadcast_var_base_name = input_name @@ -1094,24 +1101,26 @@ def _add_broadcast_allreduce(self, block): if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1: if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len( shard_allredue_vars) >= 1: - insert_sync_comm_ops(block, self._segments[-1]._end_idx, - self.dp_ring_id, shard_allredue_vars) + if not self._use_calc_stream: + insert_sync_comm_ops(block, self._segments[-1]._end_idx, + self.dp_ring_id, shard_allredue_vars) insert_allreduce_ops( block, self._segments[-1]._end_idx, self.dp_ring_id, shard_allredue_vars, - user_defined_strategy=self.user_defined_strategy) + user_defined_strategy=self.user_defined_strategy, + use_calc_stream=self._use_calc_stream) # gradient merge elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: self.create_persistable_gradients_and_insert_merge_ops( block, self._startup_program.global_block(), self._segments[-1]._end_idx, shard_allredue_vars, self._shard) - - insert_sync_comm_ops(block, self._segments[-1]._end_idx, - self.sharding_ring_id, - self._segments[-1]._allreduce_vars) + if not self._use_calc_stream: + insert_sync_comm_ops(block, self._segments[-1]._end_idx, + self.sharding_ring_id, + self._segments[-1]._allreduce_vars) # allreduce --> reduce insert_reduce_ops(block, self._segments[-1]._end_idx, @@ -1119,7 +1128,8 @@ def _add_broadcast_allreduce(self, block): self._segments[-1]._allreduce_vars, self._shard, op_role=OpRole.Backward, - use_calc_stream=False) + use_calc_stream=self._use_calc_stream, + ) for idx, segment in reversed(list(enumerate(self._segments))): allreduce_vars = self._segments[ @@ -1162,11 +1172,12 @@ def _add_broadcast_allreduce(self, block): if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1: if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len( shard_allredue_vars) >= 1: - insert_sync_comm_ops(block, segment._end_idx, - self.dp_ring_id, shard_allredue_vars) + if not self._use_calc_stream: + insert_sync_comm_ops(block, segment._end_idx, + self.dp_ring_id, shard_allredue_vars) broad_cast_vars = [x[0] for x in broadcast_vars] - if len(broad_cast_vars) > 0: + if not self._use_calc_stream and len(broad_cast_vars) > 0: insert_sync_comm_ops(block, segment._end_idx, self.sharding_ring_id, broad_cast_vars) @@ -1174,14 +1185,14 @@ def _add_broadcast_allreduce(self, block): comm_dep_vars = allreduce_vars + [ x[0] for x in broadcast_vars ] - if len(comm_dep_vars) > 0: + if not self._use_calc_stream and len(comm_dep_vars) > 0: insert_sync_comm_ops(block, segment._end_idx, self.sharding_ring_id, comm_dep_vars) # gradient merge elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: broad_cast_vars = [x[0] for x in broadcast_vars] - if len(broad_cast_vars) > 0: + if not self._use_calc_stream and len(broad_cast_vars) > 0: insert_sync_comm_ops(block, segment._end_idx, self.sharding_ring_id, broad_cast_vars) @@ -1189,7 +1200,7 @@ def _add_broadcast_allreduce(self, block): k for k, v in cast_ops.items() ] + self._segments[idx]._allreduce_vars - if len(calc_dep_vars) > 0: + if not self._use_calc_stream and len(calc_dep_vars) > 0: insert_sync_calc_op(block, segment._end_idx, [calc_dep_vars[-1]]) @@ -1208,7 +1219,7 @@ def _add_broadcast_allreduce(self, block): segment._start_idx, shard_allredue_vars, self._shard) insert_broadcast_ops(block, segment._start_idx, - self.sharding_ring_id, broadcast_vars) + self.sharding_ring_id, broadcast_vars, self._use_calc_stream) # step6: add all_reduce ops # dp @@ -1220,13 +1231,17 @@ def _add_broadcast_allreduce(self, block): segment._start_idx, self.dp_ring_id, shard_allredue_vars, - user_defined_strategy=self.user_defined_strategy) - insert_sync_comm_ops(block, segment._start_idx, - self.sharding_ring_id, allreduce_vars) + user_defined_strategy=self.user_defined_strategy, + use_calc_stream=self._use_calc_stream, + ) + if not self._use_calc_stream: + insert_sync_comm_ops(block, segment._start_idx, + self.sharding_ring_id, allreduce_vars) # gradient merge elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: - insert_sync_comm_ops(block, segment._start_idx, - self.sharding_ring_id, allreduce_vars) + if not self._use_calc_stream: + insert_sync_comm_ops(block, segment._start_idx, + self.sharding_ring_id, allreduce_vars) # sharding # allreduce --> reduce # TODO temp change @@ -1237,17 +1252,19 @@ def _add_broadcast_allreduce(self, block): allreduce_vars, self._shard, op_role=OpRole.Backward, - use_calc_stream=False) + use_calc_stream=self._use_calc_stream) block._sync_with_cpp() if self._segments[0]._broadcast_vars: broadcast_vars = [x[0] for x in self._segments[0]._broadcast_vars] - insert_sync_comm_ops(block, self._segments[0]._start_idx, - self.sharding_ring_id, broadcast_vars) + if not self._use_calc_stream: + insert_sync_comm_ops(block, self._segments[0]._start_idx, + self.sharding_ring_id, broadcast_vars) insert_broadcast_ops(block, self._segments[0]._start_idx, self.sharding_ring_id, - self._segments[0]._broadcast_vars) + self._segments[0]._broadcast_vars, + self._use_calc_stream) fill_constant_vars = [] for x in self._segments[:2]: @@ -1260,7 +1277,7 @@ def _add_broadcast_allreduce(self, block): cast_ops[k] = v calc_deps_vars = fill_constant_vars + [k for k, v in cast_ops.items()] - if fill_constant_vars or cast_ops: + if not self._use_calc_stream and fill_constant_vars or cast_ops: insert_sync_calc_op(block, self._segments[0]._start_idx, [calc_deps_vars[-1]]) @@ -1308,7 +1325,10 @@ def _build_groups(self): self.global_word_size = self.role_maker._worker_num() self.global_rank = self.role_maker._worker_index() self.global_endpoints = self.role_maker._get_trainer_endpoints() - self.current_endpoint = self.global_endpoints[self.global_rank] + if self._thread_mode: + self.current_endpoint = self.global_endpoints[self.role_maker._role_id()] + else: + self.current_endpoint = self.global_endpoints[self.global_rank] self._collective_helper = CollectiveHelper(self.role_maker, nrings=self._nrings_sharding) assert self.global_word_size % self.mp_degree == 0, \ @@ -1844,3 +1864,205 @@ def _sharding_gradient_merge(self): 'sub_block': cond_block, 'is_scalar_condition': True, }) +class ThreadShardingOptimizer(ShardingOptimizer): + """Sharding Optimizer.""" + def __init__(self, optimizer): + super().__init__(optimizer) + self.inner_opt = optimizer + self.meta_optimizers_white_list = [ + "ParameterServerOptimizer", + "RecomputeOptimizer", + "AMPOptimizer", + "LarsOptimizer", + "LambOptimizer", + "ASPOptimizer", + # "ModelParallelOptimizer", + # "PipelineOptimizer", + ] + self._thread_mode = True + self._use_calc_stream = False + op_maker = core.op_proto_and_checker_maker + self.op_role_key = op_maker.kOpRoleAttrName() + + def _prune_main_program(self, block, shard, rings): + """ + rename BroadCast param + """ + var_names = set([]) + for idx, op in enumerate(block.ops): + for input_name in op.desc.input_arg_names(): + pos = input_name.find("@BroadCast") + if pos <= 0: + continue + new_name = input_name[0 : pos] + op.desc._rename_input( + input_name, new_name + ) + var_names.add(input_name) + for output_name in op.desc.output_arg_names(): + pos = output_name.find("@BroadCast") + if pos <= 0: + continue + new_name = output_name[0 : pos] + op.desc._rename_output( + output_name, new_name + ) + var_names.add(output_name) + + for var_name in var_names: + block._remove_var(var_name, sync=False) + + print("remove broadcast param count=", len(var_names)) + block._sync_with_cpp() + + def _prune_startup_program(self, block, shard): + """ + not need process + """ + block._sync_with_cpp() + + def _insert_loss_grad_scale_op(self): + """ + paddlebox grad not need scale + """ + main_block = self._main_program.global_block() + # # step6: loss div dp_degree + # global_dp_degree = self.sharding_degree * self.dp_degree + # assert int(global_dp_degree) == global_dp_degree + # if global_dp_degree > 1: + # insert_scale_loss_grad_ops(main_block, scale=global_dp_degree) + main_block._sync_with_cpp() + + def minimize_impl( + self, loss, startup_program=None, parameter_list=None, no_grad_set=None + ): + """ + reset start program and main program + """ + sharding_configs = self.user_defined_strategy.sharding_configs + if "use_calc_stream" in sharding_configs: + self._use_calc_stream = sharding_configs["use_calc_stream"] + optimize_ops, params_grads = super().minimize_impl( + loss, startup_program, parameter_list, no_grad_set) + # main_block = self._main_program.global_block() + # startup_block = self._startup_program.global_block() + loss.block.program = self._main_program + from paddle import fluid + fluid.framework.switch_startup_program(self._startup_program) + return optimize_ops, params_grads + + def _init_comm(self): + # sync var + self.role_id = self.role_maker._role_id() + self.node_nums = self.role_maker._node_num() + startup_block = self._startup_program.global_block() + if self.node_nums > 1: + node_nums = len(self.global_endpoints) + assert ( + self.node_nums == node_nums + ), "end points not equal node nums" + self.current_endpoint = self.global_endpoints[self.role_id] + + # mp ring + if self.mp_degree > 1: + self._init_communicator( + self._startup_program, + self.current_endpoint, + self.mp_group_endpoints, + self.role_id, + self.mp_ring_id, + ) + + # sharding ring + if self.sharding_degree > 1: + self._init_communicator( + self._startup_program, + self.current_endpoint, + self.sharding_group_endpoints, + self.role_id, + self.sharding_ring_id, + ) + + # pure dp ring + if self.dp_degree > 1: + self._init_communicator( + self._startup_program, + self.current_endpoint, + self.dp_group_endpoints, + self.role_id, + self.dp_ring_id, + ) + + startup_block._sync_with_cpp() + + def _wait(self): + if self.node_nums <= 1: + return + endpoints = self.global_endpoints[:] + current_endpoint = endpoints[self.role_id] + if self.global_rank == 0: + from paddle.fluid.transpiler.details import wait_server_ready + endpoints.remove(current_endpoint) + wait_server_ready(endpoints) + + def _init_communicator( + self, + program, + current_endpoint, + endpoints, + role_id, + ring_id + ): + block = program.global_block() + # init mulit node nccl + if self.node_nums > 1: + other_endpoints = endpoints[:] + other_endpoints.remove(current_endpoint) + + comm_id_var = block.create_var( + name=unique_name.generate('comm_id'), + persistable=True, + type=core.VarDesc.VarType.RAW, + ) + if core.is_compiled_with_cuda(): + block.append_op( + type='c_gen_nccl_id', + inputs={}, + outputs={'Out': comm_id_var}, + attrs={ + 'rank': role_id, + 'endpoint': current_endpoint, + 'other_endpoints': other_endpoints, + self.op_role_key: OpRole.Forward, + }, + ) + + elif core.is_compiled_with_xpu(): + block.append_op( + type='c_gen_bkcl_id', + inputs={}, + outputs={'Out': comm_id_var}, + attrs={ + 'rank': role_id, + 'endpoint': current_endpoint, + 'other_endpoints': other_endpoints, + self.op_role_key: OpRole.Forward, + }, + ) + + block.append_op( + type='c_comm_init_multitrainer', + inputs={'X': comm_id_var}, + outputs={}, + attrs={ + 'ntrainers': self.node_nums, + 'trainer_id': role_id, + 'ring_id': ring_id, + self.op_role_key: OpRole.Forward, + }, + ) + else: + block.append_op( + type='c_comm_init_all', + attrs={'ring_id': ring_id} + ) diff --git a/python/paddle/fluid/tests/unittests/xpu/test_batch_fc_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_batch_fc_op_xpu.py new file mode 100644 index 00000000000000..9c556b342daa5b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_batch_fc_op_xpu.py @@ -0,0 +1,111 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import random +from op_test import OpTest +from op_test_xpu import XPUOpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +from op_test import OpTest, skip_check_grad_ci +import paddle.fluid.core as core + +paddle.enable_static() + +def np_cal_batchfc(input, w, bias): + slot_pairs_num, batch_size, in_dim = input.shape + _, _, out_dim = w.shape + res = np.zeros((slot_pairs_num, batch_size, out_dim)) + for slot in range(slot_pairs_num): + res[slot, :] = np.dot(input[slot, :], w[slot, :]) + # for slot in range(slot_pairs_num): + # for bindx in range(out_dim): + # res[slot, :, bindx] += bias[slot, bindx] + return res + + +class TestBatchFCOp(XPUOpTest): + + def config(self): + self.slot_pairs_num = 10 + self.batch_size = 5 + self.in_dim = 10 + self.out_dim = 12 + self.dtype = "float32" + + def setUp(self): + self.config() + self.input = np.random.random((self.slot_pairs_num, self.batch_size, + self.in_dim)).astype(self.dtype) + self.w = np.random.random( + (self.slot_pairs_num, self.in_dim, self.out_dim)).astype(self.dtype) + self.bias = np.random.random( + (self.slot_pairs_num, self.out_dim)).astype(self.dtype) + self.op_type = "batch_fc" + np_out = np_cal_batchfc(self.input, self.w, self.bias) + np_out = np_out.astype(self.dtype) + self.inputs = {"Input": self.input, "W": self.w, "Bias": self.bias} + self.outputs = {"Out": np_out} + + def test_check_output_xpu(self): + if core.is_compiled_with_xpu(): + self.check_output_with_place(paddle.XPUPlace(0)) + + def test_check_grad_xpu(self): + if core.is_compiled_with_xpu(): + self.check_grad_with_place(paddle.XPUPlace(0), + ["Bias", "W", "Input"], "Out") + + +class TestBatchFCOp1(XPUOpTest): + + def config(self): + self.slot_pairs_num = 10 + self.batch_size = 5 + self.in_dim = 10 + self.out_dim = 12 + self.dtype = "float32" + + def setUp(self): + self.config() + self.input = np.random.random((self.slot_pairs_num, self.batch_size, + self.in_dim)).astype(self.dtype) + self.w = np.random.random( + (self.slot_pairs_num, self.in_dim, self.out_dim)).astype(self.dtype) + self.bias = np.random.random( + (self.slot_pairs_num, self.out_dim)).astype(self.dtype) + self.op_type = "batch_fc" + np_out = np_cal_batchfc(self.input, self.w, self.bias) + np_out = np_out.astype(self.dtype) + self.inputs = {"Input": self.input, "W": self.w, "Bias": self.bias} + self.outputs = {"Out": np_out} + + def test_check_output_cpu(self): + try: + self.check_output_with_place(place=core.CPUPlace()) + except: + print("do not support cpu test, skip") + + def test_check_grad_cpu(self): + try: + self.check_grad_with_place(core.CPUPlace(), ["Bias", "W", "Input"], + "Out") + except: + print("do not support cpu test, skip") + + +if __name__ == "__main__": + unittest.main()