Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… tensor_inherit_from_dense_tensor
  • Loading branch information
jim19930609 committed Jan 7, 2022
2 parents 2d91053 + 7f3b087 commit f8c4a19
Show file tree
Hide file tree
Showing 46 changed files with 631 additions and 520 deletions.
18 changes: 17 additions & 1 deletion paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,23 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
if (op_with_kernel == nullptr) {
instr_node.OpBase()->Run(*local_scope, place_);
} else {
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
// fit for pten
if (instr_node.PtenKernel() && instr_node.PtenKernel()->IsValid()) {
VLOG(4) << "Run pten kernel: " << op->Type();
VLOG(4) << instr_node.InnerRuntimeContext().get() << " "
<< &instr_node.DeviceContext();
op_with_kernel->BuildPtenKernelContext(
*instr_node.InnerRuntimeContext().get(),
const_cast<platform::DeviceContext*>(&instr_node.DeviceContext()));

(*instr_node.PtenKernel())(instr_node.PtenKernelContext());

op_with_kernel->WriteBackToOutputs(
instr_node.InnerRuntimeContext().get());
instr_node.PtenKernelContext()->ClearData();
} else {
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
}
}
}

Expand Down
48 changes: 38 additions & 10 deletions paddle/fluid/framework/new_executor/interpretercore_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/pten/core/kernel_factory.h"

PADDLE_DEFINE_EXPORTED_bool(
new_executor_sequential_run, false,
"Enable sequential execution for standalone executor, used for debug");
DECLARE_bool(run_pten_kernel);

namespace paddle {
namespace framework {
namespace interpreter {
Expand Down Expand Up @@ -338,6 +341,8 @@ void build_op_func_list(const platform::Place& place,
// op is not a operatorwithkernel, so direcly run OperatorBase::Run()
deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope);
} else {
auto op_with_kernel =
static_cast<const framework::OperatorWithKernel*>(op);
// construct RuntimeContext and analysis KernelType
RuntimeContext runtime_context({}, {});
runtime_context.inputs.swap(ins_map);
Expand All @@ -350,8 +355,7 @@ void build_op_func_list(const platform::Place& place,
// TODO(Aurelius84): In case of control flow ops, they are NOT
// inheritted
// from OperatorWithKernel.
static_cast<const framework::OperatorWithKernel*>(op)->InferShape(
&infer_shape_ctx);
op_with_kernel->InferShape(&infer_shape_ctx);
}

auto kernels_iter = all_op_kernels.find(op->Type());
Expand All @@ -367,21 +371,25 @@ void build_op_func_list(const platform::Place& place,
platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
Scope scope;
auto expected_kernel_key =
dynamic_cast<const framework::OperatorWithKernel*>(op)
->GetExpectedKernelType(
ExecutionContext(*op, scope, *dev_ctx, runtime_context));
auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(
ExecutionContext(*op, scope, *dev_ctx, runtime_context));

// change device by the device_guard()
apply_device_guard(op, place, &expected_kernel_key);
VLOG(3) << "expected_kernel_key : " << expected_kernel_key;

// step 3. apply data transforms and insert data transfer ops
VariableValueMap& ins_map_temp = runtime_context.inputs;

// NOTE(zhiqiu): op_func_node->operator_base_ maybe changed in
// ApplyDataTransform
ApplyDataTransform(expected_kernel_key, place, &ins_map_temp, var_scope,
&op_func_node, vec_func_list, use_local_scope);
op_with_kernel = static_cast<const framework::OperatorWithKernel*>(
op_func_node.operator_base_.get());

// step 4. Run op kernel
VLOG(3) << op->Type()
VLOG(3) << op_with_kernel->Type()
<< " : expected_kernel_key : " << expected_kernel_key;

if (platform::is_gpu_place(expected_kernel_key.place_)) {
Expand All @@ -397,7 +405,8 @@ void build_op_func_list(const platform::Place& place,
}
op_func_node.dev_ctx_ = dev_ctx;

auto exec_ctx = ExecutionContext(*op, scope, *dev_ctx, runtime_context);
auto exec_ctx =
ExecutionContext(*op_with_kernel, scope, *dev_ctx, runtime_context);

auto kernel_iter = kernels.find(expected_kernel_key);
PADDLE_ENFORCE_NE(
Expand All @@ -406,8 +415,27 @@ void build_op_func_list(const platform::Place& place,
"Operator (%s) does not have kernel for %s.", op->Type(),
KernelTypeToString(expected_kernel_key)));

op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
op_func_node.kernel_func_(exec_ctx);
auto run_pten_kernel = false;

if (FLAGS_run_pten_kernel &&
pten::KernelFactory::Instance().HasCompatiblePtenKernel(
op_with_kernel->Type())) {
op_with_kernel->ChoosePtenKernel(exec_ctx);
run_pten_kernel = op_with_kernel->PtenKernel()->IsValid();
}

if (run_pten_kernel) {
op_with_kernel->BuildPtenKernelContext(runtime_context, dev_ctx);
op_func_node.pt_kernel_ = op_with_kernel->PtenKernel();
op_func_node.pt_kernel_context_ = op_with_kernel->PtenKernelContext();

(*op_func_node.pt_kernel_)(op_func_node.pt_kernel_context_);
op_with_kernel->WriteBackToOutputs(&runtime_context);
op_func_node.pt_kernel_context_->ClearData();
} else {
op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
op_func_node.kernel_func_(exec_ctx);
}

// post-process grad_op.outputs if need cast complex grad into real grad.
// NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it.
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/framework/new_executor/new_executor_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,14 @@ OpKernelComputeFunc Instruction::KernelFunc() const {
return op_func_node_.kernel_func_;
}

pten::Kernel* Instruction::PtenKernel() const {
return op_func_node_.pt_kernel_;
}

pten::KernelContext* Instruction::PtenKernelContext() const {
return op_func_node_.pt_kernel_context_;
}

OpFuncType Instruction::KernelType() const { return op_func_node_.type_; }

OperatorBase* Instruction::OpBase() const {
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/framework/new_executor/new_executor_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,11 @@ struct OpFuncNode {

OpKernelComputeFunc kernel_func_;
platform::DeviceContext* dev_ctx_; // not owned

// fit for pten kernel
pten::Kernel* pt_kernel_{nullptr}; // not owned
pten::KernelContext* pt_kernel_context_{nullptr}; // not onwed

OpFuncType type_;
};

Expand All @@ -313,6 +318,10 @@ class Instruction {

OpKernelComputeFunc KernelFunc() const;

pten::Kernel* PtenKernel() const;

pten::KernelContext* PtenKernelContext() const;

OpFuncType KernelType() const;

OperatorBase* OpBase() const;
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1791,6 +1791,9 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(

void OperatorWithKernel::BuildPtenKernelContext(
const RuntimeContext& ctx, platform::DeviceContext* dev_ctx) const {
if (pt_kernel_context_ == nullptr) {
pt_kernel_context_.reset(new pten::KernelContext());
}
// TODO(chenweihang): now only work for very simple case,
// many cases need to be deal with later:
// 1. the input and output are not tensor
Expand Down
22 changes: 14 additions & 8 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,20 @@ class OperatorWithKernel : public OperatorBase {
virtual KernelSignature GetExpectedPtenKernelArgs(
const ExecutionContext& ctx) const;

/* member functions for adapting to pten lib */
void ChoosePtenKernel(const ExecutionContext& ctx) const;

void BuildPtenKernelContext(const RuntimeContext& ctx,
platform::DeviceContext* dev_ctx) const;

void WriteBackToOutputs(RuntimeContext* ctx) const;

pten::Kernel* PtenKernel() const { return pt_kernel_.get(); }

pten::KernelContext* PtenKernelContext() const {
return pt_kernel_context_.get();
}

private:
void RunImpl(const Scope& scope, const platform::Place& place) const final;
void RunImpl(const Scope& scope, const platform::Place& place,
Expand Down Expand Up @@ -595,14 +609,6 @@ class OperatorWithKernel : public OperatorBase {
Tensor* GetTensorFormInputSafely(const ExecutionContext& ctx,
const std::string& name) const;

/* member functions for adapting to pten lib */
void ChoosePtenKernel(const ExecutionContext& ctx) const;

void BuildPtenKernelContext(const RuntimeContext& ctx,
platform::DeviceContext* dev_ctx) const;

void WriteBackToOutputs(RuntimeContext* ctx) const;

protected:
mutable std::unique_ptr<OpKernelType> kernel_type_;
mutable std::unique_ptr<OpKernelFunc> kernel_func_;
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/cast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ limitations under the License. */

#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/manipulation.h"
#include "paddle/pten/kernels/cast_kernel.h"

namespace paddle {
namespace operators {
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/operators/expand_as_v2_op.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ limitations under the License. */
#include "paddle/fluid/operators/expand_as_v2_op.h"
#include <memory>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -50,6 +51,10 @@ class ExpandAsV2OpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"X is the input to be expanded.");
AddInput("Y",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"Expand X according to the shape of Y.")
.AsDispensable();
AddOutput("Out",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"The rank of Output(Out) have the same with Input(X). "
Expand Down Expand Up @@ -144,3 +149,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ExpandAsV2GradKernel<paddle::platform::CUDADeviceContext, double>);
#endif

REGISTER_OP_VERSION(expand_as_v2)
.AddCheckpoint(
R"ROC(fix expand_as_v2 and add new input [Y])ROC",
paddle::framework::compatible::OpVersionDesc().NewInput(
"Y", "Expand X according to the shape of Y"));
33 changes: 25 additions & 8 deletions paddle/fluid/operators/expand_as_v2_op.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,34 @@ class ExpandAsV2Kernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_NE(target_shape[i], 0,
platform::errors::InvalidArgument(
"The value of target shape cannot be zero."));
if (vec_in_dims[i] != 1) {
if (i < diff) {
PADDLE_ENFORCE_GT(
target_shape[i], 0,
platform::errors::InvalidArgument(
"The expanded size (%d) for non-existing dimensions must be "
"positive for expand_as_v2 op.",
target_shape[i]));
repeat_times[i] = target_shape[i];
} else if (target_shape[i] > 0) {
if (vec_in_dims[i] != 1) {
PADDLE_ENFORCE_EQ(
vec_in_dims[i], target_shape[i],
platform::errors::InvalidArgument(
"The value (%d) of the non-singleton dimension does not match"
" the corresponding value (%d) in shape for expand_as_v2 op.",
vec_in_dims[i], target_shape[i]));
repeat_times[i] = 1;
} else {
repeat_times[i] = target_shape[i];
}
} else {
PADDLE_ENFORCE_EQ(
vec_in_dims[i], target_shape[i],
target_shape[i], -1,
platform::errors::InvalidArgument(
"The value (%d) of the non-singleton dimension does not match"
" the corresponding value (%d) in "
"target tensor for expand_as_v2 op.",
vec_in_dims[i], target_shape[i]));
"When the value in shape is negative for expand_as_v2 op, "
"only -1 is supported, but the value received is %d.",
target_shape[i]));
repeat_times[i] = 1;
} else {
repeat_times[i] = target_shape[i];
}
}
auto* out0 = context.Output<Tensor>("Out");
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/operators/expand_v2_op.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ class ExpandV2Op : public framework::OperatorWithKernel {
if (x_dims[i] == -1) {
out_shape[i] = -1;
} else if (expand_shape[i] == -1) {
out_shape[i] = x_dims[i];
if (static_cast<size_t>(x_dims.size()) > i) {
out_shape[i] = x_dims[i];
} else {
out_shape[i] = -1;
}
} else if (expand_shape[i] == -2) {
// We use -2 to represent the element in expand_shape is a var.
out_shape[i] = -1;
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/operators/flatten_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/pooling.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/manipulation.h"
#include "paddle/pten/kernels/flatten_kernel.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -134,8 +134,8 @@ class FlattenContiguousRangeKernel : public framework::OpKernel<T> {
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);

// call new kernel
pten::Flatten<T, DeviceContext>(dev_ctx, *pt_x.get(), start_axis, stop_axis,
pt_out.get());
pten::FlattenKernel<T, DeviceContext>(dev_ctx, *pt_x.get(), start_axis,
stop_axis, pt_out.get());
}
};

Expand Down
37 changes: 21 additions & 16 deletions paddle/fluid/operators/fused/fused_dropout_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,14 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
}

// out = layernorm(residual + dropout(src + bias))
void LayernormResidualDropoutBias(
const platform::CUDADeviceContext& ctx, const T* src, const T* residual,
const T* bias, const LayerNormParamType<T>* gamma,
const LayerNormParamType<T>* beta, T* dropout_out, MaskType* mask, T* out,
LayerNormParamType<T>* mean, LayerNormParamType<T>* variance) {
template <typename P = LayerNormParamType<T>, bool is_same_type = false>
void LayernormResidualDropoutBias(const platform::CUDADeviceContext& ctx,
const T* src, const T* residual,
const T* bias, const P* gamma,
const P* beta, T* dropout_out,
MaskType* mask, T* out,
LayerNormParamType<T>* mean,
LayerNormParamType<T>* variance) {
using U = LayerNormParamType<T>;
int vec_size = MAX_CACHE_BYTES / sizeof(T);
if (this->cols_ % vec_size != 0) {
Expand All @@ -263,25 +266,27 @@ class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
int threads = GetDesiredBlockDim(this->cols_ / vec_size);
int increment = ((this->cols_ - 1) / (threads * vec_size) + 1) * vec_size;
increment = this->dropout_param_.UpdateSeedAndIncrement(ctx, increment);
LaunchLayernormResidualDropoutBias<T, MaskType>(
LaunchLayernormResidualDropoutBias<T, MaskType, U, is_same_type>(
this->rows_, this->cols_, increment, this->dropout_param_.seed,
this->dropout_param_.dropout_prob, epsilon_,
this->dropout_param_.is_upscale_in_train, this->dropout_param_.is_test,
src, residual, bias, gamma, beta, mask, dropout_out, out, mean,
variance, ctx);
}

void LayernormResidualDropoutBiasGrad(
const platform::CUDADeviceContext& ctx, const T* d_out,
const T* layernorm_src, const MaskType* mask,
const LayerNormParamType<T>* gamma, const LayerNormParamType<T>* mean,
const LayerNormParamType<T>* variance, T* d_layernorm_src,
LayerNormParamType<T>* d_scale, LayerNormParamType<T>* d_layernorm_bias,
T* d_dropout_src, T* d_bias, T* d_residual) {
template <typename P = LayerNormParamType<T>, bool is_same_type = false>
void LayernormResidualDropoutBiasGrad(const platform::CUDADeviceContext& ctx,
const T* d_out, const T* layernorm_src,
const MaskType* mask, const P* gamma,
const LayerNormParamType<T>* mean,
const LayerNormParamType<T>* variance,
T* d_layernorm_src, P* d_scale,
P* d_layernorm_bias, T* d_dropout_src,
T* d_bias, T* d_residual) {
using U = LayerNormParamType<T>;
LayerNormBackward<T, U>(layernorm_src, d_out, gamma, mean, variance,
d_layernorm_src, d_scale, d_layernorm_bias,
epsilon_, this->rows_, this->cols_, ctx);
LayerNormBackward<T, U, is_same_type>(
layernorm_src, d_out, gamma, mean, variance, d_layernorm_src, d_scale,
d_layernorm_bias, epsilon_, this->rows_, this->cols_, ctx);
this->ResidualDropoutBiasGrad(ctx, d_layernorm_src, mask, d_dropout_src,
d_residual, d_bias);
}
Expand Down
Loading

0 comments on commit f8c4a19

Please sign in to comment.