Skip to content

Commit

Permalink
cinn(py-dsl): add pybind interface to python (PaddlePaddle#57644)
Browse files Browse the repository at this point in the history
此PR封装了Python DSL需要的C++和Python层的接口

单测和e2e测试见主PR: PaddlePaddle#56393
  • Loading branch information
6clc authored and jiahy0825 committed Oct 16, 2023
1 parent 75f3965 commit a63bfd6
Show file tree
Hide file tree
Showing 16 changed files with 415 additions and 9 deletions.
3 changes: 2 additions & 1 deletion paddle/cinn/pybind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ set(srcs
pe.cc
frontend.cc
framework.cc
utils.cc)
utils.cc
schedule.cc)

if(WITH_CUDA)
message(STATUS "Compile core_api with CUDA support")
Expand Down
3 changes: 3 additions & 0 deletions paddle/cinn/pybind/bind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ PYBIND11_MODULE(core_api, m) {
"framework", "namespace cinn::hlir::framework, CINN framework");
py::module utils =
m.def_submodule("utils", "namespace cinn::utils, CINN framework");
py::module schedule = m.def_submodule(
"schedule", "namespace cinn::ir::schedule, CINN Schedule");

BindRuntime(&runtime);
BindCommon(&common);
Expand All @@ -53,6 +55,7 @@ PYBIND11_MODULE(core_api, m) {
BindFrontend(&frontend);
BindFramework(&framework);
BindUtils(&utils);
BindSchedule(&schedule);
}

} // namespace cinn::pybind
1 change: 1 addition & 0 deletions paddle/cinn/pybind/bind.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,6 @@ void BindPE(pybind11::module *m);
void BindFrontend(pybind11::module *m);
void BindFramework(pybind11::module *m);
void BindUtils(pybind11::module *m);
void BindSchedule(pybind11::module *m);

} // namespace cinn::pybind
19 changes: 19 additions & 0 deletions paddle/cinn/pybind/framework.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "paddle/cinn/common/cinn_value.h"
#include "paddle/cinn/frontend/interpreter.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
Expand Down Expand Up @@ -211,5 +212,23 @@ void BindFramework(pybind11::module *m) {
CINN_NOT_IMPLEMENTED
}
});

py::class_<Instruction> instruction(*m, "Instruction");
instruction
.def(py::init<const Target &,
Scope *,
const std::vector<std::string> &,
const std::vector<std::string> &,
const std::string &>())
.def("run",
[](Instruction &self,
backends::Compiler &compiler,
const std::string fn_name,
std::map<std::string, cinn_pod_value_t> &name_to_pod) {
auto fn_ptr = compiler.Lookup(fn_name);
self.Finalize();
self.SetLoweredFunc(fn_ptr);
self.Run(&name_to_pod);
});
}
} // namespace cinn::pybind
4 changes: 2 additions & 2 deletions paddle/cinn/pybind/ir/ir_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class IRContextNode : public common::Object {
};

/**
* The lifecycle of RAII resource management for IRContextNode
* The life cycle of RAII resource management for IRContextNode
* is determined at the Python.
*/
class IRContext {
Expand Down Expand Up @@ -215,7 +215,7 @@ class IRBuilderNode : public common::Object {
};

/**
* The lifecycle of RAII resource management for IRBuilderNode
* The life cycle of RAII resource management for IRBuilderNode
* is determined at the Python.
*/
class IRBuilder {
Expand Down
15 changes: 14 additions & 1 deletion paddle/cinn/pybind/lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
#include "paddle/cinn/backends/codegen_c.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/lang/buffer.h"
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/lang/placeholder.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
#include "paddle/cinn/pybind/bind.h"
#include "paddle/cinn/pybind/bind_utils.h"

Expand Down Expand Up @@ -148,7 +151,17 @@ void BindModule(py::module *m) {

py::class_<ir::Module::Builder> builder(module, "Builder");
builder.def(py::init<const std::string &, const common::Target &>())
.def("add_function", &ir::Module::Builder::AddFunction)
.def("add_function",
[](ir::Module::Builder &self, ir::LoweredFunc func) {
if (self.GetTargetArch() == Target::Arch::NVGPU) {
#ifdef CINN_WITH_CUDA
auto func_expr = Expr(func);
ir::SetCudaAxisInfo(&func_expr);
optim::OptimizeExprGPU(&(func->body));
#endif
}
self.AddFunction(func);
})
.def("add_buffer", &ir::Module::Builder::AddBuffer)
.def("build", &ir::Module::Builder::Build);
}
Expand Down
62 changes: 60 additions & 2 deletions paddle/cinn/pybind/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,18 @@
#include <cstring>
#include <memory>

#include "paddle/cinn/common/common.h"
#include "paddle/cinn/pybind/bind.h"
#include "paddle/cinn/runtime/cinn_runtime.h"
#include "paddle/cinn/runtime/flags.h"

#ifdef CINN_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime.h>

#include "paddle/cinn/backends/cuda_util.h"
#endif

namespace py = pybind11;
namespace cinn::pybind {
namespace {
Expand Down Expand Up @@ -66,6 +74,48 @@ cinn_buffer_t *CreateBufferFromNumpy(py::array data,
return buffer;
}

cinn_buffer_t *CreateBufferFromNumpy(
py::array data,
common::Target target = common::DefaultHostTarget(),
int align = 0) {
if (target == common::DefaultHostTarget()) {
return CreateBufferFromNumpy(data, cinn_x86_device);
} else if (target.arch == Target::Arch::NVGPU) {
#ifdef CINN_WITH_CUDA
std::vector<int> shape;
std::copy_n(data.shape(), data.ndim(), std::back_inserter(shape));
auto *buffer = new cinn_buffer_t();
buffer->device = cinn_nvgpu_device;
buffer->memory_size = data.nbytes();
CUDA_CALL(cudaMalloc(&buffer->memory, data.nbytes()));
CUDA_CALL(cudaMemcpy(
buffer->memory, data.data(), data.nbytes(), cudaMemcpyHostToDevice));
return buffer;
#else
LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!";
#endif
} else {
CINN_NOT_IMPLEMENTED
}
}

void BufferCopyTo(const cinn_buffer_t &buffer, py::array array) {
void *array_data = array.mutable_data();
if (buffer.device == cinn_x86_device) {
std::memcpy(array_data, buffer.memory, array.nbytes());
} else if (buffer.device == cinn_nvgpu_device) {
#ifdef CINN_WITH_CUDA
CUDA_CALL(cudaMemcpy(
array_data, buffer.memory, array.nbytes(), cudaMemcpyDeviceToHost));
#else
LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!";
#endif

} else {
CINN_NOT_IMPLEMENTED
}
}

py::array BufferHostMemoryToNumpy(cinn_buffer_t &buffer) { // NOLINT
py::dtype dt;
if (buffer.type == cinn_int32_t()) {
Expand Down Expand Up @@ -162,6 +212,7 @@ void BindCinnRuntime(py::module *m) {
.value("cinn_x86_device", cinn_x86_device)
.value("cinn_opencl_device", cinn_opencl_device)
.value("cinn_arm_device", cinn_arm_device)
.value("cinn_nvgpu_device", cinn_nvgpu_device)
.export_values();

py::enum_<cinn_buffer_kind_t> cinn_buffer_kind(*m, "cinn_buffer_kind_t");
Expand Down Expand Up @@ -220,10 +271,17 @@ void BindCinnRuntime(py::module *m) {
.def("set_flag", &cinn_buffer_t::set_flag)
// Python methods
.def("numpy", &BufferHostMemoryToNumpy)
.def(py::init(&CreateBufferFromNumpy),
.def(py::init(py::overload_cast<py::array, cinn_device_kind_t, int>(
&CreateBufferFromNumpy)),
arg("data"),
arg("device"),
arg("align") = 0);
arg("align") = 0)
.def(py::init(py::overload_cast<py::array, common::Target, int>(
&CreateBufferFromNumpy)),
arg("data"),
arg("target"),
arg("align") = 0)
.def("copy_to", &BufferCopyTo);

m->def("cinn_x86_device_interface", &cinn_x86_device_interface)
.def("cinn_buffer_load_float32", &cinn_buffer_load_float32)
Expand Down
151 changes: 151 additions & 0 deletions paddle/cinn/pybind/schedule.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <string>

#include "paddle/cinn/ir/schedule/ir_schedule.h"

namespace py = pybind11;

namespace cinn::pybind {

void BindSchedule(py::module *m) {
py::class_<ir::IRSchedule> ir_schedule(*m, "IRSchedule");
ir_schedule
.def(py::init<const ir::ModuleExpr &,
utils::LinearRandomEngine::StateType,
bool,
utils::ErrorMessageLevel>(),
py::arg("modexpr"),
py::arg("rand_seed") = -1,
py::arg("debug_flag") = false,
py::arg("err_msg_level") = utils::ErrorMessageLevel::kGeneral)
.def_static(
"make",
[](ir::LoweredFunc &ir_func) {
ir::ModuleExpr *module_expr = new ir::ModuleExpr({ir_func->body});
auto scheduler = std::make_unique<ir::IRSchedule>(*module_expr);
return scheduler;
})
.def("fuse",
py::overload_cast<const std::vector<Expr> &>(&ir::IRSchedule::Fuse))
.def("split",
py::overload_cast<const Expr &, const std::vector<int> &>(
&ir::IRSchedule::Split),
py::arg("loop"),
py::arg("factors"))
.def("compute_at",
py::overload_cast<const Expr &, const Expr &, bool>(
&ir::IRSchedule::ComputeAt),
py::arg("block"),
py::arg("loop"),
py::arg("keep_unit_loops") = false)
.def("simple_compute_at",
py::overload_cast<const Expr &, const Expr &>(
&ir::IRSchedule::SimpleComputeAt),
py::arg("block"),
py::arg("loop"))
.def("reverse_compute_at",
py::overload_cast<const Expr &, const Expr &, bool>(
&ir::IRSchedule::ReverseComputeAt),
py::arg("block"),
py::arg("loop"),
py::arg("keep_unit_loops") = false)
.def("cache_read",
py::overload_cast<const Expr &, int, const std::string &>(
&ir::IRSchedule::CacheRead))
.def("cache_write",
py::overload_cast<const Expr &, int, const std::string &>(
&ir::IRSchedule::CacheWrite))
.def("sync_threads",
py::overload_cast<const Expr &, bool>(&ir::IRSchedule::SyncThreads),
py::arg("ir_node"),
py::arg("after_node") = true)
.def("set_buffer",
py::overload_cast<Expr &, const std::string &, bool>(
&ir::IRSchedule::SetBuffer),
py::arg("block"),
py::arg("memory_type"),
py::arg("fixed") = false)
.def("reorder",
py::overload_cast<const std::vector<Expr> &>(
&ir::IRSchedule::Reorder))
.def("parallel",
py::overload_cast<const Expr &>(&ir::IRSchedule::Parallel))
.def("vectorize",
py::overload_cast<const Expr &, int>(&ir::IRSchedule::Vectorize))
.def("unroll", py::overload_cast<const Expr &>(&ir::IRSchedule::Unroll))

.def("compute_inline",
py::overload_cast<const Expr &>(&ir::IRSchedule::ComputeInline))
.def("reverse_compute_inline",
py::overload_cast<const Expr &>(
&ir::IRSchedule::ReverseComputeInline))
.def("bind", &ir::IRSchedule::Bind)
.def("copy_transform_and_loop_info",
py::overload_cast<const Expr &, const Expr &>(
&ir::IRSchedule::CopyTransformAndLoopInfo))
.def("rfactor",
py::overload_cast<const Expr &, int>(&ir::IRSchedule::Rfactor))
.def("annotate",
py::overload_cast<const Expr &,
const std::string &,
const ir::attr_t &>(&ir::IRSchedule::Annotate))
.def("unannotate",
py::overload_cast<Expr &, const std::string &>(
&ir::IRSchedule::Unannotate))
.def("flatten_loops",
py::overload_cast<const std::vector<Expr> &, const bool>(
&ir::IRSchedule::FlattenLoops),
py::arg("loops"),
py::arg("force_flat") = false)
.def("sample_perfect_tile",
py::overload_cast<const Expr &, int, int, const std::vector<int> &>(
&ir::IRSchedule::SamplePerfectTile),
py::arg("loop"),
py::arg("n"),
py::arg("max_innermost_factor"),
py::arg("decision") = std::vector<int>())
.def("sample_categorical",
py::overload_cast<const std::vector<int> &,
const std::vector<float> &,
const std::vector<int> &>(
&ir::IRSchedule::SampleCategorical),
py::arg("candidates"),
py::arg("probs"),
py::arg("decision") = std::vector<int>())
.def("get_module",
py::overload_cast<>(&ir::IRSchedule::GetModule, py::const_))
.def("get_root_block", &ir::IRSchedule::GetRootBlock)
.def("get_block",
py::overload_cast<const std::string &>(&ir::IRSchedule::GetBlock,
py::const_))
.def("get_all_blocks",
py::overload_cast<>(&ir::IRSchedule::GetAllBlocks, py::const_))
.def("get_loops",
py::overload_cast<const std::string &>(&ir::IRSchedule::GetLoops,
py::const_))
.def("get_name2loops_dict",
[](const ir::IRSchedule &self, const std::string &block_name) {
std::vector<ir::Expr> loops = self.GetLoops(block_name);
std::map<std::string, ir::Expr> name2loops;
for (const ir::Expr &loop : loops) {
name2loops[loop.As<ir::For>()->loop_var->name] = loop;
}
return name2loops;
});
}
} // namespace cinn::pybind
5 changes: 5 additions & 0 deletions paddle/cinn/pybind/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
// limitations under the License.

#include "paddle/cinn/pybind/bind.h"
#include "paddle/cinn/utils/error.h"
#include "paddle/cinn/utils/profiler.h"
#include "paddle/cinn/utils/random_engine.h"

namespace py = pybind11;

Expand Down Expand Up @@ -69,6 +71,9 @@ void BindUtils(py::module *m) {
"type",
[](HostEvent &self) -> const EventType & { return self.type_; },
[](HostEvent &self, const EventType &v) { self.type_ = v; });

py::class_<utils::LinearRandomEngine>(*m, "LinearRandomEngine");
py::class_<utils::ErrorMessageLevel>(*m, "ErrorMessageLevel");
}

} // namespace pybind
Expand Down
1 change: 1 addition & 0 deletions python/cinn/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from .core_api.framework import ( # noqa: F401
Instruction,
NodeAttr,
Operator,
OpValueType,
Expand Down
Loading

0 comments on commit a63bfd6

Please sign in to comment.