Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update #28

Merged
merged 9 commits into from
Sep 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ option(WITH_UNITY_BUILD "Compile with UnityBuild mode" OFF)
option(WITH_STRIP "Strip so files of Whl packages" OFF)
option(NEW_RELEASE_CUBIN "PaddlePaddle next-level release strategy for pypi cubin package" OFF)
option(NEW_RELEASE_JIT "PaddlePaddle next-level release strategy for backup jit package" OFF)
option(WITH_ASCEND_INT64 "Compile with int64 kernel for ascend NPU" OFF)

# PY_VERSION
if(NOT PY_VERSION)
Expand Down
4 changes: 4 additions & 0 deletions cmake/configure.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ if(WITH_ASCEND_CL)
add_definitions(-DPADDLE_WITH_ASCEND_CL)
endif()

if(WITH_ASCEND_INT64)
add_definitions(-DPADDLE_WITH_ASCEND_INT64)
endif()

if(WITH_XPU)
message(STATUS "Compile with XPU!")
add_definitions(-DPADDLE_WITH_XPU)
Expand Down
42 changes: 23 additions & 19 deletions paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#if !defined(_WIN32)
#include <sched.h>
#else
#define NOMINMAX
#include <windows.h>
#endif // !_WIN32

#include "paddle/fluid/framework/new_executor/interpretercore.h"

#include <unordered_set>
Expand Down Expand Up @@ -255,10 +263,7 @@ void InterpreterCore::Convert() {
}

for (size_t i = 0; i < vec_instruction_.size(); ++i) {
// int device_type = static_cast<int>(paddle::platform::DeviceType::CUDA);
// paddle::platform::DeviceOption dev_opt(
// device_type, BOOST_GET_CONST(platform::CUDAPlace, place_).device);
gc_event_.emplace_back(place_);
gc_event_.emplace_back(place_, platform::GenerateDeviceEventFlag());

std::vector<size_t> vec_temp;
for (auto& item : vec_instruction_[i].output_index_) {
Expand Down Expand Up @@ -450,41 +455,40 @@ void InterpreterCore::CheckGC(size_t instr_id,

if (!garbages_->empty()) {
if (max_memory_size_ <= 1) {
#if defined(PADDLE_WITH_CUDA)
auto* dev_ctx = reinterpret_cast<platform::CUDADeviceContext*>(
gc_event_[instr_id].Record(
platform::DeviceContextPool::Instance().Get(place));
gc_event_[instr_id].Record(dev_ctx);
gc_event_[instr_id].SetFininshed(); // Only for CPU Event
gc_queue_->AddTask(
[ container = garbages_.release(), event = &gc_event_[instr_id] ]() {
while (!event->Query()) {
#if defined(_WIN32)
SleepEx(50, FALSE);
#else
sched_yield();
#endif
continue;
}
delete container;
});
garbages_.reset(new GarbageQueue());
#else
delete garbages_.release();
garbages_.reset(new GarbageQueue());
#endif
} else if (cur_memory_size_ >= max_memory_size_) {
#if defined(PADDLE_WITH_CUDA)
auto* dev_ctx = reinterpret_cast<platform::CUDADeviceContext*>(
gc_event_[instr_id].Record(
platform::DeviceContextPool::Instance().Get(place));
gc_event_[instr_id].Record(dev_ctx);
gc_event_[instr_id].SetFininshed(); // Only for CPU Event
gc_queue_->AddTask(
[ container = garbages_.release(), event = &gc_event_[instr_id] ]() {
while (!event->Query()) {
#if defined(_WIN32)
SleepEx(50, FALSE);
#else
sched_yield();
#endif
continue;
}
delete container;
});
garbages_.reset(new GarbageQueue());
cur_memory_size_ = 0;
#else
delete garbages_.release();
garbages_.reset(new GarbageQueue());
cur_memory_size_ = 0;
#endif
}
}
}
Expand Down
15 changes: 14 additions & 1 deletion paddle/fluid/inference/api/paddle_analysis_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,12 @@ struct PD_INFER_DECL AnalysisConfig {
///
bool tensorrt_engine_enabled() const { return use_tensorrt_; }
///
/// \brief Get the TensorRT engine precision.
///
/// \return Precision Get the TensorRT engine precision.
///
Precision tensorrt_precision_mode() const { return tensorrt_precision_mode_; }
///
/// \brief Set min, max, opt shape for TensorRT Dynamic shape mode.
/// \param min_input_shape The min input shape of the subgraph input.
/// \param max_input_shape The max input shape of the subgraph input.
Expand All @@ -366,7 +372,14 @@ struct PD_INFER_DECL AnalysisConfig {
std::map<std::string, std::vector<int>> max_input_shape,
std::map<std::string, std::vector<int>> optim_input_shape,
bool disable_trt_plugin_fp16 = false);

///
/// \brief A boolean state telling whether the trt dynamic_shape is used.
///
/// \return bool Whether the trt dynamic_shape is used.
///
bool tensorrt_dynamic_shape_enabled() const {
return !min_input_shape_.empty();
}
///
/// \brief Prevent ops running in Paddle-TRT
/// NOTE: just experimental, not an official stable API, easy to be broken.
Expand Down
5 changes: 2 additions & 3 deletions paddle/fluid/operators/collective/c_allreduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ inline bool ContainsNan(const paddle::platform::NPUDeviceContext& dev_ctx,
try {
const auto& runner_mean = paddle::operators::NpuOpRunner(
"ReduceMeanD", {*in}, {mean}, {{"axes", axes}, {"keep_dims", false}});
runner_mean.Run(stream);
TensorToVector(mean, dev_ctx, &vec);
} catch (...) {
LOG(WARNING) << "ContainsNan catch exception";
Expand Down Expand Up @@ -240,8 +239,8 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
case framework::proto::VarType::FP32: {
if (FLAGS_hccl_check_nan) {
VLOG(3) << "prepare to FoundNanInf";
found_nan = ContainsNan(*dev_ctx, dev_ctx->stream(), in);
VLOG(3) << "check_numerics:" << found_nan;
// NOTE: performance relating, DO NOT REMOVE!
ContainsNan(*dev_ctx, dev_ctx->stream(), in);
}
break;
}
Expand Down
33 changes: 19 additions & 14 deletions paddle/fluid/operators/collective/recv_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,26 @@ class RecvOpV2 : public framework::OperatorWithKernel {
ring_id, 0,
platform::errors::InvalidArgument(
"The ring_id (%d) for recv_v2 op must be non-negative.", ring_id));
auto out_shape = ctx->Attrs().Get<std::vector<int>>("out_shape");
PADDLE_ENFORCE_GE(out_shape.size(), 1,
platform::errors::InvalidArgument(
"The size of the output shape must be greater than 0 "
"but the value given is %d.",
out_shape.size()));
for (size_t i = 0; i < out_shape.size(); ++i) {
PADDLE_ENFORCE_GE(out_shape[i], 1,
platform::errors::InvalidArgument(
"The shape attribute for recv_v2 must be set "
"explicitly, but the %dth element is %d which "
"is less than 1.",
i, out_shape[i]));

if (ctx->GetOutputsVarType("Out").front() ==
framework::proto::VarType::LOD_TENSOR) {
auto out_shape = ctx->Attrs().Get<std::vector<int>>("out_shape");
PADDLE_ENFORCE_GE(
out_shape.size(), 1,
platform::errors::InvalidArgument(
"The size of the output shape must be greater than 0 "
"but the value given is %d.",
out_shape.size()));
for (size_t i = 0; i < out_shape.size(); ++i) {
PADDLE_ENFORCE_GE(out_shape[i], 1,
platform::errors::InvalidArgument(
"The shape attribute for recv_v2 must be set "
"explicitly, but the %dth element is %d which "
"is less than 1.",
i, out_shape[i]));
}
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
}
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
}

protected:
Expand Down
37 changes: 28 additions & 9 deletions paddle/fluid/operators/collective/recv_v2_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,6 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument(
"The peer (%d) for recv_v2 op must be non-negative.", peer));

auto out = ctx.Output<framework::LoDTensor>("Out");
auto out_dims = out->dims();
auto numel = out->numel();
int data_type = ctx.Attr<int>("dtype");
framework::proto::VarType::Type type =
framework::proto::VarType::Type(data_type);

gpuStream_t stream = nullptr;
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
Expand All @@ -56,14 +49,40 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
} else {
stream = comm->stream();
}

PADDLE_ENFORCE_LT(
peer, comm->nranks(),
platform::errors::InvalidArgument("The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer, comm->nranks()));
out->mutable_data<T>(out_dims, place);

int data_type = ctx.Attr<int>("dtype");
framework::proto::VarType::Type type =
framework::proto::VarType::Type(data_type);
ncclDataType_t dtype = platform::ToNCCLDataType(type);

auto *out_var = ctx.OutputVar("Out");
if (out_var->IsType<framework::LoDTensorArray>()) {
auto out_array = out_var->GetMutable<framework::LoDTensorArray>();
for (size_t idx = 0; idx < out_array->size(); ++idx) {
VLOG(3) << "LodTensorArray: idx(" << idx << ")";
auto out = &out_array->at(idx);
auto out_dims = out->dims();
out->mutable_data<T>(out_dims, place, 0);
auto numel = out->numel();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv(
out->data<T>(), numel, dtype, peer, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " recv "
<< framework::product(out_dims) << " from " << peer;
}
return;
}

auto out_shape = ctx.Attr<std::vector<int>>("out_shape");
auto out = ctx.Output<framework::LoDTensor>("Out");
auto out_dims = out->dims();
auto numel = out->numel();

out->mutable_data<T>(out_dims, place);
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv(
out->data<T>(), numel, dtype, peer, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " recv "
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/operators/collective/send_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ class SendOpV2 : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const framework::Variable* var = ctx.InputVar("X");
if (var->IsType<framework::LoDTensorArray>()) {
auto t_arr = var->Get<framework::LoDTensorArray>();
// NOTE(sandyhouse): Support an empty tensor array as Input.
// And set the kernel type is float.
if (t_arr.size() == 0) {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context());
}
}
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
Expand Down
22 changes: 19 additions & 3 deletions paddle/fluid/operators/collective/send_v2_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
#if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \
NCCL_VERSION_CODE >= 2703
auto x = ctx.Input<framework::LoDTensor>("X");
int numel = x->numel();

int rid = ctx.Attr<int>("ring_id");
PADDLE_ENFORCE_GE(
rid, 0,
Expand All @@ -56,6 +53,25 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument("The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer, comm->nranks()));

auto* x_var = ctx.InputVar("X");
if (x_var->IsType<framework::LoDTensorArray>()) {
auto& x_array = x_var->Get<framework::LoDTensorArray>();
for (size_t idx = 0; idx < x_array.size(); idx++) {
VLOG(3) << "LodTensorArray: idx(" << idx << ")";
auto& x = x_array.at(idx);
int numel = x.numel();
ncclDataType_t dtype = platform::ToNCCLDataType(x.type());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend(
x.data<T>(), numel, dtype, peer, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " send "
<< framework::product(x.dims()) << " to " << peer;
}
return;
}
auto x = ctx.Input<framework::LoDTensor>("X");
int numel = x->numel();

ncclDataType_t dtype = platform::ToNCCLDataType(x->type());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend(
x->data<T>(), numel, dtype, peer, comm->comm(), stream));
Expand Down
39 changes: 19 additions & 20 deletions paddle/fluid/operators/fill_constant_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <memory>
#include <string>

#include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/operators/utils.h"

namespace paddle {
namespace operators {

template <typename DeviceContext, typename T>
template <typename T>
class FillConstantNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
Expand All @@ -32,7 +29,6 @@ class FillConstantNPUKernel : public framework::OpKernel<T> {
auto float_value = ctx.Attr<float>("value");

auto* out_var = ctx.Output<framework::Tensor>("Out");
auto place = ctx.GetPlace();
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
Expand Down Expand Up @@ -63,25 +59,28 @@ class FillConstantNPUKernel : public framework::OpKernel<T> {
}
auto shape = GetShape(ctx);

Tensor tensor_tmp(data_type);
tensor_tmp.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&tensor_tmp, value);
Tensor tensor_value(data_type);
tensor_value.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&tensor_value, value);

out_var->mutable_data<T>(shape, ctx.GetPlace());

out_var->mutable_data<T>(shape, place);
const auto& runner = NpuOpRunner("FillD", {tensor_tmp}, {*out_var},
{{"dims", framework::vectorize(shape)}});
runner.Run(stream);
NpuOpRunner runner;
runner.SetType("Fill")
.AddInput(framework::vectorize(shape))
.AddInput(tensor_value)
.AddOutput(*out_var)
.Run(stream);
}
};
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_NPU_KERNEL(
fill_constant,
ops::FillConstantNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::FillConstantNPUKernel<paddle::platform::NPUDeviceContext, bool>,
ops::FillConstantNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::FillConstantNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
fill_constant, paddle::operators::FillConstantNPUKernel<float>,
paddle::operators::FillConstantNPUKernel<bool>,
paddle::operators::FillConstantNPUKernel<int>,
#ifdef PADDLE_WITH_ASCEND_INT64
paddle::operators::FillConstantNPUKernel<int64_t>,
#endif
paddle::operators::FillConstantNPUKernel<paddle::platform::float16>);
Loading