From 4deae3d6e3f5eda78dbf425946e10ffea9f2e92e Mon Sep 17 00:00:00 2001 From: jiangcheng Date: Fri, 2 Sep 2022 15:23:09 +0800 Subject: [PATCH 1/8] add `isclose` op (#880) * add isclose op * fix isclose pe bug * fix cinn_nvgpu_isnan_fp32 kernel registry return not bool bug * change isclose from kBroadcast to kElemWise --- cinn/frontend/base_builder.cc | 10 ++ cinn/frontend/base_builder.h | 4 + cinn/hlir/op/broadcast.cc | 135 +++++++++++++++++++-------- cinn/hlir/pe/broadcast.cc | 52 +++++++++++ cinn/hlir/pe/broadcast.h | 8 ++ cinn/pybind/frontend.cc | 9 +- cinn/runtime/cuda/cuda_intrinsics.cc | 14 ++- python/tests/ops/test_isclose_op.py | 97 +++++++++++++++++++ 8 files changed, 286 insertions(+), 43 deletions(-) create mode 100644 python/tests/ops/test_isclose_op.py diff --git a/cinn/frontend/base_builder.cc b/cinn/frontend/base_builder.cc index 10668fd317..ccfd10c950 100644 --- a/cinn/frontend/base_builder.cc +++ b/cinn/frontend/base_builder.cc @@ -308,6 +308,16 @@ Variable BaseBuilder::ScatterAdd(const Variable& operand, const Variable& update return instr.GetOutput(0); } +Variable BaseBuilder::IsClose(const Variable& x, const Variable& y, float rtol, float atol, bool equal_nan) { + Instruction instr("isclose", {x, y}); + instr.SetAttr("rtol", rtol); + instr.SetAttr("atol", atol); + instr.SetAttr("equal_nan", equal_nan); + InferShape(instr); + AppendInstruction(instr); + return instr.GetOutput(0); +} + Variable BaseBuilder::UnaryOp(const std::string& op_type, const Variable& operand) { Instruction instr(op_type, {operand}); InferShape(instr); diff --git a/cinn/frontend/base_builder.h b/cinn/frontend/base_builder.h index aca6e717fa..1d2e165cae 100644 --- a/cinn/frontend/base_builder.h +++ b/cinn/frontend/base_builder.h @@ -135,6 +135,10 @@ class BaseBuilder { return FillConstant(shape, static_cast(value), name, common::Type2Str(common::type_of()), force_cpu); } + // This operator checks if all x and y satisfy the condition: |x - y| <= atol + rtol * |y| + Variable IsClose( + const Variable& x, const Variable& y, float rtol = 1e-05f, float atol = 1e-08f, bool equal_nan = false); + protected: void InferShape(Instruction instr) const; diff --git a/cinn/hlir/op/broadcast.cc b/cinn/hlir/op/broadcast.cc index 78a9a1ed85..d328f28d57 100644 --- a/cinn/hlir/op/broadcast.cc +++ b/cinn/hlir/op/broadcast.cc @@ -189,46 +189,9 @@ std::shared_ptr StrategyForBroadcastTo(const framework::NodeAttr &at *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; }); - framework::CINNSchedule broadcast_to_schedule([=](lang::Args args, lang::RetValue *ret) { - if (FLAGS_cinn_ir_schedule) { - CHECK(!args.empty()) << "The input argument of broadcast_to schedule is empty! Please check."; - CINNValuePack arg_pack = args[0]; - std::vector vec_ast; - for (int i = 0; i < arg_pack.size(); i++) { - if (arg_pack[i].is_expr()) { - Expr temp = arg_pack[i]; - vec_ast.emplace_back(temp); - } - } - CHECK(!vec_ast.empty()); - ir::ModuleExpr mod_expr(vec_ast); - ir::IRSchedule ir_sch(mod_expr); - ir_sch.MergeExprs(); - if (target.arch == Target::Arch::NVGPU) { - pe::IRCudaScheduleInjective(ir_sch, out_shape, target); - } else if (target.arch == Target::Arch::X86) { - pe::IRScheduleInjectiveCPU(ir_sch, out_shape, target, false); - } - std::vector res{CINNValue(ir_sch.GetModule().GetExprs().at(0))}; - *ret = CINNValuePack{res}; - } else { - CHECK(!args.empty()) << "The input argument of broadcast_to schedule is empty! Please check."; - CINNValuePack arg_pack = args[0]; - CHECK_EQ(arg_pack.size(), 2UL); - Expr Out = arg_pack[0]; - poly::StageMap stages = arg_pack.back(); - CHECK(Out.as_tensor()); - if (target.arch == Target::Arch::NVGPU) { - pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], out_shape, target); - } else if (target.arch == Target::Arch::X86) { - pe::ScheduleInjectiveCPU(stages[Out.as_tensor_ref()], out_shape, target); - } - *ret = arg_pack; - } - }); - auto strategy = std::make_shared(); - strategy->AddImpl(broadcast_to_compute, broadcast_to_schedule, "strategy.broadcast_to.x86", 1); + strategy->AddImpl( + broadcast_to_compute, framework::GetInjectiveScheduleFunc(output_shapes, target), "strategy.broadcast_to.x86", 1); return strategy; } @@ -287,6 +250,89 @@ std::shared_ptr StrategyForBroadcastGrad(const framework::NodeAttr & << "Gradient operator will be decomposed into several primitive operators. Please Use Decomposer Program Pass."; } +std::shared_ptr StrategyForIsClose(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector &output_shapes, + const Target &target) { + float rtol = 1e-05f, atol = 1e-08f; + bool equal_nan = false; + + if (attrs.attr_store.count("rtol")) { + rtol = absl::get(attrs.attr_store.at("rtol")); + } + if (attrs.attr_store.count("atol")) { + atol = absl::get(attrs.attr_store.at("atol")); + } + if (attrs.attr_store.count("equal_nan")) { + equal_nan = absl::get(attrs.attr_store.at("equal_nan")); + } + + framework::CINNCompute isclose_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of isclose compute is empty! Please check."; + CINNValuePack pack_args = args[0]; + int input_size = pack_args.size(); + + std::string tensor_name = UniqName("IsClose_output"); + if (FLAGS_cinn_ir_schedule) { + // the last pack argument is the output tensor name + tensor_name = pack_args.back().operator std::string(); + --input_size; + } + CHECK_EQ(input_size, 2) << "The input number of isclose should be 2, but here " << input_size << "! Please check."; + + // the input tensor are in front + Expr x_expr = pack_args[0]; + CHECK(x_expr.as_tensor()); + auto x_tensor = x_expr.as_tensor_ref(); + + Expr y_expr = pack_args[1]; + CHECK(y_expr.as_tensor()); + auto y_tensor = y_expr.as_tensor_ref(); + + auto out = pe::IsClose(x_tensor, y_tensor, rtol, atol, equal_nan, tensor_name); + + auto stages = CreateStages({out}); + *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl( + isclose_compute, framework::GetInjectiveScheduleFunc(output_shapes, target), "strategy.assertisclose", 1); + + return strategy; +} + +std::vector InferShapeForIsClose(const std::vector &input_shapes, + const framework::AttrMapType &attrs) { + int input_size = input_shapes.size(); + CHECK_EQ(input_size, 2UL) << "The input number of isclose should be a multiple of 2, but here " << input_size + << "! Please check."; + + CHECK(input_shapes[0] == input_shapes[1]) + << "The two inputs shape of isclose should be equal, but here x:[" << cinn::utils::Join(input_shapes[0], ",") + << "] != y:[" << cinn::utils::Join(input_shapes[1], ",") << "] ! Please check."; + return {input_shapes[0]}; +} + +std::vector InferDtypeForIsClose(const std::vector &inputs_type, const framework::AttrMapType &attrs) { + int input_size = inputs_type.size(); + CHECK_EQ(input_size, 2UL) << "The input number of isclose should be a multiple of 2, but here " << input_size + << "! Please check."; + CHECK(inputs_type[0] == inputs_type[1]) + << "The two inputs dtype sof isclose should be equal, but here x:" << inputs_type[0] << " != y:" << inputs_type[1] + << "! Please check."; + + return {Bool()}; +} + +std::vector> InferLayoutForIsClose(const std::vector> &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + return {{""}, input_layouts}; +} + StrategyForBinary(elementwise_add, Add); StrategyForBinary(elementwise_mul, Multiply); @@ -387,6 +433,17 @@ CINN_REGISTER_HELPER(broadcast_ops) { .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kBroadcast) .set_support_level(4); + CINN_REGISTER_OP(isclose) + .describe("This operator checks if all x and y satisfy the condition: |x - y| <= atol + rtol * |y|") + .set_num_inputs(2) + .set_num_outputs(1) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForIsClose) + .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForIsClose)) + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForIsClose)) + .set_attr("inferlayout", MakeOpFunction(cinn::hlir::op::InferLayoutForIsClose)) + .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElemWise) + .set_support_level(4); + return true; } diff --git a/cinn/hlir/pe/broadcast.cc b/cinn/hlir/pe/broadcast.cc index d74ed7405f..85796cf2c4 100644 --- a/cinn/hlir/pe/broadcast.cc +++ b/cinn/hlir/pe/broadcast.cc @@ -277,6 +277,58 @@ Tensor BroadcastTo(const Tensor& A, out_name); } +ir::Tensor IsClose( + const ir::Tensor& x, const ir::Tensor& y, float rtol, float atol, bool equal_nan, const std::string& out_name) { + CHECK_EQ(x->shape.size(), y->shape.size()) + << "The two inputs shape dimension of is close should be equal! Please check."; + + std::vector x_shape, y_shape; + for (int i = 0; i < x->shape.size(); ++i) { + x_shape.emplace_back(x->shape[i].as_int32()); + y_shape.emplace_back(y->shape[i].as_int32()); + } + + CHECK(x_shape == y_shape) << "The two inputs shape of isclose should be equal, but here x:[" + << cinn::utils::Join(x_shape, ",") << "] != y:[" << cinn::utils::Join(y_shape, ",") + << "] ! Please check."; + + // For each a=x[i], b=y[i]: + // ``` + // if (isnan(a) || isnan(b)) { + // out = equal_nan && isnan(a) == isnan(b); + // } else { + // T left = (a > b ? a - b : b - a); + // T right = atol + (b > 0 ? rtol * b : (-rtol) * b); + // T diff = (left > right ? left - right : right - left); + // out = a == b || left <= right || diff <= 1e-15; + // } + // ``` + return Compute( + x->shape, + [=](const std::vector& indice) { + // check whether x or y is nan + auto a = x(indice), b = y(indice); + auto check_x_nan = lang::IsNan(a); + auto check_y_nan = lang::IsNan(b); + + // out = equal_nan && isnan(a) == isnan(b); + auto check_nan_same = Expr(equal_nan) && ir::EQ::Make(check_x_nan, check_y_nan); + + // check whether x and y are close + // T left = (a > b ? a - b : b - a); + auto left = ir::Select::Make(a > b, a - b, b - a); + // T right = atol + (b > 0 ? rtol * b : (-rtol) * b); + auto right = atol + ir::Select::Make(b > 0.0f, rtol * b, (-rtol) * b); + // T diff = (left > right ? left - right : right - left); + auto diff = ir::Select::Make(left > right, left - right, right - left); + // out = a == b || left <= right || diff <= 1e-15; + auto check_diff = (ir::EQ::Make(a, b) || (left <= right)) || (diff <= 1e-15f); + + return ir::Select::Make(check_x_nan || check_y_nan, check_nan_same, check_diff); + }, + out_name); +} + } // namespace pe } // namespace hlir } // namespace cinn diff --git a/cinn/hlir/pe/broadcast.h b/cinn/hlir/pe/broadcast.h index b54c14a4cb..ad5f9dc250 100644 --- a/cinn/hlir/pe/broadcast.h +++ b/cinn/hlir/pe/broadcast.h @@ -102,6 +102,14 @@ ir::Tensor BroadcastTo(const ir::Tensor& A, const std::vector& broadcast_axes, const std::string& out_name = common::UniqName("T_broadcast_to_out")); +// This operator checks if all x and y satisfy the condition: |x - y| <= atol + rtol * |y| +ir::Tensor IsClose(const ir::Tensor& x, + const ir::Tensor& y, + float rtol = 1e-05f, + float atol = 1e-08f, + bool equal_nan = false, + const std::string& out_name = common::UniqName("IsClose_output")); + } // namespace pe } // namespace hlir } // namespace cinn diff --git a/cinn/pybind/frontend.cc b/cinn/pybind/frontend.cc index 58e468bd1d..006c08facc 100644 --- a/cinn/pybind/frontend.cc +++ b/cinn/pybind/frontend.cc @@ -413,7 +413,14 @@ void BindFrontend(pybind11::module *m) { py::arg("x"), py::arg("updates"), py::arg("index"), - py::arg("axis") = 0); + py::arg("axis") = 0) + .def("isclose", + &BaseBuilder::IsClose, + py::arg("x"), + py::arg("y"), + py::arg("rtol") = 1e-05f, + py::arg("atol") = 1e-08f, + py::arg("equal_nan") = false); ; py::class_(*m, "NetBuilder") diff --git a/cinn/runtime/cuda/cuda_intrinsics.cc b/cinn/runtime/cuda/cuda_intrinsics.cc index 7623767b7e..da5cb594a1 100644 --- a/cinn/runtime/cuda/cuda_intrinsics.cc +++ b/cinn/runtime/cuda/cuda_intrinsics.cc @@ -47,10 +47,18 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(asinh); REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(atan); REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(atanh); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(isnan); REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(tanh); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(isfinite); - REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT(isinf); + +#undef REGISTER_EXTERN_FUNC_1_IN_1_OUT_FLOAT + +#define REGISTER_EXTERN_FUNC_1_IN_1_OUT_BOOL(func__) \ + REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT(cinn_nvgpu_##func__##_fp32, target, float, bool); + + REGISTER_EXTERN_FUNC_1_IN_1_OUT_BOOL(isnan); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_BOOL(isfinite); + REGISTER_EXTERN_FUNC_1_IN_1_OUT_BOOL(isinf); + +#undef REGISTER_EXTERN_FUNC_1_IN_1_OUT_BOOL FunctionProto::shape_inference_t inference_shape_globalpool = [](const std::vector &args, int offset) { diff --git a/python/tests/ops/test_isclose_op.py b/python/tests/ops/test_isclose_op.py new file mode 100644 index 0000000000..83e189df57 --- /dev/null +++ b/python/tests/ops/test_isclose_op.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2022 CINN Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest, OpTestTool +import paddle +import paddle.nn.functional as F +import cinn +from cinn.frontend import * +from cinn.common import * +import sys + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "x86 test will be skipped due to timeout.") +class TestIsCloseOp(OpTest): + def setUp(self): + self.init_case() + + def init_case(self): + self.inputs = {"x": np.random.random((16, 16)).astype("float32")} + self.inputs['y'] = self.inputs["x"] + self.rtol = 1e-05 + self.atol = 1e-08 + self.equal_nan = False + + def build_paddle_program(self, target): + x = paddle.to_tensor(self.inputs["x"], stop_gradient=False) + y = paddle.to_tensor(self.inputs["y"], stop_gradient=False) + out = paddle.isclose(x, y, self.rtol, self.atol, self.equal_nan) + self.paddle_outputs = [out] + + def build_cinn_program(self, target): + builder = NetBuilder("isclose") + + x = builder.create_input(Float(32), self.inputs["x"].shape, "x") + y = builder.create_input(Float(32), self.inputs["y"].shape, "y") + out = builder.isclose(x, y, self.rtol, self.atol, self.equal_nan) + prog = builder.build() + forward_res = self.get_cinn_output( + prog, target, [x, y], [self.inputs["x"], self.inputs["y"]], [out]) + + self.cinn_outputs = forward_res + + def test_check_results(self): + self.check_outputs_and_grads() + + +class TestIsCloseOpCase1(TestIsCloseOp): + def init_case(self): + self.inputs = { + "x": np.random.random((16, 16)).astype("float32"), + "y": np.random.random((16, 16)).astype("float32") + } + self.rtol = 1e-05 + self.atol = 1e-08 + self.equal_nan = False + + +class TestIsCloseOpCase2(TestIsCloseOp): + def init_case(self): + self.inputs = { + "x": np.array([np.nan] * 32).astype("float32"), + "y": np.random.random((32)).astype("float32") + } + self.rtol = 1e-05 + self.atol = 1e-08 + self.equal_nan = False + + +class TestIsCloseOpCase3(TestIsCloseOp): + def init_case(self): + self.inputs = { + "x": np.array([np.nan] * 32).astype("float32"), + "y": np.array([np.nan] * 32).astype("float32") + } + self.rtol = 1e-05 + self.atol = 1e-08 + self.equal_nan = True + + +if __name__ == "__main__": + unittest.main() From 31102d111474603b86d369de4de70ec11be72894 Mon Sep 17 00:00:00 2001 From: BiynXu <62832681+BiynXu@users.noreply.github.com> Date: Mon, 5 Sep 2022 11:28:01 +0800 Subject: [PATCH 2/8] Add new schedule to unittests (#904) * update op_test, op_nn_test, op_broadcast test, transform_test * Update LowerToModule --- cinn/hlir/framework/op_test.cc | 45 ++++-- cinn/hlir/op/op_broadcast_test.cc | 37 +++-- cinn/hlir/op/op_nn_test.cc | 228 +++++++++++++----------------- cinn/hlir/op/transform_test.cc | 45 ++++-- 4 files changed, 184 insertions(+), 171 deletions(-) diff --git a/cinn/hlir/framework/op_test.cc b/cinn/hlir/framework/op_test.cc index b9d42c17dc..406e1ad628 100644 --- a/cinn/hlir/framework/op_test.cc +++ b/cinn/hlir/framework/op_test.cc @@ -20,10 +20,14 @@ #include #include "cinn/cinn.h" +#include "cinn/hlir/framework/graph_compiler.h" #include "cinn/hlir/framework/node.h" #include "cinn/hlir/framework/op_strategy.h" #include "cinn/hlir/op/use_ops.h" #include "cinn/hlir/pe/broadcast.h" +#include "cinn/runtime/flags.h" + +DECLARE_bool(cinn_ir_schedule); namespace cinn { namespace hlir { @@ -46,21 +50,36 @@ TEST(Operator, GetAttrs) { common::Target target = common::DefaultHostTarget(); auto impl = OpStrategy::SelectImpl(strategy[add](attrs, inputs, type, {{100, 32}}, target)); - common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B)}}; - common::CINNValuePack rets = impl->fcompute(cinn_input); - ASSERT_EQ(rets.size(), 2UL); - rets = impl->fschedule(rets); - ASSERT_EQ(rets.size(), 2UL); - // the last element is a StageMap - for (int i = 0; i < rets->size() - 1; i++) { - ir::Expr temp = rets[i]; - inputs.push_back(temp.as_tensor_ref()); - } - auto func = Lower("add1", rets.back(), inputs); - LOG(INFO) << "Test Strategy Codegen:\n" << func; - ASSERT_EQ(impl->name, "strategy.elementwise_add.x86"); ASSERT_EQ(add->description, "elementwise_add function"); + + std::string func_name = "add1"; + + if (FLAGS_cinn_ir_schedule) { + std::string out_name = "C"; + common::CINNValuePack cinn_input = + common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B), common::CINNValue(out_name)}}; + std::vector input_output_names{"A", "B", out_name}; + + auto funcs = framework::GetFuncFromImpl(impl, cinn_input, inputs, input_output_names, func_name, target); + + for (auto func : funcs) { + LOG(INFO) << "Test Operator_ElementWise_Add_Test0's Strategy, func is :\n" << func; + } + } else { + common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B)}}; + common::CINNValuePack rets = impl->fcompute(cinn_input); + ASSERT_EQ(rets.size(), 2UL); + rets = impl->fschedule(rets); + ASSERT_EQ(rets.size(), 2UL); + // the last element is a StageMap + for (int i = 0; i < rets->size() - 1; i++) { + ir::Expr temp = rets[i]; + inputs.push_back(temp.as_tensor_ref()); + } + auto func = Lower(func_name, rets.back(), inputs); + LOG(INFO) << "Test Strategy Codegen:\n" << func; + } } } // namespace framework diff --git a/cinn/hlir/op/op_broadcast_test.cc b/cinn/hlir/op/op_broadcast_test.cc index b3d8166937..7ae529881f 100644 --- a/cinn/hlir/op/op_broadcast_test.cc +++ b/cinn/hlir/op/op_broadcast_test.cc @@ -254,6 +254,15 @@ TEST(Operator, Operator_BroadcastTo) { } } +common::CINNValuePack GetComputeResult(const std::shared_ptr &impl, + std::vector &cinn_inputs, + const std::string &output_name = "") { + if (FLAGS_cinn_ir_schedule) { + cinn_inputs.emplace_back(output_name); + } + return impl->fcompute(common::CINNValuePack{cinn_inputs}); +} + TEST(Operator, Operator_BroadcastTo_0) { auto const_scalar = Operator::Get("const_scalar"); auto broadcast_to = Operator::Get("broadcast_to"); @@ -283,37 +292,37 @@ TEST(Operator, Operator_BroadcastTo_0) { auto impl_0 = OpStrategy::SelectImpl(strategy[const_scalar](attrs, std::vector{}, type, {out_shape}, target)); std::vector cinn_inputs; - common::CINNValuePack rets_0 = impl_0->fcompute(common::CINNValuePack{cinn_inputs}); + common::CINNValuePack rets_0 = GetComputeResult(impl_0, cinn_inputs, "out_0"); ir::Expr out_0 = rets_0[0]; auto tensor_0 = out_0.as_tensor_ref(); poly::StageMap stages_0 = rets_0.back(); - auto impl_1 = OpStrategy::SelectImpl(strategy[broadcast_to](attrs, {tensor_0}, type, {out_shape}, target)); - auto input_1 = common::CINNValuePack{{{common::CINNValue(tensor_0)}}}; - common::CINNValuePack rets_1 = impl_1->fcompute(input_1); + auto impl_1 = OpStrategy::SelectImpl(strategy[broadcast_to](attrs, {tensor_0}, type, {out_shape}, target)); + std::vector cinn_inputs_1 = {{common::CINNValue(tensor_0)}}; + common::CINNValuePack rets_1 = GetComputeResult(impl_1, cinn_inputs_1, "out_1"); ir::Expr out_1 = rets_1[0]; auto tensor_1 = out_1.as_tensor_ref(); poly::StageMap stages_1 = rets_1.back(); - auto impl_2 = OpStrategy::SelectImpl(strategy[reduce_sum](attrs, {A.tensor()}, type, {out_shape}, target)); - auto input_2 = common::CINNValuePack{{{common::CINNValue(A.tensor())}}}; - common::CINNValuePack rets_2 = impl_2->fcompute(input_2); + auto impl_2 = OpStrategy::SelectImpl(strategy[reduce_sum](attrs, {A.tensor()}, type, {out_shape}, target)); + std::vector cinn_inputs_2 = {{common::CINNValue(A.tensor())}}; + common::CINNValuePack rets_2 = GetComputeResult(impl_2, cinn_inputs_2, "out_2"); ir::Expr out_2 = rets_2[0]; auto tensor_2 = out_2.as_tensor_ref(); poly::StageMap stages_2 = rets_2.back(); - auto input_4 = common::CINNValuePack{{{common::CINNValue(A.tensor())}}}; - common::CINNValuePack rets_4 = impl_2->fcompute(input_4); - ir::Expr out_4 = rets_4[0]; - auto tensor_4 = out_4.as_tensor_ref(); - poly::StageMap stages_4 = rets_4.back(); + std::vector cinn_inputs_4 = {{common::CINNValue(A.tensor())}}; + common::CINNValuePack rets_4 = GetComputeResult(impl_2, cinn_inputs_4, "out_4"); + ir::Expr out_4 = rets_4[0]; + auto tensor_4 = out_4.as_tensor_ref(); + poly::StageMap stages_4 = rets_4.back(); auto impl_3 = OpStrategy::SelectImpl(strategy[elementwise_add](attrs, {tensor_1, tensor_2}, type, {out_shape}, target)); - auto input_3 = common::CINNValuePack{{{common::CINNValue(tensor_1), common::CINNValue(tensor_2)}}}; - common::CINNValuePack rets_3 = impl_3->fcompute(input_3); + std::vector cinn_inputs_3 = {{common::CINNValue(tensor_1), common::CINNValue(tensor_2)}}; + common::CINNValuePack rets_3 = GetComputeResult(impl_3, cinn_inputs_3, "out_3"); ir::Expr out_3 = rets_3[0]; auto tensor_3 = out_3.as_tensor_ref(); diff --git a/cinn/hlir/op/op_nn_test.cc b/cinn/hlir/op/op_nn_test.cc index c0e99cf797..55f595227e 100644 --- a/cinn/hlir/op/op_nn_test.cc +++ b/cinn/hlir/op/op_nn_test.cc @@ -22,11 +22,15 @@ #include "cinn/cinn.h" #include "cinn/common/target.h" #include "cinn/common/test_helper.h" +#include "cinn/hlir/framework/graph_compiler.h" #include "cinn/hlir/framework/node.h" #include "cinn/hlir/framework/op.h" #include "cinn/hlir/framework/op_strategy.h" #include "cinn/hlir/op/use_ops.h" #include "cinn/hlir/pe/nn.h" +#include "cinn/runtime/flags.h" + +DECLARE_bool(cinn_ir_schedule); namespace cinn { namespace hlir { @@ -34,6 +38,45 @@ namespace framework { using CCompute = std::function(const std::vector)>; +Module LowerToModule(const std::string test_name, + const std::string func_name, + const std::shared_ptr &impl, + std::vector input_names, + const std::string &output_name, + std::vector &inputs, + std::vector cinn_inputs, + const Target &target) { + Module::Builder builder("module", target); + + if (FLAGS_cinn_ir_schedule) { + cinn_inputs.emplace_back(output_name); + common::CINNValuePack cinn_input = common::CINNValuePack{cinn_inputs}; + input_names.push_back(output_name); + + auto funcs = framework::GetFuncFromImpl(impl, cinn_input, inputs, input_names, func_name, target); + + for (auto func : funcs) { + LOG(INFO) << "Test" << test_name << "'s Strategy, func is :\n" << func; + builder.AddFunction(func); + } + } else { + common::CINNValuePack cinn_input = common::CINNValuePack{cinn_inputs}; + common::CINNValuePack rets = impl->fcompute(cinn_input); + rets = impl->fschedule(rets); + // the last element is a StageMap + for (int i = 0; i < rets->size() - 1; i++) { + Expr temp = rets[i]; + inputs.push_back(temp.as_tensor_ref()); + } + auto func = Lower("fn_" + func_name, rets.back(), inputs); + LOG(INFO) << "Test Strategy Codegen:\n" << func; + + builder.AddFunction(func); + } + + return builder.Build(); +} + TEST(Operator, Operator_Pool2d_Test0) { auto pool2d = Operator::Get("pool2d"); Operator temp = *pool2d; @@ -55,25 +98,15 @@ TEST(Operator, Operator_Pool2d_Test0) { std::vector type{Float(32)}; common::Target target = common::DefaultHostTarget(); auto impl = OpStrategy::SelectImpl(strategy[pool2d](attrs, inputs, type, {{1, 3, 10, 10}, {1, 3, 5, 5}}, target)); - common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A)}}; - common::CINNValuePack rets = impl->fcompute(cinn_input); - rets = impl->fschedule(rets); - ASSERT_EQ(rets.size(), 2UL); - // the last element is a StageMap - for (int i = 0; i < rets->size() - 1; i++) { - Expr temp = rets[i]; - inputs.push_back(temp.as_tensor_ref()); - } - auto func = Lower("pool2d", rets.back(), inputs); - LOG(INFO) << "Test Strategy Codegen:\n" << func; - Module::Builder builder("module0", target); - builder.AddFunction(func); - auto jit = backends::ExecutionEngine::Create({}); - auto module = builder.Build(); + std::string func_name = "pool2d"; + auto module = + LowerToModule("Operator_Pool2d_Test0", func_name, impl, {"A"}, "B", inputs, {common::CINNValue(A)}, target); + + auto jit = backends::ExecutionEngine::Create({}); jit->Link(module); - auto fn = jit->Lookup("pool2d"); + auto fn = jit->Lookup("fn_" + func_name); CHECK(fn); auto fn_ = reinterpret_cast(fn); @@ -111,25 +144,16 @@ TEST(Operator, Operator_Pool2d_Test1) { std::vector type{Float(32)}; common::Target target = common::DefaultHostTarget(); auto impl = OpStrategy::SelectImpl(strategy[pool2d](attrs, inputs, type, {{1, 3, 11, 11}, {1, 3, 5, 5}}, target)); - common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A)}}; - common::CINNValuePack rets = impl->fcompute(cinn_input); - rets = impl->fschedule(rets); - ASSERT_EQ(rets.size(), 2UL); - // the last element is a StageMap - for (int i = 0; i < rets->size() - 1; i++) { - Expr temp = rets[i]; - inputs.push_back(temp.as_tensor_ref()); - } - auto func = Lower("pool2d", rets.back(), inputs); - LOG(INFO) << "Test Strategy Codegen:\n" << func; - Module::Builder builder("module0", target); - builder.AddFunction(func); - auto jit = backends::ExecutionEngine::Create({}); - auto module = builder.Build(); + std::string func_name = "pool2d"; + + auto module = + LowerToModule("Operator_Pool2d_Test1", func_name, impl, {"A"}, "B", inputs, {common::CINNValue(A)}, target); + + auto jit = backends::ExecutionEngine::Create({}); jit->Link(module); - auto fn = jit->Lookup("pool2d"); + auto fn = jit->Lookup("fn_" + func_name); CHECK(fn); auto fn_ = reinterpret_cast(fn); @@ -169,25 +193,16 @@ TEST(Operator, Operator_Pool2d_Test2) { std::vector type{Float(32)}; common::Target target = common::DefaultHostTarget(); auto impl = OpStrategy::SelectImpl(strategy[pool2d](attrs, inputs, type, {{1, 11, 11, 3}, {1, 5, 5, 3}}, target)); - common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A)}}; - common::CINNValuePack rets = impl->fcompute(cinn_input); - rets = impl->fschedule(rets); - ASSERT_EQ(rets.size(), 2UL); - // the last element is a StageMap - for (int i = 0; i < rets->size() - 1; i++) { - Expr temp = rets[i]; - inputs.push_back(temp.as_tensor_ref()); - } - auto func = Lower("pool2d", rets.back(), inputs); - LOG(INFO) << "Test Strategy Codegen:\n" << func; - Module::Builder builder("module0", target); - builder.AddFunction(func); - auto jit = backends::ExecutionEngine::Create({}); - auto module = builder.Build(); + std::string func_name = "pool2d"; + + auto module = + LowerToModule("Operator_Pool2d_Test2", func_name, impl, {"A"}, "B", inputs, {common::CINNValue(A)}, target); + + auto jit = backends::ExecutionEngine::Create({}); jit->Link(module); - auto fn = jit->Lookup("pool2d"); + auto fn = jit->Lookup("fn_" + func_name); CHECK(fn); auto fn_ = reinterpret_cast(fn); @@ -228,25 +243,15 @@ TEST(Operator, Operator_Pool3d_Test0) { common::Target target = common::DefaultHostTarget(); auto impl = OpStrategy::SelectImpl(strategy[pool3d](attrs, inputs, type, {{1, 11, 11, 11, 3}, {1, 5, 5, 5, 3}}, target)); - common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A)}}; - common::CINNValuePack rets = impl->fcompute(cinn_input); - rets = impl->fschedule(rets); - ASSERT_EQ(rets.size(), 2UL); - // the last element is a StageMap - for (int i = 0; i < rets->size() - 1; i++) { - Expr temp = rets[i]; - inputs.push_back(temp.as_tensor_ref()); - } - auto func = Lower("pool3d", rets.back(), inputs); - LOG(INFO) << "Test Strategy Codegen:\n" << func; - Module::Builder builder("module0", target); - builder.AddFunction(func); - auto jit = backends::ExecutionEngine::Create({}); - auto module = builder.Build(); + std::string func_name = "pool3d"; + auto module = + LowerToModule("Operator_Pool3d_Test0", func_name, impl, {"A"}, "B", inputs, {common::CINNValue(A)}, target); + + auto jit = backends::ExecutionEngine::Create({}); jit->Link(module); - auto fn = jit->Lookup("pool3d"); + auto fn = jit->Lookup("fn_" + func_name); CHECK(fn); auto fn_ = reinterpret_cast(fn); @@ -286,25 +291,15 @@ TEST(Operator, Operator_Pool1d_Test0) { std::vector type{Float(32)}; common::Target target = common::DefaultHostTarget(); auto impl = OpStrategy::SelectImpl(strategy[pool1d](attrs, inputs, type, {{1, 11, 3}, {1, 5, 3}}, target)); - common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A)}}; - common::CINNValuePack rets = impl->fcompute(cinn_input); - rets = impl->fschedule(rets); - ASSERT_EQ(rets.size(), 2UL); - // the last element is a StageMap - for (int i = 0; i < rets->size() - 1; i++) { - Expr temp = rets[i]; - inputs.push_back(temp.as_tensor_ref()); - } - auto func = Lower("pool1d", rets.back(), inputs); - LOG(INFO) << "Test Strategy Codegen:\n" << func; - Module::Builder builder("module0", target); - builder.AddFunction(func); - auto jit = backends::ExecutionEngine::Create({}); - auto module = builder.Build(); + std::string func_name = "pool1d"; + auto module = + LowerToModule("Operator_Pool1d_Test0", func_name, impl, {"A"}, "B", inputs, {common::CINNValue(A)}, target); + + auto jit = backends::ExecutionEngine::Create({}); jit->Link(module); - auto fn = jit->Lookup("pool1d"); + auto fn = jit->Lookup("fn_" + func_name); CHECK(fn); auto fn_ = reinterpret_cast(fn); @@ -342,26 +337,19 @@ TEST(Operator, Operator_Select_Test0) { ASSERT_EQ(infer_shape[0][2], 64); auto impl = OpStrategy::SelectImpl(strategy[select](attrs, inputs, type, {{16, 64, 64}}, target)); - common::CINNValuePack cinn_input = common::CINNValuePack{ - {common::CINNValue(condition), common::CINNValue(true_value), common::CINNValue(false_value)}}; - common::CINNValuePack rets = impl->fcompute(cinn_input); - rets = impl->fschedule(rets); - ASSERT_EQ(rets.size(), 2UL); - // the last element is a StageMap - for (int i = 0; i < rets->size() - 1; i++) { - Expr temp = rets[i]; - inputs.push_back(temp.as_tensor_ref()); - } - auto func = Lower("select", rets.back(), inputs); - LOG(INFO) << "Test Strategy Codegen:\n" << func; - Module::Builder builder("module0", target); - builder.AddFunction(func); - auto jit = backends::ExecutionEngine::Create({}); - auto module = builder.Build(); + std::string func_name = "select"; + std::vector input_names = {"condition", "true_value", "false_value"}; + std::vector cinn_inputs = { + common::CINNValue(condition), common::CINNValue(true_value), common::CINNValue(false_value)}; + + auto module = LowerToModule( + "Operator_Select_Test0", func_name, impl, std::move(input_names), "output", inputs, cinn_inputs, target); + + auto jit = backends::ExecutionEngine::Create({}); jit->Link(module); - auto fn = jit->Lookup("select"); + auto fn = jit->Lookup("fn_" + func_name); CHECK(fn); auto fn_ = reinterpret_cast(fn); @@ -408,26 +396,15 @@ TEST(Operator, Operator_Reverse_Test0) { common::Target target = common::DefaultHostTarget(); auto impl = OpStrategy::SelectImpl(strategy[reverse](attrs, inputs, type, {{c, h, w}}, target)); - common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A)}}; - common::CINNValuePack rets = impl->fcompute(cinn_input); - rets = impl->fschedule(rets); - ASSERT_EQ(rets.size(), 2UL); - - // the last element is a StageMap - for (int i = 0; i < rets->size() - 1; i++) { - Expr temp = rets[i]; - inputs.push_back(temp.as_tensor_ref()); - } - auto func = Lower("reverse", rets.back(), inputs); - LOG(INFO) << "Test Strategy Codegen:\n" << func; - Module::Builder builder("module0", target); - builder.AddFunction(func); - auto jit = backends::ExecutionEngine::Create({}); - auto module = builder.Build(); + std::string func_name = "reverse"; + auto module = + LowerToModule("Operator_Reverse_Test0", func_name, impl, {"A"}, "B", inputs, {common::CINNValue(A)}, target); + + auto jit = backends::ExecutionEngine::Create({}); jit->Link(module); - auto fn = jit->Lookup("reverse"); + auto fn = jit->Lookup("fn_" + func_name); CHECK(fn); auto fn_ = reinterpret_cast(fn); @@ -492,26 +469,15 @@ TEST(Operator, Operator_Transpose_Test0) { auto output_shape = {n, h, w, c}; auto impl = OpStrategy::SelectImpl(strategy[transpose](attrs, inputs, type, {output_shape}, target)); - common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A)}}; - common::CINNValuePack rets = impl->fcompute(cinn_input); - rets = impl->fschedule(rets); - ASSERT_EQ(rets.size(), 2UL); - - // the last element is a StageMap - for (int i = 0; i < rets->size() - 1; i++) { - Expr temp = rets[i]; - inputs.push_back(temp.as_tensor_ref()); - } - auto func = Lower("transpose", rets.back(), inputs); - LOG(INFO) << "Test Strategy Codegen:\n" << func; - Module::Builder builder("module0", target); - builder.AddFunction(func); - auto jit = backends::ExecutionEngine::Create({}); - auto module = builder.Build(); + std::string func_name = "transpose"; + auto module = + LowerToModule("Operator_Transpose_Test0", func_name, impl, {"A"}, "B", inputs, {common::CINNValue(A)}, target); + + auto jit = backends::ExecutionEngine::Create({}); jit->Link(module); - auto fn = jit->Lookup("transpose"); + auto fn = jit->Lookup("fn_" + func_name); CHECK(fn); auto fn_ = reinterpret_cast(fn); diff --git a/cinn/hlir/op/transform_test.cc b/cinn/hlir/op/transform_test.cc index bafe2f6596..8dc0cc46a0 100644 --- a/cinn/hlir/op/transform_test.cc +++ b/cinn/hlir/op/transform_test.cc @@ -30,6 +30,7 @@ #include "cinn/cinn.h" #include "cinn/common/target.h" #include "cinn/common/test_helper.h" +#include "cinn/hlir/framework/graph_compiler.h" #include "cinn/hlir/framework/node.h" #include "cinn/hlir/framework/op.h" #include "cinn/hlir/framework/op_strategy.h" @@ -37,6 +38,9 @@ #include "cinn/hlir/pe/nn.h" #include "cinn/runtime/cinn_runtime.h" #include "cinn/runtime/cuda/cuda_module.h" +#include "cinn/runtime/flags.h" + +DECLARE_bool(cinn_ir_schedule); namespace cinn { namespace hlir { @@ -78,22 +82,37 @@ TEST(SliceAssign, SliceAssign_Op) { #endif auto impl = OpStrategy::SelectImpl(strategy(attrs, inputs, out_type, {output_shape}, target)); - common::CINNValuePack cinn_input = - common::CINNValuePack{{common::CINNValue(input.tensor()), common::CINNValue(assign.tensor())}}; - common::CINNValuePack rets = impl->fcompute(cinn_input); - rets = impl->fschedule(rets); + std::string func_name = "slice_assign"; + + if (FLAGS_cinn_ir_schedule) { + std::string out_name = "output"; + common::CINNValuePack cinn_input = common::CINNValuePack{ + {common::CINNValue(input.tensor()), common::CINNValue(assign.tensor()), common::CINNValue(out_name)}}; + std::vector input_output_names{"input", "assign", out_name}; - // the last element is a StageMap - for (int i = 0; i < rets->size() - 1; i++) { - Expr temp = rets[i]; - if (!temp.as_tensor_ref()->buffer.defined()) { - inputs.push_back(temp.as_tensor_ref()); + auto funcs = framework::GetFuncFromImpl(impl, cinn_input, inputs, input_output_names, func_name, target); + + for (auto func : funcs) { + LOG(INFO) << "Test Operator_BroadcastTo's Strategy, func is :\n" << func; } - } + } else { + common::CINNValuePack cinn_input = + common::CINNValuePack{{common::CINNValue(input.tensor()), common::CINNValue(assign.tensor())}}; + common::CINNValuePack rets = impl->fcompute(cinn_input); + rets = impl->fschedule(rets); - auto func = lang::LowerVec("slice_assign", rets.back(), inputs, {}, {}, nullptr, target); - for (auto& f : func) { - LOG(INFO) << "Test Strategy Codegen:\n" << f; + // the last element is a StageMap + for (int i = 0; i < rets->size() - 1; i++) { + Expr temp = rets[i]; + if (!temp.as_tensor_ref()->buffer.defined()) { + inputs.push_back(temp.as_tensor_ref()); + } + } + + auto func = lang::LowerVec("slice_assign", rets.back(), inputs, {}, {}, nullptr, target); + for (auto& f : func) { + LOG(INFO) << "Test Strategy Codegen:\n" << f; + } } } From 6ad58c94e43d37d4fe27708c1d1935425f945357 Mon Sep 17 00:00:00 2001 From: jiangcheng Date: Mon, 5 Sep 2022 16:05:40 +0800 Subject: [PATCH 3/8] fix cinn matmul cannot compute some shape problem (#923) * fix cinn matmul cannot compute some shape problem * fix some bug like the output shape not correct * fix by review advices --- cinn/frontend/net_builder.cc | 218 +++++++++++++++++++++- cinn/frontend/net_builder.h | 6 + cinn/frontend/op_mappers/paddle/matmul.cc | 9 +- python/tests/ops/test_matmul_op.py | 70 +++++++ 4 files changed, 299 insertions(+), 4 deletions(-) mode change 100755 => 100644 cinn/frontend/net_builder.h diff --git a/cinn/frontend/net_builder.cc b/cinn/frontend/net_builder.cc index 4bf76eea63..9ad7038369 100644 --- a/cinn/frontend/net_builder.cc +++ b/cinn/frontend/net_builder.cc @@ -281,14 +281,228 @@ std::vector NetBuilder::Conv2dGrad(const Variable& dy, return instr.GetOutputs(); } +std::pair NetBuilder::BroadcastMatmulInput( + const Variable& x, const Variable& y, bool trans_x, bool trans_y, float alpha) { + const auto &x_shape = x->shape, &y_shape = y->shape; + + auto matmul_info = [&]() { + std::stringstream ss; + ss << "matmul(X:" << x->id << "[" << cinn::utils::Join(x_shape, ", ") << "], Y:" << y->id << "[" + << cinn::utils::Join(y_shape, ", ") << "]" + << ", trans_x=" << trans_x << ", trans_y=" << trans_y << ", alpha=" << alpha << ")"; + return ss.str(); + }; + + CHECK(!x_shape.empty()) << "The input X:" << x->id << " of matmul should not empty! Please check."; + CHECK(!y_shape.empty()) << "The input Y:" << y->id << " of matmul should not empty! Please check."; + + int x_dim = x_shape.size(), y_dim = y_shape.size(); + int max_dim = std::max(x_shape.size(), y_shape.size()); + + std::vector new_x_shape, new_y_shape; + if (max_dim == 1) { + // vector * vector + CHECK(x_shape == y_shape) + << "The matmul input X's numbers must be equal to Y's numbers,when X/Y's dims =1. But here " << matmul_info(); + + // do not need broadcast + return {x, y}; + } else if (x_dim == 1) { + // vector * matrix + int y_K = trans_y ? y_shape[max_dim - 1] : y_shape[max_dim - 2]; + CHECK_EQ(y_K, x_shape[0]) << "The K dimension of Y:" << y_K << " should equal to X.shape[0]:" << x_shape[0] + << ". But here " << matmul_info(); + + // broadcast vector x to the same batch size + // [m] * [a, b, m, d] -> [a, b, 1, m] * [a, b, m, d] + new_x_shape = y_shape; + new_x_shape[max_dim - 2] = 1; + new_x_shape[max_dim - 1] = x_shape[0]; + } else if (y_dim == 1) { + // matrix * vector + int x_K = trans_x ? x_shape[max_dim - 2] : x_shape[max_dim - 1]; + CHECK_EQ(x_K, y_shape[0]) << "The K dimension of X:" << x_K << " should equal to Y.shape[0]:" << y_shape[0] + << ". But here " << matmul_info(); + + // broadcast vector y to the same batch size + // [a, b, c, m] * [m] -> [a, b, c, m] * [a, b, m, 1] + new_y_shape = x_shape; + new_y_shape[max_dim - 2] = y_shape[0]; + new_y_shape[max_dim - 1] = 1; + } else { + // matrix * matrix + int x_K = trans_x ? x_shape[x_dim - 2] : x_shape[x_dim - 1]; + int y_K = trans_y ? y_shape[y_dim - 1] : y_shape[y_dim - 2]; + CHECK_EQ(x_K, y_K) << "The K dimension of matmul not equal. Where " << matmul_info(); + + // if dimension of A or B greater than 2, broadcast input to the same shape + auto gen_new_shape = [max_dim](const std::vector& old_shape) { + std::vector new_shape; + if (old_shape.size() != max_dim) { + // if dim not equal, full 1 + new_shape.resize(max_dim - old_shape.size(), 1); + new_shape.insert(new_shape.end(), old_shape.begin(), old_shape.end()); + } else { + new_shape = old_shape; + } + return new_shape; + }; + new_x_shape = gen_new_shape(x_shape); + new_y_shape = gen_new_shape(y_shape); + + // keep the front batch dimension same + for (int i = 0; i < max_dim - 2; ++i) { + if (new_x_shape[i] == new_y_shape[i]) { + continue; + } + + CHECK(new_x_shape[i] == 1 || new_y_shape[i] == 1) + << "Input X and Y's batch dimension should be same or 1. But here " << matmul_info(); + + // broadcast the value 1 dimension + if (new_x_shape[i] == 1) { + new_x_shape[i] = new_y_shape[i]; + } else { + new_y_shape[i] = new_x_shape[i]; + } + } + } + + auto broad_x = x, broad_y = y; + if (!new_x_shape.empty() && new_x_shape != x_shape) { + int new_size = std::accumulate(new_x_shape.begin(), new_x_shape.end(), 1, std::multiplies()); + int old_size = std::accumulate(x_shape.begin(), x_shape.end(), 1, std::multiplies()); + + if (new_size == old_size) { + VLOG(4) << "Reshape matmul's input X from [" << cinn::utils::Join(x_shape, ", ") << "] to [" + << cinn::utils::Join(new_x_shape, ", ") << "]. Where " << matmul_info(); + broad_x = Reshape(x, new_x_shape); + } else { + VLOG(4) << "Broadcast matmul's input X from [" << cinn::utils::Join(x_shape, ", ") << "] to [" + << cinn::utils::Join(new_x_shape, ", ") << "]. Where " << matmul_info(); + broad_x = BroadcastTo(x, new_x_shape); + } + } + + if (!new_y_shape.empty() && new_y_shape != y_shape) { + int new_size = std::accumulate(new_y_shape.begin(), new_y_shape.end(), 1, std::multiplies()); + int old_size = std::accumulate(y_shape.begin(), y_shape.end(), 1, std::multiplies()); + + if (new_size == old_size) { + // only need reshape + VLOG(4) << "Reshape matmul's input Y from [" << cinn::utils::Join(y_shape, ", ") << "] to [" + << cinn::utils::Join(new_y_shape, ", ") << "]. Where " << matmul_info(); + broad_y = Reshape(y, new_y_shape); + } else { + // need broadcast + VLOG(4) << "Broadcast matmul's input Y from [" << cinn::utils::Join(y_shape, ", ") << "] to [" + << cinn::utils::Join(new_y_shape, ", ") << "]. Where " << matmul_info(); + broad_y = BroadcastTo(y, new_y_shape); + } + } + + return {broad_x, broad_y}; +} + +std::vector NetBuilder::GetMatmulOutputShape( + const Variable& x, const Variable& y, bool trans_x, bool trans_y, float alpha) { + const auto &x_shape = x->shape, &y_shape = y->shape; + + auto matmul_info = [&]() { + std::stringstream ss; + ss << "matmul(X:" << x->id << "[" << cinn::utils::Join(x_shape, ", ") << "], Y:" << y->id << "[" + << cinn::utils::Join(y_shape, ", ") << "]" + << ", trans_x=" << trans_x << ", trans_y=" << trans_y << ", alpha=" << alpha << ")"; + return ss.str(); + }; + + int x_dim = x_shape.size(), y_dim = y_shape.size(); + int max_dim = std::max(x_shape.size(), y_shape.size()); + + std::vector out_shape; + if (max_dim == 1) { + // vector * vector + CHECK(x_shape == y_shape) + << "The matmul input X's numbers must be equal to Y's numbers,when X/Y's dims =1. But here " << matmul_info(); + + out_shape = {1}; + } else if (x_dim == 1) { + // vector * matrix + out_shape = y_shape; + if (trans_y) { + // [m] * [a, b, d, m] -> [a, b, d] + out_shape.erase(out_shape.end() - 1); + } else { + // [m] * [a, b, m, d] -> [a, b, d] + out_shape.erase(out_shape.end() - 2); + } + } else if (y_dim == 1) { + // matrix * vector + out_shape = x_shape; + if (trans_x) { + // [a, b, m, c] * [m] -> [a, b, c] + out_shape.erase(out_shape.end() - 2); + } else { + // [a, b, c, m] * [m] -> [a, b, c] + out_shape.erase(out_shape.end() - 1); + } + } else { + // matrix * matrix + int M = trans_x ? x_shape[x_dim - 1] : x_shape[x_dim - 2]; + int N = trans_y ? y_shape[y_dim - 2] : y_shape[y_dim - 1]; + + out_shape.resize(max_dim, 1); + out_shape[max_dim - 2] = M; + out_shape[max_dim - 1] = N; + + // get the batch dimension after broadcast + int x_pos = x_dim - 3, y_pos = y_dim - 3, out_pos = max_dim - 3; + while (x_pos >= 0 && y_pos >= 0) { + CHECK(x_shape[x_pos] == y_shape[y_pos] || x_shape[x_pos] == 1 || y_shape[y_pos] == 1) + << "Input X and Y's batch dimension should be same or 1. But here " << matmul_info(); + out_shape[out_pos] = (x_shape[x_pos] == 1) ? y_shape[y_pos] : x_shape[x_pos]; + + out_pos--; + x_pos--; + y_pos--; + } + + while (x_pos >= 0) { + out_shape[out_pos--] = x_shape[x_pos--]; + } + while (y_pos >= 0) { + out_shape[out_pos--] = x_shape[y_pos--]; + } + } + return out_shape; +} + Variable NetBuilder::Matmul(const Variable& x, const Variable& y, bool trans_x, bool trans_y, float alpha) { - Instruction instr("matmul", {x, y}); + const auto& inputs = BroadcastMatmulInput(x, y, trans_x, trans_y, alpha); + + Instruction instr("matmul", {inputs.first, inputs.second}); instr.SetAttr("trans_a", trans_x); instr.SetAttr("trans_b", trans_y); instr.SetAttr("alpha", alpha); InferShape(instr); AppendInstruction(instr); - return instr.GetOutput(0); + auto out = instr.GetOutput(0); + + const auto& should_out_shape = GetMatmulOutputShape(x, y, trans_x, trans_y, alpha); + if (should_out_shape != out->shape) { + int should_out_size = std::accumulate(should_out_shape.begin(), should_out_shape.end(), 1, std::multiplies()); + int real_out_size = std::accumulate(out->shape.begin(), out->shape.end(), 1, std::multiplies()); + CHECK_EQ(should_out_size, real_out_size) + << "Cannot reshape the output:[" << out->id << "] of matmul from [" << cinn::utils::Join(out->shape, ", ") + << "] to [" << cinn::utils::Join(should_out_shape, ", ") << "]." + << " Whose input is " + << "matmul(X:" << x->id << "[" << cinn::utils::Join(x->shape, ", ") << "], Y:" << y->id << "[" + << cinn::utils::Join(y->shape, ", ") << "]" + << ", trans_x=" << trans_x << ", trans_y=" << trans_y << ", alpha=" << alpha << ")"; + out = Reshape(out, should_out_shape); + } + + return out; } Variable NetBuilder::ElementwiseOp(const std::string& op_type, const Variable& lhs, const Variable& rhs, int axis) { diff --git a/cinn/frontend/net_builder.h b/cinn/frontend/net_builder.h old mode 100755 new mode 100644 index e4c7355d18..10e635f595 --- a/cinn/frontend/net_builder.h +++ b/cinn/frontend/net_builder.h @@ -189,6 +189,12 @@ class NetBuilder : public BaseBuilder { protected: Variable ElementwiseOp(const std::string& op_type, const Variable& lhs, const Variable& rhs, int axis = -1); + + private: + // the helper function of Matmul + std::pair BroadcastMatmulInput( + const Variable& x, const Variable& y, bool trans_x, bool trans_y, float alpha); + std::vector GetMatmulOutputShape(const Variable& x, const Variable& y, bool trans_x, bool trans_y, float alpha); }; } // namespace frontend diff --git a/cinn/frontend/op_mappers/paddle/matmul.cc b/cinn/frontend/op_mappers/paddle/matmul.cc index c25d6fce75..7db1c86fea 100644 --- a/cinn/frontend/op_mappers/paddle/matmul.cc +++ b/cinn/frontend/op_mappers/paddle/matmul.cc @@ -28,14 +28,19 @@ void MatMulOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c auto out_name = op_desc.Output("Out").front(); auto trans_x = utils::GetAttrOrDefault(op_desc, "trans_x", false); + trans_x = utils::GetAttrOrDefault(op_desc, "transpose_X", trans_x); + auto trans_y = utils::GetAttrOrDefault(op_desc, "trans_y", false); + trans_y = utils::GetAttrOrDefault(op_desc, "transpose_Y", trans_y); + + auto alpha = utils::GetAttrOrDefault(op_desc, "alpha", 1.0f); VLOG(4) << out_name << "=matmul{" << x_name << ", " << y_name << ", trans_x=" << trans_x << ", trans_y=" << trans_y - << "}"; + << ", alpha=" << alpha << "}"; auto x = ctx.GetVar(x_name); auto y = ctx.GetVar(y_name); - auto out = ctx.Builder()->Matmul(x, y, trans_x, trans_y); + auto out = ctx.Builder()->Matmul(x, y, trans_x, trans_y, alpha); ctx.AddVar(out_name, out); ctx.AddVarModelToProgram(out_name, out->id); diff --git a/python/tests/ops/test_matmul_op.py b/python/tests/ops/test_matmul_op.py index 837908e04b..4f9d6ee307 100755 --- a/python/tests/ops/test_matmul_op.py +++ b/python/tests/ops/test_matmul_op.py @@ -124,5 +124,75 @@ def init_case(self): self.transpose_y = True +class TestMatmulCase7(TestMatmulOp): + def init_case(self): + self.inputs = { + "x": np.random.random([8, 16, 4]).astype("float32"), + "y": np.random.random([1, 4, 16]).astype("float32") + } + self.transpose_x = False + self.transpose_y = False + + +class TestMatmulCase8(TestMatmulOp): + def init_case(self): + self.inputs = { + "x": np.random.random([1, 8, 16, 4]).astype("float32"), + "y": np.random.random([2, 1, 4, 16]).astype("float32") + } + self.transpose_x = False + self.transpose_y = False + + +class TestMatmulCase9(TestMatmulOp): + def init_case(self): + self.inputs = { + "x": np.random.random([8, 16, 4]).astype("float32"), + "y": np.random.random([4, 16]).astype("float32") + } + self.transpose_x = False + self.transpose_y = False + + +class TestMatmulCase10(TestMatmulOp): + def init_case(self): + self.inputs = { + "x": np.random.random([2, 8, 16, 4]).astype("float32"), + "y": np.random.random([4, 16]).astype("float32") + } + self.transpose_x = False + self.transpose_y = False + + +class TestMatmulCase11(TestMatmulOp): + def init_case(self): + self.inputs = { + "x": np.random.random([2, 8, 4, 16]).astype("float32"), + "y": np.random.random([4, 16]).astype("float32") + } + self.transpose_x = True + self.transpose_y = False + + +class TestMatmulCase12(TestMatmulOp): + def init_case(self): + self.inputs = { + "x": np.random.random([2, 8, 16, 4]).astype("float32"), + "y": np.random.random([16]).astype("float32") + } + self.transpose_x = True + self.transpose_y = False + + +class TestMatmulCase13(TestMatmulOp): + def init_case(self): + self.inputs = { + "x": np.random.random([4, 16]).astype("float32"), + "y": np.random.random([16]).astype("float32") + } + self.transpose_x = False + self.transpose_y = False + + if __name__ == "__main__": unittest.main() From 7b2cbf6070afedc60472859dd7e0fe904ce7092b Mon Sep 17 00:00:00 2001 From: BiynXu <62832681+BiynXu@users.noreply.github.com> Date: Mon, 5 Sep 2022 16:46:57 +0800 Subject: [PATCH 4/8] add new schedule to pybind (#922) --- cinn/pybind/framework.cc | 42 ++++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/cinn/pybind/framework.cc b/cinn/pybind/framework.cc index f905643703..26463ea906 100644 --- a/cinn/pybind/framework.cc +++ b/cinn/pybind/framework.cc @@ -19,12 +19,16 @@ #include "cinn/common/cinn_value.h" #include "cinn/frontend/interpreter.h" +#include "cinn/hlir/framework/graph_compiler.h" #include "cinn/hlir/framework/node.h" #include "cinn/hlir/framework/op.h" #include "cinn/hlir/framework/op_strategy.h" #include "cinn/hlir/framework/scope.h" #include "cinn/hlir/op/use_ops.h" #include "cinn/pybind/bind.h" +#include "cinn/runtime/flags.h" + +DECLARE_bool(cinn_ir_schedule); namespace cinn::pybind { @@ -52,19 +56,33 @@ void BindFramework(pybind11::module *m) { res.push_back(tensor); temp_inputs.push_back(common::CINNValue(tensor)); } - common::CINNValuePack C = impl->fcompute(common::CINNValuePack{temp_inputs}); - poly::StageMap stages = C.back(); - // make sure all the tensors in the stages before schedule launch. - for (int i = 0; i < C->size() - 1; i++) { - ir::Expr temp = C[i]; - stages->InsertLazily(temp.as_tensor_ref()); - } - C = impl->fschedule(C); - for (int i = 0; i < C->size() - 1; i++) { - ir::Expr temp = C[i]; - res.push_back(temp.as_tensor_ref()); + + ir::LoweredFunc func; + if (FLAGS_cinn_ir_schedule) { + std::string output_name = "out"; + temp_inputs.emplace_back(output_name); + std::vector input_output_names; + for (const auto &input : inputs) { + input_output_names.push_back(input->name); + } + input_output_names.push_back(output_name); + func = hlir::framework::GetFuncFromImpl( + impl, common::CINNValuePack{temp_inputs}, res, input_output_names, key, target)[0]; + } else { + common::CINNValuePack C = impl->fcompute(common::CINNValuePack{temp_inputs}); + poly::StageMap stages = C.back(); + // make sure all the tensors in the stages before schedule launch. + for (int i = 0; i < C->size() - 1; i++) { + ir::Expr temp = C[i]; + stages->InsertLazily(temp.as_tensor_ref()); + } + C = impl->fschedule(C); + for (int i = 0; i < C->size() - 1; i++) { + ir::Expr temp = C[i]; + res.push_back(temp.as_tensor_ref()); + } + func = Lower(key, stages, res); } - auto func = Lower(key, stages, res); return func; }); From 7f5f423f605b55969130714076ca955998115cf1 Mon Sep 17 00:00:00 2001 From: BiynXu <62832681+BiynXu@users.noreply.github.com> Date: Mon, 5 Sep 2022 20:28:07 +0800 Subject: [PATCH 5/8] Add JSONDatabase (#912) * Add JSONDatabase * rewrite jsonfile_database * add constructors to TuningRecord * rename several class and functions * delete unnecessary functions * add serialize\deserialize and save\load unit tests * Delete temp JSON file after unittests --- CMakeLists.txt | 5 +- cinn/auto_schedule/CMakeLists.txt | 6 + cinn/auto_schedule/auto_schedule.proto | 22 ++++ cinn/auto_schedule/database/CMakeLists.txt | 3 +- cinn/auto_schedule/database/database.cc | 14 ++ cinn/auto_schedule/database/database.h | 18 +++ cinn/auto_schedule/database/database_test.cc | 19 +-- .../database/jsonfile_database.cc | 101 +++++++++++++++ .../database/jsonfile_database.h | 55 ++++++++ .../database/jsonfile_database_test.cc | 121 ++++++++++++++++++ .../auto_schedule/search_space/search_state.h | 2 + 11 files changed, 354 insertions(+), 12 deletions(-) create mode 100644 cinn/auto_schedule/auto_schedule.proto create mode 100644 cinn/auto_schedule/database/jsonfile_database.cc create mode 100644 cinn/auto_schedule/database/jsonfile_database.h create mode 100644 cinn/auto_schedule/database/jsonfile_database_test.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index c361cff06f..3d9c9ad1b9 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,7 +106,7 @@ message(STATUS "PYTHON_LIBRARIES: ${PYTHON_LIBRARIES}") message(STATUS "PYTHON_INCLUDE_DIR: ${PYTHON_INCLUDE_DIR}") INCLUDE_DIRECTORIES(${PYTHON_INCLUDE_DIR}) -cc_library(cinnapi SHARED SRCS ${cinnapi_src} DEPS glog ${llvm_libs} framework_proto param_proto framework_proto absl isl ginac pybind) +cc_library(cinnapi SHARED SRCS ${cinnapi_src} DEPS glog ${llvm_libs} framework_proto param_proto auto_schedule_proto framework_proto absl isl ginac pybind) add_dependencies(cinnapi GEN_LLVM_RUNTIME_IR_HEADER ZLIB::ZLIB) add_dependencies(cinnapi GEN_LLVM_RUNTIME_IR_HEADER ${core_deps}) @@ -133,7 +133,7 @@ function(gen_cinncore LINKTYPE) if (${LINKTYPE} STREQUAL "STATIC") set(CINNCORE_TARGET cinncore_static) endif() - cc_library(${CINNCORE_TARGET} ${LINKTYPE} SRCS ${core_src} DEPS glog ${llvm_libs} framework_proto param_proto framework_proto absl isl ginac) + cc_library(${CINNCORE_TARGET} ${LINKTYPE} SRCS ${core_src} DEPS glog ${llvm_libs} framework_proto param_proto auto_schedule_proto framework_proto absl isl ginac) add_dependencies(${CINNCORE_TARGET} GEN_LLVM_RUNTIME_IR_HEADER ZLIB::ZLIB) add_dependencies(${CINNCORE_TARGET} GEN_LLVM_RUNTIME_IR_HEADER ${core_deps}) @@ -205,6 +205,7 @@ if (PUBLISH_LIBS) COMMAND cmake -E copy ${CMAKE_BINARY_DIR}/libcinncore_static.a ${CMAKE_BINARY_DIR}/dist/cinn/lib/libcinncore_static.a COMMAND cmake -E copy ${CMAKE_BINARY_DIR}/cinn/frontend/paddle/libframework_proto.a ${CMAKE_BINARY_DIR}/dist/cinn/lib/libframework_proto.a COMMAND cmake -E copy ${CMAKE_BINARY_DIR}/cinn/hlir/pe/libparam_proto.a ${CMAKE_BINARY_DIR}/dist/cinn/lib/libparam_proto.a + COMMAND cmake -E copy ${CMAKE_BINARY_DIR}/cinn/auto_schedule/libauto_schedule_proto.a ${CMAKE_BINARY_DIR}/dist/cinn/lib/libauto_schedule_proto.a COMMENT "distribute libcinncore_static.a and related header files." DEPENDS cinncore_static ) diff --git a/cinn/auto_schedule/CMakeLists.txt b/cinn/auto_schedule/CMakeLists.txt index f5892ce4d5..1f6cf71443 100644 --- a/cinn/auto_schedule/CMakeLists.txt +++ b/cinn/auto_schedule/CMakeLists.txt @@ -7,8 +7,14 @@ add_subdirectory(search_strategy) add_subdirectory(task) add_subdirectory(task_scheduler) +proto_library(auto_schedule_proto SRCS auto_schedule.proto) + core_gather_headers() gather_srcs(cinnapi_src SRCS auto_tuner.cc) cc_test(test_auto_tuner SRCS auto_tuner_test.cc DEPS cinncore) + +foreach(header ${auto_schedule_proto_HDRS}) + set(core_proto_includes "${core_proto_includes};${header}" CACHE INTERNAL "") +endforeach() diff --git a/cinn/auto_schedule/auto_schedule.proto b/cinn/auto_schedule/auto_schedule.proto new file mode 100644 index 0000000000..3920591d15 --- /dev/null +++ b/cinn/auto_schedule/auto_schedule.proto @@ -0,0 +1,22 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax ="proto3"; + +package cinn.auto_schedule.proto; + +message TuningRecord { + string task_key = 1; + double execution_cost = 2; +} diff --git a/cinn/auto_schedule/database/CMakeLists.txt b/cinn/auto_schedule/database/CMakeLists.txt index 64ecb891d2..1c3ca9330b 100644 --- a/cinn/auto_schedule/database/CMakeLists.txt +++ b/cinn/auto_schedule/database/CMakeLists.txt @@ -1,5 +1,6 @@ core_gather_headers() -gather_srcs(cinnapi_src SRCS database.cc) +gather_srcs(cinnapi_src SRCS database.cc jsonfile_database.cc) cc_test(test_database SRCS database_test.cc DEPS cinncore) +cc_test(test_jsonfile_database SRCS jsonfile_database_test.cc DEPS cinncore) diff --git a/cinn/auto_schedule/database/database.cc b/cinn/auto_schedule/database/database.cc index a133c36b75..b57690d5db 100644 --- a/cinn/auto_schedule/database/database.cc +++ b/cinn/auto_schedule/database/database.cc @@ -14,6 +14,12 @@ #include "cinn/auto_schedule/database/database.h" +#include +#include +#include + +#include "cinn/ir/ir_schedule.h" + namespace cinn { namespace auto_schedule { @@ -21,6 +27,14 @@ bool TuningRecord::Compare::operator()(const TuningRecord& lhs, const TuningReco return lhs.execution_cost < rhs.execution_cost; } +proto::TuningRecord TuningRecord::ToProto() const { + proto::TuningRecord record_proto; + record_proto.set_task_key(task_key); + record_proto.set_execution_cost(execution_cost); + + return record_proto; +} + Database::Database(int capacity_per_task) : capacity_per_task_(capacity_per_task) { CHECK_GT(capacity_per_task_, 0) << "capacity_per_task_ should be greater than 0"; } diff --git a/cinn/auto_schedule/database/database.h b/cinn/auto_schedule/database/database.h index 2e0e5c5acc..e070f260c1 100644 --- a/cinn/auto_schedule/database/database.h +++ b/cinn/auto_schedule/database/database.h @@ -14,8 +14,14 @@ #pragma once +#include +#include +#include + +#include "cinn/auto_schedule/auto_schedule.pb.h" #include "cinn/auto_schedule/measure/measure.h" #include "cinn/auto_schedule/search_space/search_state.h" +#include "cinn/ir/ir_schedule.h" namespace cinn { namespace auto_schedule { @@ -34,6 +40,18 @@ struct TuningRecord { struct Compare { bool operator()(const TuningRecord& lhs, const TuningRecord& rhs) const; }; + + TuningRecord() = default; + + // initialize a TuningRecord object from a proto object + TuningRecord(const proto::TuningRecord& record_proto) + : task_key(record_proto.task_key()), execution_cost(record_proto.execution_cost()), state(ir::ModuleExpr()) {} + + TuningRecord(const std::string& task_key, double execution_cost, const SearchState& state) + : task_key(task_key), execution_cost(execution_cost), state(state) {} + + // convert to proto object + proto::TuningRecord ToProto() const; }; // A database supports insert or lookup historial tuning result with sepecified traits. diff --git a/cinn/auto_schedule/database/database_test.cc b/cinn/auto_schedule/database/database_test.cc index 19ce815719..2420e2f6ac 100644 --- a/cinn/auto_schedule/database/database_test.cc +++ b/cinn/auto_schedule/database/database_test.cc @@ -18,6 +18,7 @@ #include +#include "cinn/auto_schedule/auto_schedule.pb.h" #include "cinn/auto_schedule/search_space/search_state.h" #include "cinn/ir/ir_schedule.h" @@ -27,13 +28,13 @@ namespace auto_schedule { class TestDatabase : public ::testing::Test { public: TestDatabase() : test_db(2) { - test_db.AddRecord(TuningRecord({"k1", 1.0, SearchState(ir::ModuleExpr())})); - test_db.AddRecord(TuningRecord({"k2", 2.0, SearchState(ir::ModuleExpr())})); - test_db.AddRecord(TuningRecord({"k2", 3.0, SearchState(ir::ModuleExpr())})); - test_db.AddRecord(TuningRecord({"k3", 3.0, SearchState(ir::ModuleExpr())})); - test_db.AddRecord(TuningRecord({"k3", 4.0, SearchState(ir::ModuleExpr())})); - test_db.AddRecord(TuningRecord({"k3", 5.0, SearchState(ir::ModuleExpr())})); - test_db.AddRecord(TuningRecord({"k4", 4.0, SearchState(ir::ModuleExpr())})); + test_db.AddRecord(TuningRecord("k1", 1.0, SearchState(ir::ModuleExpr()))); + test_db.AddRecord(TuningRecord("k2", 2.0, SearchState(ir::ModuleExpr()))); + test_db.AddRecord(TuningRecord("k2", 3.0, SearchState(ir::ModuleExpr()))); + test_db.AddRecord(TuningRecord("k3", 3.0, SearchState(ir::ModuleExpr()))); + test_db.AddRecord(TuningRecord("k3", 4.0, SearchState(ir::ModuleExpr()))); + test_db.AddRecord(TuningRecord("k3", 5.0, SearchState(ir::ModuleExpr()))); + test_db.AddRecord(TuningRecord("k4", 4.0, SearchState(ir::ModuleExpr()))); } void SetUp() override {} @@ -59,8 +60,8 @@ TEST_F(TestDatabase, GetTopK) { SearchState state2(std::move(ir::ModuleExpr())); state1.predicted_cost = 1.2; state2.predicted_cost = 1.0; - test_db.AddRecord(TuningRecord({"k4", 2.0, state1})); - test_db.AddRecord(TuningRecord({"k4", 3.0, state2})); + test_db.AddRecord(TuningRecord("k4", 2.0, state1)); + test_db.AddRecord(TuningRecord("k4", 3.0, state2)); auto states = test_db.GetTopK("k4", 3); ASSERT_EQ(states.size(), 2); diff --git a/cinn/auto_schedule/database/jsonfile_database.cc b/cinn/auto_schedule/database/jsonfile_database.cc new file mode 100644 index 0000000000..04644b1cf2 --- /dev/null +++ b/cinn/auto_schedule/database/jsonfile_database.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/database/jsonfile_database.h" + +#include +#include +#include + +#include + +#include "cinn/auto_schedule/auto_schedule.pb.h" +#include "cinn/utils/multi_threading.h" + +namespace cinn { +namespace auto_schedule { + +// append a line to file +void AppendLineToFile(const std::string& file_path, const std::string& line) { + std::ofstream os(file_path, std::ofstream::app); + CHECK(os.good()) << "Cannot open the file to write: " << file_path; + os << line << std::endl; +} + +// read lines from a json file +std::vector ReadLinesFromFile(const std::string& file_path, bool allow_new_file) { + std::ifstream is(file_path); + if (is.good()) { + std::vector json_strs; + for (std::string str; std::getline(is, str);) { + json_strs.push_back(str); + } + + return json_strs; + } + CHECK(allow_new_file) << "File doesn't exist: " << file_path; + std::ofstream os(file_path); + CHECK(os.good()) << "Cannot create new file: " << file_path; + return {}; +} + +JSONFileDatabase::JSONFileDatabase(int capacity_per_task, const std::string& record_file_path, bool allow_new_file) + : Database(capacity_per_task), record_file_path_(record_file_path) { + auto json_lines = ReadLinesFromFile(record_file_path_, allow_new_file); + + std::vector all_records(json_lines.size()); + auto worker_fn = [this, &json_lines, &all_records](int index) { + all_records[index] = JSONToRecord(json_lines[index]); + }; + utils::parallel_run(worker_fn, utils::SequenceDispatcher(0, json_lines.size()), -1); + + for (const auto& record : all_records) { + auto& records = this->key2record_[record.task_key]; + records.emplace(record); + if (records.size() > this->capacity_per_task_) { + records.erase(std::prev(records.end())); + } + } +} + +// convert a TuningRecord object to string in JSON format +std::string JSONFileDatabase::RecordToJSON(const TuningRecord& record) { + proto::TuningRecord record_proto = record.ToProto(); + + std::string json_string; + auto status = google::protobuf::util::MessageToJsonString(record_proto, &json_string); + CHECK(status.ok()) << "Failed to serialize record to JSON, task key = " << record.task_key; + VLOG(4) << "json_string = \n" << json_string; + + return json_string; +} + +// convert a line of string in JSON format to a TuningRecord object +TuningRecord JSONFileDatabase::JSONToRecord(const std::string& json_string) { + cinn::auto_schedule::proto::TuningRecord record_proto; + auto status = google::protobuf::util::JsonStringToMessage(json_string, &record_proto); + CHECK(status.ok()) << "Failed to parse JSON: " << json_string; + + return TuningRecord(record_proto); +} + +bool JSONFileDatabase::Commit(const TuningRecord& record) { + std::string json_string = RecordToJSON(record); + AppendLineToFile(record_file_path_, json_string); + + return true; +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/cinn/auto_schedule/database/jsonfile_database.h b/cinn/auto_schedule/database/jsonfile_database.h new file mode 100644 index 0000000000..133580575a --- /dev/null +++ b/cinn/auto_schedule/database/jsonfile_database.h @@ -0,0 +1,55 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "database.h" + +namespace cinn { +namespace auto_schedule { + +// JSONFileDatabase is a database implemented by JSON file to save/load underlying data. +class JSONFileDatabase : public Database { + public: + /*! + * \brief Build a JSONFileDatabase object from a json file. + * \param capacity_per_task The max number of candidates stored. + * \param record_file_path The path of the json file. + * \param allow_new_file Whether to create new file when the given path is not found. + */ + explicit JSONFileDatabase(int capacity_per_task, const std::string& record_file_path, bool allow_new_file); + ~JSONFileDatabase() = default; + + // convert a TuningRecord object to string in JSON format + std::string RecordToJSON(const TuningRecord& record); + + // convert a line of string in JSON format to a TuningRecord object + TuningRecord JSONToRecord(const std::string& json_string); + + protected: + // commit the newly added record into json file + bool Commit(const TuningRecord& record) override; + + // the name of the json file to save tuning records. + std::string record_file_path_; +}; + +// append a line to file +void AppendLineToFile(const std::string& file_path, const std::string& line); + +// read lines from a json file +std::vector ReadLinesFromFile(const std::string& file_path, bool allow_new_file = true); + +} // namespace auto_schedule +} // namespace cinn diff --git a/cinn/auto_schedule/database/jsonfile_database_test.cc b/cinn/auto_schedule/database/jsonfile_database_test.cc new file mode 100644 index 0000000000..e0c152fa6f --- /dev/null +++ b/cinn/auto_schedule/database/jsonfile_database_test.cc @@ -0,0 +1,121 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/database/jsonfile_database.h" + +#include + +#include +#include + +#include "cinn/auto_schedule/search_space/search_state.h" +#include "cinn/ir/ir_schedule.h" + +namespace cinn { +namespace auto_schedule { + +void AddTestRecords(JSONFileDatabase& test_db) { + test_db.AddRecord(TuningRecord("k1", 1.0, SearchState(ir::ModuleExpr()))); + test_db.AddRecord(TuningRecord("k2", 2.0, SearchState(ir::ModuleExpr()))); + test_db.AddRecord(TuningRecord("k2", 3.0, SearchState(ir::ModuleExpr()))); + test_db.AddRecord(TuningRecord("k3", 3.0, SearchState(ir::ModuleExpr()))); + test_db.AddRecord(TuningRecord("k3", 4.0, SearchState(ir::ModuleExpr()))); + test_db.AddRecord(TuningRecord("k3", 5.0, SearchState(ir::ModuleExpr()))); + test_db.AddRecord(TuningRecord("k4", 4.0, SearchState(ir::ModuleExpr()))); + + SearchState state1(std::move(ir::ModuleExpr())); + SearchState state2(std::move(ir::ModuleExpr())); + state1.predicted_cost = 1.2; + state2.predicted_cost = 1.0; + test_db.AddRecord(TuningRecord("k4", 2.0, state1)); + test_db.AddRecord(TuningRecord("k4", 3.0, state2)); +} + +class TestJSONFileDatabase : public ::testing::Test { + public: + TestJSONFileDatabase() : record_file_path("/tmp/test_record.json"), test_db(2, record_file_path, true) { + if (0 == test_db.Size()) { + AddTestRecords(test_db); + } + } + + void SetUp() override {} + + void TearDown() override { + auto isFileExists = [](const std::string& file_path) -> bool { + std::ifstream f(file_path.c_str()); + return f.good(); + }; + if (isFileExists(record_file_path)) { + if (remove(record_file_path.c_str()) == 0) { + LOG(INFO) << "Successfully deleted file: " << record_file_path; + } else { + LOG(INFO) << "failed to delete file: " << record_file_path; + } + } else { + LOG(INFO) << "file: " << record_file_path << "does not exist."; + } + } + + std::string record_file_path; + JSONFileDatabase test_db; +}; + +TEST_F(TestJSONFileDatabase, SerializeAndDeserialize) { + TuningRecord record1("test", 1.0, SearchState(ir::ModuleExpr())); + std::string str = test_db.RecordToJSON(record1); + EXPECT_EQ(str, "{\"taskKey\":\"test\",\"executionCost\":1}"); + + TuningRecord record2 = test_db.JSONToRecord(str); + EXPECT_EQ(record1.task_key, record2.task_key); + EXPECT_EQ(record1.execution_cost, record2.execution_cost); +} + +TEST_F(TestJSONFileDatabase, SaveLoad) { + std::vector strs = ReadLinesFromFile(record_file_path); + ASSERT_EQ(strs.size(), 9); + EXPECT_EQ(strs[0], "{\"taskKey\":\"k1\",\"executionCost\":1}"); + EXPECT_EQ(strs[1], "{\"taskKey\":\"k2\",\"executionCost\":2}"); + EXPECT_EQ(strs[2], "{\"taskKey\":\"k2\",\"executionCost\":3}"); + EXPECT_EQ(strs[3], "{\"taskKey\":\"k3\",\"executionCost\":3}"); + EXPECT_EQ(strs[4], "{\"taskKey\":\"k3\",\"executionCost\":4}"); + EXPECT_EQ(strs[5], "{\"taskKey\":\"k3\",\"executionCost\":5}"); + EXPECT_EQ(strs[6], "{\"taskKey\":\"k4\",\"executionCost\":4}"); + EXPECT_EQ(strs[7], "{\"taskKey\":\"k4\",\"executionCost\":2}"); + EXPECT_EQ(strs[8], "{\"taskKey\":\"k4\",\"executionCost\":3}"); +} + +TEST_F(TestJSONFileDatabase, Basic) { + ASSERT_EQ(test_db.Size(), 7); + auto records = test_db.LookUp("k3"); + // check the max number of stored candidates will + // be restricted to capacity_per_task + ASSERT_EQ(test_db.Count("k3"), 2); + ASSERT_EQ(records.size(), 2); + EXPECT_EQ(records[0].execution_cost, 3.0); + EXPECT_EQ(records[1].execution_cost, 4.0); +} + +TEST_F(TestJSONFileDatabase, GetTopK) { + ASSERT_TRUE(test_db.GetTopK("k5", 2).empty()); + + auto states = test_db.GetTopK("k4", 3); + ASSERT_EQ(states.size(), 2); + + EXPECT_FLOAT_EQ(states[0].predicted_cost, 1.2); + EXPECT_FLOAT_EQ(states[1].predicted_cost, 1); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/cinn/auto_schedule/search_space/search_state.h b/cinn/auto_schedule/search_space/search_state.h index c3788c03d4..94e8071e96 100644 --- a/cinn/auto_schedule/search_space/search_state.h +++ b/cinn/auto_schedule/search_space/search_state.h @@ -45,6 +45,8 @@ class SearchState { // Negative constant standing for a cost not being initialized static constexpr float NOT_INIT_COST = -1.0; + SearchState() = default; + SearchState(const ir::ModuleExpr& mod_expr); SearchState(ir::ModuleExpr&& mod_expr); From b823da03a8aa134ef337d61b04423bea6358961d Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous <99723454+MayYouBeProsperous@users.noreply.github.com> Date: Tue, 6 Sep 2022 14:22:44 +0800 Subject: [PATCH 6/8] =?UTF-8?q?=E3=80=90PaddlePaddle=20Hackathon=2067?= =?UTF-8?q?=E3=80=91Add=20arange=20op=20(#919)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 提案:PaddlePaddle/community#195 --- cinn/frontend/net_builder.cc | 12 +++ cinn/frontend/net_builder.h | 2 + cinn/frontend/net_builder_test.cc | 72 +++++++++++++ cinn/hlir/op/contrib/CMakeLists.txt | 2 + cinn/hlir/op/contrib/arange.cc | 159 ++++++++++++++++++++++++++++ cinn/hlir/op/contrib/arange.h | 33 ++++++ cinn/hlir/op/contrib/arange_test.cc | 94 ++++++++++++++++ cinn/hlir/op/use_ops.h | 1 + 8 files changed, 375 insertions(+) create mode 100644 cinn/hlir/op/contrib/arange.cc create mode 100644 cinn/hlir/op/contrib/arange.h create mode 100644 cinn/hlir/op/contrib/arange_test.cc diff --git a/cinn/frontend/net_builder.cc b/cinn/frontend/net_builder.cc index 9ad7038369..3237ebe108 100644 --- a/cinn/frontend/net_builder.cc +++ b/cinn/frontend/net_builder.cc @@ -258,6 +258,18 @@ Variable NetBuilder::Clip(const std::vector& inputs, const float& max_ return instr.GetOutput(0); } +Variable NetBuilder::Arange(const float start, const float stop, const float step, const std::string& dtype) { + Instruction instr("arange"); + instr.SetInputs({}); + instr.SetAttr("start", start); + instr.SetAttr("stop", stop); + instr.SetAttr("step", step); + instr.SetAttr("dtype", dtype); + InferShape(instr); + AppendInstruction(instr); + return instr.GetOutput(0); +} + // conv2d grad, output(grad_x, grad_w) std::vector NetBuilder::Conv2dGrad(const Variable& dy, const Variable& x, diff --git a/cinn/frontend/net_builder.h b/cinn/frontend/net_builder.h index 10e635f595..8030946e97 100644 --- a/cinn/frontend/net_builder.h +++ b/cinn/frontend/net_builder.h @@ -173,6 +173,8 @@ class NetBuilder : public BaseBuilder { Variable Clip(const std::vector& inputs, const float& max_val, const float& min_val); + Variable Arange(const float start, const float stop, const float step, const std::string& dtype); + // conv2d grad, output(grad_x, grad_w) std::vector Conv2dGrad(const Variable& dy, const Variable& x, diff --git a/cinn/frontend/net_builder_test.cc b/cinn/frontend/net_builder_test.cc index 5ea3422c35..55244574a7 100644 --- a/cinn/frontend/net_builder_test.cc +++ b/cinn/frontend/net_builder_test.cc @@ -453,5 +453,77 @@ TEST(net_build, program_execute_squeeze_case3) { } } +TEST(net_build, program_execute_arange_float) { + const float start = 1.5F; + const float stop = 31.5F; + const float step = 2.0F; + const std::string dtype = "float32"; + + NetBuilder builder("net_builder"); + Variable out = builder.Arange(start, stop, step, dtype); + auto program = builder.Build(); + + Target target = common::DefaultHostTarget(); + + auto graph = std::make_shared(program, target); + auto scope = BuildScope(target, graph); + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + + scope->Var(std::string(out->id)); + + runtime_program->Execute(); + + auto out_tensor = scope->GetTensor(std::string(out->id)); + const std::vector& out_tensor_shape = out_tensor->shape().data(); + EXPECT_EQ(out_tensor->type(), Float(32)); + EXPECT_EQ(out_tensor_shape.size(), 1UL); + + int num_elem = static_cast(std::ceil((stop - start) / step)); + EXPECT_EQ(out_tensor_shape[0], num_elem); + + float* out_data = out_tensor->mutable_data(target); + for (int i = 0; i < out_tensor_shape[0]; ++i) { + EXPECT_NEAR(out_data[i], start + step * i, 1e-5); + VLOG(6) << out_data[i]; + } +} + +TEST(net_build, program_execute_arange_int) { + const float start = 1.5F; + const float stop = 31.5F; + const float step = 1.6F; + const std::string dtype = "int32"; + + NetBuilder builder("net_builder"); + Variable out = builder.Arange(start, stop, step, dtype); + auto program = builder.Build(); + + Target target = common::DefaultHostTarget(); + + auto graph = std::make_shared(program, target); + auto scope = BuildScope(target, graph); + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + + scope->Var(std::string(out->id)); + + runtime_program->Execute(); + + auto out_tensor = scope->GetTensor(std::string(out->id)); + const std::vector& out_tensor_shape = out_tensor->shape().data(); + EXPECT_EQ(out_tensor->type(), Int(32)); + EXPECT_EQ(out_tensor_shape.size(), 1UL); + + int num_elem = static_cast(std::ceil((stop - start) / step)); + EXPECT_EQ(out_tensor_shape[0], num_elem); + + int32_t* out_data = out_tensor->mutable_data(target); + for (int i = 0; i < out_tensor_shape[0]; ++i) { + EXPECT_EQ(out_data[i], static_cast(start + step * i)); + VLOG(6) << out_data[i]; + } +} + } // namespace frontend } // namespace cinn diff --git a/cinn/hlir/op/contrib/CMakeLists.txt b/cinn/hlir/op/contrib/CMakeLists.txt index 98f9811eb3..c6aeac910a 100644 --- a/cinn/hlir/op/contrib/CMakeLists.txt +++ b/cinn/hlir/op/contrib/CMakeLists.txt @@ -4,8 +4,10 @@ gather_srcs(cinnapi_src SRCS cast.cc squeeze.cc clip.cc + arange.cc ) cc_test(test_cast SRCS cast_test.cc DEPS cinncore) cc_test(test_squeeze SRCS squeeze_test.cc DEPS cinncore) cc_test(test_clip SRCS clip_test.cc DEPS cinncore) +cc_test(test_arange SRCS arange_test.cc DEPS cinncore) diff --git a/cinn/hlir/op/contrib/arange.cc b/cinn/hlir/op/contrib/arange.cc new file mode 100644 index 0000000000..bf9d0d221b --- /dev/null +++ b/cinn/hlir/op/contrib/arange.cc @@ -0,0 +1,159 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/hlir/op/contrib/arange.h" + +#include + +#include +#include +#include +#include + +#include "cinn/common/cas.h" +#include "cinn/common/common.h" +#include "cinn/common/context.h" +#include "cinn/common/macros.h" +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/op.h" +#include "cinn/hlir/framework/op_strategy.h" +#include "cinn/hlir/pe/nn.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/tensor.h" +#include "cinn/lang/builtin.h" +#include "cinn/lang/compute.h" + +DECLARE_bool(cinn_ir_schedule); + +namespace cinn { +namespace hlir { +namespace op { + +std::vector Arange( + const float start, const float stop, const float step, const Type &dtype, const std::string &output_name) { + int num_elem = static_cast(std::ceil((stop - start) / step)); + ir::Tensor res = lang::Compute( + {Expr(num_elem)}, + [=](const std::vector &indices) { + return ir::Cast::Make(dtype, start + step * cinn::common::cast(indices[0], common::Float(32))); + }, + common::UniqName(output_name)); + return {res}; +} + +std::vector> InferShapeForArange(const std::vector> &inputs_shape, + const framework::AttrMapType &attrs) { + float start = 0.0F; + float stop = 0.0F; + float step = 1.0F; + CHECK(attrs.find("stop") != attrs.end()) << "Please set the stop parameter of arange."; + + if (attrs.find("start") != attrs.end()) { + start = absl::get(attrs.at("start")); + } + if (attrs.find("stop") != attrs.end()) { + stop = absl::get(attrs.at("stop")); + } + if (attrs.find("step") != attrs.end()) { + step = absl::get(attrs.at("step")); + } + + CHECK_NE(step, 0) << "The value of step cann't be 0!"; + + int num_elem = static_cast(std::ceil((stop - start) / step)); + CHECK_GT(num_elem, 0) << "Invalid arange parameters, start = " << start << ", stop = " << stop << ", step = " << step + << ", cause num_elem = " << num_elem << " which is negative."; + + std::vector> res = {{num_elem}}; + return res; +} + +std::vector InferDtypeForArange(const std::vector &inputs_type, const framework::AttrMapType &attrs) { + std::string dtype = "float32"; + if (attrs.find("dtype") != attrs.end()) { + dtype = absl::get(attrs.at("dtype")); + } + std::vector res{common::Str2Type(dtype)}; + return res; +} + +std::shared_ptr StrategyForArange(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + std::string dtype = "float32"; + float start = 0.0F; + float stop = 0.0F; + float step = 1.0F; + + for (auto &iter : attrs.attr_store) { + if (iter.first == "dtype") { + dtype = absl::get(iter.second); + } else if (iter.first == "start") { + start = absl::get(iter.second); + } else if (iter.first == "stop") { + stop = absl::get(iter.second); + } else if (iter.first == "step") { + step = absl::get(iter.second); + } + } + + framework::CINNCompute arange_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of arange compute is empty! Please check.\n"; + + std::vector out = Arange(start, stop, step, common::Str2Type(dtype), common::UniqName("T_Arange_out")); + CHECK(out.size() == 1U) << "The size of Arange's output should be 1"; + + std::vector res; + auto stages = CreateStages({}); + for (auto &t : out) { + stages->InsertLazily(t); + res.push_back(common::CINNValue(t)); + } + + res.push_back(common::CINNValue(stages)); + *ret = common::CINNValuePack{res}; + }); + + framework::CINNSchedule arange_schedule([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of arange schedule is empty! Please check.\n"; + common::CINNValuePack arg_pack = args[0]; + Expr out = arg_pack[0]; + CHECK(out.as_tensor()); + *ret = arg_pack; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(arange_compute, arange_schedule, "strategy.arange.x86", 1); + return strategy; +} + +} // namespace op +} // namespace hlir +} // namespace cinn + +CINN_REGISTER_HELPER(arange_ops) { + CINN_REGISTER_OP(arange) + .describe("Returns evenly spaced values within a given interval.") + .set_num_inputs(0) + .set_num_outputs(1) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForArange) + .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForArange)) + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForArange)) + .set_support_level(4); + + return true; +} diff --git a/cinn/hlir/op/contrib/arange.h b/cinn/hlir/op/contrib/arange.h new file mode 100644 index 0000000000..2ad8e2d923 --- /dev/null +++ b/cinn/hlir/op/contrib/arange.h @@ -0,0 +1,33 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/tensor.h" + +namespace cinn { +namespace hlir { +namespace op { + +std::vector Arange( + const float start, const float stop, const float step, const Type& dtype, const std::string& output_name); + +} // namespace op +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/op/contrib/arange_test.cc b/cinn/hlir/op/contrib/arange_test.cc new file mode 100644 index 0000000000..f94f89d870 --- /dev/null +++ b/cinn/hlir/op/contrib/arange_test.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/hlir/op/contrib/arange.h" + +#include +#include + +#include +#include + +#include "cinn/backends/codegen_c.h" +#include "cinn/backends/codegen_c_x86.h" +#include "cinn/backends/codegen_cuda_dev.h" +#include "cinn/common/context.h" +#include "cinn/lang/lower.h" +#include "cinn/lang/placeholder.h" +#include "cinn/poly/stage.h" + +namespace cinn { +namespace hlir { +namespace op { + +TEST(GenerateCode_Cpu, Arange) { + common::Context::Global().ResetNameId(); + + common::Target target = common::DefaultHostTarget(); + float start = 1.5F; + float stop = 31.5F; + float step = 2.0F; + + std::vector res = Arange(start, stop, step, common::Float(32), "test_arange"); + + poly::StageMap stages = poly::CreateStages({res}); + std::vector funcs = + lang::LowerVec("TestGenerateCodeCpu_Arange", stages, res, {}, {}, nullptr, target, true); + + VLOG(6) << "Expr before CPU codegen:"; + VLOG(6) << funcs[0]->body; + + ir::Module::Builder builder("Arange_Module", target); + for (auto &f : funcs) { + builder.AddFunction(f); + } + + backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); + codegen.SetInlineBuiltinCodes(false); + std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); + VLOG(6) << "Cpu Codegen result:"; + VLOG(6) << code << std::endl; +} + +TEST(GenerateCode_Cuda, Arange) { + common::Context::Global().ResetNameId(); + + common::Target target = common::DefaultNVGPUTarget(); + float start = 1.5F; + float stop = 31.5F; + float step = 2.0F; + + std::vector res = Arange(start, stop, step, common::Float(32), "test_arange"); + + poly::StageMap stages = poly::CreateStages({res}); + std::vector funcs = + lang::LowerVec("TestGenerateCodeCpu_Arange", stages, res, {}, {}, nullptr, target, true); + + VLOG(6) << "Expr before CUDA codegen:"; + VLOG(6) << funcs[0]->body; + + ir::Module::Builder builder("Arange_Module", target); + for (auto &f : funcs) { + builder.AddFunction(f); + } + + backends::CodeGenCUDA_Dev codegen(target); + std::string code = codegen.Compile(builder.Build()); + VLOG(6) << "Cuda Codegen result:"; + VLOG(6) << code << std::endl; +} + +} // namespace op +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/op/use_ops.h b/cinn/hlir/op/use_ops.h index 96403dce89..1c0561581e 100644 --- a/cinn/hlir/op/use_ops.h +++ b/cinn/hlir/op/use_ops.h @@ -26,3 +26,4 @@ CINN_USE_REGISTER(cast_ops) CINN_USE_REGISTER(squeeze_ops) CINN_USE_REGISTER(reduce_ops) CINN_USE_REGISTER(clip_ops) +CINN_USE_REGISTER(arange_ops) From dded094b6187977480d482c6f6a5f74e605a6a69 Mon Sep 17 00:00:00 2001 From: haozech Date: Wed, 7 Sep 2022 15:52:58 +0800 Subject: [PATCH 7/8] fix unittest (#921) * fix unittest * fix injective schedule --- cinn/hlir/framework/op_lowering.cc | 2 +- cinn/hlir/op/nn.cc | 19 +++++--- cinn/hlir/pe/ir_schedule_pe.cc | 78 +++++++++++++++++------------- cinn/hlir/pe/schedule.cc | 5 -- cinn/ir/ir_schedule.cc | 4 +- cinn/lang/lower_impl.cc | 0 python/CMakeLists.txt | 2 +- python/tests/ops/test_relu_op.py | 0 8 files changed, 60 insertions(+), 50 deletions(-) mode change 100644 => 100755 cinn/hlir/framework/op_lowering.cc mode change 100755 => 100644 cinn/hlir/pe/ir_schedule_pe.cc mode change 100644 => 100755 cinn/hlir/pe/schedule.cc mode change 100755 => 100644 cinn/ir/ir_schedule.cc mode change 100644 => 100755 cinn/lang/lower_impl.cc mode change 100644 => 100755 python/tests/ops/test_relu_op.py diff --git a/cinn/hlir/framework/op_lowering.cc b/cinn/hlir/framework/op_lowering.cc old mode 100644 new mode 100755 index 4a9e5903b5..095caa6758 --- a/cinn/hlir/framework/op_lowering.cc +++ b/cinn/hlir/framework/op_lowering.cc @@ -400,7 +400,7 @@ std::vector OpLowerer::IRReduceCompute(poly::StageMap& stages, std::vector ast_exprs; for (auto& node : sub_group->nodes) { auto node_data = GetNodeData(node); - VLOG(3) << node->id(); + VLOG(3) << "In ReduceCompute, process node: " << node->id() << " with op type: " << node->op()->name; std::vector cinn_inputs; std::vector tensor_inputs = std::move(CollectInputTensor(func_args, tensor_map, node)); diff --git a/cinn/hlir/op/nn.cc b/cinn/hlir/op/nn.cc index bac3434ee2..60acfffa1b 100644 --- a/cinn/hlir/op/nn.cc +++ b/cinn/hlir/op/nn.cc @@ -1816,9 +1816,9 @@ std::shared_ptr StrategyForSoftmax(const framework::NodeAttr &attrs, std::string tensor_name = UniqName("Softmax_out"); if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(pack_args.size(), 2); - CHECK(pack_args[1].is_string()); - tensor_name = pack_args[1].operator std::string(); + CHECK_GE(pack_args.size(), 2); + CHECK(pack_args[pack_args.size() - 1].is_string()); + tensor_name = pack_args[pack_args.size() - 1].operator std::string(); } #ifdef CINN_WITH_MKLDNN @@ -1859,14 +1859,17 @@ std::shared_ptr StrategyForSoftmax(const framework::NodeAttr &attrs, if (target.arch == Target::Arch::NVGPU) { if (output_shapes[0].size() > 1) { auto all_blocks = ir_sch.GetAllBlocks(); - CHECK_EQ(all_blocks.size(), 2); - auto loops = ir_sch.GetLoops(all_blocks[1]); - auto splited_loops = ir_sch.Split(loops[1], {-1, 5}); + CHECK_EQ(all_blocks.size(), 3); + auto loops = ir_sch.GetLoops(all_blocks[2]); + int loop_index = 1; + if (output_shapes[0][0] == 1) loop_index--; + CHECK_GE(loops.size(), loop_index + 1); + auto splited_loops = ir_sch.Split(loops[loop_index], {-1, 5}); ir_sch.Bind(splited_loops[0], "blockIdx.z"); ir_sch.Bind(splited_loops[1], "threadIdx.z"); all_blocks = ir_sch.GetAllBlocks(); - loops = ir_sch.GetLoops(all_blocks[1]); - ir_sch.ComputeAt(all_blocks[0], loops.back()); + loops = ir_sch.GetLoops(all_blocks[2]); + ir_sch.ComputeAt(all_blocks[1], loops.back()); } std::vector res{CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = CINNValuePack{res}; diff --git a/cinn/hlir/pe/ir_schedule_pe.cc b/cinn/hlir/pe/ir_schedule_pe.cc old mode 100755 new mode 100644 index f71d4d9056..02d8bc3da5 --- a/cinn/hlir/pe/ir_schedule_pe.cc +++ b/cinn/hlir/pe/ir_schedule_pe.cc @@ -77,23 +77,15 @@ void IRCudaScheduleInjective(ir::IRSchedule &ir_sch, auto loops = ir_sch.GetLoops(all_blocks[0]); auto fused = ir_sch.Fuse(loops); - int num_thread = target.max_num_threads(); - int num_block = 1024; - int vector_width = 1; - int prod_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); - bool need_block_split = prod_size > num_thread * num_block * vector_width ? true : false; - if (need_block_split) { - auto splited = ir_sch.Split(fused, {num_block, num_thread, -1}); + int num_thread = target.max_num_threads(); + int vector_width = 1; + int prod_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + if (prod_size > num_thread) { + auto splited = ir_sch.Split(fused, {-1, num_thread}); ir_sch.Bind(splited[0], "blockIdx.x"); ir_sch.Bind(splited[1], "threadIdx.x"); } else { - if (prod_size > num_thread) { - auto splited = ir_sch.Split(fused, {-1, num_thread}); - ir_sch.Bind(splited[0], "blockIdx.x"); - ir_sch.Bind(splited[1], "threadIdx.x"); - } else { - ir_sch.Bind(fused, "threadIdx.x"); - } + ir_sch.Bind(fused, "threadIdx.x"); } VLOG(3) << "After IRCudaScheduleInjective, new ir is : " << ir_sch.GetModule().GetExprs().at(0); } @@ -162,6 +154,7 @@ void IRCudaSplitSchedule(ir::IRSchedule &ir_sch, ir_sch.Split(loops[0], {-1, target.max_num_threads()}); all_blocks = ir_sch.GetAllBlocks(); loops = ir_sch.GetLoops(all_blocks.back()); + CHECK_GT(loops.size(), 1); ir_sch.Bind(loops[0], "blockIdx.x"); ir_sch.Bind(loops[1], "threadIdx.x"); compute_at_level++; @@ -176,6 +169,7 @@ void IRCudaSplitSchedule(ir::IRSchedule &ir_sch, ir_sch.Split(loops[0], {-1, target.max_num_threads()}); all_blocks = ir_sch.GetAllBlocks(); loops = ir_sch.GetLoops(all_blocks.back()); + CHECK_GT(loops.size(), compute_at_level); ir_sch.SimpleComputeAt(all_blocks[i], loops[compute_at_level]); } } @@ -203,8 +197,7 @@ void IRCudaScheduleReduce(ir::IRSchedule &ir_sch, int index = ir_sch.GetLoops(output->name + "__reduce_init").size() - last_dimension_num; for (int idx = output_shape.size() - last_dimension_num; idx < static_cast(output_shape.size()) - 1; ++idx) { auto loops = ir_sch.GetLoops(output->name); - CHECK_GE(loops.size(), index + 2); - ir_sch.Fuse({loops[index], loops[index + 1]}); + if (loops.size() > index + 2) ir_sch.Fuse({loops[index], loops[index + 1]}); } int max_block_size = target.max_num_threads(); @@ -223,8 +216,8 @@ void IRCudaScheduleReduce(ir::IRSchedule &ir_sch, for (int idx = 0; idx < index - 1; ++idx) { auto loops = ir_sch.GetLoops(output->name); - CHECK_GE(loops.size(), 2U); - ir_sch.Fuse({loops[0], loops[1]}); + CHECK_GT(loops.size(), 2U); + if (loops.size() > 2) ir_sch.Fuse({loops[0], loops[1]}); } if (index > 0) { @@ -302,18 +295,34 @@ void IRCudaScheduleBlockReduce(ir::IRSchedule &ir_sch, ir::Tensor tmp_out, ir::Tensor out, const common::Target &target) { - int output_shape_size_without_reduce = tmp_out->shape.size() - 1; + VLOG(3) << "Begin IRCudaScheduleBlockReduce"; + int tmp_put_shape_size_without_reduce = 0; + for (auto i : tmp_out->shape) { + CHECK(i.is_constant()); + if (i.as_int32() != 1) tmp_put_shape_size_without_reduce++; + } + tmp_put_shape_size_without_reduce--; // fuse last parallel dimension - for (int idx = 0; idx < reduce_tmp_out->shape.size() - tmp_out->shape.size(); ++idx) { - auto loops = ir_sch.GetLoops(reduce_tmp_out->name); - ir_sch.Fuse({loops[output_shape_size_without_reduce], loops[output_shape_size_without_reduce + 1]}); + int reduce_temp_out_shape_size = 0; + for (auto i : reduce_tmp_out->shape) { + CHECK(i.is_constant()); + if (i.as_int32() != 1) reduce_temp_out_shape_size++; + } + + int tmp_out_shape_size = tmp_put_shape_size_without_reduce + 1; + for (int idx = 0; idx < reduce_temp_out_shape_size - tmp_out_shape_size; ++idx) { + auto loops = ir_sch.GetLoops(reduce_tmp_out->name); + int reduce_axis = reduce_tmp_out->reduce_axis.size(); + if (loops.size() >= tmp_put_shape_size_without_reduce + 2 + reduce_axis) + ir_sch.Fuse({loops[tmp_put_shape_size_without_reduce], loops[tmp_put_shape_size_without_reduce + 1]}); } // fuse parallel dimension - for (int idx = 0; idx < output_shape_size_without_reduce - 1; ++idx) { + for (int idx = 0; idx < tmp_put_shape_size_without_reduce - 1; ++idx) { for (auto &tensor : {reduce_tmp_out, tmp_out, out}) { - auto loops = ir_sch.GetLoops(tensor->name); - ir_sch.Fuse({loops[0], loops[1]}); + auto loops = ir_sch.GetLoops(tensor->name); + int reduce_axis = tensor->reduce_axis.size(); + if (loops.size() >= 2 + reduce_axis) ir_sch.Fuse({loops[0], loops[1]}); } } @@ -355,7 +364,7 @@ void IRCudaScheduleBlockReduce(ir::IRSchedule &ir_sch, for (auto &tensor : {reduce_tmp_out, tmp_out, out}) { auto loops = ir_sch.GetLoops(tensor->name); ir_sch.Bind(loops[0], "blockIdx.x"); - if (loops.size() > 1) { + if (loops.size() > 1U) { ir_sch.Bind(loops[1], "threadIdx.x"); } } @@ -438,10 +447,10 @@ void IRCudaTwoStepReduceSchedule(ir::IRSchedule &ir_sch, // fuse axis for (int idx = 0; idx < static_cast(internal->shape.size()) - 2; ++idx) { for (auto &tensor : {internal, tmp_out, out}) { - auto block = ir_sch.GetBlock(tensor->name); - auto loops = ir_sch.GetLoops(block); - CHECK_GE(loops.size(), 2U); - ir_sch.Fuse({loops[0], loops[1]}); + auto block = ir_sch.GetBlock(tensor->name); + auto loops = ir_sch.GetLoops(block); + int reduce_axis = tensor->reduce_axis.size(); + if (loops.size() >= 2 + reduce_axis) ir_sch.Fuse({loops[0], loops[1]}); } } @@ -481,7 +490,7 @@ void IRCudaTwoStepReduceSchedule(ir::IRSchedule &ir_sch, for (auto &tensor : {internal, tmp_out, out}) { auto block = ir_sch.GetBlock(tensor->name); auto loops = ir_sch.GetLoops(block); - ir_sch.Split(loops[0], {-1, ir::GetLoopExtent(loops[0])}); + if (!loops.empty()) ir_sch.Split(loops[0], {-1, ir::GetLoopExtent(loops[0])}); } } auto reshape_block = ir_sch.GetBlock(reshape->name); @@ -495,7 +504,7 @@ void IRCudaTwoStepReduceSchedule(ir::IRSchedule &ir_sch, for (auto &tensor : {internal, tmp_out, out}) { auto loops = ir_sch.GetLoops(tensor->name); - ir_sch.Bind(loops[0], "blockIdx.x"); + if (!loops.empty()) ir_sch.Bind(loops[0], "blockIdx.x"); if (loops.size() > 1) { ir_sch.Bind(loops[1], "threadIdx.x"); } @@ -585,8 +594,9 @@ void IRCudaScheduleConv(ir::IRSchedule &ir_sch, const common::Target &target) { VLOG(3) << "Didn't find saved param, key is: " << key; } else { VLOG(3) << "Find saved param! key is: " << key; - IRCudaScheduleConv2(ir_sch, input_pad, weights, output, target, key); - return; + // Todo:@Haoze temporarily turn off loading params + // IRCudaScheduleConv2(ir_sch, input_pad, weights, output, target, key); + // return; } ir_sch.ComputeInline(all_blocks[0]); int f_inner = GetInnerSplitter(c, h); diff --git a/cinn/hlir/pe/schedule.cc b/cinn/hlir/pe/schedule.cc old mode 100644 new mode 100755 index 689e44efa5..1dab366a3d --- a/cinn/hlir/pe/schedule.cc +++ b/cinn/hlir/pe/schedule.cc @@ -2193,7 +2193,6 @@ void CudaScheduleInjective(poly::Stage *stage, const std::vector &output_sh } int num_thread = target.max_num_threads(); - int num_block = 65535; int prod_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); if (prod_size <= num_thread) { stage->Bind(0, "threadIdx.x"); @@ -2205,10 +2204,6 @@ void CudaScheduleInjective(poly::Stage *stage, const std::vector &output_sh } if (new_num_thread == 1) LOG(FATAL) << "prod_size out of range: " << prod_size; - bool need_more_split = prod_size > new_num_thread * num_block ? true : false; - if (need_more_split) { - LOG(WARNING) << "prod_size out of range: " << prod_size << ", and new_num_thread is : " << new_num_thread; - } CHECK_GT(prod_size, new_num_thread); stage->Split(0, new_num_thread); stage->Bind(0, "blockIdx.x"); diff --git a/cinn/ir/ir_schedule.cc b/cinn/ir/ir_schedule.cc old mode 100755 new mode 100644 index 4ea96cd3c3..0e6427c46e --- a/cinn/ir/ir_schedule.cc +++ b/cinn/ir/ir_schedule.cc @@ -85,6 +85,8 @@ std::vector IRSchedule::Split(const std::string& block_name, int loop_inde } Expr IRSchedule::Fuse(const std::vector& loops) { + LOG(INFO) << "Tring to fuse : "; + for (auto& i : loops) LOG(INFO) << i; std::vector for_nodes; std::vector loop_vars; CHECK(!loops.empty()) << "The loops param of Fuse should not be empty! Please check."; @@ -1224,7 +1226,7 @@ std::vector ScheduleHelper::GetLoops(const Expr& block) const { } if (result.empty()) { LOG(INFO) << "exprs size is : " << exprs.size() << "\n and exprs[0] is : " << exprs[0]; - LOG(FATAL) << "Didn't find Loops containing ScheduleBlock with name: \n" << block_name << " in ModuleExepr."; + LOG(ERROR) << "Didn't find Loops containing ScheduleBlock with name: \n" << block_name << " in ModuleExepr."; } for (auto& it_for : result) VLOG(3) << "Get Loops :\n" << it_for; return result; diff --git a/cinn/lang/lower_impl.cc b/cinn/lang/lower_impl.cc old mode 100644 new mode 100755 diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index c977c7ebbd..171cec34bf 100755 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -76,7 +76,7 @@ ADD_TEST(NAME test_cinnbuilder WORKING_DIRECTORY ${CMAKE_BINARY_DIR} ) -ADD_TEST(NAME test_computation +ADD_TEST(NAME test_computation_python COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_BINARY_DIR}/python:$ENV{PYTHONPATH} python3 ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_computation.py ${CMAKE_BINARY_DIR}/thirds/naive_mul_model diff --git a/python/tests/ops/test_relu_op.py b/python/tests/ops/test_relu_op.py old mode 100644 new mode 100755 From 68dfadff1a507472c9c8bbb8dce569416de09adf Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Wed, 7 Sep 2022 17:21:31 +0800 Subject: [PATCH 8/8] [AutoTuning] Cost Model Feature (#876) As the title. --- cinn/auto_schedule/auto_tuner.cc | 1 + cinn/auto_schedule/auto_tuner_test.cc | 63 ++-- cinn/auto_schedule/cost_model/CMakeLists.txt | 14 +- .../cost_model/expr_cost_model.cc | 77 +++++ .../cost_model/expr_cost_model.h | 45 +++ cinn/auto_schedule/cost_model/feature.cc | 175 +++++++++++ cinn/auto_schedule/cost_model/feature.h | 178 +++++++++++ .../cost_model/feature_extractor.cc | 292 ++++++++++++++++++ .../cost_model/feature_extractor.h | 60 ++++ .../cost_model/feature_extractor_test.cc | 158 ++++++++++ cinn/auto_schedule/cost_model/feature_test.cc | 28 ++ .../{cost_model.cc => xgb_cost_model.cc} | 54 ++-- .../{cost_model.h => xgb_cost_model.h} | 48 ++- ...t_model_test.cc => xgb_cost_model_test.cc} | 15 +- .../search_space/search_space.cc | 11 +- .../auto_schedule/search_space/search_space.h | 4 +- .../auto_schedule/search_space/search_state.h | 5 +- .../search_strategy/evolutionary_search.cc | 5 +- .../search_strategy/evolutionary_search.h | 5 +- .../evolutionary_search_test.cc | 13 +- cinn/auto_schedule/task/task_optimizer.cc | 22 +- cinn/auto_schedule/task/task_optimizer.h | 6 +- cinn/common/CMakeLists.txt | 1 + cinn/common/cost_model.h | 40 +++ cinn/common/python_interpreter_guard.cc | 32 ++ cinn/common/python_interpreter_guard.h | 43 +++ cinn/runtime/flags.cc | 5 + 27 files changed, 1308 insertions(+), 92 deletions(-) create mode 100644 cinn/auto_schedule/cost_model/expr_cost_model.cc create mode 100644 cinn/auto_schedule/cost_model/expr_cost_model.h create mode 100644 cinn/auto_schedule/cost_model/feature.cc create mode 100644 cinn/auto_schedule/cost_model/feature.h create mode 100644 cinn/auto_schedule/cost_model/feature_extractor.cc create mode 100644 cinn/auto_schedule/cost_model/feature_extractor.h create mode 100644 cinn/auto_schedule/cost_model/feature_extractor_test.cc create mode 100644 cinn/auto_schedule/cost_model/feature_test.cc rename cinn/auto_schedule/cost_model/{cost_model.cc => xgb_cost_model.cc} (61%) rename cinn/auto_schedule/cost_model/{cost_model.h => xgb_cost_model.h} (50%) rename cinn/auto_schedule/cost_model/{cost_model_test.cc => xgb_cost_model_test.cc} (81%) create mode 100644 cinn/common/cost_model.h create mode 100644 cinn/common/python_interpreter_guard.cc create mode 100644 cinn/common/python_interpreter_guard.h diff --git a/cinn/auto_schedule/auto_tuner.cc b/cinn/auto_schedule/auto_tuner.cc index 6b3fd348c6..95f8899a96 100644 --- a/cinn/auto_schedule/auto_tuner.cc +++ b/cinn/auto_schedule/auto_tuner.cc @@ -15,6 +15,7 @@ #include "cinn/auto_schedule/auto_tuner.h" #include +#include #include #include diff --git a/cinn/auto_schedule/auto_tuner_test.cc b/cinn/auto_schedule/auto_tuner_test.cc index 7ffe0281f9..83a052bb29 100644 --- a/cinn/auto_schedule/auto_tuner_test.cc +++ b/cinn/auto_schedule/auto_tuner_test.cc @@ -24,6 +24,9 @@ #include "cinn/frontend/syntax.h" #include "cinn/hlir/framework/graph_compiler.h" #include "cinn/ir/ir_base.h" +#include "cinn/runtime/flags.h" + +DECLARE_bool(auto_schedule_use_cost_model); namespace cinn { namespace auto_schedule { @@ -88,6 +91,32 @@ class TestAutoTuner : public ::testing::Test { ASSERT_EQ(2, runtime_program->size()); runtime_program->Execute(); } + + void ZeroMeasure() { + // set config and options + AutoTuner::Config tuning_config; + tuning_config.task_schedule_strategy = "round_robin"; + + TuningOptions tuning_options; + tuning_options.num_measure_trials = 0; + auto result = InitializeAndTune(tuning_config, tuning_options); + BasicCheckResult(result); + ApplyTunedAndRun(result); + } + + void NonZeroMeasure() { + // set config and options + AutoTuner::Config tuning_config; + tuning_config.task_schedule_strategy = "round_robin"; + + TuningOptions tuning_options; + tuning_options.num_measure_trials = 4; + tuning_options.num_samples_per_iteration = 2; + + auto result = InitializeAndTune(tuning_config, tuning_options); + BasicCheckResult(result); + ApplyTunedAndRun(result); + } }; frontend::Program TestAutoTuner::CreateAddReluProgram() { @@ -101,30 +130,24 @@ frontend::Program TestAutoTuner::CreateAddReluProgram() { return builder.Build(); } -TEST_F(TestAutoTuner, ZeroMeasure) { - // set config and options - AutoTuner::Config tuning_config; - tuning_config.task_schedule_strategy = "round_robin"; - - TuningOptions tuning_options; - tuning_options.num_measure_trials = 0; - auto result = InitializeAndTune(tuning_config, tuning_options); - BasicCheckResult(result); - ApplyTunedAndRun(result); +TEST_F(TestAutoTuner, ZeroMeasure_DisableCostModel) { + FLAGS_auto_schedule_use_cost_model = false; + ZeroMeasure(); } -TEST_F(TestAutoTuner, NonZeroMeasure) { - // set config and options - AutoTuner::Config tuning_config; - tuning_config.task_schedule_strategy = "round_robin"; +TEST_F(TestAutoTuner, ZeroMeasure_EnableCostModel) { + FLAGS_auto_schedule_use_cost_model = true; + ZeroMeasure(); +} - TuningOptions tuning_options; - tuning_options.num_measure_trials = 4; - tuning_options.num_samples_per_iteration = 2; +TEST_F(TestAutoTuner, NonZeroMeasure_DisableCostModel) { + FLAGS_auto_schedule_use_cost_model = false; + NonZeroMeasure(); +} - auto result = InitializeAndTune(tuning_config, tuning_options); - BasicCheckResult(result); - ApplyTunedAndRun(result); +TEST_F(TestAutoTuner, NonZeroMeasure_EnableCostModel) { + FLAGS_auto_schedule_use_cost_model = true; + NonZeroMeasure(); } } // namespace auto_schedule diff --git a/cinn/auto_schedule/cost_model/CMakeLists.txt b/cinn/auto_schedule/cost_model/CMakeLists.txt index 8c465b8cbf..6e52f7a3da 100644 --- a/cinn/auto_schedule/cost_model/CMakeLists.txt +++ b/cinn/auto_schedule/cost_model/CMakeLists.txt @@ -1,13 +1,7 @@ core_gather_headers() -gather_srcs(cinnapi_src SRCS cost_model.cc) +gather_srcs(cinnapi_src SRCS xgb_cost_model.cc expr_cost_model.cc feature.cc feature_extractor.cc) -set(Python_VIRTUALENV FIRST) -find_package(PythonInterp ${PY_VERSION} REQUIRED) -find_package(PythonLibs ${PY_VERSION} REQUIRED) - -if (WITH_TESTING) - cc_test(test_cost_model SRCS cost_model_test.cc cost_model.cc DEPS pybind gtest_main) - - target_link_libraries(test_cost_model ${PYTHON_LIBRARIES}) -endif() +cc_test(test_xgb_cost_model SRCS xgb_cost_model_test.cc DEPS cinncore) +cc_test(test_feature_extractor SRCS feature_extractor_test.cc DEPS cinncore) +cc_test(test_feature SRCS feature_test.cc DEPS cinncore) diff --git a/cinn/auto_schedule/cost_model/expr_cost_model.cc b/cinn/auto_schedule/cost_model/expr_cost_model.cc new file mode 100644 index 0000000000..e41a71a409 --- /dev/null +++ b/cinn/auto_schedule/cost_model/expr_cost_model.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/cost_model/expr_cost_model.h" + +#include + +#include +#include + +#include "cinn/auto_schedule/cost_model/feature.h" +#include "cinn/auto_schedule/cost_model/feature_extractor.h" +#include "cinn/auto_schedule/search_space/search_state.h" +#include "cinn/common/target.h" +#include "cinn/ir/ir_schedule.h" + +namespace cinn { +namespace auto_schedule { + +float ExprCostModel::Predict(const ir::ModuleExpr& sample, const common::Target& target) const { + if (trained_times_.load() == 0) { + return SearchState::NOT_INIT_COST; + } + FeatureExtractor extractor; + Feature feature = extractor.Extract(sample, target); + std::vector feature_numbers = feature.ToFixedSizeVector(); + std::vector pred = XgbCostModel::Predict({feature_numbers}); + return pred[0]; +} + +void ExprCostModel::Train(const std::vector& samples, + const std::vector& labels, + const common::Target& target) { + trained_times_.store(1); + size_t total_size = samples.size(); + CHECK_EQ(total_size, labels.size()) << "Samples must have same size as labels"; + std::vector> train_feature_numbers(total_size); + FeatureExtractor extractor; + for (size_t i = 0; i < total_size; ++i) { + CHECK(samples[i] != nullptr) << "Train samples cannot be nullptr"; + Feature feature = extractor.Extract(*samples[i], target); + train_feature_numbers[i] = feature.ToFixedSizeVector(); + } + + XgbCostModel::Train(train_feature_numbers, labels); +} + +void ExprCostModel::Update(const std::vector& samples, + const std::vector& labels, + const common::Target& target) { + ++trained_times_; + size_t total_size = samples.size(); + CHECK_EQ(total_size, labels.size()) << "Samples must have same size as labels"; + std::vector> train_feature_numbers(total_size); + FeatureExtractor extractor; + for (size_t i = 0; i < total_size; ++i) { + CHECK(samples[i] != nullptr) << "Train samples cannot be nullptr"; + Feature feature = extractor.Extract(*samples[i], target); + train_feature_numbers[i] = feature.ToFixedSizeVector(); + } + + XgbCostModel::Update(train_feature_numbers, labels); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/cinn/auto_schedule/cost_model/expr_cost_model.h b/cinn/auto_schedule/cost_model/expr_cost_model.h new file mode 100644 index 0000000000..963521ec62 --- /dev/null +++ b/cinn/auto_schedule/cost_model/expr_cost_model.h @@ -0,0 +1,45 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/auto_schedule/cost_model/xgb_cost_model.h" +#include "cinn/ir/ir_schedule.h" + +namespace cinn { +namespace auto_schedule { + +/** + * A C++ cost model which trains and predicts on ir::Expr + * + */ +class ExprCostModel : public XgbCostModel { + public: + float Predict(const ir::ModuleExpr& sample, const common::Target& target) const; + void Train(const std::vector& samples, + const std::vector& labels, + const common::Target& target); + void Update(const std::vector& samples, + const std::vector& labels, + const common::Target& target); + + private: + std::atomic trained_times_{0}; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/cinn/auto_schedule/cost_model/feature.cc b/cinn/auto_schedule/cost_model/feature.cc new file mode 100644 index 0000000000..dd8ef051a7 --- /dev/null +++ b/cinn/auto_schedule/cost_model/feature.cc @@ -0,0 +1,175 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/cost_model/feature.h" + +#include + +#include + +#include "cinn/common/target.h" + +namespace cinn { +namespace auto_schedule { + +Feature::Feature() + : target_(common::UnkTarget()), + stack_encoded_feature_(1), // initialze a LoopBlockFeature as root block + current_loop_block_index_(0), + parent_indices_(1, -1) {} + +Feature::Feature(const common::Target& target) + : target_(target), + stack_encoded_feature_(1), // initialze a LoopBlockFeature as root block + current_loop_block_index_(0), + parent_indices_(1, -1) {} + +std::vector Feature::ToFixedSizeVector() { + std::vector ret(LoopBlockFeature::kTotalSize + 1, 0); // LoopBlockFeature::kTotalSize plus 1 for target + + if (target_ == common::DefaultNVGPUTarget()) { + ret[0] = 1; + } // else 0 for other cases + + // loop[i] feature count should multiply iter_multi_num[i] + std::vector iter_multi_num; + for (size_t i = 0; i < stack_encoded_feature_.size(); ++i) { + int j = 1; + const LoopBlockFeature& loop_feature = stack_encoded_feature_[i]; + int loop_prod = 1; + int parent_prod = 1; + if (i != 0) { + parent_prod = iter_multi_num[parent_indices_[i]]; + loop_prod = parent_prod * loop_feature.loop_length; + } + iter_multi_num.push_back(loop_prod); + + ret[j] += (loop_feature.float_add_or_sub * loop_prod); + ++j; + ret[j] += (loop_feature.float_mul * loop_prod); + ++j; + ret[j] += (loop_feature.float_div_or_mod * loop_prod); + ++j; + ret[j] += (loop_feature.float_cmp * loop_prod); + ++j; + ret[j] += (loop_feature.float_math_func * loop_prod); + ++j; + ret[j] += (loop_feature.float_other_call * loop_prod); + ++j; + + ret[j] += (loop_feature.int_add_or_sub * loop_prod); + ++j; + ret[j] += (loop_feature.int_mul * loop_prod); + ++j; + ret[j] += (loop_feature.int_div_or_mod * loop_prod); + ++j; + ret[j] += (loop_feature.int_cmp * loop_prod); + ++j; + ret[j] += (loop_feature.int_math_func * loop_prod); + ++j; + ret[j] += (loop_feature.int_other_call * loop_prod); + ++j; + + ret[j] += (loop_feature.bool_op * loop_prod); + ++j; + ret[j] += (loop_feature.select_op * loop_prod); + ++j; + + ret[j] += (loop_feature.mem_alloc * loop_prod); + ++j; + ret[j] += (loop_feature.mem_free * loop_prod); + ++j; + ret[j] += (loop_feature.mem_read * loop_prod); + ++j; + ret[j] += (loop_feature.mem_write * loop_prod); + ++j; + + ret[j] += (loop_feature.float_reduce_sum_or_sub * loop_prod); + ++j; + ret[j] += (loop_feature.float_reduce_mul * loop_prod); + ++j; + ret[j] += (loop_feature.float_reduce_div * loop_prod); + ++j; + ret[j] += (loop_feature.float_reduce_max_or_min * loop_prod); + ++j; + ret[j] += (loop_feature.float_broadcast * loop_prod); + ++j; + + ret[j] += (loop_feature.int_reduce_sum_or_sub * loop_prod); + ++j; + ret[j] += (loop_feature.int_reduce_mul * loop_prod); + ++j; + ret[j] += (loop_feature.int_reduce_div * loop_prod); + ++j; + ret[j] += (loop_feature.int_reduce_max_or_min * loop_prod); + ++j; + ret[j] += (loop_feature.int_broadcast * loop_prod); + ++j; + + ret[j + static_cast(loop_feature.loop_opt_type)] += 1; + j += LoopBlockFeature::kOptApplySize; + + ret[j] += (loop_feature.len_blockIdx_x * parent_prod); + ++j; + ret[j] += (loop_feature.len_blockIdx_y * parent_prod); + ++j; + ret[j] += (loop_feature.len_blockIdx_z * parent_prod); + ++j; + ret[j] += (loop_feature.len_threadIdx_x * parent_prod); + ++j; + ret[j] += (loop_feature.len_threadIdx_y * parent_prod); + ++j; + ret[j] += (loop_feature.len_threadIdx_z * parent_prod); + ++j; + ret[j] += (loop_feature.len_vthread * parent_prod); + ++j; + ret[j] += (loop_feature.vectorize_factor * parent_prod); + ++j; + } + + for (size_t i = 0; i < ret.size(); ++i) { + ret[i] = slog(ret[i]); + } + + return ret; +} + +void Feature::IntoLoopBlock() { + stack_encoded_feature_.emplace_back(LoopBlockFeature()); + stack_encoded_feature_[current_loop_block_index_].num_sub_loops += 1; + parent_indices_.push_back(current_loop_block_index_); + current_loop_block_index_ = stack_encoded_feature_.size() - 1; +} + +void Feature::ExitLoopBlock() { current_loop_block_index_ = parent_indices_[current_loop_block_index_]; } + +LoopBlockFeature& Feature::CurrentLoopBlock() { return stack_encoded_feature_[current_loop_block_index_]; } + +const LoopBlockFeature& Feature::CurrentLoopBlock() const { return stack_encoded_feature_[current_loop_block_index_]; } + +} // namespace auto_schedule +} // namespace cinn diff --git a/cinn/auto_schedule/cost_model/feature.h b/cinn/auto_schedule/cost_model/feature.h new file mode 100644 index 0000000000..994516fe30 --- /dev/null +++ b/cinn/auto_schedule/cost_model/feature.h @@ -0,0 +1,178 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/common/target.h" +#include "cinn/ir/ir_schedule.h" + +namespace cinn { +namespace auto_schedule { + +/* Loop feature enums */ +enum class ForOptimizeFeatureEnum : int { kNone, kGpuBind, kParallel, kUnroll, kVectorize }; + +/* function to scale feature numbers */ +inline float slog(float x) { return x < 0 ? std::log2(-x + 1) : std::log2(x + 1); } + +class LoopBlockFeature { + public: + // TODO(zhhsplendid): distinguish more types such as float16, float32, + // float64, etc. However speed the gap between float and int are larger than + // different bits, so we just distinguished int and float here + /* Arithmetic features */ + int float_add_or_sub = 0; + int float_mul = 0; + int float_div_or_mod = 0; + int float_cmp = 0; + int float_math_func = 0; + int float_other_call = 0; // like simple assign, cast, etc. + + int int_add_or_sub = 0; + int int_mul = 0; + int int_div_or_mod = 0; + int int_cmp = 0; + int int_math_func = 0; + int int_other_call = 0; // like simple assign, cast, etc. + + int bool_op = 0; + int select_op = 0; + + static constexpr int kArithSize = 6 * 2 + 2; + + /** + * Buffer memory features, which is the number of memory operations. + * Note that different size of memory operation can have various speed, + * however the speed difference would be small in OS. A meticulous TODO + * may be collect oprand sizes (like alloc size, write size, or so) + */ + int mem_alloc = 0; + int mem_free = 0; + int mem_read = 0; + int mem_write = 0; + + static constexpr int kMemSize = 4; + + /** + * Reduce and Broadcast features + */ + int float_reduce_sum_or_sub = 0; + int float_reduce_mul = 0; + int float_reduce_div = 0; + int float_reduce_max_or_min = 0; + int float_broadcast = 0; + + int int_reduce_sum_or_sub = 0; + int int_reduce_mul = 0; + int int_reduce_div = 0; + int int_reduce_max_or_min = 0; + int int_broadcast = 0; + + static constexpr int kReduceBroadcastSize = 10; + + /* Loop type features */ + + // A TODO mayby add loop position (Inner, Outer, Middle) feature + + ForOptimizeFeatureEnum loop_opt_type = ForOptimizeFeatureEnum::kNone; + + static constexpr int kOptApplySize = 5; + + /* Thread features if loop is optimized by GPU or CPU parallelism. + * Useless in other cases. + */ + int len_blockIdx_x = 0; + int len_blockIdx_y = 0; + int len_blockIdx_z = 0; + int len_threadIdx_x = 0; + int len_threadIdx_y = 0; + int len_threadIdx_z = 0; + int len_vthread = 0; // length of virtual thread + int vectorize_factor = 0; + + static constexpr int kThreadFeatureSize = 8; + + static constexpr int kTotalSize = kArithSize + kMemSize + kReduceBroadcastSize + kOptApplySize + kThreadFeatureSize; + + /* Non-feature attributes, used to maintain during feature_extractor */ + + // Number to indicate the loop block inside current one + int num_sub_loops = 0; + + // Number of repeats of this loop, -1 represents unknown + int loop_length = 1; +}; + +/** + * Feature of Expr. It is used in CostModel + */ +class Feature { + public: + Feature(); + + Feature(const common::Target& target); + + // Convert the various-length loop block features to fixed-size vector + std::vector ToFixedSizeVector(); + + // Call when visit into a loop block to collect LoopBlockFeature + void IntoLoopBlock(); + // Call when exit a loop block to collect LoopBlockFeature + void ExitLoopBlock(); + // The current loop block which we should collect feature on + LoopBlockFeature& CurrentLoopBlock(); + // The current loop block which we should collect feature on + const LoopBlockFeature& CurrentLoopBlock() const; + + private: + // We treat a computation feature to be encoded as variable-length vector. + // The root compute block is not a loop, but we treat it as a size-1 loop. + // Blocks are encoded like a stack. Each LoopBlockFeature contains a + // num_sub_loops to indicate the next level sub-loop-block it contains. + // + // For example, code like: + // + // some_compute_0 + // loop1 { + // some_compute_1 + // loop2 { + // some_compute_2 + // } + // } + // + // loop3 { + // some_compute_3 + // } + // + // We go through the code and push loops into stack, then the features are encoded as + // [loop_block_feature_0, loop_block_feature_1, loop_block_feature_2, loop_block_feature_3] + // where loop_block_feature_i stores the features of some_compute_i (such + // as number of arithmetic operations) + // + // loop_block_feature_0.num_sub_loops = 2 + // loop_block_feature_1.num_sub_loops = 1 + // loop_block_feature_2.num_sub_loops = 0 + // loop_block_feature_3.num_sub_loops = 0 + std::vector stack_encoded_feature_; + int current_loop_block_index_; + std::vector parent_indices_; + + common::Target target_; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/cinn/auto_schedule/cost_model/feature_extractor.cc b/cinn/auto_schedule/cost_model/feature_extractor.cc new file mode 100644 index 0000000000..f66cf7b250 --- /dev/null +++ b/cinn/auto_schedule/cost_model/feature_extractor.cc @@ -0,0 +1,292 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/cost_model/feature_extractor.h" + +#include + +#include "cinn/common/target.h" +#include "cinn/common/type.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/optim/ir_copy.h" +#include "cinn/optim/transform_polyfor_to_for.h" + +namespace cinn { +namespace auto_schedule { + +using namespace ::cinn::ir; + +FeatureExtractor::FeatureExtractor() {} + +void FeatureExtractor::Visit(const Expr *x) { IRVisitor::Visit(x); } + +Feature FeatureExtractor::Extract(const ir::ModuleExpr &mod_expr, const common::Target &target) { + feature_ = Feature(target); + for (const ir::Expr &e : mod_expr.GetExprs()) { + Visit(&e); + } + return feature_; +} + +#define VisitDoNothing(NodeType) \ + void FeatureExtractor::Visit(const NodeType *x) { \ + std::vector sub_exprs = x->expr_fields(); \ + for (const Expr *e : sub_exprs) { \ + Visit(e); \ + } \ + } + +VisitDoNothing(IntImm); +VisitDoNothing(UIntImm); +VisitDoNothing(FloatImm); +VisitDoNothing(StringImm); + +VisitDoNothing(Block); +VisitDoNothing(_Module_); +VisitDoNothing(_Var_); +VisitDoNothing(_LoweredFunc_); +VisitDoNothing(ScheduleBlock); +VisitDoNothing(ScheduleBlockRealize); +VisitDoNothing(Ramp); +VisitDoNothing(_Buffer_); +VisitDoNothing(_BufferRange_); + +#define NotVisitExprFields(NodeType) \ + void FeatureExtractor::Visit(const NodeType *x) {} + +NotVisitExprFields(_Tensor_) + +#define VisitForDtypePattern(NodeType, member) \ + void FeatureExtractor::Visit(const NodeType *x) { \ + if (x->type() == common::F32() || x->type() == common::F16() || x->type() == common::F64()) { \ + feature_.CurrentLoopBlock().float_##member += x->type().lanes(); \ + } else { \ + feature_.CurrentLoopBlock().int_##member += x->type().lanes(); \ + } \ + std::vector sub_exprs = x->expr_fields(); \ + for (const Expr *e : sub_exprs) { \ + Visit(e); \ + } \ + } + + VisitForDtypePattern(Add, add_or_sub); +VisitForDtypePattern(Sub, add_or_sub); +VisitForDtypePattern(Minus, add_or_sub); +VisitForDtypePattern(Mul, mul); +VisitForDtypePattern(Div, div_or_mod); +VisitForDtypePattern(Mod, div_or_mod); +VisitForDtypePattern(FracOp, div_or_mod); +VisitForDtypePattern(EQ, cmp); +VisitForDtypePattern(NE, cmp); +VisitForDtypePattern(GT, cmp); +VisitForDtypePattern(GE, cmp); +VisitForDtypePattern(LT, cmp); +VisitForDtypePattern(LE, cmp); +VisitForDtypePattern(Call, math_func); +VisitForDtypePattern(Power, math_func); +VisitForDtypePattern(PrimitiveNode, math_func); +VisitForDtypePattern(Cast, other_call); +VisitForDtypePattern(Let, other_call); + +#define VisitForMultiOperandsDtypePattern(NodeType, member) \ + void FeatureExtractor::Visit(const NodeType *x) { \ + if (x->type() == common::F32() || x->type() == common::F16() || x->type() == common::F64()) { \ + feature_.CurrentLoopBlock().float_##member += (x->operands().size() - 1); \ + } else { \ + feature_.CurrentLoopBlock().int_##member += (x->operands().size() - 1); \ + } \ + std::vector sub_exprs = x->expr_fields(); \ + for (const Expr *e : sub_exprs) { \ + Visit(e); \ + } \ + } + +VisitForMultiOperandsDtypePattern(Sum, add_or_sub); +VisitForMultiOperandsDtypePattern(Product, mul); + +#define VisitCountMemberPattern(NodeType, member) \ + void FeatureExtractor::Visit(const NodeType *x) { \ + feature_.CurrentLoopBlock().member += 1; \ + std::vector sub_exprs = x->expr_fields(); \ + for (const Expr *e : sub_exprs) { \ + Visit(e); \ + } \ + } + +VisitCountMemberPattern(And, bool_op); +VisitCountMemberPattern(Or, bool_op); +VisitCountMemberPattern(Not, bool_op); +VisitCountMemberPattern(Max, select_op); +VisitCountMemberPattern(Min, select_op); +VisitCountMemberPattern(IfThenElse, select_op); +VisitCountMemberPattern(Select, select_op); +VisitCountMemberPattern(Alloc, mem_alloc); +VisitCountMemberPattern(Free, mem_free); +VisitCountMemberPattern(Load, mem_read); +VisitCountMemberPattern(Store, mem_write); + +/* Visit for loops */ + +void FeatureExtractor::Visit(const For *x) { + feature_.IntoLoopBlock(); + + LoopBlockFeature &loop_feature = feature_.CurrentLoopBlock(); + if (x->min.is_constant() && x->extent.is_constant()) { + loop_feature.loop_length = (x->extent.get_constant() - x->min.get_constant()); + } else { + loop_feature.loop_length = -1; // -1 represents unknown + } + + if (x->is_parallel()) { + loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kParallel; + loop_feature.len_vthread = loop_feature.loop_length; + } else if (x->is_unrolled()) { + loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kUnroll; + } else if (x->is_vectorized()) { + loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kVectorize; + loop_feature.vectorize_factor = x->vectorize_info().factor; + } else if (x->is_binded()) { + loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kGpuBind; + const BindInfo &bind_info = x->bind_info(); + int offset = bind_info.offset; + if (bind_info.for_type == ForType::GPUBlock) { + if (offset == 0) { + loop_feature.len_blockIdx_x = loop_feature.loop_length; + } else if (offset == 1) { + loop_feature.len_blockIdx_y = loop_feature.loop_length; + } else if (offset == 2) { + loop_feature.len_blockIdx_z = loop_feature.loop_length; + } + } else if (bind_info.for_type == ForType::GPUThread) { + if (offset == 0) { + loop_feature.len_threadIdx_x = loop_feature.loop_length; + } else if (offset == 1) { + loop_feature.len_threadIdx_y = loop_feature.loop_length; + } else if (offset == 2) { + loop_feature.len_threadIdx_z = loop_feature.loop_length; + } + } + } + + std::vector sub_exprs = x->expr_fields(); + for (const Expr *e : sub_exprs) { + Visit(e); + } + + feature_.ExitLoopBlock(); +} + +void FeatureExtractor::Visit(const PolyFor *x) { + Expr copy = optim::IRCopy(Expr(x)); + feature_.IntoLoopBlock(); + optim::TransformPolyForToFor(©); + ir::For *loop = copy.As(); + CHECK(loop != nullptr); + Visit(loop); + feature_.ExitLoopBlock(); +} + +/* Visit for Reduce and Broadcast */ + +void FeatureExtractor::Visit(const Reduce *x) { + if (x->type() == common::F32() || x->type() == common::F16() || x->type() == common::F64()) { + switch (x->reduce_type) { + case Reduce::ReduceType::kSum: + feature_.CurrentLoopBlock().float_reduce_sum_or_sub += x->type().lanes(); + break; + case Reduce::ReduceType::kSub: + feature_.CurrentLoopBlock().float_reduce_sum_or_sub += x->type().lanes(); + break; + case Reduce::ReduceType::kDiv: + feature_.CurrentLoopBlock().float_reduce_div += x->type().lanes(); + break; + case Reduce::ReduceType::kMul: + feature_.CurrentLoopBlock().float_reduce_mul += x->type().lanes(); + break; + case Reduce::ReduceType::kMax: + feature_.CurrentLoopBlock().float_reduce_max_or_min += x->type().lanes(); + break; + case Reduce::ReduceType::kMin: + feature_.CurrentLoopBlock().float_reduce_max_or_min += x->type().lanes(); + break; + } + } else { + switch (x->reduce_type) { + case Reduce::ReduceType::kSum: + feature_.CurrentLoopBlock().int_reduce_sum_or_sub += x->type().lanes(); + break; + case Reduce::ReduceType::kSub: + feature_.CurrentLoopBlock().int_reduce_sum_or_sub += x->type().lanes(); + break; + case Reduce::ReduceType::kDiv: + feature_.CurrentLoopBlock().int_reduce_div += x->type().lanes(); + break; + case Reduce::ReduceType::kMul: + feature_.CurrentLoopBlock().int_reduce_mul += x->type().lanes(); + break; + case Reduce::ReduceType::kMax: + feature_.CurrentLoopBlock().int_reduce_max_or_min += x->type().lanes(); + break; + case Reduce::ReduceType::kMin: + feature_.CurrentLoopBlock().int_reduce_max_or_min += x->type().lanes(); + break; + } + } + std::vector sub_exprs = x->expr_fields(); + for (const Expr *e : sub_exprs) { + Visit(e); + } +} +VisitForDtypePattern(Broadcast, broadcast); + +/* Visit for IntrinsicOp */ +void FeatureExtractor::Visit(const IntrinsicOp *x) { + switch (x->getKind()) { +#define __(op__) \ + case IntrinsicKind::k##op__: \ + Visit(llvm::dyn_cast(x)); \ + break; + + INTRINSIC_KIND_FOR_EACH(__) +#undef __ + } +} + +VisitDoNothing(intrinsics::BufferGetDataHandle); +VisitDoNothing(intrinsics::BufferGetDataConstHandle); +VisitDoNothing(intrinsics::PodValueToX); +VisitDoNothing(intrinsics::BufferCreate); +VisitDoNothing(intrinsics::GetAddr); +VisitDoNothing(intrinsics::ArgsConstruct); + +VisitForDtypePattern(intrinsics::BuiltinIntrin, other_call) + +} // namespace auto_schedule +} // namespace cinn diff --git a/cinn/auto_schedule/cost_model/feature_extractor.h b/cinn/auto_schedule/cost_model/feature_extractor.h new file mode 100644 index 0000000000..073eee27ca --- /dev/null +++ b/cinn/auto_schedule/cost_model/feature_extractor.h @@ -0,0 +1,60 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cinn/auto_schedule/cost_model/feature.h" +#include "cinn/common/target.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/ir/ir_visitor.h" + +namespace cinn { +namespace auto_schedule { + +class FeatureExtractor : public ir::IRVisitor { + public: + FeatureExtractor(); + Feature Extract(const ir::ModuleExpr& mod_expr, const common::Target& target); + + void Visit(const Expr* x) override; + +#define __(op__) void Visit(const ir::op__* x) override; + NODETY_FORALL(__) +#undef __ + +#define __(op__) virtual void Visit(const ir::intrinsics::op__* x); + INTRINSIC_KIND_FOR_EACH(__) +#undef __ + + private: + Feature feature_; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/cinn/auto_schedule/cost_model/feature_extractor_test.cc b/cinn/auto_schedule/cost_model/feature_extractor_test.cc new file mode 100644 index 0000000000..ed0cd984c9 --- /dev/null +++ b/cinn/auto_schedule/cost_model/feature_extractor_test.cc @@ -0,0 +1,158 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/cost_model/feature_extractor.h" + +#include +#include + +#include +#include +#include + +#include "cinn/common/context.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/lang/builtin.h" +#include "cinn/lang/compute.h" +#include "cinn/lang/lower.h" +#include "cinn/lang/placeholder.h" +#include "cinn/poly/stage.h" + +namespace cinn { +namespace auto_schedule { + +TEST(FeatureExtractor, SimpleAssign) { + Context::Global().ResetNameId(); +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + ir::Expr M(32); + ir::Expr N(32); + + lang::Placeholder A("A", {M, N}); + ir::Tensor B = lang::Compute( + {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); + + poly::StageMap stages = poly::CreateStages({A, B}); + std::vector funcs = lang::LowerVec("SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); + ir::Expr ast_expr = funcs[0]->body; + VLOG(6) << "Expr to test: " << ast_expr; + + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + + FeatureExtractor extractor; + + Feature feature = extractor.Extract(mod_expr, target); + + std::vector to_check = feature.ToFixedSizeVector(); + + ASSERT_EQ(to_check.size(), static_cast(LoopBlockFeature::kTotalSize + 1)); + VLOG(6) << "Feature data before slog:"; + for (size_t i = 0; i < to_check.size(); ++i) { + VLOG(6) << i << " " << (std::pow(2, to_check[i]) - 1); + if (i != 0 && i != 17 && i != 18 && i != 29) { + ASSERT_EQ(to_check[i], 0); + } + } + // target +#ifdef CINN_WITH_CUDA + ASSERT_EQ(to_check[0], 1); +#else + ASSERT_EQ(to_check[0], 0); +#endif + // mem_read + ASSERT_EQ(to_check[17], slog(M.get_constant() * N.get_constant())); // mem_read + // mem_write + ASSERT_EQ(to_check[18], slog(M.get_constant() * N.get_constant())); // mem_write + // non-opt loops, including root block + ASSERT_EQ(to_check[29], slog(3)); +} + +TEST(FeatureExtractor, MatrixMultiply) { + Context::Global().ResetNameId(); +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + + ir::Expr M(2); + ir::Expr N(2); + ir::Expr K(4); + + lang::Placeholder A("A", {M, K}); + lang::Placeholder B("B", {K, N}); + + ir::Var k(K.as_int32(), "reduce_axis_k"); + ir::Tensor C = lang::Compute( + {M, N}, [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); + + poly::StageMap stages = poly::CreateStages({C}); + std::vector funcs = lang::LowerVec("MatrixMultiply", stages, {C}, {}, {}, nullptr, target, true); + + std::vector vec_ast{funcs[0]->body}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + std::vector blocks = ir_sch.GetAllBlocks(); + std::vector loops = ir_sch.GetLoops(blocks[0]); + ir_sch.Bind(loops.back(), "threadIdx.x"); + + ir::Expr ast_expr = mod_expr.GetExprs()[0]; + VLOG(6) << "Expr to test: " << ast_expr; + + FeatureExtractor extractor; + Feature feature = extractor.Extract(mod_expr, target); + + std::vector to_check = feature.ToFixedSizeVector(); + + ASSERT_EQ(to_check.size(), static_cast(LoopBlockFeature::kTotalSize + 1)); + std::unordered_set non_zero_indice = {0, 1, 2, 17, 18, 29, 30, 37}; + for (size_t i = 0; i < to_check.size(); ++i) { + VLOG(6) << i << " " << (std::pow(2, to_check[i]) - 1); + if (!non_zero_indice.count(i)) { + ASSERT_EQ(to_check[i], 0); + } + } + // target +#ifdef CINN_WITH_CUDA + ASSERT_EQ(to_check[0], 1); +#else + ASSERT_EQ(to_check[0], 0); +#endif + float out_loop = M.get_constant() * N.get_constant(); + float total_loop = out_loop * K.get_constant(); + // float_mul + ASSERT_EQ(to_check[1], slog(total_loop)); + // float_add_or_sub + ASSERT_EQ(to_check[2], slog(total_loop)); + // mem_read + ASSERT_EQ(to_check[17], slog(total_loop * 3)); + // mem_write + ASSERT_EQ(to_check[18], slog(total_loop + out_loop)); + + // non-opt loops, including root block + ASSERT_EQ(to_check[29], slog(3)); + // GpuBind loop + ASSERT_EQ(to_check[30], slog(1)); + // GpuBind loop + ASSERT_EQ(to_check[37], slog(out_loop)); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/cinn/auto_schedule/cost_model/feature_test.cc b/cinn/auto_schedule/cost_model/feature_test.cc new file mode 100644 index 0000000000..908672d41b --- /dev/null +++ b/cinn/auto_schedule/cost_model/feature_test.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/cost_model/feature.h" + +#include +#include + +namespace cinn { +namespace auto_schedule { + +TEST(Feature, Basic) { + // TODO(zhhsplendid): add some basic tests +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/cinn/auto_schedule/cost_model/cost_model.cc b/cinn/auto_schedule/cost_model/xgb_cost_model.cc similarity index 61% rename from cinn/auto_schedule/cost_model/cost_model.cc rename to cinn/auto_schedule/cost_model/xgb_cost_model.cc index e2248d2ef1..8549442688 100644 --- a/cinn/auto_schedule/cost_model/cost_model.cc +++ b/cinn/auto_schedule/cost_model/xgb_cost_model.cc @@ -12,26 +12,32 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "cinn/auto_schedule/cost_model/cost_model.h" +#include "cinn/auto_schedule/cost_model/xgb_cost_model.h" #include +#include #include #include #include #include +#include #include #include #include +#include #include #include #include +#include #include +#include "cinn/common/python_interpreter_guard.h" + namespace cinn { namespace auto_schedule { -std::once_flag CostModel::init_once_flag_; +std::atomic XgbCostModel::xgb_cost_model_count_(0); // Convert 1D vector to py numpy template @@ -84,40 +90,46 @@ void AddDistPkgToPythonSysPath() { } } -CostModel::CostModel() { - std::call_once(init_once_flag_, AddDistPkgToPythonSysPath); - pybind11::module cost_model_py_mod = pybind11::module::import("cinn.auto_schedule.cost_model"); - python_member_ = cost_model_py_mod.attr("CostModel")(); -} - -CostModel::~CostModel() { - // Do nothing, python_member_ will be destructed after CostModel destructor +XgbCostModel::XgbCostModel() { + common::PythonInterpreterGuard::Guard(); + int previous = xgb_cost_model_count_.fetch_add(1); + if (previous == 0) { + AddDistPkgToPythonSysPath(); + } + xgb_module_ = pybind11::module::import("xgboost"); + xgb_booster_ = xgb_module_.attr("Booster")(); } -void CostModel::Train(const std::vector>& samples, const std::vector& labels) { +void XgbCostModel::Train(const std::vector>& samples, const std::vector& labels) { + update_samples_ = samples; + update_labels_ = labels; pybind11::array np_samples = VectorToNumpy(samples); pybind11::array np_labels = VectorToNumpy(labels); - python_member_.attr("train")(np_samples, np_labels); + pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples, np_labels); + xgb_booster_ = xgb_module_.attr("train")(pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_)); } -std::vector CostModel::Predict(const std::vector>& samples) { +std::vector XgbCostModel::Predict(const std::vector>& samples) const { pybind11::array np_samples = VectorToNumpy(samples); - - pybind11::array py_result = python_member_.attr("predict")(np_samples); + pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples); + pybind11::array py_result = xgb_booster_.attr("predict")(dmatrix); return py_result.cast>(); } -void CostModel::Update(const std::vector>& samples, const std::vector& labels) { - pybind11::array np_samples = VectorToNumpy(samples); - pybind11::array np_labels = VectorToNumpy(labels); +void XgbCostModel::Update(const std::vector>& samples, const std::vector& labels) { + update_samples_.insert(update_samples_.end(), samples.begin(), samples.end()); + update_labels_.insert(update_labels_.end(), labels.begin(), labels.end()); + pybind11::array np_samples = VectorToNumpy(update_samples_); + pybind11::array np_labels = VectorToNumpy(update_labels_); - python_member_.attr("update")(np_samples, np_labels); + pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples, np_labels); + xgb_booster_ = xgb_module_.attr("train")(pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_)); } -void CostModel::Save(const std::string& path) { python_member_.attr("save")(pybind11::str(path)); } +void XgbCostModel::Save(const std::string& path) { xgb_booster_.attr("save_model")(pybind11::str(path)); } -void CostModel::Load(const std::string& path) { python_member_.attr("load")(pybind11::str(path)); } +void XgbCostModel::Load(const std::string& path) { xgb_booster_.attr("load_model")(pybind11::str(path)); } } // namespace auto_schedule } // namespace cinn diff --git a/cinn/auto_schedule/cost_model/cost_model.h b/cinn/auto_schedule/cost_model/xgb_cost_model.h similarity index 50% rename from cinn/auto_schedule/cost_model/cost_model.h rename to cinn/auto_schedule/cost_model/xgb_cost_model.h index f95811e474..fc68115ddb 100644 --- a/cinn/auto_schedule/cost_model/cost_model.h +++ b/cinn/auto_schedule/cost_model/xgb_cost_model.h @@ -16,45 +16,59 @@ #include +#include +#include #include #include #include +#include "cinn/common/cost_model.h" + namespace cinn { namespace auto_schedule { /** - * A C++ cost model which calls Python cost model via pybind + * A C++ cost model which calls Python xgboost via pybind + * + * Note: this class handles Python interpreter life time in class. + * If you have to call other Python functions out of this class so that meet + * life time conflict, you can check cinn::common::PythonInterpreterGuard + * + * For cinn::common::PythonInterpreterGuard, see: + * cinn/common/python_interpreter_guard.h .cc * - * Note: this class doesn't handle Python interpreter lifttime, users should - * manage scoped_interpreter/initialize_interpreter/finalize_interpreter by - * themselves. For pybind interpreter lifetime management, see: + * For pybind interpreter lifetime management, see: * * https://pybind11.readthedocs.io/en/stable/advanced/embedding.html#interpreter-lifetime * https://pybind11.readthedocs.io/en/stable/reference.html#_CPPv422initialize_interpreterbiPPCKcb */ -class CostModel { - // TODO(zhhsplendid): add CostModelType for C++ interface +class XgbCostModel : public CostModel { public: - CostModel(); - ~CostModel(); + XgbCostModel(); + ~XgbCostModel() = default; - void Train(const std::vector>& samples, const std::vector& labels); + void Train(const std::vector>& samples, const std::vector& labels) override; - std::vector Predict(const std::vector>& samples); + std::vector Predict(const std::vector>& samples) const override; - void Update(const std::vector>& samples, const std::vector& labels); + void Update(const std::vector>& samples, const std::vector& labels) override; - void Save(const std::string& path); + void Save(const std::string& path) override; - void Load(const std::string& path); + void Load(const std::string& path) override; private: - // Object points to Python CostModel - pybind11::object python_member_; + // Python xgboost module + pybind11::module xgb_module_; + // Object points to Python xgb.Booster() + pybind11::object xgb_booster_; + // atomic int to handle python interpreter life time and package dependency + static std::atomic xgb_cost_model_count_; + // Default train rounds + static constexpr int kTrainRound_ = 10; - // Flag to call_once on python inititalization function - static std::once_flag init_once_flag_; + std::vector> update_samples_; + std::vector update_labels_; }; } // namespace auto_schedule diff --git a/cinn/auto_schedule/cost_model/cost_model_test.cc b/cinn/auto_schedule/cost_model/xgb_cost_model_test.cc similarity index 81% rename from cinn/auto_schedule/cost_model/cost_model_test.cc rename to cinn/auto_schedule/cost_model/xgb_cost_model_test.cc index 445f8e6870..f237699a94 100644 --- a/cinn/auto_schedule/cost_model/cost_model_test.cc +++ b/cinn/auto_schedule/cost_model/xgb_cost_model_test.cc @@ -12,8 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "cinn/auto_schedule/cost_model/cost_model.h" +#include "cinn/auto_schedule/cost_model/xgb_cost_model.h" +#include #include #include @@ -26,8 +27,7 @@ namespace cinn { namespace auto_schedule { TEST(CostModel, Basic) { - pybind11::scoped_interpreter guard{}; - CostModel cost_model; + XgbCostModel cost_model; srand(time(NULL)); @@ -47,15 +47,22 @@ TEST(CostModel, Basic) { std::string path = "./test_cost_model.cpp_save_model"; cost_model.Save(path); - CostModel load_cost_model; + XgbCostModel load_cost_model; load_cost_model.Load(path); std::vector load_pred = cost_model.Predict(samples); ASSERT_EQ(pred.size(), load_pred.size()); for (size_t i = 0; i < pred.size(); ++i) { ASSERT_FLOAT_EQ(pred[i], load_pred[i]); + VLOG(6) << "pred[" << i << "] = " << pred[i]; } std::remove(path.c_str()); + + cost_model.Update(samples, labels); + pred = cost_model.Predict(samples); + for (size_t i = 0; i < pred.size(); ++i) { + VLOG(6) << "pred[" << i << "] = " << pred[i]; + } } } // namespace auto_schedule diff --git a/cinn/auto_schedule/search_space/search_space.cc b/cinn/auto_schedule/search_space/search_space.cc index 2ba0bf1311..59a39c9c25 100644 --- a/cinn/auto_schedule/search_space/search_space.cc +++ b/cinn/auto_schedule/search_space/search_space.cc @@ -20,12 +20,15 @@ #include #include -#include "cinn/auto_schedule/cost_model/cost_model.h" +#include "cinn/auto_schedule/cost_model/expr_cost_model.h" #include "cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" #include "cinn/auto_schedule/task/tune_task.h" #include "cinn/ir/ir_base.h" #include "cinn/ir/ir_schedule.h" #include "cinn/optim/ir_copy.h" +#include "cinn/runtime/flags.h" + +DECLARE_bool(auto_schedule_use_cost_model); namespace cinn { namespace auto_schedule { @@ -56,15 +59,17 @@ std::vector SearchSpace::GetRandomInitialSketch(int num) { return result; } -SearchState SearchSpace::GetScheduleMutate(const SearchState& state, const CostModel& cost_model) { +SearchState SearchSpace::GetScheduleMutate(const SearchState& state, const ExprCostModel& cost_model) { VLOG(4) << "Start SearchSpace::GetScheduleMutate"; - // TODO(zhhsplendid): cost model predict bool has_manual_schedule = false; if (has_manual_schedule) { SearchState ret = ManualScheduleMutate(state); return ret; } SearchState ret = RandomScheduleMutate(state); + if (FLAGS_auto_schedule_use_cost_model) { + ret.predicted_cost = cost_model.Predict(ret.mod_expr, tune_task_.target); + } return ret; } diff --git a/cinn/auto_schedule/search_space/search_space.h b/cinn/auto_schedule/search_space/search_space.h index d16e5fb7b7..08bd693616 100644 --- a/cinn/auto_schedule/search_space/search_space.h +++ b/cinn/auto_schedule/search_space/search_space.h @@ -18,7 +18,7 @@ #include #include -#include "cinn/auto_schedule/cost_model/cost_model.h" +#include "cinn/auto_schedule/cost_model/expr_cost_model.h" #include "cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" #include "cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h" #include "cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h" @@ -48,7 +48,7 @@ class SearchSpace { virtual std::vector GetRandomInitialSketch(int num); // Evolutionary search mutate, returns the mutated ModuleExpr and estimited cost - virtual SearchState GetScheduleMutate(const SearchState& state, const CostModel& cost_model); + virtual SearchState GetScheduleMutate(const SearchState& state, const ExprCostModel& cost_model); private: // TODO(zhhsplendid): mutate by manual schedule. diff --git a/cinn/auto_schedule/search_space/search_state.h b/cinn/auto_schedule/search_space/search_state.h index 94e8071e96..aaa701bc9d 100644 --- a/cinn/auto_schedule/search_space/search_state.h +++ b/cinn/auto_schedule/search_space/search_state.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -42,8 +43,8 @@ class SearchState { // Cost model predicted cost float predicted_cost = NOT_INIT_COST; - // Negative constant standing for a cost not being initialized - static constexpr float NOT_INIT_COST = -1.0; + // Constant standing for a cost not being initialized + static constexpr float NOT_INIT_COST = std::numeric_limits::max(); SearchState() = default; diff --git a/cinn/auto_schedule/search_strategy/evolutionary_search.cc b/cinn/auto_schedule/search_strategy/evolutionary_search.cc index 7581f5c789..8a5ae0111d 100644 --- a/cinn/auto_schedule/search_strategy/evolutionary_search.cc +++ b/cinn/auto_schedule/search_strategy/evolutionary_search.cc @@ -32,7 +32,8 @@ namespace cinn { namespace auto_schedule { -EvolutionarySearch::EvolutionarySearch(const TuneTask& tune_task) : tune_task_(tune_task) { +EvolutionarySearch::EvolutionarySearch(const TuneTask& tune_task, const ExprCostModel& cost_model) + : tune_task_(tune_task), cost_model_(cost_model) { search_space_ = std::make_unique(tune_task); } @@ -116,7 +117,7 @@ std::vector EvolutionarySearch::Evolve(const std::vector evolution_with_cost(ret_num); for (size_t i = 0; i < evolution.size(); ++i) { - evolution_with_cost.Push(search_space_->GetScheduleMutate(evolution[i], *cost_model_)); + evolution_with_cost.Push(search_space_->GetScheduleMutate(evolution[i], cost_model_)); } return evolution_with_cost.ReturnAsContainer>(); diff --git a/cinn/auto_schedule/search_strategy/evolutionary_search.h b/cinn/auto_schedule/search_strategy/evolutionary_search.h index 532817250f..68cdc15ebd 100644 --- a/cinn/auto_schedule/search_strategy/evolutionary_search.h +++ b/cinn/auto_schedule/search_strategy/evolutionary_search.h @@ -17,6 +17,7 @@ #include #include +#include "cinn/auto_schedule/cost_model/expr_cost_model.h" #include "cinn/auto_schedule/search_space/search_space.h" #include "cinn/auto_schedule/search_space/search_state.h" #include "cinn/auto_schedule/task/tune_task.h" @@ -37,7 +38,7 @@ class EvolutionarySearch { * @param tune_task: the TuneTask this class works on. This class doesn't * take ownership of the pointer. */ - EvolutionarySearch(const TuneTask& tune_task); + EvolutionarySearch(const TuneTask& tune_task, const ExprCostModel& cost_model); /** * Destructor @@ -100,7 +101,7 @@ class EvolutionarySearch { const TuneTask& tune_task_; - CostModel* cost_model_; // not owned + const ExprCostModel& cost_model_; // not owned }; } // namespace auto_schedule diff --git a/cinn/auto_schedule/search_strategy/evolutionary_search_test.cc b/cinn/auto_schedule/search_strategy/evolutionary_search_test.cc index a1dd4b57b2..2737e5c69c 100644 --- a/cinn/auto_schedule/search_strategy/evolutionary_search_test.cc +++ b/cinn/auto_schedule/search_strategy/evolutionary_search_test.cc @@ -19,6 +19,7 @@ #include #include +#include "cinn/auto_schedule/cost_model/expr_cost_model.h" #include "cinn/auto_schedule/search_space/search_space.h" #include "cinn/auto_schedule/search_space/search_state.h" #include "cinn/auto_schedule/task/tune_task.h" @@ -57,7 +58,7 @@ class MockSearchSpace : public SearchSpace { return ret; } - SearchState GetScheduleMutate(const SearchState& state, const CostModel& cost_model) override { + SearchState GetScheduleMutate(const SearchState& state, const ExprCostModel& cost_model) override { float cost = 0.0f; std::vector exprs = state.mod_expr.GetExprs(); for (const ir::Expr& expr : exprs) { @@ -75,15 +76,14 @@ class MockSearchSpace : public SearchSpace { TEST(EvolutionarySearch, GetOneBest) { TuneTask mock_tune_task; + ExprCostModel cost_model; TuningOptions options; - EvolutionarySearch evolutionary_search(mock_tune_task); + EvolutionarySearch evolutionary_search(mock_tune_task, cost_model); MockSearchSpace* mock_search_space = new MockSearchSpace(mock_tune_task); // Ownership is transferred so don't delete mock_search_space evolutionary_search.SetSearchSpace(mock_search_space); - - SearchState best_state = evolutionary_search.SearchModuleExpr(options); - + SearchState best_state = evolutionary_search.SearchModuleExpr(options); std::vector exprs = best_state.mod_expr.GetExprs(); EXPECT_GE(exprs.size(), 1UL); for (const ir::Expr& e : exprs) { @@ -93,8 +93,9 @@ TEST(EvolutionarySearch, GetOneBest) { TEST(EvolutionarySearch, GetEpsGreedy) { TuneTask mock_tune_task; + ExprCostModel cost_model; TuningOptions options; - EvolutionarySearch evolutionary_search(mock_tune_task); + EvolutionarySearch evolutionary_search(mock_tune_task, cost_model); MockSearchSpace* mock_search_space = new MockSearchSpace(mock_tune_task); // Ownership is transferred so don't delete mock_search_space diff --git a/cinn/auto_schedule/task/task_optimizer.cc b/cinn/auto_schedule/task/task_optimizer.cc index 3d31b0a300..9f4f5d8c0b 100644 --- a/cinn/auto_schedule/task/task_optimizer.cc +++ b/cinn/auto_schedule/task/task_optimizer.cc @@ -18,13 +18,21 @@ #include +#include "cinn/auto_schedule/cost_model/expr_cost_model.h" #include "cinn/auto_schedule/measure/measure.h" #include "cinn/auto_schedule/search_strategy/evolutionary_search.h" +#include "cinn/ir/ir_schedule.h" #include "cinn/optim/ir_copy.h" +#include "cinn/runtime/flags.h" + +DECLARE_bool(auto_schedule_use_cost_model); namespace cinn { namespace auto_schedule { +TaskOptimizer::TaskOptimizer(const TuneTask& task, ScheduleMeasurer* schedule_measurer) + : task_(&task), schedule_measurer_(schedule_measurer), cost_model_() {} + TuningResult::OptimizedComputeExpr TaskOptimizer::Optimize(const TuningOptions& options) { // TODO(zhhsplendid): develop other optimize methods and configure the method by options. return OptimizeByEvolution(options); @@ -44,7 +52,7 @@ TuningResult::OptimizedComputeExpr TaskOptimizer::OptimizeByEvolution(const Tuni if (evolutionary_search_ == nullptr) { // TODO(zhhsplendid): check whether the options is same as previous, // if not, we should create new EvolutionarySearch - evolutionary_search_ = std::make_unique(*task_); + evolutionary_search_ = std::make_unique(*task_, cost_model_); } if (options.num_measure_trials == 0) { @@ -79,7 +87,9 @@ TuningResult::OptimizedComputeExpr TaskOptimizer::OptimizeByEvolution(const Tuni std::vector states = evolutionary_search_->SearchModuleExprEpsGreedy(options); VLOG(4) << "TaskOptimizer run EvolutionarySearch with return size = " << states.size(); std::vector measure_inputs(states.size()); + std::vector cost_model_samples(states.size()); for (size_t i = 0; i < states.size(); ++i) { + cost_model_samples[i] = &(states[i].mod_expr); measure_inputs[i].task = task_; std::vector best_exprs = states[i].mod_expr.GetExprs(); CHECK_EQ(best_exprs.size(), task_->lowered_funcs.size()) @@ -97,6 +107,16 @@ TuningResult::OptimizedComputeExpr TaskOptimizer::OptimizeByEvolution(const Tuni CHECK_EQ(measure_outputs.size(), states.size()) << "ScheduleMeasurer didn't output same number of MeasureOutput of states in TaskOptimizer"; + std::vector cost_model_labels(states.size()); + for (size_t i = 0; i < states.size(); ++i) { + cost_model_labels[i] = measure_outputs[i].execution_cost; + } + + if (FLAGS_auto_schedule_use_cost_model) { + VLOG(6) << "cost_model_samples.size() = " << cost_model_samples.size(); + VLOG(6) << "cost_model_labels.size() = " << cost_model_labels.size(); + cost_model_.Update(cost_model_samples, cost_model_labels, task_->target); + } // TODO(zhhsplendid): write measure record into cache. for (size_t i = 0; i < measure_outputs.size(); ++i) { diff --git a/cinn/auto_schedule/task/task_optimizer.h b/cinn/auto_schedule/task/task_optimizer.h index 2134573594..af527392d4 100644 --- a/cinn/auto_schedule/task/task_optimizer.h +++ b/cinn/auto_schedule/task/task_optimizer.h @@ -16,6 +16,7 @@ #include +#include "cinn/auto_schedule/cost_model/expr_cost_model.h" #include "cinn/auto_schedule/measure/schedule_measurer.h" #include "cinn/auto_schedule/search_strategy/evolutionary_search.h" #include "cinn/auto_schedule/task/tune_task.h" @@ -29,8 +30,7 @@ namespace auto_schedule { // optimal schedule for the task. class TaskOptimizer { public: - TaskOptimizer(const TuneTask& task, ScheduleMeasurer* schedule_measurer) - : task_(&task), schedule_measurer_(schedule_measurer) {} + TaskOptimizer(const TuneTask& task, ScheduleMeasurer* schedule_measurer); TuningResult::OptimizedComputeExpr Optimize(const TuningOptions& options); @@ -42,6 +42,8 @@ class TaskOptimizer { ScheduleMeasurer* schedule_measurer_; std::unique_ptr evolutionary_search_ = nullptr; + + ExprCostModel cost_model_; }; } // namespace auto_schedule diff --git a/cinn/common/CMakeLists.txt b/cinn/common/CMakeLists.txt index 3b12ade99b..70357f0b9e 100644 --- a/cinn/common/CMakeLists.txt +++ b/cinn/common/CMakeLists.txt @@ -17,6 +17,7 @@ gather_srcs(cinnapi_src SRCS arithmatic.cc cas.cc union_find.cc + python_interpreter_guard.cc ) message(STATUS "srcs: ${cinnapi_src}") diff --git a/cinn/common/cost_model.h b/cinn/common/cost_model.h new file mode 100644 index 0000000000..6c5f4cc79b --- /dev/null +++ b/cinn/common/cost_model.h @@ -0,0 +1,40 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace cinn { +namespace auto_schedule { + +/** + * A C++ cost model virtual base class + */ +class CostModel { + public: + virtual void Train(const std::vector>& samples, const std::vector& labels) = 0; + + virtual std::vector Predict(const std::vector>& samples) const = 0; + + virtual void Update(const std::vector>& samples, const std::vector& labels) = 0; + + virtual void Save(const std::string& path) = 0; + + virtual void Load(const std::string& path) = 0; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/cinn/common/python_interpreter_guard.cc b/cinn/common/python_interpreter_guard.cc new file mode 100644 index 0000000000..bafe376d97 --- /dev/null +++ b/cinn/common/python_interpreter_guard.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/common/python_interpreter_guard.h" + +#include + +namespace cinn { +namespace common { + +PythonInterpreterGuard::PythonInterpreterGuard() { pybind11::initialize_interpreter(); } + +PythonInterpreterGuard::~PythonInterpreterGuard() { pybind11::finalize_interpreter(); } + +PythonInterpreterGuard& PythonInterpreterGuard::Guard() { + static PythonInterpreterGuard guard; + return guard; +} + +} // namespace common +} // namespace cinn diff --git a/cinn/common/python_interpreter_guard.h b/cinn/common/python_interpreter_guard.h new file mode 100644 index 0000000000..8c7961af81 --- /dev/null +++ b/cinn/common/python_interpreter_guard.h @@ -0,0 +1,43 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace cinn { +namespace common { + +/** + * Singleton to handle Python interpreter life time, since pybind11::initialize_interpreter and + * pybind11::finalize_interpreter cannot be called initialization again after finalization, this + * singleton calls pybind11::finalize_interpreter when it constructs and calls finalization when + * it destructs. + * + * In this case, every caller can call this guard to make sure the pybind11 Python interpreter + * is alive. + */ +class PythonInterpreterGuard { + public: + // Destructor + ~PythonInterpreterGuard(); + + // Singleton get instance + static PythonInterpreterGuard& Guard(); + + private: + // Constructor + PythonInterpreterGuard(); +}; + +} // namespace common +} // namespace cinn diff --git a/cinn/runtime/flags.cc b/cinn/runtime/flags.cc index 78d40a1841..ac417c9f6b 100755 --- a/cinn/runtime/flags.cc +++ b/cinn/runtime/flags.cc @@ -66,6 +66,11 @@ DEFINE_string(cinn_source_code_save_path, StringFromEnv("FLAGS_cinn_source_code_save_path", ""), "Specify the directory path of generated source code, which is used for debug."); +DEFINE_bool(auto_schedule_use_cost_model, + BoolFromEnv("FLAGS_auto_schedule_use_cost_model", false), + "Whether to use cost model in auto schedule, this is an on-developing flag and it will be removed when " + "cost model is stable"); + namespace cinn { namespace runtime {