diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 7c0bbac61807e..950756c0394a5 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -413,7 +413,23 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { if (op_with_kernel == nullptr) { instr_node.OpBase()->Run(*local_scope, place_); } else { - instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get()); + // fit for pten + if (instr_node.PtenKernel() && instr_node.PtenKernel()->IsValid()) { + VLOG(4) << "Run pten kernel: " << op->Type(); + VLOG(4) << instr_node.InnerRuntimeContext().get() << " " + << &instr_node.DeviceContext(); + op_with_kernel->BuildPtenKernelContext( + *instr_node.InnerRuntimeContext().get(), + const_cast(&instr_node.DeviceContext())); + + (*instr_node.PtenKernel())(instr_node.PtenKernelContext()); + + op_with_kernel->WriteBackToOutputs( + instr_node.InnerRuntimeContext().get()); + instr_node.PtenKernelContext()->ClearData(); + } else { + instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get()); + } } } diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index 3817a11b9afe4..41c4faa67fbeb 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -19,10 +19,13 @@ #include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h" #include "paddle/fluid/operators/controlflow/recurrent_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h" +#include "paddle/pten/core/kernel_factory.h" PADDLE_DEFINE_EXPORTED_bool( new_executor_sequential_run, false, "Enable sequential execution for standalone executor, used for debug"); +DECLARE_bool(run_pten_kernel); + namespace paddle { namespace framework { namespace interpreter { @@ -338,6 +341,8 @@ void build_op_func_list(const platform::Place& place, // op is not a operatorwithkernel, so direcly run OperatorBase::Run() deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope); } else { + auto op_with_kernel = + static_cast(op); // construct RuntimeContext and analysis KernelType RuntimeContext runtime_context({}, {}); runtime_context.inputs.swap(ins_map); @@ -350,8 +355,7 @@ void build_op_func_list(const platform::Place& place, // TODO(Aurelius84): In case of control flow ops, they are NOT // inheritted // from OperatorWithKernel. - static_cast(op)->InferShape( - &infer_shape_ctx); + op_with_kernel->InferShape(&infer_shape_ctx); } auto kernels_iter = all_op_kernels.find(op->Type()); @@ -367,10 +371,8 @@ void build_op_func_list(const platform::Place& place, platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); Scope scope; - auto expected_kernel_key = - dynamic_cast(op) - ->GetExpectedKernelType( - ExecutionContext(*op, scope, *dev_ctx, runtime_context)); + auto expected_kernel_key = op_with_kernel->GetExpectedKernelType( + ExecutionContext(*op, scope, *dev_ctx, runtime_context)); // change device by the device_guard() apply_device_guard(op, place, &expected_kernel_key); @@ -378,10 +380,16 @@ void build_op_func_list(const platform::Place& place, // step 3. apply data transforms and insert data transfer ops VariableValueMap& ins_map_temp = runtime_context.inputs; + + // NOTE(zhiqiu): op_func_node->operator_base_ maybe changed in + // ApplyDataTransform ApplyDataTransform(expected_kernel_key, place, &ins_map_temp, var_scope, &op_func_node, vec_func_list, use_local_scope); + op_with_kernel = static_cast( + op_func_node.operator_base_.get()); + // step 4. Run op kernel - VLOG(3) << op->Type() + VLOG(3) << op_with_kernel->Type() << " : expected_kernel_key : " << expected_kernel_key; if (platform::is_gpu_place(expected_kernel_key.place_)) { @@ -397,7 +405,8 @@ void build_op_func_list(const platform::Place& place, } op_func_node.dev_ctx_ = dev_ctx; - auto exec_ctx = ExecutionContext(*op, scope, *dev_ctx, runtime_context); + auto exec_ctx = + ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context); auto kernel_iter = kernels.find(expected_kernel_key); PADDLE_ENFORCE_NE( @@ -406,8 +415,27 @@ void build_op_func_list(const platform::Place& place, "Operator (%s) does not have kernel for %s.", op->Type(), KernelTypeToString(expected_kernel_key))); - op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); - op_func_node.kernel_func_(exec_ctx); + auto run_pten_kernel = false; + + if (FLAGS_run_pten_kernel && + pten::KernelFactory::Instance().HasCompatiblePtenKernel( + op_with_kernel->Type())) { + op_with_kernel->ChoosePtenKernel(exec_ctx); + run_pten_kernel = op_with_kernel->PtenKernel()->IsValid(); + } + + if (run_pten_kernel) { + op_with_kernel->BuildPtenKernelContext(runtime_context, dev_ctx); + op_func_node.pt_kernel_ = op_with_kernel->PtenKernel(); + op_func_node.pt_kernel_context_ = op_with_kernel->PtenKernelContext(); + + (*op_func_node.pt_kernel_)(op_func_node.pt_kernel_context_); + op_with_kernel->WriteBackToOutputs(&runtime_context); + op_func_node.pt_kernel_context_->ClearData(); + } else { + op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); + op_func_node.kernel_func_(exec_ctx); + } // post-process grad_op.outputs if need cast complex grad into real grad. // NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it. diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.cc b/paddle/fluid/framework/new_executor/new_executor_defs.cc index 73f16fe3e9cc7..4b9404fd178fd 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.cc +++ b/paddle/fluid/framework/new_executor/new_executor_defs.cc @@ -673,6 +673,14 @@ OpKernelComputeFunc Instruction::KernelFunc() const { return op_func_node_.kernel_func_; } +pten::Kernel* Instruction::PtenKernel() const { + return op_func_node_.pt_kernel_; +} + +pten::KernelContext* Instruction::PtenKernelContext() const { + return op_func_node_.pt_kernel_context_; +} + OpFuncType Instruction::KernelType() const { return op_func_node_.type_; } OperatorBase* Instruction::OpBase() const { diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index d691a75a6d35b..ca49e7f5670d6 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -295,6 +295,11 @@ struct OpFuncNode { OpKernelComputeFunc kernel_func_; platform::DeviceContext* dev_ctx_; // not owned + + // fit for pten kernel + pten::Kernel* pt_kernel_{nullptr}; // not owned + pten::KernelContext* pt_kernel_context_{nullptr}; // not onwed + OpFuncType type_; }; @@ -313,6 +318,10 @@ class Instruction { OpKernelComputeFunc KernelFunc() const; + pten::Kernel* PtenKernel() const; + + pten::KernelContext* PtenKernelContext() const; + OpFuncType KernelType() const; OperatorBase* OpBase() const; diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 50e16920a6737..2d2e198ef40ec 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1791,6 +1791,9 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs( void OperatorWithKernel::BuildPtenKernelContext( const RuntimeContext& ctx, platform::DeviceContext* dev_ctx) const { + if (pt_kernel_context_ == nullptr) { + pt_kernel_context_.reset(new pten::KernelContext()); + } // TODO(chenweihang): now only work for very simple case, // many cases need to be deal with later: // 1. the input and output are not tensor diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 842ef0457d7bd..59bc4813d985b 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -555,6 +555,20 @@ class OperatorWithKernel : public OperatorBase { virtual KernelSignature GetExpectedPtenKernelArgs( const ExecutionContext& ctx) const; + /* member functions for adapting to pten lib */ + void ChoosePtenKernel(const ExecutionContext& ctx) const; + + void BuildPtenKernelContext(const RuntimeContext& ctx, + platform::DeviceContext* dev_ctx) const; + + void WriteBackToOutputs(RuntimeContext* ctx) const; + + pten::Kernel* PtenKernel() const { return pt_kernel_.get(); } + + pten::KernelContext* PtenKernelContext() const { + return pt_kernel_context_.get(); + } + private: void RunImpl(const Scope& scope, const platform::Place& place) const final; void RunImpl(const Scope& scope, const platform::Place& place, @@ -595,14 +609,6 @@ class OperatorWithKernel : public OperatorBase { Tensor* GetTensorFormInputSafely(const ExecutionContext& ctx, const std::string& name) const; - /* member functions for adapting to pten lib */ - void ChoosePtenKernel(const ExecutionContext& ctx) const; - - void BuildPtenKernelContext(const RuntimeContext& ctx, - platform::DeviceContext* dev_ctx) const; - - void WriteBackToOutputs(RuntimeContext* ctx) const; - protected: mutable std::unique_ptr kernel_type_; mutable std::unique_ptr kernel_func_; diff --git a/paddle/fluid/operators/cast_op.h b/paddle/fluid/operators/cast_op.h index 4f7fe2854ae87..72aa9a195ec7c 100644 --- a/paddle/fluid/operators/cast_op.h +++ b/paddle/fluid/operators/cast_op.h @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/include/core.h" -#include "paddle/pten/include/manipulation.h" +#include "paddle/pten/kernels/cast_kernel.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/expand_as_v2_op.cc b/paddle/fluid/operators/expand_as_v2_op.cc old mode 100644 new mode 100755 index 5296a144f6247..cc293a5aaa0b2 --- a/paddle/fluid/operators/expand_as_v2_op.cc +++ b/paddle/fluid/operators/expand_as_v2_op.cc @@ -12,6 +12,7 @@ limitations under the License. */ #include "paddle/fluid/operators/expand_as_v2_op.h" #include #include +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace operators { @@ -50,6 +51,10 @@ class ExpandAsV2OpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "(Tensor, default Tensor). A tensor with rank in [1, 6]." "X is the input to be expanded."); + AddInput("Y", + "(Tensor, default Tensor). A tensor with rank in [1, 6]." + "Expand X according to the shape of Y.") + .AsDispensable(); AddOutput("Out", "(Tensor, default Tensor). A tensor with rank in [1, 6]." "The rank of Output(Out) have the same with Input(X). " @@ -144,3 +149,9 @@ REGISTER_OP_CUDA_KERNEL( ops::ExpandAsV2GradKernel, ops::ExpandAsV2GradKernel); #endif + +REGISTER_OP_VERSION(expand_as_v2) + .AddCheckpoint( + R"ROC(fix expand_as_v2 and add new input [Y])ROC", + paddle::framework::compatible::OpVersionDesc().NewInput( + "Y", "Expand X according to the shape of Y")); \ No newline at end of file diff --git a/paddle/fluid/operators/expand_as_v2_op.h b/paddle/fluid/operators/expand_as_v2_op.h old mode 100644 new mode 100755 index 3e8f7d15880bc..9e683a792c61f --- a/paddle/fluid/operators/expand_as_v2_op.h +++ b/paddle/fluid/operators/expand_as_v2_op.h @@ -91,17 +91,34 @@ class ExpandAsV2Kernel : public framework::OpKernel { PADDLE_ENFORCE_NE(target_shape[i], 0, platform::errors::InvalidArgument( "The value of target shape cannot be zero.")); - if (vec_in_dims[i] != 1) { + if (i < diff) { + PADDLE_ENFORCE_GT( + target_shape[i], 0, + platform::errors::InvalidArgument( + "The expanded size (%d) for non-existing dimensions must be " + "positive for expand_as_v2 op.", + target_shape[i])); + repeat_times[i] = target_shape[i]; + } else if (target_shape[i] > 0) { + if (vec_in_dims[i] != 1) { + PADDLE_ENFORCE_EQ( + vec_in_dims[i], target_shape[i], + platform::errors::InvalidArgument( + "The value (%d) of the non-singleton dimension does not match" + " the corresponding value (%d) in shape for expand_as_v2 op.", + vec_in_dims[i], target_shape[i])); + repeat_times[i] = 1; + } else { + repeat_times[i] = target_shape[i]; + } + } else { PADDLE_ENFORCE_EQ( - vec_in_dims[i], target_shape[i], + target_shape[i], -1, platform::errors::InvalidArgument( - "The value (%d) of the non-singleton dimension does not match" - " the corresponding value (%d) in " - "target tensor for expand_as_v2 op.", - vec_in_dims[i], target_shape[i])); + "When the value in shape is negative for expand_as_v2 op, " + "only -1 is supported, but the value received is %d.", + target_shape[i])); repeat_times[i] = 1; - } else { - repeat_times[i] = target_shape[i]; } } auto* out0 = context.Output("Out"); diff --git a/paddle/fluid/operators/expand_v2_op.cc b/paddle/fluid/operators/expand_v2_op.cc old mode 100644 new mode 100755 index dc6da979671e5..6d803c500d90f --- a/paddle/fluid/operators/expand_v2_op.cc +++ b/paddle/fluid/operators/expand_v2_op.cc @@ -65,7 +65,11 @@ class ExpandV2Op : public framework::OperatorWithKernel { if (x_dims[i] == -1) { out_shape[i] = -1; } else if (expand_shape[i] == -1) { - out_shape[i] = x_dims[i]; + if (static_cast(x_dims.size()) > i) { + out_shape[i] = x_dims[i]; + } else { + out_shape[i] = -1; + } } else if (expand_shape[i] == -2) { // We use -2 to represent the element in expand_shape is a var. out_shape[i] = -1; diff --git a/paddle/fluid/operators/flatten_op.h b/paddle/fluid/operators/flatten_op.h index 29eb579b2a0d3..fa116d9516ecd 100644 --- a/paddle/fluid/operators/flatten_op.h +++ b/paddle/fluid/operators/flatten_op.h @@ -21,7 +21,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/pooling.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/pten/include/core.h" -#include "paddle/pten/include/manipulation.h" +#include "paddle/pten/kernels/flatten_kernel.h" namespace paddle { namespace operators { @@ -134,8 +134,8 @@ class FlattenContiguousRangeKernel : public framework::OpKernel { auto pt_out = paddle::experimental::MakePtenDenseTensor(*out); // call new kernel - pten::Flatten(dev_ctx, *pt_x.get(), start_axis, stop_axis, - pt_out.get()); + pten::FlattenKernel(dev_ctx, *pt_x.get(), start_axis, + stop_axis, pt_out.get()); } }; diff --git a/paddle/fluid/operators/fused/fused_dropout_helper.h b/paddle/fluid/operators/fused/fused_dropout_helper.h index 970b2d82e2b15..3972c60e8347b 100644 --- a/paddle/fluid/operators/fused/fused_dropout_helper.h +++ b/paddle/fluid/operators/fused/fused_dropout_helper.h @@ -250,11 +250,14 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper { } // out = layernorm(residual + dropout(src + bias)) - void LayernormResidualDropoutBias( - const platform::CUDADeviceContext& ctx, const T* src, const T* residual, - const T* bias, const LayerNormParamType* gamma, - const LayerNormParamType* beta, T* dropout_out, MaskType* mask, T* out, - LayerNormParamType* mean, LayerNormParamType* variance) { + template , bool is_same_type = false> + void LayernormResidualDropoutBias(const platform::CUDADeviceContext& ctx, + const T* src, const T* residual, + const T* bias, const P* gamma, + const P* beta, T* dropout_out, + MaskType* mask, T* out, + LayerNormParamType* mean, + LayerNormParamType* variance) { using U = LayerNormParamType; int vec_size = MAX_CACHE_BYTES / sizeof(T); if (this->cols_ % vec_size != 0) { @@ -263,7 +266,7 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper { int threads = GetDesiredBlockDim(this->cols_ / vec_size); int increment = ((this->cols_ - 1) / (threads * vec_size) + 1) * vec_size; increment = this->dropout_param_.UpdateSeedAndIncrement(ctx, increment); - LaunchLayernormResidualDropoutBias( + LaunchLayernormResidualDropoutBias( this->rows_, this->cols_, increment, this->dropout_param_.seed, this->dropout_param_.dropout_prob, epsilon_, this->dropout_param_.is_upscale_in_train, this->dropout_param_.is_test, @@ -271,17 +274,19 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper { variance, ctx); } - void LayernormResidualDropoutBiasGrad( - const platform::CUDADeviceContext& ctx, const T* d_out, - const T* layernorm_src, const MaskType* mask, - const LayerNormParamType* gamma, const LayerNormParamType* mean, - const LayerNormParamType* variance, T* d_layernorm_src, - LayerNormParamType* d_scale, LayerNormParamType* d_layernorm_bias, - T* d_dropout_src, T* d_bias, T* d_residual) { + template , bool is_same_type = false> + void LayernormResidualDropoutBiasGrad(const platform::CUDADeviceContext& ctx, + const T* d_out, const T* layernorm_src, + const MaskType* mask, const P* gamma, + const LayerNormParamType* mean, + const LayerNormParamType* variance, + T* d_layernorm_src, P* d_scale, + P* d_layernorm_bias, T* d_dropout_src, + T* d_bias, T* d_residual) { using U = LayerNormParamType; - LayerNormBackward(layernorm_src, d_out, gamma, mean, variance, - d_layernorm_src, d_scale, d_layernorm_bias, - epsilon_, this->rows_, this->cols_, ctx); + LayerNormBackward( + layernorm_src, d_out, gamma, mean, variance, d_layernorm_src, d_scale, + d_layernorm_bias, epsilon_, this->rows_, this->cols_, ctx); this->ResidualDropoutBiasGrad(ctx, d_layernorm_src, mask, d_dropout_src, d_residual, d_bias); } diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h index 1827e137c15f1..b27b70dc9dc0c 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h @@ -24,46 +24,57 @@ using CudnnDataType = platform::CudnnDataType; template using LayerNormParamType = typename CudnnDataType::BatchNormParamType; +template +using LayerNormScaleBiasT = + typename std::conditional::type; + /** * @brief fused add_bias, dropout, add residual and leyer_norm into one * operators. Currently only support forward */ -template -__device__ void CalcLayernormY(const LayerNormParamType *scale, - const LayerNormParamType *bias, const T *x, - T *y, const int row_id, const int col_id, - const int cols, - const LayerNormParamType mean_val, - const LayerNormParamType invvar) { - using U = LayerNormParamType; +template +__device__ void CalcLayernormY( + const LayerNormScaleBiasT *scale, + const LayerNormScaleBiasT *bias, const T *x, + T *y, const int row_id, const int col_id, const int cols, + const LayerNormParamType mean_val, const LayerNormParamType invvar) { using LoadT = platform::AlignedVector; using StoreT = platform::AlignedVector; using LoadU = platform::AlignedVector; + using LoadScaleOrBias = + platform::AlignedVector, + VecSize>; for (int i = col_id * VecSize; i < cols; i += blockDim.x * VecSize) { - LoadU scale_vec; - LoadU bias_vec; + LoadScaleOrBias scale_vec; + LoadScaleOrBias bias_vec; LoadT x_vec; #pragma unroll for (int ii = 0; ii < VecSize; ii++) { - scale_vec[ii] = static_cast(1); - bias_vec[ii] = static_cast(0); + scale_vec[ii] = + static_cast>(1); + bias_vec[ii] = + static_cast>(0); } // vectorize load data from global platform::Load(&x[row_id * cols + i], &x_vec); if (scale != nullptr) { - platform::Load(&scale[i], &scale_vec); + platform::Load, + VecSize>(&scale[i], &scale_vec); } if (bias != nullptr) { - platform::Load(&bias[i], &bias_vec); + platform::Load, + VecSize>(&bias[i], &bias_vec); } StoreT y_vec; for (int ii = 0; ii < VecSize; ii++) { - y_vec[ii] = static_cast( - scale_vec[ii] * (static_cast(x_vec[ii]) - mean_val) * invvar + - bias_vec[ii]); + y_vec[ii] = + static_cast(static_cast(scale_vec[ii]) * + (static_cast(x_vec[ii]) - mean_val) * invvar + + static_cast(bias_vec[ii])); } platform::Store(y_vec, &y[row_id * cols + i]); } @@ -85,15 +96,17 @@ __device__ void CalcLayernormY(const LayerNormParamType *scale, * means: [rows]: layernorm means * vars: [rows]: layernorm vars */ -template +template __global__ void FusedLayernormResidualDropoutBias( const size_t rows, const size_t cols, uint64_t seed, const float dropout_prob, const bool is_upscale_in_train, const bool is_test, const uint64_t increment, const float epsilon, const T *src, const T *residual, const T *bias, - const LayerNormParamType *scale, - const LayerNormParamType *layernorm_bias, MaskType *mask, T *dst, - T *layernorm_dst, LayerNormParamType *mean, LayerNormParamType *var) { + const LayerNormScaleBiasT *scale, + const LayerNormScaleBiasT *layernorm_bias, + MaskType *mask, T *dst, T *layernorm_dst, LayerNormParamType *mean, + LayerNormParamType *var) { int col_id = threadIdx.x; int row_id = blockIdx.x; int idx = row_id * cols + col_id; @@ -101,7 +114,6 @@ __global__ void FusedLayernormResidualDropoutBias( curand_init(seed, idx, increment, &state); T factor = GetFactor(dropout_prob, is_upscale_in_train, is_test); - using U = LayerNormParamType; __shared__ U mean_share; __shared__ U var_share; @@ -121,10 +133,12 @@ __global__ void FusedLayernormResidualDropoutBias( mean_val = BlockReduceSum(mean_val, shared_mean); var_val = BlockReduceSum(var_val, shared_var); if (threadIdx.x == 0) { - auto scale = static_cast(1.) / static_cast(cols); - auto tmp = mean_val * scale; + auto scale = static_cast>( + static_cast(1.) / static_cast(cols)); + auto tmp = mean_val * static_cast(scale); mean[row_id] = mean_share = static_cast(tmp); - var_share = static_cast(var_val * scale - mean_share * mean_share); + var_share = static_cast(var_val * static_cast(scale) - + mean_share * mean_share); var_share = var_share > U(0) ? var_share : U(0); var[row_id] = var_share; } @@ -134,8 +148,9 @@ __global__ void FusedLayernormResidualDropoutBias( U invvar = rsqrt_(var_share + static_cast(epsilon)); // calculate layernorm_dst - CalcLayernormY(scale, layernorm_bias, dst, layernorm_dst, row_id, - col_id, cols, mean_val, invvar); + CalcLayernormY( + scale, layernorm_bias, dst, layernorm_dst, row_id, col_id, cols, mean_val, + invvar); } /** @@ -154,16 +169,17 @@ __global__ void FusedLayernormResidualDropoutBias( * means: [rows]: layernorm means * vars: [rows]: layernorm vars */ -template +template void LaunchLayernormResidualDropoutBias( const uint32_t rows, const uint32_t cols, const int increment, uint64_t seed, const float dropout_prob, const float epsilon, const bool is_upscale_in_train, const bool is_test, const T *src, - const T *residual, const T *bias, const LayerNormParamType *scale, - const LayerNormParamType *layernorm_bias, MaskType *mask_data, T *dst, - T *layernorm_dst, LayerNormParamType *mean, LayerNormParamType *var, - const platform::CUDADeviceContext &ctx) { - using U = LayerNormParamType; + const T *residual, const T *bias, + const LayerNormScaleBiasT *scale, + const LayerNormScaleBiasT *layernorm_bias, + MaskType *mask_data, T *dst, T *layernorm_dst, LayerNormParamType *mean, + LayerNormParamType *var, const platform::CUDADeviceContext &ctx) { // dropout_prob == 1.0f if (std::abs(dropout_prob - 1.0f) < 1e-5) { auto cuda_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); @@ -175,8 +191,9 @@ void LaunchLayernormResidualDropoutBias( // call layernorm forward switch (GetDesiredBlockDim(cols)) { FIXED_BLOCK_DIM_CASE( - LayerNormForward<<>>( + LayerNormForward< + T, U, kBlockDim, + ScaleBiasWithSameTypeX><<>>( dst, scale, layernorm_bias, layernorm_dst, mean, var, epsilon, cols)); default: @@ -184,21 +201,24 @@ void LaunchLayernormResidualDropoutBias( "Product from begin_norm_axis to end must be larger than 1")); break; } + return; } const int VecSize = MAX_CACHE_BYTES / sizeof(T); if (cols % VecSize != 0) { int blockDim = GetDesiredBlockDim(cols); - FusedLayernormResidualDropoutBias<<>>( + FusedLayernormResidualDropoutBias< + T, uint8_t, 1, U, + ScaleBiasWithSameTypeX><<>>( rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, increment, epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst, layernorm_dst, mean, var); } else { int blockDim = GetDesiredBlockDim(cols / VecSize); FusedLayernormResidualDropoutBias< - T, uint8_t, VecSize><<>>( + T, uint8_t, VecSize, U, + ScaleBiasWithSameTypeX><<>>( rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, increment, epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst, layernorm_dst, mean, var); diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu index 50e3555b4bcd6..57d3fc94dc88a 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu @@ -223,7 +223,7 @@ struct TestFusedLayernormResidualDropoutBias { layernorm_bias_ptr = layernorm_bias.data(); } - paddle::operators::LaunchLayernormResidualDropoutBias( + paddle::operators::LaunchLayernormResidualDropoutBias( rows, cols, increment, seed, dropout_prob, epsilon, is_upscale_in_train, is_test, src.data(), residual.data(), bias_ptr, scale_ptr, layernorm_bias_ptr, mask.data(), out.data(), diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 996a784affa4c..f2162f55636e5 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -21,7 +21,7 @@ limitations under the License. */ #include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/common/scalar_array.h" #include "paddle/pten/include/core.h" -#include "paddle/pten/include/manipulation.h" +#include "paddle/pten/kernels/reshape_kernel.h" namespace paddle { namespace framework { class InferShapeContext; @@ -438,18 +438,18 @@ class ReshapeKernel { } if (platform::is_cpu_place(ctx.GetPlace())) { auto &dev_ctx = ctx.device_context(); - pten::Reshape(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out); + pten::ReshapeKernel(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out); } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::is_gpu_place(ctx.GetPlace())) { auto &dev_ctx = ctx.device_context(); - pten::Reshape(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out); + pten::ReshapeKernel(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out); } #endif #ifdef PADDLE_WITH_XPU if (platform::is_xpu_place(ctx.GetPlace())) { auto &dev_ctx = ctx.device_context(); - pten::Reshape(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out); + pten::ReshapeKernel(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out); } #endif // non-inplace need move all result from pt_out to out, inplace need set diff --git a/paddle/pten/all.h b/paddle/pten/all.h index 844114c341d67..7dd517e5e6381 100644 --- a/paddle/pten/all.h +++ b/paddle/pten/all.h @@ -18,5 +18,4 @@ limitations under the License. */ #include "paddle/pten/include/core.h" #include "paddle/pten/include/infermeta.h" #include "paddle/pten/include/linalg.h" -#include "paddle/pten/include/manipulation.h" #include "paddle/pten/include/math.h" diff --git a/paddle/pten/common/device.cc b/paddle/pten/common/device.cc index 9583b521d9123..55130067ae200 100644 --- a/paddle/pten/common/device.cc +++ b/paddle/pten/common/device.cc @@ -24,7 +24,7 @@ const char* DeviceTypeStr(DeviceType type) { case DeviceType::kUndef: return "kUndef"; case DeviceType::kHost: - return "kUndef"; + return "kHost"; case DeviceType::kXpu: return "kXpu"; case DeviceType::kCuda: diff --git a/paddle/pten/include/manipulation.h b/paddle/pten/include/manipulation.h deleted file mode 100644 index a8625e52f5618..0000000000000 --- a/paddle/pten/include/manipulation.h +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -// See Note: [ How do we organize the kernel directory ] -#include "paddle/pten/api/lib/utils/storage.h" -#include "paddle/pten/include/infermeta.h" -#include "paddle/pten/kernels/cast_kernel.h" -#include "paddle/pten/kernels/flatten_kernel.h" -#include "paddle/pten/kernels/reshape_kernel.h" - -namespace pten { - -template -DenseTensor Flatten(const ContextT& dev_ctx, - const DenseTensor& x, - int start_axis, - int stop_axis) { - auto out_meta = FlattenInferMeta(x.meta(), start_axis, stop_axis); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - Flatten(dev_ctx, x, start_axis, stop_axis, &dense_out); - return dense_out; -} - -template -DenseTensor Reshape(const ContextT& dev_ctx, - const DenseTensor& x, - const std::vector& shape) { - auto out_meta = InferMetaFromVecValue(x.meta(), shape); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - Reshape(dev_ctx, x, ScalarArray(shape), &dense_out); - return dense_out; -} - -} // namespace pten diff --git a/paddle/pten/include/math.h b/paddle/pten/include/math.h index e46f460260adb..faa4c8db8dac3 100644 --- a/paddle/pten/include/math.h +++ b/paddle/pten/include/math.h @@ -18,7 +18,6 @@ limitations under the License. */ #include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/include/infermeta.h" #include "paddle/pten/kernels/complex_kernel.h" -#include "paddle/pten/kernels/math_kernel.h" #include "paddle/pten/kernels/scale_kernel.h" namespace pten { @@ -34,42 +33,6 @@ DenseTensor Sign(const ContextT& dev_ctx, const DenseTensor& x) { return dense_out; } -template -DenseTensor Mean(const ContextT& dev_ctx, - const DenseTensor& x, - const std::vector& axis, - bool keep_dim) { - auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - bool reduce_all = false; - MeanKernel(dev_ctx, x, axis, keep_dim, reduce_all, &dense_out); - return dense_out; -} - -template -DenseTensor Sum(const ContextT& dev_ctx, - const DenseTensor& x, - const std::vector& axis, - DataType dtype, - bool keep_dim) { - auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim, dtype); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - out_meta); - - // The real value of reduce_all will be get in kernel - // so use default value(false) is OK. - bool reduce_all = false; - - SumKernel( - dev_ctx, x, axis, keep_dim, reduce_all, out_meta.dtype, &dense_out); - return dense_out; -} - template DenseTensor Scale(const ContextT& dev_ctx, const DenseTensor& x, diff --git a/paddle/pten/kernels/cpu/math_kernel.cc b/paddle/pten/kernels/cpu/math_kernel.cc index 2a696584bc781..be0d52355bce6 100644 --- a/paddle/pten/kernels/cpu/math_kernel.cc +++ b/paddle/pten/kernels/cpu/math_kernel.cc @@ -21,6 +21,7 @@ #include "paddle/pten/kernels/cpu/elementwise.h" #include "paddle/pten/kernels/cpu/reduce.h" #include "paddle/pten/kernels/funcs/elementwise_functor.h" +#include "paddle/pten/kernels/funcs/reduce_functor.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/framework/eigen.h" @@ -61,7 +62,7 @@ void MeanKernel(const Context& dev_ctx, bool reduce_all, DenseTensor* out) { auto out_dtype = x.dtype(); - pten::Reduce( + pten::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); } @@ -97,7 +98,7 @@ void SumKernel(const Context& dev_ctx, bool reduce_all, DataType out_dtype, DenseTensor* out) { - pten::Reduce( + pten::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); } diff --git a/paddle/pten/kernels/cpu/reduce.h b/paddle/pten/kernels/cpu/reduce.h index fc5dbe9d58d63..fa603b2163055 100644 --- a/paddle/pten/kernels/cpu/reduce.h +++ b/paddle/pten/kernels/cpu/reduce.h @@ -19,10 +19,184 @@ #include "paddle/pten/api/ext/dispatch.h" #include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/kernels/cast_kernel.h" -#include "paddle/pten/kernels/hybird/eigen/reduce.h" +#include "paddle/pten/api/lib/utils/storage.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/kernels/hybird/eigen/common.h" +#include "paddle/pten/kernels/hybird/transpose.h" +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/operators/eigen/eigen_function.h" namespace pten { +template +void ReduceFunctor(const DeviceContext& context, + const pten::DenseTensor& input, + pten::DenseTensor* output, + const std::vector& dims, + bool keep_dim) { + auto x = EigenTensor::From(input); + auto x_rank = static_cast(x.dimensions().size()); + auto reduce_dim = Eigen::array(); + std::vector dims_ref = dims; + for (size_t i = 0; i < dims_ref.size(); ++i) { + if (dims_ref[i] < 0) dims_ref[i] = x_rank + dims_ref[i]; + reduce_dim[i] = dims_ref[i]; + } + // construct the squeezed output tensor + DDim out_dims = output->dims(); + if (keep_dim && x_rank > 1) { + const int kDelFlag = -2; + auto dims_vector = paddle::framework::vectorize(out_dims); + for (size_t i = 0; i < dims_ref.size(); ++i) { + dims_vector[dims_ref[i]] = kDelFlag; + } + dims_vector.erase(remove(dims_vector.begin(), dims_vector.end(), kDelFlag), + dims_vector.end()); + out_dims = paddle::framework::make_ddim(dims_vector); + } + auto& place = *context.eigen_device(); + Functor functor; + + if (D == 1) { + auto out = EigenScalar::From(*output); + functor(place, &x, &out, reduce_dim); + } else { + auto out = EigenTensor::From(*output, out_dims); + functor(place, &x, &out, reduce_dim); + } +} + +#define HANDLE_REDUCE_DIM(NDIM, RDIM) \ + if (ndim == NDIM && rdim == RDIM) { \ + ReduceFunctor( \ + dev_ctx, input, output, dims, keep_dim); \ + } +//////////////// HandleLargeDim + +inline void GetShuffledDim(const DDim& src_dims, + DDim* dst_dims, + const std::vector& reduced_dims, + std::vector* perm_axis) { + // check if it's a reduced dim + std::vector src_dims_check(src_dims.size(), false); + size_t src_size = src_dims.size(); + size_t reduce_size = reduced_dims.size(); + std::vector regular_reduced_dims = reduced_dims; + for (size_t i = 0; i < regular_reduced_dims.size(); i++) { + if (regular_reduced_dims[i] < 0) { + regular_reduced_dims[i] = src_size + regular_reduced_dims[i]; + } + } + + for (size_t i = 0; i < reduce_size; ++i) { + dst_dims->at(src_size - reduce_size + i) = + src_dims[regular_reduced_dims[i]]; + (*perm_axis)[src_size - reduce_size + i] = regular_reduced_dims[i]; + src_dims_check[regular_reduced_dims[i]] = true; + } + + size_t offset = 0; + for (size_t i = 0; i < src_dims_check.size(); ++i) { + bool is_reduced = src_dims_check[i]; + if (!is_reduced) { + (*perm_axis)[offset] = i; + dst_dims->at(offset++) = src_dims[i]; + } + } +} + +template +void GetShuffledInput(const DeviceContext& dev_ctx, + const pten::DenseTensor& input, + pten::DenseTensor* shuffled_input, + const std::vector& dims) { + DDim shuffled_dims(input.dims()); + std::vector perm_axis(input.dims().size()); + GetShuffledDim(input.dims(), &shuffled_dims, dims, &perm_axis); + + shuffled_input->Resize(shuffled_dims); + shuffled_input->mutable_data(); + + pten::math::TransposeNormal trans; + trans(dev_ctx, input, shuffled_input, perm_axis); +} + +template +void HandleLargeDim(const DeviceContext& dev_ctx, + const pten::DenseTensor& input, + pten::DenseTensor* output, + const std::vector& dims, + bool keep_dim) { + // shuffle the reduced dim to the end + pten::DenseTensor shuffled_input = pten::DenseTensor( + pten::make_intrusive(input.place()), + input.meta()); + + GetShuffledInput(dev_ctx, input, &shuffled_input, dims); + + // transpose to 2D tensor whose shape is {unreduced, reduced}. + const int64_t unreduced = output->numel(); + const int64_t reduced = shuffled_input.numel() / unreduced; + shuffled_input.Resize({unreduced, reduced}); + DDim output_dim = output->dims(); + output->Resize({unreduced}); + ReduceFunctor( + dev_ctx, shuffled_input, output, {1}, keep_dim); + output->Resize(output_dim); +} + +////////////// ReduceKernel + +template +void ReduceKernelImpl(const DeviceContext& dev_ctx, + const pten::DenseTensor& input, + pten::DenseTensor* output, + const std::vector& dims, + bool keep_dim, + bool reduce_all) { + output->mutable_data(); + + if (reduce_all) { + // Flatten and reduce 1-D tensor + auto x = EigenVector::Flatten(input); + auto out = EigenScalar::From(*output); + auto& dev = *dev_ctx.eigen_device(); + auto reduce_dim = Eigen::array({{0}}); + + Functor functor; + functor(dev, &x, &out, reduce_dim); + } else { + int ndim = input.dims().size(); + int rdim = dims.size(); + if (ndim > 6) { + HandleLargeDim( + dev_ctx, input, output, dims, keep_dim); + + } else { + HANDLE_REDUCE_DIM(6, 5); + HANDLE_REDUCE_DIM(6, 4); + HANDLE_REDUCE_DIM(6, 3); + HANDLE_REDUCE_DIM(6, 2); + HANDLE_REDUCE_DIM(6, 1); + HANDLE_REDUCE_DIM(5, 4); + HANDLE_REDUCE_DIM(5, 3); + HANDLE_REDUCE_DIM(5, 2); + HANDLE_REDUCE_DIM(5, 1); + HANDLE_REDUCE_DIM(4, 3); + HANDLE_REDUCE_DIM(4, 2); + HANDLE_REDUCE_DIM(4, 1); + HANDLE_REDUCE_DIM(3, 2); + HANDLE_REDUCE_DIM(3, 1); + HANDLE_REDUCE_DIM(2, 1); + HANDLE_REDUCE_DIM(1, 1); + } + } +} + template void Reduce(const DeviceContext& dev_ctx, const DenseTensor& x, @@ -52,7 +226,7 @@ void Reduce(const DeviceContext& dev_ctx, // do reduce sum PD_VISIT_ALL_TYPES( out_dtype, "ReduceKernelImpl", ([&] { - pten::eigen::ReduceKernelImpl( + pten::ReduceKernelImpl( dev_ctx, x, out, dims, keep_dim, reduce_all); })); } else { @@ -66,7 +240,7 @@ void Reduce(const DeviceContext& dev_ctx, // do reduce sum PD_VISIT_ALL_TYPES( out_dtype, "ReduceKernelImpl", ([&] { - pten::eigen::ReduceKernelImpl( + pten::ReduceKernelImpl( dev_ctx, tmp_tensor, out, dims, keep_dim, reduce_all); })); } diff --git a/paddle/pten/kernels/flatten_kernel.cc b/paddle/pten/kernels/flatten_kernel.cc index df8238cbf3a91..37d4d88ccb40e 100644 --- a/paddle/pten/kernels/flatten_kernel.cc +++ b/paddle/pten/kernels/flatten_kernel.cc @@ -22,11 +22,11 @@ namespace pten { template -void Flatten(const Context& dev_ctx, - const DenseTensor& x, - int start_axis, - int stop_axis, - DenseTensor* out) { +void FlattenKernel(const Context& dev_ctx, + const DenseTensor& x, + int start_axis, + int stop_axis, + DenseTensor* out) { auto out_dims = out->dims(); pten::Copy(dev_ctx, x, false, out); out->Resize(out_dims); @@ -42,7 +42,7 @@ void FlattenWithXShape(const Context& dev_ctx, int stop_axis, DenseTensor* out, DenseTensor* xshape) { - Flatten(dev_ctx, x, start_axis, stop_axis, out); + FlattenKernel(dev_ctx, x, start_axis, stop_axis, out); funcs::SetXShape(x, xshape); } @@ -51,7 +51,7 @@ void FlattenWithXShape(const Context& dev_ctx, PT_REGISTER_CTX_KERNEL(flatten, CPU, ALL_LAYOUT, - pten::Flatten, + pten::FlattenKernel, float, double, uint8_t, @@ -74,7 +74,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape, PT_REGISTER_CTX_KERNEL(flatten, GPU, ALL_LAYOUT, - pten::Flatten, + pten::FlattenKernel, float, paddle::platform::float16, double, @@ -100,7 +100,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape, PT_REGISTER_CTX_KERNEL(flatten, XPU, ALL_LAYOUT, - pten::Flatten, + pten::FlattenKernel, float, paddle::platform::float16, double, diff --git a/paddle/pten/kernels/flatten_kernel.h b/paddle/pten/kernels/flatten_kernel.h index 5a0445489bcf3..a67e66fac4130 100644 --- a/paddle/pten/kernels/flatten_kernel.h +++ b/paddle/pten/kernels/flatten_kernel.h @@ -15,15 +15,17 @@ limitations under the License. */ #pragma once #include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/include/infermeta.h" +#include "paddle/pten/kernels/empty_kernel.h" namespace pten { template -void Flatten(const Context& dev_ctx, - const DenseTensor& x, - int start_axis, - int stop_axis, - DenseTensor* out); +void FlattenKernel(const Context& dev_ctx, + const DenseTensor& x, + int start_axis, + int stop_axis, + DenseTensor* out); template void FlattenWithXShape(const Context& dev_ctx, @@ -33,4 +35,15 @@ void FlattenWithXShape(const Context& dev_ctx, DenseTensor* out, DenseTensor* xshape); +template +DenseTensor Flatten(const Context& dev_ctx, + const DenseTensor& x, + int start_axis, + int stop_axis) { + auto out_meta = FlattenInferMeta(x.meta(), start_axis, stop_axis); + auto dense_out = Empty(dev_ctx, std::move(out_meta)); + FlattenKernel(dev_ctx, x, start_axis, stop_axis, &dense_out); + return dense_out; +} + } // namespace pten diff --git a/paddle/pten/kernels/funcs/reduce_functor.h b/paddle/pten/kernels/funcs/reduce_functor.h new file mode 100644 index 0000000000000..64ada0231892e --- /dev/null +++ b/paddle/pten/kernels/funcs/reduce_functor.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace pten { +namespace funcs { + +//////// Sum Functor /////// +struct SumFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { + y->device(place) = x->sum(dim); + } +}; + +//////// Mean Functor /////// +struct MeanFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { + y->device(place) = x->mean(dim); + } +}; + +} // namespace funcs +} // namespace pten diff --git a/paddle/pten/kernels/hybird/eigen/reduce.h b/paddle/pten/kernels/hybird/eigen/reduce.h deleted file mode 100644 index d60a416dfdb37..0000000000000 --- a/paddle/pten/kernels/hybird/eigen/reduce.h +++ /dev/null @@ -1,214 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/pten/api/lib/utils/storage.h" -#include "paddle/pten/core/dense_tensor.h" -#include "paddle/pten/kernels/hybird/eigen/common.h" -#include "paddle/pten/kernels/hybird/transpose.h" - -// See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/operators/eigen/eigen_function.h" - -namespace pten { -namespace eigen { - -template -void ReduceFunctor(const DeviceContext& context, - const pten::DenseTensor& input, - pten::DenseTensor* output, - const std::vector& dims, - bool keep_dim) { - auto x = EigenTensor::From(input); - auto x_rank = static_cast(x.dimensions().size()); - auto reduce_dim = Eigen::array(); - std::vector dims_ref = dims; - for (size_t i = 0; i < dims_ref.size(); ++i) { - if (dims_ref[i] < 0) dims_ref[i] = x_rank + dims_ref[i]; - reduce_dim[i] = dims_ref[i]; - } - // construct the squeezed output tensor - DDim out_dims = output->dims(); - if (keep_dim && x_rank > 1) { - const int kDelFlag = -2; - auto dims_vector = paddle::framework::vectorize(out_dims); - for (size_t i = 0; i < dims_ref.size(); ++i) { - dims_vector[dims_ref[i]] = kDelFlag; - } - dims_vector.erase(remove(dims_vector.begin(), dims_vector.end(), kDelFlag), - dims_vector.end()); - out_dims = paddle::framework::make_ddim(dims_vector); - } - auto& place = *context.eigen_device(); - Functor functor; - - if (D == 1) { - auto out = EigenScalar::From(*output); - functor(place, &x, &out, reduce_dim); - } else { - auto out = EigenTensor::From(*output, out_dims); - functor(place, &x, &out, reduce_dim); - } -} - -#define HANDLE_REDUCE_DIM(NDIM, RDIM) \ - if (ndim == NDIM && rdim == RDIM) { \ - ReduceFunctor( \ - dev_ctx, input, output, dims, keep_dim); \ - } -//////////////// HandleLargeDim - -inline void GetShuffledDim(const DDim& src_dims, - DDim* dst_dims, - const std::vector& reduced_dims, - std::vector* perm_axis) { - // check if it's a reduced dim - std::vector src_dims_check(src_dims.size(), false); - size_t src_size = src_dims.size(); - size_t reduce_size = reduced_dims.size(); - std::vector regular_reduced_dims = reduced_dims; - for (size_t i = 0; i < regular_reduced_dims.size(); i++) { - if (regular_reduced_dims[i] < 0) { - regular_reduced_dims[i] = src_size + regular_reduced_dims[i]; - } - } - - for (size_t i = 0; i < reduce_size; ++i) { - dst_dims->at(src_size - reduce_size + i) = - src_dims[regular_reduced_dims[i]]; - (*perm_axis)[src_size - reduce_size + i] = regular_reduced_dims[i]; - src_dims_check[regular_reduced_dims[i]] = true; - } - - size_t offset = 0; - for (size_t i = 0; i < src_dims_check.size(); ++i) { - bool is_reduced = src_dims_check[i]; - if (!is_reduced) { - (*perm_axis)[offset] = i; - dst_dims->at(offset++) = src_dims[i]; - } - } -} - -template -void GetShuffledInput(const DeviceContext& dev_ctx, - const pten::DenseTensor& input, - pten::DenseTensor* shuffled_input, - const std::vector& dims) { - DDim shuffled_dims(input.dims()); - std::vector perm_axis(input.dims().size()); - GetShuffledDim(input.dims(), &shuffled_dims, dims, &perm_axis); - - shuffled_input->Resize(shuffled_dims); - shuffled_input->mutable_data(); - - pten::math::TransposeNormal trans; - trans(dev_ctx, input, shuffled_input, perm_axis); -} - -template -void HandleLargeDim(const DeviceContext& dev_ctx, - const pten::DenseTensor& input, - pten::DenseTensor* output, - const std::vector& dims, - bool keep_dim) { - // shuffle the reduced dim to the end - pten::DenseTensor shuffled_input = pten::DenseTensor( - pten::make_intrusive(input.place()), - input.meta()); - - GetShuffledInput(dev_ctx, input, &shuffled_input, dims); - - // transpose to 2D tensor whose shape is {unreduced, reduced}. - const int64_t unreduced = output->numel(); - const int64_t reduced = shuffled_input.numel() / unreduced; - shuffled_input.Resize({unreduced, reduced}); - DDim output_dim = output->dims(); - output->Resize({unreduced}); - ReduceFunctor( - dev_ctx, shuffled_input, output, {1}, keep_dim); - output->Resize(output_dim); -} - -////////////// ReduceKernel - -template -void ReduceKernelImpl(const DeviceContext& dev_ctx, - const pten::DenseTensor& input, - pten::DenseTensor* output, - const std::vector& dims, - bool keep_dim, - bool reduce_all) { - output->mutable_data(); - - if (reduce_all) { - // Flatten and reduce 1-D tensor - auto x = EigenVector::Flatten(input); - auto out = EigenScalar::From(*output); - auto& dev = *dev_ctx.eigen_device(); - auto reduce_dim = Eigen::array({{0}}); - - Functor functor; - functor(dev, &x, &out, reduce_dim); - } else { - int ndim = input.dims().size(); - int rdim = dims.size(); - if (ndim > 6) { - HandleLargeDim( - dev_ctx, input, output, dims, keep_dim); - - } else { - HANDLE_REDUCE_DIM(6, 5); - HANDLE_REDUCE_DIM(6, 4); - HANDLE_REDUCE_DIM(6, 3); - HANDLE_REDUCE_DIM(6, 2); - HANDLE_REDUCE_DIM(6, 1); - HANDLE_REDUCE_DIM(5, 4); - HANDLE_REDUCE_DIM(5, 3); - HANDLE_REDUCE_DIM(5, 2); - HANDLE_REDUCE_DIM(5, 1); - HANDLE_REDUCE_DIM(4, 3); - HANDLE_REDUCE_DIM(4, 2); - HANDLE_REDUCE_DIM(4, 1); - HANDLE_REDUCE_DIM(3, 2); - HANDLE_REDUCE_DIM(3, 1); - HANDLE_REDUCE_DIM(2, 1); - HANDLE_REDUCE_DIM(1, 1); - } - } -} - -//////// Sum Functor /////// -struct SumFunctor { - template - void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { - y->device(place) = x->sum(dim); - } -}; - -//////// Mean Functor /////// -struct MeanFunctor { - template - void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { - y->device(place) = x->mean(dim); - } -}; - -} // namespace eigen -} // namespace pten diff --git a/paddle/pten/kernels/math_kernel.h b/paddle/pten/kernels/math_kernel.h index b1e5188f3aaef..f87d0a31b470b 100644 --- a/paddle/pten/kernels/math_kernel.h +++ b/paddle/pten/kernels/math_kernel.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/include/infermeta.h" +#include "paddle/pten/kernels/empty_kernel.h" namespace pten { @@ -121,4 +122,34 @@ DenseTensor Multiply(const ContextT& dev_ctx, return dense_out; } +template +DenseTensor Mean(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis, + bool keep_dim) { + auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim); + auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); + bool reduce_all = false; + MeanKernel(dev_ctx, x, axis, keep_dim, reduce_all, &dense_out); + return dense_out; +} + +template +DenseTensor Sum(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis, + DataType dtype, + bool keep_dim) { + auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim, dtype); + auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); + + // The real value of reduce_all will be get in kernel + // so use default value(false) is OK. + bool reduce_all = false; + + SumKernel( + dev_ctx, x, axis, keep_dim, reduce_all, out_meta.dtype, &dense_out); + return dense_out; +} + } // namespace pten diff --git a/paddle/pten/kernels/reshape_kernel.cc b/paddle/pten/kernels/reshape_kernel.cc index 0535ea20c8cb0..d7e2e2707ee1b 100644 --- a/paddle/pten/kernels/reshape_kernel.cc +++ b/paddle/pten/kernels/reshape_kernel.cc @@ -22,10 +22,10 @@ namespace pten { template -void Reshape(const Context& dev_ctx, - const DenseTensor& x, - const ScalarArray& shape, - DenseTensor* out) { +void ReshapeKernel(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* out) { auto out_meta = InferMetaFromVecValue(x.meta(), shape.GetData()); if (x.data() == out->data() && x.numel() == out->numel()) { out->Resize(out_meta.dims); @@ -43,13 +43,16 @@ void ReshapeWithXShape(const Context& dev_ctx, DenseTensor* xshape, DenseTensor* out) { funcs::SetXShape(x, xshape); - Reshape(dev_ctx, x, shape, out); + ReshapeKernel(dev_ctx, x, shape, out); } } // namespace pten -PT_REGISTER_GENERAL_KERNEL( - reshape, CPU, ALL_LAYOUT, pten::Reshape, ALL_DTYPE) {} +PT_REGISTER_GENERAL_KERNEL(reshape, + CPU, + ALL_LAYOUT, + pten::ReshapeKernel, + ALL_DTYPE) {} PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape, CPU, ALL_LAYOUT, @@ -57,8 +60,11 @@ PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape, ALL_DTYPE) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PT_REGISTER_GENERAL_KERNEL( - reshape, GPU, ALL_LAYOUT, pten::Reshape, ALL_DTYPE) {} +PT_REGISTER_GENERAL_KERNEL(reshape, + GPU, + ALL_LAYOUT, + pten::ReshapeKernel, + ALL_DTYPE) {} PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape, GPU, ALL_LAYOUT, @@ -67,8 +73,11 @@ PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape, #endif #ifdef PADDLE_WITH_XPU -PT_REGISTER_GENERAL_KERNEL( - reshape, XPU, ALL_LAYOUT, pten::Reshape, ALL_DTYPE) {} +PT_REGISTER_GENERAL_KERNEL(reshape, + XPU, + ALL_LAYOUT, + pten::ReshapeKernel, + ALL_DTYPE) {} PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape, XPU, ALL_LAYOUT, diff --git a/paddle/pten/kernels/reshape_kernel.h b/paddle/pten/kernels/reshape_kernel.h index b10e31a434c00..faa51c69ad17c 100644 --- a/paddle/pten/kernels/reshape_kernel.h +++ b/paddle/pten/kernels/reshape_kernel.h @@ -16,14 +16,16 @@ limitations under the License. */ #include "paddle/pten/common/scalar_array.h" #include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/include/infermeta.h" +#include "paddle/pten/kernels/empty_kernel.h" namespace pten { template -void Reshape(const Context& dev_ctx, - const DenseTensor& x, - const ScalarArray& shape, - DenseTensor* out); +void ReshapeKernel(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* out); template void ReshapeWithXShape(const Context& dev_ctx, @@ -32,4 +34,14 @@ void ReshapeWithXShape(const Context& dev_ctx, DenseTensor* xshape, DenseTensor* out); +template +DenseTensor Reshape(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& shape) { + auto out_meta = InferMetaFromVecValue(x.meta(), shape); + auto dense_out = Empty(dev_ctx, std::move(out_meta)); + ReshapeKernel(dev_ctx, x, ScalarArray(shape), &dense_out); + return dense_out; +} + } // namespace pten diff --git a/paddle/pten/tests/kernels/test_cast_dev_api.cc b/paddle/pten/tests/kernels/test_cast_dev_api.cc index dc3cff150b47b..cb45d827e3be9 100644 --- a/paddle/pten/tests/kernels/test_cast_dev_api.cc +++ b/paddle/pten/tests/kernels/test_cast_dev_api.cc @@ -16,7 +16,7 @@ limitations under the License. */ #include #include -#include "paddle/pten/include/manipulation.h" +#include "paddle/pten/kernels/cast_kernel.h" #include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/common/data_type.h" diff --git a/paddle/pten/tests/kernels/test_flatten_dev_api.cc b/paddle/pten/tests/kernels/test_flatten_dev_api.cc index d2ff7480e904f..f18e5c050ba70 100644 --- a/paddle/pten/tests/kernels/test_flatten_dev_api.cc +++ b/paddle/pten/tests/kernels/test_flatten_dev_api.cc @@ -15,7 +15,7 @@ limitations under the License. */ #include #include -#include "paddle/pten/include/manipulation.h" +#include "paddle/pten/kernels/flatten_kernel.h" #include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/core/dense_tensor.h" diff --git a/paddle/pten/tests/kernels/test_mean_dev_api.cc b/paddle/pten/tests/kernels/test_mean_dev_api.cc index 4d062977e23bd..4b254e7e6c1ac 100644 --- a/paddle/pten/tests/kernels/test_mean_dev_api.cc +++ b/paddle/pten/tests/kernels/test_mean_dev_api.cc @@ -15,7 +15,7 @@ limitations under the License. */ #include #include -#include "paddle/pten/include/math.h" +#include "paddle/pten/kernels/math_kernel.h" #include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/core/dense_tensor.h" diff --git a/paddle/pten/tests/kernels/test_reshape_dev_api.cc b/paddle/pten/tests/kernels/test_reshape_dev_api.cc index 64efdc6f67201..0196e1c211004 100644 --- a/paddle/pten/tests/kernels/test_reshape_dev_api.cc +++ b/paddle/pten/tests/kernels/test_reshape_dev_api.cc @@ -15,7 +15,7 @@ limitations under the License. */ #include #include -#include "paddle/pten/include/manipulation.h" +#include "paddle/pten/kernels/reshape_kernel.h" #include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/core/dense_tensor.h" diff --git a/paddle/pten/tests/kernels/test_sum_dev_api.cc b/paddle/pten/tests/kernels/test_sum_dev_api.cc index 381b8fe44f532..afaf903063781 100644 --- a/paddle/pten/tests/kernels/test_sum_dev_api.cc +++ b/paddle/pten/tests/kernels/test_sum_dev_api.cc @@ -15,7 +15,7 @@ limitations under the License. */ #include #include -#include "paddle/pten/include/math.h" +#include "paddle/pten/kernels/math_kernel.h" #include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/core/dense_tensor.h" diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 9ff673b1d2901..7cad4d746bbf2 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -23,6 +23,7 @@ import pickle import time import paddle +from paddle.fluid.backward import append_backward from paddle.distributed.utils import get_logger from paddle.distributed.fleet import cloud_utils import paddle.fluid.core as core @@ -96,49 +97,35 @@ def _remove_distributed_attrs(self, main_program): if suffix in attr_name: op._remove_attr(attr_name) - def _apply_serial_forward_pass(self, main_program, startup_program): + def _apply_serial_pass(self, main_program, startup_program): - # apply amp forward pass + # apply amp pass if self._dist_strategy.amp: auto_parallel_amp_pass = new_pass("auto_parallel_amp_pass", self._dist_strategy.amp_configs) - auto_parallel_amp_pass.apply_forward(main_program, startup_program, - self._pass_context) + auto_parallel_amp_pass.apply(main_program, startup_program, + self._pass_context) - # apply recompute forward pass + # apply recompute pass if self._dist_strategy.recompute: auto_parallel_recompute_pass = new_pass( "auto_parallel_recompute_pass", self._dist_strategy.recompute_configs) - auto_parallel_recompute_pass.apply_forward( - main_program, startup_program, self._pass_context) + auto_parallel_recompute_pass.apply(main_program, startup_program, + self._pass_context) def _generate_backward(self, main_program, startup_program, loss, parameter_list, no_grad_set, callbacks): - # apply recompute backward pass - if self._dist_strategy.recompute: - assert auto_parallel_recompute_pass - auto_parallel_recompute_pass.apply_forward( - main_program, startup_program, parameter_list, no_grad_set, - self._pass_context) - else: - from paddle.fluid.backward import append_backward - with program_guard(main_program, startup_program): - params_grads = append_backward( - loss, - parameter_list, - no_grad_set, - callbacks, - distop_context=self._dist_context.dist_op_context) - complete_backward_annotation( - main_program, dist_context=self._dist_context) - - # apply amp forward pass - if self._dist_strategy.amp: - assert auto_parallel_amp_pass - auto_parallel_amp_pass.apply_backward(main_program, startup_program, - self._pass_context) + with program_guard(main_program, startup_program): + params_grads = append_backward( + loss, + parameter_list, + no_grad_set, + callbacks, + distop_context=self._dist_context.dist_op_context) + complete_backward_annotation( + main_program, dist_context=self._dist_context) return params_grads @@ -192,14 +179,14 @@ def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False): completed_main_program = serial_main_program self._dist_context = copy.deepcopy(dist_context) - # serial forward pass - self._apply_serial_forward_pass(completed_main_program, - serial_startup_program) # serial backward pass params_grads = self._generate_backward( completed_main_program, serial_startup_program, serial_loss, self._parameter_list, self._no_grad_set, self._callbacks) + # serial forward pass + self._apply_serial_pass(completed_main_program, serial_startup_program) + # Logical partition rank = paddle.distributed.get_rank() partitioner = Partitioner(self._dist_context, rank) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 5e799c52092db..2785eae6e8a46 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -94,7 +94,7 @@ def _build_sharding_groups(self, main_block, params_grads): def _collective_data_parallel_groups(self, main_block): for op in main_block.ops: - if op.type in _skip_ops: + if not _is_forward_op(op) or op.type in _skip_ops: continue group = _inference_data_parallel_group_for_operator( self.global_rank, op, self._dist_context) @@ -106,7 +106,7 @@ def _collective_data_parallel_groups(self, main_block): if len(self.dp_groups) != 1: raise NotImplementedError( "So far Only and Exactly one data parallel group in network are supported, but got [{}] different data parallel groups". - format(len(groups))) + format(len(self.dp_groups))) def _build_sharding_infos(self, params_grads): @@ -193,18 +193,32 @@ def _shard_gradient_clip(self, main_block): return # TODO (JZ-LIANG) support calculate global norm with tensor parallelism - is_clip_grad_by_global_norm = False + removed_op_type = ['elementwise_mul', 'squared_l2_norm', 'clip_by_norm'] + removed_op_idx = set() + removed_tmp_var = set() + for idx, op in list(enumerate(main_block.ops)): if not _is_gradient_clip_op(op): continue - if op.type == 'sum': - is_clip_grad_by_global_norm = True - break - if not is_clip_grad_by_global_norm: - return - removed_op_idx = set() - removed_tmp_var = set() + if op.type in removed_op_type: + input_name = op.input("X")[0] + param_name = input_name[:input_name.find("@GRAD")] + if not self._is_parameter_in_local_shard(param_name): + removed_op_idx.add(idx) + if op.type in ['squared_l2_norm', 'clip_by_norm']: + for output_name in op.output_arg_names: + removed_tmp_var.add(output_name) + + for idx, op in reversed(list(enumerate(main_block.ops))): + if not _is_gradient_clip_op(op): + continue + if idx in removed_op_idx: + main_block._remove_op(idx, sync=False) + + for varname in removed_tmp_var: + main_block._remove_var(varname, sync=False) + for idx, op in list(enumerate(main_block.ops)): if not _is_gradient_clip_op(op): continue @@ -218,7 +232,7 @@ def _shard_gradient_clip(self, main_block): sum_op_output = op.desc.output_arg_names()[0] for i, sharding_info in enumerate(self.sharding_infos): new_op = main_block._insert_op( - idx + i, + idx + i + 1, type='c_allreduce_sum', inputs={'X': [sum_op_output]}, outputs={'Out': [sum_op_output]}, @@ -235,21 +249,6 @@ def _shard_gradient_clip(self, main_block): new_op, dist_attr.process_mesh, dist_attr.dims_mapping, self._dist_context) break - for input_name in op.input_arg_names: - param_name = input_name[:input_name.find("@GRAD")] - if not self._is_parameter_in_local_shard(param_name): - removed_op_idx.add(idx) - for output_name in op.output_arg_names: - removed_tmp_var.add(output_name) - - for idx, op in reversed(list(enumerate(main_block.ops))): - if not _is_gradient_clip_op(op): - continue - if idx in removed_op_idx: - main_block._remove_op(idx, sync=False) - - for varname in removed_tmp_var: - main_block._remove_var(varname, sync=False) main_block._sync_with_cpp() @@ -424,12 +423,15 @@ def _shard_parameter(self, main_block, startup_block): startup_block._remove_op(idx, sync=False) continue - if op.type != "c_broadcast" and output_name in not_used_param_nane: + if op.type != "c_broadcast" and output_name in param_usage and sharding_info.get_var_rank( + output_name) != sharding_info.local_rank: startup_block._remove_op(idx, sync=False) - for varname in not_used_param_nane: - main_block._remove_var(varname, sync=False) - startup_block._remove_var(varname, sync=False) + for param_name in param_usage: + if sharding_info.get_var_rank( + param_name) != sharding_info.local_rank: + main_block._remove_var(param_name, sync=False) + startup_block._remove_var(param_name, sync=False) main_block._sync_with_cpp() startup_block._sync_with_cpp() @@ -594,6 +596,10 @@ def _is_param_grad_allreduce_op(op, block, dp_ring_ids): return block.var(base_name).is_parameter +def _is_forward_op(op): + return op.attr("op_role") == 0 + + def _inference_data_parallel_group_for_operator(rank_id, op, dist_context): dp_group = None diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py b/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py index f5eda2fdbf8e2..42bdf67824220 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py @@ -178,13 +178,13 @@ def get_gpt_model(self, strategy, place, batch_size, sequence_len, preds = model(tokens, position_ids, attention_mask) criterion = GPTPretrainingCriterion() loss = criterion(preds, labels, loss_mask) - + clip = paddle.nn.ClipGradByNorm(clip_norm=1.0) optimizer = paddle.fluid.optimizer.AdamOptimizer( learning_rate=0.00001, beta1=0.9, beta2=0.999, epsilon=1e-08, - grad_clip=None) + grad_clip=clip) optimizer = fleet.distributed_optimizer(optimizer) startup_program = paddle.static.default_startup_program() _, _, dist_startup_prog, dist_main_prog = optimizer.minimize( diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_sharding_pass.py b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_sharding_pass.py index f6b42701c2195..51e87260609df 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_sharding_pass.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_sharding_pass.py @@ -46,7 +46,7 @@ def apply_passes(self): dist_strategy.sharding = True dist_strategy.sharding_configs = { "sharding_degree": 2, - "stage": 3, + "stage": 2, } fleet.init(is_collective=True, strategy=dist_strategy) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py index ab91c3fe7c4c2..83254de61298b 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py @@ -157,9 +157,6 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): complete_train_program = auto.complete_annotation(train_program, dist_context) - parallelizer._apply_serial_forward_pass(complete_train_program, - startup_program) - params_grads = parallelizer._generate_backward( complete_train_program, startup_program, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py index 9fe5a52cf08af..3a28595c833e0 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py @@ -478,8 +478,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): # auto completion complete_train_program = auto.complete_annotation(train_program, dist_context) - parallelizer._apply_serial_forward_pass(complete_train_program, - startup_program) + params_grads = parallelizer._generate_backward( complete_train_program, startup_program, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py index 3270cfc3c8a54..dc2ad1d900f52 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py @@ -884,10 +884,6 @@ def test_gpt_dp_mp(self): complete_train_program = auto.complete_annotation(train_program, dist_context) - # serial forward pass - parallelizer._apply_serial_forward_pass(complete_train_program, - startup_program) - # serial backward pass params_grads = parallelizer._generate_backward( complete_train_program, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py index 0631cc74a32bd..614b996d26521 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py @@ -155,9 +155,6 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): complete_train_program = auto.complete_annotation(train_program, dist_context) - parallelizer._apply_serial_forward_pass(complete_train_program, - startup_program) - params_grads = parallelizer._generate_backward( complete_train_program, startup_program, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py index 0e098664f7ebb..cfbb7653fad8e 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py @@ -119,9 +119,6 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): complete_train_program = auto.complete_annotation(train_program, dist_context) - parallelizer._apply_serial_forward_pass(complete_train_program, - startup_program) - params_grads = parallelizer._generate_backward( complete_train_program, startup_program, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py index c6b1be652073c..272c1c212f08e 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py @@ -134,8 +134,6 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): # serial forward & backward completion complete_train_program = auto.complete_annotation(train_program, dist_context) - parallelizer._apply_serial_forward_pass(complete_train_program, - startup_program) params_grads = parallelizer._generate_backward( complete_train_program, diff --git a/python/paddle/fluid/tests/unittests/test_python_bf16_numpy_datatype.py b/python/paddle/fluid/tests/unittests/test_python_bf16_numpy_datatype.py new file mode 100644 index 0000000000000..a58d7d35807c6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_python_bf16_numpy_datatype.py @@ -0,0 +1,34 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from paddle_bfloat import bfloat16 +import unittest + + +class TestBF16DataType(unittest.TestCase): + def test_matmul(self): + a_bf16 = np.random.random((6, 7)).astype(bfloat16) + b_bf16 = np.random.random((7, 8)).astype(bfloat16) + c_bf16 = np.matmul(a_bf16, b_bf16) + + a_fp32 = a_bf16.astype(np.float32) + b_fp32 = b_bf16.astype(np.float32) + c_fp32 = np.matmul(a_fp32, b_fp32) + + self.assertTrue(np.allclose(c_bf16, c_fp32)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py old mode 100644 new mode 100755 index b54c3596a26a9..a15c1af391f9f --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1838,7 +1838,7 @@ def expand_as(x, y, name=None): "you must set its stop_gradient to be False by " "some_var.stop_gradient = True, supporting " "some_var as the input 'x'.") - inputs = {"X": [x]} + inputs = {"X": [x], "Y": [y]} helper = LayerHelper('expand_as', **locals()) dtype = helper.input_dtype(input_param_name='x') diff --git a/python/requirements.txt b/python/requirements.txt index f2a4580a94e51..5f2b788a81a0a 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -5,3 +5,4 @@ Pillow six decorator astor +paddle_bfloat==0.1.2