Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMP] Add check_numerics API. #54301

Merged
merged 14 commits into from
Jun 8, 2023
2 changes: 0 additions & 2 deletions paddle/fluid/framework/details/nan_inf_utils_detail.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
#include "paddle/phi/core/flags.h"
#include "paddle/phi/kernels/funcs/eigen/extensions.h"

PHI_DECLARE_int32(check_nan_inf_level);

namespace paddle {
namespace framework {
namespace details {
Expand Down
16 changes: 14 additions & 2 deletions paddle/fluid/framework/details/nan_inf_utils_detail.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/kernels/check_numerics_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/extensions.h"

PHI_DECLARE_int32(check_nan_inf_level);

namespace paddle {
namespace framework {
namespace details {
Expand Down Expand Up @@ -58,9 +61,18 @@ struct TensorCheckerVisitor {
auto* dev_ctx = reinterpret_cast<Context*>(
platform::DeviceContextPool::Instance().Get(tensor.place()));

phi::DenseTensor stats;
phi::DenseTensor values;
auto file_path = GetNanPath();
phi::CheckNumericsKernel<T, Context>(
*dev_ctx, tensor, op_type, var_name, GetNanInfStackLimit(), file_path);
phi::CheckNumericsKernel<T, Context>(*dev_ctx,
tensor,
op_type,
var_name,
FLAGS_check_nan_inf_level,
GetNanInfStackLimit(),
file_path,
&stats,
&values);
}

std::string op_type;
Expand Down
6 changes: 0 additions & 6 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2699,12 +2699,6 @@ All parameter, weight, gradient are variables in Paddle.
m.def("set_skipped_op_list",
[](const std::string &op_list) { egr::SetSkipOpList(op_list); });

m.def("check_numerics",
[](const std::string &op_name, const paddle::Tensor &tensor) {
VLOG(4) << "Check tensor whether has nan or inf.";
egr::CheckTensorHasNanOrInf(op_name, tensor);
});

BindFleetWrapper(&m);
BindIO(&m);
BindParallelExecutor(m);
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,14 @@
data_type : x
inplace : (x -> out)

- op : check_numerics
args : (Tensor tensor, str op_type = "", str var_name = "", int check_nan_inf_level = 0, int stack_height_limit = -1, str output_dir = "")
output : Tensor(stats), Tensor(values)
infer_meta :
func : CheckNumericsInferMeta
kernel :
func : check_numerics

- op : cholesky
args : (Tensor x, bool upper=false)
output : Tensor
Expand Down
14 changes: 14 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4959,6 +4959,20 @@ void ChannelShuffleInferMeta(const MetaTensor& x,
out->set_dims(output_dims);
}

void CheckNumericsInferMeta(const MetaTensor& tensor,
const std::string& op_type,
const std::string& var_name,
const int check_nan_inf_level,
const int stack_height_limit,
const std::string& output_dir,
MetaTensor* stats,
MetaTensor* values) {
stats->set_dtype(DataType::INT64);
stats->set_dims(phi::make_ddim({3}));
values->set_dtype(DataType::FLOAT32);
values->set_dims(phi::make_ddim({3}));
}

} // namespace phi

PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta);
9 changes: 9 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ void ChannelShuffleInferMeta(const MetaTensor& x,
const std::string& data_format,
MetaTensor* out);

void CheckNumericsInferMeta(const MetaTensor& tensor,
const std::string& op_type,
const std::string& var_name,
const int check_nan_inf_level,
const int stack_height_limit,
const std::string& output_dir,
MetaTensor* stats,
MetaTensor* values);

void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out);

void ClassCenterSampleInferMeta(const MetaTensor& label,
Expand Down
5 changes: 4 additions & 1 deletion paddle/phi/kernels/check_numerics_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ void CheckNumericsKernel(const Context& ctx,
const DenseTensor& tensor,
const std::string& op_type,
const std::string& var_name,
const int check_nan_inf_level,
const int stack_height_limit,
const std::string& output_dir);
const std::string& output_dir,
DenseTensor* stats,
DenseTensor* values);

} // namespace phi
22 changes: 16 additions & 6 deletions paddle/phi/kernels/cpu/check_numerics_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,39 @@ limitations under the License. */
#include "paddle/phi/kernels/check_numerics_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/check_numerics_utils.h"

PHI_DECLARE_int32(check_nan_inf_level);

namespace phi {

template <typename T, typename Context>
void CheckNumericsKernel(const Context& ctx,
const DenseTensor& tensor,
const std::string& op_type,
const std::string& var_name,
const int check_nan_inf_level,
const int stack_height_limit,
const std::string& output_dir) {
const std::string& output_dir,
DenseTensor* stats,
DenseTensor* values) {
// stats stores the checking result of num_nan, num_inf and num_zero.
stats->Resize({static_cast<int64_t>(3)});
int64_t* stats_ptr = ctx.template Alloc<int64_t>(stats);

// values stores the max_value, min_value and mean_value.
values->Resize({static_cast<int64_t>(3)});
float* values_ptr = ctx.template Alloc<float>(values);

std::string cpu_hint_str =
phi::funcs::GetCpuHintString<T>(op_type, var_name, tensor.place());
phi::funcs::CheckNumericsCpuImpl(tensor.data<T>(),
tensor.numel(),
cpu_hint_str,
FLAGS_check_nan_inf_level,
check_nan_inf_level,
"cpu",
output_dir);
output_dir,
stats_ptr,
values_ptr);
}

} // namespace phi
Expand Down
42 changes: 38 additions & 4 deletions paddle/phi/kernels/funcs/check_numerics_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,27 @@ HOSTDEVICE bool NeedPrint(MT max_value UNUSED,
return false;
}

template <typename T>
HOSTDEVICE static void SaveStatsAndValues(int64_t num_nan,
int64_t num_inf,
int64_t num_zero,
T max_value,
T min_value,
T mean_value,
int64_t* stats_ptr,
float* values_ptr) {
if (stats_ptr) {
stats_ptr[0] = num_nan;
stats_ptr[1] = num_inf;
stats_ptr[2] = num_zero;
}
if (values_ptr) {
values_ptr[0] = static_cast<float>(max_value);
values_ptr[1] = static_cast<float>(min_value);
values_ptr[2] = static_cast<float>(mean_value);
}
}

HOSTDEVICE static void PrintAndThrowError(const char* debug_info,
int64_t num_nan,
int64_t num_inf,
Expand Down Expand Up @@ -197,8 +218,10 @@ static void CheckNumericsCpuImpl(const T* value_ptr,
const int64_t numel,
const std::string& cpu_hint_str,
const int check_nan_inf_level,
const std::string log_name = "cpu",
const std::string output_dir = "") {
const std::string log_name,
const std::string output_dir,
int64_t* stats_ptr,
float* values_ptr) {
using MT = typename phi::dtype::template MPTypeTrait<T>::Type;

#ifdef _OPENMP
Expand Down Expand Up @@ -263,6 +286,15 @@ static void CheckNumericsCpuImpl(const T* value_ptr,
mean_value += thread_mean_value[i];
}

SaveStatsAndValues<MT>(num_nan,
num_inf,
num_zero,
max_value,
min_value,
mean_value,
stats_ptr,
values_ptr);

// Write log to file
if (output_dir.size() > 0) {
WriteToFileForDifferentLevel<T, MT>(cpu_hint_str.c_str(),
Expand Down Expand Up @@ -298,8 +330,10 @@ void CheckNumericsCpuImpl(const T* value_ptr,
const int64_t numel,
const std::string& cpu_hint_str,
const int check_nan_inf_level,
const std::string log_name = "cpu",
const std::string output_dir = "") {
const std::string log_name,
const std::string output_dir,
int64_t* stats_ptr,
float* values_ptr) {
using RealType = typename T::value_type;

RealType real_sum = 0.0f, imag_sum = 0.0f;
Expand Down
Loading