Skip to content

Commit

Permalink
add backward infer log (#59543)
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio authored Dec 5, 2023
1 parent 787455c commit a3ed849
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 24 deletions.
28 changes: 6 additions & 22 deletions paddle/phi/api/lib/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,8 @@ std::shared_ptr<phi::distributed::DistTensor> ReshardApiInputToKernelInput(
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
if (ReshardIsNeededWithPartial(dist_tensor->dist_attr(),
tensor_dist_attr)) {
VLOG(6) << "ApiIn to Replicated KernelIn - "
VLOG(4) << "ApiIn to Replicated KernelIn - \n"
<< "Input tensor: " << tensor.name()
<< ReshardDebugInfo(*dist_tensor, tensor_dist_attr);
auto* func = phi::distributed::ChooseProperReshardFunction(
*dist_tensor, tensor_dist_attr);
Expand Down Expand Up @@ -699,7 +700,8 @@ ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx,
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
if (ReshardIsNeededWithPartial(dist_tensor->dist_attr(), dist_attr)) {
VLOG(6) << "Vector ApiIn to Replicated KernelIn - "
VLOG(4) << "Vector ApiIn to Replicated KernelIn - \n"
<< "Input tensor: " << tensors[i].name()
<< ReshardDebugInfo(*dist_tensor, dist_attr);
auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor,
dist_attr);
Expand Down Expand Up @@ -843,25 +845,6 @@ void SetInplaceOutputCorrectDistAttr(
dev_ctx, tensors, paddle::get<1>(dist_attr), use_general_spmd_rule);
}

void ReshardOutputPartialAxisToReplicated(
phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor) {
if (out_tensor->dist_attr().is_partial()) {
auto dist_attr = out_tensor->dist_attr();
dist_attr.clean_partial_status();
if (!IsCurRankInMesh(out_tensor->dist_attr().process_mesh())) {
VLOG(6) << "DistTensor is not in mesh, just clear its partial status and "
"skip reshard it to replicated.";
out_tensor->unsafe_set_dist_attr(dist_attr);
return;
}
VLOG(6) << "FwdAPI Output P2R - "
<< ReshardDebugInfo(*out_tensor, dist_attr);
auto* func =
phi::distributed::ChooseProperReshardFunction(*out_tensor, dist_attr);
func->Eval(dev_ctx, *out_tensor, dist_attr, out_tensor);
}
}

void ReshardKernelOutputToApiOutput(
phi::DeviceContext* dev_ctx,
const std::shared_ptr<phi::distributed::DistTensor>& src_tensor,
Expand All @@ -876,7 +859,8 @@ void ReshardKernelOutputToApiOutput(
static_cast<phi::distributed::DistTensor*>(tensor_out.get());
dist_tensor->unsafe_set_dims(src_tensor->dims());
if (ReshardIsNeeded(src_tensor->dist_attr(), dist_tensor->dist_attr())) {
VLOG(6) << "BwdAPI KernelOut to ApiOut - "
VLOG(4) << "BwdAPI KernelOut to ApiOut - \n"
<< "Input tensor: " << dst_tensor->name()
<< ReshardDebugInfo(*src_tensor, dist_tensor->dist_attr());
auto* func = phi::distributed::ChooseProperReshardFunction(
*src_tensor, dist_tensor->dist_attr());
Expand Down
9 changes: 7 additions & 2 deletions paddle/phi/api/yaml/generator/dist_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,11 @@
}}"""
INFER_SPMD_TEMPLATE = """
auto spmd_info = phi::distributed::{}({});
DebugInfoForInferSpmd("{}", spmd_info);
"""
GENERAL_INFER_SPMD_TEMPLATE = """
auto spmd_info = phi::distributed::VariadicReplicatedInferSpmdDynamic({});
DebugInfoForInferSpmd("{}", spmd_info);
"""
UNSUPPORTED_INFER_SPMD_COMMENT_TEMPLATE = """
// API `{}` does not support InferSpmd now
Expand Down Expand Up @@ -860,7 +862,9 @@ def generate_specialized_infer_spmd_code(self) -> str:
infer_spmd_code = ""
infer_spmd_func_code = self.infer_meta['spmd_rule']
infer_spmd_code = INFER_SPMD_TEMPLATE.format(
infer_spmd_func_code, input_args_code[:-2]
infer_spmd_func_code,
input_args_code[:-2],
self.api,
)
self.generate_infer_spmd = True

Expand Down Expand Up @@ -921,7 +925,8 @@ def generate_general_infer_spmd_code(self) -> str:
return UNSUPPORTED_INFER_SPMD_COMMENT_TEMPLATE.format(self.api)

infer_spmd_code = GENERAL_INFER_SPMD_TEMPLATE.format(
input_args_code[:-2]
input_args_code[:-2],
self.api,
)
self.generate_infer_spmd = True
self.generate_general_infer_spmd = True
Expand Down
17 changes: 17 additions & 0 deletions paddle/phi/infermeta/spmd_rules/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -503,5 +503,22 @@ std::vector<int64_t> GetDimsMappingForAxes(
return dims_mapping;
}

void DebugInfoForInferSpmd(const std::string& rule_name,
const SpmdInfo& infer_result) {
VLOG(4) << "The infer spmd result of " << rule_name << " is as below:";
auto dist_attr_for_inputs = infer_result.first;
VLOG(4) << "======= The dist attr of inputs after inferspmd =======";
for (size_t i = 0; i < dist_attr_for_inputs.size(); ++i) {
VLOG(4) << "The dist attr of the " << i << "th input need to be "
<< PADDLE_GET(TensorDistAttr, dist_attr_for_inputs[i]);
}
VLOG(4) << "======= The dist attr of outputs after inferspmd =======";
auto dist_attr_for_outputs = infer_result.second;
for (size_t i = 0; i < dist_attr_for_outputs.size(); ++i) {
VLOG(4) << "The dist attr of the " << i << "th output need to be "
<< PADDLE_GET(TensorDistAttr, dist_attr_for_outputs[i]);
}
}

} // namespace distributed
} // namespace phi
3 changes: 3 additions & 0 deletions paddle/phi/infermeta/spmd_rules/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,5 +202,8 @@ std::vector<int64_t> GetDimsMappingForAxes(
const std::unordered_map<std::string, int64_t>& axis_to_dim_map,
const bool unsharded_miss_axis = false);

void DebugInfoForInferSpmd(const std::string& rule_name,
const SpmdInfo& infer_result);

} // namespace distributed
} // namespace phi

0 comments on commit a3ed849

Please sign in to comment.