From e8cc3a976e82cc1e5cc1eff466b96c2162152beb Mon Sep 17 00:00:00 2001 From: jiweibo Date: Thu, 25 Aug 2022 05:06:50 +0000 Subject: [PATCH] fix params sync multi times problem --- .../inference/analysis/passes/convert_to_mixed_precision.cc | 1 + .../analysis/passes/ir_params_sync_among_devices_pass.cc | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc index d786159078dbc..87750d713c6d4 100644 --- a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc +++ b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc @@ -368,6 +368,7 @@ void ProcessInputNode( in_var_type == framework::proto::VarType::FP32) { if (WeightsShouldNotConvert(in_node)) return; in_var->SetDataType(to_type); + in_var_type = to_type; } else if (!in_var->Persistable() && IsFloatVarType(in_var_type) && in_var_type != to_type) { AddCastOp(graph, diff --git a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc index bc330354e71fc..3948ca8a59fd5 100644 --- a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h" +#include #include #include "paddle/fluid/framework/data_layout.h" @@ -113,6 +114,7 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) { reserve_cpu_weights = true; } + std::unordered_set visited; for (auto *node : paddle::framework::ir::TopologySortOperations(graph)) { if (!node->IsOp()) continue; if (node->Op()->Type() == "feed" || node->Op()->Type() == "fetch") continue; @@ -126,6 +128,8 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) { } continue; } + if (visited.count(var_name)) continue; + visited.insert(var_name); auto *var = scope->FindLocalVar(var_name); PADDLE_ENFORCE_NOT_NULL(var, platform::errors::PreconditionNotMet(