Skip to content

Commit

Permalink
[Eager] Refactor TensorAdd by template (#39282)
Browse files Browse the repository at this point in the history
* Refactor TensorAdd func by template and remove gradient_accumulation in eager

* Remove needless target name

* Use overload instead of template
  • Loading branch information
veyron95 authored Jan 28, 2022
1 parent fc5fa0d commit 0bb3e5f
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 334 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/eager/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
set(eager_deps pten pten_api hook_utils tensor_utils utils global_utils backward pten_tensor legacy autograd_meta grad_node_info grad_tensor_holder gradient_accumulation accumulation_node)
set(eager_deps pten pten_api hook_utils tensor_utils utils global_utils backward pten_tensor legacy autograd_meta grad_node_info grad_tensor_holder accumulation_node)
set(fluid_deps tracer layer proto_desc operator op_registry variable_helper memcpy)
set(generated_deps dygraph_function dygraph_node)

Expand All @@ -12,7 +12,7 @@ add_subdirectory(accumulation)
add_subdirectory(legacy)

cc_library(grad_node_info SRCS grad_node_info.cc DEPS pten pten_api)
cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulation)
cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulator)

cc_library(autograd_meta SRCS autograd_meta.cc DEPS pten pten_api)
cc_library(utils SRCS utils.cc DEPS pten pten_api global_utils layer proto_desc operator op_registry variable_helper memcpy scale_op autograd_meta hook_utils)
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/eager/accumulation/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
cc_library(gradient_accumulation SRCS gradient_accumulation.cc DEPS blas pten pten_api var_type_traits layer math_function)
cc_library(accumulation_node SRCS accumulation_node.cc DEPS gradient_accumulation pten pten_api grad_node_info)
cc_library(accumulation_node SRCS accumulation_node.cc DEPS gradient_accumulator pten pten_api grad_node_info)
4 changes: 2 additions & 2 deletions paddle/fluid/eager/accumulation/accumulation_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
// limitations under the License.

#include "paddle/fluid/eager/accumulation/accumulation_node.h"
#include "paddle/fluid/eager/accumulation/gradient_accumulation.h"
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/imperative/gradient_accumulator.h"

#include "paddle/pten/api/all.h"
#include "paddle/pten/core/dense_tensor.h"
Expand All @@ -35,7 +35,7 @@ static void CopyOrAddTensor(egr::EagerTensor* tensor,
*tensor = t;
} else {
// Accumulation
egr::TensorAdd(t, tensor);
paddle::imperative::TensorAdd<egr::EagerTensor>(t, tensor);
}
}

Expand Down
291 changes: 0 additions & 291 deletions paddle/fluid/eager/accumulation/gradient_accumulation.cc

This file was deleted.

23 changes: 0 additions & 23 deletions paddle/fluid/eager/accumulation/gradient_accumulation.h

This file was deleted.

10 changes: 5 additions & 5 deletions paddle/fluid/eager/grad_tensor_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

#include "paddle/fluid/eager/grad_tensor_holder.h"
#include "paddle/fluid/eager/accumulation/gradient_accumulation.h"
#include "paddle/fluid/imperative/gradient_accumulator.h"

#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/math/math_function.h"
Expand Down Expand Up @@ -72,17 +72,17 @@ void GradTensorHolder::add(size_t slot_id, size_t rank,
} else {
// Accumulation
if (t.initialized() && buffer_tensor.initialized()) {
TensorAdd(t, &buffer_tensor);
paddle::imperative::TensorAdd<egr::EagerTensor>(t, &buffer_tensor);
} else if (t.Var().IsInitialized() &&
buffer_tensor.Var().IsInitialized()) {
VariableAdd(t, &buffer_tensor);
paddle::imperative::VariableAdd(t, &buffer_tensor);
} else if (t.Var().IsInitialized() && buffer_tensor.initialized()) {
// TODO(jiabin): This can be merge to upper if case.
buffer_tensor.SyncToVar();
VariableAdd(t, &buffer_tensor);
paddle::imperative::VariableAdd(t, &buffer_tensor);
} else if (t.initialized() && buffer_tensor.Var().IsInitialized()) {
buffer_tensor.SyncToTensor();
TensorAdd(t, &buffer_tensor);
paddle::imperative::TensorAdd<egr::EagerTensor>(t, &buffer_tensor);
} else {
// Should not happend case
// 1. both not init
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/imperative/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
cc_library(imperative_flag SRCS flags.cc DEPS gflags flags)
IF(WITH_XPU)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS xpu_op_list proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils pten pten_utils)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS xpu_op_list proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils pten pten_utils pten_api)
ELSE()
cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils pten pten_utils)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows_utils var_type_traits op_kernel_type data_transform nan_inf_utils pten pten_utils pten_api)
ENDIF()
cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry)
add_subdirectory(jit)
Expand Down
Loading

0 comments on commit 0bb3e5f

Please sign in to comment.