From a3ed849a3cc9248a50aaa521eb62c1adf91684fb Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Tue, 5 Dec 2023 13:03:25 +0800 Subject: [PATCH] add backward infer log (#59543) --- paddle/phi/api/lib/data_transform.cc | 28 ++++--------------- paddle/phi/api/yaml/generator/dist_api_gen.py | 9 ++++-- paddle/phi/infermeta/spmd_rules/utils.cc | 17 +++++++++++ paddle/phi/infermeta/spmd_rules/utils.h | 3 ++ 4 files changed, 33 insertions(+), 24 deletions(-) diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 2c2113cf6bd89..f5ba60fac11a3 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -662,7 +662,8 @@ std::shared_ptr ReshardApiInputToKernelInput( static_cast(tensor_in.get()); if (ReshardIsNeededWithPartial(dist_tensor->dist_attr(), tensor_dist_attr)) { - VLOG(6) << "ApiIn to Replicated KernelIn - " + VLOG(4) << "ApiIn to Replicated KernelIn - \n" + << "Input tensor: " << tensor.name() << ReshardDebugInfo(*dist_tensor, tensor_dist_attr); auto* func = phi::distributed::ChooseProperReshardFunction( *dist_tensor, tensor_dist_attr); @@ -699,7 +700,8 @@ ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* dist_tensor = static_cast(tensor_in.get()); if (ReshardIsNeededWithPartial(dist_tensor->dist_attr(), dist_attr)) { - VLOG(6) << "Vector ApiIn to Replicated KernelIn - " + VLOG(4) << "Vector ApiIn to Replicated KernelIn - \n" + << "Input tensor: " << tensors[i].name() << ReshardDebugInfo(*dist_tensor, dist_attr); auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor, dist_attr); @@ -843,25 +845,6 @@ void SetInplaceOutputCorrectDistAttr( dev_ctx, tensors, paddle::get<1>(dist_attr), use_general_spmd_rule); } -void ReshardOutputPartialAxisToReplicated( - phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor) { - if (out_tensor->dist_attr().is_partial()) { - auto dist_attr = out_tensor->dist_attr(); - dist_attr.clean_partial_status(); - if (!IsCurRankInMesh(out_tensor->dist_attr().process_mesh())) { - VLOG(6) << "DistTensor is not in mesh, just clear its partial status and " - "skip reshard it to replicated."; - out_tensor->unsafe_set_dist_attr(dist_attr); - return; - } - VLOG(6) << "FwdAPI Output P2R - " - << ReshardDebugInfo(*out_tensor, dist_attr); - auto* func = - phi::distributed::ChooseProperReshardFunction(*out_tensor, dist_attr); - func->Eval(dev_ctx, *out_tensor, dist_attr, out_tensor); - } -} - void ReshardKernelOutputToApiOutput( phi::DeviceContext* dev_ctx, const std::shared_ptr& src_tensor, @@ -876,7 +859,8 @@ void ReshardKernelOutputToApiOutput( static_cast(tensor_out.get()); dist_tensor->unsafe_set_dims(src_tensor->dims()); if (ReshardIsNeeded(src_tensor->dist_attr(), dist_tensor->dist_attr())) { - VLOG(6) << "BwdAPI KernelOut to ApiOut - " + VLOG(4) << "BwdAPI KernelOut to ApiOut - \n" + << "Input tensor: " << dst_tensor->name() << ReshardDebugInfo(*src_tensor, dist_tensor->dist_attr()); auto* func = phi::distributed::ChooseProperReshardFunction( *src_tensor, dist_tensor->dist_attr()); diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index 9e9a64da733c6..4e46db4b84e84 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -99,9 +99,11 @@ }}""" INFER_SPMD_TEMPLATE = """ auto spmd_info = phi::distributed::{}({}); + DebugInfoForInferSpmd("{}", spmd_info); """ GENERAL_INFER_SPMD_TEMPLATE = """ auto spmd_info = phi::distributed::VariadicReplicatedInferSpmdDynamic({}); + DebugInfoForInferSpmd("{}", spmd_info); """ UNSUPPORTED_INFER_SPMD_COMMENT_TEMPLATE = """ // API `{}` does not support InferSpmd now @@ -860,7 +862,9 @@ def generate_specialized_infer_spmd_code(self) -> str: infer_spmd_code = "" infer_spmd_func_code = self.infer_meta['spmd_rule'] infer_spmd_code = INFER_SPMD_TEMPLATE.format( - infer_spmd_func_code, input_args_code[:-2] + infer_spmd_func_code, + input_args_code[:-2], + self.api, ) self.generate_infer_spmd = True @@ -921,7 +925,8 @@ def generate_general_infer_spmd_code(self) -> str: return UNSUPPORTED_INFER_SPMD_COMMENT_TEMPLATE.format(self.api) infer_spmd_code = GENERAL_INFER_SPMD_TEMPLATE.format( - input_args_code[:-2] + input_args_code[:-2], + self.api, ) self.generate_infer_spmd = True self.generate_general_infer_spmd = True diff --git a/paddle/phi/infermeta/spmd_rules/utils.cc b/paddle/phi/infermeta/spmd_rules/utils.cc index f6ba6a89739b2..6892b06bf493c 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.cc +++ b/paddle/phi/infermeta/spmd_rules/utils.cc @@ -503,5 +503,22 @@ std::vector GetDimsMappingForAxes( return dims_mapping; } +void DebugInfoForInferSpmd(const std::string& rule_name, + const SpmdInfo& infer_result) { + VLOG(4) << "The infer spmd result of " << rule_name << " is as below:"; + auto dist_attr_for_inputs = infer_result.first; + VLOG(4) << "======= The dist attr of inputs after inferspmd ======="; + for (size_t i = 0; i < dist_attr_for_inputs.size(); ++i) { + VLOG(4) << "The dist attr of the " << i << "th input need to be " + << PADDLE_GET(TensorDistAttr, dist_attr_for_inputs[i]); + } + VLOG(4) << "======= The dist attr of outputs after inferspmd ======="; + auto dist_attr_for_outputs = infer_result.second; + for (size_t i = 0; i < dist_attr_for_outputs.size(); ++i) { + VLOG(4) << "The dist attr of the " << i << "th output need to be " + << PADDLE_GET(TensorDistAttr, dist_attr_for_outputs[i]); + } +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/utils.h b/paddle/phi/infermeta/spmd_rules/utils.h index 293742123e9f8..c58e80ba02608 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.h +++ b/paddle/phi/infermeta/spmd_rules/utils.h @@ -202,5 +202,8 @@ std::vector GetDimsMappingForAxes( const std::unordered_map& axis_to_dim_map, const bool unsharded_miss_axis = false); +void DebugInfoForInferSpmd(const std::string& rule_name, + const SpmdInfo& infer_result); + } // namespace distributed } // namespace phi