diff --git a/include/tvm/relax/builder.h b/include/tvm/relax/vm/exec_builder.h similarity index 93% rename from include/tvm/relax/builder.h rename to include/tvm/relax/vm/exec_builder.h index 40e124041e202..415c9533fe3d1 100644 --- a/include/tvm/relax/builder.h +++ b/include/tvm/relax/vm/exec_builder.h @@ -18,11 +18,11 @@ */ /*! - * \file tvm/relax/builder.h + * \file tvm/relax/vm/exec_builder.h * \brief */ -#ifndef TVM_RELAX_BUILDER_H_ -#define TVM_RELAX_BUILDER_H_ +#ifndef TVM_RELAX_EXEC_BUILDER_H_ +#define TVM_RELAX_EXEC_BUILDER_H_ #include #include @@ -30,8 +30,8 @@ #include #include -#include "./vm/bytecode.h" -#include "./vm/executable.h" +#include "./bytecode.h" +#include "./executable.h" namespace tvm { namespace relax { @@ -102,4 +102,4 @@ class ExecBuilder : public ObjectRef { } // namespace relax } // namespace tvm -#endif // TVM_RELAX_BUILDER_H_ +#endif // TVM_RELAX_EXEC_BUILDER_H_ diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index 58f08363294ff..19f1ecf6d729a 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -23,6 +23,7 @@ from . import op from . import parser from . import analysis +from . import transform # Expr diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 3e13f0ad13bfb..ea55d7d3ced1f 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -37,23 +37,3 @@ def post_order_visit(expr, fvisit): The visitor function to be applied. """ return _ffi_api.post_order_visit(expr, fvisit) - -def fma_rewrite(expr): - """Perform fused multiply add rewriting in dataflow blocks. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - """ - return _ffi_api.fma_rewrite(expr) - -def explicit_memory_rewrite(expr): - """Perform explicit memory allocation for call_dps in dataflow blocks. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - """ - return _ffi_api.explicit_memory_rewrite(expr) \ No newline at end of file diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py new file mode 100644 index 0000000000000..b50b783e41835 --- /dev/null +++ b/python/tvm/relax/transform/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax IR analysis. """ + +from .transform import * diff --git a/python/tvm/relax/transform/_ffi_api.py b/python/tvm/relax/transform/_ffi_api.py new file mode 100644 index 0000000000000..df803e82a960b --- /dev/null +++ b/python/tvm/relax/transform/_ffi_api.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +import tvm._ffi + +tvm._ffi._init_api("relax.transform", __name__) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py new file mode 100644 index 0000000000000..2a581865b3dbd --- /dev/null +++ b/python/tvm/relax/transform/transform.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return +# pylint: disable=unidiomatic-typecheck +from . import _ffi_api + +def fma_rewrite(expr): + """Perform fused multiply add rewriting in dataflow blocks. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + """ + return _ffi_api.fma_rewrite(expr) + +def explicit_memory_rewrite(expr): + """Perform explicit memory allocation for call_dps in dataflow blocks. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + """ + return _ffi_api.explicit_memory_rewrite(expr) diff --git a/src/relax/expr.cc b/src/relax/ir/expr.cc similarity index 100% rename from src/relax/expr.cc rename to src/relax/ir/expr.cc diff --git a/src/relax/expr_functor.cc b/src/relax/ir/expr_functor.cc similarity index 74% rename from src/relax/expr_functor.cc rename to src/relax/ir/expr_functor.cc index 5068b423b46a6..c277a991fd374 100644 --- a/src/relax/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -18,7 +18,7 @@ */ /*! - * \file src/relay/expr_functor.cc + * \file src/relax/expr_functor.cc * \brief A wrapper around ExprFunctor which functionally updates the AST. * * ExprMutator uses memoization and self return in order to amortize @@ -29,10 +29,6 @@ #include #include #include -#include -#include - -#include "../relay/transforms/pattern_utils.h" namespace tvm { namespace relax { @@ -415,114 +411,5 @@ Expr DataflowMutator::LookupVar(Var var) { return irbuilder_->LookupVar(var); } } - - -// ================== -// EwiseFMARewriter -// Example: -// x0 = mul(a, b) -// z0 = add(x0, c) -// --> -// z0 = ewise_fma(a, b, c) - -// Example 2: -// Question: do we want to support this? -// x0 = mul(a, add(k, b)) -// z0 = add(x0, c) -// --> -// lv0 = add(k, b) -// z0 = ewise_fma(a, lv0, c) - -class EwiseFMARewriter : public DataflowMutator { - Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) override { - static const Op& add_op = Op::Get("relax.add"); - static const Op& multiply_op = Op::Get("relax.multiply"); - static const Op& ewise_fma_op = Op::Get("relax.ewise_fma"); - - // TODO: shape & dtype check - const CallNode* op1 = binding->value.as(); - if (op1 && (op1->op == add_op)) { - Expr value = LookupVar(Downcast(op1->args[0])); - const CallNode* op2 = value.as(); - if (op2 && op2->op == multiply_op) { - Call fma_call = Call(ewise_fma_op, {op2->args[0], op2->args[1], op1->args[1]}, {}, {}); - return ir_builder->Emit(binding->var, fma_call); - } - } - return ir_builder->Emit(binding); - } -}; - -Expr FMARewrite(const Expr& e) { - return EwiseFMARewriter().Mutate(e); -} - -TVM_REGISTER_GLOBAL("relax.analysis.fma_rewrite") -.set_body_typed([](Expr expr) { - return FMARewrite(expr); -}); - -// ================== -// ExplicitMemMutator -// Example: -// y: Tensor[n, m] = rx.call_dps((n, m), op.identity, (x)) -// --> -// lv0 = rx.call("relax.builtin.alloc_tensor", [n, m]) -// rx.call_packed(op.identity, x, lv0) - -class ExplicitMemMutator : public DataflowMutator { - Expr ComputeStorageSize(const Expr& shape, const Type& type) const { - DynTensorType tensor_type = Downcast(type); - DataType dtype = DataType(tensor_type->dtype); - // Question: what if the dtype of tensor_type is unknown? - // Symbolic/static shape case - if (auto* shape_expr = shape.as()) { - PrimExpr num = PrimExpr(dtype.bits()) * PrimExpr(dtype.lanes()); - PrimExpr add = num + 7; - PrimExpr ret = 1; - for (PrimExpr dim : shape_expr->values) { - ret = ret * dim; - } - ret = ret * (add / PrimExpr(8)); - return ShapeExpr({ret}); - } - // Fully dynamic shape case - // will need to dedup with ComputeStorageInRelay when we upstream - Expr prod = relay::Prod(shape, Array(nullptr), false, false); - Expr num = relay::MakeConstantScalar(DataType::Int(64), dtype.bits() * dtype.lanes()); - Expr add = relay::Add(num, relay::MakeConstantScalar(DataType::Int(64), 7)); - Expr div = relay::MakeConstantScalar(DataType::Int(64), 8); - Expr ret = relay::Multiply(prod, relay::Divide(add, div)); - return ret; - } - - Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) override { - static const Op& call_dps_op = Op::Get("relax.call_dps"); - static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); - - const CallNode* op = binding->value.as(); - if(op && op->op == call_dps_op) { - // switch current DataflowBlock to an impure BindingBlock - ir_builder->is_dataflow_ = false; - ShapeExpr output_shape = Downcast(op->args[0]); - Type arg_type = Downcast(op->args[2])->fields[0]->checked_type(); - Expr output_size = ComputeStorageSize(output_shape, arg_type); - Var tensor = ir_builder->Emit(Call(alloc_tensor_op, {op->args[0]})); - return ir_builder->Emit(binding->var, Call(op->args[1], {op->args[2], tensor})); - } - return ir_builder->Emit(binding); - } -}; - -Expr ExplicitMemRewrite(const Expr& e) { - return ExplicitMemMutator().Mutate(e); -} - -TVM_REGISTER_GLOBAL("relax.analysis.explicit_memory_rewrite") -.set_body_typed([](Expr expr) { - return ExplicitMemRewrite(expr); -}); - - } // namespace relax } // namespace tvm diff --git a/src/relax/ir_builder.cc b/src/relax/ir/ir_builder.cc similarity index 100% rename from src/relax/ir_builder.cc rename to src/relax/ir/ir_builder.cc diff --git a/src/relax/type.cc b/src/relax/ir/type.cc similarity index 100% rename from src/relax/type.cc rename to src/relax/ir/type.cc diff --git a/src/relax/transform/fma_rewrite.cc b/src/relax/transform/fma_rewrite.cc new file mode 100644 index 0000000000000..c308d402e7055 --- /dev/null +++ b/src/relax/transform/fma_rewrite.cc @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/fma_rewrite.cc + * \brief + */ +#include + +namespace tvm { +namespace relax { + +// ================== +// EwiseFMARewriter +// Example: +// x0 = mul(a, b) +// z0 = add(x0, c) +// --> +// z0 = ewise_fma(a, b, c) + +// Example 2: +// Question: do we want to support this? +// x0 = mul(a, add(k, b)) +// z0 = add(x0, c) +// --> +// lv0 = add(k, b) +// z0 = ewise_fma(a, lv0, c) + +class EwiseFMARewriter : public DataflowMutator { + Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) override { + static const Op& add_op = Op::Get("relax.add"); + static const Op& multiply_op = Op::Get("relax.multiply"); + static const Op& ewise_fma_op = Op::Get("relax.ewise_fma"); + + // TODO: shape & dtype check + const CallNode* op1 = binding->value.as(); + if (op1 && (op1->op == add_op)) { + Expr value = LookupVar(Downcast(op1->args[0])); + const CallNode* op2 = value.as(); + if (op2 && op2->op == multiply_op) { + Call fma_call = Call(ewise_fma_op, {op2->args[0], op2->args[1], op1->args[1]}, {}, {}); + return ir_builder->Emit(binding->var, fma_call); + } + } + return ir_builder->Emit(binding); + } +}; + +Expr FMARewrite(const Expr& e) { + return EwiseFMARewriter().Mutate(e); +} + +TVM_REGISTER_GLOBAL("relax.transform.fma_rewrite") +.set_body_typed([](Expr expr) { + return FMARewrite(expr); +}); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/memory_rewrite.cc b/src/relax/transform/memory_rewrite.cc new file mode 100644 index 0000000000000..ae9832eb90128 --- /dev/null +++ b/src/relax/transform/memory_rewrite.cc @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/memory_rewrite.cc + * \brief + */ +#include +#include +#include +#include "../../relay/transforms/pattern_utils.h" + +namespace tvm { +namespace relax { + +// ================== +// ExplicitMemMutator +// Example: +// y: Tensor[n, m] = rx.call_dps((n, m), op.identity, (x)) +// --> +// lv0 = rx.call("relax.builtin.alloc_tensor", [n, m]) +// rx.call_packed(op.identity, x, lv0) + +class ExplicitMemMutator : public DataflowMutator { + Expr ComputeStorageSize(const Expr& shape, const Type& type) const { + DynTensorType tensor_type = Downcast(type); + DataType dtype = DataType(tensor_type->dtype); + // Question: what if the dtype of tensor_type is unknown? + // Symbolic/static shape case + if (auto* shape_expr = shape.as()) { + PrimExpr num = PrimExpr(dtype.bits()) * PrimExpr(dtype.lanes()); + PrimExpr add = num + 7; + PrimExpr ret = 1; + for (PrimExpr dim : shape_expr->values) { + ret = ret * dim; + } + ret = ret * (add / PrimExpr(8)); + return ShapeExpr({ret}); + } + // Fully dynamic shape case + // will need to dedup with ComputeStorageInRelay when we upstream + Expr prod = relay::Prod(shape, Array(nullptr), false, false); + Expr num = relay::MakeConstantScalar(DataType::Int(64), dtype.bits() * dtype.lanes()); + Expr add = relay::Add(num, relay::MakeConstantScalar(DataType::Int(64), 7)); + Expr div = relay::MakeConstantScalar(DataType::Int(64), 8); + Expr ret = relay::Multiply(prod, relay::Divide(add, div)); + return ret; + } + + Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) override { + static const Op& call_dps_op = Op::Get("relax.call_dps"); + static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); + + const CallNode* op = binding->value.as(); + if(op && op->op == call_dps_op) { + // switch current DataflowBlock to an impure BindingBlock + ir_builder->is_dataflow_ = false; + ShapeExpr output_shape = Downcast(op->args[0]); + Type arg_type = Downcast(op->args[2])->fields[0]->checked_type(); + Expr output_size = ComputeStorageSize(output_shape, arg_type); + Var tensor = ir_builder->Emit(Call(alloc_tensor_op, {op->args[0]})); + return ir_builder->Emit(binding->var, Call(op->args[1], {op->args[2], tensor})); + } + return ir_builder->Emit(binding); + } +}; + +Expr ExplicitMemRewrite(const Expr& e) { + return ExplicitMemMutator().Mutate(e); +} + +TVM_REGISTER_GLOBAL("relax.transform.explicit_memory_rewrite") +.set_body_typed([](Expr expr) { + return ExplicitMemRewrite(expr); +}); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/builder.cc b/src/relax/vm/exec_builder.cc similarity index 84% rename from src/relax/builder.cc rename to src/relax/vm/exec_builder.cc index 63661eefe63a9..4f681571fddf7 100644 --- a/src/relax/builder.cc +++ b/src/relax/vm/exec_builder.cc @@ -18,10 +18,9 @@ */ /*! - * \file src/relax/builder.cc + * \file src/relax/vm/exec_builder.cc */ - -#include +#include #include @@ -178,43 +177,52 @@ void ExecBuilderNode::Formalize() { } } -TVM_REGISTER_GLOBAL("relax.ExecBuilderCreate").set_body_typed(ExecBuilderNode::Create); +TVM_REGISTER_GLOBAL("relax.ExecBuilderCreate") +.set_body_typed(ExecBuilderNode::Create); TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitConstant") - .set_body_typed([](ExecBuilder builder, ObjectRef obj) { return builder->EmitConstant(obj); }); +.set_body_typed([](ExecBuilder builder, ObjectRef obj) { + return builder->EmitConstant(obj); +}); TVM_REGISTER_GLOBAL("relax.ExecBuilderFunction") - .set_body_typed([](ExecBuilder builder, String name, int64_t num_inputs) { - return builder->Function(name, num_inputs); - }); +.set_body_typed([](ExecBuilder builder, String name, int64_t num_inputs) { + return builder->Function(name, num_inputs); +}); TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitCall") - .set_body_typed([](ExecBuilder builder, String name, Array args, int64_t dst) { - std::vector args_; - for (size_t i = 0; i < args.size(); ++i) { - args_.push_back(static_cast(args[i]->value)); - } - Instruction::Arg dst_(dst); - CHECK_EQ(dst_.kind(), Instruction::ArgKind::kRegister); - builder->EmitCall(name, args_, dst_.value()); - }); +.set_body_typed([](ExecBuilder builder, String name, Array args, int64_t dst) { + std::vector args_; + for (size_t i = 0; i < args.size(); ++i) { + args_.push_back(static_cast(args[i]->value)); + } + Instruction::Arg dst_(dst); + CHECK_EQ(dst_.kind(), Instruction::ArgKind::kRegister); + builder->EmitCall(name, args_, dst_.value()); +}); TVM_REGISTER_GLOBAL("relax.ExecBuilderEmitRet") - .set_body_typed([](ExecBuilder builder, int64_t result) { builder->EmitRet(result); }); +.set_body_typed([](ExecBuilder builder, int64_t result) { + builder->EmitRet(result); +}); -TVM_REGISTER_GLOBAL("relax.ExecBuilderR").set_body_typed([](ExecBuilder builder, int64_t value) { +TVM_REGISTER_GLOBAL("relax.ExecBuilderR") +.set_body_typed([](ExecBuilder builder, int64_t value) { return Instruction::Arg(Instruction::kRegister, value).data; }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderImm").set_body_typed([](ExecBuilder builder, int64_t value) { +TVM_REGISTER_GLOBAL("relax.ExecBuilderImm") +.set_body_typed([](ExecBuilder builder, int64_t value) { return Instruction::Arg(Instruction::kImmediate, value).data; }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderC").set_body_typed([](ExecBuilder builder, int64_t value) { +TVM_REGISTER_GLOBAL("relax.ExecBuilderC") +.set_body_typed([](ExecBuilder builder, int64_t value) { return Instruction::Arg(Instruction::kConstIdx, value).data; }); -TVM_REGISTER_GLOBAL("relax.ExecBuilderGet").set_body_typed([](ExecBuilder builder) { +TVM_REGISTER_GLOBAL("relax.ExecBuilderGet") +.set_body_typed([](ExecBuilder builder) { return builder->Get(); }); diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index b9d31ab3d3201..f322fa177b2b6 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -61,48 +61,6 @@ def fvisit(e): assert names == ["relax.add", "relax.multiply"] -def test_fma_rewrite(): - m = tir.Var("m", "int32") - n = tir.Var("n", "int32") - dtype0 = rx.DynTensorType(rank=2, dtype="float16") - dtype1 = rx.DynTensorType(rank=2, dtype="float16") - x = rx.Var("x", [m, n], dtype0) - y = rx.Var("y", [m, n], dtype1) - ib = rx.IRBuilder() - with ib.function([x, y]): - with ib.dataflow() as df: - lv0 = ib.emit(rx.op.multiply(x, y)) - lv1 = ib.emit(rx.op.add(lv0, y)) - gv0 = ib.emit_output(lv1) - ib.emit_output(gv0) - expr = ib.get() - - # before rewrite - v0 = expr.body.blocks[0].bindings[1].var - s0 = expr.body.blocks[0].bindings[1].value - assert isinstance(s0, tvm.relay.Call) - assert s0.op.name == "relax.add" - assert structural_equal(v0.shape, rx.ShapeExpr([m, n])) - assert structural_equal(s0.shape, rx.ShapeExpr([m, n])) - assert structural_equal(gv0.shape, rx.ShapeExpr([m, n])) - - # after rewrite - func = rx.analysis.fma_rewrite(expr) - - v1 = func.body.blocks[0].bindings[1].var - s1 = func.body.blocks[0].bindings[1].value - assert isinstance(s1, tvm.relay.Call) - assert s1.op.name == "relax.ewise_fma" - assert structural_equal(v1.shape, rx.ShapeExpr([m, n])) - assert structural_equal(s1.shape, rx.ShapeExpr([m, n])) - - # The var binded to the fma call is reused because the shape - # and type of var are unchanged after rewriting - assert lv1 == v0 - - assert type(func.body.blocks[0].bindings[2].var) == rx.Var - assert type(func.body.blocks[0].bindings[2].value) == rx.DataflowVar - def test_lazy_irbuilder(): m = tir.Var("m", "int32") n = tir.Var("n", "int32") @@ -129,7 +87,7 @@ def test_lazy_irbuilder(): assert s0.op.name == "relax.multiply" # after rewrite (the bindings and the dataflow block are reused) - func = rx.analysis.fma_rewrite(expr) + func = rx.transform.fma_rewrite(expr) block1 = func.body.blocks[0] v1 = func.body.blocks[0].bindings[1].var @@ -140,45 +98,8 @@ def test_lazy_irbuilder(): assert v1 == v0 assert s1 == s0 -def test_explicit_memory_rewrite(): - m = tir.Var("m", "int32") - n = tir.Var("n", "int32") - shape_anno = [m, n] - type_anno = rx.DynTensorType(2, "float32") - x = rx.Var("x", shape_anno, type_anno) - ib = rx.IRBuilder() - with ib.function(x): - with ib.dataflow() as df: - lv0 = rx.call_dps([m, n], rx.extern("test.op.identity"), [x]) - gv0 = ib.emit_output(lv0) - ib.emit_output(gv0) - expr = ib.get() - - # before rewrite - v0 = expr.body.blocks[0].bindings[0].var - s0 = expr.body.blocks[0].bindings[0].value - assert isinstance(s0, tvm.relay.Call) - assert s0.op.name == "relax.call_dps" - - # after rewrite - func = rx.analysis.explicit_memory_rewrite(expr) - - # the dataflow block has changed to binding block due to the rewriting - block = func.body.blocks[0] - assert isinstance(block, rx.BindingBlock) - - s1 = block.bindings[0].value - assert isinstance(s1, tvm.relay.Call) - assert s1.op.name == "relax.builtin.alloc_tensor" - assert isinstance(s1.args[0], rx.ShapeExpr) - assert structural_equal(s1.args[0], rx.ShapeExpr(shape_anno)) - s2 = block.bindings[1].value - assert s2.op.global_symbol == "test.op.identity" - if __name__ == "__main__": test_dispatch_var() test_post_order_visit() - test_fma_rewrite() test_lazy_irbuilder() - test_explicit_memory_rewrite() diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py index f515d331f9ea4..5c3e5ea424d12 100644 --- a/tests/python/relax/test_expr.py +++ b/tests/python/relax/test_expr.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import tvm from tvm import tir from tvm import relax as rx diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index 8e58c84ca613d..95df0ea9250e1 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from __future__ import annotations # must import to defer parsing of annotations import pytest import tvm diff --git a/tests/python/relax/test_printer.py b/tests/python/relax/test_printer.py index bb4363cd20fc0..b99b0a9223d52 100644 --- a/tests/python/relax/test_printer.py +++ b/tests/python/relax/test_printer.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + from __future__ import annotations # must import to defer parsing of annotations import pytest import tvm diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py new file mode 100644 index 0000000000000..304ced55d2ded --- /dev/null +++ b/tests/python/relax/test_transform.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import tir +from tvm import relax as rx +from tvm.ir import structural_equal +import numpy as np + +def test_fma_rewrite(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + dtype0 = rx.DynTensorType(rank=2, dtype="float16") + dtype1 = rx.DynTensorType(rank=2, dtype="float16") + x = rx.Var("x", [m, n], dtype0) + y = rx.Var("y", [m, n], dtype1) + ib = rx.IRBuilder() + with ib.function([x, y]): + with ib.dataflow() as df: + lv0 = ib.emit(rx.op.multiply(x, y)) + lv1 = ib.emit(rx.op.add(lv0, y)) + gv0 = ib.emit_output(lv1) + ib.emit_output(gv0) + expr = ib.get() + + # before rewrite + v0 = expr.body.blocks[0].bindings[1].var + s0 = expr.body.blocks[0].bindings[1].value + assert isinstance(s0, tvm.relay.Call) + assert s0.op.name == "relax.add" + assert structural_equal(v0.shape, rx.ShapeExpr([m, n])) + assert structural_equal(s0.shape, rx.ShapeExpr([m, n])) + assert structural_equal(gv0.shape, rx.ShapeExpr([m, n])) + + # after rewrite + func = rx.transform.fma_rewrite(expr) + + v1 = func.body.blocks[0].bindings[1].var + s1 = func.body.blocks[0].bindings[1].value + assert isinstance(s1, tvm.relay.Call) + assert s1.op.name == "relax.ewise_fma" + assert structural_equal(v1.shape, rx.ShapeExpr([m, n])) + assert structural_equal(s1.shape, rx.ShapeExpr([m, n])) + + # The var binded to the fma call is reused because the shape + # and type of var are unchanged after rewriting + assert lv1 == v0 + + assert type(func.body.blocks[0].bindings[2].var) == rx.Var + assert type(func.body.blocks[0].bindings[2].value) == rx.DataflowVar + + +def test_explicit_memory_rewrite(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + shape_anno = [m, n] + type_anno = rx.DynTensorType(2, "float32") + x = rx.Var("x", shape_anno, type_anno) + ib = rx.IRBuilder() + with ib.function(x): + with ib.dataflow() as df: + lv0 = rx.call_dps([m, n], rx.extern("test.op.identity"), [x]) + gv0 = ib.emit_output(lv0) + ib.emit_output(gv0) + expr = ib.get() + + # before rewrite + v0 = expr.body.blocks[0].bindings[0].var + s0 = expr.body.blocks[0].bindings[0].value + assert isinstance(s0, tvm.relay.Call) + assert s0.op.name == "relax.call_dps" + + # after rewrite + func = rx.transform.explicit_memory_rewrite(expr) + + # the dataflow block has changed to binding block due to the rewriting + block = func.body.blocks[0] + assert isinstance(block, rx.BindingBlock) + + s1 = block.bindings[0].value + assert isinstance(s1, tvm.relay.Call) + assert s1.op.name == "relax.builtin.alloc_tensor" + assert isinstance(s1.args[0], rx.ShapeExpr) + assert structural_equal(s1.args[0], rx.ShapeExpr(shape_anno)) + s2 = block.bindings[1].value + assert s2.op.global_symbol == "test.op.identity" + +if __name__ == "__main__": + test_fma_rewrite() + test_explicit_memory_rewrite()