Skip to content

Commit

Permalink
[Prim] [NewIR]Automatic code generation for vjp rules (#56512)
Browse files Browse the repository at this point in the history
* support ir api form prim

* support ir api for prim

* support vjp prim mode in new ir

* remove useless code

* remove useless code

* auto code generator for primitive vjp methods

* add vjp and backend manual and fix segment fault

---------

Co-authored-by: cyber-pioneer <chenzhuo@tju.edu.cn>
Co-authored-by: cxxly <chenxx_id@163.com>
  • Loading branch information
3 people authored Aug 22, 2023
1 parent f02261b commit 14b81d5
Show file tree
Hide file tree
Showing 24 changed files with 881 additions and 770 deletions.
5 changes: 5 additions & 0 deletions paddle/fluid/primitive/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
backend/generated/*.cc
backend/generated/*.h
primitive/primitive.h
rule/vjp/generated/generated_vjp.h
rule/vjp/generated/generated_vjp.cc
1 change: 1 addition & 0 deletions paddle/fluid/primitive/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(utils)
add_subdirectory(backend)
add_subdirectory(rule)
add_subdirectory(codegen)
11 changes: 9 additions & 2 deletions paddle/fluid/primitive/backend/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
set(eager_backend_files
${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/backend/generated/generated_eager_backend.cc
)
if(WITH_PYTHON OR NOT ON_INFER)
cc_library(
primitive_backend_eager_experimental
SRCS eager_backend.cc
SRCS ${eager_backend_files}
DEPS final_dygraph_function eager_utils phi)
endif()
set(static_backend_files
${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/backend/generated/generated_static_backend.cc
${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/backend/manual/manual_static_backend.cc
)
cc_library(
primitive_backend_static_experimental
SRCS static_backend.cc
SRCS ${static_backend_files}
DEPS pd_dialect_api)
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,5 @@

#pragma once

#include <string>
#include <vector>

#include "paddle/phi/api/include/tensor.h"

namespace paddle {
namespace primitive {
namespace backend {} // namespace backend
} // namespace primitive
} // namespace paddle
#include "paddle/fluid/primitive/backend/generated/generated_backend.h"
#include "paddle/fluid/primitive/backend/manual/manual_backend.h"
38 changes: 38 additions & 0 deletions paddle/fluid/primitive/backend/manual/manual_backend.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <vector>

#include "paddle/phi/api/include/tensor.h"

namespace paddle {
namespace primitive {
namespace backend {

using Tensor = paddle::Tensor;
using Scalar = paddle::experimental::Scalar;
using IntArray = paddle::experimental::IntArray;
using DataType = phi::DataType;

template <typename T>
std::vector<Tensor> concat_grad(const std::vector<Tensor>& x,
const Tensor& out_grad,
const Tensor& axis);

} // namespace backend
} // namespace primitive
} // namespace paddle
59 changes: 59 additions & 0 deletions paddle/fluid/primitive/backend/manual/manual_static_backend.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_api.h"
#include "paddle/fluid/primitive/backend/manual/manual_backend.h"
#include "paddle/fluid/primitive/primitive/primitive.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"

namespace paddle {
namespace primitive {
namespace backend {

using LazyTensor = paddle::primitive::LazyTensor;

template <>
std::vector<Tensor> concat_grad<LazyTensor>(const std::vector<Tensor>& x,
const Tensor& out_grad,
const Tensor& axis) {
std::vector<ir::OpResult> x_res;
for (uint64_t idx = 0; idx < x.size(); idx++) {
x_res.emplace_back(std::static_pointer_cast<LazyTensor>(x[idx].impl())
->getValue()
.dyn_cast<ir::OpResult>());
}

ir::OpResult out_grad_res =
std::static_pointer_cast<LazyTensor>(out_grad.impl())
->getValue()
.dyn_cast<ir::OpResult>();

ir::OpResult axis_res = std::static_pointer_cast<LazyTensor>(axis.impl())
->getValue()
.dyn_cast<ir::OpResult>();

std::vector<ir::OpResult> op_res =
paddle::dialect::concat_grad(x_res, out_grad_res, axis_res);

std::vector<Tensor> op_result;
for (uint64_t idx = 0; idx < op_res.size(); idx++) {
op_result.emplace_back(
std::make_shared<primitive::LazyTensor>(op_res[idx]));
}
return op_result;
}

} // namespace backend
} // namespace primitive
} // namespace paddle
Loading

0 comments on commit 14b81d5

Please sign in to comment.